import uuid from collections import defaultdict from fastapi import APIRouter, Depends, HTTPException, Query, status from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from api.database import get_session from api.models.task import Task from api.models.token_event import TokenEvent from api.schemas.token_event import TokenEventCreate, TokenEventRead, TokenSummary router = APIRouter(prefix="/token-events", tags=["token-events"]) @router.post("/", response_model=TokenEventRead, status_code=status.HTTP_201_CREATED) async def create_token_event( body: TokenEventCreate, session: AsyncSession = Depends(get_session), ) -> TokenEvent: data = body.model_dump() # Auto-populate workstream_id from task if not provided if data.get("task_id") and not data.get("workstream_id"): task = await session.get(Task, data["task_id"]) if task: data["workstream_id"] = task.workstream_id event = TokenEvent(**data) session.add(event) await session.commit() await session.refresh(event) return event @router.get("/summary/", response_model=TokenSummary) async def get_token_summary( scope: str = Query(..., description="task|workstream|repo|commit|release|session"), id: str = Query(..., description="FK value or ref_id depending on scope"), session: AsyncSession = Depends(get_session), ) -> TokenSummary: q = select(TokenEvent) if scope == "task": try: uid = uuid.UUID(id) except ValueError: raise HTTPException(status_code=422, detail="id must be a valid UUID for scope=task") q = q.where(TokenEvent.task_id == uid) elif scope == "workstream": try: uid = uuid.UUID(id) except ValueError: raise HTTPException(status_code=422, detail="id must be a valid UUID for scope=workstream") q = q.where(TokenEvent.workstream_id == uid) elif scope == "repo": try: uid = uuid.UUID(id) except ValueError: raise HTTPException(status_code=422, detail="id must be a valid UUID for scope=repo") q = q.where(TokenEvent.repo_id == uid) elif scope in ("commit", "release", "session"): q = q.where(TokenEvent.ref_type == scope, TokenEvent.ref_id == id) else: raise HTTPException(status_code=422, detail=f"Unknown scope: {scope!r}") result = await session.execute(q) events = list(result.scalars().all()) tokens_in = sum(e.tokens_in for e in events) tokens_out = sum(e.tokens_out for e in events) by_model: dict[str, int] = defaultdict(int) by_agent: dict[str, int] = defaultdict(int) for e in events: if e.model: by_model[e.model] += e.tokens_in + e.tokens_out if e.agent: by_agent[e.agent] += e.tokens_in + e.tokens_out return TokenSummary( scope=scope, scope_id=id, tokens_in=tokens_in, tokens_out=tokens_out, tokens_total=tokens_in + tokens_out, event_count=len(events), by_model=dict(by_model), by_agent=dict(by_agent), ) @router.get("/", response_model=list[TokenEventRead]) async def list_token_events( task_id: uuid.UUID | None = None, workstream_id: uuid.UUID | None = None, repo_id: uuid.UUID | None = None, ref_type: str | None = None, ref_id: str | None = None, model: str | None = None, agent: str | None = None, limit: int = Query(100, le=1000), session: AsyncSession = Depends(get_session), ) -> list[TokenEvent]: q = select(TokenEvent) if task_id: q = q.where(TokenEvent.task_id == task_id) if workstream_id: q = q.where(TokenEvent.workstream_id == workstream_id) if repo_id: q = q.where(TokenEvent.repo_id == repo_id) if ref_type: q = q.where(TokenEvent.ref_type == ref_type) if ref_id: q = q.where(TokenEvent.ref_id == ref_id) if model: q = q.where(TokenEvent.model == model) if agent: q = q.where(TokenEvent.agent == agent) q = q.order_by(TokenEvent.created_at.desc()).limit(limit) result = await session.execute(q) return list(result.scalars().all())