Files
state-hub/api/routers/execution.py
2026-05-23 19:11:30 +02:00

197 lines
7.2 KiB
Python

import uuid
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from api.database import get_session
from api.models.task import Task, TaskStatus
from api.models.workplan_launch_request import WorkplanLaunchRequest
from api.models.workstream import Workstream
from api.models.workstream_dependency import WorkstreamDependency
from api.schemas.execution import (
ExecutionIntentRead,
ExecutionIntentUpdate,
ExecutionSemantics,
LaunchRequestCreate,
LaunchRequestRead,
WorkplanQueueItem,
)
from api.services.execution_queue import (
ACTIVITY_CORE_RESPONSIBILITIES,
CONCURRENCY_MODES,
EXECUTION_STATES,
LAUNCH_MODES,
STATE_HUB_RESPONSIBILITIES,
execution_state_for_launch,
queue_sort_key,
workstream_blockers,
)
from api.workplan_status import CLOSED_WORKSTREAM_STATUSES, normalize_workstream_status
router = APIRouter(prefix="/execution", tags=["execution"])
@router.get("/semantics", response_model=ExecutionSemantics)
async def execution_semantics() -> ExecutionSemantics:
return ExecutionSemantics(
execution_states=EXECUTION_STATES,
launch_modes=LAUNCH_MODES,
concurrency_modes=CONCURRENCY_MODES,
state_hub_responsibility=STATE_HUB_RESPONSIBILITIES,
activity_core_responsibility=ACTIVITY_CORE_RESPONSIBILITIES,
)
@router.patch("/workstreams/{workstream_id}/intent", response_model=ExecutionIntentRead)
async def update_execution_intent(
workstream_id: uuid.UUID,
body: ExecutionIntentUpdate,
session: AsyncSession = Depends(get_session),
) -> ExecutionIntentRead:
ws = await session.get(Workstream, workstream_id)
if ws is None:
raise HTTPException(status_code=404, detail="Workstream not found")
for field, value in body.model_dump(exclude_unset=True).items():
setattr(ws, field, value)
await session.commit()
await session.refresh(ws)
return _intent_read(ws)
@router.get("/workplan-stack", response_model=list[WorkplanQueueItem])
async def workplan_stack(
include_manual: bool = Query(True),
include_blocked: bool = Query(True),
session: AsyncSession = Depends(get_session),
) -> list[WorkplanQueueItem]:
result = await session.execute(select(Workstream))
workstreams = [
ws for ws in result.scalars().all()
if normalize_workstream_status(ws.status) not in CLOSED_WORKSTREAM_STATUSES
]
ws_by_id = {ws.id: ws for ws in workstreams}
ws_status = {ws.id: normalize_workstream_status(ws.status) for ws in workstreams}
dep_result = await session.execute(select(WorkstreamDependency))
ws_deps: dict[uuid.UUID, list[uuid.UUID]] = {}
task_deps: dict[uuid.UUID, list[uuid.UUID]] = {}
for dep in dep_result.scalars().all():
if dep.to_workstream_id is not None:
ws_deps.setdefault(dep.from_workstream_id, []).append(dep.to_workstream_id)
if dep.to_task_id is not None:
task_deps.setdefault(dep.from_workstream_id, []).append(dep.to_task_id)
task_ids = [task_id for ids in task_deps.values() for task_id in ids]
task_status: dict[uuid.UUID, str] = {}
if task_ids:
task_result = await session.execute(select(Task).where(Task.id.in_(task_ids)))
task_status = {task.id: _task_status(task.status) for task in task_result.scalars().all()}
items: list[WorkplanQueueItem] = []
for ws in workstreams:
if not include_manual and ws.execution_state == "manual":
continue
lifecycle_status = normalize_workstream_status(ws.status)
blocked_ws = [
blocker for blocker in workstream_blockers(ws.id, ws_deps, ws_status)
if blocker in ws_by_id or blocker in ws_status
]
blocked_tasks = [
task_id for task_id in task_deps.get(ws.id, [])
if task_status.get(task_id) not in {"done", "cancelled"}
]
eligible = lifecycle_status != "blocked" and not blocked_ws and not blocked_tasks
if not include_blocked and not eligible:
continue
sort_key = queue_sort_key(ws, eligible=eligible)
items.append(WorkplanQueueItem(
workstream_id=ws.id,
slug=ws.slug,
title=ws.title,
status=lifecycle_status,
repo_id=ws.repo_id,
planning_priority=ws.planning_priority,
planning_order=ws.planning_order,
execution_state=ws.execution_state,
launch_mode=ws.launch_mode,
concurrency_mode=ws.concurrency_mode,
queue_rank=ws.queue_rank,
execution_group=ws.execution_group,
scheduled_for=ws.scheduled_for,
eligible=eligible,
blocked_by_workstream_ids=blocked_ws,
blocked_by_task_ids=blocked_tasks,
sort_key=sort_key,
))
return sorted(items, key=lambda item: item.sort_key)
@router.post(
"/launch-requests",
response_model=LaunchRequestRead,
status_code=status.HTTP_201_CREATED,
)
async def create_launch_request(
body: LaunchRequestCreate,
session: AsyncSession = Depends(get_session),
) -> WorkplanLaunchRequest:
ws = await session.get(Workstream, body.workstream_id)
if ws is None:
raise HTTPException(status_code=404, detail="Workstream not found")
launch_request = WorkplanLaunchRequest(
workstream_id=ws.id,
requested_by=body.requested_by,
requested_actor=body.requested_actor,
launch_mode=body.launch_mode,
concurrency_mode=body.concurrency_mode,
priority=body.priority or ws.planning_priority,
repo_id=body.repo_id or ws.repo_id,
branch_preference=body.branch_preference,
immediate_pickup=body.immediate_pickup,
notes=body.notes,
request_metadata=body.request_metadata,
)
ws.launch_mode = body.launch_mode
ws.concurrency_mode = body.concurrency_mode
ws.execution_state = execution_state_for_launch(body.launch_mode, body.immediate_pickup)
session.add(launch_request)
await session.commit()
await session.refresh(launch_request)
return launch_request
@router.get("/launch-requests", response_model=list[LaunchRequestRead])
async def list_launch_requests(
workstream_id: uuid.UUID | None = None,
request_status: str | None = None,
session: AsyncSession = Depends(get_session),
) -> list[WorkplanLaunchRequest]:
q = select(WorkplanLaunchRequest).order_by(WorkplanLaunchRequest.created_at.desc())
if workstream_id:
q = q.where(WorkplanLaunchRequest.workstream_id == workstream_id)
if request_status:
q = q.where(WorkplanLaunchRequest.status == request_status)
result = await session.execute(q)
return list(result.scalars().all())
def _intent_read(ws: Workstream) -> ExecutionIntentRead:
return ExecutionIntentRead(
workstream_id=ws.id,
execution_state=ws.execution_state,
launch_mode=ws.launch_mode,
concurrency_mode=ws.concurrency_mode,
queue_rank=ws.queue_rank,
execution_group=ws.execution_group,
scheduled_for=ws.scheduled_for,
)
def _task_status(status_value: TaskStatus | str) -> str:
if hasattr(status_value, "value"):
return status_value.value
return str(status_value or "").strip().lower()