from __future__ import annotations import hashlib import shutil from dataclasses import asdict, dataclass, field from datetime import datetime, timezone from pathlib import Path from typing import Any import yaml from .checks import run_collection_checks from .errors import InfospaceError from .evaluation_io import read_entity_evaluations from .history import get_history, read_metrics_file, record_check_results from .lifecycle import create_infospace, load_infospace, register_artifact from .openrouter import OpenRouterAssistedGenerationAdapter from .source_intake import SourceChunk, normalize_source from .workflow import ( AssistedGenerationAdapter, FixtureAssistedGenerationAdapter, WorkflowRunResult, plan_workflow, run_workflow, ) STATE_PATH = Path("output/workflows/generation-state.yaml") DEFAULT_PROFILE = "general-knowledge" WORKFLOW_BY_STAGE = { "summary": ["generic-source-summary"], "summarize": ["generic-source-summary"], "extract": ["generic-source-entities"], "entities": ["generic-source-entities"], "relations": ["generic-source-relations"], "evaluate": ["generic-source-evaluations"], "evaluation": ["generic-source-evaluations"], "all": [ "generic-source-summary", "generic-source-entities", "generic-source-relations", "generic-source-evaluations", ], } @dataclass(frozen=True) class GenerationRunResult: root: str status: str stage: str skipped: bool = False stale: bool = False workflows: list[dict[str, Any]] = field(default_factory=list) metrics: dict[str, Any] = field(default_factory=dict) history_snapshot_id: str = "" def to_dict(self) -> dict[str, Any]: data = asdict(self) return {key: value for key, value in data.items() if value not in ("", [], {})} def init_generation_infospace( workspace: str | Path, source: str | Path, slug: str, *, name: str, profile: str = DEFAULT_PROFILE, max_chunks: int | None = None, ) -> Any: chunks = normalize_source(source, max_chunks=max_chunks) infospace = create_infospace(Path(workspace), slug, name=name) _install_profile(infospace.root, profile) _write_workflows(infospace.root, profile) _register_source_chunks(infospace.root, chunks) _write_state( infospace.root, { "profile": profile, "source": str(Path(source)), "source_chunks": _source_state(infospace.root), "profile_digest": _profile_digest(infospace.root, profile), "stage_status": {}, "completed": False, "created_at": _now(), "updated_at": _now(), }, ) return load_infospace(infospace.root) WORDS_PER_TOKEN_DEFAULT = 0.75 ENTITIES_PER_CHUNK_ESTIMATE = 2 _CALLS_PER_CHUNK_BY_WORKFLOW = { "generic-source-summary": 1, "generic-source-entities": 1, "generic-source-relations": 1, } def plan_generation( root: str | Path, *, stage: str = "all", chapter_filter: list[str] | None = None, chunk_filter: list[str] | None = None, from_chapter: int | None = None, to_chapter: int | None = None, max_calls: int | None = None, cost_cap: float | None = None, cost_per_1k_tokens: float = 0.0, words_per_token: float = WORDS_PER_TOKEN_DEFAULT, entities_per_chunk: int = ENTITIES_PER_CHUNK_ESTIMATE, full: bool = False, ) -> dict[str, Any]: root_path = Path(root) status = status_generation(root_path) summary = plan_generation_summary( root_path, stage=stage, chapter_filter=chapter_filter, chunk_filter=chunk_filter, from_chapter=from_chapter, to_chapter=to_chapter, max_calls=max_calls, cost_cap=cost_cap, cost_per_1k_tokens=cost_per_1k_tokens, words_per_token=words_per_token, entities_per_chunk=entities_per_chunk, ) summary["root"] = str(root_path) summary["stale"] = status["stale"] summary["status"] = "planned" if not full: return summary workflow_ids = _workflow_ids_for_stage(stage) plans: list[dict[str, Any]] = [] for workflow_id in workflow_ids: try: plans.append(plan_workflow(root_path, workflow_id).to_dict()) except InfospaceError as exc: plans.append( { "workflow_id": workflow_id, "status": "blocked", "error": exc.to_dict(), } ) summary["workflows"] = plans return summary def plan_generation_summary( root: str | Path, *, stage: str = "all", chapter_filter: list[str] | None = None, chunk_filter: list[str] | None = None, from_chapter: int | None = None, to_chapter: int | None = None, max_calls: int | None = None, cost_cap: float | None = None, cost_per_1k_tokens: float = 0.0, words_per_token: float = WORDS_PER_TOKEN_DEFAULT, entities_per_chunk: int = ENTITIES_PER_CHUNK_ESTIMATE, ) -> dict[str, Any]: root_path = Path(root) infospace = load_infospace(root_path) sources = [item for item in infospace.artifacts if item.kind == "source"] selected = _select_source_chunks( sources, chapter_filter=chapter_filter, chunk_filter=chunk_filter, from_chapter=from_chapter, to_chapter=to_chapter, ) workflow_ids = _workflow_ids_for_stage(stage) profile_name = _read_profile_name(root_path) template_words = _profile_template_words(root_path, profile_name) chunk_word_total = sum(_source_word_count(root_path, item) for item in selected) per_stage: list[dict[str, Any]] = [] total_calls = 0 total_prompt_words = 0 for workflow_id in workflow_ids: if workflow_id == "generic-source-evaluations": calls = len(selected) * max(0, entities_per_chunk) template_label = "evaluate-entity" entity_words_estimate = 80 prompt_words = calls * ( template_words.get(template_label, 0) + entity_words_estimate ) else: calls = len(selected) * _CALLS_PER_CHUNK_BY_WORKFLOW.get(workflow_id, 0) template_label = _template_for_workflow(workflow_id) prompt_words = calls * template_words.get(template_label, 0) + chunk_word_total * ( 1 if calls else 0 ) per_stage.append( { "workflow_id": workflow_id, "calls": calls, "prompt_words_estimate": prompt_words, } ) total_calls += calls total_prompt_words += prompt_words total_tokens = int(round(total_prompt_words / words_per_token)) if words_per_token > 0 else 0 cost: float | None = None if cost_per_1k_tokens > 0: cost = round((total_tokens / 1000.0) * cost_per_1k_tokens, 4) chapter_numbers = sorted( { int(item.provenance.get("chapter_number")) for item in selected if isinstance(item.provenance.get("chapter_number"), int) } ) return { "stage": stage, "source_chunk_count": len(sources), "selected_chunk_count": len(selected), "selected_chunk_ids": [item.id.split("/", 1)[-1].rsplit(".md", 1)[0] for item in selected], "selected_chapter_numbers": chapter_numbers, "per_workflow": per_stage, "total_provider_calls_estimate": total_calls, "total_prompt_words_estimate": total_prompt_words, "total_prompt_tokens_estimate": total_tokens, "estimated_cost_usd": cost, "cost_per_1k_tokens": cost_per_1k_tokens or None, "words_per_token": words_per_token, "entities_per_chunk_estimate": entities_per_chunk, "max_calls": max_calls, "cost_cap": cost_cap, "exceeds_max_calls": bool(max_calls is not None and total_calls > max_calls), "exceeds_cost_cap": bool(cost_cap is not None and cost is not None and cost > cost_cap), } def _select_source_chunks( sources: list[Any], *, chapter_filter: list[str] | None, chunk_filter: list[str] | None, from_chapter: int | None, to_chapter: int | None, ) -> list[Any]: chunk_set = {value.strip() for value in (chunk_filter or []) if value.strip()} label_set = {value.strip().lower() for value in (chapter_filter or []) if value.strip()} out: list[Any] = [] for item in sources: chunk_id = item.provenance.get("chunk_id") or item.id.split("/", 1)[-1].rsplit(".md", 1)[0] if chunk_set and chunk_id not in chunk_set: continue chapter_number = item.provenance.get("chapter_number") chapter_label = (item.provenance.get("chapter_label") or "").strip().lower() if label_set: number_match = ( isinstance(chapter_number, int) and str(chapter_number) in label_set ) label_match = chapter_label in label_set if chapter_label else False if not (number_match or label_match): continue if from_chapter is not None or to_chapter is not None: if not isinstance(chapter_number, int): continue if from_chapter is not None and chapter_number < from_chapter: continue if to_chapter is not None and chapter_number > to_chapter: continue out.append(item) return out def _template_for_workflow(workflow_id: str) -> str: mapping = { "generic-source-summary": "summarize-source", "generic-source-entities": "extract-entities", "generic-source-relations": "extract-relations", "generic-source-evaluations": "evaluate-entity", } return mapping.get(workflow_id, "") def _profile_template_words(root: Path, profile: str) -> dict[str, int]: template_dir = Path(root) / "profiles" / profile / "templates" counts: dict[str, int] = {} if not template_dir.is_dir(): return counts for path in template_dir.glob("*.md"): try: text = path.read_text(encoding="utf-8") except OSError: continue counts[path.stem] = len(text.split()) return counts def _source_word_count(root: Path, artifact: Any) -> int: path = Path(root) / artifact.path try: return len(path.read_text(encoding="utf-8").split()) except OSError: return 0 def _read_profile_name(root: Path) -> str: state = _read_state(root) return str(state.get("profile") or DEFAULT_PROFILE) def run_generation( root: str | Path, *, stage: str = "all", provider: str = "fixture", model: str = "", fixture_responses: str | Path | None = None, resume: bool = False, force: bool = False, ) -> GenerationRunResult: root_path = Path(root) stage_key = stage.strip().lower() state = _read_state(root_path) status = status_generation(root_path) workflow_ids = _workflow_ids_for_stage(stage_key) if resume and not force and state.get("completed") is True and not status["stale"]: return GenerationRunResult( root=str(root_path), status="skipped", stage=stage, skipped=True, stale=False, workflows=[], metrics=status.get("metrics", {}), ) adapter = ( _adapter_for(provider, model=model, fixture_responses=fixture_responses) if workflow_ids else None ) workflow_results: list[dict[str, Any]] = [] for workflow_id in workflow_ids: result = run_workflow(root_path, workflow_id, assisted_adapter=adapter) workflow_results.append(result.to_dict()) state = _mark_workflow_completed(state, result) metrics: dict[str, Any] = {} snapshot_id = "" if stage_key in {"all", "metrics"}: check_result = _record_metrics(root_path) metrics = check_result.metrics snapshot_id = check_result.snapshot.snapshot_id _write_generation_report(root_path, metrics, snapshot_id) state.update( { "source_chunks": _source_state(root_path), "profile_digest": _profile_digest(root_path, str(state.get("profile") or DEFAULT_PROFILE)), "completed": stage_key in {"all", "metrics"}, "updated_at": _now(), "last_run": { "stage": stage, "provider": provider, "model": model, "workflow_count": len(workflow_results), "snapshot_id": snapshot_id, "completed_at": _now(), }, } ) _write_state(root_path, state) return GenerationRunResult( root=str(root_path), status="completed", stage=stage, skipped=False, stale=False, workflows=workflow_results, metrics=metrics, history_snapshot_id=snapshot_id, ) def status_generation(root: str | Path) -> dict[str, Any]: root_path = Path(root) infospace = load_infospace(root_path) state = _read_state(root_path) stale_sources = _stale_source_ids(infospace.root) profile = str(state.get("profile") or DEFAULT_PROFILE) stale_profile = bool( state.get("profile_digest") and state.get("profile_digest") != _profile_digest(infospace.root, profile) ) evaluations = read_entity_evaluations(infospace.root / "output" / "evaluations") history = get_history(infospace.root) return { "root": str(infospace.root), "slug": infospace.config.slug, "profile": profile, "source_chunk_count": sum(1 for item in infospace.artifacts if item.kind == "source"), "entity_count": sum(1 for item in infospace.artifacts if item.kind == "entity"), "relation_count": sum(1 for item in infospace.artifacts if item.kind == "relation"), "evaluation_count": len(evaluations), "generated_count": sum(1 for item in infospace.artifacts if item.kind == "generated"), "metrics": read_metrics_file(infospace.root / "output" / "metrics" / "metrics.yaml"), "history_snapshot_count": len(history), "latest_snapshot_id": history[-1].snapshot_id if history else "", "stale": bool(stale_sources or stale_profile), "stale_sources": stale_sources, "stale_profile": stale_profile, "completed": bool(state.get("completed", False)), "stage_status": state.get("stage_status", {}), } def _adapter_for( provider: str, *, model: str, fixture_responses: str | Path | None, ) -> AssistedGenerationAdapter: if fixture_responses: return FixtureAssistedGenerationAdapter.from_file(Path(fixture_responses)) if provider == "openrouter": return OpenRouterAssistedGenerationAdapter(model=model) raise InfospaceError( "missing_assisted_generation_adapter", "Assisted generation requires --fixture-responses or --provider openrouter", {"provider": provider}, ) def _register_source_chunks(root: Path, chunks: list[SourceChunk]) -> None: for chunk in chunks: path = root / "artifacts" / "sources" / f"{chunk.chunk_id}.md" path.parent.mkdir(parents=True, exist_ok=True) path.write_text(chunk.markdown, encoding="utf-8") register_artifact( root, artifact_id=f"source/{chunk.chunk_id}.md", path=path, kind="source", title=chunk.title, provenance={ "original_path": chunk.original_path, "source_type": chunk.source_type, "digest": chunk.digest, "chunk_id": chunk.chunk_id, "chunk_index": chunk.chunk_index, "chunk_count": chunk.chunk_count, "imported_at": chunk.imported_at, "extractor_version": chunk.extractor_version, "section_role": chunk.section_role, "spine_index": chunk.spine_index, "book_metadata": dict(chunk.book_metadata), "chapter_label": chunk.chapter_label, "chapter_number": chunk.chapter_number, "page_anchors": list(chunk.page_anchors), }, ) def _install_profile(root: Path, profile: str) -> None: source = Path(__file__).parent / "profiles" / profile if not source.is_dir(): raise InfospaceError( "missing_generation_profile", f"Generation profile does not exist: {profile}", {"profile": profile, "path": str(source)}, ) profile_target = root / "profiles" / profile template_target = root / "workflows" / "templates" / profile shutil.copytree(source, profile_target, dirs_exist_ok=True) shutil.copytree(source / "templates", template_target, dirs_exist_ok=True) def _write_workflows(root: Path, profile: str) -> None: config_path = root / "infospace.yaml" config = yaml.safe_load(config_path.read_text(encoding="utf-8")) or {} config["schemas"] = { **dict(config.get("schemas") or {}), "entity": f"profiles/{profile}/contracts/entity.contract.md", "relation": f"profiles/{profile}/contracts/relation.contract.md", "evaluation": f"profiles/{profile}/contracts/evaluation.contract.md", } config["workflows"] = _profile_workflows(profile) config_path.write_text(yaml.safe_dump(config, sort_keys=False), encoding="utf-8") def _profile_workflows(profile: str) -> list[dict[str, Any]]: base = f"workflows/templates/{profile}" return [ { "id": "generic-source-summary", "description": "Summarize normalized source chunks.", "inputs": {"source": {"kind": "source"}}, "static_macros": {"profile": profile}, "stages": [ { "id": "summarize-source", "kind": "assisted", "input": "source", "template": f"{base}/summarize-source.md", "provider_hint": "openrouter", "output": { "path": "artifacts/generated/{{ input.slug }}-summary.md", "artifact_id": "generated/{{ input.slug }}-summary.md", "kind": "generated", "title": "{{ input.title }} Summary", }, } ], }, { "id": "generic-source-entities", "description": "Extract reusable entity artifacts from source chunks.", "inputs": {"source": {"kind": "source"}}, "static_macros": {"profile": profile}, "stages": [ { "id": "extract-entities", "kind": "assisted", "input": "source", "template": f"{base}/extract-entities.md", "provider_hint": "openrouter", "output": { "path": "artifacts/generated/{{ input.slug }}-entities.md", "artifact_id": "generated/{{ input.slug }}-entities.md", "kind": "generated", "title": "{{ input.title }} Entity Bundle", }, }, { "id": "split-entities", "kind": "split_entities", "input": "source", "template": "", "static_macros": {"bundle_stage": "extract-entities"}, }, ], }, { "id": "generic-source-relations", "description": "Extract relation artifacts from source chunks.", "inputs": {"source": {"kind": "source"}}, "static_macros": {"profile": profile}, "stages": [ { "id": "extract-relations", "kind": "assisted", "input": "source", "template": f"{base}/extract-relations.md", "provider_hint": "openrouter", "output": { "path": "artifacts/relations/{{ input.slug }}-relations.md", "artifact_id": "relation/{{ input.slug }}-relations.md", "kind": "relation", "title": "{{ input.title }} Relations", }, } ], }, { "id": "generic-source-evaluations", "description": "Evaluate generated entities with the profile rubric.", "inputs": {"entity": {"kind": "entity"}}, "static_macros": {"profile": profile}, "stages": [ { "id": "evaluate-entity", "kind": "assisted", "input": "entity", "template": f"{base}/evaluate-entity.md", "provider_hint": "openrouter", "output": { "path": "output/evaluations/{{ input.slug }}.md", "artifact_id": "generated/evaluation-{{ input.slug }}.md", "kind": "generated", "title": "{{ input.title }} Evaluation", }, } ], }, ] def _record_metrics(root: Path) -> Any: infospace = load_infospace(root) return record_check_results( infospace.root, run_collection_checks(infospace.artifacts), artifact_evaluations=read_entity_evaluations(infospace.root / "output" / "evaluations"), schema_name="generic-source", metadata={"generator": "generic-source"}, ) def _write_generation_report(root: Path, metrics: dict[str, Any], snapshot_id: str) -> None: status = status_generation(root) text = "\n".join( [ "# Generation Report", "", f"Snapshot: {snapshot_id}", f"Sources: {status['source_chunk_count']}", f"Entities: {status['entity_count']}", f"Relations: {status['relation_count']}", f"Evaluations: {status['evaluation_count']}", "", "## Metrics", "", *[f"- {name}: {value}" for name, value in sorted(metrics.items())], "", ] ) path = root / "reports" / "generation-summary.md" path.parent.mkdir(parents=True, exist_ok=True) path.write_text(text, encoding="utf-8") register_artifact( root, artifact_id="generated/generation-summary.md", path=path, kind="generated", title="Generation Summary", provenance={"workflow_id": "generic-source-generator", "snapshot_id": snapshot_id}, ) def _workflow_ids_for_stage(stage: str) -> list[str]: normalized = stage.strip().lower() if normalized == "intake": return [] if normalized == "metrics": return [] if normalized not in WORKFLOW_BY_STAGE: raise InfospaceError( "invalid_generation_stage", f"Unsupported generation stage: {stage}", { "stage": stage, "valid_stages": sorted([*WORKFLOW_BY_STAGE, "intake", "metrics"]), }, ) return WORKFLOW_BY_STAGE[normalized] def _source_state(root: Path) -> dict[str, Any]: infospace = load_infospace(root) return { item.id: { "path": item.path, "digest": item.provenance.get("digest", ""), "title": item.title, "source_type": item.provenance.get("source_type", ""), "chunk_id": item.provenance.get("chunk_id", ""), } for item in infospace.artifacts if item.kind == "source" } def _stale_source_ids(root: Path) -> list[str]: infospace = load_infospace(root) stale: list[str] = [] for item in infospace.artifacts: if item.kind != "source": continue path = infospace.root / item.path expected = str(item.provenance.get("digest") or "") if not path.is_file() or (expected and _digest_text(path.read_text(encoding="utf-8")) != expected): stale.append(item.id) return stale def _mark_workflow_completed( state: dict[str, Any], result: WorkflowRunResult, ) -> dict[str, Any]: stage_status = dict(state.get("stage_status") or {}) stage_status[result.workflow_id] = { "status": result.status, "run_id": result.run_id, "output_artifact_ids": [output.artifact_id for output in result.outputs], "updated_at": _now(), } return {**state, "stage_status": stage_status} def _profile_digest(root: Path, profile: str) -> str: files: list[Path] = [] for base in ( root / "profiles" / profile, root / "workflows" / "templates" / profile, ): if base.is_dir(): files.extend(path for path in sorted(base.rglob("*")) if path.is_file()) hasher = hashlib.sha256() for path in files: hasher.update(str(path.relative_to(root)).encode("utf-8")) hasher.update(path.read_bytes()) return hasher.hexdigest() def _read_state(root: Path) -> dict[str, Any]: path = root / STATE_PATH if not path.is_file(): return {} data = yaml.safe_load(path.read_text(encoding="utf-8")) return data if isinstance(data, dict) else {} def _write_state(root: Path, state: dict[str, Any]) -> None: path = root / STATE_PATH path.parent.mkdir(parents=True, exist_ok=True) path.write_text(yaml.safe_dump(state, sort_keys=False), encoding="utf-8") def _digest_text(text: str) -> str: return hashlib.sha256(text.encode("utf-8")).hexdigest() def _now() -> str: return datetime.now(timezone.utc).isoformat()