Files
infospace-bench/src/infospace_bench/workflow.py

722 lines
25 KiB
Python

from __future__ import annotations
import uuid
from dataclasses import asdict, dataclass, field
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Protocol
import yaml
from .errors import InfospaceError
from .generation import write_entity_bundle_artifacts
from .lifecycle import load_infospace, register_artifact
from .markdown_adapter import render_markdown_template
from .models import KnowledgeArtifact
@dataclass(frozen=True)
class WorkflowInputSpec:
kind: str = ""
artifact_ids: list[str] = field(default_factory=list)
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "WorkflowInputSpec":
return cls(
kind=str(data.get("kind") or ""),
artifact_ids=[str(item) for item in data.get("artifact_ids", [])],
)
def to_dict(self) -> dict[str, Any]:
data: dict[str, Any] = {}
if self.kind:
data["kind"] = self.kind
if self.artifact_ids:
data["artifact_ids"] = self.artifact_ids
return data
@dataclass(frozen=True)
class WorkflowOutputSpec:
path: str
kind: str = "generated"
artifact_id: str = ""
title: str = ""
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "WorkflowOutputSpec":
return cls(
path=str(data["path"]),
kind=str(data.get("kind") or "generated"),
artifact_id=str(data.get("artifact_id") or ""),
title=str(data.get("title") or ""),
)
def to_dict(self) -> dict[str, Any]:
data = asdict(self)
return {key: value for key, value in data.items() if value not in ("", [])}
@dataclass(frozen=True)
class WorkflowStage:
id: str
kind: str
input: str
template: str = ""
output: WorkflowOutputSpec | None = None
static_macros: dict[str, Any] = field(default_factory=dict)
provider_hint: str | None = None
optional: bool = False
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "WorkflowStage":
output = data.get("output")
return cls(
id=str(data["id"]),
kind=str(data.get("kind") or "template"),
input=str(data.get("input") or ""),
template=str(data.get("template") or ""),
output=WorkflowOutputSpec.from_dict(output) if isinstance(output, dict) else None,
static_macros=dict(data.get("static_macros") or {}),
provider_hint=(
str(data["provider_hint"]) if data.get("provider_hint") else None
),
optional=bool(data.get("optional", False)),
)
def to_dict(self) -> dict[str, Any]:
data: dict[str, Any] = {
"id": self.id,
"kind": self.kind,
"input": self.input,
"template": self.template,
"static_macros": self.static_macros,
"optional": self.optional,
}
if self.output is not None:
data["output"] = self.output.to_dict()
if self.provider_hint:
data["provider_hint"] = self.provider_hint
return {key: value for key, value in data.items() if value not in ("", [], {})}
@dataclass(frozen=True)
class WorkflowDefinition:
id: str
description: str = ""
inputs: dict[str, WorkflowInputSpec] = field(default_factory=dict)
stages: list[WorkflowStage] = field(default_factory=list)
static_macros: dict[str, Any] = field(default_factory=dict)
expected_evaluations: list[str] = field(default_factory=list)
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "WorkflowDefinition":
return cls(
id=str(data["id"]),
description=str(data.get("description") or ""),
inputs={
str(name): WorkflowInputSpec.from_dict(spec)
for name, spec in (data.get("inputs") or {}).items()
if isinstance(spec, dict)
},
stages=[
WorkflowStage.from_dict(item)
for item in data.get("stages", [])
if isinstance(item, dict)
],
static_macros=dict(data.get("static_macros") or {}),
expected_evaluations=[
str(item) for item in data.get("expected_evaluations", [])
],
)
def to_dict(self) -> dict[str, Any]:
return {
"id": self.id,
"description": self.description,
"inputs": {
name: spec.to_dict() for name, spec in self.inputs.items()
},
"stages": [stage.to_dict() for stage in self.stages],
"static_macros": self.static_macros,
"expected_evaluations": self.expected_evaluations,
}
@dataclass(frozen=True)
class WorkflowInputRecord:
name: str
artifact_id: str
kind: str
title: str
path: str
slug: str
content: str
def to_template_data(self) -> dict[str, Any]:
return asdict(self)
def to_dict(self) -> dict[str, Any]:
data = asdict(self)
data["content"] = self.content
return data
@dataclass(frozen=True)
class WorkflowOutputRecord:
stage_id: str
artifact_id: str
path: str
kind: str
title: str
input_artifact_id: str
written: bool
def to_dict(self) -> dict[str, Any]:
return asdict(self)
@dataclass(frozen=True)
class AssistedGenerationRequest:
stage_id: str
workflow_id: str
input_artifact_id: str
prompt: str
provider_hint: str | None = None
metadata: dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> dict[str, Any]:
return {
"stage_id": self.stage_id,
"workflow_id": self.workflow_id,
"input_artifact_id": self.input_artifact_id,
"prompt": self.prompt,
"provider_hint": self.provider_hint,
"metadata": self.metadata,
}
@dataclass(frozen=True)
class AssistedGenerationResult:
markdown: str
provider: str = ""
metadata: dict[str, Any] = field(default_factory=dict)
class AssistedGenerationAdapter(Protocol):
def generate(
self,
request: AssistedGenerationRequest,
) -> AssistedGenerationResult:
"""Generate Markdown for an assisted workflow request."""
class FixtureAssistedGenerationAdapter:
def __init__(
self,
responses: dict[tuple[str, str], AssistedGenerationResult],
) -> None:
self.responses = responses
@classmethod
def from_file(cls, path: str | Path) -> "FixtureAssistedGenerationAdapter":
source = Path(path)
data = yaml.safe_load(source.read_text(encoding="utf-8")) or {}
if not isinstance(data, dict):
raise InfospaceError(
"invalid_assisted_fixture",
f"Expected mapping in assisted fixture file: {source}",
{"path": str(source)},
)
responses: dict[tuple[str, str], AssistedGenerationResult] = {}
for item in data.get("responses", []):
if not isinstance(item, dict):
continue
stage_id = str(item["stage_id"])
input_artifact_id = str(item.get("input_artifact_id") or "*")
markdown = str(item.get("markdown") or "")
markdown_path = item.get("markdown_path")
if markdown_path:
markdown = (source.parent / str(markdown_path)).read_text(
encoding="utf-8"
)
responses[(stage_id, input_artifact_id)] = AssistedGenerationResult(
markdown=markdown,
provider=str(item.get("provider") or "fixture"),
metadata=dict(item.get("metadata") or {}),
)
return cls(responses)
def generate(
self,
request: AssistedGenerationRequest,
) -> AssistedGenerationResult:
key = (request.stage_id, request.input_artifact_id)
result = self.responses.get(key) or self.responses.get((request.stage_id, "*"))
if result is None:
raise InfospaceError(
"missing_assisted_fixture_response",
"No fixture response for assisted workflow request",
{
"stage_id": request.stage_id,
"input_artifact_id": request.input_artifact_id,
},
)
return result
@dataclass(frozen=True)
class WorkflowStageRecord:
stage_id: str
kind: str
status: str
input_artifact_id: str
output_artifact_id: str = ""
message: str = ""
def to_dict(self) -> dict[str, Any]:
data = asdict(self)
return {key: value for key, value in data.items() if value != ""}
@dataclass(frozen=True)
class WorkflowRunResult:
run_id: str
workflow_id: str
status: str
dry_run: bool
inputs: list[WorkflowInputRecord] = field(default_factory=list)
stages: list[WorkflowStageRecord] = field(default_factory=list)
outputs: list[WorkflowOutputRecord] = field(default_factory=list)
assisted_requests: list[AssistedGenerationRequest] = field(default_factory=list)
run_record_path: str = ""
def to_dict(self) -> dict[str, Any]:
data = {
"run_id": self.run_id,
"workflow_id": self.workflow_id,
"status": self.status,
"dry_run": self.dry_run,
"inputs": [item.to_dict() for item in self.inputs],
"stages": [item.to_dict() for item in self.stages],
"outputs": [item.to_dict() for item in self.outputs],
"assisted_requests": [
item.to_dict() for item in self.assisted_requests
],
"run_record_path": self.run_record_path,
}
return {key: value for key, value in data.items() if value not in ("", [])}
def load_workflows(root: str | Path) -> list[WorkflowDefinition]:
infospace = load_infospace(root)
return [
WorkflowDefinition.from_dict(item)
for item in infospace.config.workflows
if isinstance(item, dict)
]
def get_workflow(root: str | Path, workflow_id: str) -> WorkflowDefinition:
for workflow in load_workflows(root):
if workflow.id == workflow_id:
return workflow
raise InfospaceError(
"missing_workflow",
f"Workflow is not declared: {workflow_id}",
{"workflow_id": workflow_id},
)
def plan_workflow(root: str | Path, workflow_id: str) -> WorkflowRunResult:
return _execute_workflow(root, workflow_id, dry_run=True)
def run_workflow(
root: str | Path,
workflow_id: str,
*,
assisted_adapter: AssistedGenerationAdapter | None = None,
) -> WorkflowRunResult:
return _execute_workflow(
root,
workflow_id,
dry_run=False,
assisted_adapter=assisted_adapter,
)
def _execute_workflow(
root: str | Path,
workflow_id: str,
*,
dry_run: bool,
assisted_adapter: AssistedGenerationAdapter | None = None,
) -> WorkflowRunResult:
infospace = load_infospace(root)
workflow = get_workflow(infospace.root, workflow_id)
run_id = uuid.uuid4().hex[:12]
inputs = _collect_inputs(infospace.root, infospace.artifacts, workflow)
stages: list[WorkflowStageRecord] = []
outputs: list[WorkflowOutputRecord] = []
assisted_requests: list[AssistedGenerationRequest] = []
stage_outputs: dict[str, dict[str, Any]] = {}
for stage in workflow.stages:
selected_inputs = [item for item in inputs if item.name == stage.input]
if not selected_inputs:
raise InfospaceError(
"workflow_stage_has_no_inputs",
f"Workflow stage has no matching inputs: {stage.id}",
{"workflow_id": workflow.id, "stage_id": stage.id},
)
for input_record in selected_inputs:
data = _template_data(workflow, stage, input_record, stage_outputs)
if stage.kind == "template":
template_text = _read_template(infospace.root, stage.template)
rendered = render_markdown_template(template_text, data)
output = _resolve_output(
workflow,
stage,
input_record,
rendered.markdown,
data,
infospace.root,
dry_run=dry_run,
)
outputs.append(output)
stage_outputs[stage.id] = {
"content": rendered.markdown,
"artifact_id": output.artifact_id,
"path": output.path,
"provider": "",
}
stages.append(
WorkflowStageRecord(
stage_id=stage.id,
kind=stage.kind,
status="planned" if dry_run else "completed",
input_artifact_id=input_record.artifact_id,
output_artifact_id=output.artifact_id,
)
)
elif stage.kind == "assisted":
template_text = _read_template(infospace.root, stage.template)
rendered = render_markdown_template(template_text, data)
request = AssistedGenerationRequest(
stage_id=stage.id,
workflow_id=workflow.id,
input_artifact_id=input_record.artifact_id,
prompt=rendered.markdown,
provider_hint=stage.provider_hint,
metadata={"output": stage.output.to_dict() if stage.output else {}},
)
assisted_requests.append(request)
if dry_run:
stages.append(
WorkflowStageRecord(
stage_id=stage.id,
kind=stage.kind,
status="requires_adapter",
input_artifact_id=input_record.artifact_id,
)
)
continue
if assisted_adapter is None:
raise InfospaceError(
"assisted_stage_requires_adapter",
"Assisted workflow stages require an explicit adapter",
{
"workflow_id": workflow.id,
"stage_id": stage.id,
"input_artifact_id": input_record.artifact_id,
},
)
result = assisted_adapter.generate(request)
output = _resolve_output(
workflow,
stage,
input_record,
result.markdown,
data,
infospace.root,
dry_run=False,
provider=result.provider,
)
outputs.append(output)
stage_outputs[stage.id] = {
"content": result.markdown,
"artifact_id": output.artifact_id,
"path": output.path,
"provider": result.provider,
"metadata": result.metadata,
}
stages.append(
WorkflowStageRecord(
stage_id=stage.id,
kind=stage.kind,
status="completed",
input_artifact_id=input_record.artifact_id,
output_artifact_id=output.artifact_id,
)
)
elif stage.kind == "split_entities":
bundle_stage = str(stage.static_macros.get("bundle_stage") or "")
if not bundle_stage:
raise InfospaceError(
"missing_split_bundle_stage",
"split_entities stage requires static_macros.bundle_stage",
{"workflow_id": workflow.id, "stage_id": stage.id},
)
bundle_output = stage_outputs.get(bundle_stage)
if bundle_output is None:
if dry_run:
stages.append(
WorkflowStageRecord(
stage_id=stage.id,
kind=stage.kind,
status="waiting_for_assisted_output",
input_artifact_id=input_record.artifact_id,
)
)
continue
raise InfospaceError(
"missing_split_bundle_output",
"split_entities stage could not find the source bundle output",
{
"workflow_id": workflow.id,
"stage_id": stage.id,
"bundle_stage": bundle_stage,
},
)
items = write_entity_bundle_artifacts(
infospace.root,
str(bundle_output.get("content") or ""),
workflow_id=workflow.id,
stage_id=stage.id,
input_artifact_id=input_record.artifact_id,
source_bundle_artifact_id=str(
bundle_output.get("artifact_id") or ""
),
provider=str(bundle_output.get("provider") or ""),
dry_run=dry_run,
)
for item in items:
outputs.append(
WorkflowOutputRecord(
stage_id=stage.id,
artifact_id=item.artifact_id,
path=item.path,
kind="entity",
title=item.title,
input_artifact_id=input_record.artifact_id,
written=not dry_run,
)
)
stage_outputs[stage.id] = {
"content": "\n".join(item.markdown for item in items),
"artifact_id": ",".join(item.artifact_id for item in items),
"path": ",".join(item.path for item in items),
"provider": str(bundle_output.get("provider") or ""),
}
stages.append(
WorkflowStageRecord(
stage_id=stage.id,
kind=stage.kind,
status="planned" if dry_run else "completed",
input_artifact_id=input_record.artifact_id,
output_artifact_id=",".join(
item.artifact_id for item in items
),
message=f"split {len(items)} entities",
)
)
else:
raise InfospaceError(
"unsupported_workflow_stage",
f"Unsupported workflow stage kind: {stage.kind}",
{
"workflow_id": workflow.id,
"stage_id": stage.id,
"kind": stage.kind,
},
)
status = "planned" if dry_run else "completed"
run_record_path = ""
result = WorkflowRunResult(
run_id=run_id,
workflow_id=workflow.id,
status=status,
dry_run=dry_run,
inputs=inputs,
stages=stages,
outputs=outputs,
assisted_requests=assisted_requests,
)
if not dry_run:
run_record_path = _write_run_record(infospace.root, result)
result = WorkflowRunResult(
run_id=run_id,
workflow_id=workflow.id,
status=status,
dry_run=dry_run,
inputs=inputs,
stages=stages,
outputs=outputs,
assisted_requests=assisted_requests,
run_record_path=run_record_path,
)
return result
def _collect_inputs(
root: Path,
artifacts: list[KnowledgeArtifact],
workflow: WorkflowDefinition,
) -> list[WorkflowInputRecord]:
records: list[WorkflowInputRecord] = []
for name, spec in workflow.inputs.items():
selected = [
artifact
for artifact in artifacts
if _matches_input_spec(artifact, spec)
]
for artifact in selected:
artifact_path = root / artifact.path
records.append(
WorkflowInputRecord(
name=name,
artifact_id=artifact.id,
kind=artifact.kind,
title=artifact.title or Path(artifact.path).stem,
path=artifact.path,
slug=Path(artifact.path).stem,
content=artifact_path.read_text(encoding="utf-8"),
)
)
return records
def _matches_input_spec(
artifact: KnowledgeArtifact,
spec: WorkflowInputSpec,
) -> bool:
if spec.artifact_ids and artifact.id not in spec.artifact_ids:
return False
if spec.kind and artifact.kind != spec.kind:
return False
return True
def _template_data(
workflow: WorkflowDefinition,
stage: WorkflowStage,
input_record: WorkflowInputRecord,
stage_outputs: dict[str, dict[str, Any]],
) -> dict[str, Any]:
return {
"workflow": workflow.to_dict(),
"stage": stage.to_dict(),
"input": input_record.to_template_data(),
"macros": {**workflow.static_macros, **stage.static_macros},
"stages": stage_outputs,
}
def _read_template(root: Path, relative_path: str) -> str:
path = root / relative_path
if not path.is_file():
raise InfospaceError(
"missing_workflow_template",
f"Workflow template does not exist: {relative_path}",
{"template": relative_path},
)
return path.read_text(encoding="utf-8")
def _resolve_output(
workflow: WorkflowDefinition,
stage: WorkflowStage,
input_record: WorkflowInputRecord,
markdown: str,
data: dict[str, Any],
root: Path,
*,
dry_run: bool,
provider: str = "",
) -> WorkflowOutputRecord:
if stage.output is None:
raise InfospaceError(
"missing_workflow_output",
f"Workflow stage has no output declaration: {stage.id}",
{"workflow_id": workflow.id, "stage_id": stage.id},
)
output_path = _render_inline(stage.output.path, data)
artifact_id = _render_inline(stage.output.artifact_id, data)
if not artifact_id:
artifact_id = f"{stage.output.kind}/{Path(output_path).name}"
title = _render_inline(stage.output.title, data)
target = _safe_target(root, output_path)
if not dry_run:
target.parent.mkdir(parents=True, exist_ok=True)
target.write_text(markdown, encoding="utf-8")
if stage.output.kind != "evaluation":
register_artifact(
root,
artifact_id=artifact_id,
path=output_path,
kind=stage.output.kind,
title=title,
provenance={
"workflow_id": workflow.id,
"stage_id": stage.id,
"input_artifact_id": input_record.artifact_id,
**({"provider": provider} if provider else {}),
},
relationships=[
{
"type": "generated_from",
"target": input_record.artifact_id,
}
],
)
return WorkflowOutputRecord(
stage_id=stage.id,
artifact_id=artifact_id,
path=output_path,
kind=stage.output.kind,
title=title,
input_artifact_id=input_record.artifact_id,
written=not dry_run,
)
def _render_inline(template_text: str, data: dict[str, Any]) -> str:
if not template_text:
return ""
return render_markdown_template(template_text, data).markdown
def _safe_target(root: Path, relative_path: str) -> Path:
target = (root / relative_path).resolve()
root_resolved = root.resolve()
try:
target.relative_to(root_resolved)
except ValueError as exc:
raise InfospaceError(
"workflow_output_escapes_infospace",
f"Workflow output path escapes infospace: {relative_path}",
{"root": str(root), "path": relative_path},
) from exc
return target
def _write_run_record(root: Path, result: WorkflowRunResult) -> str:
run_path = root / "output" / "workflows" / "runs" / f"{result.run_id}.yaml"
payload = result.to_dict()
payload["recorded_at"] = datetime.now(timezone.utc).isoformat()
run_path.parent.mkdir(parents=True, exist_ok=True)
run_path.write_text(yaml.safe_dump(payload, sort_keys=False), encoding="utf-8")
return str(run_path)