feat(prompts): add batch LLM evaluation orchestrator (S1.6)
BatchEvaluator runs evaluation prompts across item batches with incremental evaluation (skip unchanged via content digest), per-item error isolation, progress callbacks, and aggregate token usage tracking. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
168
markitect/prompts/execution/batch.py
Normal file
168
markitect/prompts/execution/batch.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""
|
||||
Batch LLM evaluation orchestrator.
|
||||
|
||||
Runs an evaluation prompt against a batch of items (entities, pairs,
|
||||
etc.), collecting structured results. Handles:
|
||||
|
||||
- Incremental evaluation (skip items whose content hasn't changed)
|
||||
- Progress reporting via callback
|
||||
- Graceful error handling per item (one failure doesn't stop the batch)
|
||||
- Aggregate token usage tracking
|
||||
|
||||
This is the mechanism by which infospace tooling delegates LLM work
|
||||
to the platform. The adapter's own retry logic handles transient
|
||||
API errors (rate limits, 5xx).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from markitect.prompts.execution.llm_adapter import LLMAdapter
|
||||
from markitect.prompts.execution.models import LLMResponse, RunConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchItem:
|
||||
"""A single item to evaluate in a batch.
|
||||
|
||||
Attributes:
|
||||
key: Unique identifier (e.g. entity slug).
|
||||
prompt: The compiled prompt text to send to the LLM.
|
||||
content_digest: Hash of the source content, used for
|
||||
incremental evaluation (skip if unchanged).
|
||||
metadata: Arbitrary pass-through metadata.
|
||||
"""
|
||||
|
||||
key: str
|
||||
prompt: str
|
||||
content_digest: str = ""
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchResult:
|
||||
"""Result for a single batch item.
|
||||
|
||||
Attributes:
|
||||
key: Matches the input :attr:`BatchItem.key`.
|
||||
status: One of ``"success"``, ``"error"``, ``"skipped"``.
|
||||
response: The LLM response (``None`` if skipped or error).
|
||||
error: Error message (``None`` if success or skipped).
|
||||
metadata: Pass-through metadata from the input item.
|
||||
"""
|
||||
|
||||
key: str
|
||||
status: str
|
||||
response: Optional[LLMResponse] = None
|
||||
error: Optional[str] = None
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchSummary:
|
||||
"""Aggregate results from a batch evaluation run."""
|
||||
|
||||
total: int = 0
|
||||
succeeded: int = 0
|
||||
failed: int = 0
|
||||
skipped: int = 0
|
||||
results: List[BatchResult] = field(default_factory=list)
|
||||
total_prompt_tokens: int = 0
|
||||
total_completion_tokens: int = 0
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int:
|
||||
return self.total_prompt_tokens + self.total_completion_tokens
|
||||
|
||||
def success_rate(self) -> float:
|
||||
"""Fraction of non-skipped items that succeeded."""
|
||||
attempted = self.total - self.skipped
|
||||
if attempted == 0:
|
||||
return 1.0
|
||||
return self.succeeded / attempted
|
||||
|
||||
|
||||
class BatchEvaluator:
|
||||
"""Orchestrates LLM evaluation across a batch of items.
|
||||
|
||||
Args:
|
||||
adapter: The LLM adapter to use for evaluation.
|
||||
config: Run configuration (model, temperature, etc.).
|
||||
progress_callback: Optional ``fn(completed, total, result)``
|
||||
called after each item is processed.
|
||||
previous_digests: Optional ``{key: digest}`` mapping from a
|
||||
previous run. Items whose digest matches are skipped.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
adapter: LLMAdapter,
|
||||
config: Optional[RunConfig] = None,
|
||||
progress_callback: Optional[Callable[[int, int, BatchResult], None]] = None,
|
||||
previous_digests: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
self._adapter = adapter
|
||||
self._config = config or RunConfig()
|
||||
self._progress_callback = progress_callback
|
||||
self._previous_digests = previous_digests or {}
|
||||
|
||||
def evaluate(self, items: List[BatchItem]) -> BatchSummary:
|
||||
"""Run evaluation for all items and return aggregate results.
|
||||
|
||||
Items whose :attr:`~BatchItem.content_digest` matches an entry
|
||||
in *previous_digests* are skipped. All other items are sent to
|
||||
the LLM adapter. Errors on individual items are captured
|
||||
without aborting the batch.
|
||||
"""
|
||||
summary = BatchSummary(total=len(items))
|
||||
|
||||
for idx, item in enumerate(items):
|
||||
result = self._evaluate_one(item)
|
||||
summary.results.append(result)
|
||||
|
||||
if result.status == "success":
|
||||
summary.succeeded += 1
|
||||
usage = result.response.usage if result.response else {}
|
||||
summary.total_prompt_tokens += usage.get("prompt_tokens", 0)
|
||||
summary.total_completion_tokens += usage.get("completion_tokens", 0)
|
||||
elif result.status == "skipped":
|
||||
summary.skipped += 1
|
||||
else:
|
||||
summary.failed += 1
|
||||
|
||||
if self._progress_callback is not None:
|
||||
self._progress_callback(idx + 1, len(items), result)
|
||||
|
||||
return summary
|
||||
|
||||
def _evaluate_one(self, item: BatchItem) -> BatchResult:
|
||||
"""Evaluate a single item, handling skip logic and errors."""
|
||||
# Incremental: skip if digest unchanged
|
||||
if (
|
||||
item.content_digest
|
||||
and item.key in self._previous_digests
|
||||
and self._previous_digests[item.key] == item.content_digest
|
||||
):
|
||||
return BatchResult(
|
||||
key=item.key,
|
||||
status="skipped",
|
||||
metadata=item.metadata,
|
||||
)
|
||||
|
||||
try:
|
||||
response = self._adapter.execute_prompt(item.prompt, self._config)
|
||||
return BatchResult(
|
||||
key=item.key,
|
||||
status="success",
|
||||
response=response,
|
||||
metadata=item.metadata,
|
||||
)
|
||||
except Exception as exc:
|
||||
return BatchResult(
|
||||
key=item.key,
|
||||
status="error",
|
||||
error=str(exc),
|
||||
metadata=item.metadata,
|
||||
)
|
||||
281
tests/unit/prompts/test_batch_evaluator.py
Normal file
281
tests/unit/prompts/test_batch_evaluator.py
Normal file
@@ -0,0 +1,281 @@
|
||||
"""Tests for markitect.prompts.execution.batch."""
|
||||
|
||||
import pytest
|
||||
|
||||
from markitect.prompts.execution.batch import (
|
||||
BatchEvaluator,
|
||||
BatchItem,
|
||||
BatchResult,
|
||||
BatchSummary,
|
||||
)
|
||||
from markitect.prompts.execution.llm_adapter import MockLLMAdapter, ErrorLLMAdapter
|
||||
from markitect.prompts.execution.models import RunConfig, LLMResponse
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _items(n=3, digest_prefix="d"):
|
||||
return [
|
||||
BatchItem(
|
||||
key=f"entity-{i}",
|
||||
prompt=f"Evaluate entity {i}",
|
||||
content_digest=f"{digest_prefix}{i}",
|
||||
metadata={"index": i},
|
||||
)
|
||||
for i in range(n)
|
||||
]
|
||||
|
||||
|
||||
# ── BatchItem / BatchResult / BatchSummary ───────────────────────────
|
||||
|
||||
|
||||
class TestBatchModels:
|
||||
def test_batch_item_defaults(self):
|
||||
item = BatchItem(key="slug", prompt="text")
|
||||
assert item.content_digest == ""
|
||||
assert item.metadata == {}
|
||||
|
||||
def test_batch_result_defaults(self):
|
||||
result = BatchResult(key="slug", status="success")
|
||||
assert result.response is None
|
||||
assert result.error is None
|
||||
|
||||
def test_summary_total_tokens(self):
|
||||
s = BatchSummary(total_prompt_tokens=100, total_completion_tokens=50)
|
||||
assert s.total_tokens == 150
|
||||
|
||||
def test_summary_success_rate_all_success(self):
|
||||
s = BatchSummary(total=3, succeeded=3)
|
||||
assert s.success_rate() == 1.0
|
||||
|
||||
def test_summary_success_rate_with_failures(self):
|
||||
s = BatchSummary(total=4, succeeded=2, failed=2)
|
||||
assert s.success_rate() == pytest.approx(0.5)
|
||||
|
||||
def test_summary_success_rate_all_skipped(self):
|
||||
s = BatchSummary(total=3, skipped=3)
|
||||
assert s.success_rate() == 1.0
|
||||
|
||||
def test_summary_success_rate_mixed(self):
|
||||
s = BatchSummary(total=5, succeeded=2, failed=1, skipped=2)
|
||||
# 3 attempted, 2 succeeded
|
||||
assert s.success_rate() == pytest.approx(2 / 3)
|
||||
|
||||
|
||||
# ── BatchEvaluator ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestBatchEvaluator:
|
||||
def test_evaluate_all_items(self):
|
||||
adapter = MockLLMAdapter("result")
|
||||
evaluator = BatchEvaluator(adapter)
|
||||
summary = evaluator.evaluate(_items(3))
|
||||
|
||||
assert summary.total == 3
|
||||
assert summary.succeeded == 3
|
||||
assert summary.failed == 0
|
||||
assert summary.skipped == 0
|
||||
assert len(summary.results) == 3
|
||||
assert adapter.call_count == 3
|
||||
|
||||
def test_results_preserve_keys(self):
|
||||
adapter = MockLLMAdapter("ok")
|
||||
evaluator = BatchEvaluator(adapter)
|
||||
items = _items(2)
|
||||
summary = evaluator.evaluate(items)
|
||||
|
||||
keys = [r.key for r in summary.results]
|
||||
assert keys == ["entity-0", "entity-1"]
|
||||
|
||||
def test_results_preserve_metadata(self):
|
||||
adapter = MockLLMAdapter("ok")
|
||||
evaluator = BatchEvaluator(adapter)
|
||||
items = _items(1)
|
||||
summary = evaluator.evaluate(items)
|
||||
assert summary.results[0].metadata == {"index": 0}
|
||||
|
||||
def test_response_content_available(self):
|
||||
adapter = MockLLMAdapter("evaluated text")
|
||||
evaluator = BatchEvaluator(adapter)
|
||||
summary = evaluator.evaluate(_items(1))
|
||||
assert summary.results[0].response.content == "evaluated text"
|
||||
|
||||
def test_token_usage_aggregated(self):
|
||||
adapter = MockLLMAdapter("result")
|
||||
evaluator = BatchEvaluator(adapter)
|
||||
summary = evaluator.evaluate(_items(3))
|
||||
assert summary.total_prompt_tokens > 0
|
||||
assert summary.total_completion_tokens > 0
|
||||
assert summary.total_tokens == summary.total_prompt_tokens + summary.total_completion_tokens
|
||||
|
||||
def test_config_passed_to_adapter(self):
|
||||
adapter = MockLLMAdapter("ok")
|
||||
config = RunConfig(temperature=0.1, max_tokens=500)
|
||||
evaluator = BatchEvaluator(adapter, config=config)
|
||||
evaluator.evaluate(_items(1))
|
||||
assert adapter.last_config.temperature == 0.1
|
||||
assert adapter.last_config.max_tokens == 500
|
||||
|
||||
|
||||
# ── Incremental evaluation ──────────────────────────────────────────
|
||||
|
||||
|
||||
class TestIncrementalEvaluation:
|
||||
def test_skip_unchanged_items(self):
|
||||
adapter = MockLLMAdapter("result")
|
||||
previous = {"entity-0": "d0", "entity-1": "d1", "entity-2": "d2"}
|
||||
evaluator = BatchEvaluator(adapter, previous_digests=previous)
|
||||
|
||||
summary = evaluator.evaluate(_items(3))
|
||||
assert summary.skipped == 3
|
||||
assert summary.succeeded == 0
|
||||
assert adapter.call_count == 0
|
||||
|
||||
def test_evaluate_changed_items(self):
|
||||
adapter = MockLLMAdapter("result")
|
||||
# Only entity-0 has matching digest
|
||||
previous = {"entity-0": "d0"}
|
||||
evaluator = BatchEvaluator(adapter, previous_digests=previous)
|
||||
|
||||
summary = evaluator.evaluate(_items(3))
|
||||
assert summary.skipped == 1
|
||||
assert summary.succeeded == 2
|
||||
assert adapter.call_count == 2
|
||||
|
||||
def test_evaluate_new_items(self):
|
||||
adapter = MockLLMAdapter("result")
|
||||
# Previous has different keys
|
||||
previous = {"old-entity": "old-digest"}
|
||||
evaluator = BatchEvaluator(adapter, previous_digests=previous)
|
||||
|
||||
summary = evaluator.evaluate(_items(2))
|
||||
assert summary.skipped == 0
|
||||
assert summary.succeeded == 2
|
||||
|
||||
def test_changed_digest_not_skipped(self):
|
||||
adapter = MockLLMAdapter("result")
|
||||
# Same key but different digest
|
||||
previous = {"entity-0": "old-digest"}
|
||||
evaluator = BatchEvaluator(adapter, previous_digests=previous)
|
||||
|
||||
summary = evaluator.evaluate(_items(1))
|
||||
assert summary.skipped == 0
|
||||
assert summary.succeeded == 1
|
||||
|
||||
def test_empty_digest_not_skipped(self):
|
||||
adapter = MockLLMAdapter("result")
|
||||
previous = {"entity-0": "d0"}
|
||||
evaluator = BatchEvaluator(adapter, previous_digests=previous)
|
||||
|
||||
item = BatchItem(key="entity-0", prompt="eval", content_digest="")
|
||||
summary = evaluator.evaluate([item])
|
||||
assert summary.skipped == 0
|
||||
assert summary.succeeded == 1
|
||||
|
||||
def test_skipped_status_in_result(self):
|
||||
adapter = MockLLMAdapter("result")
|
||||
previous = {"entity-0": "d0"}
|
||||
evaluator = BatchEvaluator(adapter, previous_digests=previous)
|
||||
|
||||
summary = evaluator.evaluate(_items(1))
|
||||
assert summary.results[0].status == "skipped"
|
||||
assert summary.results[0].response is None
|
||||
|
||||
|
||||
# ── Error handling ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestBatchErrorHandling:
|
||||
def test_error_captured_not_raised(self):
|
||||
adapter = ErrorLLMAdapter("kaboom")
|
||||
evaluator = BatchEvaluator(adapter)
|
||||
|
||||
summary = evaluator.evaluate(_items(2))
|
||||
assert summary.failed == 2
|
||||
assert summary.succeeded == 0
|
||||
|
||||
def test_error_message_in_result(self):
|
||||
adapter = ErrorLLMAdapter("something went wrong")
|
||||
evaluator = BatchEvaluator(adapter)
|
||||
|
||||
summary = evaluator.evaluate(_items(1))
|
||||
assert summary.results[0].status == "error"
|
||||
assert "something went wrong" in summary.results[0].error
|
||||
|
||||
def test_error_does_not_stop_batch(self):
|
||||
"""One failing item doesn't prevent others from running."""
|
||||
call_count = 0
|
||||
|
||||
class FailOnFirstAdapter(MockLLMAdapter):
|
||||
def execute_prompt(self, prompt, config):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
raise RuntimeError("first fails")
|
||||
return super().execute_prompt(prompt, config)
|
||||
|
||||
adapter = FailOnFirstAdapter("ok")
|
||||
evaluator = BatchEvaluator(adapter)
|
||||
summary = evaluator.evaluate(_items(3))
|
||||
|
||||
assert summary.failed == 1
|
||||
assert summary.succeeded == 2
|
||||
assert summary.results[0].status == "error"
|
||||
assert summary.results[1].status == "success"
|
||||
assert summary.results[2].status == "success"
|
||||
|
||||
|
||||
# ── Progress callback ───────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestProgressCallback:
|
||||
def test_callback_called_for_each_item(self):
|
||||
calls = []
|
||||
adapter = MockLLMAdapter("ok")
|
||||
evaluator = BatchEvaluator(
|
||||
adapter,
|
||||
progress_callback=lambda done, total, result: calls.append(
|
||||
(done, total, result.key)
|
||||
),
|
||||
)
|
||||
evaluator.evaluate(_items(3))
|
||||
|
||||
assert len(calls) == 3
|
||||
assert calls[0] == (1, 3, "entity-0")
|
||||
assert calls[1] == (2, 3, "entity-1")
|
||||
assert calls[2] == (3, 3, "entity-2")
|
||||
|
||||
def test_callback_receives_result(self):
|
||||
results = []
|
||||
adapter = MockLLMAdapter("ok")
|
||||
evaluator = BatchEvaluator(
|
||||
adapter,
|
||||
progress_callback=lambda done, total, result: results.append(result),
|
||||
)
|
||||
evaluator.evaluate(_items(2))
|
||||
|
||||
assert all(isinstance(r, BatchResult) for r in results)
|
||||
assert results[0].status == "success"
|
||||
|
||||
def test_no_callback_no_error(self):
|
||||
adapter = MockLLMAdapter("ok")
|
||||
evaluator = BatchEvaluator(adapter)
|
||||
# Should work fine without callback
|
||||
summary = evaluator.evaluate(_items(1))
|
||||
assert summary.succeeded == 1
|
||||
|
||||
|
||||
# ── Empty batch ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestEmptyBatch:
|
||||
def test_empty_items(self):
|
||||
adapter = MockLLMAdapter("ok")
|
||||
evaluator = BatchEvaluator(adapter)
|
||||
summary = evaluator.evaluate([])
|
||||
|
||||
assert summary.total == 0
|
||||
assert summary.succeeded == 0
|
||||
assert summary.results == []
|
||||
assert adapter.call_count == 0
|
||||
Reference in New Issue
Block a user