Files
markitect-tool/src/markitect_tool/extension/registry.py

199 lines
6.6 KiB
Python

"""Extension descriptors and registries."""
from __future__ import annotations
from dataclasses import asdict, dataclass, field
from typing import Any, Callable, Iterable
from markitect_tool.extension.processing import ProcessingCapability
ExtensionFactory = Callable[[], Any]
class ExtensionRegistryError(ValueError):
"""Raised when extension descriptors or registries are invalid."""
@dataclass(frozen=True)
class OptionalDependency:
"""An optional runtime dependency declared by an extension."""
name: str
package: str | None = None
extra: str | None = None
required: bool = False
purpose: str | None = None
def to_dict(self) -> dict[str, Any]:
return _drop_empty(asdict(self))
@dataclass(frozen=True)
class ExtensionDescriptor:
"""Inspectable descriptor for one internal extension."""
id: str
kind: str
version: str = "1"
summary: str | None = None
factory: ExtensionFactory | None = field(default=None, compare=False, repr=False)
capabilities: list[ProcessingCapability] = field(default_factory=list)
optional_dependencies: list[OptionalDependency] = field(default_factory=list)
safety: dict[str, Any] = field(default_factory=dict)
input_contract: str | None = None
output_contract: str | None = None
diagnostics_namespace: str | None = None
provenance_prefix: str | None = None
cli: dict[str, Any] = field(default_factory=dict)
docs: list[str] = field(default_factory=list)
examples: list[str] = field(default_factory=list)
metadata: dict[str, Any] = field(default_factory=dict)
def __post_init__(self) -> None:
if not self.id.strip():
raise ExtensionRegistryError("Extension id cannot be empty")
if not self.kind.strip():
raise ExtensionRegistryError("Extension kind cannot be empty")
def to_dict(self) -> dict[str, Any]:
data = {
"id": self.id,
"kind": self.kind,
"version": self.version,
"summary": self.summary,
"capabilities": [capability.to_dict() for capability in self.capabilities],
"optional_dependencies": [
dependency.to_dict() for dependency in self.optional_dependencies
],
"safety": self.safety,
"input_contract": self.input_contract,
"output_contract": self.output_contract,
"diagnostics_namespace": self.diagnostics_namespace,
"provenance_prefix": self.provenance_prefix,
"cli": self.cli,
"docs": self.docs,
"examples": self.examples,
"metadata": self.metadata,
}
return _drop_empty(data)
def instantiate(self) -> Any:
"""Create or return the implementation for this descriptor."""
if self.factory is None:
raise ExtensionRegistryError(f"Extension `{self.id}` has no factory")
return self.factory()
@dataclass(frozen=True)
class ExtensionDependencyCheck:
"""Result of checking required extension dependencies."""
extension_id: str
missing: list[str] = field(default_factory=list)
optional_missing: list[str] = field(default_factory=list)
@property
def compatible(self) -> bool:
return not self.missing
def to_dict(self) -> dict[str, Any]:
return {
"extension_id": self.extension_id,
"compatible": self.compatible,
"missing": self.missing,
"optional_missing": self.optional_missing,
}
class ExtensionRegistry:
"""Registry of internal extension descriptors."""
def __init__(self, descriptors: Iterable[ExtensionDescriptor] | None = None) -> None:
self._descriptors: dict[str, ExtensionDescriptor] = {}
self._by_kind: dict[str, set[str]] = {}
self._by_capability: dict[str, set[str]] = {}
for descriptor in descriptors or []:
self.register(descriptor)
def register(self, descriptor: ExtensionDescriptor) -> None:
if descriptor.id in self._descriptors:
raise ExtensionRegistryError(f"Duplicate extension id `{descriptor.id}`")
self._descriptors[descriptor.id] = descriptor
self._by_kind.setdefault(descriptor.kind, set()).add(descriptor.id)
for capability in descriptor.capabilities:
self._by_capability.setdefault(capability.id, set()).add(descriptor.id)
def get(self, extension_id: str) -> ExtensionDescriptor:
try:
return self._descriptors[extension_id]
except KeyError as exc:
raise ExtensionRegistryError(f"Unknown extension `{extension_id}`") from exc
def list(self, *, kind: str | None = None) -> list[ExtensionDescriptor]:
if kind is None:
ids = sorted(self._descriptors)
else:
ids = sorted(self._by_kind.get(kind, set()))
return [self._descriptors[key] for key in ids]
def require_capability(self, capability_id: str) -> list[ExtensionDescriptor]:
return [
self._descriptors[extension_id]
for extension_id in sorted(self._by_capability.get(capability_id, set()))
]
def check_dependencies(
self,
extension_id: str,
*,
available_modules: set[str] | None = None,
) -> ExtensionDependencyCheck:
descriptor = self.get(extension_id)
available = (
available_modules
if available_modules is not None
else _available_modules(
dependency.name for dependency in descriptor.optional_dependencies
)
)
missing: list[str] = []
optional_missing: list[str] = []
for dependency in descriptor.optional_dependencies:
if dependency.name in available:
continue
if dependency.required:
missing.append(dependency.name)
else:
optional_missing.append(dependency.name)
return ExtensionDependencyCheck(
extension_id=extension_id,
missing=missing,
optional_missing=optional_missing,
)
def to_dict(self) -> dict[str, Any]:
return {
"count": len(self._descriptors),
"extensions": [descriptor.to_dict() for descriptor in self.list()],
}
def _available_modules(module_names: Iterable[str]) -> set[str]:
import importlib.util
return {
module_name
for module_name in module_names
if importlib.util.find_spec(module_name) is not None
}
def _drop_empty(data: dict[str, Any]) -> dict[str, Any]:
return {
key: value
for key, value in data.items()
if value not in (None, [], {}, "")
}