from __future__ import annotations from functools import lru_cache from pathlib import Path from typing import Any import yaml from task_flow_engine import AssertionDef, AssertionResult, FlowDef, FlowEngine, FlowResult FLOW_DIR = Path(__file__).resolve().parents[1] / "flows" @lru_cache def load_flow(entity_type: str) -> FlowDef: path = FLOW_DIR / f"{entity_type}.yaml" data = yaml.safe_load(path.read_text(encoding="utf-8")) return FlowDef.from_dict(data) def evaluate_transition( entity_type: str, current_workstation: str, target_workstation: str, extra: dict[str, Any] | None = None, ) -> tuple[bool, list[AssertionResult], FlowResult]: flow = load_flow(entity_type) obj = { "status": current_workstation, "workstation": current_workstation, "previous_workstation": current_workstation, **(extra or {}), } engine = create_flow_engine() result = engine.evaluate(obj, flow) can_reach, failures = engine.can_reach(obj, flow, target_workstation) return can_reach, failures, result def create_flow_engine() -> FlowEngine: return FlowEngine( custom_ops={ "dependencies.any_incomplete": _dependencies_any_incomplete, } ) def _dependencies_any_incomplete( assertion: AssertionDef, obj: dict[str, Any], values: list[Any], ) -> bool: expected = assertion.value if isinstance(expected, list): return bool(values) and any(value not in expected for value in values) return bool(values) and any(value != expected for value in values) def assertion_result_to_dict(result: AssertionResult) -> dict[str, Any]: return { "id": result.id, "passed": result.passed, "target": result.target, "op": result.op, "expected": result.expected, "actual": result.actual, "description": result.description, "reason": result.reason, } def flow_result_to_dict(result: FlowResult) -> dict[str, Any]: return { "current_workstation": result.current_workstation, "exit_blocked": result.exit_blocked, "blocking_assertions": [ assertion_result_to_dict(item) for item in result.blocking_assertions ], "reachable": result.reachable, "unreachable": [ { "workstation": item.workstation, "blocking": assertion_result_to_dict(item.blocking), } for item in result.unreachable ], }