diff --git a/state-hub/api/routers/token_events.py b/state-hub/api/routers/token_events.py index f529957..6dd6bf7 100644 --- a/state-hub/api/routers/token_events.py +++ b/state-hub/api/routers/token_events.py @@ -10,7 +10,7 @@ from api.models.managed_repo import ManagedRepo from api.models.task import Task from api.models.token_event import TokenEvent from api.models.workstream import Workstream -from api.schemas.token_event import RepoTokenSummary, TokenEventCreate, TokenEventRead, TokenSummary +from api.schemas.token_event import RepoTokenSummary, TokenEventCreate, TokenEventPatch, TokenEventRead, TokenSummary router = APIRouter(prefix="/token-events", tags=["token-events"]) @@ -166,6 +166,22 @@ async def get_tokens_by_repo( ] +@router.patch("/{event_id}", response_model=TokenEventRead) +async def patch_token_event( + event_id: uuid.UUID, + body: TokenEventPatch, + session: AsyncSession = Depends(get_session), +) -> TokenEvent: + event = await session.get(TokenEvent, event_id) + if event is None: + raise HTTPException(status_code=404, detail="Token event not found") + for field, value in body.model_dump(exclude_none=True).items(): + setattr(event, field, value) + await session.commit() + await session.refresh(event) + return event + + @router.get("/{event_id}", response_model=TokenEventRead) async def get_token_event( event_id: uuid.UUID, @@ -186,6 +202,7 @@ async def list_token_events( ref_id: str | None = None, model: str | None = None, agent: str | None = None, + note: str | None = None, limit: int = Query(100, le=1000), session: AsyncSession = Depends(get_session), ) -> list[TokenEvent]: @@ -204,6 +221,8 @@ async def list_token_events( q = q.where(TokenEvent.model == model) if agent: q = q.where(TokenEvent.agent == agent) + if note: + q = q.where(TokenEvent.note == note) q = q.order_by(TokenEvent.created_at.desc()).limit(limit) result = await session.execute(q) return list(result.scalars().all()) diff --git a/state-hub/api/schemas/token_event.py b/state-hub/api/schemas/token_event.py index 6933228..60acbda 100644 --- a/state-hub/api/schemas/token_event.py +++ b/state-hub/api/schemas/token_event.py @@ -52,6 +52,14 @@ class TokenSummary(BaseModel): by_agent: dict[str, int] +class TokenEventPatch(BaseModel): + tokens_in: int | None = None + tokens_out: int | None = None + note: str | None = None + model: str | None = None + agent: str | None = None + + class RepoTokenSummary(BaseModel): repo_id: uuid.UUID repo_slug: str diff --git a/state-hub/scripts/task_token_hook.py b/state-hub/scripts/task_token_hook.py new file mode 100755 index 0000000..00a72f8 --- /dev/null +++ b/state-hub/scripts/task_token_hook.py @@ -0,0 +1,135 @@ +#!/usr/bin/env python3 +"""PostToolUse hook: replace heuristic token events with real transcript-derived counts. + +Fires after mcp__state-hub__update_task_status when status=done. +Reads the Claude Code session transcript to compute the token delta since the +previous task completion, then PATCHes the heuristic event with real counts. + +State is persisted per session in /tmp/custodian_tokens_.json so +deltas are correctly scoped even when multiple tasks complete in one session. +""" +import json +import os +import sys +import urllib.error +import urllib.request +from pathlib import Path + +API = os.environ.get("CUSTODIAN_API", "http://127.0.0.1:8000") +STATE_DIR = Path(os.environ.get("TMPDIR", "/tmp")) + + +def read_transcript_totals(transcript_path: str) -> tuple[int, int]: + """Sum all usage entries in the transcript JSONL up to the current point.""" + total_in = total_out = 0 + try: + with open(transcript_path) as f: + for line in f: + try: + entry = json.loads(line) + usage = entry.get("message", {}).get("usage", {}) + if usage: + # Count all input token variants (direct + cache creation + cache read) + total_in += ( + usage.get("input_tokens", 0) + + usage.get("cache_creation_input_tokens", 0) + + usage.get("cache_read_input_tokens", 0) + ) + total_out += usage.get("output_tokens", 0) + except (json.JSONDecodeError, TypeError): + continue + except OSError: + pass + return total_in, total_out + + +def load_state(session_id: str) -> tuple[int, int]: + state_file = STATE_DIR / f"custodian_tokens_{session_id}.json" + try: + data = json.loads(state_file.read_text()) + return data.get("total_in", 0), data.get("total_out", 0) + except (OSError, json.JSONDecodeError): + return 0, 0 + + +def save_state(session_id: str, total_in: int, total_out: int) -> None: + state_file = STATE_DIR / f"custodian_tokens_{session_id}.json" + state_file.write_text(json.dumps({"total_in": total_in, "total_out": total_out})) + + +def api_get(path: str): + req = urllib.request.Request(f"{API}{path}") + with urllib.request.urlopen(req, timeout=5) as r: + return json.loads(r.read()) + + +def api_patch(path: str, data: dict): + body = json.dumps(data).encode() + req = urllib.request.Request( + f"{API}{path}", + data=body, + headers={"Content-Type": "application/json"}, + method="PATCH", + ) + with urllib.request.urlopen(req, timeout=5) as r: + return json.loads(r.read()) + + +def main() -> None: + try: + payload = json.loads(sys.stdin.read()) + except json.JSONDecodeError: + return + + tool_name = payload.get("tool_name", "") + if "update_task_status" not in tool_name: + return + + tool_input = payload.get("tool_input", {}) + if tool_input.get("status") != "done": + return + + task_id = tool_input.get("task_id") + if not task_id: + return + + transcript_path = payload.get("transcript_path", "") + session_id = payload.get("session_id", "unknown") + + # Compute token delta for this task + current_in, current_out = read_transcript_totals(transcript_path) + last_in, last_out = load_state(session_id) + delta_in = max(0, current_in - last_in) + delta_out = max(0, current_out - last_out) + save_state(session_id, current_in, current_out) + + if delta_in == 0 and delta_out == 0: + return # Nothing measurable — leave heuristic in place + + # Find the most recent heuristic event for this task and replace it + try: + events = api_get(f"/token-events/?task_id={task_id}¬e=heuristic&limit=5") + except (urllib.error.URLError, OSError): + return # API offline — leave heuristic as-is + + if not events: + return + + event_id = events[0]["id"] + model = tool_input.get("model") + agent = tool_input.get("agent") + + patch_body: dict = {"tokens_in": delta_in, "tokens_out": delta_out, "note": "measured"} + if model: + patch_body["model"] = model + if agent: + patch_body["agent"] = agent + + try: + api_patch(f"/token-events/{event_id}", patch_body) + except (urllib.error.URLError, OSError): + pass + + +if __name__ == "__main__": + main()