from collections.abc import Callable from typing import Any, Awaitable from fastapi import APIRouter, Depends, HTTPException, Query, Response, status from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import noload from hub_core.models.domain import Domain from hub_core.models.managed_repo import ManagedRepo from hub_core.schemas.domain import DomainCreate, DomainDetail, DomainRead, DomainRename, DomainUpdate, RepoStub DomainDetailBuilder = Callable[[Any, AsyncSession], Awaitable[Any]] DomainArchiveValidator = Callable[[Any, AsyncSession], Awaitable[None]] def create_domains_router( get_session: Callable[..., AsyncSession], *, domain_model: type[Domain] = Domain, repo_model: type[ManagedRepo] = ManagedRepo, domain_create_schema: type[DomainCreate] = DomainCreate, domain_detail_schema: type[DomainDetail] = DomainDetail, domain_read_schema: type[DomainRead] = DomainRead, domain_rename_schema: type[DomainRename] = DomainRename, domain_update_schema: type[DomainUpdate] = DomainUpdate, repo_stub_schema: type[RepoStub] = RepoStub, detail_builder: DomainDetailBuilder | None = None, before_archive: DomainArchiveValidator | None = None, list_noload_fields: tuple[str, ...] = ("repos",), include_update_route: bool = True, ) -> APIRouter: router = APIRouter(prefix="/domains", tags=["domains"]) list_response_model = list[domain_read_schema] @router.get("/", response_model=list_response_model) async def list_domains( response: Response, status_filter: str | None = Query(None, alias="status", description="active | archived | all"), session: AsyncSession = Depends(get_session), ) -> list[Any]: response.headers["Cache-Control"] = "max-age=60, stale-while-revalidate=30" q = select(domain_model).options( *[ noload(getattr(domain_model, field)) for field in list_noload_fields if hasattr(domain_model, field) ] ).order_by(domain_model.name) if status_filter and status_filter != "all": q = q.where(domain_model.status == status_filter) elif status_filter is None: q = q.where(domain_model.status == "active") result = await session.execute(q) return list(result.scalars().all()) @router.post("/", response_model=domain_read_schema, status_code=status.HTTP_201_CREATED) async def create_domain( body: domain_create_schema, session: AsyncSession = Depends(get_session), ) -> Any: existing = await session.execute(select(domain_model).where(domain_model.slug == body.slug)) if existing.scalar_one_or_none(): raise HTTPException(status_code=409, detail=f"Domain slug '{body.slug}' already exists") domain = domain_model(slug=body.slug, name=body.name, description=body.description) session.add(domain) await session.commit() await session.refresh(domain) return domain @router.get("/{slug}", response_model=domain_detail_schema) async def get_domain( slug: str, session: AsyncSession = Depends(get_session), ) -> Any: domain = await _get_domain_by_slug(slug, session, domain_model) if detail_builder is not None: return await detail_builder(domain, session) return await _build_default_domain_detail( domain, session, repo_model=repo_model, repo_stub_schema=repo_stub_schema, domain_detail_schema=domain_detail_schema, ) if include_update_route: @router.patch("/{slug}", response_model=domain_read_schema) async def update_domain( slug: str, body: domain_update_schema, session: AsyncSession = Depends(get_session), ) -> Any: domain = await _get_domain_by_slug(slug, session, domain_model) for field, value in body.model_dump(exclude_unset=True).items(): setattr(domain, field, value) await session.commit() await session.refresh(domain) return domain @router.patch("/{slug}/rename", response_model=domain_read_schema) async def rename_domain( slug: str, body: domain_rename_schema, session: AsyncSession = Depends(get_session), ) -> Any: domain = await _get_domain_by_slug(slug, session, domain_model) if body.new_slug != slug: conflict = await session.execute(select(domain_model).where(domain_model.slug == body.new_slug)) if conflict.scalar_one_or_none(): raise HTTPException(status_code=409, detail=f"Slug '{body.new_slug}' already taken") domain.slug = body.new_slug domain.name = body.new_name await session.commit() await session.refresh(domain) return domain @router.patch("/{slug}/archive", response_model=domain_read_schema) async def archive_domain( slug: str, session: AsyncSession = Depends(get_session), ) -> Any: domain = await _get_domain_by_slug(slug, session, domain_model) if before_archive is not None: await before_archive(domain, session) domain.status = "archived" await session.commit() await session.refresh(domain) return domain return router async def _get_domain_by_slug( slug: str, session: AsyncSession, domain_model: type[Domain], ) -> Any: result = await session.execute(select(domain_model).where(domain_model.slug == slug)) domain = result.scalar_one_or_none() if domain is None: raise HTTPException(status_code=404, detail=f"Domain '{slug}' not found") return domain async def _build_default_domain_detail( domain: Any, session: AsyncSession, *, repo_model: type[ManagedRepo], repo_stub_schema: type[RepoStub], domain_detail_schema: type[DomainDetail], ) -> Any: repo_count_row = await session.execute( select(func.count()).select_from(repo_model) .where(repo_model.domain_id == domain.id) .where(repo_model.status == "active") ) repos_row = await session.execute( select(repo_model) .where(repo_model.domain_id == domain.id) .where(repo_model.status == "active") .order_by(repo_model.name) ) repos = list(repos_row.scalars().all()) return domain_detail_schema( id=domain.id, slug=domain.slug, name=domain.name, description=domain.description, status=domain.status, created_at=domain.created_at, updated_at=domain.updated_at, repos=[repo_stub_schema.model_validate(repo) for repo in repos], extension_counts={"repos": repo_count_row.scalar_one()}, )