feat(token-tracking): record AI token consumption per task (CUST-WP-0029)

Introduces end-to-end token consumption tracking so agent work is
visible as a cost/effort metric alongside tasks and workplans.

- Migration o2j3k4l5m6n7: token_events table with FK indexes on
  task_id, workstream_id, repo_id, created_at
- ORM model, Pydantic schemas (TokenEventCreate, TokenEventRead with
  computed tokens_total, TokenSummary)
- Router: POST /token-events/, GET /token-events/ (7 filters),
  GET /token-events/summary/ (task|workstream|repo|commit|release scope)
- MCP tools: record_token_event, get_token_summary (formatted table)
- update_task_status enriched with optional tokens_in/tokens_out
  passthrough — one call creates status update + token event
- Dashboard token-cost.md page: by-repo bar, by-workplan table,
  by-model bar, top-10 tasks by tokens
- ralph-workplan skill updated with token reporting guidance and
  per-task heuristics for estimating counts
- Tests: test_token_events.py + test_token_passthrough.py (182 pass)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-29 17:46:46 +02:00
parent a486c63603
commit 58e1bafce9
15 changed files with 983 additions and 2 deletions

View File

@@ -7,6 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from api.database import get_session
from api.models.task import Task, TaskStatus
from api.models.token_event import TokenEvent
from api.schemas.task import TaskCreate, TaskRead, TaskUpdate
router = APIRouter(prefix="/tasks", tags=["tasks"])
@@ -72,10 +73,33 @@ async def update_task(
task = await session.get(Task, task_id)
if task is None:
raise HTTPException(status_code=404, detail="Task not found")
for field, value in body.model_dump(exclude_unset=True).items():
# Separate token fields from task fields
token_fields = {"tokens_in", "tokens_out", "model", "agent", "session_id"}
update_data = body.model_dump(exclude_unset=True)
token_data = {k: update_data.pop(k) for k in list(update_data.keys()) if k in token_fields}
for field, value in update_data.items():
setattr(task, field, value)
await session.commit()
await session.refresh(task)
# Create token event if token passthrough fields provided
if "tokens_in" in token_data and "tokens_out" in token_data:
event = TokenEvent(
task_id=task_id,
workstream_id=task.workstream_id,
tokens_in=token_data["tokens_in"],
tokens_out=token_data["tokens_out"],
model=token_data.get("model"),
agent=token_data.get("agent"),
session_id=token_data.get("session_id"),
ref_type="task",
ref_id=str(task_id),
)
session.add(event)
await session.commit()
return task

122
api/routers/token_events.py Normal file
View File

@@ -0,0 +1,122 @@
import uuid
from collections import defaultdict
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from api.database import get_session
from api.models.task import Task
from api.models.token_event import TokenEvent
from api.schemas.token_event import TokenEventCreate, TokenEventRead, TokenSummary
router = APIRouter(prefix="/token-events", tags=["token-events"])
@router.post("/", response_model=TokenEventRead, status_code=status.HTTP_201_CREATED)
async def create_token_event(
body: TokenEventCreate,
session: AsyncSession = Depends(get_session),
) -> TokenEvent:
data = body.model_dump()
# Auto-populate workstream_id from task if not provided
if data.get("task_id") and not data.get("workstream_id"):
task = await session.get(Task, data["task_id"])
if task:
data["workstream_id"] = task.workstream_id
event = TokenEvent(**data)
session.add(event)
await session.commit()
await session.refresh(event)
return event
@router.get("/summary/", response_model=TokenSummary)
async def get_token_summary(
scope: str = Query(..., description="task|workstream|repo|commit|release|session"),
id: str = Query(..., description="FK value or ref_id depending on scope"),
session: AsyncSession = Depends(get_session),
) -> TokenSummary:
q = select(TokenEvent)
if scope == "task":
try:
uid = uuid.UUID(id)
except ValueError:
raise HTTPException(status_code=422, detail="id must be a valid UUID for scope=task")
q = q.where(TokenEvent.task_id == uid)
elif scope == "workstream":
try:
uid = uuid.UUID(id)
except ValueError:
raise HTTPException(status_code=422, detail="id must be a valid UUID for scope=workstream")
q = q.where(TokenEvent.workstream_id == uid)
elif scope == "repo":
try:
uid = uuid.UUID(id)
except ValueError:
raise HTTPException(status_code=422, detail="id must be a valid UUID for scope=repo")
q = q.where(TokenEvent.repo_id == uid)
elif scope in ("commit", "release", "session"):
q = q.where(TokenEvent.ref_type == scope, TokenEvent.ref_id == id)
else:
raise HTTPException(status_code=422, detail=f"Unknown scope: {scope!r}")
result = await session.execute(q)
events = list(result.scalars().all())
tokens_in = sum(e.tokens_in for e in events)
tokens_out = sum(e.tokens_out for e in events)
by_model: dict[str, int] = defaultdict(int)
by_agent: dict[str, int] = defaultdict(int)
for e in events:
if e.model:
by_model[e.model] += e.tokens_in + e.tokens_out
if e.agent:
by_agent[e.agent] += e.tokens_in + e.tokens_out
return TokenSummary(
scope=scope,
scope_id=id,
tokens_in=tokens_in,
tokens_out=tokens_out,
tokens_total=tokens_in + tokens_out,
event_count=len(events),
by_model=dict(by_model),
by_agent=dict(by_agent),
)
@router.get("/", response_model=list[TokenEventRead])
async def list_token_events(
task_id: uuid.UUID | None = None,
workstream_id: uuid.UUID | None = None,
repo_id: uuid.UUID | None = None,
ref_type: str | None = None,
ref_id: str | None = None,
model: str | None = None,
agent: str | None = None,
limit: int = Query(100, le=1000),
session: AsyncSession = Depends(get_session),
) -> list[TokenEvent]:
q = select(TokenEvent)
if task_id:
q = q.where(TokenEvent.task_id == task_id)
if workstream_id:
q = q.where(TokenEvent.workstream_id == workstream_id)
if repo_id:
q = q.where(TokenEvent.repo_id == repo_id)
if ref_type:
q = q.where(TokenEvent.ref_type == ref_type)
if ref_id:
q = q.where(TokenEvent.ref_id == ref_id)
if model:
q = q.where(TokenEvent.model == model)
if agent:
q = q.where(TokenEvent.agent == agent)
q = q.order_by(TokenEvent.created_at.desc()).limit(limit)
result = await session.execute(q)
return list(result.scalars().all())