Files
state-hub/api/routers/tasks.py

205 lines
7.1 KiB
Python

import uuid
from datetime import date
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from api.database import get_session
from api.models.task import Task, TaskStatus
from api.models.token_event import TokenEvent
from api.models.workstream import Workstream
from api.schemas.task import TaskCreate, TaskRead, TaskUpdate
from api.services.lifecycle import status_value, transition_task_status
router = APIRouter(prefix="/tasks", tags=["tasks"])
@router.get("/", response_model=list[TaskRead])
async def list_tasks(
workstream_id: uuid.UUID | None = None,
status: TaskStatus | None = None,
assignee: str | None = None,
needs_human: bool | None = Query(None),
priority: str | None = None,
due_date_before: date | None = None,
session: AsyncSession = Depends(get_session),
) -> list[Task]:
q = select(Task)
if workstream_id:
q = q.where(Task.workstream_id == workstream_id)
if status:
q = q.where(Task.status == status)
if assignee:
q = q.where(Task.assignee == assignee)
if needs_human is not None:
q = q.where(Task.needs_human == needs_human)
if priority:
q = q.where(Task.priority == priority)
if due_date_before is not None:
q = q.where(Task.due_date <= due_date_before)
q = q.order_by(Task.created_at)
result = await session.execute(q)
return list(result.scalars().all())
@router.post("/", response_model=TaskRead, status_code=status.HTTP_201_CREATED)
async def create_task(
body: TaskCreate,
session: AsyncSession = Depends(get_session),
) -> Task:
task = Task(**body.model_dump())
session.add(task)
if status_value(task.status) == "in_progress":
ws = await session.get(Workstream, task.workstream_id)
transition_task_status(
task,
task.status,
parent_workstream=ws,
previous_task_status="todo",
)
await session.commit()
await session.refresh(task)
return task
@router.get("/{task_id}", response_model=TaskRead)
async def get_task(
task_id: uuid.UUID,
session: AsyncSession = Depends(get_session),
) -> Task:
task = await session.get(Task, task_id)
if task is None:
raise HTTPException(status_code=404, detail="Task not found")
return task
@router.patch("/{task_id}", response_model=TaskRead)
async def update_task(
task_id: uuid.UUID,
body: TaskUpdate,
session: AsyncSession = Depends(get_session),
) -> Task:
task = await session.get(Task, task_id)
if task is None:
raise HTTPException(status_code=404, detail="Task not found")
previous_status = status_value(task.status)
# Separate token fields from task fields
token_field_names = {
"tokens_in",
"tokens_out",
"workplan_tokens_in",
"workplan_tokens_out",
"token_note",
"model",
"agent",
"session_id",
"suppress_token_event",
}
update_data = body.model_dump(exclude_unset=True)
token_data = {k: update_data.pop(k) for k in list(update_data.keys()) if k in token_field_names}
suppress_token_event = bool(token_data.pop("suppress_token_event", False))
status_update = update_data.pop("status", None)
new_status = status_value(status_update) if status_update is not None else None
for field, value in update_data.items():
setattr(task, field, value)
if new_status is not None:
ws = await session.get(Workstream, task.workstream_id)
transition_task_status(
task,
status_update,
parent_workstream=ws,
previous_task_status=previous_status,
)
await session.commit()
await session.refresh(task)
# Token event — three-tier logic, only for an intentional transition to done.
if (
new_status == "done"
and previous_status != "done"
and not suppress_token_event
):
if "tokens_in" in token_data and "tokens_out" in token_data:
# Tier 1: exact counts — default note "measured"; caller may override with token_note
tin = token_data["tokens_in"]
tout = token_data["tokens_out"]
tnote = token_data.get("token_note") or "measured"
measurement_kind = "measured"
source_provider = "manual"
confidence = 1.0
source_id = f"task:{task_id}:manual"
raw_metadata = {"input_source": "task_status_patch"}
elif "workplan_tokens_in" in token_data and "workplan_tokens_out" in token_data:
# Tier 2: prorate workplan total across task count
count_result = await session.execute(
select(func.count(Task.id)).where(Task.workstream_id == task.workstream_id)
)
task_count = max(count_result.scalar() or 1, 1)
tin = token_data["workplan_tokens_in"] // task_count
tout = token_data["workplan_tokens_out"] // task_count
tnote = "workplan"
measurement_kind = "allocated"
source_provider = "manual"
confidence = 0.7
source_id = f"task:{task_id}:workplan-allocation"
raw_metadata = {
"allocation_method": "workplan_prorated",
"workplan_tokens_in": token_data["workplan_tokens_in"],
"workplan_tokens_out": token_data["workplan_tokens_out"],
"task_count": task_count,
}
else:
# Tier 3: heuristic fallback
tin, tout, tnote = 1000, 500, "heuristic"
measurement_kind = "estimated"
source_provider = "task_fallback"
confidence = 0.35
source_id = f"task:{task_id}:heuristic"
raw_metadata = {"estimation_method": "fixed_task_done_fallback"}
# Resolve repo_id via workstream
ws = await session.get(Workstream, task.workstream_id)
repo_id = ws.repo_id if ws else None
event = TokenEvent(
task_id=task_id,
workstream_id=task.workstream_id,
repo_id=repo_id,
tokens_in=tin,
tokens_out=tout,
model=token_data.get("model"),
agent=token_data.get("agent"),
session_id=token_data.get("session_id"),
ref_type="task",
ref_id=str(task_id),
note=tnote,
measurement_kind=measurement_kind,
source_provider=source_provider,
source_id=source_id,
confidence=confidence,
raw_total_tokens=tin + tout,
raw_metadata=raw_metadata,
)
session.add(event)
await session.commit()
return task
@router.delete("/{task_id}", response_model=TaskRead)
async def cancel_task(
task_id: uuid.UUID,
session: AsyncSession = Depends(get_session),
) -> Task:
task = await session.get(Task, task_id)
if task is None:
raise HTTPException(status_code=404, detail="Task not found")
transition_task_status(task, TaskStatus.cancelled)
await session.commit()
await session.refresh(task)
return task