from __future__ import annotations from collections.abc import Callable from dataclasses import dataclass from typing import Any from task_flow_engine import builtins from task_flow_engine.models import AssertionDef, AssertionResult CustomOp = Callable[[AssertionDef, dict[str, Any], list[Any]], bool | tuple[bool, str]] @dataclass class AssertionEvaluator: custom_ops: dict[str, CustomOp] | None = None max_nodes: int = 1_000 def evaluate(self, assertion: AssertionDef, obj: dict[str, Any]) -> AssertionResult: values = resolve_target(obj, assertion.target, max_nodes=self.max_nodes) passed, reason = self._evaluate(assertion, obj, values) if not reason: reason = _default_reason(assertion, values, passed) return AssertionResult( id=assertion.id, passed=passed, target=assertion.target, op=assertion.op, expected=assertion.value, actual=values, description=assertion.description, reason=reason, ) def _evaluate( self, assertion: AssertionDef, obj: dict[str, Any], values: list[Any], ) -> tuple[bool, str]: if assertion.op == "all_eq": return builtins.all_eq(values, assertion.value), "" if assertion.op == "any_eq": return builtins.any_eq(values, assertion.value), "" if assertion.op == "none_eq": return builtins.none_eq(values, assertion.value), "" if assertion.op == "exists": return builtins.exists(values, assertion.value), "" if assertion.op == "count_gte": return builtins.count_gte(values, assertion.value), "" if assertion.op == "custom": return self._evaluate_custom(assertion, obj, values) return False, f"Unknown assertion op '{assertion.op}'." def _evaluate_custom( self, assertion: AssertionDef, obj: dict[str, Any], values: list[Any], ) -> tuple[bool, str]: if not self.custom_ops or assertion.id not in self.custom_ops: return False, f"No custom op registered for assertion '{assertion.id}'." result = self.custom_ops[assertion.id](assertion, obj, values) if isinstance(result, tuple): passed, reason = result return bool(passed), reason return bool(result), "" def resolve_target(obj: Any, target: str, max_nodes: int = 1_000) -> list[Any]: if not target: return [obj] parts = target.split(".") seen: set[int] = set() values = _resolve(obj, parts, seen, max_nodes) return [_scalarize(value) for value in values] def _resolve(current: Any, parts: list[str], seen: set[int], max_nodes: int) -> list[Any]: if len(seen) > max_nodes: return [] current_id = id(current) if isinstance(current, (dict, list, tuple, set)) or hasattr(current, "__dict__"): if current_id in seen: return [] seen.add(current_id) if not parts: return [current] part, rest = parts[0], parts[1:] if part == "*": values: list[Any] = [] for item in _iter_items(current): values.extend(_resolve(item, rest, seen.copy(), max_nodes)) return values next_value = _get_child(current, part) if next_value is _MISSING: return [] return _resolve(next_value, rest, seen, max_nodes) def _iter_items(current: Any) -> list[Any]: if isinstance(current, dict): return list(current.values()) if isinstance(current, (list, tuple, set)): return list(current) return [] _MISSING = object() def _get_child(current: Any, part: str) -> Any: if isinstance(current, dict): return current.get(part, _MISSING) if isinstance(current, (list, tuple)) and part.isdigit(): index = int(part) return current[index] if index < len(current) else _MISSING return getattr(current, part, _MISSING) def _scalarize(value: Any) -> Any: if hasattr(value, "value") and not isinstance(value, (str, bytes, bytearray)): return value.value return value def _default_reason(assertion: AssertionDef, values: list[Any], passed: bool) -> str: if passed: return f"Assertion '{assertion.id}' passed." if assertion.op == "exists": return f"Expected at least one non-empty value at {assertion.target}; got {values}." if assertion.op == "count_gte": return f"Expected at least {assertion.value} values at {assertion.target}; got {len(values)}." return ( f"Expected {assertion.op} at {assertion.target} with {assertion.value!r}; " f"got {values!r}." )