import uuid from collections.abc import Callable, Collection from datetime import datetime from typing import Any from fastapi import APIRouter, Depends, HTTPException, Query, status from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from hub_core.events import ALERT_EVENT_TYPES, RISK_EVENT_TYPES from hub_core.models.progress_event import ProgressEvent from hub_core.schemas.progress_event import ProgressEventCreate, ProgressEventRead from hub_core.utils.pagination import PageParams, apply_pagination def create_progress_router( get_session: Callable[..., AsyncSession], *, progress_model: type[ProgressEvent] = ProgressEvent, progress_create_schema: type[ProgressEventCreate] = ProgressEventCreate, progress_read_schema: type[ProgressEventRead] = ProgressEventRead, ) -> APIRouter: router = APIRouter(prefix="/progress", tags=["progress"]) list_response_model = list[progress_read_schema] async def _list_events( session: AsyncSession, *, topic_id: uuid.UUID | None = None, workstream_id: uuid.UUID | None = None, task_id: uuid.UUID | None = None, decision_id: uuid.UUID | None = None, event_type: str | None = None, event_types: Collection[str] | None = None, since: datetime | None = None, limit: int = 100, offset: int = 0, ) -> list[Any]: q = select(progress_model) for field, value in ( ("topic_id", topic_id), ("workstream_id", workstream_id), ("task_id", task_id), ("decision_id", decision_id), ): if value is not None: column = getattr(progress_model, field, None) if column is None: raise HTTPException( status_code=400, detail=f"Progress events do not support filtering by {field}", ) q = q.where(column == value) if event_type: q = q.where(progress_model.event_type == event_type) if event_types is not None: q = q.where(progress_model.event_type.in_(sorted(event_types))) if since: q = q.where(progress_model.created_at >= since) q = q.order_by(progress_model.created_at.desc()) q = apply_pagination(q, PageParams(limit=limit, offset=offset)) result = await session.execute(q) return list(result.scalars().all()) @router.get("/", response_model=list_response_model) async def list_progress( topic_id: uuid.UUID | None = None, workstream_id: uuid.UUID | None = None, task_id: uuid.UUID | None = None, decision_id: uuid.UUID | None = None, event_type: str | None = None, since: datetime | None = None, limit: int = Query(100, le=1000), offset: int = Query(0, ge=0), session: AsyncSession = Depends(get_session), ) -> list[Any]: return await _list_events( session, topic_id=topic_id, workstream_id=workstream_id, task_id=task_id, decision_id=decision_id, event_type=event_type, since=since, limit=limit, offset=offset, ) @router.get("/risks", response_model=list_response_model) async def get_risks( since: datetime | None = None, limit: int = Query(100, le=1000), offset: int = Query(0, ge=0), session: AsyncSession = Depends(get_session), ) -> list[Any]: return await _list_events( session, event_types=RISK_EVENT_TYPES, since=since, limit=limit, offset=offset, ) @router.get("/alerts", response_model=list_response_model) async def get_alerts( since: datetime | None = None, limit: int = Query(100, le=1000), offset: int = Query(0, ge=0), session: AsyncSession = Depends(get_session), ) -> list[Any]: return await _list_events( session, event_types=ALERT_EVENT_TYPES, since=since, limit=limit, offset=offset, ) @router.post("/", response_model=progress_read_schema, status_code=status.HTTP_201_CREATED) async def append_progress( body: progress_create_schema, session: AsyncSession = Depends(get_session), ) -> Any: event = progress_model(**body.model_dump()) session.add(event) await session.commit() await session.refresh(event) return event return router