Files
state-hub/api/routers/tasks.py
tegwick af3fdfde80 feat(token-tracking): introduce token note taxonomy (measured/userbased/workplan/heuristic)
Tier 1 (exact counts) now defaults to note="measured" instead of null,
signalling the counts were read from the Claude Code status bar.
Callers can pass note="userbased" when a human provided the numbers.

  measured  — agent read exact counts from the Claude Code status bar
  userbased — counts provided by a human
  workplan  — prorated from workplan total across task count
  heuristic — server fallback, 1000/500, no agent input

Added token_note field to TaskUpdate schema and exposed note param on
update_task_status and record_interactive_task MCP tools.
TOOLS.md documents the full taxonomy. 185 tests pass.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-29 18:47:40 +02:00

137 lines
4.7 KiB
Python

import uuid
from datetime import date
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from api.database import get_session
from api.models.task import Task, TaskStatus
from api.models.token_event import TokenEvent
from api.schemas.task import TaskCreate, TaskRead, TaskUpdate
router = APIRouter(prefix="/tasks", tags=["tasks"])
@router.get("/", response_model=list[TaskRead])
async def list_tasks(
workstream_id: uuid.UUID | None = None,
status: TaskStatus | None = None,
assignee: str | None = None,
needs_human: bool | None = Query(None),
priority: str | None = None,
due_date_before: date | None = None,
session: AsyncSession = Depends(get_session),
) -> list[Task]:
q = select(Task)
if workstream_id:
q = q.where(Task.workstream_id == workstream_id)
if status:
q = q.where(Task.status == status)
if assignee:
q = q.where(Task.assignee == assignee)
if needs_human is not None:
q = q.where(Task.needs_human == needs_human)
if priority:
q = q.where(Task.priority == priority)
if due_date_before is not None:
q = q.where(Task.due_date <= due_date_before)
q = q.order_by(Task.created_at)
result = await session.execute(q)
return list(result.scalars().all())
@router.post("/", response_model=TaskRead, status_code=status.HTTP_201_CREATED)
async def create_task(
body: TaskCreate,
session: AsyncSession = Depends(get_session),
) -> Task:
task = Task(**body.model_dump())
session.add(task)
await session.commit()
await session.refresh(task)
return task
@router.get("/{task_id}", response_model=TaskRead)
async def get_task(
task_id: uuid.UUID,
session: AsyncSession = Depends(get_session),
) -> Task:
task = await session.get(Task, task_id)
if task is None:
raise HTTPException(status_code=404, detail="Task not found")
return task
@router.patch("/{task_id}", response_model=TaskRead)
async def update_task(
task_id: uuid.UUID,
body: TaskUpdate,
session: AsyncSession = Depends(get_session),
) -> Task:
task = await session.get(Task, task_id)
if task is None:
raise HTTPException(status_code=404, detail="Task not found")
# Separate token fields from task fields
token_field_names = {"tokens_in", "tokens_out", "workplan_tokens_in", "workplan_tokens_out", "token_note", "model", "agent", "session_id"}
update_data = body.model_dump(exclude_unset=True)
token_data = {k: update_data.pop(k) for k in list(update_data.keys()) if k in token_field_names}
for field, value in update_data.items():
setattr(task, field, value)
await session.commit()
await session.refresh(task)
# Token event — three-tier logic, only when marking done
if update_data.get("status") == "done":
if "tokens_in" in token_data and "tokens_out" in token_data:
# Tier 1: exact counts — default note "measured"; caller may override with token_note
tin = token_data["tokens_in"]
tout = token_data["tokens_out"]
tnote = token_data.get("token_note") or "measured"
elif "workplan_tokens_in" in token_data and "workplan_tokens_out" in token_data:
# Tier 2: prorate workplan total across task count
count_result = await session.execute(
select(func.count(Task.id)).where(Task.workstream_id == task.workstream_id)
)
task_count = max(count_result.scalar() or 1, 1)
tin = token_data["workplan_tokens_in"] // task_count
tout = token_data["workplan_tokens_out"] // task_count
tnote = "workplan"
else:
# Tier 3: heuristic fallback
tin, tout, tnote = 1000, 500, "heuristic"
event = TokenEvent(
task_id=task_id,
workstream_id=task.workstream_id,
tokens_in=tin,
tokens_out=tout,
model=token_data.get("model"),
agent=token_data.get("agent"),
session_id=token_data.get("session_id"),
ref_type="task",
ref_id=str(task_id),
note=tnote,
)
session.add(event)
await session.commit()
return task
@router.delete("/{task_id}", response_model=TaskRead)
async def cancel_task(
task_id: uuid.UUID,
session: AsyncSession = Depends(get_session),
) -> Task:
task = await session.get(Task, task_id)
if task is None:
raise HTTPException(status_code=404, detail="Task not found")
task.status = TaskStatus.cancelled
await session.commit()
await session.refresh(task)
return task