Files
hub-core/hub_core/routers/capabilities.py
tegwick af28282861 feat(capabilities): add write router factory and MCP composition (HUB-WP-0002)
Add create_capability_request_write_router with host workflow callbacks,
CapabilityRequestReroute schema, HubCoreMCPServer.attach_to() with CORE_TOOL_NAMES
exclude filtering, tests, and mark HUB-WP-0002 finished.
2026-06-22 19:52:22 +02:00

499 lines
19 KiB
Python

import uuid
from collections.abc import Callable
from datetime import datetime, timezone
from typing import Any, Awaitable
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from hub_core.models.capability_catalog import CapabilityCatalog
from hub_core.models.capability_request import CapabilityRequest
from hub_core.models.domain import Domain
from hub_core.models.managed_repo import ManagedRepo
from hub_core.schemas.capability import (
CapabilityRequestAccept,
CapabilityRequestCreate,
CapabilityRequestDispute,
CapabilityRequestPatch,
CapabilityRequestReroute,
CapabilityRequestRead,
CapabilityRequestStatusPatch,
CatalogCreate,
CatalogPatch,
CatalogRead,
)
RouteRequestCallback = Callable[
[AsyncSession, Any],
Awaitable[tuple[uuid.UUID | None, uuid.UUID | None, str | None]],
]
BuildRequestCallback = Callable[
[Any, Any, uuid.UUID | None, uuid.UUID | None, str | None],
Any,
]
RequestLifecycleCallback = Callable[[AsyncSession, Any, Any], Awaitable[None]]
CheckTransitionCallback = Callable[[str, str], None]
ApplyAcceptFieldsCallback = Callable[[Any, Any], None]
AfterStatusChangeCallback = Callable[[AsyncSession, Any, Any, datetime], Awaitable[None]]
ApplyPatchCallback = Callable[[AsyncSession, Any, Any], Awaitable[bool]]
AfterDisputeCallback = Callable[[AsyncSession, Any, Any, datetime], Awaitable[None]]
AfterRerouteCallback = Callable[[AsyncSession, Any, Any, str], Awaitable[None]]
def create_capability_catalog_router(
get_session: Callable[..., AsyncSession],
*,
domain_model: type[Domain] = Domain,
repo_model: type[ManagedRepo] = ManagedRepo,
catalog_model: type[CapabilityCatalog] = CapabilityCatalog,
catalog_create_schema: type[CatalogCreate] = CatalogCreate,
catalog_patch_schema: type[CatalogPatch] = CatalogPatch,
catalog_read_schema: type[CatalogRead] = CatalogRead,
) -> APIRouter:
router = APIRouter(tags=["capability-requests"])
list_response_model = list[catalog_read_schema]
@router.post("/capability-catalog/", response_model=catalog_read_schema, status_code=status.HTTP_201_CREATED)
async def create_catalog_entry(
body: catalog_create_schema,
session: AsyncSession = Depends(get_session),
) -> Any:
domain = await _resolve_domain(body.domain, session, domain_model)
repo_id = None
if body.repo_slug:
repo = await _resolve_repo(body.repo_slug, session, repo_model)
repo_id = repo.id
entry = catalog_model(
domain_id=domain.id,
repo_id=repo_id,
capability_type=body.capability_type,
title=body.title,
description=body.description,
keywords=body.keywords,
)
session.add(entry)
try:
await session.commit()
except Exception:
await session.rollback()
raise HTTPException(
status_code=409,
detail=(
f"Catalog entry '{body.title}' for type '{body.capability_type}' "
f"already exists in domain '{body.domain}'"
),
)
await session.refresh(entry)
return entry
@router.get("/capability-catalog/", response_model=list_response_model)
async def list_catalog(
domain: str | None = Query(None),
capability_type: str | None = Query(None),
status_filter: str | None = Query(None, alias="status"),
session: AsyncSession = Depends(get_session),
) -> list[Any]:
q = select(catalog_model).order_by(catalog_model.created_at.desc())
if domain:
domain_obj = await _resolve_domain(domain, session, domain_model)
q = q.where(catalog_model.domain_id == domain_obj.id)
if capability_type:
q = q.where(catalog_model.capability_type == capability_type)
if status_filter and status_filter != "all":
q = q.where(catalog_model.status == status_filter)
elif not status_filter:
q = q.where(catalog_model.status == "active")
result = await session.execute(q)
return list(result.scalars().all())
@router.patch("/capability-catalog/{entry_id}", response_model=catalog_read_schema)
async def patch_catalog_entry(
entry_id: uuid.UUID,
body: catalog_patch_schema,
session: AsyncSession = Depends(get_session),
) -> Any:
entry = await session.get(catalog_model, entry_id)
if entry is None:
raise HTTPException(status_code=404, detail=f"Catalog entry '{entry_id}' not found")
if body.repo_slug is not None:
repo = await _resolve_repo(body.repo_slug, session, repo_model)
entry.repo_id = repo.id
if body.description is not None:
entry.description = body.description
if body.keywords is not None:
entry.keywords = body.keywords
if body.status is not None:
entry.status = body.status
await session.commit()
await session.refresh(entry)
return entry
return router
def create_capability_request_read_router(
get_session: Callable[..., AsyncSession],
*,
domain_model: type[Domain] = Domain,
request_model: type[CapabilityRequest] = CapabilityRequest,
request_read_schema: type[CapabilityRequestRead] = CapabilityRequestRead,
) -> APIRouter:
router = APIRouter(tags=["capability-requests"])
list_response_model = list[request_read_schema]
@router.get("/capability-requests/", response_model=list_response_model)
async def list_requests(
domain: str | None = Query(
None,
description="Filter by requesting or fulfilling domain slug",
),
status_filter: str | None = Query(None, alias="status"),
capability_type: str | None = Query(None),
session: AsyncSession = Depends(get_session),
) -> list[Any]:
q = select(request_model).order_by(request_model.created_at.desc())
if domain:
domain_obj = await _resolve_domain(domain, session, domain_model)
q = q.where(
(request_model.requesting_domain_id == domain_obj.id)
| (request_model.fulfilling_domain_id == domain_obj.id)
)
if status_filter:
q = q.where(request_model.status == status_filter)
if capability_type:
q = q.where(request_model.capability_type == capability_type)
result = await session.execute(q)
return list(result.scalars().all())
@router.get("/capability-requests/{request_id}", response_model=request_read_schema)
async def get_request(
request_id: uuid.UUID,
session: AsyncSession = Depends(get_session),
) -> Any:
req = await session.get(request_model, request_id)
if req is None:
raise HTTPException(status_code=404, detail=f"Capability request '{request_id}' not found")
return req
return router
def create_capability_request_write_router(
get_session: Callable[..., AsyncSession],
*,
domain_model: type[Domain] = Domain,
catalog_model: type[CapabilityCatalog] = CapabilityCatalog,
request_model: type[CapabilityRequest] = CapabilityRequest,
request_create_schema: type[CapabilityRequestCreate] = CapabilityRequestCreate,
request_accept_schema: type[CapabilityRequestAccept] = CapabilityRequestAccept,
request_patch_schema: type[CapabilityRequestPatch] = CapabilityRequestPatch,
request_status_patch_schema: type[CapabilityRequestStatusPatch] = CapabilityRequestStatusPatch,
request_dispute_schema: type[CapabilityRequestDispute] = CapabilityRequestDispute,
request_reroute_schema: type[CapabilityRequestReroute] = CapabilityRequestReroute,
request_read_schema: type[CapabilityRequestRead] = CapabilityRequestRead,
route_request: RouteRequestCallback | None = None,
build_request: BuildRequestCallback | None = None,
on_request_persisted: RequestLifecycleCallback | None = None,
check_transition: CheckTransitionCallback | None = None,
apply_accept_fields: ApplyAcceptFieldsCallback | None = None,
after_accept: RequestLifecycleCallback | None = None,
after_status_change: AfterStatusChangeCallback | None = None,
apply_patch: ApplyPatchCallback | None = None,
after_dispute: AfterDisputeCallback | None = None,
after_reroute: AfterRerouteCallback | None = None,
include_reroute: bool = True,
) -> APIRouter:
router = APIRouter(tags=["capability-requests"])
@router.post(
"/capability-requests/",
response_model=request_read_schema,
status_code=status.HTTP_201_CREATED,
)
async def create_request(
body: request_create_schema,
session: AsyncSession = Depends(get_session),
) -> Any:
requesting_domain = await _resolve_domain(body.requesting_domain, session, domain_model)
if route_request is not None:
fulfilling_domain_id, catalog_entry_id, routing_note = await route_request(session, body)
else:
fulfilling_domain_id, catalog_entry_id, routing_note = await _default_route_request(
session,
body,
catalog_model,
)
if build_request is not None:
req = build_request(
body,
requesting_domain,
fulfilling_domain_id,
catalog_entry_id,
routing_note,
)
else:
req = request_model(
title=body.title,
description=body.description,
capability_type=body.capability_type,
priority=body.priority,
requesting_domain_id=requesting_domain.id,
requesting_agent=body.requesting_agent,
request_context=body.request_context,
fulfilling_domain_id=fulfilling_domain_id,
catalog_entry_id=catalog_entry_id,
routing_note=routing_note,
)
session.add(req)
if on_request_persisted is not None:
await on_request_persisted(session, req, body)
await session.commit()
await session.refresh(req)
return req
@router.post("/capability-requests/{request_id}/accept", response_model=request_read_schema)
async def accept_request(
request_id: uuid.UUID,
body: request_accept_schema,
session: AsyncSession = Depends(get_session),
) -> Any:
req = await _get_request_or_404(request_id, session, request_model)
if check_transition is not None:
check_transition(req.status, "accepted")
now = datetime.now(tz=timezone.utc)
req.status = "accepted"
req.fulfilling_agent = body.fulfilling_agent
if hasattr(body, "fulfillment_context"):
req.fulfillment_context = body.fulfillment_context
req.accepted_at = now
if apply_accept_fields is not None:
apply_accept_fields(req, body)
if after_accept is not None:
await after_accept(session, req, body)
await session.commit()
await session.refresh(req)
return req
@router.patch("/capability-requests/{request_id}/status", response_model=request_read_schema)
async def patch_request_status(
request_id: uuid.UUID,
body: request_status_patch_schema,
session: AsyncSession = Depends(get_session),
) -> Any:
req = await _get_request_or_404(request_id, session, request_model)
if check_transition is not None:
check_transition(req.status, body.status)
req.status = body.status
if body.note:
req.resolution_note = body.note
now = datetime.now(tz=timezone.utc)
if body.status == "completed":
req.completed_at = now
if after_status_change is not None:
await after_status_change(session, req, body, now)
await session.commit()
await session.refresh(req)
return req
@router.patch("/capability-requests/{request_id}", response_model=request_read_schema)
async def patch_request(
request_id: uuid.UUID,
body: request_patch_schema,
session: AsyncSession = Depends(get_session),
) -> Any:
req = await _get_request_or_404(request_id, session, request_model)
if apply_patch is not None:
changed = await apply_patch(session, req, body)
if not changed:
return req
else:
if body.catalog_entry_id is not None:
catalog_entry = await _resolve_catalog_entry(body.catalog_entry_id, session, catalog_model)
req.catalog_entry_id = catalog_entry.id
req.fulfilling_domain_id = catalog_entry.domain_id
if body.priority is not None:
req.priority = body.priority
if body.request_context is not None:
req.request_context = body.request_context
if body.fulfillment_context is not None:
req.fulfillment_context = body.fulfillment_context
await session.commit()
await session.refresh(req)
return req
@router.post("/capability-requests/{request_id}/dispute", response_model=request_read_schema)
async def dispute_request(
request_id: uuid.UUID,
body: request_dispute_schema,
session: AsyncSession = Depends(get_session),
) -> Any:
req = await _get_request_or_404(request_id, session, request_model)
if check_transition is not None:
check_transition(req.status, "routing_disputed")
now = datetime.now(tz=timezone.utc)
req.status = "routing_disputed"
req.dispute_reason = body.reason
req.disputed_by = body.disputed_by
req.dispute_suggested_domain = body.suggested_domain
req.disputed_at = now
if after_dispute is not None:
await after_dispute(session, req, body, now)
await session.commit()
await session.refresh(req)
return req
if include_reroute:
@router.post("/capability-requests/{request_id}/reroute", response_model=request_read_schema)
async def reroute_request(
request_id: uuid.UUID,
body: request_reroute_schema,
session: AsyncSession = Depends(get_session),
) -> Any:
req = await _get_request_or_404(request_id, session, request_model)
if req.status != "routing_disputed":
raise HTTPException(
status_code=422,
detail=(
f"Cannot reroute from status '{req.status}'. "
"Only 'routing_disputed' requests can be rerouted."
),
)
if body.catalog_entry_id is None and body.domain is None:
raise HTTPException(
status_code=422,
detail="Either catalog_entry_id or domain must be provided.",
)
if body.catalog_entry_id is not None:
entry = await _resolve_catalog_entry(body.catalog_entry_id, session, catalog_model)
req.catalog_entry_id = entry.id
req.fulfilling_domain_id = entry.domain_id
domain_obj = await session.get(domain_model, entry.domain_id)
new_domain_slug = domain_obj.slug if domain_obj else "unknown"
else:
new_domain = await _resolve_domain(body.domain, session, domain_model)
req.fulfilling_domain_id = new_domain.id
new_domain_slug = new_domain.slug
req.dispute_reason = None
req.disputed_by = None
req.dispute_suggested_domain = None
req.disputed_at = None
req.status = "requested"
reroute_entry = f"re-routed by {body.rerouted_by}{new_domain_slug}: {body.note}"
req.routing_note = (
(req.routing_note + "\n" + reroute_entry) if req.routing_note else reroute_entry
)
if after_reroute is not None:
await after_reroute(session, req, body, new_domain_slug)
await session.commit()
await session.refresh(req)
return req
return router
def create_capabilities_router(get_session: Callable[..., AsyncSession]) -> APIRouter:
router = APIRouter(tags=["capability-requests"])
router.include_router(create_capability_catalog_router(get_session))
router.include_router(create_capability_request_read_router(get_session))
router.include_router(create_capability_request_write_router(get_session))
return router
async def _resolve_domain(
slug: str,
session: AsyncSession,
domain_model: type[Domain],
) -> Any:
result = await session.execute(select(domain_model).where(domain_model.slug == slug))
domain = result.scalar_one_or_none()
if domain is None:
raise HTTPException(status_code=404, detail=f"Domain '{slug}' not found")
return domain
async def _resolve_repo(
slug: str,
session: AsyncSession,
repo_model: type[ManagedRepo],
) -> Any:
result = await session.execute(select(repo_model).where(repo_model.slug == slug))
repo = result.scalar_one_or_none()
if repo is None:
raise HTTPException(status_code=404, detail=f"Repo '{slug}' not found")
return repo
async def _default_route_request(
session: AsyncSession,
body: CapabilityRequestCreate,
catalog_model: type[CapabilityCatalog],
) -> tuple[uuid.UUID | None, uuid.UUID | None, str | None]:
catalog_entry_id = body.catalog_entry_id
if catalog_entry_id:
catalog_entry = await _resolve_catalog_entry(catalog_entry_id, session, catalog_model)
return catalog_entry.domain_id, catalog_entry.id, "Routed by explicit catalog entry."
catalog_entry = await _find_catalog_route(body.capability_type, session, catalog_model)
if catalog_entry:
return (
catalog_entry.domain_id,
catalog_entry.id,
"Routed by first active catalog match for capability_type.",
)
return None, None, None
async def _resolve_catalog_entry(
entry_id: uuid.UUID,
session: AsyncSession,
catalog_model: type[CapabilityCatalog] = CapabilityCatalog,
) -> Any:
entry = await session.get(catalog_model, entry_id)
if entry is None:
raise HTTPException(status_code=404, detail=f"Catalog entry '{entry_id}' not found")
return entry
async def _find_catalog_route(
capability_type: str,
session: AsyncSession,
catalog_model: type[CapabilityCatalog] = CapabilityCatalog,
) -> Any | None:
result = await session.execute(
select(catalog_model)
.where(catalog_model.capability_type == capability_type)
.where(catalog_model.status == "active")
.order_by(catalog_model.created_at.desc())
)
return result.scalars().first()
async def _get_request_or_404(
request_id: uuid.UUID,
session: AsyncSession,
request_model: type[CapabilityRequest] = CapabilityRequest,
) -> Any:
req = await session.get(request_model, request_id)
if req is None:
raise HTTPException(status_code=404, detail=f"Capability request '{request_id}' not found")
return req