from datetime import datetime, timezone from fastapi import APIRouter, Depends from fastapi.responses import JSONResponse from sqlalchemy import func, select, text from sqlalchemy.ext.asyncio import AsyncSession from api.database import get_session, engine from api.models.decision import Decision, DecisionStatus, DecisionType from api.models.progress_event import ProgressEvent from api.models.task import Task, TaskStatus from api.models.topic import Topic, TopicStatus from api.models.workstream import Workstream, WorkstreamStatus from api.schemas.decision import DecisionRead from api.schemas.progress_event import ProgressEventRead from api.schemas.state import ( DecisionTotals, StateSummary, TaskTotals, Totals, TopicTotals, WorkstreamTotals, ) from api.schemas.task import TaskRead from api.schemas.topic import TopicWithWorkstreams from api.schemas.workstream import WorkstreamRead router = APIRouter(prefix="/state", tags=["state"]) @router.get("/summary", response_model=StateSummary) async def get_summary(session: AsyncSession = Depends(get_session)) -> StateSummary: # Run all queries sequentially on one session. # AsyncSession does not support concurrent operations (no gather on same session). topics_rows = await session.execute( select(Topic).where(Topic.status != TopicStatus.archived).order_by(Topic.created_at) ) topics = list(topics_rows.scalars().all()) blocking_rows = await session.execute( select(Decision) .where(Decision.decision_type == DecisionType.pending) .where(Decision.status.in_([DecisionStatus.open, DecisionStatus.escalated])) .order_by(Decision.deadline.asc().nullslast(), Decision.created_at) ) blocking = list(blocking_rows.scalars().all()) blocked_rows = await session.execute( select(Task).where(Task.status == TaskStatus.blocked).order_by(Task.created_at) ) blocked = list(blocked_rows.scalars().all()) recent_rows = await session.execute( select(ProgressEvent).order_by(ProgressEvent.created_at.desc()).limit(20) ) recent = list(recent_rows.scalars().all()) open_ws_rows = await session.execute( select(Workstream) .where(Workstream.status.in_([WorkstreamStatus.active, WorkstreamStatus.blocked])) .order_by(Workstream.due_date.asc().nullslast(), Workstream.created_at) ) open_ws = list(open_ws_rows.scalars().all()) # Totals — one GROUP BY per table topic_counts = {r[0]: r[1] for r in await session.execute( select(Topic.status, func.count()).group_by(Topic.status) )} ws_counts = {r[0]: r[1] for r in await session.execute( select(Workstream.status, func.count()).group_by(Workstream.status) )} task_counts = {r[0]: r[1] for r in await session.execute( select(Task.status, func.count()).group_by(Task.status) )} dec_counts = {r[0]: r[1] for r in await session.execute( select(Decision.status, func.count()).group_by(Decision.status) )} totals = Totals( topics=TopicTotals( active=topic_counts.get(TopicStatus.active, 0), paused=topic_counts.get(TopicStatus.paused, 0), archived=topic_counts.get(TopicStatus.archived, 0), total=sum(topic_counts.values()), ), workstreams=WorkstreamTotals( active=ws_counts.get(WorkstreamStatus.active, 0), blocked=ws_counts.get(WorkstreamStatus.blocked, 0), completed=ws_counts.get(WorkstreamStatus.completed, 0), archived=ws_counts.get(WorkstreamStatus.archived, 0), total=sum(ws_counts.values()), ), tasks=TaskTotals( todo=task_counts.get(TaskStatus.todo, 0), in_progress=task_counts.get(TaskStatus.in_progress, 0), blocked=task_counts.get(TaskStatus.blocked, 0), done=task_counts.get(TaskStatus.done, 0), cancelled=task_counts.get(TaskStatus.cancelled, 0), total=sum(task_counts.values()), ), decisions=DecisionTotals( open=dec_counts.get(DecisionStatus.open, 0), resolved=dec_counts.get(DecisionStatus.resolved, 0), escalated=dec_counts.get(DecisionStatus.escalated, 0), superseded=dec_counts.get(DecisionStatus.superseded, 0), total=sum(dec_counts.values()), ), ) return StateSummary( generated_at=datetime.now(tz=timezone.utc), totals=totals, topics=[TopicWithWorkstreams.model_validate(t) for t in topics], blocking_decisions=[DecisionRead.model_validate(d) for d in blocking], blocked_tasks=[TaskRead.model_validate(t) for t in blocked], recent_progress=[ProgressEventRead.model_validate(e) for e in recent], open_workstreams=[WorkstreamRead.model_validate(w) for w in open_ws], ) @router.get("/health") async def health_check() -> dict: try: async with engine.connect() as conn: await conn.execute(text("SELECT 1")) return {"status": "ok", "db": "connected"} except Exception as exc: return JSONResponse( status_code=503, content={"status": "error", "db": str(exc)}, )