Files
markitect-tool/src/markitect_tool/backend/planning.py

426 lines
15 KiB
Python

"""Refresh planning for optional snapshot and index backends."""
from __future__ import annotations
import hashlib
import json
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any
import yaml
from markitect_tool.backend.engine import (
DependencyEdge,
EMPTY_PARSE_OPTIONS_HASH,
PARSER_ID,
PARSER_VERSION,
)
from markitect_tool.cache import scan_markdown_files
@dataclass(frozen=True)
class SnapshotState:
"""Previously known source state from a snapshot/index backend."""
path: str
size: int
mtime_ns: int
content_hash: str
snapshot_id: str
parser: str = PARSER_ID
parser_version: str = PARSER_VERSION
parse_options_hash: str = EMPTY_PARSE_OPTIONS_HASH
contract_hash: str | None = None
indexed: bool = True
dependencies: list[DependencyEdge] = field(default_factory=list)
def to_dict(self) -> dict[str, Any]:
data = asdict(self)
data["dependencies"] = [edge.to_dict() for edge in self.dependencies]
return {key: value for key, value in data.items() if value is not None}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "SnapshotState":
return cls(
path=str(data["path"]),
size=int(data["size"]),
mtime_ns=int(data["mtime_ns"]),
content_hash=str(data["content_hash"]),
snapshot_id=str(data["snapshot_id"]),
parser=str(data.get("parser", PARSER_ID)),
parser_version=str(data.get("parser_version", PARSER_VERSION)),
parse_options_hash=str(
data.get(
"parse_options_hash",
EMPTY_PARSE_OPTIONS_HASH,
)
),
contract_hash=str(data["contract_hash"]) if data.get("contract_hash") is not None else None,
indexed=bool(data.get("indexed", True)),
dependencies=[
_dependency_edge_from_dict(edge)
for edge in data.get("dependencies", [])
if isinstance(edge, dict)
],
)
@dataclass(frozen=True)
class SnapshotPlanEntry:
"""One source-path decision in a refresh plan."""
path: str
actions: list[str]
reason: str
size: int | None = None
mtime_ns: int | None = None
previous_snapshot_id: str | None = None
content_hash: str | None = None
invalidated_by: list[str] = field(default_factory=list)
def to_dict(self) -> dict[str, Any]:
return {key: value for key, value in asdict(self).items() if value not in (None, [], {})}
@dataclass(frozen=True)
class SnapshotRefreshPlan:
"""A cheap-first plan for refreshing snapshots and derived indexes."""
root: str
parser: str
parser_version: str
parse_options_hash: str
contract_hash: str | None
verify_hashes: bool
entries: list[SnapshotPlanEntry]
@property
def unchanged(self) -> list[str]:
return _paths_without_actions(self.entries)
@property
def needs_hash(self) -> list[str]:
return _paths_with_action(self.entries, "hash")
@property
def needs_parse(self) -> list[str]:
return _paths_with_action(self.entries, "parse")
@property
def needs_index(self) -> list[str]:
return _paths_with_action(self.entries, "index")
@property
def needs_metadata_update(self) -> list[str]:
return _paths_with_action(self.entries, "metadata")
@property
def deleted(self) -> list[str]:
return _paths_with_action(self.entries, "delete")
@property
def invalidated(self) -> list[str]:
return sorted(entry.path for entry in self.entries if "invalidate" in entry.actions)
@property
def dirty(self) -> bool:
return any(entry.actions for entry in self.entries)
def to_dict(self) -> dict[str, Any]:
return {
"dirty": self.dirty,
"root": self.root,
"parser": self.parser,
"parser_version": self.parser_version,
"parse_options_hash": self.parse_options_hash,
"contract_hash": self.contract_hash,
"verify_hashes": self.verify_hashes,
"counts": {
"unchanged": len(self.unchanged),
"needs_hash": len(self.needs_hash),
"needs_parse": len(self.needs_parse),
"needs_index": len(self.needs_index),
"needs_metadata_update": len(self.needs_metadata_update),
"deleted": len(self.deleted),
"invalidated": len(self.invalidated),
},
"unchanged": self.unchanged,
"needs_hash": self.needs_hash,
"needs_parse": self.needs_parse,
"needs_index": self.needs_index,
"needs_metadata_update": self.needs_metadata_update,
"deleted": self.deleted,
"invalidated": self.invalidated,
"entries": [entry.to_dict() for entry in self.entries],
}
def plan_snapshot_refresh(
paths: list[str | Path],
*,
previous: list[SnapshotState] | dict[str, SnapshotState] | None = None,
root: str | Path = ".",
recursive: bool = True,
parse_options: dict[str, Any] | None = None,
contract_hash: str | None = None,
verify_hashes: bool = False,
) -> SnapshotRefreshPlan:
"""Plan snapshot/index refresh work using cheap metadata before hashing.
When ``verify_hashes`` is false, files with changed size/mtime are marked
for hash, parse, and index. When true, the planner hashes only those
metadata-changed files so it can avoid parsing when content is unchanged.
"""
root_path = Path(root).resolve()
previous_by_path = _previous_by_path(previous)
parse_options_hash = _hash_mapping(parse_options or {})
current_files = {
_relative(path, root_path): path
for path in scan_markdown_files(paths, recursive=recursive)
}
entries: list[SnapshotPlanEntry] = []
changed_or_deleted: set[str] = set()
for relative_path, file_path in sorted(current_files.items()):
stat = file_path.stat()
known = previous_by_path.get(relative_path)
if known is None:
entries.append(
SnapshotPlanEntry(
path=relative_path,
actions=["hash", "parse", "index"],
reason="new_file",
size=stat.st_size,
mtime_ns=stat.st_mtime_ns,
)
)
changed_or_deleted.add(relative_path)
continue
identity_changed = (
known.parser != PARSER_ID
or known.parser_version != PARSER_VERSION
or known.parse_options_hash != parse_options_hash
or known.contract_hash != contract_hash
)
if identity_changed:
entries.append(
SnapshotPlanEntry(
path=relative_path,
actions=["hash", "parse", "index"],
reason="snapshot_identity_parameters_changed",
size=stat.st_size,
mtime_ns=stat.st_mtime_ns,
previous_snapshot_id=known.snapshot_id,
)
)
changed_or_deleted.add(relative_path)
continue
metadata_same = known.size == stat.st_size and known.mtime_ns == stat.st_mtime_ns
if metadata_same:
actions = [] if known.indexed else ["index"]
entries.append(
SnapshotPlanEntry(
path=relative_path,
actions=actions,
reason="unchanged" if not actions else "snapshot_not_indexed",
size=stat.st_size,
mtime_ns=stat.st_mtime_ns,
previous_snapshot_id=known.snapshot_id,
content_hash=known.content_hash,
)
)
continue
if not verify_hashes:
entries.append(
SnapshotPlanEntry(
path=relative_path,
actions=["hash", "parse", "index"],
reason="file_metadata_changed",
size=stat.st_size,
mtime_ns=stat.st_mtime_ns,
previous_snapshot_id=known.snapshot_id,
)
)
changed_or_deleted.add(relative_path)
continue
current_hash = _hash_file(file_path)
if current_hash == known.content_hash:
actions = ["hash", "metadata"] if known.indexed else ["hash", "metadata", "index"]
entries.append(
SnapshotPlanEntry(
path=relative_path,
actions=actions,
reason="file_metadata_changed_content_same",
size=stat.st_size,
mtime_ns=stat.st_mtime_ns,
previous_snapshot_id=known.snapshot_id,
content_hash=current_hash,
)
)
else:
entries.append(
SnapshotPlanEntry(
path=relative_path,
actions=["hash", "parse", "index"],
reason="content_hash_changed",
size=stat.st_size,
mtime_ns=stat.st_mtime_ns,
previous_snapshot_id=known.snapshot_id,
content_hash=current_hash,
)
)
changed_or_deleted.add(relative_path)
for relative_path, known in sorted(previous_by_path.items()):
if relative_path in current_files:
continue
entries.append(
SnapshotPlanEntry(
path=relative_path,
actions=["delete"],
reason="source_missing",
previous_snapshot_id=known.snapshot_id,
content_hash=known.content_hash,
)
)
changed_or_deleted.add(relative_path)
invalidated = _transitive_dependents(changed_or_deleted, previous_by_path)
if invalidated:
entries = _apply_invalidations(entries, invalidated, changed_or_deleted)
return SnapshotRefreshPlan(
root=str(root_path),
parser=PARSER_ID,
parser_version=PARSER_VERSION,
parse_options_hash=parse_options_hash,
contract_hash=contract_hash,
verify_hashes=verify_hashes,
entries=sorted(entries, key=lambda entry: entry.path),
)
def load_snapshot_state_file(path: str | Path) -> list[SnapshotState]:
"""Load a portable snapshot-state fixture from JSON or YAML."""
state_path = Path(path)
data = yaml.safe_load(state_path.read_text(encoding="utf-8")) or {}
raw_snapshots = data.get("snapshots", data.get("states", data))
if isinstance(raw_snapshots, dict):
raw_snapshots = list(raw_snapshots.values())
if not isinstance(raw_snapshots, list):
raise ValueError("Snapshot state file must contain a `snapshots` list")
return [
SnapshotState.from_dict(item)
for item in raw_snapshots
if isinstance(item, dict)
]
def _previous_by_path(
previous: list[SnapshotState] | dict[str, SnapshotState] | None,
) -> dict[str, SnapshotState]:
if previous is None:
return {}
if isinstance(previous, dict):
return dict(previous)
return {state.path: state for state in previous}
def _dependency_edge_from_dict(data: dict[str, Any]) -> DependencyEdge:
return DependencyEdge(
source_id=str(data["source_id"]),
target=str(data["target"]),
kind=str(data["kind"]),
target_snapshot_id=str(data["target_snapshot_id"]) if data.get("target_snapshot_id") else None,
metadata=dict(data.get("metadata") or {}),
)
def _transitive_dependents(
changed_paths: set[str],
previous_by_path: dict[str, SnapshotState],
) -> dict[str, list[str]]:
reverse: dict[str, set[str]] = {}
for state in previous_by_path.values():
for edge in state.dependencies:
reverse.setdefault(edge.target, set()).add(state.path)
if edge.target_snapshot_id:
reverse.setdefault(edge.target_snapshot_id, set()).add(state.path)
invalidates: dict[str, list[str]] = {}
queue = list(changed_paths)
visited = set(changed_paths)
while queue:
changed = queue.pop(0)
dependents = sorted(reverse.get(changed, set()))
if dependents:
invalidates[changed] = dependents
for dependent in dependents:
if dependent in visited:
continue
visited.add(dependent)
queue.append(dependent)
return invalidates
def _apply_invalidations(
entries: list[SnapshotPlanEntry],
invalidates: dict[str, list[str]],
changed_or_deleted: set[str],
) -> list[SnapshotPlanEntry]:
dependents_by_path: dict[str, list[str]] = {}
for changed_path, dependents in invalidates.items():
for dependent in dependents:
dependents_by_path.setdefault(dependent, []).append(changed_path)
existing = {entry.path: entry for entry in entries}
for dependent, causes in dependents_by_path.items():
if dependent in changed_or_deleted:
continue
entry = existing.get(dependent)
actions = sorted(set((entry.actions if entry else []) + ["invalidate"]))
reason = "dependency_changed" if entry is None or entry.reason == "unchanged" else entry.reason
existing[dependent] = SnapshotPlanEntry(
path=dependent,
actions=actions,
reason=reason,
size=entry.size if entry else None,
mtime_ns=entry.mtime_ns if entry else None,
previous_snapshot_id=entry.previous_snapshot_id if entry else None,
content_hash=entry.content_hash if entry else None,
invalidated_by=sorted(set(causes)),
)
return list(existing.values())
def _paths_with_action(entries: list[SnapshotPlanEntry], action: str) -> list[str]:
return sorted(entry.path for entry in entries if action in entry.actions)
def _paths_without_actions(entries: list[SnapshotPlanEntry]) -> list[str]:
return sorted(entry.path for entry in entries if not entry.actions)
def _relative(path: Path, root: Path) -> str:
resolved = path.resolve()
try:
return resolved.relative_to(root).as_posix()
except ValueError:
return resolved.as_posix()
def _hash_file(path: Path) -> str:
return "sha256:" + hashlib.sha256(path.read_bytes()).hexdigest()
def _hash_mapping(mapping: dict[str, Any]) -> str:
payload = json.dumps(mapping, sort_keys=True, ensure_ascii=False)
return "sha256:" + hashlib.sha256(payload.encode("utf-8")).hexdigest()