Task flow engine implementation

This commit is contained in:
2026-05-02 00:21:14 +02:00
parent 5502d1d535
commit a00f1b615b
15 changed files with 517 additions and 86 deletions

View File

@@ -7,6 +7,7 @@ from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from api.database import get_session
from api.flow_defs import assertion_result_to_dict, evaluate_transition, flow_result_to_dict
from api.models.agent_message import AgentMessage
from api.models.capability_catalog import CapabilityCatalog
from api.models.capability_request import CapabilityRequest
@@ -28,22 +29,6 @@ from api.schemas.capability_request import (
router = APIRouter(tags=["capability-requests"])
# ---------------------------------------------------------------------------
# Lifecycle guard
# ---------------------------------------------------------------------------
_VALID_TRANSITIONS: dict[str, set[str]] = {
"requested": {"accepted", "rejected", "withdrawn", "routing_disputed"},
"routing_disputed": {"requested", "withdrawn"},
"accepted": {"in_progress", "rejected", "withdrawn"},
"in_progress": {"ready_for_review", "rejected", "withdrawn"},
"ready_for_review": {"completed", "in_progress", "withdrawn"},
"completed": set(),
"rejected": set(),
"withdrawn": set(),
}
# ---------------------------------------------------------------------------
# Capability Catalog endpoints
# ---------------------------------------------------------------------------
@@ -602,12 +587,21 @@ async def _get_request_or_404(request_id: uuid.UUID, session: AsyncSession) -> C
def _check_transition(current: str, target: str) -> None:
allowed = _VALID_TRANSITIONS.get(current, set())
if target not in allowed:
can_reach, failures, flow_result = evaluate_transition(
"capability_request",
current,
target,
)
if not can_reach:
raise HTTPException(
status_code=422,
detail=(
f"Cannot transition from '{current}' to '{target}'. "
f"Allowed: {sorted(allowed) or 'none (terminal state)'}"
),
detail={
"message": f"Cannot transition from '{current}' to '{target}'.",
"current_workstation": current,
"target_workstation": target,
"blocking_assertions": [
assertion_result_to_dict(item) for item in failures
],
"flow_result": flow_result_to_dict(flow_result),
},
)

View File

@@ -6,37 +6,12 @@ from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from api.database import get_session
from api.flow_defs import assertion_result_to_dict, evaluate_transition, flow_result_to_dict
from api.models.contribution import Contribution, ContributionStatus, ContributionType
from api.schemas.contribution import ContributionCreate, ContributionRead, ContributionStatusPatch
router = APIRouter(prefix="/contributions", tags=["contributions"])
# Valid forward transitions in the lifecycle
_VALID_TRANSITIONS: dict[ContributionStatus, set[ContributionStatus]] = {
ContributionStatus.draft: {
ContributionStatus.submitted,
ContributionStatus.withdrawn,
},
ContributionStatus.submitted: {
ContributionStatus.acknowledged,
ContributionStatus.rejected,
ContributionStatus.withdrawn,
},
ContributionStatus.acknowledged: {
ContributionStatus.accepted,
ContributionStatus.rejected,
ContributionStatus.withdrawn,
},
ContributionStatus.accepted: {
ContributionStatus.merged,
ContributionStatus.withdrawn,
},
ContributionStatus.rejected: set(),
ContributionStatus.merged: set(),
ContributionStatus.withdrawn: set(),
}
@router.get("/", response_model=list[ContributionRead])
async def list_contributions(
type: ContributionType | None = Query(None),
@@ -93,14 +68,25 @@ async def patch_contribution_status(
session: AsyncSession = Depends(get_session),
) -> Contribution:
contrib = await _get_or_404(contribution_id, session)
allowed = _VALID_TRANSITIONS.get(contrib.status, set())
if body.status not in allowed:
current = _status_value(contrib.status)
target = _status_value(body.status)
can_reach, failures, flow_result = evaluate_transition(
"contribution",
current,
target,
)
if not can_reach:
raise HTTPException(
status_code=422,
detail=(
f"Cannot transition from '{contrib.status}' to '{body.status}'. "
f"Allowed: {[s.value for s in allowed] or 'none (terminal state)'}"
),
detail={
"message": f"Cannot transition from '{current}' to '{target}'.",
"current_workstation": current,
"target_workstation": target,
"blocking_assertions": [
assertion_result_to_dict(item) for item in failures
],
"flow_result": flow_result_to_dict(flow_result),
},
)
contrib.status = body.status
if body.notes:
@@ -145,3 +131,7 @@ async def _get_or_404(contribution_id: uuid.UUID, session: AsyncSession) -> Cont
if contrib is None:
raise HTTPException(status_code=404, detail=f"Contribution '{contribution_id}' not found")
return contrib
def _status_value(status: ContributionStatus | str) -> str:
return status.value if isinstance(status, ContributionStatus) else str(status)

View File

@@ -10,7 +10,7 @@ from api.models.extension_point import ExtensionPoint
from api.models.managed_repo import ManagedRepo
from api.models.technical_debt import TechnicalDebt
from api.models.topic import Topic
from api.models.workstream import Workstream, WorkstreamStatus
from api.models.workstream import Workstream
from api.schemas.domain import DomainCreate, DomainDetail, DomainRead, DomainRename, DomainUpdate, RepoStub
router = APIRouter(prefix="/domains", tags=["domains"])
@@ -69,7 +69,7 @@ async def get_domain(
ws_count_row = await session.execute(
select(func.count()).select_from(Workstream)
.where(Workstream.topic_id.in_(topic_ids))
.where(Workstream.status == WorkstreamStatus.active)
.where(Workstream.status == "active")
)
ws_count = ws_count_row.scalar_one()

167
api/routers/flows.py Normal file
View File

@@ -0,0 +1,167 @@
from __future__ import annotations
import uuid
from typing import Any
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from api.database import get_session
from api.flow_defs import (
assertion_result_to_dict,
create_flow_engine,
flow_result_to_dict,
load_flow,
)
from api.models.capability_request import CapabilityRequest
from api.models.contribution import Contribution
from api.models.task import Task
from api.models.workstream import Workstream
from api.models.workstream_dependency import WorkstreamDependency
router = APIRouter(prefix="/flows", tags=["flows"])
@router.get("/definitions")
async def list_flow_definitions() -> list[dict[str, Any]]:
flows = [
load_flow(entity_type)
for entity_type in (
"workstream",
"task",
"contribution",
"capability_request",
)
]
return [
{
"id": flow.id,
"entity_type": flow.entity_type,
"workstations": [
{
"name": workstation.name,
"description": workstation.description,
"entry_assertion_count": len(workstation.entry_assertions),
"exit_assertion_count": len(workstation.exit_assertions),
}
for workstation in flow.workstations
],
}
for flow in flows
]
@router.get("/{entity_type}/{entity_id}")
async def get_flow_state(
entity_type: str,
entity_id: uuid.UUID,
session: AsyncSession = Depends(get_session),
) -> dict[str, Any]:
obj = await _flow_object(entity_type, entity_id, session)
flow = load_flow(entity_type)
result = create_flow_engine().evaluate(obj, flow)
return flow_result_to_dict(result)
@router.post("/{entity_type}/{entity_id}/advance/{target_workstation}")
async def advance_workstation(
entity_type: str,
entity_id: uuid.UUID,
target_workstation: str,
session: AsyncSession = Depends(get_session),
) -> dict[str, Any]:
obj = await _flow_object(entity_type, entity_id, session)
flow = load_flow(entity_type)
engine = create_flow_engine()
can_reach, failures = engine.can_reach(obj, flow, target_workstation)
if not can_reach:
raise HTTPException(
status_code=409,
detail={
"message": (
f"Cannot advance {entity_type} '{entity_id}' "
f"to '{target_workstation}'."
),
"blocking_assertions": [
assertion_result_to_dict(item) for item in failures
],
"flow_result": flow_result_to_dict(engine.evaluate(obj, flow)),
},
)
entity = await _entity(entity_type, entity_id, session)
entity.status = target_workstation
await session.commit()
await session.refresh(entity)
return await get_flow_state(entity_type, entity_id, session)
async def _flow_object(
entity_type: str,
entity_id: uuid.UUID,
session: AsyncSession,
) -> dict[str, Any]:
entity = await _entity(entity_type, entity_id, session)
status = _value(entity.status)
obj: dict[str, Any] = {
"id": str(entity.id),
"status": status,
"workstation": status,
"previous_workstation": status,
}
if entity_type == "workstream":
tasks = list((await session.execute(
select(Task).where(Task.workstream_id == entity_id)
)).scalars().all())
deps = list((await session.execute(
select(WorkstreamDependency).where(
WorkstreamDependency.from_workstream_id == entity_id
)
)).scalars().all())
dependency_ids = [dep.to_workstream_id for dep in deps]
dependency_workstations: list[dict[str, Any]] = []
if dependency_ids:
dep_ws = list((await session.execute(
select(Workstream).where(Workstream.id.in_(dependency_ids))
)).scalars().all())
dependency_workstations = [
{"id": str(ws.id), "workstation": ws.status}
for ws in dep_ws
]
obj.update({
"tasks": [{"id": str(task.id), "status": _value(task.status)} for task in tasks],
"dependencies": dependency_workstations,
})
elif entity_type == "task":
obj.update({
"needs_human": entity.needs_human,
"blocking_reason": entity.blocking_reason,
})
return obj
async def _entity(
entity_type: str,
entity_id: uuid.UUID,
session: AsyncSession,
):
model_by_type = {
"workstream": Workstream,
"task": Task,
"contribution": Contribution,
"capability_request": CapabilityRequest,
}
model = model_by_type.get(entity_type)
if model is None:
raise HTTPException(status_code=404, detail=f"Unknown flow entity type '{entity_type}'")
entity = await session.get(model, entity_id)
if entity is None:
raise HTTPException(status_code=404, detail=f"{entity_type} '{entity_id}' not found")
return entity
def _value(item):
return item.value if hasattr(item, "value") else item

View File

@@ -6,6 +6,7 @@ from sqlalchemy import func, select, text
from sqlalchemy.ext.asyncio import AsyncSession
from api.database import get_session, engine
from api.flow_defs import assertion_result_to_dict, load_flow
from api.models.capability_request import CapabilityRequest
from api.models.contribution import Contribution, ContributionStatus, ContributionType
from api.models.decision import Decision, DecisionStatus, DecisionType
@@ -17,7 +18,7 @@ from api.models.sbom_entry import SBOMEntry
from api.models.task import Task, TaskPriority, TaskStatus
from api.models.technical_debt import TechnicalDebt
from api.models.topic import Topic, TopicStatus
from api.models.workstream import Workstream, WorkstreamStatus
from api.models.workstream import Workstream
from api.models.workstream_dependency import WorkstreamDependency
from api.schemas.decision import DecisionRead
from api.schemas.domain import DomainSummary
@@ -35,6 +36,7 @@ from api.schemas.task import TaskRead
from api.schemas.topic import TopicWithWorkstreams
from api.schemas.workstream import WorkstreamRead, WorkstreamWithTaskCounts, WorkstreamWithDeps
from api.schemas.workstream_dependency import WorkstreamDepStub
from task_flow_engine import FlowEngine
router = APIRouter(prefix="/state", tags=["state"])
@@ -69,7 +71,7 @@ async def get_summary(session: AsyncSession = Depends(get_session)) -> StateSumm
open_ws_rows = await session.execute(
select(Workstream)
.where(Workstream.status.in_([WorkstreamStatus.active, WorkstreamStatus.blocked]))
.where(Workstream.status.in_(["active", "blocked"]))
.order_by(Workstream.due_date.asc().nullslast(), Workstream.created_at)
)
open_ws = list(open_ws_rows.scalars().all())
@@ -128,6 +130,27 @@ async def get_summary(session: AsyncSession = Depends(get_session)) -> StateSumm
description=d.description,
))
workstream_flow = load_flow("workstream")
flow_engine = FlowEngine()
effective_status: dict = {}
blocked_reasons: dict = {}
for w in open_ws:
flow_obj = {
"status": w.status,
"workstation": w.status,
"tasks": [{"status": _value(t.status)} for t in w.tasks],
"dependencies": [
{"workstation": ws_lookup[d.to_workstream_id].status}
for d in dep_rows
if d.from_workstream_id == w.id and d.to_workstream_id in ws_lookup
],
}
flow_result = flow_engine.evaluate(flow_obj, workstream_flow)
effective_status[w.id] = "blocked" if flow_result.exit_blocked else w.status
blocked_reasons[w.id] = [
assertion_result_to_dict(item) for item in flow_result.blocking_assertions
]
# Totals — one GROUP BY per table
topic_counts = {r[0]: r[1] for r in await session.execute(
select(Topic.status, func.count()).group_by(Topic.status)
@@ -150,10 +173,10 @@ async def get_summary(session: AsyncSession = Depends(get_session)) -> StateSumm
total=sum(topic_counts.values()),
),
workstreams=WorkstreamTotals(
active=ws_counts.get(WorkstreamStatus.active, 0),
blocked=ws_counts.get(WorkstreamStatus.blocked, 0),
completed=ws_counts.get(WorkstreamStatus.completed, 0),
archived=ws_counts.get(WorkstreamStatus.archived, 0),
active=sum(1 for status in effective_status.values() if status == "active"),
blocked=sum(1 for status in effective_status.values() if status == "blocked"),
completed=ws_counts.get("completed", 0),
archived=ws_counts.get("archived", 0),
total=sum(ws_counts.values()),
),
tasks=TaskTotals(
@@ -226,7 +249,10 @@ async def get_summary(session: AsyncSession = Depends(get_session)) -> StateSumm
open_capability_requests=open_cap_req_count,
open_workstreams=[
WorkstreamWithDeps(
**WorkstreamRead.model_validate(w).model_dump(),
**{
**WorkstreamRead.model_validate(w).model_dump(),
"status": effective_status.get(w.id, w.status),
},
tasks_total=sum(task_per_ws.get(w.id, {}).values()),
tasks_todo=task_per_ws.get(w.id, {}).get(TaskStatus.todo, 0),
tasks_in_progress=task_per_ws.get(w.id, {}).get(TaskStatus.in_progress, 0),
@@ -234,6 +260,7 @@ async def get_summary(session: AsyncSession = Depends(get_session)) -> StateSumm
tasks_done=task_per_ws.get(w.id, {}).get(TaskStatus.done, 0),
depends_on=dep_index.get(w.id, {}).get("depends_on", []),
blocks=dep_index.get(w.id, {}).get("blocks", []),
blocked_reasons=blocked_reasons.get(w.id, []),
)
for w in open_ws
],
@@ -259,7 +286,7 @@ async def _build_domain_summaries(session: AsyncSession) -> list[DomainSummary]:
for domain_id, cnt in await session.execute(
select(Topic.domain_id, func.count(Workstream.id))
.join(Workstream, Workstream.topic_id == Topic.id)
.where(Workstream.status == WorkstreamStatus.active)
.where(Workstream.status == "active")
.group_by(Topic.domain_id)
):
ws_per_domain[domain_id] = cnt
@@ -357,14 +384,14 @@ async def _derive_next_steps(session: AsyncSession) -> list[NextStep]:
all_done = True
for to_id in to_ws_ids:
to_ws = await session.get(Workstream, to_id)
if to_ws is None or to_ws.status != WorkstreamStatus.completed:
if to_ws is None or to_ws.status != "completed":
all_done = False
break
if not all_done:
continue
from_ws = await session.get(Workstream, from_ws_id)
if from_ws is None or from_ws.status not in (WorkstreamStatus.active, WorkstreamStatus.blocked):
if from_ws is None or from_ws.status not in ("active", "blocked"):
continue
todo_rows = await session.execute(
@@ -414,6 +441,10 @@ async def _get_domain_slug_for_workstream(ws: Workstream | None, session: AsyncS
return domain.slug if domain else None
def _value(item):
return item.value if hasattr(item, "value") else item
@router.get("/next_steps", response_model=list[NextStep])
async def get_next_steps(session: AsyncSession = Depends(get_session)) -> list[NextStep]:
"""Derive contextual next-action suggestions from current hub state.

View File

@@ -5,8 +5,13 @@ from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from api.database import get_session
from api.models.workstream import Workstream, WorkstreamStatus
from api.schemas.workstream import WorkstreamCreate, WorkstreamRead, WorkstreamUpdate
from api.models.workstream import Workstream
from api.schemas.workstream import (
WorkstreamCreate,
WorkstreamRead,
WorkstreamStatus,
WorkstreamUpdate,
)
router = APIRouter(prefix="/workstreams", tags=["workstreams"])
@@ -86,7 +91,7 @@ async def archive_workstream(
ws = await session.get(Workstream, workstream_id)
if ws is None:
raise HTTPException(status_code=404, detail="Workstream not found")
ws.status = WorkstreamStatus.archived
ws.status = "archived"
await session.commit()
await session.refresh(ws)
return ws