"""Tests for markitect.prompts.execution.batch.""" import pytest from markitect.prompts.execution.batch import ( BatchEvaluator, BatchItem, BatchResult, BatchSummary, ) from markitect.prompts.execution.llm_adapter import MockLLMAdapter, ErrorLLMAdapter from markitect.prompts.execution.models import RunConfig, LLMResponse # ── Helpers ────────────────────────────────────────────────────────── def _items(n=3, digest_prefix="d"): return [ BatchItem( key=f"entity-{i}", prompt=f"Evaluate entity {i}", content_digest=f"{digest_prefix}{i}", metadata={"index": i}, ) for i in range(n) ] # ── BatchItem / BatchResult / BatchSummary ─────────────────────────── class TestBatchModels: def test_batch_item_defaults(self): item = BatchItem(key="slug", prompt="text") assert item.content_digest == "" assert item.metadata == {} def test_batch_result_defaults(self): result = BatchResult(key="slug", status="success") assert result.response is None assert result.error is None def test_summary_total_tokens(self): s = BatchSummary(total_prompt_tokens=100, total_completion_tokens=50) assert s.total_tokens == 150 def test_summary_success_rate_all_success(self): s = BatchSummary(total=3, succeeded=3) assert s.success_rate() == 1.0 def test_summary_success_rate_with_failures(self): s = BatchSummary(total=4, succeeded=2, failed=2) assert s.success_rate() == pytest.approx(0.5) def test_summary_success_rate_all_skipped(self): s = BatchSummary(total=3, skipped=3) assert s.success_rate() == 1.0 def test_summary_success_rate_mixed(self): s = BatchSummary(total=5, succeeded=2, failed=1, skipped=2) # 3 attempted, 2 succeeded assert s.success_rate() == pytest.approx(2 / 3) # ── BatchEvaluator ────────────────────────────────────────────────── class TestBatchEvaluator: def test_evaluate_all_items(self): adapter = MockLLMAdapter("result") evaluator = BatchEvaluator(adapter) summary = evaluator.evaluate(_items(3)) assert summary.total == 3 assert summary.succeeded == 3 assert summary.failed == 0 assert summary.skipped == 0 assert len(summary.results) == 3 assert adapter.call_count == 3 def test_results_preserve_keys(self): adapter = MockLLMAdapter("ok") evaluator = BatchEvaluator(adapter) items = _items(2) summary = evaluator.evaluate(items) keys = [r.key for r in summary.results] assert keys == ["entity-0", "entity-1"] def test_results_preserve_metadata(self): adapter = MockLLMAdapter("ok") evaluator = BatchEvaluator(adapter) items = _items(1) summary = evaluator.evaluate(items) assert summary.results[0].metadata == {"index": 0} def test_response_content_available(self): adapter = MockLLMAdapter("evaluated text") evaluator = BatchEvaluator(adapter) summary = evaluator.evaluate(_items(1)) assert summary.results[0].response.content == "evaluated text" def test_token_usage_aggregated(self): adapter = MockLLMAdapter("result") evaluator = BatchEvaluator(adapter) summary = evaluator.evaluate(_items(3)) assert summary.total_prompt_tokens > 0 assert summary.total_completion_tokens > 0 assert summary.total_tokens == summary.total_prompt_tokens + summary.total_completion_tokens def test_config_passed_to_adapter(self): adapter = MockLLMAdapter("ok") config = RunConfig(temperature=0.1, max_tokens=500) evaluator = BatchEvaluator(adapter, config=config) evaluator.evaluate(_items(1)) assert adapter.last_config.temperature == 0.1 assert adapter.last_config.max_tokens == 500 # ── Incremental evaluation ────────────────────────────────────────── class TestIncrementalEvaluation: def test_skip_unchanged_items(self): adapter = MockLLMAdapter("result") previous = {"entity-0": "d0", "entity-1": "d1", "entity-2": "d2"} evaluator = BatchEvaluator(adapter, previous_digests=previous) summary = evaluator.evaluate(_items(3)) assert summary.skipped == 3 assert summary.succeeded == 0 assert adapter.call_count == 0 def test_evaluate_changed_items(self): adapter = MockLLMAdapter("result") # Only entity-0 has matching digest previous = {"entity-0": "d0"} evaluator = BatchEvaluator(adapter, previous_digests=previous) summary = evaluator.evaluate(_items(3)) assert summary.skipped == 1 assert summary.succeeded == 2 assert adapter.call_count == 2 def test_evaluate_new_items(self): adapter = MockLLMAdapter("result") # Previous has different keys previous = {"old-entity": "old-digest"} evaluator = BatchEvaluator(adapter, previous_digests=previous) summary = evaluator.evaluate(_items(2)) assert summary.skipped == 0 assert summary.succeeded == 2 def test_changed_digest_not_skipped(self): adapter = MockLLMAdapter("result") # Same key but different digest previous = {"entity-0": "old-digest"} evaluator = BatchEvaluator(adapter, previous_digests=previous) summary = evaluator.evaluate(_items(1)) assert summary.skipped == 0 assert summary.succeeded == 1 def test_empty_digest_not_skipped(self): adapter = MockLLMAdapter("result") previous = {"entity-0": "d0"} evaluator = BatchEvaluator(adapter, previous_digests=previous) item = BatchItem(key="entity-0", prompt="eval", content_digest="") summary = evaluator.evaluate([item]) assert summary.skipped == 0 assert summary.succeeded == 1 def test_skipped_status_in_result(self): adapter = MockLLMAdapter("result") previous = {"entity-0": "d0"} evaluator = BatchEvaluator(adapter, previous_digests=previous) summary = evaluator.evaluate(_items(1)) assert summary.results[0].status == "skipped" assert summary.results[0].response is None # ── Error handling ────────────────────────────────────────────────── class TestBatchErrorHandling: def test_error_captured_not_raised(self): adapter = ErrorLLMAdapter("kaboom") evaluator = BatchEvaluator(adapter) summary = evaluator.evaluate(_items(2)) assert summary.failed == 2 assert summary.succeeded == 0 def test_error_message_in_result(self): adapter = ErrorLLMAdapter("something went wrong") evaluator = BatchEvaluator(adapter) summary = evaluator.evaluate(_items(1)) assert summary.results[0].status == "error" assert "something went wrong" in summary.results[0].error def test_error_does_not_stop_batch(self): """One failing item doesn't prevent others from running.""" call_count = 0 class FailOnFirstAdapter(MockLLMAdapter): def execute_prompt(self, prompt, config): nonlocal call_count call_count += 1 if call_count == 1: raise RuntimeError("first fails") return super().execute_prompt(prompt, config) adapter = FailOnFirstAdapter("ok") evaluator = BatchEvaluator(adapter) summary = evaluator.evaluate(_items(3)) assert summary.failed == 1 assert summary.succeeded == 2 assert summary.results[0].status == "error" assert summary.results[1].status == "success" assert summary.results[2].status == "success" # ── Progress callback ─────────────────────────────────────────────── class TestProgressCallback: def test_callback_called_for_each_item(self): calls = [] adapter = MockLLMAdapter("ok") evaluator = BatchEvaluator( adapter, progress_callback=lambda done, total, result: calls.append( (done, total, result.key) ), ) evaluator.evaluate(_items(3)) assert len(calls) == 3 assert calls[0] == (1, 3, "entity-0") assert calls[1] == (2, 3, "entity-1") assert calls[2] == (3, 3, "entity-2") def test_callback_receives_result(self): results = [] adapter = MockLLMAdapter("ok") evaluator = BatchEvaluator( adapter, progress_callback=lambda done, total, result: results.append(result), ) evaluator.evaluate(_items(2)) assert all(isinstance(r, BatchResult) for r in results) assert results[0].status == "success" def test_no_callback_no_error(self): adapter = MockLLMAdapter("ok") evaluator = BatchEvaluator(adapter) # Should work fine without callback summary = evaluator.evaluate(_items(1)) assert summary.succeeded == 1 # ── Empty batch ───────────────────────────────────────────────────── class TestEmptyBatch: def test_empty_items(self): adapter = MockLLMAdapter("ok") evaluator = BatchEvaluator(adapter) summary = evaluator.evaluate([]) assert summary.total == 0 assert summary.succeeded == 0 assert summary.results == [] assert adapter.call_count == 0