Files
state-hub/api/routers/token_events.py

641 lines
24 KiB
Python

import uuid
from collections import defaultdict
from datetime import datetime
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from api.database import get_session
from api.models.managed_repo import ManagedRepo
from api.models.task import Task
from api.models.token_event import TokenEvent
from api.models.workstream import Workstream
from api.schemas.token_event import (
RepoTokenSummary,
TokenAggregateRow,
TokenAggregateSummary,
TokenEventCreate,
TokenEventPatch,
TokenEventRead,
TokenQualitySummary,
TokenSummary,
)
router = APIRouter(prefix="/token-events", tags=["token-events"])
DEFAULT_CONFIDENCE = {
"measured": 1.0,
"allocated": 0.70,
"estimated": 0.35,
"superseded": 0.0,
}
SOURCE_PARSER_DEFAULTS = {
"codex_session": "codex-desktop-v1",
"claude_transcript": "claude-transcript-v1",
"llm_connect": "llm-connect-v1",
}
def _event_total(event: TokenEvent) -> int:
return event.tokens_in + event.tokens_out
def _infer_measurement_kind(data: dict[str, Any]) -> str:
if data.get("measurement_kind"):
return str(data["measurement_kind"])
note = data.get("note")
if note == "heuristic_superseded_by_codex_backfill":
return "superseded"
if note == "workplan":
return "allocated"
if note == "heuristic":
return "estimated"
if note == "measured" or str(note or "").startswith("backfill:codex-session"):
return "measured"
provider = data.get("source_provider")
if provider in {"codex_session", "claude_transcript", "llm_connect"}:
return "measured"
return "estimated"
def _infer_source_provider(data: dict[str, Any], measurement_kind: str) -> str:
if data.get("source_provider"):
return str(data["source_provider"])
note = data.get("note")
ref_id = str(data.get("ref_id") or "")
agent = str(data.get("agent") or "").lower()
if note == "heuristic":
return "task_fallback"
if ref_id.startswith("codex:") or str(note or "").startswith("backfill:codex-session"):
return "codex_session"
if measurement_kind == "measured" and "claude" in agent:
return "claude_transcript"
return "manual"
def _apply_event_defaults(data: dict[str, Any]) -> dict[str, Any]:
measurement_kind = _infer_measurement_kind(data)
source_provider = _infer_source_provider(data, measurement_kind)
data["measurement_kind"] = measurement_kind
data["source_provider"] = source_provider
if not data.get("source_id") and source_provider in {"codex_session", "claude_transcript", "llm_connect"}:
source_id = data.get("ref_id") or data.get("session_id")
if source_id:
data["source_id"] = str(source_id)
if not data.get("source_created_at") and data.get("created_at") and data.get("source_id"):
data["source_created_at"] = data["created_at"]
data.setdefault("confidence", DEFAULT_CONFIDENCE.get(measurement_kind, 0.35))
data.setdefault("cached_input_tokens", 0)
data.setdefault("reasoning_output_tokens", 0)
data.setdefault("raw_total_tokens", (data.get("tokens_in") or 0) + (data.get("tokens_out") or 0))
data.setdefault("raw_metadata", {})
if source_provider in SOURCE_PARSER_DEFAULTS:
data.setdefault("parser_version", SOURCE_PARSER_DEFAULTS[source_provider])
return data
async def _populate_relationship_defaults(data: dict[str, Any], session: AsyncSession) -> dict[str, Any]:
# Auto-populate workstream_id from task if not provided
if data.get("task_id") and not data.get("workstream_id"):
task = await session.get(Task, data["task_id"])
if task:
data["workstream_id"] = task.workstream_id
# Auto-populate repo_id from workstream if not provided
if data.get("workstream_id") and not data.get("repo_id"):
ws = await session.get(Workstream, data["workstream_id"])
if ws and ws.repo_id:
data["repo_id"] = ws.repo_id
return data
async def _find_source_event(data: dict[str, Any], session: AsyncSession) -> TokenEvent | None:
source_id = data.get("source_id")
if not source_id:
return None
result = await session.execute(
select(TokenEvent).where(
TokenEvent.measurement_kind == data["measurement_kind"],
TokenEvent.source_provider == data["source_provider"],
TokenEvent.source_id == source_id,
)
)
return result.scalar_one_or_none()
async def _create_or_upsert_event(data: dict[str, Any], session: AsyncSession) -> TokenEvent:
data = _apply_event_defaults(data)
data = await _populate_relationship_defaults(data, session)
existing = await _find_source_event(data, session)
if existing is not None:
for field, value in data.items():
setattr(existing, field, value)
await session.commit()
await session.refresh(existing)
return existing
event = TokenEvent(**data)
session.add(event)
await session.commit()
await session.refresh(event)
return event
def _filter_query(
q,
*,
task_id: uuid.UUID | None = None,
workstream_id: uuid.UUID | None = None,
repo_id: uuid.UUID | None = None,
ref_type: str | None = None,
ref_id: str | None = None,
model: str | None = None,
agent: str | None = None,
note: str | None = None,
measurement_kind: str | None = None,
source_provider: str | None = None,
since: datetime | None = None,
until: datetime | None = None,
include_superseded: bool = True,
unattributed: bool = False,
):
if task_id:
q = q.where(TokenEvent.task_id == task_id)
if workstream_id:
q = q.where(TokenEvent.workstream_id == workstream_id)
if repo_id:
q = q.where(TokenEvent.repo_id == repo_id)
if ref_type:
q = q.where(TokenEvent.ref_type == ref_type)
if ref_id:
q = q.where(TokenEvent.ref_id == ref_id)
if model:
q = q.where(TokenEvent.model == model)
if agent:
q = q.where(TokenEvent.agent == agent)
if note:
q = q.where(TokenEvent.note == note)
if measurement_kind:
q = q.where(TokenEvent.measurement_kind == measurement_kind)
if source_provider:
q = q.where(TokenEvent.source_provider == source_provider)
if since:
q = q.where(TokenEvent.created_at >= since)
if until:
q = q.where(TokenEvent.created_at < until)
if not include_superseded:
q = q.where(TokenEvent.measurement_kind != "superseded")
if unattributed:
q = q.where(
TokenEvent.repo_id.is_(None),
TokenEvent.workstream_id.is_(None),
TokenEvent.task_id.is_(None),
)
return q
@router.post("/", response_model=TokenEventRead, status_code=status.HTTP_201_CREATED)
async def create_token_event(
body: TokenEventCreate,
session: AsyncSession = Depends(get_session),
) -> TokenEvent:
data = body.model_dump(exclude_none=True)
return await _create_or_upsert_event(data, session)
@router.post("/upsert", response_model=TokenEventRead)
async def upsert_token_event(
body: TokenEventCreate,
session: AsyncSession = Depends(get_session),
) -> TokenEvent:
data = body.model_dump(exclude_none=True)
return await _create_or_upsert_event(data, session)
@router.get("/summary/", response_model=TokenSummary)
async def get_token_summary(
scope: str = Query(..., description="task|workstream|repo|commit|release|session"),
id: str = Query(..., description="FK value or ref_id depending on scope"),
session: AsyncSession = Depends(get_session),
) -> TokenSummary:
q = select(TokenEvent)
if scope == "task":
try:
uid = uuid.UUID(id)
except ValueError:
raise HTTPException(status_code=422, detail="id must be a valid UUID for scope=task")
q = q.where(TokenEvent.task_id == uid)
elif scope == "workstream":
try:
uid = uuid.UUID(id)
except ValueError:
raise HTTPException(status_code=422, detail="id must be a valid UUID for scope=workstream")
q = q.where(TokenEvent.workstream_id == uid)
elif scope == "repo":
try:
uid = uuid.UUID(id)
except ValueError:
raise HTTPException(status_code=422, detail="id must be a valid UUID for scope=repo")
q = q.where(TokenEvent.repo_id == uid)
elif scope in ("commit", "release", "session"):
q = q.where(TokenEvent.ref_type == scope, TokenEvent.ref_id == id)
else:
raise HTTPException(status_code=422, detail=f"Unknown scope: {scope!r}")
result = await session.execute(q)
events = list(result.scalars().all())
tokens_in = sum(e.tokens_in for e in events)
tokens_out = sum(e.tokens_out for e in events)
by_model: dict[str, int] = defaultdict(int)
by_agent: dict[str, int] = defaultdict(int)
by_measurement_kind: dict[str, int] = defaultdict(int)
by_source_provider: dict[str, int] = defaultdict(int)
for e in events:
total = _event_total(e)
if e.model:
by_model[e.model] += total
if e.agent:
by_agent[e.agent] += total
by_measurement_kind[e.measurement_kind] += total
by_source_provider[e.source_provider] += total
return TokenSummary(
scope=scope,
scope_id=id,
tokens_in=tokens_in,
tokens_out=tokens_out,
tokens_total=tokens_in + tokens_out,
event_count=len(events),
by_model=dict(by_model),
by_agent=dict(by_agent),
by_measurement_kind=dict(by_measurement_kind),
by_source_provider=dict(by_source_provider),
)
@router.get("/by-repo/", response_model=list[RepoTokenSummary])
async def get_tokens_by_repo(
measurement_kind: str | None = None,
source_provider: str | None = None,
since: datetime | None = None,
until: datetime | None = None,
include_superseded: bool = Query(True),
session: AsyncSession = Depends(get_session),
) -> list[RepoTokenSummary]:
"""Aggregate token consumption per repo, resolving via the full graph.
Resolution order for each event:
1. token_events.repo_id (direct)
2. → workstreams.repo_id (via workstream_id)
3. → task.workstream_id → workstreams.repo_id (via task_id)
Only events that resolve to a repo are included.
"""
# Fetch all events, workstreams, repos in three queries (avoids N+1)
events_result = await session.execute(
_filter_query(
select(TokenEvent),
measurement_kind=measurement_kind,
source_provider=source_provider,
since=since,
until=until,
include_superseded=include_superseded,
)
)
events = list(events_result.scalars().all())
ws_result = await session.execute(select(Workstream))
ws_map: dict[uuid.UUID, Workstream] = {w.id: w for w in ws_result.scalars().all()}
task_result = await session.execute(select(Task))
task_map: dict[uuid.UUID, Task] = {t.id: t for t in task_result.scalars().all()}
repo_result = await session.execute(select(ManagedRepo))
repo_map: dict[uuid.UUID, ManagedRepo] = {r.id: r for r in repo_result.scalars().all()}
def resolve_repo_id(e: TokenEvent) -> uuid.UUID | None:
if e.repo_id:
return e.repo_id
ws_id = e.workstream_id
if not ws_id and e.task_id and e.task_id in task_map:
ws_id = task_map[e.task_id].workstream_id
if ws_id and ws_id in ws_map:
return ws_map[ws_id].repo_id
return None
groups: dict[uuid.UUID, dict] = {}
for e in events:
rid = resolve_repo_id(e)
if not rid or rid not in repo_map:
continue
if rid not in groups:
groups[rid] = {
"repo_id": rid,
"repo_slug": repo_map[rid].slug,
"tokens_in": 0,
"tokens_out": 0,
"event_count": 0,
"by_model": defaultdict(int),
"by_note": defaultdict(int),
"by_measurement_kind": defaultdict(int),
"by_source_provider": defaultdict(int),
}
g = groups[rid]
g["tokens_in"] += e.tokens_in
g["tokens_out"] += e.tokens_out
g["event_count"] += 1
total = _event_total(e)
if e.model:
g["by_model"][e.model] += total
g["by_note"][e.note or "unknown"] += total
g["by_measurement_kind"][e.measurement_kind] += total
g["by_source_provider"][e.source_provider] += total
return [
RepoTokenSummary(
**{k: (dict(v) if isinstance(v, defaultdict) else v) for k, v in g.items()},
tokens_total=g["tokens_in"] + g["tokens_out"],
)
for g in sorted(groups.values(), key=lambda x: -(x["tokens_in"] + x["tokens_out"]))
]
@router.get("/aggregate/", response_model=TokenAggregateSummary)
async def get_token_aggregate(
measurement_kind: str | None = None,
source_provider: str | None = None,
since: datetime | None = None,
until: datetime | None = None,
include_superseded: bool = Query(False),
session: AsyncSession = Depends(get_session),
) -> TokenAggregateSummary:
events_result = await session.execute(
_filter_query(
select(TokenEvent),
measurement_kind=measurement_kind,
source_provider=source_provider,
since=since,
until=until,
include_superseded=include_superseded,
)
)
events = list(events_result.scalars().all())
ws_result = await session.execute(select(Workstream))
ws_map: dict[uuid.UUID, Workstream] = {w.id: w for w in ws_result.scalars().all()}
task_result = await session.execute(select(Task))
task_map: dict[uuid.UUID, Task] = {t.id: t for t in task_result.scalars().all()}
repo_result = await session.execute(select(ManagedRepo))
repo_map: dict[uuid.UUID, ManagedRepo] = {r.id: r for r in repo_result.scalars().all()}
def resolve_repo_id(e: TokenEvent) -> uuid.UUID | None:
if e.repo_id:
return e.repo_id
ws_id = e.workstream_id
if not ws_id and e.task_id and e.task_id in task_map:
ws_id = task_map[e.task_id].workstream_id
if ws_id and ws_id in ws_map:
return ws_map[ws_id].repo_id
return None
def add(groups: dict[str, dict[str, Any]], key: str | None, label: str | None, e: TokenEvent) -> None:
if not key:
return
if key not in groups:
groups[key] = {
"scope_id": key,
"label": label,
"tokens_in": 0,
"tokens_out": 0,
"event_count": 0,
"by_measurement_kind": defaultdict(int),
"by_source_provider": defaultdict(int),
}
row = groups[key]
total = _event_total(e)
row["tokens_in"] += e.tokens_in
row["tokens_out"] += e.tokens_out
row["event_count"] += 1
row["by_measurement_kind"][e.measurement_kind] += total
row["by_source_provider"][e.source_provider] += total
by_repo: dict[str, dict[str, Any]] = {}
by_workstream: dict[str, dict[str, Any]] = {}
by_task: dict[str, dict[str, Any]] = {}
by_model: dict[str, dict[str, Any]] = {}
by_measurement_kind: dict[str, int] = defaultdict(int)
by_source_provider: dict[str, int] = defaultdict(int)
first_event_at = last_event_at = last_ingested_at = None
tokens_in = tokens_out = 0
for e in events:
total = _event_total(e)
tokens_in += e.tokens_in
tokens_out += e.tokens_out
by_measurement_kind[e.measurement_kind] += total
by_source_provider[e.source_provider] += total
if first_event_at is None or e.created_at < first_event_at:
first_event_at = e.created_at
if last_event_at is None or e.created_at > last_event_at:
last_event_at = e.created_at
if last_ingested_at is None or e.ingested_at > last_ingested_at:
last_ingested_at = e.ingested_at
rid = resolve_repo_id(e)
repo = repo_map.get(rid) if rid else None
add(by_repo, str(rid) if rid else None, repo.slug if repo else None, e)
ws_id = e.workstream_id or (task_map[e.task_id].workstream_id if e.task_id in task_map else None)
ws = ws_map.get(ws_id) if ws_id else None
add(by_workstream, str(ws_id) if ws_id else None, ws.title if ws else None, e)
task = task_map.get(e.task_id) if e.task_id else None
add(by_task, str(e.task_id) if e.task_id else None, task.title if task else None, e)
add(by_model, e.model or "unknown", e.model or "unknown", e)
def rows(groups: dict[str, dict[str, Any]]) -> list[TokenAggregateRow]:
result = []
for row in groups.values():
result.append(
TokenAggregateRow(
**{k: (dict(v) if isinstance(v, defaultdict) else v) for k, v in row.items()},
tokens_total=row["tokens_in"] + row["tokens_out"],
)
)
return sorted(result, key=lambda item: -item.tokens_total)
return TokenAggregateSummary(
tokens_in=tokens_in,
tokens_out=tokens_out,
tokens_total=tokens_in + tokens_out,
event_count=len(events),
first_event_at=first_event_at,
last_event_at=last_event_at,
last_ingested_at=last_ingested_at,
by_repo=rows(by_repo),
by_workstream=rows(by_workstream),
by_task=rows(by_task),
by_model=rows(by_model),
by_measurement_kind=dict(by_measurement_kind),
by_source_provider=dict(by_source_provider),
)
@router.get("/quality/", response_model=TokenQualitySummary)
async def get_token_quality(
since: datetime | None = None,
until: datetime | None = None,
session: AsyncSession = Depends(get_session),
) -> TokenQualitySummary:
result = await session.execute(_filter_query(select(TokenEvent), since=since, until=until))
events = list(result.scalars().all())
by_measurement_kind: dict[str, int] = defaultdict(int)
by_source_provider: dict[str, int] = defaultdict(int)
source_counts: dict[tuple[str, str, str], int] = defaultdict(int)
last_codex_ingested_at = None
last_claude_ingested_at = None
fallback_count = 0
unattributed_measured_count = 0
missing_provenance_count = 0
for e in events:
by_measurement_kind[e.measurement_kind] += 1
by_source_provider[e.source_provider] += 1
if e.source_id:
source_counts[(e.measurement_kind, e.source_provider, e.source_id)] += 1
if e.source_provider == "task_fallback" or e.note == "heuristic":
fallback_count += 1
if e.measurement_kind == "measured" and not (e.repo_id or e.workstream_id or e.task_id):
unattributed_measured_count += 1
if e.measurement_kind == "measured" and not e.source_id:
missing_provenance_count += 1
if e.source_provider == "codex_session" and (
last_codex_ingested_at is None or e.ingested_at > last_codex_ingested_at
):
last_codex_ingested_at = e.ingested_at
if e.source_provider == "claude_transcript" and (
last_claude_ingested_at is None or e.ingested_at > last_claude_ingested_at
):
last_claude_ingested_at = e.ingested_at
duplicate_source_count = sum(1 for count in source_counts.values() if count > 1)
return TokenQualitySummary(
event_count=len(events),
measured_event_count=by_measurement_kind.get("measured", 0),
estimated_event_count=by_measurement_kind.get("estimated", 0),
allocated_event_count=by_measurement_kind.get("allocated", 0),
superseded_event_count=by_measurement_kind.get("superseded", 0),
fallback_event_count=fallback_count,
unattributed_measured_event_count=unattributed_measured_count,
missing_provenance_event_count=missing_provenance_count,
duplicate_source_count=duplicate_source_count,
last_codex_ingested_at=last_codex_ingested_at,
last_claude_ingested_at=last_claude_ingested_at,
last_reconciliation_at=None,
by_measurement_kind=dict(by_measurement_kind),
by_source_provider=dict(by_source_provider),
)
@router.patch("/{event_id}", response_model=TokenEventRead)
async def patch_token_event(
event_id: uuid.UUID,
body: TokenEventPatch,
session: AsyncSession = Depends(get_session),
) -> TokenEvent:
event = await session.get(TokenEvent, event_id)
if event is None:
raise HTTPException(status_code=404, detail="Token event not found")
data = body.model_dump(exclude_none=True)
if "note" in data or "measurement_kind" in data or "source_provider" in data:
merged = {
"tokens_in": data.get("tokens_in", event.tokens_in),
"tokens_out": data.get("tokens_out", event.tokens_out),
"note": data.get("note", event.note),
"agent": data.get("agent", event.agent),
"ref_id": data.get("ref_id", event.ref_id),
"session_id": data.get("session_id", event.session_id),
"measurement_kind": data.get("measurement_kind", event.measurement_kind),
"source_provider": data.get("source_provider", event.source_provider),
"source_id": data.get("source_id", event.source_id),
}
inferred = _apply_event_defaults({k: v for k, v in merged.items() if v is not None})
data.setdefault("measurement_kind", inferred["measurement_kind"])
data.setdefault("source_provider", inferred["source_provider"])
data.setdefault("confidence", inferred["confidence"])
if inferred.get("source_id"):
data.setdefault("source_id", inferred["source_id"])
for field, value in data.items():
setattr(event, field, value)
await session.commit()
await session.refresh(event)
return event
@router.get("/{event_id}", response_model=TokenEventRead)
async def get_token_event(
event_id: uuid.UUID,
session: AsyncSession = Depends(get_session),
) -> TokenEvent:
event = await session.get(TokenEvent, event_id)
if event is None:
raise HTTPException(status_code=404, detail="Token event not found")
return event
@router.get("/", response_model=list[TokenEventRead])
async def list_token_events(
task_id: uuid.UUID | None = None,
workstream_id: uuid.UUID | None = None,
repo_id: uuid.UUID | None = None,
ref_type: str | None = None,
ref_id: str | None = None,
model: str | None = None,
agent: str | None = None,
note: str | None = None,
measurement_kind: str | None = None,
source_provider: str | None = None,
since: datetime | None = None,
until: datetime | None = None,
include_superseded: bool = Query(True),
unattributed: bool = False,
offset: int = Query(0, ge=0),
limit: int = Query(100, le=1000),
session: AsyncSession = Depends(get_session),
) -> list[TokenEvent]:
q = _filter_query(
select(TokenEvent),
task_id=task_id,
workstream_id=workstream_id,
repo_id=repo_id,
ref_type=ref_type,
ref_id=ref_id,
model=model,
agent=agent,
note=note,
measurement_kind=measurement_kind,
source_provider=source_provider,
since=since,
until=until,
include_superseded=include_superseded,
unattributed=unattributed,
)
q = q.order_by(TokenEvent.created_at.desc()).offset(offset).limit(limit)
result = await session.execute(q)
return list(result.scalars().all())