Files
state-hub/api/routers/tasks.py
tegwick acb30978cd feat(token-tracking): repo aggregation via graph walk (task→workstream→repo)
By Repo now resolves via the full chain rather than requiring repo_id
directly on the token event:
  1. token_events.repo_id (direct)
  2. → workstreams.repo_id (via workstream_id)
  3. → task.workstream_id → workstreams.repo_id (via task_id)

Changes:
- Auto-populate repo_id on token events at creation time (both the
  token_events router and the tasks router)
- New GET /token-events/by-repo/ endpoint with RepoTokenSummary schema;
  returns tokens_in/out/total, event_count, by_model, by_note per repo
- Dashboard By Repo section uses /by-repo/ directly and shows repo_slug
  instead of a truncated UUID
- Backfilled the three existing events (userbased) with repo_id via SQL

185 tests pass.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-29 19:05:23 +02:00

143 lines
4.9 KiB
Python

import uuid
from datetime import date
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from api.database import get_session
from api.models.task import Task, TaskStatus
from api.models.token_event import TokenEvent
from api.models.workstream import Workstream
from api.schemas.task import TaskCreate, TaskRead, TaskUpdate
router = APIRouter(prefix="/tasks", tags=["tasks"])
@router.get("/", response_model=list[TaskRead])
async def list_tasks(
workstream_id: uuid.UUID | None = None,
status: TaskStatus | None = None,
assignee: str | None = None,
needs_human: bool | None = Query(None),
priority: str | None = None,
due_date_before: date | None = None,
session: AsyncSession = Depends(get_session),
) -> list[Task]:
q = select(Task)
if workstream_id:
q = q.where(Task.workstream_id == workstream_id)
if status:
q = q.where(Task.status == status)
if assignee:
q = q.where(Task.assignee == assignee)
if needs_human is not None:
q = q.where(Task.needs_human == needs_human)
if priority:
q = q.where(Task.priority == priority)
if due_date_before is not None:
q = q.where(Task.due_date <= due_date_before)
q = q.order_by(Task.created_at)
result = await session.execute(q)
return list(result.scalars().all())
@router.post("/", response_model=TaskRead, status_code=status.HTTP_201_CREATED)
async def create_task(
body: TaskCreate,
session: AsyncSession = Depends(get_session),
) -> Task:
task = Task(**body.model_dump())
session.add(task)
await session.commit()
await session.refresh(task)
return task
@router.get("/{task_id}", response_model=TaskRead)
async def get_task(
task_id: uuid.UUID,
session: AsyncSession = Depends(get_session),
) -> Task:
task = await session.get(Task, task_id)
if task is None:
raise HTTPException(status_code=404, detail="Task not found")
return task
@router.patch("/{task_id}", response_model=TaskRead)
async def update_task(
task_id: uuid.UUID,
body: TaskUpdate,
session: AsyncSession = Depends(get_session),
) -> Task:
task = await session.get(Task, task_id)
if task is None:
raise HTTPException(status_code=404, detail="Task not found")
# Separate token fields from task fields
token_field_names = {"tokens_in", "tokens_out", "workplan_tokens_in", "workplan_tokens_out", "token_note", "model", "agent", "session_id"}
update_data = body.model_dump(exclude_unset=True)
token_data = {k: update_data.pop(k) for k in list(update_data.keys()) if k in token_field_names}
for field, value in update_data.items():
setattr(task, field, value)
await session.commit()
await session.refresh(task)
# Token event — three-tier logic, only when marking done
if update_data.get("status") == "done":
if "tokens_in" in token_data and "tokens_out" in token_data:
# Tier 1: exact counts — default note "measured"; caller may override with token_note
tin = token_data["tokens_in"]
tout = token_data["tokens_out"]
tnote = token_data.get("token_note") or "measured"
elif "workplan_tokens_in" in token_data and "workplan_tokens_out" in token_data:
# Tier 2: prorate workplan total across task count
count_result = await session.execute(
select(func.count(Task.id)).where(Task.workstream_id == task.workstream_id)
)
task_count = max(count_result.scalar() or 1, 1)
tin = token_data["workplan_tokens_in"] // task_count
tout = token_data["workplan_tokens_out"] // task_count
tnote = "workplan"
else:
# Tier 3: heuristic fallback
tin, tout, tnote = 1000, 500, "heuristic"
# Resolve repo_id via workstream
ws = await session.get(Workstream, task.workstream_id)
repo_id = ws.repo_id if ws else None
event = TokenEvent(
task_id=task_id,
workstream_id=task.workstream_id,
repo_id=repo_id,
tokens_in=tin,
tokens_out=tout,
model=token_data.get("model"),
agent=token_data.get("agent"),
session_id=token_data.get("session_id"),
ref_type="task",
ref_id=str(task_id),
note=tnote,
)
session.add(event)
await session.commit()
return task
@router.delete("/{task_id}", response_model=TaskRead)
async def cancel_task(
task_id: uuid.UUID,
session: AsyncSession = Depends(get_session),
) -> Task:
task = await session.get(Task, task_id)
if task is None:
raise HTTPException(status_code=404, detail="Task not found")
task.status = TaskStatus.cancelled
await session.commit()
await session.refresh(task)
return task