Files
railiance-fabric/railiance_fabric/registry.py

483 lines
19 KiB
Python

from __future__ import annotations
import json
import sqlite3
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import jsonschema
from .loader import load_yaml, repo_root
class RegistryError(Exception):
def __init__(self, message: str, status_code: int = 400) -> None:
super().__init__(message)
self.message = message
self.status_code = status_code
@dataclass(frozen=True)
class RegistryStore:
path: Path
def init_schema(self) -> None:
if str(self.path) != ":memory:":
self.path.parent.mkdir(parents=True, exist_ok=True)
with self._connect() as db:
db.executescript(
"""
create table if not exists repositories (
slug text primary key,
name text not null,
remote_url text,
default_branch text,
state_hub_repo_id text,
created_at text not null,
updated_at text not null
);
create table if not exists snapshots (
id integer primary key autoincrement,
repo_slug text not null references repositories(slug),
commit_sha text not null,
generated_at text not null,
graph_json text not null,
created_at text not null
);
create index if not exists idx_snapshots_repo_latest
on snapshots(repo_slug, id desc);
"""
)
def upsert_repository(self, payload: dict[str, Any]) -> dict[str, Any]:
slug = _required_text(payload, "slug")
now = _utc_now()
name = str(payload.get("name") or slug)
remote_url = _optional_text(payload, "remote_url")
default_branch = str(payload.get("default_branch") or "main")
state_hub_repo_id = _optional_text(payload, "state_hub_repo_id")
with self._connect() as db:
db.execute(
"""
insert into repositories (
slug, name, remote_url, default_branch, state_hub_repo_id,
created_at, updated_at
)
values (?, ?, ?, ?, ?, ?, ?)
on conflict(slug) do update set
name = excluded.name,
remote_url = excluded.remote_url,
default_branch = excluded.default_branch,
state_hub_repo_id = excluded.state_hub_repo_id,
updated_at = excluded.updated_at
""",
(slug, name, remote_url, default_branch, state_hub_repo_id, now, now),
)
return self.get_repository(slug)
def list_repositories(self) -> list[dict[str, Any]]:
with self._connect() as db:
rows = db.execute(
"""
select slug, name, remote_url, default_branch, state_hub_repo_id,
created_at, updated_at
from repositories
order by slug
"""
).fetchall()
return [_row_dict(row) for row in rows]
def get_repository(self, slug: str) -> dict[str, Any]:
with self._connect() as db:
row = db.execute(
"""
select slug, name, remote_url, default_branch, state_hub_repo_id,
created_at, updated_at
from repositories
where slug = ?
""",
(slug,),
).fetchone()
if row is None:
raise RegistryError(f"repository not found: {slug}", 404)
return _row_dict(row)
def add_snapshot(self, repo_slug: str, payload: dict[str, Any]) -> dict[str, Any]:
self.get_repository(repo_slug)
commit = _required_text(payload, "commit")
generated_at = str(payload.get("generated_at") or _utc_now())
graph = payload.get("graph")
if not isinstance(graph, dict):
raise RegistryError("snapshot payload requires object field 'graph'")
graph = _with_source(graph, repo_slug, commit, generated_at)
validate_graph_export(graph)
now = _utc_now()
with self._connect() as db:
cursor = db.execute(
"""
insert into snapshots (repo_slug, commit_sha, generated_at, graph_json, created_at)
values (?, ?, ?, ?, ?)
""",
(repo_slug, commit, generated_at, json.dumps(graph, sort_keys=True), now),
)
snapshot_id = int(cursor.lastrowid)
return self.get_snapshot(snapshot_id)
def get_snapshot(self, snapshot_id: int) -> dict[str, Any]:
with self._connect() as db:
row = db.execute(
"""
select id, repo_slug, commit_sha, generated_at, graph_json, created_at
from snapshots
where id = ?
""",
(snapshot_id,),
).fetchone()
if row is None:
raise RegistryError(f"snapshot not found: {snapshot_id}", 404)
return _snapshot_dict(row)
def latest_snapshot(self, repo_slug: str) -> dict[str, Any]:
self.get_repository(repo_slug)
with self._connect() as db:
row = db.execute(
"""
select id, repo_slug, commit_sha, generated_at, graph_json, created_at
from snapshots
where repo_slug = ?
order by id desc
limit 1
""",
(repo_slug,),
).fetchone()
if row is None:
raise RegistryError(f"no snapshots for repository: {repo_slug}", 404)
return _snapshot_dict(row)
def latest_snapshots(self) -> list[dict[str, Any]]:
with self._connect() as db:
rows = db.execute(
"""
select s.id, s.repo_slug, s.commit_sha, s.generated_at, s.graph_json, s.created_at
from snapshots s
join (
select repo_slug, max(id) as latest_id
from snapshots
group by repo_slug
) latest on latest.latest_id = s.id
order by s.repo_slug
"""
).fetchall()
return [_snapshot_dict(row) for row in rows]
def combined_graph(self) -> dict[str, Any]:
nodes: dict[str, dict[str, Any]] = {}
edges: list[dict[str, str]] = []
for snapshot in self.latest_snapshots():
graph = snapshot["graph"]
for node in graph.get("nodes", []):
if isinstance(node, dict):
nodes[str(node.get("id", ""))] = node
for edge in graph.get("edges", []):
if isinstance(edge, dict):
edges.append(
{
"from": str(edge.get("from", "")),
"to": str(edge.get("to", "")),
"type": str(edge.get("type", "")),
}
)
return {
"apiVersion": "railiance.fabric/v1alpha1",
"kind": "FabricGraphExport",
"generated_at": _utc_now(),
"source": {"repo": "registry", "commit": "", "path": ""},
"nodes": [nodes[key] for key in sorted(nodes)],
"edges": sorted(edges, key=lambda edge: (edge["from"], edge["to"], edge["type"])),
}
def _connect(self) -> sqlite3.Connection:
db = sqlite3.connect(self.path)
db.row_factory = sqlite3.Row
return db
def validate_graph_export(graph: dict[str, Any]) -> None:
schemas_dir = repo_root() / "schemas"
schema_path = schemas_dir / "state-hub-export.schema.yaml"
store = {
path.resolve().as_uri(): load_yaml(path)
for path in sorted(schemas_dir.glob("*.schema.yaml"))
}
schema = load_yaml(schema_path)
resolver = jsonschema.RefResolver(
base_uri=schema_path.resolve().as_uri(),
referrer=schema,
store=store,
)
validator = jsonschema.Draft202012Validator(schema, resolver=resolver)
errors = sorted(validator.iter_errors(graph), key=lambda error: list(error.path))
if errors:
error = errors[0]
location = ".".join(str(part) for part in error.path) or "<root>"
raise RegistryError(f"invalid FabricGraphExport at {location}: {error.message}")
def providers(graph: dict[str, Any], capability: str) -> list[dict[str, Any]]:
result = []
for node in _nodes(graph):
attrs = _attrs(node)
if node.get("kind") != "CapabilityDeclaration":
continue
if node.get("id") == capability or attrs.get("capability_type") == capability:
result.append(
{
"provider_id": node.get("id", ""),
"name": node.get("name", ""),
"service_id": attrs.get("service_id", ""),
"capability_type": attrs.get("capability_type", ""),
"lifecycle": node.get("lifecycle", ""),
"interfaces": attrs.get("interface_ids", []),
"repo": node.get("repo", ""),
"domain": node.get("domain", ""),
}
)
return sorted(result, key=lambda item: item["provider_id"])
def consumers(graph: dict[str, Any], target: str) -> list[dict[str, Any]]:
nodes = _nodes_by_id(graph)
target_interface_type = ""
target_node = nodes.get(target)
if target_node and target_node.get("kind") == "InterfaceDeclaration":
target_interface_type = str(_attrs(target_node).get("interface_type", ""))
result: list[dict[str, Any]] = []
bindings_by_dependency = _bindings_by_dependency(graph)
for dependency in _dependency_nodes(graph):
attrs = _attrs(dependency)
dependency_id = str(dependency.get("id", ""))
dependency_matches = target in {
dependency_id,
str(attrs.get("requires_capability_id", "")),
str(attrs.get("requires_capability_type", "")),
str(attrs.get("interface_type", "")),
} or bool(target_interface_type and target_interface_type == attrs.get("interface_type"))
bindings = bindings_by_dependency.get(dependency_id, [])
matching_bindings = [
binding
for binding in bindings
if target in {binding["provider_capability_id"], binding["provider_interface_id"]}
]
if dependency_matches and not bindings:
result.append(_consumer_match(attrs, dependency_id, {}))
for binding in bindings:
if dependency_matches or binding in matching_bindings:
result.append(_consumer_match(attrs, dependency_id, binding))
return sorted(result, key=lambda item: (item["consumer_service_id"], item["dependency_id"]))
def unresolved_dependencies(graph: dict[str, Any]) -> list[dict[str, Any]]:
result = []
bindings_by_dependency = _bindings_by_dependency(graph)
for dependency in _dependency_nodes(graph):
attrs = _attrs(dependency)
dependency_id = str(dependency.get("id", ""))
required_id = str(attrs.get("requires_capability_id", ""))
required_type = str(attrs.get("requires_capability_type", ""))
provider_matches = providers(graph, required_id or required_type)
bindings = bindings_by_dependency.get(dependency_id, [])
has_missing_binding = any(binding.get("status") in {"missing", "disputed"} for binding in bindings)
if not provider_matches or has_missing_binding:
result.append(
{
"dependency_id": dependency_id,
"consumer_service_id": attrs.get("consumer_service_id", ""),
"requires_capability_id": required_id,
"requires_capability_type": required_type,
"interface_type": attrs.get("interface_type", ""),
"reason": "missing_provider" if not provider_matches else "binding_not_exact",
}
)
return sorted(result, key=lambda item: item["dependency_id"])
def blast_radius(graph: dict[str, Any], interface: str) -> list[dict[str, Any]]:
target_node = _nodes_by_id(graph).get(interface)
matches = consumers(graph, interface)
if target_node and target_node.get("kind") == "InterfaceDeclaration":
return [match for match in matches if match.get("provider_interface_id") == interface]
return [
match
for match in matches
if _dependency_attrs(graph, match["dependency_id"]).get("interface_type") == interface
]
def dependency_path_lines(graph: dict[str, Any], service_id: str) -> list[str]:
nodes = _nodes_by_id(graph)
if service_id not in nodes or nodes[service_id].get("kind") != "ServiceDeclaration":
return [f"unknown service: {service_id}"]
deps_by_consumer: dict[str, list[dict[str, Any]]] = {}
for dependency in _dependency_nodes(graph):
attrs = _attrs(dependency)
deps_by_consumer.setdefault(str(attrs.get("consumer_service_id", "")), []).append(dependency)
capability_service = {
str(node.get("id", "")): str(_attrs(node).get("service_id", ""))
for node in _nodes(graph)
if node.get("kind") == "CapabilityDeclaration"
}
bindings_by_dependency = _bindings_by_dependency(graph)
lines: list[str] = []
def walk(current: str, indent: int, stack: list[str]) -> None:
prefix = " " * indent
if current in stack:
lines.append(f"{prefix}{current} (cycle)")
return
lines.append(f"{prefix}{current}")
dependencies = sorted(deps_by_consumer.get(current, []), key=lambda item: str(item.get("id", "")))
if not dependencies:
lines.append(f"{prefix} no declared dependencies")
return
for dependency in dependencies:
dependency_id = str(dependency.get("id", ""))
attrs = _attrs(dependency)
required = attrs.get("requires_capability_type", "")
lines.append(f"{prefix} requires {required}: {dependency_id}")
bindings = bindings_by_dependency.get(dependency_id, [])
if not bindings:
candidate_providers = providers(graph, str(required))
if candidate_providers:
for provider in candidate_providers:
lines.append(f"{prefix} candidate {provider['provider_id']}")
else:
lines.append(f"{prefix} unresolved")
continue
for binding in bindings:
provider_id = binding.get("provider_capability_id", "")
provider_service = capability_service.get(provider_id, "")
status = binding.get("status", "")
lines.append(f"{prefix} {status} -> {provider_id}")
if provider_service and provider_service != current:
walk(provider_service, indent + 3, stack + [current])
walk(service_id, 0, [])
return lines
def graph_node(graph: dict[str, Any], graph_id: str) -> dict[str, Any]:
node = _nodes_by_id(graph).get(graph_id)
if node is None:
raise RegistryError(f"graph node not found: {graph_id}", 404)
return node
def _with_source(graph: dict[str, Any], repo_slug: str, commit: str, generated_at: str) -> dict[str, Any]:
copy = json.loads(json.dumps(graph))
copy.setdefault("generated_at", generated_at)
copy.setdefault("source", {})
copy["source"].setdefault("repo", repo_slug)
copy["source"].setdefault("commit", commit)
copy["source"].setdefault("path", "")
return copy
def _snapshot_dict(row: sqlite3.Row) -> dict[str, Any]:
return {
"id": row["id"],
"repo_slug": row["repo_slug"],
"commit": row["commit_sha"],
"generated_at": row["generated_at"],
"graph": json.loads(row["graph_json"]),
"created_at": row["created_at"],
}
def _row_dict(row: sqlite3.Row) -> dict[str, Any]:
return {key: row[key] for key in row.keys()}
def _required_text(payload: dict[str, Any], key: str) -> str:
value = payload.get(key)
if not isinstance(value, str) or not value.strip():
raise RegistryError(f"field '{key}' is required")
return value.strip()
def _optional_text(payload: dict[str, Any], key: str) -> str | None:
value = payload.get(key)
if value is None:
return None
if not isinstance(value, str):
raise RegistryError(f"field '{key}' must be a string")
return value
def _utc_now() -> str:
return datetime.now(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z")
def _nodes(graph: dict[str, Any]) -> list[dict[str, Any]]:
return [node for node in graph.get("nodes", []) if isinstance(node, dict)]
def _nodes_by_id(graph: dict[str, Any]) -> dict[str, dict[str, Any]]:
return {str(node.get("id", "")): node for node in _nodes(graph)}
def _attrs(node: dict[str, Any]) -> dict[str, Any]:
attrs = node.get("attributes", {})
return attrs if isinstance(attrs, dict) else {}
def _dependency_nodes(graph: dict[str, Any]) -> list[dict[str, Any]]:
return [node for node in _nodes(graph) if node.get("kind") == "DependencyDeclaration"]
def _dependency_attrs(graph: dict[str, Any], dependency_id: str) -> dict[str, Any]:
node = _nodes_by_id(graph).get(dependency_id, {})
return _attrs(node)
def _bindings_by_dependency(graph: dict[str, Any]) -> dict[str, list[dict[str, str]]]:
result: dict[str, list[dict[str, str]]] = {}
for node in _nodes(graph):
if node.get("kind") != "BindingAssertion":
continue
attrs = _attrs(node)
dependency_id = str(attrs.get("dependency_id", ""))
if not dependency_id:
continue
result.setdefault(dependency_id, []).append(
{
"binding_id": str(node.get("id", "")),
"provider_capability_id": str(attrs.get("provider_capability_id", "")),
"provider_interface_id": str(attrs.get("provider_interface_id", "")),
"status": str(attrs.get("status", "")),
}
)
for bindings in result.values():
bindings.sort(key=lambda item: item["binding_id"])
return result
def _consumer_match(attrs: dict[str, Any], dependency_id: str, binding: dict[str, str]) -> dict[str, Any]:
return {
"consumer_service_id": attrs.get("consumer_service_id", ""),
"dependency_id": dependency_id,
"required_capability_type": attrs.get("requires_capability_type", ""),
"provider_capability_id": binding.get("provider_capability_id", ""),
"provider_interface_id": binding.get("provider_interface_id", ""),
"status": binding.get("status", ""),
}