Files
infospace-bench/src/infospace_bench/generator.py

526 lines
18 KiB
Python

from __future__ import annotations
import hashlib
import shutil
from dataclasses import asdict, dataclass, field
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import yaml
from .checks import run_collection_checks
from .errors import InfospaceError
from .evaluation_io import read_entity_evaluations
from .history import get_history, read_metrics_file, record_check_results
from .lifecycle import create_infospace, load_infospace, register_artifact
from .openrouter import OpenRouterAssistedGenerationAdapter
from .source_intake import SourceChunk, normalize_source
from .workflow import (
AssistedGenerationAdapter,
FixtureAssistedGenerationAdapter,
WorkflowRunResult,
plan_workflow,
run_workflow,
)
STATE_PATH = Path("output/workflows/generation-state.yaml")
DEFAULT_PROFILE = "general-knowledge"
WORKFLOW_BY_STAGE = {
"summary": ["generic-source-summary"],
"summarize": ["generic-source-summary"],
"extract": ["generic-source-entities"],
"entities": ["generic-source-entities"],
"relations": ["generic-source-relations"],
"evaluate": ["generic-source-evaluations"],
"evaluation": ["generic-source-evaluations"],
"all": [
"generic-source-summary",
"generic-source-entities",
"generic-source-relations",
"generic-source-evaluations",
],
}
@dataclass(frozen=True)
class GenerationRunResult:
root: str
status: str
stage: str
skipped: bool = False
stale: bool = False
workflows: list[dict[str, Any]] = field(default_factory=list)
metrics: dict[str, Any] = field(default_factory=dict)
history_snapshot_id: str = ""
def to_dict(self) -> dict[str, Any]:
data = asdict(self)
return {key: value for key, value in data.items() if value not in ("", [], {})}
def init_generation_infospace(
workspace: str | Path,
source: str | Path,
slug: str,
*,
name: str,
profile: str = DEFAULT_PROFILE,
max_chunks: int | None = None,
) -> Any:
chunks = normalize_source(source, max_chunks=max_chunks)
infospace = create_infospace(Path(workspace), slug, name=name)
_install_profile(infospace.root, profile)
_write_workflows(infospace.root, profile)
_register_source_chunks(infospace.root, chunks)
_write_state(
infospace.root,
{
"profile": profile,
"source": str(Path(source)),
"source_chunks": _source_state(infospace.root),
"profile_digest": _profile_digest(infospace.root, profile),
"stage_status": {},
"completed": False,
"created_at": _now(),
"updated_at": _now(),
},
)
return load_infospace(infospace.root)
def plan_generation(root: str | Path, *, stage: str = "all") -> dict[str, Any]:
root_path = Path(root)
workflow_ids = _workflow_ids_for_stage(stage)
plans: list[dict[str, Any]] = []
for workflow_id in workflow_ids:
try:
plans.append(plan_workflow(root_path, workflow_id).to_dict())
except InfospaceError as exc:
plans.append(
{
"workflow_id": workflow_id,
"status": "blocked",
"error": exc.to_dict(),
}
)
status = status_generation(root_path)
return {
"root": str(root_path),
"stage": stage,
"status": "planned",
"stale": status["stale"],
"source_chunk_count": status["source_chunk_count"],
"workflows": plans,
}
def run_generation(
root: str | Path,
*,
stage: str = "all",
provider: str = "fixture",
model: str = "",
fixture_responses: str | Path | None = None,
resume: bool = False,
force: bool = False,
) -> GenerationRunResult:
root_path = Path(root)
stage_key = stage.strip().lower()
state = _read_state(root_path)
status = status_generation(root_path)
workflow_ids = _workflow_ids_for_stage(stage_key)
if resume and not force and state.get("completed") is True and not status["stale"]:
return GenerationRunResult(
root=str(root_path),
status="skipped",
stage=stage,
skipped=True,
stale=False,
workflows=[],
metrics=status.get("metrics", {}),
)
adapter = (
_adapter_for(provider, model=model, fixture_responses=fixture_responses)
if workflow_ids
else None
)
workflow_results: list[dict[str, Any]] = []
for workflow_id in workflow_ids:
result = run_workflow(root_path, workflow_id, assisted_adapter=adapter)
workflow_results.append(result.to_dict())
state = _mark_workflow_completed(state, result)
metrics: dict[str, Any] = {}
snapshot_id = ""
if stage_key in {"all", "metrics"}:
check_result = _record_metrics(root_path)
metrics = check_result.metrics
snapshot_id = check_result.snapshot.snapshot_id
_write_generation_report(root_path, metrics, snapshot_id)
state.update(
{
"source_chunks": _source_state(root_path),
"profile_digest": _profile_digest(root_path, str(state.get("profile") or DEFAULT_PROFILE)),
"completed": stage_key in {"all", "metrics"},
"updated_at": _now(),
"last_run": {
"stage": stage,
"provider": provider,
"model": model,
"workflow_count": len(workflow_results),
"snapshot_id": snapshot_id,
"completed_at": _now(),
},
}
)
_write_state(root_path, state)
return GenerationRunResult(
root=str(root_path),
status="completed",
stage=stage,
skipped=False,
stale=False,
workflows=workflow_results,
metrics=metrics,
history_snapshot_id=snapshot_id,
)
def status_generation(root: str | Path) -> dict[str, Any]:
root_path = Path(root)
infospace = load_infospace(root_path)
state = _read_state(root_path)
stale_sources = _stale_source_ids(infospace.root)
profile = str(state.get("profile") or DEFAULT_PROFILE)
stale_profile = bool(
state.get("profile_digest")
and state.get("profile_digest") != _profile_digest(infospace.root, profile)
)
evaluations = read_entity_evaluations(infospace.root / "output" / "evaluations")
history = get_history(infospace.root)
return {
"root": str(infospace.root),
"slug": infospace.config.slug,
"profile": profile,
"source_chunk_count": sum(1 for item in infospace.artifacts if item.kind == "source"),
"entity_count": sum(1 for item in infospace.artifacts if item.kind == "entity"),
"relation_count": sum(1 for item in infospace.artifacts if item.kind == "relation"),
"evaluation_count": len(evaluations),
"generated_count": sum(1 for item in infospace.artifacts if item.kind == "generated"),
"metrics": read_metrics_file(infospace.root / "output" / "metrics" / "metrics.yaml"),
"history_snapshot_count": len(history),
"latest_snapshot_id": history[-1].snapshot_id if history else "",
"stale": bool(stale_sources or stale_profile),
"stale_sources": stale_sources,
"stale_profile": stale_profile,
"completed": bool(state.get("completed", False)),
"stage_status": state.get("stage_status", {}),
}
def _adapter_for(
provider: str,
*,
model: str,
fixture_responses: str | Path | None,
) -> AssistedGenerationAdapter:
if fixture_responses:
return FixtureAssistedGenerationAdapter.from_file(Path(fixture_responses))
if provider == "openrouter":
return OpenRouterAssistedGenerationAdapter(model=model)
raise InfospaceError(
"missing_assisted_generation_adapter",
"Assisted generation requires --fixture-responses or --provider openrouter",
{"provider": provider},
)
def _register_source_chunks(root: Path, chunks: list[SourceChunk]) -> None:
for chunk in chunks:
path = root / "artifacts" / "sources" / f"{chunk.chunk_id}.md"
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(chunk.markdown, encoding="utf-8")
register_artifact(
root,
artifact_id=f"source/{chunk.chunk_id}.md",
path=path,
kind="source",
title=chunk.title,
provenance={
"original_path": chunk.original_path,
"source_type": chunk.source_type,
"digest": chunk.digest,
"chunk_id": chunk.chunk_id,
"chunk_index": chunk.chunk_index,
"chunk_count": chunk.chunk_count,
"imported_at": chunk.imported_at,
"extractor_version": chunk.extractor_version,
},
)
def _install_profile(root: Path, profile: str) -> None:
source = Path(__file__).parent / "profiles" / profile
if not source.is_dir():
raise InfospaceError(
"missing_generation_profile",
f"Generation profile does not exist: {profile}",
{"profile": profile, "path": str(source)},
)
profile_target = root / "profiles" / profile
template_target = root / "workflows" / "templates" / profile
shutil.copytree(source, profile_target, dirs_exist_ok=True)
shutil.copytree(source / "templates", template_target, dirs_exist_ok=True)
def _write_workflows(root: Path, profile: str) -> None:
config_path = root / "infospace.yaml"
config = yaml.safe_load(config_path.read_text(encoding="utf-8")) or {}
config["schemas"] = {
**dict(config.get("schemas") or {}),
"entity": f"profiles/{profile}/contracts/entity.contract.md",
"relation": f"profiles/{profile}/contracts/relation.contract.md",
"evaluation": f"profiles/{profile}/contracts/evaluation.contract.md",
}
config["workflows"] = _profile_workflows(profile)
config_path.write_text(yaml.safe_dump(config, sort_keys=False), encoding="utf-8")
def _profile_workflows(profile: str) -> list[dict[str, Any]]:
base = f"workflows/templates/{profile}"
return [
{
"id": "generic-source-summary",
"description": "Summarize normalized source chunks.",
"inputs": {"source": {"kind": "source"}},
"static_macros": {"profile": profile},
"stages": [
{
"id": "summarize-source",
"kind": "assisted",
"input": "source",
"template": f"{base}/summarize-source.md",
"provider_hint": "openrouter",
"output": {
"path": "artifacts/generated/{{ input.slug }}-summary.md",
"artifact_id": "generated/{{ input.slug }}-summary.md",
"kind": "generated",
"title": "{{ input.title }} Summary",
},
}
],
},
{
"id": "generic-source-entities",
"description": "Extract reusable entity artifacts from source chunks.",
"inputs": {"source": {"kind": "source"}},
"static_macros": {"profile": profile},
"stages": [
{
"id": "extract-entities",
"kind": "assisted",
"input": "source",
"template": f"{base}/extract-entities.md",
"provider_hint": "openrouter",
"output": {
"path": "artifacts/generated/{{ input.slug }}-entities.md",
"artifact_id": "generated/{{ input.slug }}-entities.md",
"kind": "generated",
"title": "{{ input.title }} Entity Bundle",
},
},
{
"id": "split-entities",
"kind": "split_entities",
"input": "source",
"template": "",
"static_macros": {"bundle_stage": "extract-entities"},
},
],
},
{
"id": "generic-source-relations",
"description": "Extract relation artifacts from source chunks.",
"inputs": {"source": {"kind": "source"}},
"static_macros": {"profile": profile},
"stages": [
{
"id": "extract-relations",
"kind": "assisted",
"input": "source",
"template": f"{base}/extract-relations.md",
"provider_hint": "openrouter",
"output": {
"path": "artifacts/relations/{{ input.slug }}-relations.md",
"artifact_id": "relation/{{ input.slug }}-relations.md",
"kind": "relation",
"title": "{{ input.title }} Relations",
},
}
],
},
{
"id": "generic-source-evaluations",
"description": "Evaluate generated entities with the profile rubric.",
"inputs": {"entity": {"kind": "entity"}},
"static_macros": {"profile": profile},
"stages": [
{
"id": "evaluate-entity",
"kind": "assisted",
"input": "entity",
"template": f"{base}/evaluate-entity.md",
"provider_hint": "openrouter",
"output": {
"path": "output/evaluations/{{ input.slug }}.md",
"artifact_id": "generated/evaluation-{{ input.slug }}.md",
"kind": "generated",
"title": "{{ input.title }} Evaluation",
},
}
],
},
]
def _record_metrics(root: Path) -> Any:
infospace = load_infospace(root)
return record_check_results(
infospace.root,
run_collection_checks(infospace.artifacts),
artifact_evaluations=read_entity_evaluations(infospace.root / "output" / "evaluations"),
schema_name="generic-source",
metadata={"generator": "generic-source"},
)
def _write_generation_report(root: Path, metrics: dict[str, Any], snapshot_id: str) -> None:
status = status_generation(root)
text = "\n".join(
[
"# Generation Report",
"",
f"Snapshot: {snapshot_id}",
f"Sources: {status['source_chunk_count']}",
f"Entities: {status['entity_count']}",
f"Relations: {status['relation_count']}",
f"Evaluations: {status['evaluation_count']}",
"",
"## Metrics",
"",
*[f"- {name}: {value}" for name, value in sorted(metrics.items())],
"",
]
)
path = root / "reports" / "generation-summary.md"
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(text, encoding="utf-8")
register_artifact(
root,
artifact_id="generated/generation-summary.md",
path=path,
kind="generated",
title="Generation Summary",
provenance={"workflow_id": "generic-source-generator", "snapshot_id": snapshot_id},
)
def _workflow_ids_for_stage(stage: str) -> list[str]:
normalized = stage.strip().lower()
if normalized == "intake":
return []
if normalized == "metrics":
return []
if normalized not in WORKFLOW_BY_STAGE:
raise InfospaceError(
"invalid_generation_stage",
f"Unsupported generation stage: {stage}",
{
"stage": stage,
"valid_stages": sorted([*WORKFLOW_BY_STAGE, "intake", "metrics"]),
},
)
return WORKFLOW_BY_STAGE[normalized]
def _source_state(root: Path) -> dict[str, Any]:
infospace = load_infospace(root)
return {
item.id: {
"path": item.path,
"digest": item.provenance.get("digest", ""),
"title": item.title,
"source_type": item.provenance.get("source_type", ""),
"chunk_id": item.provenance.get("chunk_id", ""),
}
for item in infospace.artifacts
if item.kind == "source"
}
def _stale_source_ids(root: Path) -> list[str]:
infospace = load_infospace(root)
stale: list[str] = []
for item in infospace.artifacts:
if item.kind != "source":
continue
path = infospace.root / item.path
expected = str(item.provenance.get("digest") or "")
if not path.is_file() or (expected and _digest_text(path.read_text(encoding="utf-8")) != expected):
stale.append(item.id)
return stale
def _mark_workflow_completed(
state: dict[str, Any],
result: WorkflowRunResult,
) -> dict[str, Any]:
stage_status = dict(state.get("stage_status") or {})
stage_status[result.workflow_id] = {
"status": result.status,
"run_id": result.run_id,
"output_artifact_ids": [output.artifact_id for output in result.outputs],
"updated_at": _now(),
}
return {**state, "stage_status": stage_status}
def _profile_digest(root: Path, profile: str) -> str:
files: list[Path] = []
for base in (
root / "profiles" / profile,
root / "workflows" / "templates" / profile,
):
if base.is_dir():
files.extend(path for path in sorted(base.rglob("*")) if path.is_file())
hasher = hashlib.sha256()
for path in files:
hasher.update(str(path.relative_to(root)).encode("utf-8"))
hasher.update(path.read_bytes())
return hasher.hexdigest()
def _read_state(root: Path) -> dict[str, Any]:
path = root / STATE_PATH
if not path.is_file():
return {}
data = yaml.safe_load(path.read_text(encoding="utf-8"))
return data if isinstance(data, dict) else {}
def _write_state(root: Path, state: dict[str, Any]) -> None:
path = root / STATE_PATH
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(yaml.safe_dump(state, sort_keys=False), encoding="utf-8")
def _digest_text(text: str) -> str:
return hashlib.sha256(text.encode("utf-8")).hexdigest()
def _now() -> str:
return datetime.now(timezone.utc).isoformat()