generated from coulomb/repo-seed
141 lines
4.6 KiB
Python
141 lines
4.6 KiB
Python
from __future__ import annotations
|
|
|
|
from collections.abc import Callable
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
from task_flow_engine import builtins
|
|
from task_flow_engine.models import AssertionDef, AssertionResult
|
|
|
|
|
|
CustomOp = Callable[[AssertionDef, dict[str, Any], list[Any]], bool | tuple[bool, str]]
|
|
|
|
|
|
@dataclass
|
|
class AssertionEvaluator:
|
|
custom_ops: dict[str, CustomOp] | None = None
|
|
max_nodes: int = 1_000
|
|
|
|
def evaluate(self, assertion: AssertionDef, obj: dict[str, Any]) -> AssertionResult:
|
|
values = resolve_target(obj, assertion.target, max_nodes=self.max_nodes)
|
|
passed, reason = self._evaluate(assertion, obj, values)
|
|
if not reason:
|
|
reason = _default_reason(assertion, values, passed)
|
|
return AssertionResult(
|
|
id=assertion.id,
|
|
passed=passed,
|
|
target=assertion.target,
|
|
op=assertion.op,
|
|
expected=assertion.value,
|
|
actual=values,
|
|
description=assertion.description,
|
|
reason=reason,
|
|
)
|
|
|
|
def _evaluate(
|
|
self,
|
|
assertion: AssertionDef,
|
|
obj: dict[str, Any],
|
|
values: list[Any],
|
|
) -> tuple[bool, str]:
|
|
if assertion.op == "all_eq":
|
|
return builtins.all_eq(values, assertion.value), ""
|
|
if assertion.op == "any_eq":
|
|
return builtins.any_eq(values, assertion.value), ""
|
|
if assertion.op == "none_eq":
|
|
return builtins.none_eq(values, assertion.value), ""
|
|
if assertion.op == "exists":
|
|
return builtins.exists(values, assertion.value), ""
|
|
if assertion.op == "count_gte":
|
|
return builtins.count_gte(values, assertion.value), ""
|
|
if assertion.op == "custom":
|
|
return self._evaluate_custom(assertion, obj, values)
|
|
return False, f"Unknown assertion op '{assertion.op}'."
|
|
|
|
def _evaluate_custom(
|
|
self,
|
|
assertion: AssertionDef,
|
|
obj: dict[str, Any],
|
|
values: list[Any],
|
|
) -> tuple[bool, str]:
|
|
if not self.custom_ops or assertion.id not in self.custom_ops:
|
|
return False, f"No custom op registered for assertion '{assertion.id}'."
|
|
result = self.custom_ops[assertion.id](assertion, obj, values)
|
|
if isinstance(result, tuple):
|
|
passed, reason = result
|
|
return bool(passed), reason
|
|
return bool(result), ""
|
|
|
|
|
|
def resolve_target(obj: Any, target: str, max_nodes: int = 1_000) -> list[Any]:
|
|
if not target:
|
|
return [obj]
|
|
parts = target.split(".")
|
|
seen: set[int] = set()
|
|
values = _resolve(obj, parts, seen, max_nodes)
|
|
return [_scalarize(value) for value in values]
|
|
|
|
|
|
def _resolve(current: Any, parts: list[str], seen: set[int], max_nodes: int) -> list[Any]:
|
|
if len(seen) > max_nodes:
|
|
return []
|
|
current_id = id(current)
|
|
if isinstance(current, (dict, list, tuple, set)) or hasattr(current, "__dict__"):
|
|
if current_id in seen:
|
|
return []
|
|
seen.add(current_id)
|
|
|
|
if not parts:
|
|
return [current]
|
|
|
|
part, rest = parts[0], parts[1:]
|
|
if part == "*":
|
|
values: list[Any] = []
|
|
for item in _iter_items(current):
|
|
values.extend(_resolve(item, rest, seen.copy(), max_nodes))
|
|
return values
|
|
|
|
next_value = _get_child(current, part)
|
|
if next_value is _MISSING:
|
|
return []
|
|
return _resolve(next_value, rest, seen, max_nodes)
|
|
|
|
|
|
def _iter_items(current: Any) -> list[Any]:
|
|
if isinstance(current, dict):
|
|
return list(current.values())
|
|
if isinstance(current, (list, tuple, set)):
|
|
return list(current)
|
|
return []
|
|
|
|
|
|
_MISSING = object()
|
|
|
|
|
|
def _get_child(current: Any, part: str) -> Any:
|
|
if isinstance(current, dict):
|
|
return current.get(part, _MISSING)
|
|
if isinstance(current, (list, tuple)) and part.isdigit():
|
|
index = int(part)
|
|
return current[index] if index < len(current) else _MISSING
|
|
return getattr(current, part, _MISSING)
|
|
|
|
|
|
def _scalarize(value: Any) -> Any:
|
|
if hasattr(value, "value") and not isinstance(value, (str, bytes, bytearray)):
|
|
return value.value
|
|
return value
|
|
|
|
|
|
def _default_reason(assertion: AssertionDef, values: list[Any], passed: bool) -> str:
|
|
if passed:
|
|
return f"Assertion '{assertion.id}' passed."
|
|
if assertion.op == "exists":
|
|
return f"Expected at least one non-empty value at {assertion.target}; got {values}."
|
|
if assertion.op == "count_gte":
|
|
return f"Expected at least {assertion.value} values at {assertion.target}; got {len(values)}."
|
|
return (
|
|
f"Expected {assertion.op} at {assertion.target} with {assertion.value!r}; "
|
|
f"got {values!r}."
|
|
)
|