feat: add State Hub bulk status skill

This commit is contained in:
2026-06-07 20:11:07 +02:00
parent 8f17bc1f50
commit 55e36bdf2d
9 changed files with 496 additions and 5 deletions

View File

@@ -6,10 +6,18 @@ from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from api.database import get_session
from api.models.progress_event import ProgressEvent
from api.models.task import Task, TaskStatus
from api.models.token_event import TokenEvent
from api.models.workstream import Workstream
from api.schemas.task import TaskCountRead, TaskCreate, TaskRead, TaskUpdate
from api.schemas.task import (
TaskCountRead,
TaskCreate,
TaskRead,
TaskStatusBulkSync,
TaskStatusBulkSyncRead,
TaskUpdate,
)
from api.services.lifecycle import status_value, transition_task_status
from api.task_status import normalize_task_status
@@ -88,6 +96,84 @@ async def create_task(
return task
@router.post("/bulk-status-sync", response_model=TaskStatusBulkSyncRead)
async def bulk_status_sync(
body: TaskStatusBulkSync,
session: AsyncSession = Depends(get_session),
) -> TaskStatusBulkSyncRead:
seen: set[uuid.UUID] = set()
duplicate_ids: list[str] = []
tasks_by_id: dict[uuid.UUID, Task] = {}
missing_ids: list[str] = []
for update in body.updates:
if update.task_id in seen:
duplicate_ids.append(str(update.task_id))
continue
seen.add(update.task_id)
task = await session.get(Task, update.task_id)
if task is None:
missing_ids.append(str(update.task_id))
else:
tasks_by_id[update.task_id] = task
if duplicate_ids:
raise HTTPException(
status_code=400,
detail={"message": "duplicate task_id values are not allowed", "task_ids": duplicate_ids},
)
if missing_ids:
raise HTTPException(
status_code=404,
detail={"message": "one or more tasks were not found", "task_ids": missing_ids},
)
updated: list[Task] = []
events: list[ProgressEvent] = []
author = body.author or "custodian"
for update in body.updates:
task = tasks_by_id[update.task_id]
previous_status = status_value(task.status)
target_status = status_value(update.status)
if update.blocking_reason is not None:
task.blocking_reason = update.blocking_reason
ws = await session.get(Workstream, task.workstream_id)
transition_task_status(
task,
update.status,
parent_workstream=ws,
previous_task_status=previous_status,
)
event = ProgressEvent(
task_id=task.id,
workstream_id=task.workstream_id,
event_type="task_status_changed",
summary=f"Task status -> {target_status}: {task.title}",
author=author,
session_id=body.session_id,
detail={
"bulk_status_sync": True,
"previous_status": previous_status,
"status": target_status,
"blocking_reason": update.blocking_reason,
},
)
session.add(event)
updated.append(task)
events.append(event)
await session.commit()
for task in updated:
await session.refresh(task)
for event in events:
await session.refresh(event)
return TaskStatusBulkSyncRead(
updated=updated,
progress_event_ids=[event.id for event in events],
)
@router.get("/{task_id}", response_model=TaskRead)
async def get_task(
task_id: uuid.UUID,

View File

@@ -77,6 +77,25 @@ class TaskUpdate(TaskStatusMixin):
return self
class TaskStatusBulkUpdate(TaskStatusMixin):
task_id: uuid.UUID
status: TaskStatus
blocking_reason: str | None = None
class TaskStatusBulkSync(BaseModel):
updates: list[TaskStatusBulkUpdate]
author: str | None = "custodian"
session_id: str | None = None
@field_validator("updates")
@classmethod
def updates_required(cls, value: list[TaskStatusBulkUpdate]):
if not value:
raise ValueError("at least one task status update is required")
return value
class TaskRead(TaskStatusMixin):
model_config = ConfigDict(from_attributes=True)
id: uuid.UUID
@@ -99,3 +118,8 @@ class TaskCountRead(TaskStatusMixin):
workstream_id: uuid.UUID
status: TaskStatus
count: int
class TaskStatusBulkSyncRead(BaseModel):
updated: list[TaskRead]
progress_event_ids: list[uuid.UUID]