Files
the-custodian/state-hub/api/routers/token_events.py
tegwick f794d8c47f feat(token-events): auto-capture real token counts via PostToolUse hook
- Add PATCH /token-events/{id} endpoint to correct heuristic events
- Add `note` filter to GET /token-events/ list
- Add TokenEventPatch schema
- Add task_token_hook.py: PostToolUse hook that reads the Claude Code
  session transcript, computes per-task token delta, and replaces the
  heuristic token event with real measured counts (note="measured")
- Register hook in ~/.claude/settings.json on mcp__state-hub__update_task_status
  Covers both interactive sessions and ralph-workplan loops

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-01 22:38:45 +02:00

229 lines
8.0 KiB
Python

import uuid
from collections import defaultdict
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from api.database import get_session
from api.models.managed_repo import ManagedRepo
from api.models.task import Task
from api.models.token_event import TokenEvent
from api.models.workstream import Workstream
from api.schemas.token_event import RepoTokenSummary, TokenEventCreate, TokenEventPatch, TokenEventRead, TokenSummary
router = APIRouter(prefix="/token-events", tags=["token-events"])
@router.post("/", response_model=TokenEventRead, status_code=status.HTTP_201_CREATED)
async def create_token_event(
body: TokenEventCreate,
session: AsyncSession = Depends(get_session),
) -> TokenEvent:
data = body.model_dump()
# Auto-populate workstream_id from task if not provided
if data.get("task_id") and not data.get("workstream_id"):
task = await session.get(Task, data["task_id"])
if task:
data["workstream_id"] = task.workstream_id
# Auto-populate repo_id from workstream if not provided
if data.get("workstream_id") and not data.get("repo_id"):
ws = await session.get(Workstream, data["workstream_id"])
if ws and ws.repo_id:
data["repo_id"] = ws.repo_id
event = TokenEvent(**data)
session.add(event)
await session.commit()
await session.refresh(event)
return event
@router.get("/summary/", response_model=TokenSummary)
async def get_token_summary(
scope: str = Query(..., description="task|workstream|repo|commit|release|session"),
id: str = Query(..., description="FK value or ref_id depending on scope"),
session: AsyncSession = Depends(get_session),
) -> TokenSummary:
q = select(TokenEvent)
if scope == "task":
try:
uid = uuid.UUID(id)
except ValueError:
raise HTTPException(status_code=422, detail="id must be a valid UUID for scope=task")
q = q.where(TokenEvent.task_id == uid)
elif scope == "workstream":
try:
uid = uuid.UUID(id)
except ValueError:
raise HTTPException(status_code=422, detail="id must be a valid UUID for scope=workstream")
q = q.where(TokenEvent.workstream_id == uid)
elif scope == "repo":
try:
uid = uuid.UUID(id)
except ValueError:
raise HTTPException(status_code=422, detail="id must be a valid UUID for scope=repo")
q = q.where(TokenEvent.repo_id == uid)
elif scope in ("commit", "release", "session"):
q = q.where(TokenEvent.ref_type == scope, TokenEvent.ref_id == id)
else:
raise HTTPException(status_code=422, detail=f"Unknown scope: {scope!r}")
result = await session.execute(q)
events = list(result.scalars().all())
tokens_in = sum(e.tokens_in for e in events)
tokens_out = sum(e.tokens_out for e in events)
by_model: dict[str, int] = defaultdict(int)
by_agent: dict[str, int] = defaultdict(int)
for e in events:
if e.model:
by_model[e.model] += e.tokens_in + e.tokens_out
if e.agent:
by_agent[e.agent] += e.tokens_in + e.tokens_out
return TokenSummary(
scope=scope,
scope_id=id,
tokens_in=tokens_in,
tokens_out=tokens_out,
tokens_total=tokens_in + tokens_out,
event_count=len(events),
by_model=dict(by_model),
by_agent=dict(by_agent),
)
@router.get("/by-repo/", response_model=list[RepoTokenSummary])
async def get_tokens_by_repo(
session: AsyncSession = Depends(get_session),
) -> list[RepoTokenSummary]:
"""Aggregate token consumption per repo, resolving via the full graph.
Resolution order for each event:
1. token_events.repo_id (direct)
2. → workstreams.repo_id (via workstream_id)
3. → task.workstream_id → workstreams.repo_id (via task_id)
Only events that resolve to a repo are included.
"""
# Fetch all events, workstreams, repos in three queries (avoids N+1)
events_result = await session.execute(select(TokenEvent))
events = list(events_result.scalars().all())
ws_result = await session.execute(select(Workstream))
ws_map: dict[uuid.UUID, Workstream] = {w.id: w for w in ws_result.scalars().all()}
task_result = await session.execute(select(Task))
task_map: dict[uuid.UUID, Task] = {t.id: t for t in task_result.scalars().all()}
repo_result = await session.execute(select(ManagedRepo))
repo_map: dict[uuid.UUID, ManagedRepo] = {r.id: r for r in repo_result.scalars().all()}
def resolve_repo_id(e: TokenEvent) -> uuid.UUID | None:
if e.repo_id:
return e.repo_id
ws_id = e.workstream_id
if not ws_id and e.task_id and e.task_id in task_map:
ws_id = task_map[e.task_id].workstream_id
if ws_id and ws_id in ws_map:
return ws_map[ws_id].repo_id
return None
groups: dict[uuid.UUID, dict] = {}
for e in events:
rid = resolve_repo_id(e)
if not rid or rid not in repo_map:
continue
if rid not in groups:
groups[rid] = {
"repo_id": rid,
"repo_slug": repo_map[rid].slug,
"tokens_in": 0,
"tokens_out": 0,
"event_count": 0,
"by_model": defaultdict(int),
"by_note": defaultdict(int),
}
g = groups[rid]
g["tokens_in"] += e.tokens_in
g["tokens_out"] += e.tokens_out
g["event_count"] += 1
if e.model:
g["by_model"][e.model] += e.tokens_in + e.tokens_out
g["by_note"][e.note or "unknown"] += e.tokens_in + e.tokens_out
return [
RepoTokenSummary(
**{k: (dict(v) if isinstance(v, defaultdict) else v) for k, v in g.items()},
tokens_total=g["tokens_in"] + g["tokens_out"],
)
for g in sorted(groups.values(), key=lambda x: -(x["tokens_in"] + x["tokens_out"]))
]
@router.patch("/{event_id}", response_model=TokenEventRead)
async def patch_token_event(
event_id: uuid.UUID,
body: TokenEventPatch,
session: AsyncSession = Depends(get_session),
) -> TokenEvent:
event = await session.get(TokenEvent, event_id)
if event is None:
raise HTTPException(status_code=404, detail="Token event not found")
for field, value in body.model_dump(exclude_none=True).items():
setattr(event, field, value)
await session.commit()
await session.refresh(event)
return event
@router.get("/{event_id}", response_model=TokenEventRead)
async def get_token_event(
event_id: uuid.UUID,
session: AsyncSession = Depends(get_session),
) -> TokenEvent:
event = await session.get(TokenEvent, event_id)
if event is None:
raise HTTPException(status_code=404, detail="Token event not found")
return event
@router.get("/", response_model=list[TokenEventRead])
async def list_token_events(
task_id: uuid.UUID | None = None,
workstream_id: uuid.UUID | None = None,
repo_id: uuid.UUID | None = None,
ref_type: str | None = None,
ref_id: str | None = None,
model: str | None = None,
agent: str | None = None,
note: str | None = None,
limit: int = Query(100, le=1000),
session: AsyncSession = Depends(get_session),
) -> list[TokenEvent]:
q = select(TokenEvent)
if task_id:
q = q.where(TokenEvent.task_id == task_id)
if workstream_id:
q = q.where(TokenEvent.workstream_id == workstream_id)
if repo_id:
q = q.where(TokenEvent.repo_id == repo_id)
if ref_type:
q = q.where(TokenEvent.ref_type == ref_type)
if ref_id:
q = q.where(TokenEvent.ref_id == ref_id)
if model:
q = q.where(TokenEvent.model == model)
if agent:
q = q.where(TokenEvent.agent == agent)
if note:
q = q.where(TokenEvent.note == note)
q = q.order_by(TokenEvent.created_at.desc()).limit(limit)
result = await session.execute(q)
return list(result.scalars().all())