import uuid from collections import defaultdict from fastapi import APIRouter, Depends, HTTPException, Query, status from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from api.database import get_session from api.models.managed_repo import ManagedRepo from api.models.task import Task from api.models.token_event import TokenEvent from api.models.workstream import Workstream from api.schemas.token_event import RepoTokenSummary, TokenEventCreate, TokenEventPatch, TokenEventRead, TokenSummary router = APIRouter(prefix="/token-events", tags=["token-events"]) @router.post("/", response_model=TokenEventRead, status_code=status.HTTP_201_CREATED) async def create_token_event( body: TokenEventCreate, session: AsyncSession = Depends(get_session), ) -> TokenEvent: data = body.model_dump() # Auto-populate workstream_id from task if not provided if data.get("task_id") and not data.get("workstream_id"): task = await session.get(Task, data["task_id"]) if task: data["workstream_id"] = task.workstream_id # Auto-populate repo_id from workstream if not provided if data.get("workstream_id") and not data.get("repo_id"): ws = await session.get(Workstream, data["workstream_id"]) if ws and ws.repo_id: data["repo_id"] = ws.repo_id event = TokenEvent(**data) session.add(event) await session.commit() await session.refresh(event) return event @router.get("/summary/", response_model=TokenSummary) async def get_token_summary( scope: str = Query(..., description="task|workstream|repo|commit|release|session"), id: str = Query(..., description="FK value or ref_id depending on scope"), session: AsyncSession = Depends(get_session), ) -> TokenSummary: q = select(TokenEvent) if scope == "task": try: uid = uuid.UUID(id) except ValueError: raise HTTPException(status_code=422, detail="id must be a valid UUID for scope=task") q = q.where(TokenEvent.task_id == uid) elif scope == "workstream": try: uid = uuid.UUID(id) except ValueError: raise HTTPException(status_code=422, detail="id must be a valid UUID for scope=workstream") q = q.where(TokenEvent.workstream_id == uid) elif scope == "repo": try: uid = uuid.UUID(id) except ValueError: raise HTTPException(status_code=422, detail="id must be a valid UUID for scope=repo") q = q.where(TokenEvent.repo_id == uid) elif scope in ("commit", "release", "session"): q = q.where(TokenEvent.ref_type == scope, TokenEvent.ref_id == id) else: raise HTTPException(status_code=422, detail=f"Unknown scope: {scope!r}") result = await session.execute(q) events = list(result.scalars().all()) tokens_in = sum(e.tokens_in for e in events) tokens_out = sum(e.tokens_out for e in events) by_model: dict[str, int] = defaultdict(int) by_agent: dict[str, int] = defaultdict(int) for e in events: if e.model: by_model[e.model] += e.tokens_in + e.tokens_out if e.agent: by_agent[e.agent] += e.tokens_in + e.tokens_out return TokenSummary( scope=scope, scope_id=id, tokens_in=tokens_in, tokens_out=tokens_out, tokens_total=tokens_in + tokens_out, event_count=len(events), by_model=dict(by_model), by_agent=dict(by_agent), ) @router.get("/by-repo/", response_model=list[RepoTokenSummary]) async def get_tokens_by_repo( session: AsyncSession = Depends(get_session), ) -> list[RepoTokenSummary]: """Aggregate token consumption per repo, resolving via the full graph. Resolution order for each event: 1. token_events.repo_id (direct) 2. → workstreams.repo_id (via workstream_id) 3. → task.workstream_id → workstreams.repo_id (via task_id) Only events that resolve to a repo are included. """ # Fetch all events, workstreams, repos in three queries (avoids N+1) events_result = await session.execute(select(TokenEvent)) events = list(events_result.scalars().all()) ws_result = await session.execute(select(Workstream)) ws_map: dict[uuid.UUID, Workstream] = {w.id: w for w in ws_result.scalars().all()} task_result = await session.execute(select(Task)) task_map: dict[uuid.UUID, Task] = {t.id: t for t in task_result.scalars().all()} repo_result = await session.execute(select(ManagedRepo)) repo_map: dict[uuid.UUID, ManagedRepo] = {r.id: r for r in repo_result.scalars().all()} def resolve_repo_id(e: TokenEvent) -> uuid.UUID | None: if e.repo_id: return e.repo_id ws_id = e.workstream_id if not ws_id and e.task_id and e.task_id in task_map: ws_id = task_map[e.task_id].workstream_id if ws_id and ws_id in ws_map: return ws_map[ws_id].repo_id return None groups: dict[uuid.UUID, dict] = {} for e in events: rid = resolve_repo_id(e) if not rid or rid not in repo_map: continue if rid not in groups: groups[rid] = { "repo_id": rid, "repo_slug": repo_map[rid].slug, "tokens_in": 0, "tokens_out": 0, "event_count": 0, "by_model": defaultdict(int), "by_note": defaultdict(int), } g = groups[rid] g["tokens_in"] += e.tokens_in g["tokens_out"] += e.tokens_out g["event_count"] += 1 if e.model: g["by_model"][e.model] += e.tokens_in + e.tokens_out g["by_note"][e.note or "unknown"] += e.tokens_in + e.tokens_out return [ RepoTokenSummary( **{k: (dict(v) if isinstance(v, defaultdict) else v) for k, v in g.items()}, tokens_total=g["tokens_in"] + g["tokens_out"], ) for g in sorted(groups.values(), key=lambda x: -(x["tokens_in"] + x["tokens_out"])) ] @router.patch("/{event_id}", response_model=TokenEventRead) async def patch_token_event( event_id: uuid.UUID, body: TokenEventPatch, session: AsyncSession = Depends(get_session), ) -> TokenEvent: event = await session.get(TokenEvent, event_id) if event is None: raise HTTPException(status_code=404, detail="Token event not found") for field, value in body.model_dump(exclude_none=True).items(): setattr(event, field, value) await session.commit() await session.refresh(event) return event @router.get("/{event_id}", response_model=TokenEventRead) async def get_token_event( event_id: uuid.UUID, session: AsyncSession = Depends(get_session), ) -> TokenEvent: event = await session.get(TokenEvent, event_id) if event is None: raise HTTPException(status_code=404, detail="Token event not found") return event @router.get("/", response_model=list[TokenEventRead]) async def list_token_events( task_id: uuid.UUID | None = None, workstream_id: uuid.UUID | None = None, repo_id: uuid.UUID | None = None, ref_type: str | None = None, ref_id: str | None = None, model: str | None = None, agent: str | None = None, note: str | None = None, limit: int = Query(100, le=1000), session: AsyncSession = Depends(get_session), ) -> list[TokenEvent]: q = select(TokenEvent) if task_id: q = q.where(TokenEvent.task_id == task_id) if workstream_id: q = q.where(TokenEvent.workstream_id == workstream_id) if repo_id: q = q.where(TokenEvent.repo_id == repo_id) if ref_type: q = q.where(TokenEvent.ref_type == ref_type) if ref_id: q = q.where(TokenEvent.ref_id == ref_id) if model: q = q.where(TokenEvent.model == model) if agent: q = q.where(TokenEvent.agent == agent) if note: q = q.where(TokenEvent.note == note) q = q.order_by(TokenEvent.created_at.desc()).limit(limit) result = await session.execute(q) return list(result.scalars().all())