generated from coulomb/repo-seed
316 lines
11 KiB
Python
316 lines
11 KiB
Python
import uuid
|
|
from datetime import date
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
|
from sqlalchemy import func, select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from api.database import get_session
|
|
from api.models.progress_event import ProgressEvent
|
|
from api.models.task import Task, TaskStatus
|
|
from api.models.token_event import TokenEvent
|
|
from api.models.workstream import Workstream
|
|
from api.schemas.task import (
|
|
TaskCountRead,
|
|
TaskCreate,
|
|
TaskRead,
|
|
TaskStatusBulkSync,
|
|
TaskStatusBulkSyncRead,
|
|
TaskUpdate,
|
|
)
|
|
from api.services.lifecycle import status_value, transition_task_status
|
|
from api.task_status import normalize_task_status
|
|
|
|
router = APIRouter(prefix="/tasks", tags=["tasks"])
|
|
|
|
|
|
@router.get("/", response_model=list[TaskRead])
|
|
async def list_tasks(
|
|
workstream_id: uuid.UUID | None = None,
|
|
status: str | None = None,
|
|
assignee: str | None = None,
|
|
needs_human: bool | None = Query(None),
|
|
priority: str | None = None,
|
|
due_date_before: date | None = None,
|
|
limit: int | None = Query(None, ge=1, le=5000),
|
|
offset: int = Query(0, ge=0),
|
|
session: AsyncSession = Depends(get_session),
|
|
) -> list[Task]:
|
|
q = select(Task)
|
|
if workstream_id:
|
|
q = q.where(Task.workstream_id == workstream_id)
|
|
if status:
|
|
q = q.where(Task.status == TaskStatus(normalize_task_status(status)))
|
|
if assignee:
|
|
q = q.where(Task.assignee == assignee)
|
|
if needs_human is not None:
|
|
q = q.where(Task.needs_human == needs_human)
|
|
if priority:
|
|
q = q.where(Task.priority == priority)
|
|
if due_date_before is not None:
|
|
q = q.where(Task.due_date <= due_date_before)
|
|
q = q.order_by(Task.created_at)
|
|
if offset:
|
|
q = q.offset(offset)
|
|
if limit is not None:
|
|
q = q.limit(limit)
|
|
result = await session.execute(q)
|
|
return list(result.scalars().all())
|
|
|
|
|
|
@router.get("/counts", response_model=list[TaskCountRead])
|
|
async def count_tasks(
|
|
workstream_id: uuid.UUID | None = None,
|
|
status: str | None = None,
|
|
session: AsyncSession = Depends(get_session),
|
|
) -> list[TaskCountRead]:
|
|
q = select(Task.workstream_id, Task.status, func.count()).group_by(Task.workstream_id, Task.status)
|
|
if workstream_id:
|
|
q = q.where(Task.workstream_id == workstream_id)
|
|
if status:
|
|
q = q.where(Task.status == TaskStatus(normalize_task_status(status)))
|
|
rows = await session.execute(q)
|
|
return [
|
|
TaskCountRead(workstream_id=ws_id, status=task_status, count=count)
|
|
for ws_id, task_status, count in rows
|
|
]
|
|
|
|
|
|
@router.post("/", response_model=TaskRead, status_code=status.HTTP_201_CREATED)
|
|
async def create_task(
|
|
body: TaskCreate,
|
|
session: AsyncSession = Depends(get_session),
|
|
) -> Task:
|
|
task = Task(**body.model_dump())
|
|
session.add(task)
|
|
if status_value(task.status) == "progress":
|
|
ws = await session.get(Workstream, task.workstream_id)
|
|
transition_task_status(
|
|
task,
|
|
task.status,
|
|
parent_workstream=ws,
|
|
previous_task_status="todo",
|
|
)
|
|
await session.commit()
|
|
await session.refresh(task)
|
|
return task
|
|
|
|
|
|
@router.post("/bulk-status-sync", response_model=TaskStatusBulkSyncRead)
|
|
async def bulk_status_sync(
|
|
body: TaskStatusBulkSync,
|
|
session: AsyncSession = Depends(get_session),
|
|
) -> TaskStatusBulkSyncRead:
|
|
seen: set[uuid.UUID] = set()
|
|
duplicate_ids: list[str] = []
|
|
tasks_by_id: dict[uuid.UUID, Task] = {}
|
|
missing_ids: list[str] = []
|
|
|
|
for update in body.updates:
|
|
if update.task_id in seen:
|
|
duplicate_ids.append(str(update.task_id))
|
|
continue
|
|
seen.add(update.task_id)
|
|
task = await session.get(Task, update.task_id)
|
|
if task is None:
|
|
missing_ids.append(str(update.task_id))
|
|
else:
|
|
tasks_by_id[update.task_id] = task
|
|
|
|
if duplicate_ids:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail={"message": "duplicate task_id values are not allowed", "task_ids": duplicate_ids},
|
|
)
|
|
if missing_ids:
|
|
raise HTTPException(
|
|
status_code=404,
|
|
detail={"message": "one or more tasks were not found", "task_ids": missing_ids},
|
|
)
|
|
|
|
updated: list[Task] = []
|
|
events: list[ProgressEvent] = []
|
|
author = body.author or "custodian"
|
|
for update in body.updates:
|
|
task = tasks_by_id[update.task_id]
|
|
previous_status = status_value(task.status)
|
|
target_status = status_value(update.status)
|
|
if update.blocking_reason is not None:
|
|
task.blocking_reason = update.blocking_reason
|
|
ws = await session.get(Workstream, task.workstream_id)
|
|
transition_task_status(
|
|
task,
|
|
update.status,
|
|
parent_workstream=ws,
|
|
previous_task_status=previous_status,
|
|
)
|
|
event = ProgressEvent(
|
|
task_id=task.id,
|
|
workstream_id=task.workstream_id,
|
|
event_type="task_status_changed",
|
|
summary=f"Task status -> {target_status}: {task.title}",
|
|
author=author,
|
|
session_id=body.session_id,
|
|
detail={
|
|
"bulk_status_sync": True,
|
|
"previous_status": previous_status,
|
|
"status": target_status,
|
|
"blocking_reason": update.blocking_reason,
|
|
},
|
|
)
|
|
session.add(event)
|
|
updated.append(task)
|
|
events.append(event)
|
|
|
|
await session.commit()
|
|
for task in updated:
|
|
await session.refresh(task)
|
|
for event in events:
|
|
await session.refresh(event)
|
|
|
|
return TaskStatusBulkSyncRead(
|
|
updated=updated,
|
|
progress_event_ids=[event.id for event in events],
|
|
)
|
|
|
|
|
|
@router.get("/{task_id}", response_model=TaskRead)
|
|
async def get_task(
|
|
task_id: uuid.UUID,
|
|
session: AsyncSession = Depends(get_session),
|
|
) -> Task:
|
|
task = await session.get(Task, task_id)
|
|
if task is None:
|
|
raise HTTPException(status_code=404, detail="Task not found")
|
|
return task
|
|
|
|
|
|
@router.patch("/{task_id}", response_model=TaskRead)
|
|
async def update_task(
|
|
task_id: uuid.UUID,
|
|
body: TaskUpdate,
|
|
session: AsyncSession = Depends(get_session),
|
|
) -> Task:
|
|
task = await session.get(Task, task_id)
|
|
if task is None:
|
|
raise HTTPException(status_code=404, detail="Task not found")
|
|
|
|
previous_status = status_value(task.status)
|
|
|
|
# Separate token fields from task fields
|
|
token_field_names = {
|
|
"tokens_in",
|
|
"tokens_out",
|
|
"workplan_tokens_in",
|
|
"workplan_tokens_out",
|
|
"token_note",
|
|
"model",
|
|
"agent",
|
|
"session_id",
|
|
"suppress_token_event",
|
|
}
|
|
update_data = body.model_dump(exclude_unset=True)
|
|
token_data = {k: update_data.pop(k) for k in list(update_data.keys()) if k in token_field_names}
|
|
suppress_token_event = bool(token_data.pop("suppress_token_event", False))
|
|
status_update = update_data.pop("status", None)
|
|
new_status = status_value(status_update) if status_update is not None else None
|
|
|
|
for field, value in update_data.items():
|
|
setattr(task, field, value)
|
|
if new_status is not None:
|
|
ws = await session.get(Workstream, task.workstream_id)
|
|
transition_task_status(
|
|
task,
|
|
status_update,
|
|
parent_workstream=ws,
|
|
previous_task_status=previous_status,
|
|
)
|
|
await session.commit()
|
|
await session.refresh(task)
|
|
|
|
# Token event — three-tier logic, only for an intentional transition to done.
|
|
if (
|
|
new_status == "done"
|
|
and previous_status != "done"
|
|
and not suppress_token_event
|
|
):
|
|
if "tokens_in" in token_data and "tokens_out" in token_data:
|
|
# Tier 1: exact counts — default note "measured"; caller may override with token_note
|
|
tin = token_data["tokens_in"]
|
|
tout = token_data["tokens_out"]
|
|
tnote = token_data.get("token_note") or "measured"
|
|
measurement_kind = "measured"
|
|
source_provider = "manual"
|
|
confidence = 1.0
|
|
source_id = f"task:{task_id}:manual"
|
|
raw_metadata = {"input_source": "task_status_patch"}
|
|
elif "workplan_tokens_in" in token_data and "workplan_tokens_out" in token_data:
|
|
# Tier 2: prorate workplan total across task count
|
|
count_result = await session.execute(
|
|
select(func.count(Task.id)).where(Task.workstream_id == task.workstream_id)
|
|
)
|
|
task_count = max(count_result.scalar() or 1, 1)
|
|
tin = token_data["workplan_tokens_in"] // task_count
|
|
tout = token_data["workplan_tokens_out"] // task_count
|
|
tnote = "workplan"
|
|
measurement_kind = "allocated"
|
|
source_provider = "manual"
|
|
confidence = 0.7
|
|
source_id = f"task:{task_id}:workplan-allocation"
|
|
raw_metadata = {
|
|
"allocation_method": "workplan_prorated",
|
|
"workplan_tokens_in": token_data["workplan_tokens_in"],
|
|
"workplan_tokens_out": token_data["workplan_tokens_out"],
|
|
"task_count": task_count,
|
|
}
|
|
else:
|
|
# Tier 3: heuristic fallback
|
|
tin, tout, tnote = 1000, 500, "heuristic"
|
|
measurement_kind = "estimated"
|
|
source_provider = "task_fallback"
|
|
confidence = 0.35
|
|
source_id = f"task:{task_id}:heuristic"
|
|
raw_metadata = {"estimation_method": "fixed_task_done_fallback"}
|
|
|
|
# Resolve repo_id via workstream
|
|
ws = await session.get(Workstream, task.workstream_id)
|
|
repo_id = ws.repo_id if ws else None
|
|
|
|
event = TokenEvent(
|
|
task_id=task_id,
|
|
workstream_id=task.workstream_id,
|
|
repo_id=repo_id,
|
|
tokens_in=tin,
|
|
tokens_out=tout,
|
|
model=token_data.get("model"),
|
|
agent=token_data.get("agent"),
|
|
session_id=token_data.get("session_id"),
|
|
ref_type="task",
|
|
ref_id=str(task_id),
|
|
note=tnote,
|
|
measurement_kind=measurement_kind,
|
|
source_provider=source_provider,
|
|
source_id=source_id,
|
|
confidence=confidence,
|
|
raw_total_tokens=tin + tout,
|
|
raw_metadata=raw_metadata,
|
|
)
|
|
session.add(event)
|
|
await session.commit()
|
|
|
|
return task
|
|
|
|
|
|
@router.delete("/{task_id}", response_model=TaskRead)
|
|
async def cancel_task(
|
|
task_id: uuid.UUID,
|
|
session: AsyncSession = Depends(get_session),
|
|
) -> Task:
|
|
task = await session.get(Task, task_id)
|
|
if task is None:
|
|
raise HTTPException(status_code=404, detail="Task not found")
|
|
transition_task_status(task, TaskStatus.cancel)
|
|
await session.commit()
|
|
await session.refresh(task)
|
|
return task
|