Files
state-hub/api/routers/flows.py

170 lines
5.4 KiB
Python

from __future__ import annotations
import uuid
from typing import Any
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from api.database import get_session
from api.flow_defs import (
assertion_result_to_dict,
create_flow_engine,
flow_result_to_dict,
load_flow,
)
from api.models.capability_request import CapabilityRequest
from api.models.contribution import Contribution
from api.models.task import Task
from api.models.workstream import Workstream
from api.models.workstream_dependency import WorkstreamDependency
from api.workplan_status import normalize_workstream_status
router = APIRouter(prefix="/flows", tags=["flows"])
@router.get("/definitions")
async def list_flow_definitions() -> list[dict[str, Any]]:
flows = [
load_flow(entity_type)
for entity_type in (
"workstream",
"task",
"contribution",
"capability_request",
)
]
return [
{
"id": flow.id,
"entity_type": flow.entity_type,
"workstations": [
{
"name": workstation.name,
"description": workstation.description,
"entry_assertion_count": len(workstation.entry_assertions),
"exit_assertion_count": len(workstation.exit_assertions),
}
for workstation in flow.workstations
],
}
for flow in flows
]
@router.get("/{entity_type}/{entity_id}")
async def get_flow_state(
entity_type: str,
entity_id: uuid.UUID,
session: AsyncSession = Depends(get_session),
) -> dict[str, Any]:
obj = await _flow_object(entity_type, entity_id, session)
flow = load_flow(entity_type)
result = create_flow_engine().evaluate(obj, flow)
return flow_result_to_dict(result)
@router.post("/{entity_type}/{entity_id}/advance/{target_workstation}")
async def advance_workstation(
entity_type: str,
entity_id: uuid.UUID,
target_workstation: str,
session: AsyncSession = Depends(get_session),
) -> dict[str, Any]:
obj = await _flow_object(entity_type, entity_id, session)
flow = load_flow(entity_type)
engine = create_flow_engine()
can_reach, failures = engine.can_reach(obj, flow, target_workstation)
if not can_reach:
raise HTTPException(
status_code=409,
detail={
"message": (
f"Cannot advance {entity_type} '{entity_id}' "
f"to '{target_workstation}'."
),
"blocking_assertions": [
assertion_result_to_dict(item) for item in failures
],
"flow_result": flow_result_to_dict(engine.evaluate(obj, flow)),
},
)
entity = await _entity(entity_type, entity_id, session)
entity.status = target_workstation
await session.commit()
await session.refresh(entity)
return await get_flow_state(entity_type, entity_id, session)
async def _flow_object(
entity_type: str,
entity_id: uuid.UUID,
session: AsyncSession,
) -> dict[str, Any]:
entity = await _entity(entity_type, entity_id, session)
status = _value(entity.status)
current_status = normalize_workstream_status(status) if entity_type == "workstream" else status
obj: dict[str, Any] = {
"id": str(entity.id),
"status": current_status,
"workstation": current_status,
"previous_workstation": current_status,
}
if entity_type == "workstream":
tasks = list((await session.execute(
select(Task).where(Task.workstream_id == entity_id)
)).scalars().all())
deps = list((await session.execute(
select(WorkstreamDependency).where(
WorkstreamDependency.from_workstream_id == entity_id
)
)).scalars().all())
dependency_ids = [dep.to_workstream_id for dep in deps]
dependency_workstations: list[dict[str, Any]] = []
if dependency_ids:
dep_ws = list((await session.execute(
select(Workstream).where(Workstream.id.in_(dependency_ids))
)).scalars().all())
dependency_workstations = [
{"id": str(ws.id), "workstation": normalize_workstream_status(ws.status)}
for ws in dep_ws
]
obj.update({
"tasks": [{"id": str(task.id), "status": _value(task.status)} for task in tasks],
"dependencies": dependency_workstations,
})
elif entity_type == "task":
obj.update({
"needs_human": entity.needs_human,
"blocking_reason": entity.blocking_reason,
})
return obj
async def _entity(
entity_type: str,
entity_id: uuid.UUID,
session: AsyncSession,
):
model_by_type = {
"workstream": Workstream,
"task": Task,
"contribution": Contribution,
"capability_request": CapabilityRequest,
}
model = model_by_type.get(entity_type)
if model is None:
raise HTTPException(status_code=404, detail=f"Unknown flow entity type '{entity_type}'")
entity = await session.get(model, entity_id)
if entity is None:
raise HTTPException(status_code=404, detail=f"{entity_type} '{entity_id}' not found")
return entity
def _value(item):
return item.value if hasattr(item, "value") else item