generated from coulomb/repo-seed
Add hub-core package, docs, and State Hub integration scaffold
Extract the first reusable slice (models, schemas, routers, MCP, migrations) from state-hub with INTENT/SCOPE, agent instructions, workplan, and aligned inter_hub capability registry index.
This commit is contained in:
23
hub_core/routers/__init__.py
Normal file
23
hub_core/routers/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from hub_core.routers.capabilities import (
|
||||
create_capabilities_router,
|
||||
create_capability_catalog_router,
|
||||
create_capability_request_read_router,
|
||||
)
|
||||
from hub_core.routers.domains import create_domains_router
|
||||
from hub_core.routers.messages import create_messages_router
|
||||
from hub_core.routers.policy import create_policy_router
|
||||
from hub_core.routers.progress import create_progress_router
|
||||
from hub_core.routers.repos import create_repos_router
|
||||
from hub_core.routers.tpsc import create_tpsc_router
|
||||
|
||||
__all__ = [
|
||||
"create_capabilities_router",
|
||||
"create_capability_catalog_router",
|
||||
"create_capability_request_read_router",
|
||||
"create_domains_router",
|
||||
"create_messages_router",
|
||||
"create_policy_router",
|
||||
"create_progress_router",
|
||||
"create_repos_router",
|
||||
"create_tpsc_router",
|
||||
]
|
||||
427
hub_core/routers/capabilities.py
Normal file
427
hub_core/routers/capabilities.py
Normal file
@@ -0,0 +1,427 @@
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from hub_core.models.capability_catalog import CapabilityCatalog
|
||||
from hub_core.models.capability_request import CapabilityRequest
|
||||
from hub_core.models.domain import Domain
|
||||
from hub_core.models.managed_repo import ManagedRepo
|
||||
from hub_core.schemas.capability import (
|
||||
CapabilityRequestAccept,
|
||||
CapabilityRequestCreate,
|
||||
CapabilityRequestDispute,
|
||||
CapabilityRequestPatch,
|
||||
CapabilityRequestRead,
|
||||
CapabilityRequestStatusPatch,
|
||||
CatalogCreate,
|
||||
CatalogPatch,
|
||||
CatalogRead,
|
||||
)
|
||||
|
||||
|
||||
def create_capability_catalog_router(
|
||||
get_session: Callable[..., AsyncSession],
|
||||
*,
|
||||
domain_model: type[Domain] = Domain,
|
||||
repo_model: type[ManagedRepo] = ManagedRepo,
|
||||
catalog_model: type[CapabilityCatalog] = CapabilityCatalog,
|
||||
catalog_create_schema: type[CatalogCreate] = CatalogCreate,
|
||||
catalog_patch_schema: type[CatalogPatch] = CatalogPatch,
|
||||
catalog_read_schema: type[CatalogRead] = CatalogRead,
|
||||
) -> APIRouter:
|
||||
router = APIRouter(tags=["capability-requests"])
|
||||
list_response_model = list[catalog_read_schema]
|
||||
|
||||
@router.post("/capability-catalog/", response_model=catalog_read_schema, status_code=status.HTTP_201_CREATED)
|
||||
async def create_catalog_entry(
|
||||
body: catalog_create_schema,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> Any:
|
||||
domain = await _resolve_domain(body.domain, session, domain_model)
|
||||
repo_id = None
|
||||
if body.repo_slug:
|
||||
repo = await _resolve_repo(body.repo_slug, session, repo_model)
|
||||
repo_id = repo.id
|
||||
entry = catalog_model(
|
||||
domain_id=domain.id,
|
||||
repo_id=repo_id,
|
||||
capability_type=body.capability_type,
|
||||
title=body.title,
|
||||
description=body.description,
|
||||
keywords=body.keywords,
|
||||
)
|
||||
session.add(entry)
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=(
|
||||
f"Catalog entry '{body.title}' for type '{body.capability_type}' "
|
||||
f"already exists in domain '{body.domain}'"
|
||||
),
|
||||
)
|
||||
await session.refresh(entry)
|
||||
return entry
|
||||
|
||||
@router.get("/capability-catalog/", response_model=list_response_model)
|
||||
async def list_catalog(
|
||||
domain: str | None = Query(None),
|
||||
capability_type: str | None = Query(None),
|
||||
status_filter: str | None = Query(None, alias="status"),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> list[Any]:
|
||||
q = select(catalog_model).order_by(catalog_model.created_at.desc())
|
||||
if domain:
|
||||
domain_obj = await _resolve_domain(domain, session, domain_model)
|
||||
q = q.where(catalog_model.domain_id == domain_obj.id)
|
||||
if capability_type:
|
||||
q = q.where(catalog_model.capability_type == capability_type)
|
||||
if status_filter and status_filter != "all":
|
||||
q = q.where(catalog_model.status == status_filter)
|
||||
elif not status_filter:
|
||||
q = q.where(catalog_model.status == "active")
|
||||
result = await session.execute(q)
|
||||
return list(result.scalars().all())
|
||||
|
||||
@router.patch("/capability-catalog/{entry_id}", response_model=catalog_read_schema)
|
||||
async def patch_catalog_entry(
|
||||
entry_id: uuid.UUID,
|
||||
body: catalog_patch_schema,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> Any:
|
||||
entry = await session.get(catalog_model, entry_id)
|
||||
if entry is None:
|
||||
raise HTTPException(status_code=404, detail=f"Catalog entry '{entry_id}' not found")
|
||||
if body.repo_slug is not None:
|
||||
repo = await _resolve_repo(body.repo_slug, session, repo_model)
|
||||
entry.repo_id = repo.id
|
||||
if body.description is not None:
|
||||
entry.description = body.description
|
||||
if body.keywords is not None:
|
||||
entry.keywords = body.keywords
|
||||
if body.status is not None:
|
||||
entry.status = body.status
|
||||
await session.commit()
|
||||
await session.refresh(entry)
|
||||
return entry
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def create_capability_request_read_router(
|
||||
get_session: Callable[..., AsyncSession],
|
||||
*,
|
||||
domain_model: type[Domain] = Domain,
|
||||
request_model: type[CapabilityRequest] = CapabilityRequest,
|
||||
request_read_schema: type[CapabilityRequestRead] = CapabilityRequestRead,
|
||||
) -> APIRouter:
|
||||
router = APIRouter(tags=["capability-requests"])
|
||||
list_response_model = list[request_read_schema]
|
||||
|
||||
@router.get("/capability-requests/", response_model=list_response_model)
|
||||
async def list_requests(
|
||||
domain: str | None = Query(
|
||||
None,
|
||||
description="Filter by requesting or fulfilling domain slug",
|
||||
),
|
||||
status_filter: str | None = Query(None, alias="status"),
|
||||
capability_type: str | None = Query(None),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> list[Any]:
|
||||
q = select(request_model).order_by(request_model.created_at.desc())
|
||||
if domain:
|
||||
domain_obj = await _resolve_domain(domain, session, domain_model)
|
||||
q = q.where(
|
||||
(request_model.requesting_domain_id == domain_obj.id)
|
||||
| (request_model.fulfilling_domain_id == domain_obj.id)
|
||||
)
|
||||
if status_filter:
|
||||
q = q.where(request_model.status == status_filter)
|
||||
if capability_type:
|
||||
q = q.where(request_model.capability_type == capability_type)
|
||||
result = await session.execute(q)
|
||||
return list(result.scalars().all())
|
||||
|
||||
@router.get("/capability-requests/{request_id}", response_model=request_read_schema)
|
||||
async def get_request(
|
||||
request_id: uuid.UUID,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> Any:
|
||||
req = await session.get(request_model, request_id)
|
||||
if req is None:
|
||||
raise HTTPException(status_code=404, detail=f"Capability request '{request_id}' not found")
|
||||
return req
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def create_capabilities_router(get_session: Callable[..., AsyncSession]) -> APIRouter:
|
||||
router = APIRouter(tags=["capability-requests"])
|
||||
|
||||
@router.post("/capability-catalog/", response_model=CatalogRead, status_code=status.HTTP_201_CREATED)
|
||||
async def create_catalog_entry(
|
||||
body: CatalogCreate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> CapabilityCatalog:
|
||||
domain = await _resolve_domain(body.domain, session, Domain)
|
||||
repo_id = None
|
||||
if body.repo_slug:
|
||||
repo = await _resolve_repo(body.repo_slug, session, ManagedRepo)
|
||||
repo_id = repo.id
|
||||
entry = CapabilityCatalog(
|
||||
domain_id=domain.id,
|
||||
repo_id=repo_id,
|
||||
capability_type=body.capability_type,
|
||||
title=body.title,
|
||||
description=body.description,
|
||||
keywords=body.keywords,
|
||||
)
|
||||
session.add(entry)
|
||||
try:
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=(
|
||||
f"Catalog entry '{body.title}' for type '{body.capability_type}' "
|
||||
f"already exists in domain '{body.domain}'"
|
||||
),
|
||||
)
|
||||
await session.refresh(entry)
|
||||
return entry
|
||||
|
||||
@router.get("/capability-catalog/", response_model=list[CatalogRead])
|
||||
async def list_catalog(
|
||||
domain: str | None = Query(None),
|
||||
capability_type: str | None = Query(None),
|
||||
status_filter: str | None = Query(None, alias="status"),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> list[CapabilityCatalog]:
|
||||
q = select(CapabilityCatalog).order_by(CapabilityCatalog.created_at.desc())
|
||||
if domain:
|
||||
domain_obj = await _resolve_domain(domain, session, Domain)
|
||||
q = q.where(CapabilityCatalog.domain_id == domain_obj.id)
|
||||
if capability_type:
|
||||
q = q.where(CapabilityCatalog.capability_type == capability_type)
|
||||
if status_filter and status_filter != "all":
|
||||
q = q.where(CapabilityCatalog.status == status_filter)
|
||||
elif not status_filter:
|
||||
q = q.where(CapabilityCatalog.status == "active")
|
||||
result = await session.execute(q)
|
||||
return list(result.scalars().all())
|
||||
|
||||
@router.patch("/capability-catalog/{entry_id}", response_model=CatalogRead)
|
||||
async def patch_catalog_entry(
|
||||
entry_id: uuid.UUID,
|
||||
body: CatalogPatch,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> CapabilityCatalog:
|
||||
entry = await session.get(CapabilityCatalog, entry_id)
|
||||
if entry is None:
|
||||
raise HTTPException(status_code=404, detail=f"Catalog entry '{entry_id}' not found")
|
||||
if body.repo_slug is not None:
|
||||
repo = await _resolve_repo(body.repo_slug, session, ManagedRepo)
|
||||
entry.repo_id = repo.id
|
||||
if body.description is not None:
|
||||
entry.description = body.description
|
||||
if body.keywords is not None:
|
||||
entry.keywords = body.keywords
|
||||
if body.status is not None:
|
||||
entry.status = body.status
|
||||
await session.commit()
|
||||
await session.refresh(entry)
|
||||
return entry
|
||||
|
||||
@router.post("/capability-requests/", response_model=CapabilityRequestRead, status_code=status.HTTP_201_CREATED)
|
||||
async def create_request(
|
||||
body: CapabilityRequestCreate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> CapabilityRequest:
|
||||
requesting_domain = await _resolve_domain(body.requesting_domain, session, Domain)
|
||||
fulfilling_domain_id = None
|
||||
catalog_entry_id = body.catalog_entry_id
|
||||
routing_note = None
|
||||
if catalog_entry_id:
|
||||
catalog_entry = await _resolve_catalog_entry(catalog_entry_id, session)
|
||||
fulfilling_domain_id = catalog_entry.domain_id
|
||||
routing_note = "Routed by explicit catalog entry."
|
||||
else:
|
||||
catalog_entry = await _find_catalog_route(body.capability_type, session)
|
||||
if catalog_entry:
|
||||
catalog_entry_id = catalog_entry.id
|
||||
fulfilling_domain_id = catalog_entry.domain_id
|
||||
routing_note = "Routed by first active catalog match for capability_type."
|
||||
|
||||
req = CapabilityRequest(
|
||||
title=body.title,
|
||||
description=body.description,
|
||||
capability_type=body.capability_type,
|
||||
priority=body.priority,
|
||||
requesting_domain_id=requesting_domain.id,
|
||||
requesting_agent=body.requesting_agent,
|
||||
request_context=body.request_context,
|
||||
fulfilling_domain_id=fulfilling_domain_id,
|
||||
catalog_entry_id=catalog_entry_id,
|
||||
routing_note=routing_note,
|
||||
)
|
||||
session.add(req)
|
||||
await session.commit()
|
||||
await session.refresh(req)
|
||||
return req
|
||||
|
||||
@router.get("/capability-requests/", response_model=list[CapabilityRequestRead])
|
||||
async def list_requests(
|
||||
domain: str | None = Query(None, description="Filter by requesting or fulfilling domain slug"),
|
||||
status_filter: str | None = Query(None, alias="status"),
|
||||
capability_type: str | None = Query(None),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> list[CapabilityRequest]:
|
||||
q = select(CapabilityRequest).order_by(CapabilityRequest.created_at.desc())
|
||||
if domain:
|
||||
domain_obj = await _resolve_domain(domain, session, Domain)
|
||||
q = q.where(
|
||||
(CapabilityRequest.requesting_domain_id == domain_obj.id)
|
||||
| (CapabilityRequest.fulfilling_domain_id == domain_obj.id)
|
||||
)
|
||||
if status_filter:
|
||||
q = q.where(CapabilityRequest.status == status_filter)
|
||||
if capability_type:
|
||||
q = q.where(CapabilityRequest.capability_type == capability_type)
|
||||
result = await session.execute(q)
|
||||
return list(result.scalars().all())
|
||||
|
||||
@router.get("/capability-requests/{request_id}", response_model=CapabilityRequestRead)
|
||||
async def get_request(
|
||||
request_id: uuid.UUID,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> CapabilityRequest:
|
||||
return await _get_request_or_404(request_id, session)
|
||||
|
||||
@router.post("/capability-requests/{request_id}/accept", response_model=CapabilityRequestRead)
|
||||
async def accept_request(
|
||||
request_id: uuid.UUID,
|
||||
body: CapabilityRequestAccept,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> CapabilityRequest:
|
||||
req = await _get_request_or_404(request_id, session)
|
||||
req.status = "accepted"
|
||||
req.fulfilling_agent = body.fulfilling_agent
|
||||
req.fulfillment_context = body.fulfillment_context
|
||||
req.accepted_at = datetime.now(tz=timezone.utc)
|
||||
await session.commit()
|
||||
await session.refresh(req)
|
||||
return req
|
||||
|
||||
@router.patch("/capability-requests/{request_id}", response_model=CapabilityRequestRead)
|
||||
async def patch_request(
|
||||
request_id: uuid.UUID,
|
||||
body: CapabilityRequestPatch,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> CapabilityRequest:
|
||||
req = await _get_request_or_404(request_id, session)
|
||||
if body.catalog_entry_id is not None:
|
||||
catalog_entry = await _resolve_catalog_entry(body.catalog_entry_id, session)
|
||||
req.catalog_entry_id = catalog_entry.id
|
||||
req.fulfilling_domain_id = catalog_entry.domain_id
|
||||
if body.priority is not None:
|
||||
req.priority = body.priority
|
||||
if body.request_context is not None:
|
||||
req.request_context = body.request_context
|
||||
if body.fulfillment_context is not None:
|
||||
req.fulfillment_context = body.fulfillment_context
|
||||
await session.commit()
|
||||
await session.refresh(req)
|
||||
return req
|
||||
|
||||
@router.patch("/capability-requests/{request_id}/status", response_model=CapabilityRequestRead)
|
||||
async def patch_request_status(
|
||||
request_id: uuid.UUID,
|
||||
body: CapabilityRequestStatusPatch,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> CapabilityRequest:
|
||||
req = await _get_request_or_404(request_id, session)
|
||||
req.status = body.status
|
||||
if body.note:
|
||||
req.resolution_note = body.note
|
||||
if body.status == "completed":
|
||||
req.completed_at = datetime.now(tz=timezone.utc)
|
||||
await session.commit()
|
||||
await session.refresh(req)
|
||||
return req
|
||||
|
||||
@router.post("/capability-requests/{request_id}/dispute", response_model=CapabilityRequestRead)
|
||||
async def dispute_request(
|
||||
request_id: uuid.UUID,
|
||||
body: CapabilityRequestDispute,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> CapabilityRequest:
|
||||
req = await _get_request_or_404(request_id, session)
|
||||
req.status = "routing_disputed"
|
||||
req.dispute_reason = body.reason
|
||||
req.disputed_by = body.disputed_by
|
||||
req.dispute_suggested_domain = body.suggested_domain
|
||||
req.disputed_at = datetime.now(tz=timezone.utc)
|
||||
await session.commit()
|
||||
await session.refresh(req)
|
||||
return req
|
||||
|
||||
return router
|
||||
|
||||
|
||||
async def _resolve_domain(
|
||||
slug: str,
|
||||
session: AsyncSession,
|
||||
domain_model: type[Domain],
|
||||
) -> Any:
|
||||
result = await session.execute(select(domain_model).where(domain_model.slug == slug))
|
||||
domain = result.scalar_one_or_none()
|
||||
if domain is None:
|
||||
raise HTTPException(status_code=404, detail=f"Domain '{slug}' not found")
|
||||
return domain
|
||||
|
||||
|
||||
async def _resolve_repo(
|
||||
slug: str,
|
||||
session: AsyncSession,
|
||||
repo_model: type[ManagedRepo],
|
||||
) -> Any:
|
||||
result = await session.execute(select(repo_model).where(repo_model.slug == slug))
|
||||
repo = result.scalar_one_or_none()
|
||||
if repo is None:
|
||||
raise HTTPException(status_code=404, detail=f"Repo '{slug}' not found")
|
||||
return repo
|
||||
|
||||
|
||||
async def _resolve_catalog_entry(entry_id: uuid.UUID, session: AsyncSession) -> CapabilityCatalog:
|
||||
entry = await session.get(CapabilityCatalog, entry_id)
|
||||
if entry is None:
|
||||
raise HTTPException(status_code=404, detail=f"Catalog entry '{entry_id}' not found")
|
||||
return entry
|
||||
|
||||
|
||||
async def _find_catalog_route(
|
||||
capability_type: str,
|
||||
session: AsyncSession,
|
||||
) -> CapabilityCatalog | None:
|
||||
result = await session.execute(
|
||||
select(CapabilityCatalog)
|
||||
.where(CapabilityCatalog.capability_type == capability_type)
|
||||
.where(CapabilityCatalog.status == "active")
|
||||
.order_by(CapabilityCatalog.created_at.desc())
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
|
||||
async def _get_request_or_404(request_id: uuid.UUID, session: AsyncSession) -> CapabilityRequest:
|
||||
req = await session.get(CapabilityRequest, request_id)
|
||||
if req is None:
|
||||
raise HTTPException(status_code=404, detail=f"Capability request '{request_id}' not found")
|
||||
return req
|
||||
178
hub_core/routers/domains.py
Normal file
178
hub_core/routers/domains.py
Normal file
@@ -0,0 +1,178 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Awaitable
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Response, status
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import noload
|
||||
|
||||
from hub_core.models.domain import Domain
|
||||
from hub_core.models.managed_repo import ManagedRepo
|
||||
from hub_core.schemas.domain import DomainCreate, DomainDetail, DomainRead, DomainRename, DomainUpdate, RepoStub
|
||||
|
||||
|
||||
DomainDetailBuilder = Callable[[Any, AsyncSession], Awaitable[Any]]
|
||||
DomainArchiveValidator = Callable[[Any, AsyncSession], Awaitable[None]]
|
||||
|
||||
|
||||
def create_domains_router(
|
||||
get_session: Callable[..., AsyncSession],
|
||||
*,
|
||||
domain_model: type[Domain] = Domain,
|
||||
repo_model: type[ManagedRepo] = ManagedRepo,
|
||||
domain_create_schema: type[DomainCreate] = DomainCreate,
|
||||
domain_detail_schema: type[DomainDetail] = DomainDetail,
|
||||
domain_read_schema: type[DomainRead] = DomainRead,
|
||||
domain_rename_schema: type[DomainRename] = DomainRename,
|
||||
domain_update_schema: type[DomainUpdate] = DomainUpdate,
|
||||
repo_stub_schema: type[RepoStub] = RepoStub,
|
||||
detail_builder: DomainDetailBuilder | None = None,
|
||||
before_archive: DomainArchiveValidator | None = None,
|
||||
list_noload_fields: tuple[str, ...] = ("repos",),
|
||||
include_update_route: bool = True,
|
||||
) -> APIRouter:
|
||||
router = APIRouter(prefix="/domains", tags=["domains"])
|
||||
list_response_model = list[domain_read_schema]
|
||||
|
||||
@router.get("/", response_model=list_response_model)
|
||||
async def list_domains(
|
||||
response: Response,
|
||||
status_filter: str | None = Query(None, alias="status", description="active | archived | all"),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> list[Any]:
|
||||
response.headers["Cache-Control"] = "max-age=60, stale-while-revalidate=30"
|
||||
q = select(domain_model).options(
|
||||
*[
|
||||
noload(getattr(domain_model, field))
|
||||
for field in list_noload_fields
|
||||
if hasattr(domain_model, field)
|
||||
]
|
||||
).order_by(domain_model.name)
|
||||
if status_filter and status_filter != "all":
|
||||
q = q.where(domain_model.status == status_filter)
|
||||
elif status_filter is None:
|
||||
q = q.where(domain_model.status == "active")
|
||||
result = await session.execute(q)
|
||||
return list(result.scalars().all())
|
||||
|
||||
@router.post("/", response_model=domain_read_schema, status_code=status.HTTP_201_CREATED)
|
||||
async def create_domain(
|
||||
body: domain_create_schema,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> Any:
|
||||
existing = await session.execute(select(domain_model).where(domain_model.slug == body.slug))
|
||||
if existing.scalar_one_or_none():
|
||||
raise HTTPException(status_code=409, detail=f"Domain slug '{body.slug}' already exists")
|
||||
domain = domain_model(slug=body.slug, name=body.name, description=body.description)
|
||||
session.add(domain)
|
||||
await session.commit()
|
||||
await session.refresh(domain)
|
||||
return domain
|
||||
|
||||
@router.get("/{slug}", response_model=domain_detail_schema)
|
||||
async def get_domain(
|
||||
slug: str,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> Any:
|
||||
domain = await _get_domain_by_slug(slug, session, domain_model)
|
||||
if detail_builder is not None:
|
||||
return await detail_builder(domain, session)
|
||||
return await _build_default_domain_detail(
|
||||
domain,
|
||||
session,
|
||||
repo_model=repo_model,
|
||||
repo_stub_schema=repo_stub_schema,
|
||||
domain_detail_schema=domain_detail_schema,
|
||||
)
|
||||
|
||||
if include_update_route:
|
||||
@router.patch("/{slug}", response_model=domain_read_schema)
|
||||
async def update_domain(
|
||||
slug: str,
|
||||
body: domain_update_schema,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> Any:
|
||||
domain = await _get_domain_by_slug(slug, session, domain_model)
|
||||
for field, value in body.model_dump(exclude_unset=True).items():
|
||||
setattr(domain, field, value)
|
||||
await session.commit()
|
||||
await session.refresh(domain)
|
||||
return domain
|
||||
|
||||
@router.patch("/{slug}/rename", response_model=domain_read_schema)
|
||||
async def rename_domain(
|
||||
slug: str,
|
||||
body: domain_rename_schema,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> Any:
|
||||
domain = await _get_domain_by_slug(slug, session, domain_model)
|
||||
if body.new_slug != slug:
|
||||
conflict = await session.execute(select(domain_model).where(domain_model.slug == body.new_slug))
|
||||
if conflict.scalar_one_or_none():
|
||||
raise HTTPException(status_code=409, detail=f"Slug '{body.new_slug}' already taken")
|
||||
domain.slug = body.new_slug
|
||||
domain.name = body.new_name
|
||||
await session.commit()
|
||||
await session.refresh(domain)
|
||||
return domain
|
||||
|
||||
@router.patch("/{slug}/archive", response_model=domain_read_schema)
|
||||
async def archive_domain(
|
||||
slug: str,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> Any:
|
||||
domain = await _get_domain_by_slug(slug, session, domain_model)
|
||||
if before_archive is not None:
|
||||
await before_archive(domain, session)
|
||||
domain.status = "archived"
|
||||
await session.commit()
|
||||
await session.refresh(domain)
|
||||
return domain
|
||||
|
||||
return router
|
||||
|
||||
|
||||
async def _get_domain_by_slug(
|
||||
slug: str,
|
||||
session: AsyncSession,
|
||||
domain_model: type[Domain],
|
||||
) -> Any:
|
||||
result = await session.execute(select(domain_model).where(domain_model.slug == slug))
|
||||
domain = result.scalar_one_or_none()
|
||||
if domain is None:
|
||||
raise HTTPException(status_code=404, detail=f"Domain '{slug}' not found")
|
||||
return domain
|
||||
|
||||
|
||||
async def _build_default_domain_detail(
|
||||
domain: Any,
|
||||
session: AsyncSession,
|
||||
*,
|
||||
repo_model: type[ManagedRepo],
|
||||
repo_stub_schema: type[RepoStub],
|
||||
domain_detail_schema: type[DomainDetail],
|
||||
) -> Any:
|
||||
repo_count_row = await session.execute(
|
||||
select(func.count()).select_from(repo_model)
|
||||
.where(repo_model.domain_id == domain.id)
|
||||
.where(repo_model.status == "active")
|
||||
)
|
||||
repos_row = await session.execute(
|
||||
select(repo_model)
|
||||
.where(repo_model.domain_id == domain.id)
|
||||
.where(repo_model.status == "active")
|
||||
.order_by(repo_model.name)
|
||||
)
|
||||
repos = list(repos_row.scalars().all())
|
||||
|
||||
return domain_detail_schema(
|
||||
id=domain.id,
|
||||
slug=domain.slug,
|
||||
name=domain.name,
|
||||
description=domain.description,
|
||||
status=domain.status,
|
||||
created_at=domain.created_at,
|
||||
updated_at=domain.updated_at,
|
||||
repos=[repo_stub_schema.model_validate(repo) for repo in repos],
|
||||
extension_counts={"repos": repo_count_row.scalar_one()},
|
||||
)
|
||||
121
hub_core/routers/messages.py
Normal file
121
hub_core/routers/messages.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from hub_core.models.agent_message import AgentMessage
|
||||
from hub_core.schemas.agent_message import MessageCreate, MessageRead, MessageReply
|
||||
|
||||
|
||||
def create_messages_router(
|
||||
get_session: Callable[..., AsyncSession],
|
||||
*,
|
||||
message_model: type[AgentMessage] = AgentMessage,
|
||||
) -> APIRouter:
|
||||
router = APIRouter(prefix="/messages", tags=["messages"])
|
||||
|
||||
async def _get_message(message_id: uuid.UUID, session: AsyncSession) -> Any:
|
||||
msg = await session.get(message_model, message_id)
|
||||
if msg is None:
|
||||
raise HTTPException(status_code=404, detail=f"Message {message_id} not found")
|
||||
return msg
|
||||
|
||||
@router.post("/", response_model=MessageRead, status_code=status.HTTP_201_CREATED)
|
||||
async def send_message(
|
||||
body: MessageCreate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> Any:
|
||||
if body.thread_id:
|
||||
root = await session.get(message_model, body.thread_id)
|
||||
if root is None:
|
||||
raise HTTPException(status_code=404, detail=f"Thread root {body.thread_id} not found")
|
||||
msg = message_model(**body.model_dump())
|
||||
session.add(msg)
|
||||
await session.commit()
|
||||
await session.refresh(msg)
|
||||
return msg
|
||||
|
||||
@router.get("/", response_model=list[MessageRead])
|
||||
async def list_messages(
|
||||
to_agent: str | None = None,
|
||||
from_agent: str | None = None,
|
||||
unread_only: bool = False,
|
||||
limit: int = 50,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> list[Any]:
|
||||
q = select(message_model).where(message_model.archived_at.is_(None))
|
||||
if to_agent:
|
||||
q = q.where(
|
||||
(message_model.to_agent == to_agent) | (message_model.to_agent == "broadcast")
|
||||
)
|
||||
if from_agent:
|
||||
q = q.where(message_model.from_agent == from_agent)
|
||||
if unread_only:
|
||||
q = q.where(message_model.read_at.is_(None))
|
||||
q = q.order_by(message_model.created_at.desc()).limit(limit)
|
||||
result = await session.execute(q)
|
||||
return list(result.scalars().all())
|
||||
|
||||
@router.get("/thread/{thread_id}", response_model=list[MessageRead])
|
||||
async def get_thread(
|
||||
thread_id: uuid.UUID,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> list[Any]:
|
||||
q = select(message_model).where(
|
||||
(message_model.id == thread_id) | (message_model.thread_id == thread_id)
|
||||
).order_by(message_model.created_at)
|
||||
result = await session.execute(q)
|
||||
return list(result.scalars().all())
|
||||
|
||||
@router.patch("/{message_id}/read", response_model=MessageRead)
|
||||
async def mark_read(
|
||||
message_id: uuid.UUID,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> Any:
|
||||
msg = await _get_message(message_id, session)
|
||||
if msg.read_at is None:
|
||||
msg.read_at = datetime.now(timezone.utc)
|
||||
await session.commit()
|
||||
await session.refresh(msg)
|
||||
return msg
|
||||
|
||||
@router.patch("/{message_id}/archive", response_model=MessageRead)
|
||||
async def archive_message(
|
||||
message_id: uuid.UUID,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> Any:
|
||||
msg = await _get_message(message_id, session)
|
||||
msg.archived_at = datetime.now(timezone.utc)
|
||||
if msg.read_at is None:
|
||||
msg.read_at = msg.archived_at
|
||||
await session.commit()
|
||||
await session.refresh(msg)
|
||||
return msg
|
||||
|
||||
@router.post("/{message_id}/reply", response_model=MessageRead, status_code=status.HTTP_201_CREATED)
|
||||
async def reply_to_message(
|
||||
message_id: uuid.UUID,
|
||||
body: MessageReply,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> Any:
|
||||
original = await _get_message(message_id, session)
|
||||
if original.read_at is None:
|
||||
original.read_at = datetime.now(timezone.utc)
|
||||
thread_root = original.thread_id or original.id
|
||||
reply = message_model(
|
||||
from_agent=body.from_agent,
|
||||
to_agent=original.from_agent,
|
||||
subject=f"Re: {original.subject}",
|
||||
body=body.body,
|
||||
thread_id=thread_root,
|
||||
)
|
||||
session.add(reply)
|
||||
await session.commit()
|
||||
await session.refresh(reply)
|
||||
return reply
|
||||
|
||||
return router
|
||||
30
hub_core/routers/policy.py
Normal file
30
hub_core/routers/policy.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from hub_core.schemas.policy import PolicyRead, PolicyUpdate
|
||||
|
||||
PolicyLoader = Callable[[str], PolicyRead | None]
|
||||
PolicyUpdater = Callable[[str, str], PolicyRead]
|
||||
|
||||
|
||||
def create_policy_router(
|
||||
load_policy: PolicyLoader,
|
||||
update_policy: PolicyUpdater | None = None,
|
||||
) -> APIRouter:
|
||||
router = APIRouter(prefix="/policy", tags=["policy"])
|
||||
|
||||
@router.get("/{name}", response_model=PolicyRead)
|
||||
def get_policy(name: str) -> PolicyRead:
|
||||
policy = load_policy(name)
|
||||
if policy is None:
|
||||
raise HTTPException(status_code=404, detail=f"Policy '{name}' not found")
|
||||
return policy
|
||||
|
||||
if update_policy is not None:
|
||||
|
||||
@router.put("/{name}", response_model=PolicyRead)
|
||||
def put_policy(name: str, body: PolicyUpdate) -> PolicyRead:
|
||||
return update_policy(name, body.content)
|
||||
|
||||
return router
|
||||
130
hub_core/routers/progress.py
Normal file
130
hub_core/routers/progress.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import uuid
|
||||
from collections.abc import Callable, Collection
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from hub_core.events import ALERT_EVENT_TYPES, RISK_EVENT_TYPES
|
||||
from hub_core.models.progress_event import ProgressEvent
|
||||
from hub_core.schemas.progress_event import ProgressEventCreate, ProgressEventRead
|
||||
from hub_core.utils.pagination import PageParams, apply_pagination
|
||||
|
||||
|
||||
def create_progress_router(
|
||||
get_session: Callable[..., AsyncSession],
|
||||
*,
|
||||
progress_model: type[ProgressEvent] = ProgressEvent,
|
||||
progress_create_schema: type[ProgressEventCreate] = ProgressEventCreate,
|
||||
progress_read_schema: type[ProgressEventRead] = ProgressEventRead,
|
||||
) -> APIRouter:
|
||||
router = APIRouter(prefix="/progress", tags=["progress"])
|
||||
list_response_model = list[progress_read_schema]
|
||||
|
||||
async def _list_events(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
topic_id: uuid.UUID | None = None,
|
||||
workstream_id: uuid.UUID | None = None,
|
||||
task_id: uuid.UUID | None = None,
|
||||
decision_id: uuid.UUID | None = None,
|
||||
event_type: str | None = None,
|
||||
event_types: Collection[str] | None = None,
|
||||
since: datetime | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[Any]:
|
||||
q = select(progress_model)
|
||||
for field, value in (
|
||||
("topic_id", topic_id),
|
||||
("workstream_id", workstream_id),
|
||||
("task_id", task_id),
|
||||
("decision_id", decision_id),
|
||||
):
|
||||
if value is not None:
|
||||
column = getattr(progress_model, field, None)
|
||||
if column is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Progress events do not support filtering by {field}",
|
||||
)
|
||||
q = q.where(column == value)
|
||||
if event_type:
|
||||
q = q.where(progress_model.event_type == event_type)
|
||||
if event_types is not None:
|
||||
q = q.where(progress_model.event_type.in_(sorted(event_types)))
|
||||
if since:
|
||||
q = q.where(progress_model.created_at >= since)
|
||||
q = q.order_by(progress_model.created_at.desc())
|
||||
q = apply_pagination(q, PageParams(limit=limit, offset=offset))
|
||||
result = await session.execute(q)
|
||||
return list(result.scalars().all())
|
||||
|
||||
@router.get("/", response_model=list_response_model)
|
||||
async def list_progress(
|
||||
topic_id: uuid.UUID | None = None,
|
||||
workstream_id: uuid.UUID | None = None,
|
||||
task_id: uuid.UUID | None = None,
|
||||
decision_id: uuid.UUID | None = None,
|
||||
event_type: str | None = None,
|
||||
since: datetime | None = None,
|
||||
limit: int = Query(100, le=1000),
|
||||
offset: int = Query(0, ge=0),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> list[Any]:
|
||||
return await _list_events(
|
||||
session,
|
||||
topic_id=topic_id,
|
||||
workstream_id=workstream_id,
|
||||
task_id=task_id,
|
||||
decision_id=decision_id,
|
||||
event_type=event_type,
|
||||
since=since,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
@router.get("/risks", response_model=list_response_model)
|
||||
async def get_risks(
|
||||
since: datetime | None = None,
|
||||
limit: int = Query(100, le=1000),
|
||||
offset: int = Query(0, ge=0),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> list[Any]:
|
||||
return await _list_events(
|
||||
session,
|
||||
event_types=RISK_EVENT_TYPES,
|
||||
since=since,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
@router.get("/alerts", response_model=list_response_model)
|
||||
async def get_alerts(
|
||||
since: datetime | None = None,
|
||||
limit: int = Query(100, le=1000),
|
||||
offset: int = Query(0, ge=0),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> list[Any]:
|
||||
return await _list_events(
|
||||
session,
|
||||
event_types=ALERT_EVENT_TYPES,
|
||||
since=since,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
@router.post("/", response_model=progress_read_schema, status_code=status.HTTP_201_CREATED)
|
||||
async def append_progress(
|
||||
body: progress_create_schema,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> Any:
|
||||
event = progress_model(**body.model_dump())
|
||||
session.add(event)
|
||||
await session.commit()
|
||||
await session.refresh(event)
|
||||
return event
|
||||
|
||||
return router
|
||||
173
hub_core/routers/repos.py
Normal file
173
hub_core/routers/repos.py
Normal file
@@ -0,0 +1,173 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Awaitable
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import noload
|
||||
|
||||
from hub_core.models.domain import Domain
|
||||
from hub_core.models.managed_repo import ManagedRepo
|
||||
from hub_core.schemas.managed_repo import RepoCreate, RepoPathRegister, RepoRead, RepoUpdate
|
||||
|
||||
|
||||
RepoRegisteredHook = Callable[[Any, Any, Any], Awaitable[None] | None]
|
||||
|
||||
|
||||
def create_repos_router(
|
||||
get_session: Callable[..., AsyncSession],
|
||||
*,
|
||||
prefix: str = "/repos",
|
||||
domain_model: type[Domain] = Domain,
|
||||
repo_model: type[ManagedRepo] = ManagedRepo,
|
||||
repo_create_schema: type[RepoCreate] = RepoCreate,
|
||||
repo_update_schema: type[RepoUpdate] = RepoUpdate,
|
||||
repo_read_schema: type[RepoRead] = RepoRead,
|
||||
repo_path_register_schema: type[RepoPathRegister] = RepoPathRegister,
|
||||
list_noload_fields: tuple[str, ...] = ("domain",),
|
||||
create_extension_fields: tuple[str, ...] = (),
|
||||
after_register: RepoRegisteredHook | None = None,
|
||||
include_collection_routes: bool = True,
|
||||
include_lookup_routes: bool = True,
|
||||
include_slug_routes: bool = True,
|
||||
) -> APIRouter:
|
||||
router = APIRouter(prefix=prefix, tags=["repos"])
|
||||
list_response_model = list[repo_read_schema]
|
||||
|
||||
if include_collection_routes:
|
||||
@router.get("/", response_model=list_response_model)
|
||||
async def list_repos(
|
||||
response: Response,
|
||||
domain: str | None = None,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> list[Any]:
|
||||
response.headers["Cache-Control"] = "max-age=60, stale-while-revalidate=30"
|
||||
q = select(repo_model).options(
|
||||
*[
|
||||
noload(getattr(repo_model, field))
|
||||
for field in list_noload_fields
|
||||
if hasattr(repo_model, field)
|
||||
]
|
||||
).order_by(repo_model.name)
|
||||
if domain:
|
||||
domain_obj = await _get_domain_by_slug(domain, session, domain_model)
|
||||
q = q.where(repo_model.domain_id == domain_obj.id)
|
||||
result = await session.execute(q)
|
||||
return list(result.scalars().all())
|
||||
|
||||
@router.post("/", response_model=repo_read_schema, status_code=status.HTTP_201_CREATED)
|
||||
async def register_repo(
|
||||
body: repo_create_schema,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> Any:
|
||||
domain_obj = await _get_domain_by_slug(body.domain_slug, session, domain_model)
|
||||
existing = await session.execute(select(repo_model).where(repo_model.slug == body.slug))
|
||||
if existing.scalar_one_or_none():
|
||||
raise HTTPException(status_code=409, detail=f"Repo slug '{body.slug}' already exists")
|
||||
repo_attrs = {
|
||||
"domain_id": domain_obj.id,
|
||||
"slug": body.slug,
|
||||
"name": body.name,
|
||||
"local_path": body.local_path,
|
||||
"host_paths": body.host_paths,
|
||||
"remote_url": body.remote_url,
|
||||
"git_fingerprint": body.git_fingerprint,
|
||||
"description": body.description,
|
||||
}
|
||||
for field in create_extension_fields:
|
||||
if hasattr(body, field) and hasattr(repo_model, field):
|
||||
repo_attrs[field] = getattr(body, field)
|
||||
repo = repo_model(**repo_attrs)
|
||||
session.add(repo)
|
||||
await session.commit()
|
||||
await session.refresh(repo)
|
||||
if after_register is not None:
|
||||
hook_result = after_register(repo, body, domain_obj)
|
||||
if hook_result is not None:
|
||||
await hook_result
|
||||
return repo
|
||||
|
||||
if include_lookup_routes:
|
||||
@router.get("/by-fingerprint", response_model=list_response_model)
|
||||
async def get_repo_by_fingerprint(
|
||||
hash: str,
|
||||
remote_url: str | None = None,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> list[Any]:
|
||||
q = select(repo_model).where(repo_model.git_fingerprint == hash)
|
||||
if remote_url:
|
||||
q = q.where(repo_model.remote_url == remote_url)
|
||||
result = await session.execute(q)
|
||||
return list(result.scalars().all())
|
||||
|
||||
@router.get("/by-remote", response_model=repo_read_schema)
|
||||
async def get_repo_by_remote_url(
|
||||
url: str,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> Any:
|
||||
result = await session.execute(select(repo_model).where(repo_model.remote_url == url))
|
||||
repo = result.scalar_one_or_none()
|
||||
if repo is None:
|
||||
raise HTTPException(status_code=404, detail=f"No repo with remote_url '{url}' found")
|
||||
return repo
|
||||
|
||||
if include_slug_routes:
|
||||
@router.get("/{slug}", response_model=repo_read_schema)
|
||||
async def get_repo(
|
||||
slug: str,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> Any:
|
||||
return await _get_repo_by_slug(slug, session, repo_model)
|
||||
|
||||
@router.patch("/{slug}", response_model=repo_read_schema)
|
||||
async def update_repo(
|
||||
slug: str,
|
||||
body: repo_update_schema,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> Any:
|
||||
repo = await _get_repo_by_slug(slug, session, repo_model)
|
||||
for field, value in body.model_dump(exclude_unset=True).items():
|
||||
setattr(repo, field, value)
|
||||
await session.commit()
|
||||
await session.refresh(repo)
|
||||
return repo
|
||||
|
||||
@router.post("/{slug}/paths", response_model=repo_read_schema)
|
||||
async def register_repo_path(
|
||||
slug: str,
|
||||
body: repo_path_register_schema,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> Any:
|
||||
repo = await _get_repo_by_slug(slug, session, repo_model)
|
||||
host_paths = dict(repo.host_paths or {})
|
||||
host_paths[body.host] = body.path
|
||||
repo.host_paths = host_paths
|
||||
await session.commit()
|
||||
await session.refresh(repo)
|
||||
return repo
|
||||
|
||||
return router
|
||||
|
||||
|
||||
async def _get_domain_by_slug(
|
||||
slug: str,
|
||||
session: AsyncSession,
|
||||
domain_model: type[Domain],
|
||||
) -> Any:
|
||||
result = await session.execute(select(domain_model).where(domain_model.slug == slug))
|
||||
domain = result.scalar_one_or_none()
|
||||
if domain is None:
|
||||
raise HTTPException(status_code=404, detail=f"Domain '{slug}' not found")
|
||||
return domain
|
||||
|
||||
|
||||
async def _get_repo_by_slug(
|
||||
slug: str,
|
||||
session: AsyncSession,
|
||||
repo_model: type[ManagedRepo],
|
||||
) -> Any:
|
||||
result = await session.execute(select(repo_model).where(repo_model.slug == slug))
|
||||
repo = result.scalar_one_or_none()
|
||||
if repo is None:
|
||||
raise HTTPException(status_code=404, detail=f"Repo '{slug}' not found")
|
||||
return repo
|
||||
240
hub_core/routers/tpsc.py
Normal file
240
hub_core/routers/tpsc.py
Normal file
@@ -0,0 +1,240 @@
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from hub_core.models.managed_repo import ManagedRepo
|
||||
from hub_core.models.tpsc import TPSCCatalog, TPSCEntry, TPSCSnapshot
|
||||
from hub_core.schemas.tpsc import (
|
||||
GDPR_WARNING_LEVELS,
|
||||
TPSCCatalogCreate,
|
||||
TPSCCatalogRead,
|
||||
TPSCEntryRead,
|
||||
TPSCGDPRReport,
|
||||
TPSCGDPRWarning,
|
||||
TPSCIngestRequest,
|
||||
TPSCSnapshotRead,
|
||||
)
|
||||
|
||||
|
||||
def create_tpsc_router(
|
||||
get_session: Callable[..., AsyncSession],
|
||||
*,
|
||||
repo_model: type[ManagedRepo] = ManagedRepo,
|
||||
catalog_model: type[TPSCCatalog] = TPSCCatalog,
|
||||
snapshot_model: type[TPSCSnapshot] = TPSCSnapshot,
|
||||
entry_model: type[TPSCEntry] = TPSCEntry,
|
||||
) -> APIRouter:
|
||||
router = APIRouter(prefix="/tpsc", tags=["tpsc"])
|
||||
|
||||
@router.get("/catalog/", response_model=list[TPSCCatalogRead])
|
||||
async def list_catalog(
|
||||
gdpr_maturity: str | None = None,
|
||||
category: str | None = None,
|
||||
pricing_model: str | None = None,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> list[Any]:
|
||||
q = select(catalog_model).where(catalog_model.status != "deprecated")
|
||||
if gdpr_maturity:
|
||||
q = q.where(catalog_model.gdpr_maturity == gdpr_maturity)
|
||||
if category:
|
||||
q = q.where(catalog_model.category == category)
|
||||
if pricing_model:
|
||||
q = q.where(catalog_model.pricing_model == pricing_model)
|
||||
q = q.order_by(catalog_model.name)
|
||||
result = await session.execute(q)
|
||||
return list(result.scalars().all())
|
||||
|
||||
@router.get("/catalog/{slug}", response_model=TPSCCatalogRead)
|
||||
async def get_catalog_entry(
|
||||
slug: str,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> Any:
|
||||
row = (
|
||||
await session.execute(select(catalog_model).where(catalog_model.slug == slug))
|
||||
).scalar_one_or_none()
|
||||
if row is None:
|
||||
raise HTTPException(status_code=404, detail=f"Service '{slug}' not found in catalog")
|
||||
return row
|
||||
|
||||
@router.post("/catalog/", response_model=TPSCCatalogRead, status_code=status.HTTP_201_CREATED)
|
||||
async def register_service(
|
||||
body: TPSCCatalogCreate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> Any:
|
||||
existing = (
|
||||
await session.execute(select(catalog_model).where(catalog_model.slug == body.slug))
|
||||
).scalar_one_or_none()
|
||||
if existing:
|
||||
for field, value in body.model_dump(exclude_unset=True).items():
|
||||
setattr(existing, field, value)
|
||||
existing.updated_at = datetime.now(tz=timezone.utc)
|
||||
await session.commit()
|
||||
await session.refresh(existing)
|
||||
return existing
|
||||
entry = catalog_model(**body.model_dump())
|
||||
session.add(entry)
|
||||
await session.commit()
|
||||
await session.refresh(entry)
|
||||
return entry
|
||||
|
||||
@router.post("/ingest/", response_model=TPSCSnapshotRead, status_code=status.HTTP_201_CREATED)
|
||||
async def ingest_tpsc(
|
||||
body: TPSCIngestRequest,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> TPSCSnapshotRead:
|
||||
repo = (
|
||||
await session.execute(select(repo_model).where(repo_model.slug == body.repo_slug))
|
||||
).scalar_one_or_none()
|
||||
repo_id = repo.id if repo else None
|
||||
slugs = {entry.service_slug for entry in body.entries}
|
||||
catalog_rows = []
|
||||
if slugs:
|
||||
catalog_rows = (
|
||||
await session.execute(select(catalog_model).where(catalog_model.slug.in_(slugs)))
|
||||
).scalars().all()
|
||||
catalog_map = {row.slug: row for row in catalog_rows}
|
||||
|
||||
snapshot = snapshot_model(
|
||||
repo_id=repo_id,
|
||||
source_file=body.source_file,
|
||||
entry_count=len(body.entries),
|
||||
)
|
||||
session.add(snapshot)
|
||||
await session.flush()
|
||||
|
||||
entries_with_catalogs = []
|
||||
for body_entry in body.entries:
|
||||
catalog_entry = catalog_map.get(body_entry.service_slug)
|
||||
entry = entry_model(
|
||||
snapshot_id=snapshot.id,
|
||||
catalog_id=catalog_entry.id if catalog_entry else None,
|
||||
**body_entry.model_dump(),
|
||||
)
|
||||
session.add(entry)
|
||||
entries_with_catalogs.append((entry, catalog_entry))
|
||||
|
||||
await session.flush()
|
||||
await session.commit()
|
||||
await session.refresh(snapshot)
|
||||
|
||||
return TPSCSnapshotRead(
|
||||
id=snapshot.id,
|
||||
repo_id=snapshot.repo_id,
|
||||
snapshot_at=snapshot.snapshot_at,
|
||||
source_file=snapshot.source_file,
|
||||
entry_count=snapshot.entry_count,
|
||||
entries=[
|
||||
_entry_read(entry, catalog_entry)
|
||||
for entry, catalog_entry in entries_with_catalogs
|
||||
],
|
||||
)
|
||||
|
||||
@router.get("/snapshots/", response_model=list[TPSCSnapshotRead])
|
||||
async def list_snapshots(
|
||||
repo_slug: str | None = None,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> list[TPSCSnapshotRead]:
|
||||
q = select(snapshot_model).options(
|
||||
selectinload(snapshot_model.entries).selectinload(entry_model.catalog_entry)
|
||||
)
|
||||
if repo_slug:
|
||||
repo = (
|
||||
await session.execute(select(repo_model).where(repo_model.slug == repo_slug))
|
||||
).scalar_one_or_none()
|
||||
if repo is None:
|
||||
raise HTTPException(status_code=404, detail=f"Repo '{repo_slug}' not found")
|
||||
q = q.where(snapshot_model.repo_id == repo.id)
|
||||
q = q.order_by(snapshot_model.snapshot_at.desc())
|
||||
rows = (await session.execute(q)).scalars().all()
|
||||
return [_snapshot_read(row) for row in rows]
|
||||
|
||||
@router.get("/report/gdpr", response_model=TPSCGDPRReport)
|
||||
async def gdpr_report(
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> TPSCGDPRReport:
|
||||
latest_sub = (
|
||||
select(snapshot_model.repo_id, func.max(snapshot_model.snapshot_at).label("max_at"))
|
||||
.group_by(snapshot_model.repo_id)
|
||||
.subquery()
|
||||
)
|
||||
latest_snaps = (
|
||||
await session.execute(
|
||||
select(snapshot_model)
|
||||
.join(
|
||||
latest_sub,
|
||||
(snapshot_model.repo_id == latest_sub.c.repo_id)
|
||||
& (snapshot_model.snapshot_at == latest_sub.c.max_at),
|
||||
)
|
||||
.options(selectinload(snapshot_model.entries).selectinload(entry_model.catalog_entry))
|
||||
)
|
||||
).scalars().all()
|
||||
all_repos = (await session.execute(select(repo_model))).scalars().all()
|
||||
repo_map = {repo.id: repo.slug for repo in all_repos}
|
||||
all_services = (await session.execute(select(catalog_model))).scalars().all()
|
||||
by_maturity: dict[str, int] = {}
|
||||
for service in all_services:
|
||||
by_maturity[service.gdpr_maturity] = by_maturity.get(service.gdpr_maturity, 0) + 1
|
||||
|
||||
warnings = []
|
||||
seen = set()
|
||||
for snap in latest_snaps:
|
||||
repo_slug = repo_map.get(snap.repo_id) if snap.repo_id else None
|
||||
for entry in snap.entries:
|
||||
catalog_entry = entry.catalog_entry
|
||||
maturity = catalog_entry.gdpr_maturity if catalog_entry else "unknown"
|
||||
if maturity not in GDPR_WARNING_LEVELS:
|
||||
continue
|
||||
key = (repo_slug, entry.service_slug)
|
||||
if key in seen:
|
||||
continue
|
||||
seen.add(key)
|
||||
warnings.append(
|
||||
TPSCGDPRWarning(
|
||||
repo_slug=repo_slug,
|
||||
service_slug=entry.service_slug,
|
||||
gdpr_maturity=maturity,
|
||||
purpose=entry.purpose,
|
||||
pricing_model=catalog_entry.pricing_model if catalog_entry else None,
|
||||
)
|
||||
)
|
||||
return TPSCGDPRReport(
|
||||
generated_at=datetime.now(tz=timezone.utc),
|
||||
total_services=len(all_services),
|
||||
warning_count=len(warnings),
|
||||
warnings=warnings,
|
||||
by_maturity=by_maturity,
|
||||
)
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def _entry_read(entry: TPSCEntry, catalog_entry: TPSCCatalog | None) -> TPSCEntryRead:
|
||||
return TPSCEntryRead(
|
||||
id=entry.id,
|
||||
snapshot_id=entry.snapshot_id,
|
||||
catalog_id=entry.catalog_id,
|
||||
service_slug=entry.service_slug,
|
||||
purpose=entry.purpose,
|
||||
auth_type=entry.auth_type,
|
||||
endpoint_override=entry.endpoint_override,
|
||||
notes=entry.notes,
|
||||
gdpr_maturity=catalog_entry.gdpr_maturity if catalog_entry else None,
|
||||
gdpr_warning=(catalog_entry.gdpr_maturity in GDPR_WARNING_LEVELS) if catalog_entry else True,
|
||||
pricing_model=catalog_entry.pricing_model if catalog_entry else None,
|
||||
)
|
||||
|
||||
|
||||
def _snapshot_read(snapshot: TPSCSnapshot) -> TPSCSnapshotRead:
|
||||
return TPSCSnapshotRead(
|
||||
id=snapshot.id,
|
||||
repo_id=snapshot.repo_id,
|
||||
snapshot_at=snapshot.snapshot_at,
|
||||
source_file=snapshot.source_file,
|
||||
entry_count=snapshot.entry_count,
|
||||
entries=[_entry_read(entry, entry.catalog_entry) for entry in snapshot.entries],
|
||||
)
|
||||
Reference in New Issue
Block a user