Files
the-custodian/state-hub/api/routers/token_events.py
tegwick d65bc701da feat(token-tracking): record AI token consumption per task (CUST-WP-0029)
Introduces end-to-end token consumption tracking so agent work is
visible as a cost/effort metric alongside tasks and workplans.

- Migration o2j3k4l5m6n7: token_events table with FK indexes on
  task_id, workstream_id, repo_id, created_at
- ORM model, Pydantic schemas (TokenEventCreate, TokenEventRead with
  computed tokens_total, TokenSummary)
- Router: POST /token-events/, GET /token-events/ (7 filters),
  GET /token-events/summary/ (task|workstream|repo|commit|release scope)
- MCP tools: record_token_event, get_token_summary (formatted table)
- update_task_status enriched with optional tokens_in/tokens_out
  passthrough — one call creates status update + token event
- Dashboard token-cost.md page: by-repo bar, by-workplan table,
  by-model bar, top-10 tasks by tokens
- ralph-workplan skill updated with token reporting guidance and
  per-task heuristics for estimating counts
- Tests: test_token_events.py + test_token_passthrough.py (182 pass)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-29 17:46:46 +02:00

123 lines
4.1 KiB
Python

import uuid
from collections import defaultdict
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from api.database import get_session
from api.models.task import Task
from api.models.token_event import TokenEvent
from api.schemas.token_event import TokenEventCreate, TokenEventRead, TokenSummary
router = APIRouter(prefix="/token-events", tags=["token-events"])
@router.post("/", response_model=TokenEventRead, status_code=status.HTTP_201_CREATED)
async def create_token_event(
body: TokenEventCreate,
session: AsyncSession = Depends(get_session),
) -> TokenEvent:
data = body.model_dump()
# Auto-populate workstream_id from task if not provided
if data.get("task_id") and not data.get("workstream_id"):
task = await session.get(Task, data["task_id"])
if task:
data["workstream_id"] = task.workstream_id
event = TokenEvent(**data)
session.add(event)
await session.commit()
await session.refresh(event)
return event
@router.get("/summary/", response_model=TokenSummary)
async def get_token_summary(
scope: str = Query(..., description="task|workstream|repo|commit|release|session"),
id: str = Query(..., description="FK value or ref_id depending on scope"),
session: AsyncSession = Depends(get_session),
) -> TokenSummary:
q = select(TokenEvent)
if scope == "task":
try:
uid = uuid.UUID(id)
except ValueError:
raise HTTPException(status_code=422, detail="id must be a valid UUID for scope=task")
q = q.where(TokenEvent.task_id == uid)
elif scope == "workstream":
try:
uid = uuid.UUID(id)
except ValueError:
raise HTTPException(status_code=422, detail="id must be a valid UUID for scope=workstream")
q = q.where(TokenEvent.workstream_id == uid)
elif scope == "repo":
try:
uid = uuid.UUID(id)
except ValueError:
raise HTTPException(status_code=422, detail="id must be a valid UUID for scope=repo")
q = q.where(TokenEvent.repo_id == uid)
elif scope in ("commit", "release", "session"):
q = q.where(TokenEvent.ref_type == scope, TokenEvent.ref_id == id)
else:
raise HTTPException(status_code=422, detail=f"Unknown scope: {scope!r}")
result = await session.execute(q)
events = list(result.scalars().all())
tokens_in = sum(e.tokens_in for e in events)
tokens_out = sum(e.tokens_out for e in events)
by_model: dict[str, int] = defaultdict(int)
by_agent: dict[str, int] = defaultdict(int)
for e in events:
if e.model:
by_model[e.model] += e.tokens_in + e.tokens_out
if e.agent:
by_agent[e.agent] += e.tokens_in + e.tokens_out
return TokenSummary(
scope=scope,
scope_id=id,
tokens_in=tokens_in,
tokens_out=tokens_out,
tokens_total=tokens_in + tokens_out,
event_count=len(events),
by_model=dict(by_model),
by_agent=dict(by_agent),
)
@router.get("/", response_model=list[TokenEventRead])
async def list_token_events(
task_id: uuid.UUID | None = None,
workstream_id: uuid.UUID | None = None,
repo_id: uuid.UUID | None = None,
ref_type: str | None = None,
ref_id: str | None = None,
model: str | None = None,
agent: str | None = None,
limit: int = Query(100, le=1000),
session: AsyncSession = Depends(get_session),
) -> list[TokenEvent]:
q = select(TokenEvent)
if task_id:
q = q.where(TokenEvent.task_id == task_id)
if workstream_id:
q = q.where(TokenEvent.workstream_id == workstream_id)
if repo_id:
q = q.where(TokenEvent.repo_id == repo_id)
if ref_type:
q = q.where(TokenEvent.ref_type == ref_type)
if ref_id:
q = q.where(TokenEvent.ref_id == ref_id)
if model:
q = q.where(TokenEvent.model == model)
if agent:
q = q.where(TokenEvent.agent == agent)
q = q.order_by(TokenEvent.created_at.desc()).limit(limit)
result = await session.execute(q)
return list(result.scalars().all())