- Add PATCH /token-events/{id} endpoint to correct heuristic events
- Add `note` filter to GET /token-events/ list
- Add TokenEventPatch schema
- Add task_token_hook.py: PostToolUse hook that reads the Claude Code
session transcript, computes per-task token delta, and replaces the
heuristic token event with real measured counts (note="measured")
- Register hook in ~/.claude/settings.json on mcp__state-hub__update_task_status
Covers both interactive sessions and ralph-workplan loops
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
229 lines
8.0 KiB
Python
229 lines
8.0 KiB
Python
import uuid
|
|
from collections import defaultdict
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from api.database import get_session
|
|
from api.models.managed_repo import ManagedRepo
|
|
from api.models.task import Task
|
|
from api.models.token_event import TokenEvent
|
|
from api.models.workstream import Workstream
|
|
from api.schemas.token_event import RepoTokenSummary, TokenEventCreate, TokenEventPatch, TokenEventRead, TokenSummary
|
|
|
|
router = APIRouter(prefix="/token-events", tags=["token-events"])
|
|
|
|
|
|
@router.post("/", response_model=TokenEventRead, status_code=status.HTTP_201_CREATED)
|
|
async def create_token_event(
|
|
body: TokenEventCreate,
|
|
session: AsyncSession = Depends(get_session),
|
|
) -> TokenEvent:
|
|
data = body.model_dump()
|
|
|
|
# Auto-populate workstream_id from task if not provided
|
|
if data.get("task_id") and not data.get("workstream_id"):
|
|
task = await session.get(Task, data["task_id"])
|
|
if task:
|
|
data["workstream_id"] = task.workstream_id
|
|
|
|
# Auto-populate repo_id from workstream if not provided
|
|
if data.get("workstream_id") and not data.get("repo_id"):
|
|
ws = await session.get(Workstream, data["workstream_id"])
|
|
if ws and ws.repo_id:
|
|
data["repo_id"] = ws.repo_id
|
|
|
|
event = TokenEvent(**data)
|
|
session.add(event)
|
|
await session.commit()
|
|
await session.refresh(event)
|
|
return event
|
|
|
|
|
|
@router.get("/summary/", response_model=TokenSummary)
|
|
async def get_token_summary(
|
|
scope: str = Query(..., description="task|workstream|repo|commit|release|session"),
|
|
id: str = Query(..., description="FK value or ref_id depending on scope"),
|
|
session: AsyncSession = Depends(get_session),
|
|
) -> TokenSummary:
|
|
q = select(TokenEvent)
|
|
|
|
if scope == "task":
|
|
try:
|
|
uid = uuid.UUID(id)
|
|
except ValueError:
|
|
raise HTTPException(status_code=422, detail="id must be a valid UUID for scope=task")
|
|
q = q.where(TokenEvent.task_id == uid)
|
|
elif scope == "workstream":
|
|
try:
|
|
uid = uuid.UUID(id)
|
|
except ValueError:
|
|
raise HTTPException(status_code=422, detail="id must be a valid UUID for scope=workstream")
|
|
q = q.where(TokenEvent.workstream_id == uid)
|
|
elif scope == "repo":
|
|
try:
|
|
uid = uuid.UUID(id)
|
|
except ValueError:
|
|
raise HTTPException(status_code=422, detail="id must be a valid UUID for scope=repo")
|
|
q = q.where(TokenEvent.repo_id == uid)
|
|
elif scope in ("commit", "release", "session"):
|
|
q = q.where(TokenEvent.ref_type == scope, TokenEvent.ref_id == id)
|
|
else:
|
|
raise HTTPException(status_code=422, detail=f"Unknown scope: {scope!r}")
|
|
|
|
result = await session.execute(q)
|
|
events = list(result.scalars().all())
|
|
|
|
tokens_in = sum(e.tokens_in for e in events)
|
|
tokens_out = sum(e.tokens_out for e in events)
|
|
|
|
by_model: dict[str, int] = defaultdict(int)
|
|
by_agent: dict[str, int] = defaultdict(int)
|
|
for e in events:
|
|
if e.model:
|
|
by_model[e.model] += e.tokens_in + e.tokens_out
|
|
if e.agent:
|
|
by_agent[e.agent] += e.tokens_in + e.tokens_out
|
|
|
|
return TokenSummary(
|
|
scope=scope,
|
|
scope_id=id,
|
|
tokens_in=tokens_in,
|
|
tokens_out=tokens_out,
|
|
tokens_total=tokens_in + tokens_out,
|
|
event_count=len(events),
|
|
by_model=dict(by_model),
|
|
by_agent=dict(by_agent),
|
|
)
|
|
|
|
|
|
@router.get("/by-repo/", response_model=list[RepoTokenSummary])
|
|
async def get_tokens_by_repo(
|
|
session: AsyncSession = Depends(get_session),
|
|
) -> list[RepoTokenSummary]:
|
|
"""Aggregate token consumption per repo, resolving via the full graph.
|
|
|
|
Resolution order for each event:
|
|
1. token_events.repo_id (direct)
|
|
2. → workstreams.repo_id (via workstream_id)
|
|
3. → task.workstream_id → workstreams.repo_id (via task_id)
|
|
|
|
Only events that resolve to a repo are included.
|
|
"""
|
|
# Fetch all events, workstreams, repos in three queries (avoids N+1)
|
|
events_result = await session.execute(select(TokenEvent))
|
|
events = list(events_result.scalars().all())
|
|
|
|
ws_result = await session.execute(select(Workstream))
|
|
ws_map: dict[uuid.UUID, Workstream] = {w.id: w for w in ws_result.scalars().all()}
|
|
|
|
task_result = await session.execute(select(Task))
|
|
task_map: dict[uuid.UUID, Task] = {t.id: t for t in task_result.scalars().all()}
|
|
|
|
repo_result = await session.execute(select(ManagedRepo))
|
|
repo_map: dict[uuid.UUID, ManagedRepo] = {r.id: r for r in repo_result.scalars().all()}
|
|
|
|
def resolve_repo_id(e: TokenEvent) -> uuid.UUID | None:
|
|
if e.repo_id:
|
|
return e.repo_id
|
|
ws_id = e.workstream_id
|
|
if not ws_id and e.task_id and e.task_id in task_map:
|
|
ws_id = task_map[e.task_id].workstream_id
|
|
if ws_id and ws_id in ws_map:
|
|
return ws_map[ws_id].repo_id
|
|
return None
|
|
|
|
groups: dict[uuid.UUID, dict] = {}
|
|
for e in events:
|
|
rid = resolve_repo_id(e)
|
|
if not rid or rid not in repo_map:
|
|
continue
|
|
if rid not in groups:
|
|
groups[rid] = {
|
|
"repo_id": rid,
|
|
"repo_slug": repo_map[rid].slug,
|
|
"tokens_in": 0,
|
|
"tokens_out": 0,
|
|
"event_count": 0,
|
|
"by_model": defaultdict(int),
|
|
"by_note": defaultdict(int),
|
|
}
|
|
g = groups[rid]
|
|
g["tokens_in"] += e.tokens_in
|
|
g["tokens_out"] += e.tokens_out
|
|
g["event_count"] += 1
|
|
if e.model:
|
|
g["by_model"][e.model] += e.tokens_in + e.tokens_out
|
|
g["by_note"][e.note or "unknown"] += e.tokens_in + e.tokens_out
|
|
|
|
return [
|
|
RepoTokenSummary(
|
|
**{k: (dict(v) if isinstance(v, defaultdict) else v) for k, v in g.items()},
|
|
tokens_total=g["tokens_in"] + g["tokens_out"],
|
|
)
|
|
for g in sorted(groups.values(), key=lambda x: -(x["tokens_in"] + x["tokens_out"]))
|
|
]
|
|
|
|
|
|
@router.patch("/{event_id}", response_model=TokenEventRead)
|
|
async def patch_token_event(
|
|
event_id: uuid.UUID,
|
|
body: TokenEventPatch,
|
|
session: AsyncSession = Depends(get_session),
|
|
) -> TokenEvent:
|
|
event = await session.get(TokenEvent, event_id)
|
|
if event is None:
|
|
raise HTTPException(status_code=404, detail="Token event not found")
|
|
for field, value in body.model_dump(exclude_none=True).items():
|
|
setattr(event, field, value)
|
|
await session.commit()
|
|
await session.refresh(event)
|
|
return event
|
|
|
|
|
|
@router.get("/{event_id}", response_model=TokenEventRead)
|
|
async def get_token_event(
|
|
event_id: uuid.UUID,
|
|
session: AsyncSession = Depends(get_session),
|
|
) -> TokenEvent:
|
|
event = await session.get(TokenEvent, event_id)
|
|
if event is None:
|
|
raise HTTPException(status_code=404, detail="Token event not found")
|
|
return event
|
|
|
|
|
|
@router.get("/", response_model=list[TokenEventRead])
|
|
async def list_token_events(
|
|
task_id: uuid.UUID | None = None,
|
|
workstream_id: uuid.UUID | None = None,
|
|
repo_id: uuid.UUID | None = None,
|
|
ref_type: str | None = None,
|
|
ref_id: str | None = None,
|
|
model: str | None = None,
|
|
agent: str | None = None,
|
|
note: str | None = None,
|
|
limit: int = Query(100, le=1000),
|
|
session: AsyncSession = Depends(get_session),
|
|
) -> list[TokenEvent]:
|
|
q = select(TokenEvent)
|
|
if task_id:
|
|
q = q.where(TokenEvent.task_id == task_id)
|
|
if workstream_id:
|
|
q = q.where(TokenEvent.workstream_id == workstream_id)
|
|
if repo_id:
|
|
q = q.where(TokenEvent.repo_id == repo_id)
|
|
if ref_type:
|
|
q = q.where(TokenEvent.ref_type == ref_type)
|
|
if ref_id:
|
|
q = q.where(TokenEvent.ref_id == ref_id)
|
|
if model:
|
|
q = q.where(TokenEvent.model == model)
|
|
if agent:
|
|
q = q.where(TokenEvent.agent == agent)
|
|
if note:
|
|
q = q.where(TokenEvent.note == note)
|
|
q = q.order_by(TokenEvent.created_at.desc()).limit(limit)
|
|
result = await session.execute(q)
|
|
return list(result.scalars().all())
|