diff --git a/api/routers/capability_requests.py b/api/routers/capability_requests.py index 31b80ac..4f8f025 100644 --- a/api/routers/capability_requests.py +++ b/api/routers/capability_requests.py @@ -2,7 +2,7 @@ import re import uuid from datetime import datetime, timezone -from fastapi import APIRouter, Depends, HTTPException, Query, status +from fastapi import Depends, HTTPException, Query, status from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession @@ -15,9 +15,6 @@ from api.models.domain import Domain from api.models.managed_repo import ManagedRepo from api.models.task import Task from api.schemas.capability_request import ( - CatalogCreate, - CatalogPatch, - CatalogRead, CapabilityRequestAccept, CapabilityRequestCreate, CapabilityRequestDispute, @@ -26,90 +23,15 @@ from api.schemas.capability_request import ( CapabilityRequestReroute, CapabilityRequestStatusPatch, ) - -router = APIRouter(tags=["capability-requests"]) - -# --------------------------------------------------------------------------- -# Capability Catalog endpoints -# --------------------------------------------------------------------------- - -@router.post("/capability-catalog/", response_model=CatalogRead, status_code=status.HTTP_201_CREATED) -async def create_catalog_entry( - body: CatalogCreate, - session: AsyncSession = Depends(get_session), -) -> CapabilityCatalog: - domain = await _resolve_domain(body.domain, session) - - repo_id = None - if body.repo_slug: - repo = await _resolve_repo(body.repo_slug, session) - repo_id = repo.id - - entry = CapabilityCatalog( - domain_id=domain.id, - repo_id=repo_id, - capability_type=body.capability_type, - title=body.title, - description=body.description, - keywords=body.keywords, - ) - session.add(entry) - try: - await session.commit() - except Exception: - await session.rollback() - raise HTTPException( - status_code=409, - detail=f"Catalog entry '{body.title}' for type '{body.capability_type}' already exists in domain '{body.domain}'", - ) - await session.refresh(entry) - return entry +from hub_core.routers.capabilities import create_capability_catalog_router -@router.patch("/capability-catalog/{entry_id}", response_model=CatalogRead) -async def patch_catalog_entry( - entry_id: uuid.UUID, - body: CatalogPatch, - session: AsyncSession = Depends(get_session), -) -> CapabilityCatalog: - entry = await session.get(CapabilityCatalog, entry_id) - if entry is None: - raise HTTPException(status_code=404, detail=f"Catalog entry '{entry_id}' not found") - - if body.repo_slug is not None: - repo = await _resolve_repo(body.repo_slug, session) - entry.repo_id = repo.id - if body.description is not None: - entry.description = body.description - if body.keywords is not None: - entry.keywords = body.keywords - if body.status is not None: - entry.status = body.status - - await session.commit() - await session.refresh(entry) - return entry - - -@router.get("/capability-catalog/", response_model=list[CatalogRead]) -async def list_catalog( - domain: str | None = Query(None), - capability_type: str | None = Query(None), - status_filter: str | None = Query(None, alias="status"), - session: AsyncSession = Depends(get_session), -) -> list[CapabilityCatalog]: - q = select(CapabilityCatalog).order_by(CapabilityCatalog.created_at.desc()) - if domain: - d = await _resolve_domain(domain, session) - q = q.where(CapabilityCatalog.domain_id == d.id) - if capability_type: - q = q.where(CapabilityCatalog.capability_type == capability_type) - if status_filter and status_filter != "all": - q = q.where(CapabilityCatalog.status == status_filter) - elif not status_filter: - q = q.where(CapabilityCatalog.status == "active") - result = await session.execute(q) - return list(result.scalars().all()) +router = create_capability_catalog_router( + get_session, + domain_model=Domain, + repo_model=ManagedRepo, + catalog_model=CapabilityCatalog, +) # --------------------------------------------------------------------------- @@ -571,14 +493,6 @@ async def _resolve_domain(slug: str, session: AsyncSession) -> Domain: return domain -async def _resolve_repo(slug: str, session: AsyncSession) -> ManagedRepo: - result = await session.execute(select(ManagedRepo).where(ManagedRepo.slug == slug)) - repo = result.scalar_one_or_none() - if repo is None: - raise HTTPException(status_code=404, detail=f"Repo '{slug}' not found") - return repo - - async def _get_request_or_404(request_id: uuid.UUID, session: AsyncSession) -> CapabilityRequest: req = await session.get(CapabilityRequest, request_id) if req is None: