import uuid from datetime import date from fastapi import APIRouter, Depends, HTTPException, Query, status from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession from api.database import get_session from api.models.task import Task, TaskStatus from api.models.token_event import TokenEvent from api.models.workstream import Workstream from api.schemas.task import TaskCreate, TaskRead, TaskUpdate from api.services.lifecycle import status_value, transition_task_status from api.task_status import normalize_task_status router = APIRouter(prefix="/tasks", tags=["tasks"]) @router.get("/", response_model=list[TaskRead]) async def list_tasks( workstream_id: uuid.UUID | None = None, status: str | None = None, assignee: str | None = None, needs_human: bool | None = Query(None), priority: str | None = None, due_date_before: date | None = None, session: AsyncSession = Depends(get_session), ) -> list[Task]: q = select(Task) if workstream_id: q = q.where(Task.workstream_id == workstream_id) if status: q = q.where(Task.status == TaskStatus(normalize_task_status(status))) if assignee: q = q.where(Task.assignee == assignee) if needs_human is not None: q = q.where(Task.needs_human == needs_human) if priority: q = q.where(Task.priority == priority) if due_date_before is not None: q = q.where(Task.due_date <= due_date_before) q = q.order_by(Task.created_at) result = await session.execute(q) return list(result.scalars().all()) @router.post("/", response_model=TaskRead, status_code=status.HTTP_201_CREATED) async def create_task( body: TaskCreate, session: AsyncSession = Depends(get_session), ) -> Task: task = Task(**body.model_dump()) session.add(task) if status_value(task.status) == "progress": ws = await session.get(Workstream, task.workstream_id) transition_task_status( task, task.status, parent_workstream=ws, previous_task_status="todo", ) await session.commit() await session.refresh(task) return task @router.get("/{task_id}", response_model=TaskRead) async def get_task( task_id: uuid.UUID, session: AsyncSession = Depends(get_session), ) -> Task: task = await session.get(Task, task_id) if task is None: raise HTTPException(status_code=404, detail="Task not found") return task @router.patch("/{task_id}", response_model=TaskRead) async def update_task( task_id: uuid.UUID, body: TaskUpdate, session: AsyncSession = Depends(get_session), ) -> Task: task = await session.get(Task, task_id) if task is None: raise HTTPException(status_code=404, detail="Task not found") previous_status = status_value(task.status) # Separate token fields from task fields token_field_names = { "tokens_in", "tokens_out", "workplan_tokens_in", "workplan_tokens_out", "token_note", "model", "agent", "session_id", "suppress_token_event", } update_data = body.model_dump(exclude_unset=True) token_data = {k: update_data.pop(k) for k in list(update_data.keys()) if k in token_field_names} suppress_token_event = bool(token_data.pop("suppress_token_event", False)) status_update = update_data.pop("status", None) new_status = status_value(status_update) if status_update is not None else None for field, value in update_data.items(): setattr(task, field, value) if new_status is not None: ws = await session.get(Workstream, task.workstream_id) transition_task_status( task, status_update, parent_workstream=ws, previous_task_status=previous_status, ) await session.commit() await session.refresh(task) # Token event — three-tier logic, only for an intentional transition to done. if ( new_status == "done" and previous_status != "done" and not suppress_token_event ): if "tokens_in" in token_data and "tokens_out" in token_data: # Tier 1: exact counts — default note "measured"; caller may override with token_note tin = token_data["tokens_in"] tout = token_data["tokens_out"] tnote = token_data.get("token_note") or "measured" measurement_kind = "measured" source_provider = "manual" confidence = 1.0 source_id = f"task:{task_id}:manual" raw_metadata = {"input_source": "task_status_patch"} elif "workplan_tokens_in" in token_data and "workplan_tokens_out" in token_data: # Tier 2: prorate workplan total across task count count_result = await session.execute( select(func.count(Task.id)).where(Task.workstream_id == task.workstream_id) ) task_count = max(count_result.scalar() or 1, 1) tin = token_data["workplan_tokens_in"] // task_count tout = token_data["workplan_tokens_out"] // task_count tnote = "workplan" measurement_kind = "allocated" source_provider = "manual" confidence = 0.7 source_id = f"task:{task_id}:workplan-allocation" raw_metadata = { "allocation_method": "workplan_prorated", "workplan_tokens_in": token_data["workplan_tokens_in"], "workplan_tokens_out": token_data["workplan_tokens_out"], "task_count": task_count, } else: # Tier 3: heuristic fallback tin, tout, tnote = 1000, 500, "heuristic" measurement_kind = "estimated" source_provider = "task_fallback" confidence = 0.35 source_id = f"task:{task_id}:heuristic" raw_metadata = {"estimation_method": "fixed_task_done_fallback"} # Resolve repo_id via workstream ws = await session.get(Workstream, task.workstream_id) repo_id = ws.repo_id if ws else None event = TokenEvent( task_id=task_id, workstream_id=task.workstream_id, repo_id=repo_id, tokens_in=tin, tokens_out=tout, model=token_data.get("model"), agent=token_data.get("agent"), session_id=token_data.get("session_id"), ref_type="task", ref_id=str(task_id), note=tnote, measurement_kind=measurement_kind, source_provider=source_provider, source_id=source_id, confidence=confidence, raw_total_tokens=tin + tout, raw_metadata=raw_metadata, ) session.add(event) await session.commit() return task @router.delete("/{task_id}", response_model=TaskRead) async def cancel_task( task_id: uuid.UUID, session: AsyncSession = Depends(get_session), ) -> Task: task = await session.get(Task, task_id) if task is None: raise HTTPException(status_code=404, detail="Task not found") transition_task_status(task, TaskStatus.cancel) await session.commit() await session.refresh(task) return task