import uuid from collections.abc import Callable from datetime import datetime, timezone from typing import Any, Awaitable from fastapi import APIRouter, Depends, HTTPException, Query, status from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from hub_core.models.capability_catalog import CapabilityCatalog from hub_core.models.capability_request import CapabilityRequest from hub_core.models.domain import Domain from hub_core.models.managed_repo import ManagedRepo from hub_core.schemas.capability import ( CapabilityRequestAccept, CapabilityRequestCreate, CapabilityRequestDispute, CapabilityRequestPatch, CapabilityRequestReroute, CapabilityRequestRead, CapabilityRequestStatusPatch, CatalogCreate, CatalogPatch, CatalogRead, ) RouteRequestCallback = Callable[ [AsyncSession, Any], Awaitable[tuple[uuid.UUID | None, uuid.UUID | None, str | None]], ] BuildRequestCallback = Callable[ [Any, Any, uuid.UUID | None, uuid.UUID | None, str | None], Any, ] RequestLifecycleCallback = Callable[[AsyncSession, Any, Any], Awaitable[None]] CheckTransitionCallback = Callable[[str, str], None] ApplyAcceptFieldsCallback = Callable[[Any, Any], None] AfterStatusChangeCallback = Callable[[AsyncSession, Any, Any, datetime], Awaitable[None]] ApplyPatchCallback = Callable[[AsyncSession, Any, Any], Awaitable[bool]] AfterDisputeCallback = Callable[[AsyncSession, Any, Any, datetime], Awaitable[None]] AfterRerouteCallback = Callable[[AsyncSession, Any, Any, str], Awaitable[None]] def create_capability_catalog_router( get_session: Callable[..., AsyncSession], *, domain_model: type[Domain] = Domain, repo_model: type[ManagedRepo] = ManagedRepo, catalog_model: type[CapabilityCatalog] = CapabilityCatalog, catalog_create_schema: type[CatalogCreate] = CatalogCreate, catalog_patch_schema: type[CatalogPatch] = CatalogPatch, catalog_read_schema: type[CatalogRead] = CatalogRead, ) -> APIRouter: router = APIRouter(tags=["capability-requests"]) list_response_model = list[catalog_read_schema] @router.post("/capability-catalog/", response_model=catalog_read_schema, status_code=status.HTTP_201_CREATED) async def create_catalog_entry( body: catalog_create_schema, session: AsyncSession = Depends(get_session), ) -> Any: domain = await _resolve_domain(body.domain, session, domain_model) repo_id = None if body.repo_slug: repo = await _resolve_repo(body.repo_slug, session, repo_model) repo_id = repo.id entry = catalog_model( domain_id=domain.id, repo_id=repo_id, capability_type=body.capability_type, title=body.title, description=body.description, keywords=body.keywords, ) session.add(entry) try: await session.commit() except Exception: await session.rollback() raise HTTPException( status_code=409, detail=( f"Catalog entry '{body.title}' for type '{body.capability_type}' " f"already exists in domain '{body.domain}'" ), ) await session.refresh(entry) return entry @router.get("/capability-catalog/", response_model=list_response_model) async def list_catalog( domain: str | None = Query(None), capability_type: str | None = Query(None), status_filter: str | None = Query(None, alias="status"), session: AsyncSession = Depends(get_session), ) -> list[Any]: q = select(catalog_model).order_by(catalog_model.created_at.desc()) if domain: domain_obj = await _resolve_domain(domain, session, domain_model) q = q.where(catalog_model.domain_id == domain_obj.id) if capability_type: q = q.where(catalog_model.capability_type == capability_type) if status_filter and status_filter != "all": q = q.where(catalog_model.status == status_filter) elif not status_filter: q = q.where(catalog_model.status == "active") result = await session.execute(q) return list(result.scalars().all()) @router.patch("/capability-catalog/{entry_id}", response_model=catalog_read_schema) async def patch_catalog_entry( entry_id: uuid.UUID, body: catalog_patch_schema, session: AsyncSession = Depends(get_session), ) -> Any: entry = await session.get(catalog_model, entry_id) if entry is None: raise HTTPException(status_code=404, detail=f"Catalog entry '{entry_id}' not found") if body.repo_slug is not None: repo = await _resolve_repo(body.repo_slug, session, repo_model) entry.repo_id = repo.id if body.description is not None: entry.description = body.description if body.keywords is not None: entry.keywords = body.keywords if body.status is not None: entry.status = body.status await session.commit() await session.refresh(entry) return entry return router def create_capability_request_read_router( get_session: Callable[..., AsyncSession], *, domain_model: type[Domain] = Domain, request_model: type[CapabilityRequest] = CapabilityRequest, request_read_schema: type[CapabilityRequestRead] = CapabilityRequestRead, ) -> APIRouter: router = APIRouter(tags=["capability-requests"]) list_response_model = list[request_read_schema] @router.get("/capability-requests/", response_model=list_response_model) async def list_requests( domain: str | None = Query( None, description="Filter by requesting or fulfilling domain slug", ), status_filter: str | None = Query(None, alias="status"), capability_type: str | None = Query(None), session: AsyncSession = Depends(get_session), ) -> list[Any]: q = select(request_model).order_by(request_model.created_at.desc()) if domain: domain_obj = await _resolve_domain(domain, session, domain_model) q = q.where( (request_model.requesting_domain_id == domain_obj.id) | (request_model.fulfilling_domain_id == domain_obj.id) ) if status_filter: q = q.where(request_model.status == status_filter) if capability_type: q = q.where(request_model.capability_type == capability_type) result = await session.execute(q) return list(result.scalars().all()) @router.get("/capability-requests/{request_id}", response_model=request_read_schema) async def get_request( request_id: uuid.UUID, session: AsyncSession = Depends(get_session), ) -> Any: req = await session.get(request_model, request_id) if req is None: raise HTTPException(status_code=404, detail=f"Capability request '{request_id}' not found") return req return router def create_capability_request_write_router( get_session: Callable[..., AsyncSession], *, domain_model: type[Domain] = Domain, catalog_model: type[CapabilityCatalog] = CapabilityCatalog, request_model: type[CapabilityRequest] = CapabilityRequest, request_create_schema: type[CapabilityRequestCreate] = CapabilityRequestCreate, request_accept_schema: type[CapabilityRequestAccept] = CapabilityRequestAccept, request_patch_schema: type[CapabilityRequestPatch] = CapabilityRequestPatch, request_status_patch_schema: type[CapabilityRequestStatusPatch] = CapabilityRequestStatusPatch, request_dispute_schema: type[CapabilityRequestDispute] = CapabilityRequestDispute, request_reroute_schema: type[CapabilityRequestReroute] = CapabilityRequestReroute, request_read_schema: type[CapabilityRequestRead] = CapabilityRequestRead, route_request: RouteRequestCallback | None = None, build_request: BuildRequestCallback | None = None, on_request_persisted: RequestLifecycleCallback | None = None, check_transition: CheckTransitionCallback | None = None, apply_accept_fields: ApplyAcceptFieldsCallback | None = None, after_accept: RequestLifecycleCallback | None = None, after_status_change: AfterStatusChangeCallback | None = None, apply_patch: ApplyPatchCallback | None = None, after_dispute: AfterDisputeCallback | None = None, after_reroute: AfterRerouteCallback | None = None, include_reroute: bool = True, ) -> APIRouter: router = APIRouter(tags=["capability-requests"]) @router.post( "/capability-requests/", response_model=request_read_schema, status_code=status.HTTP_201_CREATED, ) async def create_request( body: request_create_schema, session: AsyncSession = Depends(get_session), ) -> Any: requesting_domain = await _resolve_domain(body.requesting_domain, session, domain_model) if route_request is not None: fulfilling_domain_id, catalog_entry_id, routing_note = await route_request(session, body) else: fulfilling_domain_id, catalog_entry_id, routing_note = await _default_route_request( session, body, catalog_model, ) if build_request is not None: req = build_request( body, requesting_domain, fulfilling_domain_id, catalog_entry_id, routing_note, ) else: req = request_model( title=body.title, description=body.description, capability_type=body.capability_type, priority=body.priority, requesting_domain_id=requesting_domain.id, requesting_agent=body.requesting_agent, request_context=body.request_context, fulfilling_domain_id=fulfilling_domain_id, catalog_entry_id=catalog_entry_id, routing_note=routing_note, ) session.add(req) if on_request_persisted is not None: await on_request_persisted(session, req, body) await session.commit() await session.refresh(req) return req @router.post("/capability-requests/{request_id}/accept", response_model=request_read_schema) async def accept_request( request_id: uuid.UUID, body: request_accept_schema, session: AsyncSession = Depends(get_session), ) -> Any: req = await _get_request_or_404(request_id, session, request_model) if check_transition is not None: check_transition(req.status, "accepted") now = datetime.now(tz=timezone.utc) req.status = "accepted" req.fulfilling_agent = body.fulfilling_agent if hasattr(body, "fulfillment_context"): req.fulfillment_context = body.fulfillment_context req.accepted_at = now if apply_accept_fields is not None: apply_accept_fields(req, body) if after_accept is not None: await after_accept(session, req, body) await session.commit() await session.refresh(req) return req @router.patch("/capability-requests/{request_id}/status", response_model=request_read_schema) async def patch_request_status( request_id: uuid.UUID, body: request_status_patch_schema, session: AsyncSession = Depends(get_session), ) -> Any: req = await _get_request_or_404(request_id, session, request_model) if check_transition is not None: check_transition(req.status, body.status) req.status = body.status if body.note: req.resolution_note = body.note now = datetime.now(tz=timezone.utc) if body.status == "completed": req.completed_at = now if after_status_change is not None: await after_status_change(session, req, body, now) await session.commit() await session.refresh(req) return req @router.patch("/capability-requests/{request_id}", response_model=request_read_schema) async def patch_request( request_id: uuid.UUID, body: request_patch_schema, session: AsyncSession = Depends(get_session), ) -> Any: req = await _get_request_or_404(request_id, session, request_model) if apply_patch is not None: changed = await apply_patch(session, req, body) if not changed: return req else: if body.catalog_entry_id is not None: catalog_entry = await _resolve_catalog_entry(body.catalog_entry_id, session, catalog_model) req.catalog_entry_id = catalog_entry.id req.fulfilling_domain_id = catalog_entry.domain_id if body.priority is not None: req.priority = body.priority if body.request_context is not None: req.request_context = body.request_context if body.fulfillment_context is not None: req.fulfillment_context = body.fulfillment_context await session.commit() await session.refresh(req) return req @router.post("/capability-requests/{request_id}/dispute", response_model=request_read_schema) async def dispute_request( request_id: uuid.UUID, body: request_dispute_schema, session: AsyncSession = Depends(get_session), ) -> Any: req = await _get_request_or_404(request_id, session, request_model) if check_transition is not None: check_transition(req.status, "routing_disputed") now = datetime.now(tz=timezone.utc) req.status = "routing_disputed" req.dispute_reason = body.reason req.disputed_by = body.disputed_by req.dispute_suggested_domain = body.suggested_domain req.disputed_at = now if after_dispute is not None: await after_dispute(session, req, body, now) await session.commit() await session.refresh(req) return req if include_reroute: @router.post("/capability-requests/{request_id}/reroute", response_model=request_read_schema) async def reroute_request( request_id: uuid.UUID, body: request_reroute_schema, session: AsyncSession = Depends(get_session), ) -> Any: req = await _get_request_or_404(request_id, session, request_model) if req.status != "routing_disputed": raise HTTPException( status_code=422, detail=( f"Cannot reroute from status '{req.status}'. " "Only 'routing_disputed' requests can be rerouted." ), ) if body.catalog_entry_id is None and body.domain is None: raise HTTPException( status_code=422, detail="Either catalog_entry_id or domain must be provided.", ) if body.catalog_entry_id is not None: entry = await _resolve_catalog_entry(body.catalog_entry_id, session, catalog_model) req.catalog_entry_id = entry.id req.fulfilling_domain_id = entry.domain_id domain_obj = await session.get(domain_model, entry.domain_id) new_domain_slug = domain_obj.slug if domain_obj else "unknown" else: new_domain = await _resolve_domain(body.domain, session, domain_model) req.fulfilling_domain_id = new_domain.id new_domain_slug = new_domain.slug req.dispute_reason = None req.disputed_by = None req.dispute_suggested_domain = None req.disputed_at = None req.status = "requested" reroute_entry = f"re-routed by {body.rerouted_by} → {new_domain_slug}: {body.note}" req.routing_note = ( (req.routing_note + "\n" + reroute_entry) if req.routing_note else reroute_entry ) if after_reroute is not None: await after_reroute(session, req, body, new_domain_slug) await session.commit() await session.refresh(req) return req return router def create_capabilities_router(get_session: Callable[..., AsyncSession]) -> APIRouter: router = APIRouter(tags=["capability-requests"]) router.include_router(create_capability_catalog_router(get_session)) router.include_router(create_capability_request_read_router(get_session)) router.include_router(create_capability_request_write_router(get_session)) return router async def _resolve_domain( slug: str, session: AsyncSession, domain_model: type[Domain], ) -> Any: result = await session.execute(select(domain_model).where(domain_model.slug == slug)) domain = result.scalar_one_or_none() if domain is None: raise HTTPException(status_code=404, detail=f"Domain '{slug}' not found") return domain async def _resolve_repo( slug: str, session: AsyncSession, repo_model: type[ManagedRepo], ) -> Any: result = await session.execute(select(repo_model).where(repo_model.slug == slug)) repo = result.scalar_one_or_none() if repo is None: raise HTTPException(status_code=404, detail=f"Repo '{slug}' not found") return repo async def _default_route_request( session: AsyncSession, body: CapabilityRequestCreate, catalog_model: type[CapabilityCatalog], ) -> tuple[uuid.UUID | None, uuid.UUID | None, str | None]: catalog_entry_id = body.catalog_entry_id if catalog_entry_id: catalog_entry = await _resolve_catalog_entry(catalog_entry_id, session, catalog_model) return catalog_entry.domain_id, catalog_entry.id, "Routed by explicit catalog entry." catalog_entry = await _find_catalog_route(body.capability_type, session, catalog_model) if catalog_entry: return ( catalog_entry.domain_id, catalog_entry.id, "Routed by first active catalog match for capability_type.", ) return None, None, None async def _resolve_catalog_entry( entry_id: uuid.UUID, session: AsyncSession, catalog_model: type[CapabilityCatalog] = CapabilityCatalog, ) -> Any: entry = await session.get(catalog_model, entry_id) if entry is None: raise HTTPException(status_code=404, detail=f"Catalog entry '{entry_id}' not found") return entry async def _find_catalog_route( capability_type: str, session: AsyncSession, catalog_model: type[CapabilityCatalog] = CapabilityCatalog, ) -> Any | None: result = await session.execute( select(catalog_model) .where(catalog_model.capability_type == capability_type) .where(catalog_model.status == "active") .order_by(catalog_model.created_at.desc()) ) return result.scalars().first() async def _get_request_or_404( request_id: uuid.UUID, session: AsyncSession, request_model: type[CapabilityRequest] = CapabilityRequest, ) -> Any: req = await session.get(request_model, request_id) if req is None: raise HTTPException(status_code=404, detail=f"Capability request '{request_id}' not found") return req