feat(prompts): implement Phase 7 - Quality & Validation (FR-9, FR-10)
Some checks failed
Test Suite / unit-tests (3.11) (push) Has been cancelled
Test Suite / unit-tests (3.12) (push) Has been cancelled
Test Suite / integration-tests (push) Has been cancelled
Test Suite / e2e-tests (push) Has been cancelled
Test Suite / performance-tests (push) Has been cancelled
Test Suite / code-quality (push) Has been cancelled
Test Suite / security-scan (push) Has been cancelled
Test Suite / test-summary (push) Has been cancelled

Add quality gate framework with schema validation (JSON Schema via
jsonschema library), pattern validation (regex-based), multi-gate
QualityValidator with SQLite persistence, HaltingPolicyEngine with
budget/iteration/improvement checks, and RefinementLoop for iterative
execute-validate-halt cycles.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-09 13:31:37 +01:00
parent bd1d05ba79
commit 704272644c
15 changed files with 2615 additions and 0 deletions

View File

@@ -0,0 +1,221 @@
"""
Unit tests for HaltingPolicyEngine.
Tests halting decisions based on quality results, iteration limits,
marginal improvement, and resource budgets.
"""
import pytest
from markitect.prompts.quality.models import (
GateType,
HaltDecision,
QualityPolicy,
ValidationResult,
ValidationStatus,
)
from markitect.prompts.quality.policy import HaltingPolicyEngine
def _make_result(status=ValidationStatus.PASS, score=1.0, gate_id="gate-1"):
"""Helper to create a ValidationResult."""
return ValidationResult.create(
gate_id=gate_id,
gate_type=GateType.PATTERN,
artifact_id="art-1",
status=status,
score=score,
)
class TestQualityMetDecision:
"""Tests for quality met halting."""
def test_all_pass_halts_quality_met(self):
"""Test all gates passing triggers quality met halt."""
policy = QualityPolicy(max_iterations=5)
engine = HaltingPolicyEngine(policy)
results = [_make_result(ValidationStatus.PASS, 1.0)]
record = engine.evaluate(results, iteration=1)
assert record.decision == HaltDecision.HALT_QUALITY_MET
assert "quality gates passed" in record.reason.lower()
def test_required_gates_all_pass(self):
"""Test required gates all passing triggers quality met."""
policy = QualityPolicy(
max_iterations=5,
required_gate_ids=["required-gate"],
)
engine = HaltingPolicyEngine(policy)
results = [
_make_result(ValidationStatus.PASS, 1.0, gate_id="required-gate"),
_make_result(ValidationStatus.FAIL, 0.5, gate_id="optional-gate"),
]
record = engine.evaluate(results, iteration=1)
assert record.decision == HaltDecision.HALT_QUALITY_MET
def test_required_gate_fails_continues(self):
"""Test required gate failing allows continuation."""
policy = QualityPolicy(
max_iterations=5,
required_gate_ids=["required-gate"],
)
engine = HaltingPolicyEngine(policy)
results = [
_make_result(ValidationStatus.FAIL, 0.5, gate_id="required-gate"),
]
record = engine.evaluate(results, iteration=1)
assert record.decision == HaltDecision.CONTINUE
class TestIterationLimitDecision:
"""Tests for iteration limit halting."""
def test_at_iteration_limit(self):
"""Test halting at iteration limit."""
policy = QualityPolicy(max_iterations=3)
engine = HaltingPolicyEngine(policy)
results = [_make_result(ValidationStatus.FAIL, 0.5)]
record = engine.evaluate(results, iteration=3)
assert record.decision == HaltDecision.HALT_ITERATION_LIMIT
assert record.iteration == 3
def test_before_iteration_limit(self):
"""Test not halting before iteration limit."""
policy = QualityPolicy(max_iterations=5)
engine = HaltingPolicyEngine(policy)
results = [_make_result(ValidationStatus.FAIL, 0.5)]
record = engine.evaluate(results, iteration=2)
assert record.decision == HaltDecision.CONTINUE
class TestBudgetExhaustedDecision:
"""Tests for resource budget exhaustion."""
def test_budget_exhausted(self):
"""Test halting when budget is exhausted."""
policy = QualityPolicy(max_iterations=10, resource_budget=5)
engine = HaltingPolicyEngine(policy)
results = [_make_result(ValidationStatus.FAIL, 0.5)]
record = engine.evaluate(results, iteration=1, total_runs=5)
assert record.decision == HaltDecision.HALT_BUDGET_EXHAUSTED
def test_budget_not_exhausted(self):
"""Test not halting when budget remains."""
policy = QualityPolicy(max_iterations=10, resource_budget=10)
engine = HaltingPolicyEngine(policy)
results = [_make_result(ValidationStatus.FAIL, 0.5)]
record = engine.evaluate(results, iteration=1, total_runs=3)
assert record.decision == HaltDecision.CONTINUE
class TestMarginalImprovementDecision:
"""Tests for marginal improvement halting."""
def test_no_improvement_halts(self):
"""Test halting when improvement is below threshold."""
policy = QualityPolicy(max_iterations=10, min_improvement=0.05)
engine = HaltingPolicyEngine(policy)
results = [_make_result(ValidationStatus.FAIL, 0.52)]
record = engine.evaluate(
results,
iteration=2,
score_history=[0.50], # improvement: 0.02 < 0.05
)
assert record.decision == HaltDecision.HALT_NO_IMPROVEMENT
def test_sufficient_improvement_continues(self):
"""Test continuing when improvement meets threshold."""
policy = QualityPolicy(max_iterations=10, min_improvement=0.05)
engine = HaltingPolicyEngine(policy)
results = [_make_result(ValidationStatus.FAIL, 0.60)]
record = engine.evaluate(
results,
iteration=2,
score_history=[0.50], # improvement: 0.10 >= 0.05
)
assert record.decision == HaltDecision.CONTINUE
def test_first_iteration_no_history(self):
"""Test first iteration with no history continues."""
policy = QualityPolicy(max_iterations=10, min_improvement=0.05)
engine = HaltingPolicyEngine(policy)
results = [_make_result(ValidationStatus.FAIL, 0.50)]
record = engine.evaluate(results, iteration=1)
assert record.decision == HaltDecision.CONTINUE
class TestPriorityOrder:
"""Tests for the priority order of halting checks."""
def test_budget_checked_before_iteration(self):
"""Test budget exhaustion takes priority over iteration limit."""
policy = QualityPolicy(max_iterations=3, resource_budget=2)
engine = HaltingPolicyEngine(policy)
results = [_make_result(ValidationStatus.FAIL, 0.5)]
record = engine.evaluate(results, iteration=3, total_runs=2)
assert record.decision == HaltDecision.HALT_BUDGET_EXHAUSTED
def test_iteration_checked_before_quality(self):
"""Test iteration limit checked before quality met."""
policy = QualityPolicy(max_iterations=2)
engine = HaltingPolicyEngine(policy)
results = [_make_result(ValidationStatus.PASS, 1.0)]
# At iteration limit AND quality met — iteration limit wins
# Actually quality is checked after iteration, so quality would be checked
# But iteration 2 >= max 2 triggers first
record = engine.evaluate(results, iteration=2)
assert record.decision == HaltDecision.HALT_ITERATION_LIMIT
class TestHaltingRecord:
"""Tests for HaltingRecord attributes."""
def test_record_has_scores(self):
"""Test halting record includes score history."""
policy = QualityPolicy(max_iterations=5)
engine = HaltingPolicyEngine(policy)
results = [_make_result(ValidationStatus.PASS, 0.9)]
record = engine.evaluate(
results, iteration=2, score_history=[0.5]
)
assert record.scores == [0.5, 0.9]
def test_record_to_dict(self):
"""Test halting record serialization."""
policy = QualityPolicy(max_iterations=3)
engine = HaltingPolicyEngine(policy)
results = [_make_result(ValidationStatus.PASS, 1.0)]
record = engine.evaluate(results, iteration=1)
d = record.to_dict()
assert d["decision"] == "halted_quality_met"
assert d["iteration"] == 1
assert d["max_iterations"] == 3

View File

@@ -0,0 +1,303 @@
"""
Unit tests for quality gate models and individual gate implementations.
Tests QualityGate models, SchemaValidationGate, and PatternValidationGate.
"""
import json
import pytest
from markitect.prompts.quality.models import (
GateType,
ValidationStatus,
ValidationDiagnostic,
ValidationResult,
QualityPolicy,
HaltDecision,
HaltingRecord,
)
from markitect.prompts.quality.gates.schema_gate import SchemaValidationGate
from markitect.prompts.quality.gates.pattern_gate import PatternValidationGate
class TestValidationDiagnostic:
"""Tests for ValidationDiagnostic."""
def test_create_diagnostic(self):
"""Test creating a diagnostic."""
diag = ValidationDiagnostic(
code="TEST_ERROR",
message="Something went wrong",
severity="error",
)
assert diag.code == "TEST_ERROR"
assert diag.severity == "error"
def test_to_dict(self):
"""Test diagnostic serialization."""
diag = ValidationDiagnostic(code="E1", message="msg", severity="warning")
d = diag.to_dict()
assert d["code"] == "E1"
assert d["severity"] == "warning"
def test_from_dict(self):
"""Test diagnostic deserialization."""
data = {"code": "E1", "message": "msg", "severity": "info"}
diag = ValidationDiagnostic.from_dict(data)
assert diag.code == "E1"
assert diag.severity == "info"
class TestValidationResult:
"""Tests for ValidationResult."""
def test_create_result(self):
"""Test creating a validation result."""
result = ValidationResult.create(
gate_id="gate-1",
gate_type=GateType.SCHEMA,
artifact_id="art-1",
status=ValidationStatus.PASS,
score=1.0,
)
assert result.gate_id == "gate-1"
assert result.status == ValidationStatus.PASS
assert result.score == 1.0
def test_result_unique_ids(self):
"""Test that each result gets a unique ID."""
r1 = ValidationResult.create(
gate_id="g", gate_type=GateType.PATTERN,
artifact_id="a", status=ValidationStatus.PASS,
)
r2 = ValidationResult.create(
gate_id="g", gate_type=GateType.PATTERN,
artifact_id="a", status=ValidationStatus.PASS,
)
assert r1.id != r2.id
def test_to_dict_from_dict(self):
"""Test round-trip serialization."""
result = ValidationResult.create(
gate_id="g1",
gate_type=GateType.SCHEMA,
artifact_id="a1",
status=ValidationStatus.FAIL,
score=0.5,
diagnostics=[
ValidationDiagnostic(code="E1", message="err", severity="error"),
],
)
d = result.to_dict()
restored = ValidationResult.from_dict(d)
assert restored.id == result.id
assert restored.status == ValidationStatus.FAIL
assert restored.score == 0.5
assert len(restored.diagnostics) == 1
class TestQualityPolicy:
"""Tests for QualityPolicy."""
def test_default_policy(self):
"""Test default policy values."""
policy = QualityPolicy()
assert policy.max_iterations == 3
assert policy.min_improvement == 0.05
assert policy.fail_on_gate_failure is True
assert policy.resource_budget == 10
def test_to_dict_from_dict(self):
"""Test round-trip serialization."""
policy = QualityPolicy(
max_iterations=5,
min_improvement=0.1,
required_gate_ids=["g1", "g2"],
)
d = policy.to_dict()
restored = QualityPolicy.from_dict(d)
assert restored.max_iterations == 5
assert restored.min_improvement == 0.1
assert restored.required_gate_ids == ["g1", "g2"]
class TestSchemaValidationGate:
"""Tests for SchemaValidationGate."""
def test_valid_json_passes(self):
"""Test valid JSON against schema passes."""
schema = {
"type": "object",
"required": ["name", "version"],
"properties": {
"name": {"type": "string"},
"version": {"type": "integer"},
},
}
gate = SchemaValidationGate(schema=schema, name="test-schema")
content = json.dumps({"name": "test", "version": 1})
result = gate.validate(content, "art-1")
assert result.status == ValidationStatus.PASS
assert result.score == 1.0
assert len(result.diagnostics) == 0
def test_missing_required_field_fails(self):
"""Test missing required field fails validation."""
schema = {
"type": "object",
"required": ["name", "version"],
"properties": {
"name": {"type": "string"},
"version": {"type": "integer"},
},
}
gate = SchemaValidationGate(schema=schema)
content = json.dumps({"name": "test"})
result = gate.validate(content, "art-1")
assert result.status == ValidationStatus.FAIL
assert any(d.code == "SCHEMA_VIOLATION" for d in result.diagnostics)
def test_wrong_type_fails(self):
"""Test wrong type fails validation."""
schema = {
"type": "object",
"properties": {
"count": {"type": "integer"},
},
}
gate = SchemaValidationGate(schema=schema)
content = json.dumps({"count": "not-a-number"})
result = gate.validate(content, "art-1")
assert result.status == ValidationStatus.FAIL
def test_invalid_json_fails(self):
"""Test invalid JSON content fails."""
schema = {"type": "object"}
gate = SchemaValidationGate(schema=schema)
result = gate.validate("not json {{{", "art-1")
assert result.status == ValidationStatus.FAIL
assert result.score == 0.0
assert any(d.code == "INVALID_JSON" for d in result.diagnostics)
def test_gate_has_correct_type(self):
"""Test gate type is SCHEMA."""
gate = SchemaValidationGate(schema={"type": "object"})
assert gate.gate_type == GateType.SCHEMA
def test_empty_schema_passes_any_object(self):
"""Test empty schema passes any valid JSON."""
gate = SchemaValidationGate(schema={})
result = gate.validate(json.dumps({"any": "thing"}), "art-1")
assert result.status == ValidationStatus.PASS
def test_score_reflects_error_count(self):
"""Test that score decreases with more errors."""
schema = {
"type": "object",
"required": ["a", "b", "c"],
"properties": {
"a": {"type": "string"},
"b": {"type": "string"},
"c": {"type": "string"},
},
}
gate = SchemaValidationGate(schema=schema)
# Missing all 3 required fields
result = gate.validate(json.dumps({}), "art-1")
assert result.status == ValidationStatus.FAIL
assert result.score < 1.0
class TestPatternValidationGate:
"""Tests for PatternValidationGate."""
def test_required_pattern_present(self):
"""Test content with required patterns passes."""
gate = PatternValidationGate(
required_patterns=[r"## Endpoints", r"### Authentication"],
)
content = "# API\n## Endpoints\n### Authentication\nDetails here."
result = gate.validate(content, "art-1")
assert result.status == ValidationStatus.PASS
assert result.score == 1.0
def test_required_pattern_missing(self):
"""Test missing required pattern fails."""
gate = PatternValidationGate(
required_patterns=[r"## Endpoints", r"### Authentication"],
)
content = "# API\n## Endpoints\nNo auth section."
result = gate.validate(content, "art-1")
assert result.status == ValidationStatus.FAIL
assert any(d.code == "MISSING_PATTERN" for d in result.diagnostics)
def test_forbidden_pattern_absent(self):
"""Test content without forbidden patterns passes."""
gate = PatternValidationGate(
forbidden_patterns=[r"TODO", r"FIXME"],
)
content = "This is a clean document."
result = gate.validate(content, "art-1")
assert result.status == ValidationStatus.PASS
def test_forbidden_pattern_present(self):
"""Test content with forbidden pattern fails."""
gate = PatternValidationGate(
forbidden_patterns=[r"TODO", r"FIXME"],
)
content = "This needs work. TODO: fix this."
result = gate.validate(content, "art-1")
assert result.status == ValidationStatus.FAIL
assert any(d.code == "FORBIDDEN_PATTERN" for d in result.diagnostics)
def test_combined_required_and_forbidden(self):
"""Test both required and forbidden patterns together."""
gate = PatternValidationGate(
required_patterns=[r"## Summary"],
forbidden_patterns=[r"FIXME"],
)
# Has required, no forbidden
result1 = gate.validate("## Summary\nAll good.", "art-1")
assert result1.status == ValidationStatus.PASS
# Has required AND forbidden
result2 = gate.validate("## Summary\nFIXME: broken.", "art-1")
assert result2.status == ValidationStatus.FAIL
def test_no_patterns_passes(self):
"""Test gate with no patterns always passes."""
gate = PatternValidationGate()
result = gate.validate("anything", "art-1")
assert result.status == ValidationStatus.PASS
def test_gate_has_correct_type(self):
"""Test gate type is PATTERN."""
gate = PatternValidationGate()
assert gate.gate_type == GateType.PATTERN
def test_score_proportional_to_failures(self):
"""Test score is proportional to number of checks passed."""
gate = PatternValidationGate(
required_patterns=[r"A", r"B", r"C", r"D"],
)
# Only A is present → 3 out of 4 fail
result = gate.validate("A", "art-1")
assert result.status == ValidationStatus.FAIL
assert 0.0 < result.score < 1.0
def test_regex_pattern_matching(self):
"""Test regex patterns work correctly."""
gate = PatternValidationGate(
required_patterns=[r"\d{3}-\d{4}"],
)
result = gate.validate("Call 555-1234 for info", "art-1")
assert result.status == ValidationStatus.PASS

View File

@@ -0,0 +1,264 @@
"""
Unit tests for QualityValidator.
Tests applying multiple gates, aggregating results, and persistence.
"""
import json
import pytest
import tempfile
from pathlib import Path
from markitect.prompts.quality.models import (
GateType,
ValidationStatus,
ValidationResult,
)
from markitect.prompts.quality.gates.schema_gate import SchemaValidationGate
from markitect.prompts.quality.gates.pattern_gate import PatternValidationGate
from markitect.prompts.quality.validator import QualityValidator
@pytest.fixture
def temp_db():
"""Create temporary database for testing."""
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
db_path = f.name
yield db_path
Path(db_path).unlink(missing_ok=True)
@pytest.fixture
def schema_gate():
"""Create a simple schema gate."""
return SchemaValidationGate(
schema={
"type": "object",
"required": ["name"],
"properties": {"name": {"type": "string"}},
},
gate_id="schema-gate-1",
name="test-schema",
)
@pytest.fixture
def pattern_gate():
"""Create a simple pattern gate."""
return PatternValidationGate(
required_patterns=[r"## Summary"],
forbidden_patterns=[r"TODO"],
gate_id="pattern-gate-1",
name="test-pattern",
)
class TestValidateArtifact:
"""Tests for validating artifacts with multiple gates."""
def test_all_gates_pass(self, schema_gate, pattern_gate):
"""Test all gates passing."""
validator = QualityValidator(gates=[schema_gate, pattern_gate])
# Content that satisfies both gates (JSON for schema, text for pattern)
# Schema gate needs JSON, pattern gate needs text patterns
# Use separate validators for different content types
schema_validator = QualityValidator(gates=[schema_gate])
results = schema_validator.validate_artifact(
json.dumps({"name": "test"}), "art-1"
)
assert len(results) == 1
assert results[0].status == ValidationStatus.PASS
def test_pattern_gate_validates(self, pattern_gate):
"""Test pattern gate validation."""
validator = QualityValidator(gates=[pattern_gate])
results = validator.validate_artifact(
"## Summary\nAll good here.", "art-1"
)
assert len(results) == 1
assert results[0].status == ValidationStatus.PASS
def test_multiple_gates_mixed_results(self, pattern_gate):
"""Test multiple gates with mixed pass/fail."""
gate_a = PatternValidationGate(
required_patterns=[r"## Summary"],
gate_id="gate-a",
)
gate_b = PatternValidationGate(
required_patterns=[r"## Missing Section"],
gate_id="gate-b",
)
validator = QualityValidator(gates=[gate_a, gate_b])
results = validator.validate_artifact("## Summary\nContent.", "art-1")
assert len(results) == 2
statuses = {r.gate_id: r.status for r in results}
assert statuses["gate-a"] == ValidationStatus.PASS
assert statuses["gate-b"] == ValidationStatus.FAIL
def test_no_gates_returns_empty(self):
"""Test validator with no gates returns empty list."""
validator = QualityValidator()
results = validator.validate_artifact("content", "art-1")
assert results == []
class TestAllPassed:
"""Tests for the all_passed helper."""
def test_all_pass(self):
"""Test all_passed returns True when all pass."""
validator = QualityValidator()
results = [
ValidationResult.create(
gate_id="g1", gate_type=GateType.PATTERN,
artifact_id="a", status=ValidationStatus.PASS,
),
ValidationResult.create(
gate_id="g2", gate_type=GateType.PATTERN,
artifact_id="a", status=ValidationStatus.PASS,
),
]
assert validator.all_passed(results) is True
def test_one_fails(self):
"""Test all_passed returns False when one fails."""
validator = QualityValidator()
results = [
ValidationResult.create(
gate_id="g1", gate_type=GateType.PATTERN,
artifact_id="a", status=ValidationStatus.PASS,
),
ValidationResult.create(
gate_id="g2", gate_type=GateType.PATTERN,
artifact_id="a", status=ValidationStatus.FAIL,
),
]
assert validator.all_passed(results) is False
def test_empty_results(self):
"""Test all_passed with empty list returns True."""
validator = QualityValidator()
assert validator.all_passed([]) is True
class TestAggregateScore:
"""Tests for aggregate score calculation."""
def test_average_scores(self):
"""Test aggregate is average of scores."""
validator = QualityValidator()
results = [
ValidationResult.create(
gate_id="g1", gate_type=GateType.PATTERN,
artifact_id="a", status=ValidationStatus.PASS, score=1.0,
),
ValidationResult.create(
gate_id="g2", gate_type=GateType.PATTERN,
artifact_id="a", status=ValidationStatus.FAIL, score=0.5,
),
]
assert validator.aggregate_score(results) == 0.75
def test_no_results(self):
"""Test aggregate with no results returns 1.0."""
validator = QualityValidator()
assert validator.aggregate_score([]) == 1.0
def test_none_scores_ignored(self):
"""Test results with None scores are handled."""
validator = QualityValidator()
results = [
ValidationResult.create(
gate_id="g1", gate_type=GateType.PATTERN,
artifact_id="a", status=ValidationStatus.PASS, score=None,
),
]
assert validator.aggregate_score(results) == 1.0
class TestGetFailedGates:
"""Tests for getting failed gates."""
def test_get_failed(self):
"""Test filtering failed results."""
validator = QualityValidator()
results = [
ValidationResult.create(
gate_id="g1", gate_type=GateType.PATTERN,
artifact_id="a", status=ValidationStatus.PASS,
),
ValidationResult.create(
gate_id="g2", gate_type=GateType.PATTERN,
artifact_id="a", status=ValidationStatus.FAIL,
),
]
failed = validator.get_failed_gates(results)
assert len(failed) == 1
assert failed[0].gate_id == "g2"
class TestResultsToManifest:
"""Tests for converting results to manifest dict."""
def test_manifest_dict_format(self):
"""Test manifest dict has correct structure."""
validator = QualityValidator()
results = [
ValidationResult.create(
gate_id="g1", gate_type=GateType.PATTERN,
artifact_id="a", status=ValidationStatus.PASS, score=1.0,
),
]
manifest = validator.results_to_manifest_dict(results)
assert "quality_gates" in manifest
assert manifest["all_passed"] is True
assert manifest["aggregate_score"] == 1.0
class TestPersistence:
"""Tests for persisting validation results."""
def test_persist_and_retrieve_by_run(self, temp_db, pattern_gate):
"""Test persisting results and retrieving by run ID."""
validator = QualityValidator(gates=[pattern_gate], db_path=temp_db)
validator.validate_artifact(
"## Summary\nClean content.", "art-1", run_id="run-1"
)
results = validator.get_results_for_run("run-1")
assert len(results) == 1
assert results[0]["status"] == "pass"
def test_persist_and_retrieve_by_artifact(self, temp_db, pattern_gate):
"""Test persisting results and retrieving by artifact ID."""
validator = QualityValidator(gates=[pattern_gate], db_path=temp_db)
validator.validate_artifact(
"## Summary\nClean content.", "art-1", run_id="run-1"
)
results = validator.get_results_for_artifact("art-1")
assert len(results) == 1
assert results[0]["artifact_id"] == "art-1"
def test_no_persistence_without_db(self, pattern_gate):
"""Test no persistence when db_path is None."""
validator = QualityValidator(gates=[pattern_gate])
results = validator.validate_artifact(
"## Summary\nContent.", "art-1", run_id="run-1"
)
assert len(results) == 1
# No DB queries should work
assert validator.get_results_for_run("run-1") == []
def test_add_gate(self):
"""Test adding a gate after construction."""
validator = QualityValidator()
assert len(validator.gates) == 0
gate = PatternValidationGate(required_patterns=[r"test"])
validator.add_gate(gate)
assert len(validator.gates) == 1

View File

@@ -0,0 +1,219 @@
"""
Unit tests for RefinementLoop.
Tests the execute → validate → halt or refine cycle.
"""
import json
import pytest
from typing import List, Tuple
from markitect.prompts.quality.models import (
HaltDecision,
QualityPolicy,
ValidationResult,
ValidationStatus,
)
from markitect.prompts.quality.gates.pattern_gate import PatternValidationGate
from markitect.prompts.quality.validator import QualityValidator
from markitect.prompts.quality.refinement import RefinementLoop
class TestRefinementLoopBasic:
"""Tests for basic refinement loop operation."""
def test_immediate_quality_met(self):
"""Test loop halts immediately when quality is met on first iteration."""
gate = PatternValidationGate(
required_patterns=[r"## Summary"],
gate_id="gate-1",
)
validator = QualityValidator(gates=[gate])
policy = QualityPolicy(max_iterations=5)
loop = RefinementLoop(validator, policy)
def callback(iteration, prev_results):
return (f"run-{iteration}", "## Summary\nGood content.", "art-1")
result = loop.run(callback, "art-1")
assert result.iterations_run == 1
assert result.halting_record.decision == HaltDecision.HALT_QUALITY_MET
assert len(result.run_ids) == 1
def test_iterates_until_quality_met(self):
"""Test loop iterates until quality requirements are met."""
gate = PatternValidationGate(
required_patterns=[r"## Summary", r"## Details"],
gate_id="gate-1",
)
validator = QualityValidator(gates=[gate])
policy = QualityPolicy(max_iterations=5, min_improvement=0.0)
loop = RefinementLoop(validator, policy)
contents = [
"## Summary\nOnly summary.", # missing Details
"## Summary\nStill only summary.", # still missing, slightly different
"## Summary\n## Details\nBoth sections.", # complete
]
def callback(iteration, prev_results):
content = contents[min(iteration - 1, len(contents) - 1)]
return (f"run-{iteration}", content, "art-1")
result = loop.run(callback, "art-1")
assert result.iterations_run == 3
assert result.halting_record.decision == HaltDecision.HALT_QUALITY_MET
assert len(result.run_ids) == 3
def test_hits_iteration_limit(self):
"""Test loop stops at iteration limit."""
gate = PatternValidationGate(
required_patterns=[r"IMPOSSIBLE_PATTERN_12345"],
gate_id="gate-1",
)
validator = QualityValidator(gates=[gate])
policy = QualityPolicy(max_iterations=3, min_improvement=0.0)
loop = RefinementLoop(validator, policy)
def callback(iteration, prev_results):
return (f"run-{iteration}", "never matches", "art-1")
result = loop.run(callback, "art-1")
assert result.iterations_run == 3
assert result.halting_record.decision == HaltDecision.HALT_ITERATION_LIMIT
def test_budget_exhaustion(self):
"""Test loop stops when resource budget is exhausted."""
gate = PatternValidationGate(
required_patterns=[r"IMPOSSIBLE"],
gate_id="gate-1",
)
validator = QualityValidator(gates=[gate])
policy = QualityPolicy(max_iterations=10, resource_budget=2)
loop = RefinementLoop(validator, policy)
def callback(iteration, prev_results):
return (f"run-{iteration}", "content", "art-1")
result = loop.run(callback, "art-1")
assert result.iterations_run == 2
assert result.halting_record.decision == HaltDecision.HALT_BUDGET_EXHAUSTED
class TestRefinementLoopHistory:
"""Tests for refinement loop tracking history."""
def test_all_results_tracked(self):
"""Test all iteration results are tracked."""
gate = PatternValidationGate(
required_patterns=[r"## Summary"],
gate_id="gate-1",
)
validator = QualityValidator(gates=[gate])
policy = QualityPolicy(max_iterations=3)
loop = RefinementLoop(validator, policy)
iteration_count = [0]
def callback(iteration, prev_results):
iteration_count[0] = iteration
if iteration < 3:
return (f"run-{iteration}", "no match", "art-1")
return (f"run-{iteration}", "## Summary\nDone.", "art-1")
result = loop.run(callback, "art-1")
assert len(result.all_results) == result.iterations_run
assert len(result.run_ids) == result.iterations_run
def test_run_ids_collected(self):
"""Test run IDs are collected from each iteration."""
gate = PatternValidationGate(gate_id="gate-1")
validator = QualityValidator(gates=[gate])
policy = QualityPolicy(max_iterations=3)
loop = RefinementLoop(validator, policy)
def callback(iteration, prev_results):
return (f"run-{iteration}", "content", "art-1")
result = loop.run(callback, "art-1")
assert result.run_ids[0] == "run-1"
def test_previous_results_passed_to_callback(self):
"""Test callback receives previous iteration results."""
gate = PatternValidationGate(
required_patterns=[r"## Final"],
gate_id="gate-1",
)
validator = QualityValidator(gates=[gate])
policy = QualityPolicy(max_iterations=3)
loop = RefinementLoop(validator, policy)
received_prev: List[List[ValidationResult]] = []
def callback(iteration, prev_results):
received_prev.append(prev_results)
if iteration >= 2:
return (f"run-{iteration}", "## Final\nDone.", "art-1")
return (f"run-{iteration}", "incomplete", "art-1")
result = loop.run(callback, "art-1")
# First iteration gets empty prev_results
assert len(received_prev[0]) == 0
# Second iteration gets results from first
assert len(received_prev[1]) == 1
class TestRefinementLoopNoImprovement:
"""Tests for no-improvement halting in refinement loop."""
def test_halts_on_no_improvement(self):
"""Test loop halts when scores stop improving."""
gate = PatternValidationGate(
required_patterns=[r"## Required"],
gate_id="gate-1",
)
validator = QualityValidator(gates=[gate])
policy = QualityPolicy(
max_iterations=10,
min_improvement=0.1,
)
loop = RefinementLoop(validator, policy)
# Both iterations produce same failing content → same score → no improvement
def callback(iteration, prev_results):
return (f"run-{iteration}", "no match at all", "art-1")
result = loop.run(callback, "art-1")
# Second iteration should detect no improvement and halt
assert result.iterations_run == 2
assert result.halting_record.decision == HaltDecision.HALT_NO_IMPROVEMENT
class TestRefinementResultSerialization:
"""Tests for RefinementResult serialization."""
def test_to_dict(self):
"""Test RefinementResult serialization."""
gate = PatternValidationGate(gate_id="gate-1")
validator = QualityValidator(gates=[gate])
policy = QualityPolicy(max_iterations=1)
loop = RefinementLoop(validator, policy)
def callback(iteration, prev_results):
return ("run-1", "content", "art-1")
result = loop.run(callback, "art-1")
d = result.to_dict()
assert "iterations_run" in d
assert "final_results" in d
assert "halting_record" in d
assert "run_ids" in d