import uuid from collections import defaultdict from datetime import datetime from typing import Any from fastapi import APIRouter, Depends, HTTPException, Query, status from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from api.database import get_session from api.models.managed_repo import ManagedRepo from api.models.task import Task from api.models.token_event import TokenEvent from api.models.workstream import Workstream from api.schemas.token_event import ( RepoTokenSummary, TokenAggregateRow, TokenAggregateSummary, TokenEventCreate, TokenEventPatch, TokenEventRead, TokenQualitySummary, TokenSummary, ) router = APIRouter(prefix="/token-events", tags=["token-events"]) DEFAULT_CONFIDENCE = { "measured": 1.0, "allocated": 0.70, "estimated": 0.35, "superseded": 0.0, } SOURCE_PARSER_DEFAULTS = { "codex_session": "codex-desktop-v1", "claude_transcript": "claude-transcript-v1", "llm_connect": "llm-connect-v1", } def _event_total(event: TokenEvent) -> int: return event.tokens_in + event.tokens_out def _infer_measurement_kind(data: dict[str, Any]) -> str: if data.get("measurement_kind"): return str(data["measurement_kind"]) note = data.get("note") if note == "heuristic_superseded_by_codex_backfill": return "superseded" if note == "workplan": return "allocated" if note == "heuristic": return "estimated" if note == "measured" or str(note or "").startswith("backfill:codex-session"): return "measured" provider = data.get("source_provider") if provider in {"codex_session", "claude_transcript", "llm_connect"}: return "measured" return "estimated" def _infer_source_provider(data: dict[str, Any], measurement_kind: str) -> str: if data.get("source_provider"): return str(data["source_provider"]) note = data.get("note") ref_id = str(data.get("ref_id") or "") agent = str(data.get("agent") or "").lower() if note == "heuristic": return "task_fallback" if ref_id.startswith("codex:") or str(note or "").startswith("backfill:codex-session"): return "codex_session" if measurement_kind == "measured" and "claude" in agent: return "claude_transcript" return "manual" def _apply_event_defaults(data: dict[str, Any]) -> dict[str, Any]: measurement_kind = _infer_measurement_kind(data) source_provider = _infer_source_provider(data, measurement_kind) data["measurement_kind"] = measurement_kind data["source_provider"] = source_provider if not data.get("source_id") and source_provider in {"codex_session", "claude_transcript", "llm_connect"}: source_id = data.get("ref_id") or data.get("session_id") if source_id: data["source_id"] = str(source_id) if not data.get("source_created_at") and data.get("created_at") and data.get("source_id"): data["source_created_at"] = data["created_at"] data.setdefault("confidence", DEFAULT_CONFIDENCE.get(measurement_kind, 0.35)) data.setdefault("cached_input_tokens", 0) data.setdefault("reasoning_output_tokens", 0) data.setdefault("raw_total_tokens", (data.get("tokens_in") or 0) + (data.get("tokens_out") or 0)) data.setdefault("raw_metadata", {}) if source_provider in SOURCE_PARSER_DEFAULTS: data.setdefault("parser_version", SOURCE_PARSER_DEFAULTS[source_provider]) return data async def _populate_relationship_defaults(data: dict[str, Any], session: AsyncSession) -> dict[str, Any]: # Auto-populate workstream_id from task if not provided if data.get("task_id") and not data.get("workstream_id"): task = await session.get(Task, data["task_id"]) if task: data["workstream_id"] = task.workstream_id # Auto-populate repo_id from workstream if not provided if data.get("workstream_id") and not data.get("repo_id"): ws = await session.get(Workstream, data["workstream_id"]) if ws and ws.repo_id: data["repo_id"] = ws.repo_id return data async def _find_source_event(data: dict[str, Any], session: AsyncSession) -> TokenEvent | None: source_id = data.get("source_id") if not source_id: return None result = await session.execute( select(TokenEvent).where( TokenEvent.measurement_kind == data["measurement_kind"], TokenEvent.source_provider == data["source_provider"], TokenEvent.source_id == source_id, ) ) return result.scalar_one_or_none() async def _create_or_upsert_event(data: dict[str, Any], session: AsyncSession) -> TokenEvent: data = _apply_event_defaults(data) data = await _populate_relationship_defaults(data, session) existing = await _find_source_event(data, session) if existing is not None: for field, value in data.items(): setattr(existing, field, value) await session.commit() await session.refresh(existing) return existing event = TokenEvent(**data) session.add(event) await session.commit() await session.refresh(event) return event def _filter_query( q, *, task_id: uuid.UUID | None = None, workstream_id: uuid.UUID | None = None, repo_id: uuid.UUID | None = None, ref_type: str | None = None, ref_id: str | None = None, model: str | None = None, agent: str | None = None, note: str | None = None, measurement_kind: str | None = None, source_provider: str | None = None, since: datetime | None = None, until: datetime | None = None, include_superseded: bool = True, unattributed: bool = False, ): if task_id: q = q.where(TokenEvent.task_id == task_id) if workstream_id: q = q.where(TokenEvent.workstream_id == workstream_id) if repo_id: q = q.where(TokenEvent.repo_id == repo_id) if ref_type: q = q.where(TokenEvent.ref_type == ref_type) if ref_id: q = q.where(TokenEvent.ref_id == ref_id) if model: q = q.where(TokenEvent.model == model) if agent: q = q.where(TokenEvent.agent == agent) if note: q = q.where(TokenEvent.note == note) if measurement_kind: q = q.where(TokenEvent.measurement_kind == measurement_kind) if source_provider: q = q.where(TokenEvent.source_provider == source_provider) if since: q = q.where(TokenEvent.created_at >= since) if until: q = q.where(TokenEvent.created_at < until) if not include_superseded: q = q.where(TokenEvent.measurement_kind != "superseded") if unattributed: q = q.where( TokenEvent.repo_id.is_(None), TokenEvent.workstream_id.is_(None), TokenEvent.task_id.is_(None), ) return q @router.post("/", response_model=TokenEventRead, status_code=status.HTTP_201_CREATED) async def create_token_event( body: TokenEventCreate, session: AsyncSession = Depends(get_session), ) -> TokenEvent: data = body.model_dump(exclude_none=True) return await _create_or_upsert_event(data, session) @router.post("/upsert", response_model=TokenEventRead) async def upsert_token_event( body: TokenEventCreate, session: AsyncSession = Depends(get_session), ) -> TokenEvent: data = body.model_dump(exclude_none=True) return await _create_or_upsert_event(data, session) @router.get("/summary/", response_model=TokenSummary) async def get_token_summary( scope: str = Query(..., description="task|workstream|repo|commit|release|session"), id: str = Query(..., description="FK value or ref_id depending on scope"), session: AsyncSession = Depends(get_session), ) -> TokenSummary: q = select(TokenEvent) if scope == "task": try: uid = uuid.UUID(id) except ValueError: raise HTTPException(status_code=422, detail="id must be a valid UUID for scope=task") q = q.where(TokenEvent.task_id == uid) elif scope == "workstream": try: uid = uuid.UUID(id) except ValueError: raise HTTPException(status_code=422, detail="id must be a valid UUID for scope=workstream") q = q.where(TokenEvent.workstream_id == uid) elif scope == "repo": try: uid = uuid.UUID(id) except ValueError: raise HTTPException(status_code=422, detail="id must be a valid UUID for scope=repo") q = q.where(TokenEvent.repo_id == uid) elif scope in ("commit", "release", "session"): q = q.where(TokenEvent.ref_type == scope, TokenEvent.ref_id == id) else: raise HTTPException(status_code=422, detail=f"Unknown scope: {scope!r}") result = await session.execute(q) events = list(result.scalars().all()) tokens_in = sum(e.tokens_in for e in events) tokens_out = sum(e.tokens_out for e in events) by_model: dict[str, int] = defaultdict(int) by_agent: dict[str, int] = defaultdict(int) by_measurement_kind: dict[str, int] = defaultdict(int) by_source_provider: dict[str, int] = defaultdict(int) for e in events: total = _event_total(e) if e.model: by_model[e.model] += total if e.agent: by_agent[e.agent] += total by_measurement_kind[e.measurement_kind] += total by_source_provider[e.source_provider] += total return TokenSummary( scope=scope, scope_id=id, tokens_in=tokens_in, tokens_out=tokens_out, tokens_total=tokens_in + tokens_out, event_count=len(events), by_model=dict(by_model), by_agent=dict(by_agent), by_measurement_kind=dict(by_measurement_kind), by_source_provider=dict(by_source_provider), ) @router.get("/by-repo/", response_model=list[RepoTokenSummary]) async def get_tokens_by_repo( measurement_kind: str | None = None, source_provider: str | None = None, since: datetime | None = None, until: datetime | None = None, include_superseded: bool = Query(True), session: AsyncSession = Depends(get_session), ) -> list[RepoTokenSummary]: """Aggregate token consumption per repo, resolving via the full graph. Resolution order for each event: 1. token_events.repo_id (direct) 2. → workstreams.repo_id (via workstream_id) 3. → task.workstream_id → workstreams.repo_id (via task_id) Only events that resolve to a repo are included. """ # Fetch all events, workstreams, repos in three queries (avoids N+1) events_result = await session.execute( _filter_query( select(TokenEvent), measurement_kind=measurement_kind, source_provider=source_provider, since=since, until=until, include_superseded=include_superseded, ) ) events = list(events_result.scalars().all()) ws_result = await session.execute(select(Workstream)) ws_map: dict[uuid.UUID, Workstream] = {w.id: w for w in ws_result.scalars().all()} task_result = await session.execute(select(Task)) task_map: dict[uuid.UUID, Task] = {t.id: t for t in task_result.scalars().all()} repo_result = await session.execute(select(ManagedRepo)) repo_map: dict[uuid.UUID, ManagedRepo] = {r.id: r for r in repo_result.scalars().all()} def resolve_repo_id(e: TokenEvent) -> uuid.UUID | None: if e.repo_id: return e.repo_id ws_id = e.workstream_id if not ws_id and e.task_id and e.task_id in task_map: ws_id = task_map[e.task_id].workstream_id if ws_id and ws_id in ws_map: return ws_map[ws_id].repo_id return None groups: dict[uuid.UUID, dict] = {} for e in events: rid = resolve_repo_id(e) if not rid or rid not in repo_map: continue if rid not in groups: groups[rid] = { "repo_id": rid, "repo_slug": repo_map[rid].slug, "tokens_in": 0, "tokens_out": 0, "event_count": 0, "by_model": defaultdict(int), "by_note": defaultdict(int), "by_measurement_kind": defaultdict(int), "by_source_provider": defaultdict(int), } g = groups[rid] g["tokens_in"] += e.tokens_in g["tokens_out"] += e.tokens_out g["event_count"] += 1 total = _event_total(e) if e.model: g["by_model"][e.model] += total g["by_note"][e.note or "unknown"] += total g["by_measurement_kind"][e.measurement_kind] += total g["by_source_provider"][e.source_provider] += total return [ RepoTokenSummary( **{k: (dict(v) if isinstance(v, defaultdict) else v) for k, v in g.items()}, tokens_total=g["tokens_in"] + g["tokens_out"], ) for g in sorted(groups.values(), key=lambda x: -(x["tokens_in"] + x["tokens_out"])) ] @router.get("/aggregate/", response_model=TokenAggregateSummary) async def get_token_aggregate( measurement_kind: str | None = None, source_provider: str | None = None, since: datetime | None = None, until: datetime | None = None, include_superseded: bool = Query(False), session: AsyncSession = Depends(get_session), ) -> TokenAggregateSummary: events_result = await session.execute( _filter_query( select(TokenEvent), measurement_kind=measurement_kind, source_provider=source_provider, since=since, until=until, include_superseded=include_superseded, ) ) events = list(events_result.scalars().all()) ws_result = await session.execute(select(Workstream)) ws_map: dict[uuid.UUID, Workstream] = {w.id: w for w in ws_result.scalars().all()} task_result = await session.execute(select(Task)) task_map: dict[uuid.UUID, Task] = {t.id: t for t in task_result.scalars().all()} repo_result = await session.execute(select(ManagedRepo)) repo_map: dict[uuid.UUID, ManagedRepo] = {r.id: r for r in repo_result.scalars().all()} def resolve_repo_id(e: TokenEvent) -> uuid.UUID | None: if e.repo_id: return e.repo_id ws_id = e.workstream_id if not ws_id and e.task_id and e.task_id in task_map: ws_id = task_map[e.task_id].workstream_id if ws_id and ws_id in ws_map: return ws_map[ws_id].repo_id return None def add(groups: dict[str, dict[str, Any]], key: str | None, label: str | None, e: TokenEvent) -> None: if not key: return if key not in groups: groups[key] = { "scope_id": key, "label": label, "tokens_in": 0, "tokens_out": 0, "event_count": 0, "by_measurement_kind": defaultdict(int), "by_source_provider": defaultdict(int), } row = groups[key] total = _event_total(e) row["tokens_in"] += e.tokens_in row["tokens_out"] += e.tokens_out row["event_count"] += 1 row["by_measurement_kind"][e.measurement_kind] += total row["by_source_provider"][e.source_provider] += total by_repo: dict[str, dict[str, Any]] = {} by_workstream: dict[str, dict[str, Any]] = {} by_task: dict[str, dict[str, Any]] = {} by_model: dict[str, dict[str, Any]] = {} by_measurement_kind: dict[str, int] = defaultdict(int) by_source_provider: dict[str, int] = defaultdict(int) first_event_at = last_event_at = last_ingested_at = None tokens_in = tokens_out = 0 for e in events: total = _event_total(e) tokens_in += e.tokens_in tokens_out += e.tokens_out by_measurement_kind[e.measurement_kind] += total by_source_provider[e.source_provider] += total if first_event_at is None or e.created_at < first_event_at: first_event_at = e.created_at if last_event_at is None or e.created_at > last_event_at: last_event_at = e.created_at if last_ingested_at is None or e.ingested_at > last_ingested_at: last_ingested_at = e.ingested_at rid = resolve_repo_id(e) repo = repo_map.get(rid) if rid else None add(by_repo, str(rid) if rid else None, repo.slug if repo else None, e) ws_id = e.workstream_id or (task_map[e.task_id].workstream_id if e.task_id in task_map else None) ws = ws_map.get(ws_id) if ws_id else None add(by_workstream, str(ws_id) if ws_id else None, ws.title if ws else None, e) task = task_map.get(e.task_id) if e.task_id else None add(by_task, str(e.task_id) if e.task_id else None, task.title if task else None, e) add(by_model, e.model or "unknown", e.model or "unknown", e) def rows(groups: dict[str, dict[str, Any]]) -> list[TokenAggregateRow]: result = [] for row in groups.values(): result.append( TokenAggregateRow( **{k: (dict(v) if isinstance(v, defaultdict) else v) for k, v in row.items()}, tokens_total=row["tokens_in"] + row["tokens_out"], ) ) return sorted(result, key=lambda item: -item.tokens_total) return TokenAggregateSummary( tokens_in=tokens_in, tokens_out=tokens_out, tokens_total=tokens_in + tokens_out, event_count=len(events), first_event_at=first_event_at, last_event_at=last_event_at, last_ingested_at=last_ingested_at, by_repo=rows(by_repo), by_workstream=rows(by_workstream), by_task=rows(by_task), by_model=rows(by_model), by_measurement_kind=dict(by_measurement_kind), by_source_provider=dict(by_source_provider), ) @router.get("/quality/", response_model=TokenQualitySummary) async def get_token_quality( since: datetime | None = None, until: datetime | None = None, session: AsyncSession = Depends(get_session), ) -> TokenQualitySummary: result = await session.execute(_filter_query(select(TokenEvent), since=since, until=until)) events = list(result.scalars().all()) by_measurement_kind: dict[str, int] = defaultdict(int) by_source_provider: dict[str, int] = defaultdict(int) source_counts: dict[tuple[str, str, str], int] = defaultdict(int) last_codex_ingested_at = None last_claude_ingested_at = None fallback_count = 0 unattributed_measured_count = 0 missing_provenance_count = 0 for e in events: by_measurement_kind[e.measurement_kind] += 1 by_source_provider[e.source_provider] += 1 if e.source_id: source_counts[(e.measurement_kind, e.source_provider, e.source_id)] += 1 if e.source_provider == "task_fallback" or e.note == "heuristic": fallback_count += 1 if e.measurement_kind == "measured" and not (e.repo_id or e.workstream_id or e.task_id): unattributed_measured_count += 1 if e.measurement_kind == "measured" and not e.source_id: missing_provenance_count += 1 if e.source_provider == "codex_session" and ( last_codex_ingested_at is None or e.ingested_at > last_codex_ingested_at ): last_codex_ingested_at = e.ingested_at if e.source_provider == "claude_transcript" and ( last_claude_ingested_at is None or e.ingested_at > last_claude_ingested_at ): last_claude_ingested_at = e.ingested_at duplicate_source_count = sum(1 for count in source_counts.values() if count > 1) return TokenQualitySummary( event_count=len(events), measured_event_count=by_measurement_kind.get("measured", 0), estimated_event_count=by_measurement_kind.get("estimated", 0), allocated_event_count=by_measurement_kind.get("allocated", 0), superseded_event_count=by_measurement_kind.get("superseded", 0), fallback_event_count=fallback_count, unattributed_measured_event_count=unattributed_measured_count, missing_provenance_event_count=missing_provenance_count, duplicate_source_count=duplicate_source_count, last_codex_ingested_at=last_codex_ingested_at, last_claude_ingested_at=last_claude_ingested_at, last_reconciliation_at=None, by_measurement_kind=dict(by_measurement_kind), by_source_provider=dict(by_source_provider), ) @router.patch("/{event_id}", response_model=TokenEventRead) async def patch_token_event( event_id: uuid.UUID, body: TokenEventPatch, session: AsyncSession = Depends(get_session), ) -> TokenEvent: event = await session.get(TokenEvent, event_id) if event is None: raise HTTPException(status_code=404, detail="Token event not found") data = body.model_dump(exclude_none=True) if "note" in data or "measurement_kind" in data or "source_provider" in data: merged = { "tokens_in": data.get("tokens_in", event.tokens_in), "tokens_out": data.get("tokens_out", event.tokens_out), "note": data.get("note", event.note), "agent": data.get("agent", event.agent), "ref_id": data.get("ref_id", event.ref_id), "session_id": data.get("session_id", event.session_id), "measurement_kind": data.get("measurement_kind", event.measurement_kind), "source_provider": data.get("source_provider", event.source_provider), "source_id": data.get("source_id", event.source_id), } inferred = _apply_event_defaults({k: v for k, v in merged.items() if v is not None}) data.setdefault("measurement_kind", inferred["measurement_kind"]) data.setdefault("source_provider", inferred["source_provider"]) data.setdefault("confidence", inferred["confidence"]) if inferred.get("source_id"): data.setdefault("source_id", inferred["source_id"]) for field, value in data.items(): setattr(event, field, value) await session.commit() await session.refresh(event) return event @router.get("/{event_id}", response_model=TokenEventRead) async def get_token_event( event_id: uuid.UUID, session: AsyncSession = Depends(get_session), ) -> TokenEvent: event = await session.get(TokenEvent, event_id) if event is None: raise HTTPException(status_code=404, detail="Token event not found") return event @router.get("/", response_model=list[TokenEventRead]) async def list_token_events( task_id: uuid.UUID | None = None, workstream_id: uuid.UUID | None = None, repo_id: uuid.UUID | None = None, ref_type: str | None = None, ref_id: str | None = None, model: str | None = None, agent: str | None = None, note: str | None = None, measurement_kind: str | None = None, source_provider: str | None = None, since: datetime | None = None, until: datetime | None = None, include_superseded: bool = Query(True), unattributed: bool = False, offset: int = Query(0, ge=0), limit: int = Query(100, le=1000), session: AsyncSession = Depends(get_session), ) -> list[TokenEvent]: q = _filter_query( select(TokenEvent), task_id=task_id, workstream_id=workstream_id, repo_id=repo_id, ref_type=ref_type, ref_id=ref_id, model=model, agent=agent, note=note, measurement_kind=measurement_kind, source_provider=source_provider, since=since, until=until, include_superseded=include_superseded, unattributed=unattributed, ) q = q.order_by(TokenEvent.created_at.desc()).offset(offset).limit(limit) result = await session.execute(q) return list(result.scalars().all())