"""Extension descriptors and registries.""" from __future__ import annotations from dataclasses import asdict, dataclass, field from typing import Any, Callable, Iterable from markitect_tool.extension.processing import ProcessingCapability ExtensionFactory = Callable[[], Any] class ExtensionRegistryError(ValueError): """Raised when extension descriptors or registries are invalid.""" @dataclass(frozen=True) class OptionalDependency: """An optional runtime dependency declared by an extension.""" name: str package: str | None = None extra: str | None = None required: bool = False purpose: str | None = None def to_dict(self) -> dict[str, Any]: return _drop_empty(asdict(self)) @dataclass(frozen=True) class ExtensionDescriptor: """Inspectable descriptor for one internal extension.""" id: str kind: str version: str = "1" summary: str | None = None factory: ExtensionFactory | None = field(default=None, compare=False, repr=False) capabilities: list[ProcessingCapability] = field(default_factory=list) optional_dependencies: list[OptionalDependency] = field(default_factory=list) safety: dict[str, Any] = field(default_factory=dict) input_contract: str | None = None output_contract: str | None = None diagnostics_namespace: str | None = None provenance_prefix: str | None = None cli: dict[str, Any] = field(default_factory=dict) docs: list[str] = field(default_factory=list) examples: list[str] = field(default_factory=list) metadata: dict[str, Any] = field(default_factory=dict) def __post_init__(self) -> None: if not self.id.strip(): raise ExtensionRegistryError("Extension id cannot be empty") if not self.kind.strip(): raise ExtensionRegistryError("Extension kind cannot be empty") def to_dict(self) -> dict[str, Any]: data = { "id": self.id, "kind": self.kind, "version": self.version, "summary": self.summary, "capabilities": [capability.to_dict() for capability in self.capabilities], "optional_dependencies": [ dependency.to_dict() for dependency in self.optional_dependencies ], "safety": self.safety, "input_contract": self.input_contract, "output_contract": self.output_contract, "diagnostics_namespace": self.diagnostics_namespace, "provenance_prefix": self.provenance_prefix, "cli": self.cli, "docs": self.docs, "examples": self.examples, "metadata": self.metadata, } return _drop_empty(data) def instantiate(self) -> Any: """Create or return the implementation for this descriptor.""" if self.factory is None: raise ExtensionRegistryError(f"Extension `{self.id}` has no factory") return self.factory() @dataclass(frozen=True) class ExtensionDependencyCheck: """Result of checking required extension dependencies.""" extension_id: str missing: list[str] = field(default_factory=list) optional_missing: list[str] = field(default_factory=list) @property def compatible(self) -> bool: return not self.missing def to_dict(self) -> dict[str, Any]: return { "extension_id": self.extension_id, "compatible": self.compatible, "missing": self.missing, "optional_missing": self.optional_missing, } class ExtensionRegistry: """Registry of internal extension descriptors.""" def __init__(self, descriptors: Iterable[ExtensionDescriptor] | None = None) -> None: self._descriptors: dict[str, ExtensionDescriptor] = {} self._by_kind: dict[str, set[str]] = {} self._by_capability: dict[str, set[str]] = {} for descriptor in descriptors or []: self.register(descriptor) def register(self, descriptor: ExtensionDescriptor) -> None: if descriptor.id in self._descriptors: raise ExtensionRegistryError(f"Duplicate extension id `{descriptor.id}`") self._descriptors[descriptor.id] = descriptor self._by_kind.setdefault(descriptor.kind, set()).add(descriptor.id) for capability in descriptor.capabilities: self._by_capability.setdefault(capability.id, set()).add(descriptor.id) def get(self, extension_id: str) -> ExtensionDescriptor: try: return self._descriptors[extension_id] except KeyError as exc: raise ExtensionRegistryError(f"Unknown extension `{extension_id}`") from exc def list(self, *, kind: str | None = None) -> list[ExtensionDescriptor]: if kind is None: ids = sorted(self._descriptors) else: ids = sorted(self._by_kind.get(kind, set())) return [self._descriptors[key] for key in ids] def require_capability(self, capability_id: str) -> list[ExtensionDescriptor]: return [ self._descriptors[extension_id] for extension_id in sorted(self._by_capability.get(capability_id, set())) ] def check_dependencies( self, extension_id: str, *, available_modules: set[str] | None = None, ) -> ExtensionDependencyCheck: descriptor = self.get(extension_id) available = ( available_modules if available_modules is not None else _available_modules( dependency.name for dependency in descriptor.optional_dependencies ) ) missing: list[str] = [] optional_missing: list[str] = [] for dependency in descriptor.optional_dependencies: if dependency.name in available: continue if dependency.required: missing.append(dependency.name) else: optional_missing.append(dependency.name) return ExtensionDependencyCheck( extension_id=extension_id, missing=missing, optional_missing=optional_missing, ) def to_dict(self) -> dict[str, Any]: return { "count": len(self._descriptors), "extensions": [descriptor.to_dict() for descriptor in self.list()], } def _available_modules(module_names: Iterable[str]) -> set[str]: import importlib.util return { module_name for module_name in module_names if importlib.util.find_spec(module_name) is not None } def _drop_empty(data: dict[str, Any]) -> dict[str, Any]: return { key: value for key, value in data.items() if value not in (None, [], {}, "") }