import uuid from datetime import date from fastapi import APIRouter, Depends, HTTPException, Query, status from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession from api.database import get_session from api.models.task import Task, TaskStatus from api.models.token_event import TokenEvent from api.models.workstream import Workstream from api.schemas.task import TaskCreate, TaskRead, TaskUpdate router = APIRouter(prefix="/tasks", tags=["tasks"]) @router.get("/", response_model=list[TaskRead]) async def list_tasks( workstream_id: uuid.UUID | None = None, status: TaskStatus | None = None, assignee: str | None = None, needs_human: bool | None = Query(None), priority: str | None = None, due_date_before: date | None = None, session: AsyncSession = Depends(get_session), ) -> list[Task]: q = select(Task) if workstream_id: q = q.where(Task.workstream_id == workstream_id) if status: q = q.where(Task.status == status) if assignee: q = q.where(Task.assignee == assignee) if needs_human is not None: q = q.where(Task.needs_human == needs_human) if priority: q = q.where(Task.priority == priority) if due_date_before is not None: q = q.where(Task.due_date <= due_date_before) q = q.order_by(Task.created_at) result = await session.execute(q) return list(result.scalars().all()) @router.post("/", response_model=TaskRead, status_code=status.HTTP_201_CREATED) async def create_task( body: TaskCreate, session: AsyncSession = Depends(get_session), ) -> Task: task = Task(**body.model_dump()) session.add(task) await session.commit() await session.refresh(task) return task @router.get("/{task_id}", response_model=TaskRead) async def get_task( task_id: uuid.UUID, session: AsyncSession = Depends(get_session), ) -> Task: task = await session.get(Task, task_id) if task is None: raise HTTPException(status_code=404, detail="Task not found") return task @router.patch("/{task_id}", response_model=TaskRead) async def update_task( task_id: uuid.UUID, body: TaskUpdate, session: AsyncSession = Depends(get_session), ) -> Task: task = await session.get(Task, task_id) if task is None: raise HTTPException(status_code=404, detail="Task not found") # Separate token fields from task fields token_field_names = {"tokens_in", "tokens_out", "workplan_tokens_in", "workplan_tokens_out", "token_note", "model", "agent", "session_id"} update_data = body.model_dump(exclude_unset=True) token_data = {k: update_data.pop(k) for k in list(update_data.keys()) if k in token_field_names} for field, value in update_data.items(): setattr(task, field, value) await session.commit() await session.refresh(task) # Token event — three-tier logic, only when marking done if update_data.get("status") == "done": if "tokens_in" in token_data and "tokens_out" in token_data: # Tier 1: exact counts — default note "measured"; caller may override with token_note tin = token_data["tokens_in"] tout = token_data["tokens_out"] tnote = token_data.get("token_note") or "measured" elif "workplan_tokens_in" in token_data and "workplan_tokens_out" in token_data: # Tier 2: prorate workplan total across task count count_result = await session.execute( select(func.count(Task.id)).where(Task.workstream_id == task.workstream_id) ) task_count = max(count_result.scalar() or 1, 1) tin = token_data["workplan_tokens_in"] // task_count tout = token_data["workplan_tokens_out"] // task_count tnote = "workplan" else: # Tier 3: heuristic fallback tin, tout, tnote = 1000, 500, "heuristic" # Resolve repo_id via workstream ws = await session.get(Workstream, task.workstream_id) repo_id = ws.repo_id if ws else None event = TokenEvent( task_id=task_id, workstream_id=task.workstream_id, repo_id=repo_id, tokens_in=tin, tokens_out=tout, model=token_data.get("model"), agent=token_data.get("agent"), session_id=token_data.get("session_id"), ref_type="task", ref_id=str(task_id), note=tnote, ) session.add(event) await session.commit() return task @router.delete("/{task_id}", response_model=TaskRead) async def cancel_task( task_id: uuid.UUID, session: AsyncSession = Depends(get_session), ) -> Task: task = await session.get(Task, task_id) if task is None: raise HTTPException(status_code=404, detail="Task not found") task.status = TaskStatus.cancelled await session.commit() await session.refresh(task) return task