from collections.abc import Callable from typing import Any, Awaitable from fastapi import APIRouter, Depends, HTTPException, Response, status from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import noload from hub_core.models.domain import Domain from hub_core.models.managed_repo import ManagedRepo from hub_core.schemas.managed_repo import RepoCreate, RepoPathRegister, RepoRead, RepoUpdate RepoRegisteredHook = Callable[[Any, Any, Any], Awaitable[None] | None] def create_repos_router( get_session: Callable[..., AsyncSession], *, prefix: str = "/repos", domain_model: type[Domain] = Domain, repo_model: type[ManagedRepo] = ManagedRepo, repo_create_schema: type[RepoCreate] = RepoCreate, repo_update_schema: type[RepoUpdate] = RepoUpdate, repo_read_schema: type[RepoRead] = RepoRead, repo_path_register_schema: type[RepoPathRegister] = RepoPathRegister, list_noload_fields: tuple[str, ...] = ("domain",), create_extension_fields: tuple[str, ...] = (), after_register: RepoRegisteredHook | None = None, include_collection_routes: bool = True, include_lookup_routes: bool = True, include_slug_routes: bool = True, ) -> APIRouter: router = APIRouter(prefix=prefix, tags=["repos"]) list_response_model = list[repo_read_schema] if include_collection_routes: @router.get("/", response_model=list_response_model) async def list_repos( response: Response, domain: str | None = None, session: AsyncSession = Depends(get_session), ) -> list[Any]: response.headers["Cache-Control"] = "max-age=60, stale-while-revalidate=30" q = select(repo_model).options( *[ noload(getattr(repo_model, field)) for field in list_noload_fields if hasattr(repo_model, field) ] ).order_by(repo_model.name) if domain: domain_obj = await _get_domain_by_slug(domain, session, domain_model) q = q.where(repo_model.domain_id == domain_obj.id) result = await session.execute(q) return list(result.scalars().all()) @router.post("/", response_model=repo_read_schema, status_code=status.HTTP_201_CREATED) async def register_repo( body: repo_create_schema, session: AsyncSession = Depends(get_session), ) -> Any: domain_obj = await _get_domain_by_slug(body.domain_slug, session, domain_model) existing = await session.execute(select(repo_model).where(repo_model.slug == body.slug)) if existing.scalar_one_or_none(): raise HTTPException(status_code=409, detail=f"Repo slug '{body.slug}' already exists") repo_attrs = { "domain_id": domain_obj.id, "slug": body.slug, "name": body.name, "local_path": body.local_path, "host_paths": body.host_paths, "remote_url": body.remote_url, "git_fingerprint": body.git_fingerprint, "description": body.description, } for field in create_extension_fields: if hasattr(body, field) and hasattr(repo_model, field): repo_attrs[field] = getattr(body, field) repo = repo_model(**repo_attrs) session.add(repo) await session.commit() await session.refresh(repo) if after_register is not None: hook_result = after_register(repo, body, domain_obj) if hook_result is not None: await hook_result return repo if include_lookup_routes: @router.get("/by-fingerprint", response_model=list_response_model) async def get_repo_by_fingerprint( hash: str, remote_url: str | None = None, session: AsyncSession = Depends(get_session), ) -> list[Any]: q = select(repo_model).where(repo_model.git_fingerprint == hash) if remote_url: q = q.where(repo_model.remote_url == remote_url) result = await session.execute(q) return list(result.scalars().all()) @router.get("/by-remote", response_model=repo_read_schema) async def get_repo_by_remote_url( url: str, session: AsyncSession = Depends(get_session), ) -> Any: result = await session.execute(select(repo_model).where(repo_model.remote_url == url)) repo = result.scalar_one_or_none() if repo is None: raise HTTPException(status_code=404, detail=f"No repo with remote_url '{url}' found") return repo if include_slug_routes: @router.get("/{slug}", response_model=repo_read_schema) async def get_repo( slug: str, session: AsyncSession = Depends(get_session), ) -> Any: return await _get_repo_by_slug(slug, session, repo_model) @router.patch("/{slug}", response_model=repo_read_schema) async def update_repo( slug: str, body: repo_update_schema, session: AsyncSession = Depends(get_session), ) -> Any: repo = await _get_repo_by_slug(slug, session, repo_model) for field, value in body.model_dump(exclude_unset=True).items(): setattr(repo, field, value) await session.commit() await session.refresh(repo) return repo @router.post("/{slug}/paths", response_model=repo_read_schema) async def register_repo_path( slug: str, body: repo_path_register_schema, session: AsyncSession = Depends(get_session), ) -> Any: repo = await _get_repo_by_slug(slug, session, repo_model) host_paths = dict(repo.host_paths or {}) host_paths[body.host] = body.path repo.host_paths = host_paths await session.commit() await session.refresh(repo) return repo return router async def _get_domain_by_slug( slug: str, session: AsyncSession, domain_model: type[Domain], ) -> Any: result = await session.execute(select(domain_model).where(domain_model.slug == slug)) domain = result.scalar_one_or_none() if domain is None: raise HTTPException(status_code=404, detail=f"Domain '{slug}' not found") return domain async def _get_repo_by_slug( slug: str, session: AsyncSession, repo_model: type[ManagedRepo], ) -> Any: result = await session.execute(select(repo_model).where(repo_model.slug == slug)) repo = result.scalar_one_or_none() if repo is None: raise HTTPException(status_code=404, detail=f"Repo '{slug}' not found") return repo