diff --git a/state-hub/api/routers/tasks.py b/state-hub/api/routers/tasks.py index 27dbd1a..adf6cc6 100644 --- a/state-hub/api/routers/tasks.py +++ b/state-hub/api/routers/tasks.py @@ -8,6 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from api.database import get_session from api.models.task import Task, TaskStatus from api.models.token_event import TokenEvent +from api.models.workstream import Workstream from api.schemas.task import TaskCreate, TaskRead, TaskUpdate router = APIRouter(prefix="/tasks", tags=["tasks"]) @@ -104,9 +105,14 @@ async def update_task( # Tier 3: heuristic fallback tin, tout, tnote = 1000, 500, "heuristic" + # Resolve repo_id via workstream + ws = await session.get(Workstream, task.workstream_id) + repo_id = ws.repo_id if ws else None + event = TokenEvent( task_id=task_id, workstream_id=task.workstream_id, + repo_id=repo_id, tokens_in=tin, tokens_out=tout, model=token_data.get("model"), diff --git a/state-hub/api/routers/token_events.py b/state-hub/api/routers/token_events.py index 82e5850..c3f8ee2 100644 --- a/state-hub/api/routers/token_events.py +++ b/state-hub/api/routers/token_events.py @@ -6,9 +6,11 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from api.database import get_session +from api.models.managed_repo import ManagedRepo from api.models.task import Task from api.models.token_event import TokenEvent -from api.schemas.token_event import TokenEventCreate, TokenEventRead, TokenSummary +from api.models.workstream import Workstream +from api.schemas.token_event import RepoTokenSummary, TokenEventCreate, TokenEventRead, TokenSummary router = APIRouter(prefix="/token-events", tags=["token-events"]) @@ -26,6 +28,12 @@ async def create_token_event( if task: data["workstream_id"] = task.workstream_id + # Auto-populate repo_id from workstream if not provided + if data.get("workstream_id") and not data.get("repo_id"): + ws = await session.get(Workstream, data["workstream_id"]) + if ws and ws.repo_id: + data["repo_id"] = ws.repo_id + event = TokenEvent(**data) session.add(event) await session.commit() @@ -90,6 +98,74 @@ async def get_token_summary( ) +@router.get("/by-repo/", response_model=list[RepoTokenSummary]) +async def get_tokens_by_repo( + session: AsyncSession = Depends(get_session), +) -> list[RepoTokenSummary]: + """Aggregate token consumption per repo, resolving via the full graph. + + Resolution order for each event: + 1. token_events.repo_id (direct) + 2. → workstreams.repo_id (via workstream_id) + 3. → task.workstream_id → workstreams.repo_id (via task_id) + + Only events that resolve to a repo are included. + """ + # Fetch all events, workstreams, repos in three queries (avoids N+1) + events_result = await session.execute(select(TokenEvent)) + events = list(events_result.scalars().all()) + + ws_result = await session.execute(select(Workstream)) + ws_map: dict[uuid.UUID, Workstream] = {w.id: w for w in ws_result.scalars().all()} + + task_result = await session.execute(select(Task)) + task_map: dict[uuid.UUID, Task] = {t.id: t for t in task_result.scalars().all()} + + repo_result = await session.execute(select(ManagedRepo)) + repo_map: dict[uuid.UUID, ManagedRepo] = {r.id: r for r in repo_result.scalars().all()} + + def resolve_repo_id(e: TokenEvent) -> uuid.UUID | None: + if e.repo_id: + return e.repo_id + ws_id = e.workstream_id + if not ws_id and e.task_id and e.task_id in task_map: + ws_id = task_map[e.task_id].workstream_id + if ws_id and ws_id in ws_map: + return ws_map[ws_id].repo_id + return None + + groups: dict[uuid.UUID, dict] = {} + for e in events: + rid = resolve_repo_id(e) + if not rid or rid not in repo_map: + continue + if rid not in groups: + groups[rid] = { + "repo_id": rid, + "repo_slug": repo_map[rid].slug, + "tokens_in": 0, + "tokens_out": 0, + "event_count": 0, + "by_model": defaultdict(int), + "by_note": defaultdict(int), + } + g = groups[rid] + g["tokens_in"] += e.tokens_in + g["tokens_out"] += e.tokens_out + g["event_count"] += 1 + if e.model: + g["by_model"][e.model] += e.tokens_in + e.tokens_out + g["by_note"][e.note or "unknown"] += e.tokens_in + e.tokens_out + + return [ + RepoTokenSummary( + **{k: (dict(v) if isinstance(v, defaultdict) else v) for k, v in g.items()}, + tokens_total=g["tokens_in"] + g["tokens_out"], + ) + for g in sorted(groups.values(), key=lambda x: -(x["tokens_in"] + x["tokens_out"])) + ] + + @router.get("/", response_model=list[TokenEventRead]) async def list_token_events( task_id: uuid.UUID | None = None, diff --git a/state-hub/api/schemas/token_event.py b/state-hub/api/schemas/token_event.py index 966bf12..6933228 100644 --- a/state-hub/api/schemas/token_event.py +++ b/state-hub/api/schemas/token_event.py @@ -50,3 +50,14 @@ class TokenSummary(BaseModel): event_count: int by_model: dict[str, int] by_agent: dict[str, int] + + +class RepoTokenSummary(BaseModel): + repo_id: uuid.UUID + repo_slug: str + tokens_in: int + tokens_out: int + tokens_total: int + event_count: int + by_model: dict[str, int] + by_note: dict[str, int] diff --git a/state-hub/dashboard/src/token-cost.md b/state-hub/dashboard/src/token-cost.md index dd012b3..1ad4b9b 100644 --- a/state-hub/dashboard/src/token-cost.md +++ b/state-hub/dashboard/src/token-cost.md @@ -8,19 +8,22 @@ const POLL = 60_000; ``` ```js -// Live poll for token data +// Fetch both /by-repo/ and raw events in parallel const tokenState = (async function*() { while (true) { - let data = {by_repo: [], by_workstream: [], top_tasks: [], by_model: [], total_events: 0}, ok = false; + let byRepo = [], events = [], ok = false; try { - const r = await fetch(`${API}/token-events/?limit=1000`); - ok = r.ok; + const [r1, r2] = await Promise.all([ + fetch(`${API}/token-events/by-repo/`), + fetch(`${API}/token-events/?limit=1000`), + ]); + ok = r1.ok && r2.ok; if (ok) { - const events = await r.json(); - data = buildSummary(events); + byRepo = await r1.json(); + events = await r2.json(); } } catch {} - yield {data, ok, ts: new Date()}; + yield {byRepo, events, ok, ts: new Date()}; await new Promise(res => setTimeout(res, POLL)); } })(); @@ -28,15 +31,9 @@ const tokenState = (async function*() { ```js function buildSummary(events) { - const byRepo = {}, byWs = {}, byModel = {}, byTask = {}; + const byWs = {}, byModel = {}, byTask = {}; for (const e of events) { const tot = (e.tokens_in || 0) + (e.tokens_out || 0); - if (e.repo_id) { - byRepo[e.repo_id] = byRepo[e.repo_id] || {scope_id: e.repo_id, tokens_in: 0, tokens_out: 0, event_count: 0}; - byRepo[e.repo_id].tokens_in += e.tokens_in || 0; - byRepo[e.repo_id].tokens_out += e.tokens_out || 0; - byRepo[e.repo_id].event_count++; - } if (e.workstream_id) { byWs[e.workstream_id] = byWs[e.workstream_id] || {scope_id: e.workstream_id, tokens_in: 0, tokens_out: 0, event_count: 0}; byWs[e.workstream_id].tokens_in += e.tokens_in || 0; @@ -55,7 +52,6 @@ function buildSummary(events) { .map(([k,v]) => typeof v === "number" ? {id: k, tokens_total: v} : {...v, tokens_total: (v.tokens_in||0)+(v.tokens_out||0)}) .sort((a,b) => b.tokens_total - a.tokens_total); return { - by_repo: sortDesc(byRepo), by_workstream: sortDesc(byWs), by_model: Object.entries(byModel).map(([model,tokens_total]) => ({model,tokens_total})).sort((a,b)=>b.tokens_total-a.tokens_total), top_tasks: sortDesc(byTask).slice(0,10), @@ -65,7 +61,8 @@ function buildSummary(events) { ``` ```js -const td = tokenState.data ?? {by_repo:[], by_workstream:[], top_tasks:[], by_model:[], total_events:0}; +const byRepo = tokenState.byRepo ?? []; +const summary = buildSummary(tokenState.events ?? []); const _ok = tokenState.ok ?? false; const _ts = tokenState.ts; ``` @@ -74,7 +71,7 @@ const _ts = tokenState.ts; ```js const _liveEl = html`
No token events recorded yet.
`); +if (byRepo.length === 0) { + display(html`No token events with repo association yet.
`); } else { display(Plot.plot({ title: "Token consumption by repo", @@ -94,9 +91,9 @@ if (td.by_repo.length === 0) { color: {legend: true, domain: ["tokens_in", "tokens_out"], range: ["#4e79a7","#f28e2b"]}, marks: [ Plot.barX( - td.by_repo.flatMap(r => [ - {repo: r.scope_id.slice(0,8), type: "tokens_in", value: r.tokens_in}, - {repo: r.scope_id.slice(0,8), type: "tokens_out", value: r.tokens_out}, + byRepo.flatMap(r => [ + {repo: r.repo_slug, type: "tokens_in", value: r.tokens_in}, + {repo: r.repo_slug, type: "tokens_out", value: r.tokens_out}, ]), {x: "value", y: "repo", fill: "type", tip: true} ), @@ -108,7 +105,7 @@ if (td.by_repo.length === 0) { ## By Workplan ```js -const wsRows = td.by_workstream.slice(0, 20); +const wsRows = summary.by_workstream.slice(0, 20); if (wsRows.length === 0) { display(html`No workstream data yet.
`); } else { @@ -135,7 +132,7 @@ if (wsRows.length === 0) { ## By Model ```js -if (td.by_model.length === 0) { +if (summary.by_model.length === 0) { display(html`No model data yet.
`); } else { display(Plot.plot({ @@ -144,7 +141,7 @@ if (td.by_model.length === 0) { width: Math.min(700, width), x: {label: "Total tokens", tickFormat: "~s"}, marks: [ - Plot.barX(td.by_model, {x: "tokens_total", y: "model", fill: "#4e79a7", tip: true}), + Plot.barX(summary.by_model, {x: "tokens_total", y: "model", fill: "#4e79a7", tip: true}), ], })); } @@ -153,10 +150,10 @@ if (td.by_model.length === 0) { ## Top 10 Tasks by Tokens ```js -if (td.top_tasks.length === 0) { +if (summary.top_tasks.length === 0) { display(html`No task-level data yet.
`); } else { - display(Inputs.table(td.top_tasks, { + display(Inputs.table(summary.top_tasks, { columns: ["task_id", "tokens_in", "tokens_out", "tokens_total"], header: {task_id: "Task ID", tokens_in: "In", tokens_out: "Out", tokens_total: "Total"}, format: {