generated from coulomb/repo-seed
refactor(hub-core): mount capability write router and compose MCP tools
Use create_capability_request_write_router with dev-hub callbacks and attach generic HubCoreMCPServer tools while keeping enriched local overrides.
This commit is contained in:
@@ -1,8 +1,7 @@
|
||||
import re
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
@@ -26,47 +25,40 @@ from api.schemas.capability_request import (
|
||||
from hub_core.routers.capabilities import (
|
||||
create_capability_catalog_router,
|
||||
create_capability_request_read_router,
|
||||
)
|
||||
|
||||
|
||||
router = create_capability_catalog_router(
|
||||
get_session,
|
||||
domain_model=Domain,
|
||||
repo_model=ManagedRepo,
|
||||
catalog_model=CapabilityCatalog,
|
||||
)
|
||||
router.include_router(
|
||||
create_capability_request_read_router(
|
||||
get_session,
|
||||
domain_model=Domain,
|
||||
request_model=CapabilityRequest,
|
||||
request_read_schema=CapabilityRequestRead,
|
||||
)
|
||||
create_capability_request_write_router,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Capability Request endpoints
|
||||
# Write-router callbacks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/capability-requests/", response_model=CapabilityRequestRead, status_code=status.HTTP_201_CREATED)
|
||||
async def create_request(
|
||||
async def _route_capability(
|
||||
session: AsyncSession,
|
||||
body: CapabilityRequestCreate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> CapabilityRequest:
|
||||
req_domain = await _resolve_domain(body.requesting_domain, session)
|
||||
|
||||
# Route to provider
|
||||
fulfilling_domain_id, catalog_entry_id, routing_note = await _route_capability(
|
||||
session, body.capability_type, body.title, body.description or ""
|
||||
) -> tuple[uuid.UUID | None, uuid.UUID | None, str | None]:
|
||||
fulfilling_domain_id, catalog_entry_id, routing_note = await _route_capability_match(
|
||||
session,
|
||||
body.capability_type,
|
||||
body.title,
|
||||
body.description or "",
|
||||
)
|
||||
return fulfilling_domain_id, catalog_entry_id, routing_note
|
||||
|
||||
req = CapabilityRequest(
|
||||
|
||||
def _build_capability_request(
|
||||
body: CapabilityRequestCreate,
|
||||
requesting_domain: Domain,
|
||||
fulfilling_domain_id: uuid.UUID | None,
|
||||
catalog_entry_id: uuid.UUID | None,
|
||||
routing_note: str | None,
|
||||
) -> CapabilityRequest:
|
||||
return CapabilityRequest(
|
||||
title=body.title,
|
||||
description=body.description,
|
||||
capability_type=body.capability_type,
|
||||
priority=body.priority,
|
||||
requesting_domain_id=req_domain.id,
|
||||
requesting_domain_id=requesting_domain.id,
|
||||
requesting_agent=body.requesting_agent,
|
||||
requesting_workplan_id=body.requesting_workplan_id,
|
||||
blocking_task_id=body.blocking_task_id,
|
||||
@@ -74,12 +66,17 @@ async def create_request(
|
||||
catalog_entry_id=catalog_entry_id,
|
||||
routing_note=routing_note,
|
||||
)
|
||||
session.add(req)
|
||||
await session.flush() # get req.id before creating notification
|
||||
|
||||
# Auto-notify
|
||||
if fulfilling_domain_id:
|
||||
ful_domain = await session.get(Domain, fulfilling_domain_id)
|
||||
|
||||
async def _notify_on_create(
|
||||
session: AsyncSession,
|
||||
req: CapabilityRequest,
|
||||
body: CapabilityRequestCreate,
|
||||
) -> None:
|
||||
await session.flush()
|
||||
|
||||
if req.fulfilling_domain_id:
|
||||
ful_domain = await session.get(Domain, req.fulfilling_domain_id)
|
||||
to_agent = ful_domain.slug if ful_domain else "broadcast"
|
||||
else:
|
||||
to_agent = "broadcast"
|
||||
@@ -98,29 +95,16 @@ async def create_request(
|
||||
),
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(req)
|
||||
return req
|
||||
|
||||
|
||||
@router.post("/capability-requests/{request_id}/accept", response_model=CapabilityRequestRead)
|
||||
async def accept_request(
|
||||
request_id: uuid.UUID,
|
||||
body: CapabilityRequestAccept,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> CapabilityRequest:
|
||||
req = await _get_request_or_404(request_id, session)
|
||||
_check_transition(req.status, "accepted")
|
||||
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
req.status = "accepted"
|
||||
req.fulfilling_agent = body.fulfilling_agent
|
||||
def _apply_accept_fields(req: CapabilityRequest, body: CapabilityRequestAccept) -> None:
|
||||
req.fulfilling_workplan_id = body.fulfilling_workplan_id
|
||||
req.accepted_at = now
|
||||
|
||||
# If no fulfilling domain was set by routing, infer from the accepting agent's context
|
||||
# (The agent can also PATCH it later if needed)
|
||||
|
||||
async def _notify_on_accept(
|
||||
session: AsyncSession,
|
||||
req: CapabilityRequest,
|
||||
body: CapabilityRequestAccept,
|
||||
) -> None:
|
||||
_add_notification(
|
||||
session,
|
||||
from_agent=body.fulfilling_agent,
|
||||
@@ -129,30 +113,14 @@ async def accept_request(
|
||||
body=f"Your capability request **{req.title}** has been accepted by **{body.fulfilling_agent}**.",
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(req)
|
||||
return req
|
||||
|
||||
|
||||
@router.patch("/capability-requests/{request_id}/status", response_model=CapabilityRequestRead)
|
||||
async def patch_request_status(
|
||||
request_id: uuid.UUID,
|
||||
async def _on_status_change(
|
||||
session: AsyncSession,
|
||||
req: CapabilityRequest,
|
||||
body: CapabilityRequestStatusPatch,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> CapabilityRequest:
|
||||
req = await _get_request_or_404(request_id, session)
|
||||
_check_transition(req.status, body.status)
|
||||
|
||||
req.status = body.status
|
||||
if body.note:
|
||||
req.resolution_note = body.note
|
||||
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
|
||||
# Status-specific side effects
|
||||
now: datetime,
|
||||
) -> None:
|
||||
if body.status == "completed":
|
||||
req.completed_at = now
|
||||
# Auto-unblock the blocking task
|
||||
if req.blocking_task_id:
|
||||
task = await session.get(Task, req.blocking_task_id)
|
||||
if task and task.status == "wait":
|
||||
@@ -200,23 +168,12 @@ async def patch_request_status(
|
||||
body=f"Work on capability **{req.title}** is now in progress.",
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(req)
|
||||
return req
|
||||
|
||||
|
||||
@router.patch("/capability-requests/{request_id}", response_model=CapabilityRequestRead)
|
||||
async def patch_request(
|
||||
request_id: uuid.UUID,
|
||||
async def _apply_capability_patch(
|
||||
session: AsyncSession,
|
||||
req: CapabilityRequest,
|
||||
body: CapabilityRequestPatch,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> CapabilityRequest:
|
||||
"""Correct mutable metadata: catalog_entry_id (re-derives fulfilling domain),
|
||||
priority, blocking_task_id, fulfilling_workplan_id.
|
||||
Only fields present in the request body (non-None) are updated.
|
||||
"""
|
||||
req = await _get_request_or_404(request_id, session)
|
||||
|
||||
) -> bool:
|
||||
corrections: list[str] = []
|
||||
|
||||
if body.catalog_entry_id is not None:
|
||||
@@ -225,8 +182,6 @@ async def patch_request(
|
||||
if entry is None:
|
||||
raise HTTPException(status_code=404, detail=f"Catalog entry '{body.catalog_entry_id}' not found")
|
||||
req.catalog_entry_id = entry.id
|
||||
# Re-derive fulfilling domain from catalog entry
|
||||
old_domain_id = req.fulfilling_domain_id
|
||||
req.fulfilling_domain_id = entry.domain_id
|
||||
corrections.append(
|
||||
f"catalog_entry: {old_entry_id} → {entry.id} ({entry.title}); "
|
||||
@@ -246,44 +201,25 @@ async def patch_request(
|
||||
corrections.append(f"fulfilling_workplan_id → {body.fulfilling_workplan_id}")
|
||||
|
||||
if not corrections:
|
||||
return req # no-op
|
||||
return False
|
||||
|
||||
correction_note = "hub correction: " + "; ".join(corrections)
|
||||
req.routing_note = (req.routing_note + "\n" + correction_note) if req.routing_note else correction_note
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(req)
|
||||
return req
|
||||
return True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dispute endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@router.post("/capability-requests/{request_id}/dispute", response_model=CapabilityRequestRead)
|
||||
async def dispute_request(
|
||||
request_id: uuid.UUID,
|
||||
async def _notify_on_dispute(
|
||||
session: AsyncSession,
|
||||
req: CapabilityRequest,
|
||||
body: CapabilityRequestDispute,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> CapabilityRequest:
|
||||
"""Flag a routing decision as incorrect. Transitions to routing_disputed."""
|
||||
req = await _get_request_or_404(request_id, session)
|
||||
_check_transition(req.status, "routing_disputed")
|
||||
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
req.status = "routing_disputed"
|
||||
req.dispute_reason = body.reason
|
||||
req.disputed_by = body.disputed_by
|
||||
req.dispute_suggested_domain = body.suggested_domain
|
||||
req.disputed_at = now
|
||||
|
||||
now: datetime,
|
||||
) -> None:
|
||||
dispute_entry = (
|
||||
f"disputed by {body.disputed_by}: {body.reason}"
|
||||
+ (f" (suggested: {body.suggested_domain})" if body.suggested_domain else "")
|
||||
)
|
||||
req.routing_note = (req.routing_note + "\n" + dispute_entry) if req.routing_note else dispute_entry
|
||||
|
||||
# Notify custodian
|
||||
_add_notification(
|
||||
session,
|
||||
from_agent=body.disputed_by,
|
||||
@@ -297,7 +233,6 @@ async def dispute_request(
|
||||
+ f"\nCurrently routed to: {req.fulfilling_domain_slug or 'unrouted'}"
|
||||
),
|
||||
)
|
||||
# Notify current fulfilling domain
|
||||
if req.fulfilling_domain_slug:
|
||||
_add_notification(
|
||||
session,
|
||||
@@ -312,52 +247,13 @@ async def dispute_request(
|
||||
),
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(req)
|
||||
return req
|
||||
|
||||
|
||||
@router.post("/capability-requests/{request_id}/reroute", response_model=CapabilityRequestRead)
|
||||
async def reroute_request(
|
||||
request_id: uuid.UUID,
|
||||
async def _notify_on_reroute(
|
||||
session: AsyncSession,
|
||||
req: CapabilityRequest,
|
||||
body: CapabilityRequestReroute,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> CapabilityRequest:
|
||||
"""Re-route a disputed request to a new domain. Resets to requested."""
|
||||
req = await _get_request_or_404(request_id, session)
|
||||
if req.status != "routing_disputed":
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail=f"Cannot reroute from status '{req.status}'. Only 'routing_disputed' requests can be rerouted.",
|
||||
)
|
||||
if body.catalog_entry_id is None and body.domain is None:
|
||||
raise HTTPException(status_code=422, detail="Either catalog_entry_id or domain must be provided.")
|
||||
|
||||
if body.catalog_entry_id is not None:
|
||||
entry = await session.get(CapabilityCatalog, body.catalog_entry_id)
|
||||
if entry is None:
|
||||
raise HTTPException(status_code=404, detail=f"Catalog entry '{body.catalog_entry_id}' not found")
|
||||
req.catalog_entry_id = entry.id
|
||||
req.fulfilling_domain_id = entry.domain_id
|
||||
new_domain_slug = (await session.get(Domain, entry.domain_id)).slug if entry.domain_id else "unknown"
|
||||
else:
|
||||
new_domain = await _resolve_domain(body.domain, session)
|
||||
req.fulfilling_domain_id = new_domain.id
|
||||
new_domain_slug = new_domain.slug
|
||||
|
||||
old_domain = req.dispute_suggested_domain or "unknown"
|
||||
|
||||
# Clear dispute fields
|
||||
req.dispute_reason = None
|
||||
req.disputed_by = None
|
||||
req.dispute_suggested_domain = None
|
||||
req.disputed_at = None
|
||||
req.status = "requested"
|
||||
|
||||
reroute_entry = f"re-routed by {body.rerouted_by} → {new_domain_slug}: {body.note}"
|
||||
req.routing_note = (req.routing_note + "\n" + reroute_entry) if req.routing_note else reroute_entry
|
||||
|
||||
# Notify requester
|
||||
new_domain_slug: str,
|
||||
) -> None:
|
||||
_add_notification(
|
||||
session,
|
||||
from_agent=body.rerouted_by,
|
||||
@@ -368,7 +264,6 @@ async def reroute_request(
|
||||
f"**Note:** {body.note}"
|
||||
),
|
||||
)
|
||||
# Notify new fulfilling domain
|
||||
_add_notification(
|
||||
session,
|
||||
from_agent=body.rerouted_by,
|
||||
@@ -383,24 +278,20 @@ async def reroute_request(
|
||||
),
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(req)
|
||||
return req
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Routing algorithm
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _route_capability(
|
||||
session: AsyncSession, capability_type: str, title: str, description: str
|
||||
async def _route_capability_match(
|
||||
session: AsyncSession,
|
||||
capability_type: str,
|
||||
title: str,
|
||||
description: str,
|
||||
) -> tuple[uuid.UUID | None, uuid.UUID | None, str]:
|
||||
"""Find the best-matching catalog entry for a capability request.
|
||||
|
||||
Returns (domain_id, catalog_entry_id, routing_note).
|
||||
Uses word-boundary matching on (title + description) combined to avoid
|
||||
false positives from substring matches (e.g. 'postgres' inside 'postgresql',
|
||||
'ha' inside 'has').
|
||||
"""
|
||||
q = select(CapabilityCatalog).where(
|
||||
CapabilityCatalog.capability_type == capability_type,
|
||||
@@ -412,20 +303,19 @@ async def _route_capability(
|
||||
return None, None, f"no active catalog entries for type '{capability_type}' — broadcast"
|
||||
|
||||
if len(entries) == 1:
|
||||
e = entries[0]
|
||||
return e.domain_id, e.id, f"single match: '{e.title}' (domain={e.domain_id})"
|
||||
entry = entries[0]
|
||||
return entry.domain_id, entry.id, f"single match: '{entry.title}' (domain={entry.domain_id})"
|
||||
|
||||
# Score by word-boundary keyword overlap against title + description combined
|
||||
combined = f"{title} {description or ''}".lower()
|
||||
scored: list[tuple[int, CapabilityCatalog]] = []
|
||||
for entry in entries:
|
||||
keywords = [kw for kw in (entry.keywords or []) if len(kw) >= 3]
|
||||
score = sum(
|
||||
1 for kw in keywords
|
||||
if re.search(r'\b' + re.escape(kw.lower()) + r'\b', combined)
|
||||
if re.search(r"\b" + re.escape(kw.lower()) + r"\b", combined)
|
||||
)
|
||||
scored.append((score, entry))
|
||||
scored.sort(key=lambda x: -x[0])
|
||||
scored.sort(key=lambda item: -item[0])
|
||||
|
||||
best_score, best = scored[0]
|
||||
if best_score == 0:
|
||||
@@ -456,7 +346,6 @@ def _add_notification(
|
||||
subject: str,
|
||||
body: str,
|
||||
) -> None:
|
||||
"""Create an AgentMessage notification in the current session (no commit)."""
|
||||
msg = AgentMessage(
|
||||
from_agent=from_agent,
|
||||
to_agent=to_agent,
|
||||
@@ -466,21 +355,6 @@ def _add_notification(
|
||||
session.add(msg)
|
||||
|
||||
|
||||
async def _resolve_domain(slug: str, session: AsyncSession) -> Domain:
|
||||
result = await session.execute(select(Domain).where(Domain.slug == slug))
|
||||
domain = result.scalar_one_or_none()
|
||||
if domain is None:
|
||||
raise HTTPException(status_code=404, detail=f"Domain '{slug}' not found")
|
||||
return domain
|
||||
|
||||
|
||||
async def _get_request_or_404(request_id: uuid.UUID, session: AsyncSession) -> CapabilityRequest:
|
||||
req = await session.get(CapabilityRequest, request_id)
|
||||
if req is None:
|
||||
raise HTTPException(status_code=404, detail=f"Capability request '{request_id}' not found")
|
||||
return req
|
||||
|
||||
|
||||
def _check_transition(current: str, target: str) -> None:
|
||||
can_reach, failures, flow_result = evaluate_transition(
|
||||
"capability_request",
|
||||
@@ -500,3 +374,44 @@ def _check_transition(current: str, target: str) -> None:
|
||||
"flow_result": flow_result_to_dict(flow_result),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
router = create_capability_catalog_router(
|
||||
get_session,
|
||||
domain_model=Domain,
|
||||
repo_model=ManagedRepo,
|
||||
catalog_model=CapabilityCatalog,
|
||||
)
|
||||
router.include_router(
|
||||
create_capability_request_read_router(
|
||||
get_session,
|
||||
domain_model=Domain,
|
||||
request_model=CapabilityRequest,
|
||||
request_read_schema=CapabilityRequestRead,
|
||||
)
|
||||
)
|
||||
router.include_router(
|
||||
create_capability_request_write_router(
|
||||
get_session,
|
||||
domain_model=Domain,
|
||||
catalog_model=CapabilityCatalog,
|
||||
request_model=CapabilityRequest,
|
||||
request_create_schema=CapabilityRequestCreate,
|
||||
request_accept_schema=CapabilityRequestAccept,
|
||||
request_patch_schema=CapabilityRequestPatch,
|
||||
request_status_patch_schema=CapabilityRequestStatusPatch,
|
||||
request_dispute_schema=CapabilityRequestDispute,
|
||||
request_reroute_schema=CapabilityRequestReroute,
|
||||
request_read_schema=CapabilityRequestRead,
|
||||
route_request=_route_capability,
|
||||
build_request=_build_capability_request,
|
||||
on_request_persisted=_notify_on_create,
|
||||
check_transition=_check_transition,
|
||||
apply_accept_fields=_apply_accept_fields,
|
||||
after_accept=_notify_on_accept,
|
||||
after_status_change=_on_status_change,
|
||||
apply_patch=_apply_capability_patch,
|
||||
after_dispute=_notify_on_dispute,
|
||||
after_reroute=_notify_on_reroute,
|
||||
)
|
||||
)
|
||||
Reference in New Issue
Block a user