Files
state-hub/api/routers/tasks.py

316 lines
11 KiB
Python

import uuid
from datetime import date
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from api.database import get_session
from api.models.progress_event import ProgressEvent
from api.models.task import Task, TaskStatus
from api.models.token_event import TokenEvent
from api.models.workstream import Workstream
from api.schemas.task import (
TaskCountRead,
TaskCreate,
TaskRead,
TaskStatusBulkSync,
TaskStatusBulkSyncRead,
TaskUpdate,
)
from api.services.lifecycle import status_value, transition_task_status
from api.task_status import normalize_task_status
router = APIRouter(prefix="/tasks", tags=["tasks"])
@router.get("/", response_model=list[TaskRead])
async def list_tasks(
workstream_id: uuid.UUID | None = None,
status: str | None = None,
assignee: str | None = None,
needs_human: bool | None = Query(None),
priority: str | None = None,
due_date_before: date | None = None,
limit: int | None = Query(None, ge=1, le=5000),
offset: int = Query(0, ge=0),
session: AsyncSession = Depends(get_session),
) -> list[Task]:
q = select(Task)
if workstream_id:
q = q.where(Task.workstream_id == workstream_id)
if status:
q = q.where(Task.status == TaskStatus(normalize_task_status(status)))
if assignee:
q = q.where(Task.assignee == assignee)
if needs_human is not None:
q = q.where(Task.needs_human == needs_human)
if priority:
q = q.where(Task.priority == priority)
if due_date_before is not None:
q = q.where(Task.due_date <= due_date_before)
q = q.order_by(Task.created_at)
if offset:
q = q.offset(offset)
if limit is not None:
q = q.limit(limit)
result = await session.execute(q)
return list(result.scalars().all())
@router.get("/counts", response_model=list[TaskCountRead])
async def count_tasks(
workstream_id: uuid.UUID | None = None,
status: str | None = None,
session: AsyncSession = Depends(get_session),
) -> list[TaskCountRead]:
q = select(Task.workstream_id, Task.status, func.count()).group_by(Task.workstream_id, Task.status)
if workstream_id:
q = q.where(Task.workstream_id == workstream_id)
if status:
q = q.where(Task.status == TaskStatus(normalize_task_status(status)))
rows = await session.execute(q)
return [
TaskCountRead(workstream_id=ws_id, status=task_status, count=count)
for ws_id, task_status, count in rows
]
@router.post("/", response_model=TaskRead, status_code=status.HTTP_201_CREATED)
async def create_task(
body: TaskCreate,
session: AsyncSession = Depends(get_session),
) -> Task:
task = Task(**body.model_dump())
session.add(task)
if status_value(task.status) == "progress":
ws = await session.get(Workstream, task.workstream_id)
transition_task_status(
task,
task.status,
parent_workstream=ws,
previous_task_status="todo",
)
await session.commit()
await session.refresh(task)
return task
@router.post("/bulk-status-sync", response_model=TaskStatusBulkSyncRead)
async def bulk_status_sync(
body: TaskStatusBulkSync,
session: AsyncSession = Depends(get_session),
) -> TaskStatusBulkSyncRead:
seen: set[uuid.UUID] = set()
duplicate_ids: list[str] = []
tasks_by_id: dict[uuid.UUID, Task] = {}
missing_ids: list[str] = []
for update in body.updates:
if update.task_id in seen:
duplicate_ids.append(str(update.task_id))
continue
seen.add(update.task_id)
task = await session.get(Task, update.task_id)
if task is None:
missing_ids.append(str(update.task_id))
else:
tasks_by_id[update.task_id] = task
if duplicate_ids:
raise HTTPException(
status_code=400,
detail={"message": "duplicate task_id values are not allowed", "task_ids": duplicate_ids},
)
if missing_ids:
raise HTTPException(
status_code=404,
detail={"message": "one or more tasks were not found", "task_ids": missing_ids},
)
updated: list[Task] = []
events: list[ProgressEvent] = []
author = body.author or "custodian"
for update in body.updates:
task = tasks_by_id[update.task_id]
previous_status = status_value(task.status)
target_status = status_value(update.status)
if update.blocking_reason is not None:
task.blocking_reason = update.blocking_reason
ws = await session.get(Workstream, task.workstream_id)
transition_task_status(
task,
update.status,
parent_workstream=ws,
previous_task_status=previous_status,
)
event = ProgressEvent(
task_id=task.id,
workstream_id=task.workstream_id,
event_type="task_status_changed",
summary=f"Task status -> {target_status}: {task.title}",
author=author,
session_id=body.session_id,
detail={
"bulk_status_sync": True,
"previous_status": previous_status,
"status": target_status,
"blocking_reason": update.blocking_reason,
},
)
session.add(event)
updated.append(task)
events.append(event)
await session.commit()
for task in updated:
await session.refresh(task)
for event in events:
await session.refresh(event)
return TaskStatusBulkSyncRead(
updated=updated,
progress_event_ids=[event.id for event in events],
)
@router.get("/{task_id}", response_model=TaskRead)
async def get_task(
task_id: uuid.UUID,
session: AsyncSession = Depends(get_session),
) -> Task:
task = await session.get(Task, task_id)
if task is None:
raise HTTPException(status_code=404, detail="Task not found")
return task
@router.patch("/{task_id}", response_model=TaskRead)
async def update_task(
task_id: uuid.UUID,
body: TaskUpdate,
session: AsyncSession = Depends(get_session),
) -> Task:
task = await session.get(Task, task_id)
if task is None:
raise HTTPException(status_code=404, detail="Task not found")
previous_status = status_value(task.status)
# Separate token fields from task fields
token_field_names = {
"tokens_in",
"tokens_out",
"workplan_tokens_in",
"workplan_tokens_out",
"token_note",
"model",
"agent",
"session_id",
"suppress_token_event",
}
update_data = body.model_dump(exclude_unset=True)
token_data = {k: update_data.pop(k) for k in list(update_data.keys()) if k in token_field_names}
suppress_token_event = bool(token_data.pop("suppress_token_event", False))
status_update = update_data.pop("status", None)
new_status = status_value(status_update) if status_update is not None else None
for field, value in update_data.items():
setattr(task, field, value)
if new_status is not None:
ws = await session.get(Workstream, task.workstream_id)
transition_task_status(
task,
status_update,
parent_workstream=ws,
previous_task_status=previous_status,
)
await session.commit()
await session.refresh(task)
# Token event — three-tier logic, only for an intentional transition to done.
if (
new_status == "done"
and previous_status != "done"
and not suppress_token_event
):
if "tokens_in" in token_data and "tokens_out" in token_data:
# Tier 1: exact counts — default note "measured"; caller may override with token_note
tin = token_data["tokens_in"]
tout = token_data["tokens_out"]
tnote = token_data.get("token_note") or "measured"
measurement_kind = "measured"
source_provider = "manual"
confidence = 1.0
source_id = f"task:{task_id}:manual"
raw_metadata = {"input_source": "task_status_patch"}
elif "workplan_tokens_in" in token_data and "workplan_tokens_out" in token_data:
# Tier 2: prorate workplan total across task count
count_result = await session.execute(
select(func.count(Task.id)).where(Task.workstream_id == task.workstream_id)
)
task_count = max(count_result.scalar() or 1, 1)
tin = token_data["workplan_tokens_in"] // task_count
tout = token_data["workplan_tokens_out"] // task_count
tnote = "workplan"
measurement_kind = "allocated"
source_provider = "manual"
confidence = 0.7
source_id = f"task:{task_id}:workplan-allocation"
raw_metadata = {
"allocation_method": "workplan_prorated",
"workplan_tokens_in": token_data["workplan_tokens_in"],
"workplan_tokens_out": token_data["workplan_tokens_out"],
"task_count": task_count,
}
else:
# Tier 3: heuristic fallback
tin, tout, tnote = 1000, 500, "heuristic"
measurement_kind = "estimated"
source_provider = "task_fallback"
confidence = 0.35
source_id = f"task:{task_id}:heuristic"
raw_metadata = {"estimation_method": "fixed_task_done_fallback"}
# Resolve repo_id via workstream
ws = await session.get(Workstream, task.workstream_id)
repo_id = ws.repo_id if ws else None
event = TokenEvent(
task_id=task_id,
workstream_id=task.workstream_id,
repo_id=repo_id,
tokens_in=tin,
tokens_out=tout,
model=token_data.get("model"),
agent=token_data.get("agent"),
session_id=token_data.get("session_id"),
ref_type="task",
ref_id=str(task_id),
note=tnote,
measurement_kind=measurement_kind,
source_provider=source_provider,
source_id=source_id,
confidence=confidence,
raw_total_tokens=tin + tout,
raw_metadata=raw_metadata,
)
session.add(event)
await session.commit()
return task
@router.delete("/{task_id}", response_model=TaskRead)
async def cancel_task(
task_id: uuid.UUID,
session: AsyncSession = Depends(get_session),
) -> Task:
task = await session.get(Task, task_id)
if task is None:
raise HTTPException(status_code=404, detail="Task not found")
transition_task_status(task, TaskStatus.cancel)
await session.commit()
await session.refresh(task)
return task