""" Unit tests for RefinementLoop. Tests the execute → validate → halt or refine cycle. """ import json import pytest from typing import List, Tuple from markitect.prompts.quality.models import ( HaltDecision, QualityPolicy, ValidationResult, ValidationStatus, ) from markitect.prompts.quality.gates.pattern_gate import PatternValidationGate from markitect.prompts.quality.validator import QualityValidator from markitect.prompts.quality.refinement import RefinementLoop class TestRefinementLoopBasic: """Tests for basic refinement loop operation.""" def test_immediate_quality_met(self): """Test loop halts immediately when quality is met on first iteration.""" gate = PatternValidationGate( required_patterns=[r"## Summary"], gate_id="gate-1", ) validator = QualityValidator(gates=[gate]) policy = QualityPolicy(max_iterations=5) loop = RefinementLoop(validator, policy) def callback(iteration, prev_results): return (f"run-{iteration}", "## Summary\nGood content.", "art-1") result = loop.run(callback, "art-1") assert result.iterations_run == 1 assert result.halting_record.decision == HaltDecision.HALT_QUALITY_MET assert len(result.run_ids) == 1 def test_iterates_until_quality_met(self): """Test loop iterates until quality requirements are met.""" gate = PatternValidationGate( required_patterns=[r"## Summary", r"## Details"], gate_id="gate-1", ) validator = QualityValidator(gates=[gate]) policy = QualityPolicy(max_iterations=5, min_improvement=0.0) loop = RefinementLoop(validator, policy) contents = [ "## Summary\nOnly summary.", # missing Details "## Summary\nStill only summary.", # still missing, slightly different "## Summary\n## Details\nBoth sections.", # complete ] def callback(iteration, prev_results): content = contents[min(iteration - 1, len(contents) - 1)] return (f"run-{iteration}", content, "art-1") result = loop.run(callback, "art-1") assert result.iterations_run == 3 assert result.halting_record.decision == HaltDecision.HALT_QUALITY_MET assert len(result.run_ids) == 3 def test_hits_iteration_limit(self): """Test loop stops at iteration limit.""" gate = PatternValidationGate( required_patterns=[r"IMPOSSIBLE_PATTERN_12345"], gate_id="gate-1", ) validator = QualityValidator(gates=[gate]) policy = QualityPolicy(max_iterations=3, min_improvement=0.0) loop = RefinementLoop(validator, policy) def callback(iteration, prev_results): return (f"run-{iteration}", "never matches", "art-1") result = loop.run(callback, "art-1") assert result.iterations_run == 3 assert result.halting_record.decision == HaltDecision.HALT_ITERATION_LIMIT def test_budget_exhaustion(self): """Test loop stops when resource budget is exhausted.""" gate = PatternValidationGate( required_patterns=[r"IMPOSSIBLE"], gate_id="gate-1", ) validator = QualityValidator(gates=[gate]) policy = QualityPolicy(max_iterations=10, resource_budget=2) loop = RefinementLoop(validator, policy) def callback(iteration, prev_results): return (f"run-{iteration}", "content", "art-1") result = loop.run(callback, "art-1") assert result.iterations_run == 2 assert result.halting_record.decision == HaltDecision.HALT_BUDGET_EXHAUSTED class TestRefinementLoopHistory: """Tests for refinement loop tracking history.""" def test_all_results_tracked(self): """Test all iteration results are tracked.""" gate = PatternValidationGate( required_patterns=[r"## Summary"], gate_id="gate-1", ) validator = QualityValidator(gates=[gate]) policy = QualityPolicy(max_iterations=3) loop = RefinementLoop(validator, policy) iteration_count = [0] def callback(iteration, prev_results): iteration_count[0] = iteration if iteration < 3: return (f"run-{iteration}", "no match", "art-1") return (f"run-{iteration}", "## Summary\nDone.", "art-1") result = loop.run(callback, "art-1") assert len(result.all_results) == result.iterations_run assert len(result.run_ids) == result.iterations_run def test_run_ids_collected(self): """Test run IDs are collected from each iteration.""" gate = PatternValidationGate(gate_id="gate-1") validator = QualityValidator(gates=[gate]) policy = QualityPolicy(max_iterations=3) loop = RefinementLoop(validator, policy) def callback(iteration, prev_results): return (f"run-{iteration}", "content", "art-1") result = loop.run(callback, "art-1") assert result.run_ids[0] == "run-1" def test_previous_results_passed_to_callback(self): """Test callback receives previous iteration results.""" gate = PatternValidationGate( required_patterns=[r"## Final"], gate_id="gate-1", ) validator = QualityValidator(gates=[gate]) policy = QualityPolicy(max_iterations=3) loop = RefinementLoop(validator, policy) received_prev: List[List[ValidationResult]] = [] def callback(iteration, prev_results): received_prev.append(prev_results) if iteration >= 2: return (f"run-{iteration}", "## Final\nDone.", "art-1") return (f"run-{iteration}", "incomplete", "art-1") result = loop.run(callback, "art-1") # First iteration gets empty prev_results assert len(received_prev[0]) == 0 # Second iteration gets results from first assert len(received_prev[1]) == 1 class TestRefinementLoopNoImprovement: """Tests for no-improvement halting in refinement loop.""" def test_halts_on_no_improvement(self): """Test loop halts when scores stop improving.""" gate = PatternValidationGate( required_patterns=[r"## Required"], gate_id="gate-1", ) validator = QualityValidator(gates=[gate]) policy = QualityPolicy( max_iterations=10, min_improvement=0.1, ) loop = RefinementLoop(validator, policy) # Both iterations produce same failing content → same score → no improvement def callback(iteration, prev_results): return (f"run-{iteration}", "no match at all", "art-1") result = loop.run(callback, "art-1") # Second iteration should detect no improvement and halt assert result.iterations_run == 2 assert result.halting_record.decision == HaltDecision.HALT_NO_IMPROVEMENT class TestRefinementResultSerialization: """Tests for RefinementResult serialization.""" def test_to_dict(self): """Test RefinementResult serialization.""" gate = PatternValidationGate(gate_id="gate-1") validator = QualityValidator(gates=[gate]) policy = QualityPolicy(max_iterations=1) loop = RefinementLoop(validator, policy) def callback(iteration, prev_results): return ("run-1", "content", "art-1") result = loop.run(callback, "art-1") d = result.to_dict() assert "iterations_run" in d assert "final_results" in d assert "halting_record" in d assert "run_ids" in d