generated from coulomb/repo-seed
feat: add State Hub bulk status skill
This commit is contained in:
@@ -6,10 +6,18 @@ from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from api.database import get_session
|
||||
from api.models.progress_event import ProgressEvent
|
||||
from api.models.task import Task, TaskStatus
|
||||
from api.models.token_event import TokenEvent
|
||||
from api.models.workstream import Workstream
|
||||
from api.schemas.task import TaskCountRead, TaskCreate, TaskRead, TaskUpdate
|
||||
from api.schemas.task import (
|
||||
TaskCountRead,
|
||||
TaskCreate,
|
||||
TaskRead,
|
||||
TaskStatusBulkSync,
|
||||
TaskStatusBulkSyncRead,
|
||||
TaskUpdate,
|
||||
)
|
||||
from api.services.lifecycle import status_value, transition_task_status
|
||||
from api.task_status import normalize_task_status
|
||||
|
||||
@@ -88,6 +96,84 @@ async def create_task(
|
||||
return task
|
||||
|
||||
|
||||
@router.post("/bulk-status-sync", response_model=TaskStatusBulkSyncRead)
|
||||
async def bulk_status_sync(
|
||||
body: TaskStatusBulkSync,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> TaskStatusBulkSyncRead:
|
||||
seen: set[uuid.UUID] = set()
|
||||
duplicate_ids: list[str] = []
|
||||
tasks_by_id: dict[uuid.UUID, Task] = {}
|
||||
missing_ids: list[str] = []
|
||||
|
||||
for update in body.updates:
|
||||
if update.task_id in seen:
|
||||
duplicate_ids.append(str(update.task_id))
|
||||
continue
|
||||
seen.add(update.task_id)
|
||||
task = await session.get(Task, update.task_id)
|
||||
if task is None:
|
||||
missing_ids.append(str(update.task_id))
|
||||
else:
|
||||
tasks_by_id[update.task_id] = task
|
||||
|
||||
if duplicate_ids:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"message": "duplicate task_id values are not allowed", "task_ids": duplicate_ids},
|
||||
)
|
||||
if missing_ids:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"message": "one or more tasks were not found", "task_ids": missing_ids},
|
||||
)
|
||||
|
||||
updated: list[Task] = []
|
||||
events: list[ProgressEvent] = []
|
||||
author = body.author or "custodian"
|
||||
for update in body.updates:
|
||||
task = tasks_by_id[update.task_id]
|
||||
previous_status = status_value(task.status)
|
||||
target_status = status_value(update.status)
|
||||
if update.blocking_reason is not None:
|
||||
task.blocking_reason = update.blocking_reason
|
||||
ws = await session.get(Workstream, task.workstream_id)
|
||||
transition_task_status(
|
||||
task,
|
||||
update.status,
|
||||
parent_workstream=ws,
|
||||
previous_task_status=previous_status,
|
||||
)
|
||||
event = ProgressEvent(
|
||||
task_id=task.id,
|
||||
workstream_id=task.workstream_id,
|
||||
event_type="task_status_changed",
|
||||
summary=f"Task status -> {target_status}: {task.title}",
|
||||
author=author,
|
||||
session_id=body.session_id,
|
||||
detail={
|
||||
"bulk_status_sync": True,
|
||||
"previous_status": previous_status,
|
||||
"status": target_status,
|
||||
"blocking_reason": update.blocking_reason,
|
||||
},
|
||||
)
|
||||
session.add(event)
|
||||
updated.append(task)
|
||||
events.append(event)
|
||||
|
||||
await session.commit()
|
||||
for task in updated:
|
||||
await session.refresh(task)
|
||||
for event in events:
|
||||
await session.refresh(event)
|
||||
|
||||
return TaskStatusBulkSyncRead(
|
||||
updated=updated,
|
||||
progress_event_ids=[event.id for event in events],
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{task_id}", response_model=TaskRead)
|
||||
async def get_task(
|
||||
task_id: uuid.UUID,
|
||||
|
||||
@@ -77,6 +77,25 @@ class TaskUpdate(TaskStatusMixin):
|
||||
return self
|
||||
|
||||
|
||||
class TaskStatusBulkUpdate(TaskStatusMixin):
|
||||
task_id: uuid.UUID
|
||||
status: TaskStatus
|
||||
blocking_reason: str | None = None
|
||||
|
||||
|
||||
class TaskStatusBulkSync(BaseModel):
|
||||
updates: list[TaskStatusBulkUpdate]
|
||||
author: str | None = "custodian"
|
||||
session_id: str | None = None
|
||||
|
||||
@field_validator("updates")
|
||||
@classmethod
|
||||
def updates_required(cls, value: list[TaskStatusBulkUpdate]):
|
||||
if not value:
|
||||
raise ValueError("at least one task status update is required")
|
||||
return value
|
||||
|
||||
|
||||
class TaskRead(TaskStatusMixin):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
id: uuid.UUID
|
||||
@@ -99,3 +118,8 @@ class TaskCountRead(TaskStatusMixin):
|
||||
workstream_id: uuid.UUID
|
||||
status: TaskStatus
|
||||
count: int
|
||||
|
||||
|
||||
class TaskStatusBulkSyncRead(BaseModel):
|
||||
updated: list[TaskRead]
|
||||
progress_event_ids: list[uuid.UUID]
|
||||
|
||||
Reference in New Issue
Block a user