generated from coulomb/repo-seed
Add create_capability_request_write_router with host workflow callbacks, CapabilityRequestReroute schema, HubCoreMCPServer.attach_to() with CORE_TOOL_NAMES exclude filtering, tests, and mark HUB-WP-0002 finished.
499 lines
19 KiB
Python
499 lines
19 KiB
Python
import uuid
|
|
from collections.abc import Callable
|
|
from datetime import datetime, timezone
|
|
from typing import Any, Awaitable
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from hub_core.models.capability_catalog import CapabilityCatalog
|
|
from hub_core.models.capability_request import CapabilityRequest
|
|
from hub_core.models.domain import Domain
|
|
from hub_core.models.managed_repo import ManagedRepo
|
|
from hub_core.schemas.capability import (
|
|
CapabilityRequestAccept,
|
|
CapabilityRequestCreate,
|
|
CapabilityRequestDispute,
|
|
CapabilityRequestPatch,
|
|
CapabilityRequestReroute,
|
|
CapabilityRequestRead,
|
|
CapabilityRequestStatusPatch,
|
|
CatalogCreate,
|
|
CatalogPatch,
|
|
CatalogRead,
|
|
)
|
|
|
|
RouteRequestCallback = Callable[
|
|
[AsyncSession, Any],
|
|
Awaitable[tuple[uuid.UUID | None, uuid.UUID | None, str | None]],
|
|
]
|
|
BuildRequestCallback = Callable[
|
|
[Any, Any, uuid.UUID | None, uuid.UUID | None, str | None],
|
|
Any,
|
|
]
|
|
RequestLifecycleCallback = Callable[[AsyncSession, Any, Any], Awaitable[None]]
|
|
CheckTransitionCallback = Callable[[str, str], None]
|
|
ApplyAcceptFieldsCallback = Callable[[Any, Any], None]
|
|
AfterStatusChangeCallback = Callable[[AsyncSession, Any, Any, datetime], Awaitable[None]]
|
|
ApplyPatchCallback = Callable[[AsyncSession, Any, Any], Awaitable[bool]]
|
|
AfterDisputeCallback = Callable[[AsyncSession, Any, Any, datetime], Awaitable[None]]
|
|
AfterRerouteCallback = Callable[[AsyncSession, Any, Any, str], Awaitable[None]]
|
|
|
|
|
|
def create_capability_catalog_router(
|
|
get_session: Callable[..., AsyncSession],
|
|
*,
|
|
domain_model: type[Domain] = Domain,
|
|
repo_model: type[ManagedRepo] = ManagedRepo,
|
|
catalog_model: type[CapabilityCatalog] = CapabilityCatalog,
|
|
catalog_create_schema: type[CatalogCreate] = CatalogCreate,
|
|
catalog_patch_schema: type[CatalogPatch] = CatalogPatch,
|
|
catalog_read_schema: type[CatalogRead] = CatalogRead,
|
|
) -> APIRouter:
|
|
router = APIRouter(tags=["capability-requests"])
|
|
list_response_model = list[catalog_read_schema]
|
|
|
|
@router.post("/capability-catalog/", response_model=catalog_read_schema, status_code=status.HTTP_201_CREATED)
|
|
async def create_catalog_entry(
|
|
body: catalog_create_schema,
|
|
session: AsyncSession = Depends(get_session),
|
|
) -> Any:
|
|
domain = await _resolve_domain(body.domain, session, domain_model)
|
|
repo_id = None
|
|
if body.repo_slug:
|
|
repo = await _resolve_repo(body.repo_slug, session, repo_model)
|
|
repo_id = repo.id
|
|
entry = catalog_model(
|
|
domain_id=domain.id,
|
|
repo_id=repo_id,
|
|
capability_type=body.capability_type,
|
|
title=body.title,
|
|
description=body.description,
|
|
keywords=body.keywords,
|
|
)
|
|
session.add(entry)
|
|
try:
|
|
await session.commit()
|
|
except Exception:
|
|
await session.rollback()
|
|
raise HTTPException(
|
|
status_code=409,
|
|
detail=(
|
|
f"Catalog entry '{body.title}' for type '{body.capability_type}' "
|
|
f"already exists in domain '{body.domain}'"
|
|
),
|
|
)
|
|
await session.refresh(entry)
|
|
return entry
|
|
|
|
@router.get("/capability-catalog/", response_model=list_response_model)
|
|
async def list_catalog(
|
|
domain: str | None = Query(None),
|
|
capability_type: str | None = Query(None),
|
|
status_filter: str | None = Query(None, alias="status"),
|
|
session: AsyncSession = Depends(get_session),
|
|
) -> list[Any]:
|
|
q = select(catalog_model).order_by(catalog_model.created_at.desc())
|
|
if domain:
|
|
domain_obj = await _resolve_domain(domain, session, domain_model)
|
|
q = q.where(catalog_model.domain_id == domain_obj.id)
|
|
if capability_type:
|
|
q = q.where(catalog_model.capability_type == capability_type)
|
|
if status_filter and status_filter != "all":
|
|
q = q.where(catalog_model.status == status_filter)
|
|
elif not status_filter:
|
|
q = q.where(catalog_model.status == "active")
|
|
result = await session.execute(q)
|
|
return list(result.scalars().all())
|
|
|
|
@router.patch("/capability-catalog/{entry_id}", response_model=catalog_read_schema)
|
|
async def patch_catalog_entry(
|
|
entry_id: uuid.UUID,
|
|
body: catalog_patch_schema,
|
|
session: AsyncSession = Depends(get_session),
|
|
) -> Any:
|
|
entry = await session.get(catalog_model, entry_id)
|
|
if entry is None:
|
|
raise HTTPException(status_code=404, detail=f"Catalog entry '{entry_id}' not found")
|
|
if body.repo_slug is not None:
|
|
repo = await _resolve_repo(body.repo_slug, session, repo_model)
|
|
entry.repo_id = repo.id
|
|
if body.description is not None:
|
|
entry.description = body.description
|
|
if body.keywords is not None:
|
|
entry.keywords = body.keywords
|
|
if body.status is not None:
|
|
entry.status = body.status
|
|
await session.commit()
|
|
await session.refresh(entry)
|
|
return entry
|
|
|
|
return router
|
|
|
|
|
|
def create_capability_request_read_router(
|
|
get_session: Callable[..., AsyncSession],
|
|
*,
|
|
domain_model: type[Domain] = Domain,
|
|
request_model: type[CapabilityRequest] = CapabilityRequest,
|
|
request_read_schema: type[CapabilityRequestRead] = CapabilityRequestRead,
|
|
) -> APIRouter:
|
|
router = APIRouter(tags=["capability-requests"])
|
|
list_response_model = list[request_read_schema]
|
|
|
|
@router.get("/capability-requests/", response_model=list_response_model)
|
|
async def list_requests(
|
|
domain: str | None = Query(
|
|
None,
|
|
description="Filter by requesting or fulfilling domain slug",
|
|
),
|
|
status_filter: str | None = Query(None, alias="status"),
|
|
capability_type: str | None = Query(None),
|
|
session: AsyncSession = Depends(get_session),
|
|
) -> list[Any]:
|
|
q = select(request_model).order_by(request_model.created_at.desc())
|
|
if domain:
|
|
domain_obj = await _resolve_domain(domain, session, domain_model)
|
|
q = q.where(
|
|
(request_model.requesting_domain_id == domain_obj.id)
|
|
| (request_model.fulfilling_domain_id == domain_obj.id)
|
|
)
|
|
if status_filter:
|
|
q = q.where(request_model.status == status_filter)
|
|
if capability_type:
|
|
q = q.where(request_model.capability_type == capability_type)
|
|
result = await session.execute(q)
|
|
return list(result.scalars().all())
|
|
|
|
@router.get("/capability-requests/{request_id}", response_model=request_read_schema)
|
|
async def get_request(
|
|
request_id: uuid.UUID,
|
|
session: AsyncSession = Depends(get_session),
|
|
) -> Any:
|
|
req = await session.get(request_model, request_id)
|
|
if req is None:
|
|
raise HTTPException(status_code=404, detail=f"Capability request '{request_id}' not found")
|
|
return req
|
|
|
|
return router
|
|
|
|
|
|
def create_capability_request_write_router(
|
|
get_session: Callable[..., AsyncSession],
|
|
*,
|
|
domain_model: type[Domain] = Domain,
|
|
catalog_model: type[CapabilityCatalog] = CapabilityCatalog,
|
|
request_model: type[CapabilityRequest] = CapabilityRequest,
|
|
request_create_schema: type[CapabilityRequestCreate] = CapabilityRequestCreate,
|
|
request_accept_schema: type[CapabilityRequestAccept] = CapabilityRequestAccept,
|
|
request_patch_schema: type[CapabilityRequestPatch] = CapabilityRequestPatch,
|
|
request_status_patch_schema: type[CapabilityRequestStatusPatch] = CapabilityRequestStatusPatch,
|
|
request_dispute_schema: type[CapabilityRequestDispute] = CapabilityRequestDispute,
|
|
request_reroute_schema: type[CapabilityRequestReroute] = CapabilityRequestReroute,
|
|
request_read_schema: type[CapabilityRequestRead] = CapabilityRequestRead,
|
|
route_request: RouteRequestCallback | None = None,
|
|
build_request: BuildRequestCallback | None = None,
|
|
on_request_persisted: RequestLifecycleCallback | None = None,
|
|
check_transition: CheckTransitionCallback | None = None,
|
|
apply_accept_fields: ApplyAcceptFieldsCallback | None = None,
|
|
after_accept: RequestLifecycleCallback | None = None,
|
|
after_status_change: AfterStatusChangeCallback | None = None,
|
|
apply_patch: ApplyPatchCallback | None = None,
|
|
after_dispute: AfterDisputeCallback | None = None,
|
|
after_reroute: AfterRerouteCallback | None = None,
|
|
include_reroute: bool = True,
|
|
) -> APIRouter:
|
|
router = APIRouter(tags=["capability-requests"])
|
|
|
|
@router.post(
|
|
"/capability-requests/",
|
|
response_model=request_read_schema,
|
|
status_code=status.HTTP_201_CREATED,
|
|
)
|
|
async def create_request(
|
|
body: request_create_schema,
|
|
session: AsyncSession = Depends(get_session),
|
|
) -> Any:
|
|
requesting_domain = await _resolve_domain(body.requesting_domain, session, domain_model)
|
|
if route_request is not None:
|
|
fulfilling_domain_id, catalog_entry_id, routing_note = await route_request(session, body)
|
|
else:
|
|
fulfilling_domain_id, catalog_entry_id, routing_note = await _default_route_request(
|
|
session,
|
|
body,
|
|
catalog_model,
|
|
)
|
|
|
|
if build_request is not None:
|
|
req = build_request(
|
|
body,
|
|
requesting_domain,
|
|
fulfilling_domain_id,
|
|
catalog_entry_id,
|
|
routing_note,
|
|
)
|
|
else:
|
|
req = request_model(
|
|
title=body.title,
|
|
description=body.description,
|
|
capability_type=body.capability_type,
|
|
priority=body.priority,
|
|
requesting_domain_id=requesting_domain.id,
|
|
requesting_agent=body.requesting_agent,
|
|
request_context=body.request_context,
|
|
fulfilling_domain_id=fulfilling_domain_id,
|
|
catalog_entry_id=catalog_entry_id,
|
|
routing_note=routing_note,
|
|
)
|
|
|
|
session.add(req)
|
|
if on_request_persisted is not None:
|
|
await on_request_persisted(session, req, body)
|
|
await session.commit()
|
|
await session.refresh(req)
|
|
return req
|
|
|
|
@router.post("/capability-requests/{request_id}/accept", response_model=request_read_schema)
|
|
async def accept_request(
|
|
request_id: uuid.UUID,
|
|
body: request_accept_schema,
|
|
session: AsyncSession = Depends(get_session),
|
|
) -> Any:
|
|
req = await _get_request_or_404(request_id, session, request_model)
|
|
if check_transition is not None:
|
|
check_transition(req.status, "accepted")
|
|
|
|
now = datetime.now(tz=timezone.utc)
|
|
req.status = "accepted"
|
|
req.fulfilling_agent = body.fulfilling_agent
|
|
if hasattr(body, "fulfillment_context"):
|
|
req.fulfillment_context = body.fulfillment_context
|
|
req.accepted_at = now
|
|
if apply_accept_fields is not None:
|
|
apply_accept_fields(req, body)
|
|
if after_accept is not None:
|
|
await after_accept(session, req, body)
|
|
|
|
await session.commit()
|
|
await session.refresh(req)
|
|
return req
|
|
|
|
@router.patch("/capability-requests/{request_id}/status", response_model=request_read_schema)
|
|
async def patch_request_status(
|
|
request_id: uuid.UUID,
|
|
body: request_status_patch_schema,
|
|
session: AsyncSession = Depends(get_session),
|
|
) -> Any:
|
|
req = await _get_request_or_404(request_id, session, request_model)
|
|
if check_transition is not None:
|
|
check_transition(req.status, body.status)
|
|
|
|
req.status = body.status
|
|
if body.note:
|
|
req.resolution_note = body.note
|
|
|
|
now = datetime.now(tz=timezone.utc)
|
|
if body.status == "completed":
|
|
req.completed_at = now
|
|
|
|
if after_status_change is not None:
|
|
await after_status_change(session, req, body, now)
|
|
|
|
await session.commit()
|
|
await session.refresh(req)
|
|
return req
|
|
|
|
@router.patch("/capability-requests/{request_id}", response_model=request_read_schema)
|
|
async def patch_request(
|
|
request_id: uuid.UUID,
|
|
body: request_patch_schema,
|
|
session: AsyncSession = Depends(get_session),
|
|
) -> Any:
|
|
req = await _get_request_or_404(request_id, session, request_model)
|
|
if apply_patch is not None:
|
|
changed = await apply_patch(session, req, body)
|
|
if not changed:
|
|
return req
|
|
else:
|
|
if body.catalog_entry_id is not None:
|
|
catalog_entry = await _resolve_catalog_entry(body.catalog_entry_id, session, catalog_model)
|
|
req.catalog_entry_id = catalog_entry.id
|
|
req.fulfilling_domain_id = catalog_entry.domain_id
|
|
if body.priority is not None:
|
|
req.priority = body.priority
|
|
if body.request_context is not None:
|
|
req.request_context = body.request_context
|
|
if body.fulfillment_context is not None:
|
|
req.fulfillment_context = body.fulfillment_context
|
|
|
|
await session.commit()
|
|
await session.refresh(req)
|
|
return req
|
|
|
|
@router.post("/capability-requests/{request_id}/dispute", response_model=request_read_schema)
|
|
async def dispute_request(
|
|
request_id: uuid.UUID,
|
|
body: request_dispute_schema,
|
|
session: AsyncSession = Depends(get_session),
|
|
) -> Any:
|
|
req = await _get_request_or_404(request_id, session, request_model)
|
|
if check_transition is not None:
|
|
check_transition(req.status, "routing_disputed")
|
|
|
|
now = datetime.now(tz=timezone.utc)
|
|
req.status = "routing_disputed"
|
|
req.dispute_reason = body.reason
|
|
req.disputed_by = body.disputed_by
|
|
req.dispute_suggested_domain = body.suggested_domain
|
|
req.disputed_at = now
|
|
|
|
if after_dispute is not None:
|
|
await after_dispute(session, req, body, now)
|
|
|
|
await session.commit()
|
|
await session.refresh(req)
|
|
return req
|
|
|
|
if include_reroute:
|
|
|
|
@router.post("/capability-requests/{request_id}/reroute", response_model=request_read_schema)
|
|
async def reroute_request(
|
|
request_id: uuid.UUID,
|
|
body: request_reroute_schema,
|
|
session: AsyncSession = Depends(get_session),
|
|
) -> Any:
|
|
req = await _get_request_or_404(request_id, session, request_model)
|
|
if req.status != "routing_disputed":
|
|
raise HTTPException(
|
|
status_code=422,
|
|
detail=(
|
|
f"Cannot reroute from status '{req.status}'. "
|
|
"Only 'routing_disputed' requests can be rerouted."
|
|
),
|
|
)
|
|
if body.catalog_entry_id is None and body.domain is None:
|
|
raise HTTPException(
|
|
status_code=422,
|
|
detail="Either catalog_entry_id or domain must be provided.",
|
|
)
|
|
|
|
if body.catalog_entry_id is not None:
|
|
entry = await _resolve_catalog_entry(body.catalog_entry_id, session, catalog_model)
|
|
req.catalog_entry_id = entry.id
|
|
req.fulfilling_domain_id = entry.domain_id
|
|
domain_obj = await session.get(domain_model, entry.domain_id)
|
|
new_domain_slug = domain_obj.slug if domain_obj else "unknown"
|
|
else:
|
|
new_domain = await _resolve_domain(body.domain, session, domain_model)
|
|
req.fulfilling_domain_id = new_domain.id
|
|
new_domain_slug = new_domain.slug
|
|
|
|
req.dispute_reason = None
|
|
req.disputed_by = None
|
|
req.dispute_suggested_domain = None
|
|
req.disputed_at = None
|
|
req.status = "requested"
|
|
|
|
reroute_entry = f"re-routed by {body.rerouted_by} → {new_domain_slug}: {body.note}"
|
|
req.routing_note = (
|
|
(req.routing_note + "\n" + reroute_entry) if req.routing_note else reroute_entry
|
|
)
|
|
|
|
if after_reroute is not None:
|
|
await after_reroute(session, req, body, new_domain_slug)
|
|
|
|
await session.commit()
|
|
await session.refresh(req)
|
|
return req
|
|
|
|
return router
|
|
|
|
|
|
def create_capabilities_router(get_session: Callable[..., AsyncSession]) -> APIRouter:
|
|
router = APIRouter(tags=["capability-requests"])
|
|
router.include_router(create_capability_catalog_router(get_session))
|
|
router.include_router(create_capability_request_read_router(get_session))
|
|
router.include_router(create_capability_request_write_router(get_session))
|
|
return router
|
|
|
|
|
|
async def _resolve_domain(
|
|
slug: str,
|
|
session: AsyncSession,
|
|
domain_model: type[Domain],
|
|
) -> Any:
|
|
result = await session.execute(select(domain_model).where(domain_model.slug == slug))
|
|
domain = result.scalar_one_or_none()
|
|
if domain is None:
|
|
raise HTTPException(status_code=404, detail=f"Domain '{slug}' not found")
|
|
return domain
|
|
|
|
|
|
async def _resolve_repo(
|
|
slug: str,
|
|
session: AsyncSession,
|
|
repo_model: type[ManagedRepo],
|
|
) -> Any:
|
|
result = await session.execute(select(repo_model).where(repo_model.slug == slug))
|
|
repo = result.scalar_one_or_none()
|
|
if repo is None:
|
|
raise HTTPException(status_code=404, detail=f"Repo '{slug}' not found")
|
|
return repo
|
|
|
|
|
|
async def _default_route_request(
|
|
session: AsyncSession,
|
|
body: CapabilityRequestCreate,
|
|
catalog_model: type[CapabilityCatalog],
|
|
) -> tuple[uuid.UUID | None, uuid.UUID | None, str | None]:
|
|
catalog_entry_id = body.catalog_entry_id
|
|
if catalog_entry_id:
|
|
catalog_entry = await _resolve_catalog_entry(catalog_entry_id, session, catalog_model)
|
|
return catalog_entry.domain_id, catalog_entry.id, "Routed by explicit catalog entry."
|
|
|
|
catalog_entry = await _find_catalog_route(body.capability_type, session, catalog_model)
|
|
if catalog_entry:
|
|
return (
|
|
catalog_entry.domain_id,
|
|
catalog_entry.id,
|
|
"Routed by first active catalog match for capability_type.",
|
|
)
|
|
return None, None, None
|
|
|
|
|
|
async def _resolve_catalog_entry(
|
|
entry_id: uuid.UUID,
|
|
session: AsyncSession,
|
|
catalog_model: type[CapabilityCatalog] = CapabilityCatalog,
|
|
) -> Any:
|
|
entry = await session.get(catalog_model, entry_id)
|
|
if entry is None:
|
|
raise HTTPException(status_code=404, detail=f"Catalog entry '{entry_id}' not found")
|
|
return entry
|
|
|
|
|
|
async def _find_catalog_route(
|
|
capability_type: str,
|
|
session: AsyncSession,
|
|
catalog_model: type[CapabilityCatalog] = CapabilityCatalog,
|
|
) -> Any | None:
|
|
result = await session.execute(
|
|
select(catalog_model)
|
|
.where(catalog_model.capability_type == capability_type)
|
|
.where(catalog_model.status == "active")
|
|
.order_by(catalog_model.created_at.desc())
|
|
)
|
|
return result.scalars().first()
|
|
|
|
|
|
async def _get_request_or_404(
|
|
request_id: uuid.UUID,
|
|
session: AsyncSession,
|
|
request_model: type[CapabilityRequest] = CapabilityRequest,
|
|
) -> Any:
|
|
req = await session.get(request_model, request_id)
|
|
if req is None:
|
|
raise HTTPException(status_code=404, detail=f"Capability request '{request_id}' not found")
|
|
return req
|