#!/usr/bin/env python3 """Idempotent registration from committed ``.repo-classification.yaml`` (STATE-WP-0065 P3). Reads classification from a repo checkout, validates against the canon allowed-values, and upserts the ``managed_repos`` row (create or update classification + market domain). Usage: python scripts/register_from_classification.py --repo-path /path/to/repo [--dry-run] python scripts/register_from_classification.py --slug state-hub [--dry-run] python scripts/register_from_classification.py --bulk [--dry-run] python scripts/register_from_classification.py --help """ from __future__ import annotations import argparse import asyncio import json import re import socket import subprocess import sys from dataclasses import dataclass, field from datetime import date from pathlib import Path from typing import Any, Literal _REPO_ROOT = Path(__file__).resolve().parent.parent if str(_REPO_ROOT) not in sys.path: sys.path.insert(0, str(_REPO_ROOT)) from sqlalchemy import select # noqa: E402 from api.classification import ( # noqa: E402 CLASSIFICATION_FILENAME, ClassificationData, load_classification_file, ) from api.config import settings # noqa: E402 from api.database import async_session_factory, engine # noqa: E402 from api.models.domain import Domain # noqa: E402 from api.models.managed_repo import ManagedRepo # noqa: E402 try: import httpx _HAS_HTTPX = True except ImportError: _HAS_HTTPX = False Outcome = Literal["registered", "updated", "skipped", "invalid"] @dataclass class RowResult: slug: str path: str outcome: Outcome detail: str = "" warnings: list[str] = field(default_factory=list) @dataclass class RegistrationReport: results: list[RowResult] = field(default_factory=list) def add(self, result: RowResult) -> None: self.results.append(result) def counts(self) -> dict[str, int]: totals = {"registered": 0, "updated": 0, "skipped": 0, "invalid": 0} for row in self.results: totals[row.outcome] = totals.get(row.outcome, 0) + 1 return totals def render_text(self) -> str: lines = ["register-from-classification report", ""] for row in self.results: lines.append(f" [{row.outcome:10}] {row.slug:30} {row.detail}") for warning in row.warnings: lines.append(f" warn: {warning}") counts = self.counts() lines.append("") lines.append( "Summary: " f"registered={counts['registered']} " f"updated={counts['updated']} " f"skipped={counts['skipped']} " f"invalid={counts['invalid']}" ) return "\n".join(lines) def to_dict(self) -> dict[str, Any]: return { "summary": self.counts(), "results": [ { "slug": r.slug, "path": r.path, "outcome": r.outcome, "detail": r.detail, "warnings": r.warnings, } for r in self.results ], } def _slugify(name: str) -> str: slug = re.sub(r"[^a-z0-9]+", "-", name.lower()).strip("-") return slug or "repo" def _parse_classified_at(value: str | None) -> date | None: if not value: return None try: return date.fromisoformat(str(value)[:10]) except ValueError: return None def _git_value(repo_path: Path, args: list[str]) -> str | None: try: return subprocess.check_output( ["git", *args], cwd=repo_path, stderr=subprocess.DEVNULL, text=True, ).strip() or None except (subprocess.CalledProcessError, FileNotFoundError, OSError): return None def _git_root(path: Path) -> Path: root = _git_value(path, ["rev-parse", "--show-toplevel"]) return Path(root) if root else path.resolve() def _resolve_repo_path_for_host(repo: ManagedRepo) -> str | None: hostname = socket.gethostname() host_paths = repo.host_paths or {} path = host_paths.get(hostname) or repo.local_path if path and Path(path).is_dir(): return path for candidate in host_paths.values(): if candidate and Path(candidate).is_dir(): return candidate return None def _classification_changed(repo: ManagedRepo, data: ClassificationData, domain_id) -> bool: if repo.domain_id != domain_id: return True fields = ( ("category", data.category), ("secondary_domains", data.secondary_domains or None), ("capability_tags", data.capability_tags or None), ("business_stake", data.business_stake or None), ("business_mechanics", data.business_mechanics or None), ("classified_at", _parse_classified_at(data.classified_at)), ("classified_by", data.classified_by), ("standard_version", data.standard_version), ) for attr, new_val in fields: if getattr(repo, attr) != new_val: return True return False def _apply_classification(repo: ManagedRepo, data: ClassificationData, domain_id) -> None: repo.domain_id = domain_id repo.category = data.category repo.secondary_domains = data.secondary_domains or None repo.capability_tags = data.capability_tags or None repo.business_stake = data.business_stake or None repo.business_mechanics = data.business_mechanics or None repo.classified_at = _parse_classified_at(data.classified_at) repo.classified_by = data.classified_by repo.standard_version = data.standard_version async def _get_domain_id(session, market_slug: str): result = await session.execute(select(Domain).where(Domain.slug == market_slug)) domain = result.scalar_one_or_none() if domain is None: raise ValueError(f"Market domain '{market_slug}' not found in domains table") return domain.id async def _get_repo_by_slug(session, slug: str) -> ManagedRepo | None: result = await session.execute(select(ManagedRepo).where(ManagedRepo.slug == slug)) return result.scalar_one_or_none() def _api_request( method: str, path: str, *, api_base: str, body: dict | None = None, ) -> tuple[int, Any]: if not _HAS_HTTPX: return (0, {"_error": "httpx not installed"}) url = api_base.rstrip("/") + path try: with httpx.Client(timeout=30.0) as client: response = client.request(method, url, json=body) if response.status_code == 204: return response.status_code, None try: payload = response.json() except Exception: payload = {"_raw": response.text} return response.status_code, payload except httpx.HTTPError as exc: return (0, {"_error": str(exc)}) async def _upsert_via_db( *, slug: str, repo_path: Path, data: ClassificationData, dry_run: bool, report: RegistrationReport, ) -> None: git_root = _git_root(repo_path) remote_url = _git_value(git_root, ["remote", "get-url", "origin"]) git_fingerprint = _git_value(git_root, ["rev-list", "--max-parents=0", "HEAD"]) hostname = socket.gethostname() display_name = git_root.name.replace("-", " ").replace("_", " ").title() async with async_session_factory() as session: try: domain_id = await _get_domain_id(session, data.domain) except ValueError as exc: if dry_run: report.add( RowResult( slug, str(git_root), "skipped", f"dry-run: {exc}", ) ) return report.add(RowResult(slug, str(git_root), "invalid", str(exc))) return repo = await _get_repo_by_slug(session, slug) if repo is None: if dry_run: report.add( RowResult( slug, str(git_root), "registered", f"would create repo under domain '{data.domain}' (dry-run)", ) ) return repo = ManagedRepo( domain_id=domain_id, slug=slug, name=display_name, local_path=str(git_root), host_paths={hostname: str(git_root)}, remote_url=remote_url, git_fingerprint=git_fingerprint, status="active", ) _apply_classification(repo, data, domain_id) session.add(repo) await session.commit() report.add( RowResult(slug, str(git_root), "registered", f"domain={data.domain}") ) return warnings: list[str] = [] if not _classification_changed(repo, data, domain_id): if repo.local_path != str(git_root): if dry_run: report.add( RowResult( slug, str(git_root), "skipped", "classification unchanged; would refresh local_path (dry-run)", ) ) return repo.local_path = str(git_root) host_paths = dict(repo.host_paths or {}) host_paths[hostname] = str(git_root) repo.host_paths = host_paths if remote_url: repo.remote_url = remote_url if git_fingerprint: repo.git_fingerprint = git_fingerprint await session.commit() report.add( RowResult(slug, str(git_root), "skipped", "paths refreshed only") ) return report.add( RowResult(slug, str(git_root), "skipped", "classification already current") ) return if dry_run: report.add( RowResult( slug, str(git_root), "updated", f"would update classification (domain={data.domain}) (dry-run)", ) ) return _apply_classification(repo, data, domain_id) repo.local_path = str(git_root) host_paths = dict(repo.host_paths or {}) host_paths[hostname] = str(git_root) repo.host_paths = host_paths if remote_url: repo.remote_url = remote_url if git_fingerprint: repo.git_fingerprint = git_fingerprint await session.commit() report.add( RowResult(slug, str(git_root), "updated", f"domain={data.domain}") ) async def _upsert_via_api( *, slug: str, repo_path: Path, data: ClassificationData, dry_run: bool, api_base: str, report: RegistrationReport, ) -> None: git_root = _git_root(repo_path) remote_url = _git_value(git_root, ["remote", "get-url", "origin"]) git_fingerprint = _git_value(git_root, ["rev-list", "--max-parents=0", "HEAD"]) hostname = socket.gethostname() display_name = git_root.name.replace("-", " ").replace("_", " ").title() status, existing = _api_request("GET", f"/repos/{slug}", api_base=api_base) if status == 404 or (isinstance(existing, dict) and existing.get("detail")): existing = None elif status == 0: report.add( RowResult( slug, str(git_root), "invalid", f"API unreachable: {existing.get('_error', existing)}", ) ) return patch_body = { "category": data.category, "secondary_domains": data.secondary_domains, "capability_tags": data.capability_tags, "business_stake": data.business_stake, "business_mechanics": data.business_mechanics, "classified_at": data.classified_at, "classified_by": data.classified_by, "standard_version": data.standard_version, "domain_slug": data.domain, "local_path": str(git_root), "remote_url": remote_url, "git_fingerprint": git_fingerprint, } if existing is None: if dry_run: report.add( RowResult( slug, str(git_root), "registered", f"would POST /repos/ domain={data.domain} (dry-run)", ) ) return post_body = { "domain_slug": data.domain, "slug": slug, "name": display_name, "local_path": str(git_root), "host_paths": {hostname: str(git_root)}, "remote_url": remote_url, "git_fingerprint": git_fingerprint, } code, created = _api_request("POST", "/repos/", api_base=api_base, body=post_body) if code not in (200, 201): detail = created.get("detail", created) if isinstance(created, dict) else created report.add(RowResult(slug, str(git_root), "invalid", f"POST failed: {detail}")) return code, updated = _api_request( "PATCH", f"/repos/{slug}", api_base=api_base, body=patch_body ) if code != 200: detail = updated.get("detail", updated) if isinstance(updated, dict) else updated report.add( RowResult( slug, str(git_root), "invalid", f"created repo but classification PATCH failed: {detail}", ) ) return report.add(RowResult(slug, str(git_root), "registered", f"domain={data.domain}")) return if dry_run: report.add( RowResult( slug, str(git_root), "updated", f"would PATCH /repos/{slug} domain={data.domain} (dry-run)", ) ) return code, updated = _api_request( "PATCH", f"/repos/{slug}", api_base=api_base, body=patch_body ) if code != 200: detail = updated.get("detail", updated) if isinstance(updated, dict) else updated report.add(RowResult(slug, str(git_root), "invalid", f"PATCH failed: {detail}")) return _api_request( "POST", f"/repos/{slug}/paths", api_base=api_base, body={"host": hostname, "path": str(git_root)}, ) report.add(RowResult(slug, str(git_root), "updated", f"domain={data.domain}")) async def register_one( *, slug: str, repo_path: Path, dry_run: bool = False, use_api: bool = False, api_base: str | None = None, report: RegistrationReport | None = None, ) -> RowResult: """Register or update a single repo from its classification file.""" report = report or RegistrationReport() git_root = _git_root(repo_path) data, errors, warnings = load_classification_file(git_root) if data is None: result = RowResult( slug, str(git_root), "invalid", "; ".join(errors) or "classification invalid", warnings=warnings, ) report.add(result) return result if use_api: await _upsert_via_api( slug=slug, repo_path=git_root, data=data, dry_run=dry_run, api_base=api_base or settings.api_base, report=report, ) else: await _upsert_via_db( slug=slug, repo_path=git_root, data=data, dry_run=dry_run, report=report, ) return report.results[-1] async def _bulk_targets(session) -> list[tuple[str, str]]: result = await session.execute( select(ManagedRepo).where(ManagedRepo.status == "active").order_by(ManagedRepo.slug) ) targets: list[tuple[str, str]] = [] for repo in result.scalars().all(): path = _resolve_repo_path_for_host(repo) if path: targets.append((repo.slug, path)) return targets async def run_registration(args: argparse.Namespace) -> RegistrationReport: report = RegistrationReport() use_api = args.api and not args.db if args.bulk: async with async_session_factory() as session: targets = await _bulk_targets(session) if not targets: report.add( RowResult("(bulk)", "", "skipped", "no active repos with accessible local paths") ) return report for slug, path in targets: await register_one( slug=slug, repo_path=Path(path), dry_run=args.dry_run, use_api=use_api, api_base=args.api_base, report=report, ) return report if args.repo_path: repo_path = Path(args.repo_path).expanduser().resolve() slug = args.slug or _slugify(_git_root(repo_path).name) await register_one( slug=slug, repo_path=repo_path, dry_run=args.dry_run, use_api=use_api, api_base=args.api_base, report=report, ) return report if args.slug: async with async_session_factory() as session: repo = await _get_repo_by_slug(session, args.slug) if repo is None: report.add(RowResult(args.slug, "", "invalid", "repo slug not found in DB")) return report path = _resolve_repo_path_for_host(repo) if not path: report.add( RowResult( args.slug, "", "invalid", "no accessible local path (local_path / host_paths)", ) ) return report await register_one( slug=args.slug, repo_path=Path(path), dry_run=args.dry_run, use_api=use_api, api_base=args.api_base, report=report, ) return report raise SystemExit("Specify --repo-path PATH, --slug SLUG, or --bulk") def build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( description="Register or update managed_repos from .repo-classification.yaml", ) parser.add_argument("--repo-path", metavar="PATH", help="Local git checkout path") parser.add_argument( "--slug", metavar="SLUG", help="Registered repo slug (required with --bulk omitted unless --repo-path given)", ) parser.add_argument( "--bulk", action="store_true", help="All active registered repos with accessible local paths", ) parser.add_argument( "--dry-run", action="store_true", help="Report actions without writing to DB/API", ) parser.add_argument( "--api", action="store_true", help="Upsert via REST API (default: direct DB session)", ) parser.add_argument( "--db", action="store_true", help="Force direct DB session (overrides --api)", ) parser.add_argument( "--api-base", default=settings.api_base, help=f"State Hub API base URL (default: {settings.api_base})", ) parser.add_argument("--json", action="store_true", help="Emit JSON report") return parser def main(argv: list[str] | None = None) -> int: parser = build_parser() args = parser.parse_args(argv) if args.bulk: if args.repo_path: parser.error("--bulk cannot be combined with --repo-path") elif args.repo_path: pass elif args.slug: pass else: parser.error("Specify one of --repo-path PATH, --slug SLUG, or --bulk") report = asyncio.run(run_registration(args)) if args.json: print(json.dumps(report.to_dict(), indent=2)) else: print(report.render_text()) counts = report.counts() return 1 if counts["invalid"] else 0 if __name__ == "__main__": raise SystemExit(main())