feat(prompts): implement Phase 7 - Quality & Validation (FR-9, FR-10)
Some checks failed
Test Suite / unit-tests (3.11) (push) Has been cancelled
Test Suite / unit-tests (3.12) (push) Has been cancelled
Test Suite / integration-tests (push) Has been cancelled
Test Suite / e2e-tests (push) Has been cancelled
Test Suite / performance-tests (push) Has been cancelled
Test Suite / code-quality (push) Has been cancelled
Test Suite / security-scan (push) Has been cancelled
Test Suite / test-summary (push) Has been cancelled
Some checks failed
Test Suite / unit-tests (3.11) (push) Has been cancelled
Test Suite / unit-tests (3.12) (push) Has been cancelled
Test Suite / integration-tests (push) Has been cancelled
Test Suite / e2e-tests (push) Has been cancelled
Test Suite / performance-tests (push) Has been cancelled
Test Suite / code-quality (push) Has been cancelled
Test Suite / security-scan (push) Has been cancelled
Test Suite / test-summary (push) Has been cancelled
Add quality gate framework with schema validation (JSON Schema via jsonschema library), pattern validation (regex-based), multi-gate QualityValidator with SQLite persistence, HaltingPolicyEngine with budget/iteration/improvement checks, and RefinementLoop for iterative execute-validate-halt cycles. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
53
markitect/prompts/quality/__init__.py
Normal file
53
markitect/prompts/quality/__init__.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""
|
||||
Quality validation and halting policies for prompt artifacts.
|
||||
|
||||
Implements FR-9: QualityGate Validation
|
||||
Implements FR-10: Halting and Refinement Policy
|
||||
|
||||
- FR-9.1: Schema validation against generated artifacts
|
||||
- FR-9.2: Multiple QualityGates per artifact
|
||||
- FR-9.3: Record pass/fail results and diagnostics
|
||||
- FR-9.4: Halting policies based on QualityGate results
|
||||
- FR-10.1: Configurable QualityPolicies
|
||||
- FR-10.2: Halting decisions (quality, improvement, iterations, budget)
|
||||
- FR-10.3: Record halting decisions in RunManifest
|
||||
"""
|
||||
|
||||
from markitect.prompts.quality.models import (
|
||||
GateType,
|
||||
ValidationStatus,
|
||||
HaltDecision,
|
||||
ValidationDiagnostic,
|
||||
ValidationResult,
|
||||
QualityGate,
|
||||
QualityPolicy,
|
||||
HaltingRecord,
|
||||
RefinementResult,
|
||||
)
|
||||
from markitect.prompts.quality.gates.schema_gate import SchemaValidationGate
|
||||
from markitect.prompts.quality.gates.pattern_gate import PatternValidationGate
|
||||
from markitect.prompts.quality.validator import QualityValidator
|
||||
from markitect.prompts.quality.policy import HaltingPolicyEngine
|
||||
from markitect.prompts.quality.refinement import RefinementLoop
|
||||
|
||||
__all__ = [
|
||||
# Models
|
||||
"GateType",
|
||||
"ValidationStatus",
|
||||
"HaltDecision",
|
||||
"ValidationDiagnostic",
|
||||
"ValidationResult",
|
||||
"QualityGate",
|
||||
"QualityPolicy",
|
||||
"HaltingRecord",
|
||||
"RefinementResult",
|
||||
# Gates
|
||||
"SchemaValidationGate",
|
||||
"PatternValidationGate",
|
||||
# Validator
|
||||
"QualityValidator",
|
||||
# Policy
|
||||
"HaltingPolicyEngine",
|
||||
# Refinement
|
||||
"RefinementLoop",
|
||||
]
|
||||
13
markitect/prompts/quality/gates/__init__.py
Normal file
13
markitect/prompts/quality/gates/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""
|
||||
Quality gate implementations.
|
||||
|
||||
Provides SchemaValidationGate and PatternValidationGate.
|
||||
"""
|
||||
|
||||
from markitect.prompts.quality.gates.schema_gate import SchemaValidationGate
|
||||
from markitect.prompts.quality.gates.pattern_gate import PatternValidationGate
|
||||
|
||||
__all__ = [
|
||||
"SchemaValidationGate",
|
||||
"PatternValidationGate",
|
||||
]
|
||||
109
markitect/prompts/quality/gates/pattern_gate.py
Normal file
109
markitect/prompts/quality/gates/pattern_gate.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""
|
||||
Pattern validation quality gate.
|
||||
|
||||
Validates content against required and forbidden regex patterns.
|
||||
"""
|
||||
|
||||
import re
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from markitect.prompts.quality.models import (
|
||||
GateType,
|
||||
QualityGate,
|
||||
ValidationDiagnostic,
|
||||
ValidationResult,
|
||||
ValidationStatus,
|
||||
)
|
||||
|
||||
|
||||
class PatternValidationGate(QualityGate):
|
||||
"""
|
||||
Validates artifact content against regex patterns.
|
||||
|
||||
Checks that all required patterns are present and no forbidden
|
||||
patterns are found.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
required_patterns: Optional[List[str]] = None,
|
||||
forbidden_patterns: Optional[List[str]] = None,
|
||||
gate_id: Optional[str] = None,
|
||||
name: str = "pattern",
|
||||
):
|
||||
"""
|
||||
Initialize with pattern lists.
|
||||
|
||||
Args:
|
||||
required_patterns: Regex patterns that must be found in content
|
||||
forbidden_patterns: Regex patterns that must NOT be found in content
|
||||
gate_id: Optional gate identifier
|
||||
name: Human-readable gate name
|
||||
"""
|
||||
super().__init__(
|
||||
gate_id=gate_id or str(uuid.uuid4()),
|
||||
name=name,
|
||||
gate_type=GateType.PATTERN,
|
||||
)
|
||||
self.required_patterns = required_patterns or []
|
||||
self.forbidden_patterns = forbidden_patterns or []
|
||||
|
||||
def validate(self, content: str, artifact_id: str) -> ValidationResult:
|
||||
"""
|
||||
Validate content against required and forbidden patterns.
|
||||
|
||||
Args:
|
||||
content: Content string to validate
|
||||
artifact_id: ID of the artifact being validated
|
||||
|
||||
Returns:
|
||||
ValidationResult with status and diagnostics
|
||||
"""
|
||||
diagnostics = []
|
||||
total_checks = len(self.required_patterns) + len(self.forbidden_patterns)
|
||||
failures = 0
|
||||
|
||||
# Check required patterns
|
||||
for pattern in self.required_patterns:
|
||||
if not re.search(pattern, content):
|
||||
diagnostics.append(
|
||||
ValidationDiagnostic(
|
||||
code="MISSING_PATTERN",
|
||||
message=f"Required pattern not found: {pattern}",
|
||||
severity="error",
|
||||
)
|
||||
)
|
||||
failures += 1
|
||||
|
||||
# Check forbidden patterns
|
||||
for pattern in self.forbidden_patterns:
|
||||
match = re.search(pattern, content)
|
||||
if match:
|
||||
diagnostics.append(
|
||||
ValidationDiagnostic(
|
||||
code="FORBIDDEN_PATTERN",
|
||||
message=f"Forbidden pattern found: {pattern} (matched: '{match.group()}')",
|
||||
severity="error",
|
||||
)
|
||||
)
|
||||
failures += 1
|
||||
|
||||
if total_checks == 0:
|
||||
status = ValidationStatus.PASS
|
||||
score = 1.0
|
||||
elif failures == 0:
|
||||
status = ValidationStatus.PASS
|
||||
score = 1.0
|
||||
else:
|
||||
status = ValidationStatus.FAIL
|
||||
score = max(0.0, 1.0 - failures / total_checks)
|
||||
|
||||
return ValidationResult.create(
|
||||
gate_id=self.id,
|
||||
gate_type=self.gate_type,
|
||||
artifact_id=artifact_id,
|
||||
status=status,
|
||||
score=score,
|
||||
diagnostics=diagnostics,
|
||||
)
|
||||
123
markitect/prompts/quality/gates/schema_gate.py
Normal file
123
markitect/prompts/quality/gates/schema_gate.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""
|
||||
Schema validation quality gate.
|
||||
|
||||
Implements FR-9.1: Validate generated artifacts against JSON schemas.
|
||||
Uses the jsonschema library for validation.
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import jsonschema
|
||||
|
||||
from markitect.prompts.quality.models import (
|
||||
GateType,
|
||||
QualityGate,
|
||||
ValidationDiagnostic,
|
||||
ValidationResult,
|
||||
ValidationStatus,
|
||||
)
|
||||
|
||||
|
||||
class SchemaValidationGate(QualityGate):
|
||||
"""
|
||||
Validates artifact content against a JSON schema.
|
||||
|
||||
Parses content as JSON and validates against the provided schema
|
||||
using the jsonschema library.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
schema: Dict[str, Any],
|
||||
gate_id: Optional[str] = None,
|
||||
name: str = "schema",
|
||||
):
|
||||
"""
|
||||
Initialize with a JSON schema.
|
||||
|
||||
Args:
|
||||
schema: JSON Schema dictionary
|
||||
gate_id: Optional gate identifier (auto-generated if not provided)
|
||||
name: Human-readable gate name
|
||||
"""
|
||||
super().__init__(
|
||||
gate_id=gate_id or str(uuid.uuid4()),
|
||||
name=name,
|
||||
gate_type=GateType.SCHEMA,
|
||||
)
|
||||
self.schema = schema
|
||||
|
||||
def validate(self, content: str, artifact_id: str) -> ValidationResult:
|
||||
"""
|
||||
Validate content against the JSON schema.
|
||||
|
||||
Parses the content as JSON, then validates against the schema.
|
||||
Returns FAIL if content is not valid JSON or fails schema validation.
|
||||
|
||||
Args:
|
||||
content: JSON content string to validate
|
||||
artifact_id: ID of the artifact being validated
|
||||
|
||||
Returns:
|
||||
ValidationResult with status and diagnostics
|
||||
"""
|
||||
diagnostics = []
|
||||
|
||||
# Parse JSON
|
||||
try:
|
||||
data = json.loads(content)
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
diagnostics.append(
|
||||
ValidationDiagnostic(
|
||||
code="INVALID_JSON",
|
||||
message=f"Content is not valid JSON: {e}",
|
||||
severity="error",
|
||||
)
|
||||
)
|
||||
return ValidationResult.create(
|
||||
gate_id=self.id,
|
||||
gate_type=self.gate_type,
|
||||
artifact_id=artifact_id,
|
||||
status=ValidationStatus.FAIL,
|
||||
score=0.0,
|
||||
diagnostics=diagnostics,
|
||||
)
|
||||
|
||||
# Validate against schema
|
||||
validator = jsonschema.Draft7Validator(self.schema)
|
||||
errors = list(validator.iter_errors(data))
|
||||
|
||||
if not errors:
|
||||
return ValidationResult.create(
|
||||
gate_id=self.id,
|
||||
gate_type=self.gate_type,
|
||||
artifact_id=artifact_id,
|
||||
status=ValidationStatus.PASS,
|
||||
score=1.0,
|
||||
diagnostics=[],
|
||||
)
|
||||
|
||||
for error in errors:
|
||||
path = ".".join(str(p) for p in error.absolute_path) or "(root)"
|
||||
diagnostics.append(
|
||||
ValidationDiagnostic(
|
||||
code="SCHEMA_VIOLATION",
|
||||
message=f"At '{path}': {error.message}",
|
||||
severity="error",
|
||||
)
|
||||
)
|
||||
|
||||
# Score based on proportion of passing validations
|
||||
total_checks = len(errors) + 1 # approximate
|
||||
score = max(0.0, 1.0 - len(errors) / total_checks)
|
||||
|
||||
return ValidationResult.create(
|
||||
gate_id=self.id,
|
||||
gate_type=self.gate_type,
|
||||
artifact_id=artifact_id,
|
||||
status=ValidationStatus.FAIL,
|
||||
score=score,
|
||||
diagnostics=diagnostics,
|
||||
)
|
||||
283
markitect/prompts/quality/models.py
Normal file
283
markitect/prompts/quality/models.py
Normal file
@@ -0,0 +1,283 @@
|
||||
"""
|
||||
Data models for quality validation and halting policies.
|
||||
|
||||
Implements FR-9: QualityGate Validation
|
||||
Implements FR-10: Halting and Refinement Policy
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class GateType(Enum):
|
||||
"""Type classification for quality gates."""
|
||||
SCHEMA = "schema"
|
||||
PATTERN = "pattern"
|
||||
CUSTOM = "custom"
|
||||
|
||||
|
||||
class ValidationStatus(Enum):
|
||||
"""Outcome status of a quality gate check."""
|
||||
PASS = "pass"
|
||||
FAIL = "fail"
|
||||
WARNING = "warning"
|
||||
SKIPPED = "skipped"
|
||||
|
||||
|
||||
class HaltDecision(Enum):
|
||||
"""Decision outcome from halting policy evaluation."""
|
||||
CONTINUE = "continue"
|
||||
HALT_QUALITY_MET = "halted_quality_met"
|
||||
HALT_ITERATION_LIMIT = "halted_iteration_limit"
|
||||
HALT_BUDGET_EXHAUSTED = "halted_budget_exhausted"
|
||||
HALT_NO_IMPROVEMENT = "halted_no_improvement"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationDiagnostic:
|
||||
"""
|
||||
Single diagnostic message from a quality gate.
|
||||
|
||||
Attributes:
|
||||
code: Machine-readable diagnostic code
|
||||
message: Human-readable description
|
||||
severity: Severity level (error, warning, info)
|
||||
"""
|
||||
code: str
|
||||
message: str
|
||||
severity: str = "error"
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"code": self.code,
|
||||
"message": self.message,
|
||||
"severity": self.severity,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "ValidationDiagnostic":
|
||||
"""Create from dictionary."""
|
||||
return cls(
|
||||
code=data["code"],
|
||||
message=data["message"],
|
||||
severity=data.get("severity", "error"),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
"""
|
||||
Result of applying a quality gate to an artifact.
|
||||
|
||||
Implements FR-9.3: Record pass/fail results and diagnostics.
|
||||
|
||||
Attributes:
|
||||
id: Unique result identifier
|
||||
gate_id: ID of the quality gate that produced this result
|
||||
gate_type: Type of the quality gate
|
||||
artifact_id: ID of the validated artifact
|
||||
status: Pass/fail outcome
|
||||
score: Optional quality score (0.0-1.0)
|
||||
diagnostics: List of diagnostic messages
|
||||
validated_at: When validation occurred
|
||||
"""
|
||||
id: str
|
||||
gate_id: str
|
||||
gate_type: GateType
|
||||
artifact_id: str
|
||||
status: ValidationStatus
|
||||
score: Optional[float] = None
|
||||
diagnostics: List[ValidationDiagnostic] = field(default_factory=list)
|
||||
validated_at: datetime = field(default_factory=datetime.utcnow)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
gate_id: str,
|
||||
gate_type: GateType,
|
||||
artifact_id: str,
|
||||
status: ValidationStatus,
|
||||
score: Optional[float] = None,
|
||||
diagnostics: Optional[List[ValidationDiagnostic]] = None,
|
||||
) -> "ValidationResult":
|
||||
"""Create a new ValidationResult."""
|
||||
return cls(
|
||||
id=str(uuid.uuid4()),
|
||||
gate_id=gate_id,
|
||||
gate_type=gate_type,
|
||||
artifact_id=artifact_id,
|
||||
status=status,
|
||||
score=score,
|
||||
diagnostics=diagnostics or [],
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for serialization."""
|
||||
return {
|
||||
"id": self.id,
|
||||
"gate_id": self.gate_id,
|
||||
"gate_type": self.gate_type.value,
|
||||
"artifact_id": self.artifact_id,
|
||||
"status": self.status.value,
|
||||
"score": self.score,
|
||||
"diagnostics": [d.to_dict() for d in self.diagnostics],
|
||||
"validated_at": self.validated_at.isoformat(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "ValidationResult":
|
||||
"""Create from dictionary."""
|
||||
return cls(
|
||||
id=data["id"],
|
||||
gate_id=data["gate_id"],
|
||||
gate_type=GateType(data["gate_type"]),
|
||||
artifact_id=data["artifact_id"],
|
||||
status=ValidationStatus(data["status"]),
|
||||
score=data.get("score"),
|
||||
diagnostics=[
|
||||
ValidationDiagnostic.from_dict(d)
|
||||
for d in data.get("diagnostics", [])
|
||||
],
|
||||
validated_at=datetime.fromisoformat(data["validated_at"]),
|
||||
)
|
||||
|
||||
|
||||
class QualityGate(ABC):
|
||||
"""
|
||||
Abstract base class for quality gates.
|
||||
|
||||
Implements FR-9.1/FR-9.2: Pluggable validation framework
|
||||
supporting multiple gates per artifact.
|
||||
"""
|
||||
|
||||
def __init__(self, gate_id: str, name: str, gate_type: GateType):
|
||||
self.id = gate_id
|
||||
self.name = name
|
||||
self.gate_type = gate_type
|
||||
|
||||
@abstractmethod
|
||||
def validate(self, content: str, artifact_id: str) -> ValidationResult:
|
||||
"""
|
||||
Validate content against this quality gate.
|
||||
|
||||
Args:
|
||||
content: Content to validate
|
||||
artifact_id: ID of the artifact being validated
|
||||
|
||||
Returns:
|
||||
ValidationResult with status and diagnostics
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class QualityPolicy:
|
||||
"""
|
||||
Configuration for halting and refinement policy.
|
||||
|
||||
Implements FR-10.1: Configurable QualityPolicies.
|
||||
|
||||
Attributes:
|
||||
max_iterations: Maximum refinement iterations
|
||||
min_improvement: Minimum score improvement to continue
|
||||
fail_on_gate_failure: Whether any gate failure halts execution
|
||||
resource_budget: Maximum total runs allowed
|
||||
required_gate_ids: Gate IDs that must pass for quality to be met
|
||||
"""
|
||||
max_iterations: int = 3
|
||||
min_improvement: float = 0.05
|
||||
fail_on_gate_failure: bool = True
|
||||
resource_budget: int = 10
|
||||
required_gate_ids: List[str] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"max_iterations": self.max_iterations,
|
||||
"min_improvement": self.min_improvement,
|
||||
"fail_on_gate_failure": self.fail_on_gate_failure,
|
||||
"resource_budget": self.resource_budget,
|
||||
"required_gate_ids": self.required_gate_ids,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "QualityPolicy":
|
||||
"""Create from dictionary."""
|
||||
return cls(
|
||||
max_iterations=data.get("max_iterations", 3),
|
||||
min_improvement=data.get("min_improvement", 0.05),
|
||||
fail_on_gate_failure=data.get("fail_on_gate_failure", True),
|
||||
resource_budget=data.get("resource_budget", 10),
|
||||
required_gate_ids=data.get("required_gate_ids", []),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HaltingRecord:
|
||||
"""
|
||||
Record of a halting decision.
|
||||
|
||||
Implements FR-10.3: Record halting decisions in the RunManifest.
|
||||
|
||||
Attributes:
|
||||
decision: The halting decision
|
||||
iteration: Current iteration number
|
||||
max_iterations: Maximum allowed iterations
|
||||
scores: Score history across iterations
|
||||
reason: Human-readable reason for decision
|
||||
recorded_at: When the decision was made
|
||||
"""
|
||||
decision: HaltDecision
|
||||
iteration: int
|
||||
max_iterations: int
|
||||
scores: List[float] = field(default_factory=list)
|
||||
reason: str = ""
|
||||
recorded_at: datetime = field(default_factory=datetime.utcnow)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for serialization."""
|
||||
return {
|
||||
"decision": self.decision.value,
|
||||
"iteration": self.iteration,
|
||||
"max_iterations": self.max_iterations,
|
||||
"scores": self.scores,
|
||||
"reason": self.reason,
|
||||
"recorded_at": self.recorded_at.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class RefinementResult:
|
||||
"""
|
||||
Result of a refinement loop execution.
|
||||
|
||||
Attributes:
|
||||
iterations_run: Number of iterations executed
|
||||
final_results: Validation results from the last iteration
|
||||
halting_record: Record of the halting decision
|
||||
all_results: Validation results from all iterations
|
||||
run_ids: List of run IDs produced during refinement
|
||||
"""
|
||||
iterations_run: int
|
||||
final_results: List[ValidationResult] = field(default_factory=list)
|
||||
halting_record: Optional[HaltingRecord] = None
|
||||
all_results: List[List[ValidationResult]] = field(default_factory=list)
|
||||
run_ids: List[str] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
"iterations_run": self.iterations_run,
|
||||
"final_results": [r.to_dict() for r in self.final_results],
|
||||
"halting_record": self.halting_record.to_dict() if self.halting_record else None,
|
||||
"all_results": [
|
||||
[r.to_dict() for r in iteration]
|
||||
for iteration in self.all_results
|
||||
],
|
||||
"run_ids": self.run_ids,
|
||||
}
|
||||
153
markitect/prompts/quality/policy.py
Normal file
153
markitect/prompts/quality/policy.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""
|
||||
Halting policy engine for refinement control.
|
||||
|
||||
Implements FR-10: Halting and Refinement Policy
|
||||
Evaluates halting conditions based on quality gate results,
|
||||
marginal improvement, iteration limits, and resource budgets.
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from markitect.prompts.quality.models import (
|
||||
HaltDecision,
|
||||
HaltingRecord,
|
||||
QualityPolicy,
|
||||
ValidationResult,
|
||||
ValidationStatus,
|
||||
)
|
||||
|
||||
|
||||
class HaltingPolicyEngine:
|
||||
"""
|
||||
Evaluates halting decisions based on configurable policies.
|
||||
|
||||
Implements FR-10.2: Evaluate halting based on quality gate results,
|
||||
marginal improvement metrics, iteration limits, and resource budgets.
|
||||
|
||||
Implements FR-10.3: Record halting decisions.
|
||||
"""
|
||||
|
||||
def __init__(self, policy: QualityPolicy):
|
||||
"""
|
||||
Initialize with a quality policy.
|
||||
|
||||
Args:
|
||||
policy: Quality policy configuration
|
||||
"""
|
||||
self.policy = policy
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
results: List[ValidationResult],
|
||||
iteration: int,
|
||||
score_history: Optional[List[float]] = None,
|
||||
total_runs: int = 0,
|
||||
) -> HaltingRecord:
|
||||
"""
|
||||
Evaluate whether to halt or continue refinement.
|
||||
|
||||
Checks in order:
|
||||
1. Resource budget exhaustion
|
||||
2. Iteration limit
|
||||
3. Quality met (all required gates pass)
|
||||
4. Marginal improvement below threshold
|
||||
|
||||
Args:
|
||||
results: Validation results from current iteration
|
||||
iteration: Current iteration number (1-based)
|
||||
score_history: Aggregate scores from previous iterations
|
||||
total_runs: Total number of runs consumed
|
||||
|
||||
Returns:
|
||||
HaltingRecord with the decision
|
||||
"""
|
||||
score_history = score_history or []
|
||||
current_score = self._aggregate_score(results)
|
||||
all_scores = score_history + [current_score]
|
||||
|
||||
# Check resource budget
|
||||
if total_runs >= self.policy.resource_budget:
|
||||
return HaltingRecord(
|
||||
decision=HaltDecision.HALT_BUDGET_EXHAUSTED,
|
||||
iteration=iteration,
|
||||
max_iterations=self.policy.max_iterations,
|
||||
scores=all_scores,
|
||||
reason=f"Resource budget exhausted: {total_runs}/{self.policy.resource_budget} runs used",
|
||||
)
|
||||
|
||||
# Check iteration limit
|
||||
if iteration >= self.policy.max_iterations:
|
||||
return HaltingRecord(
|
||||
decision=HaltDecision.HALT_ITERATION_LIMIT,
|
||||
iteration=iteration,
|
||||
max_iterations=self.policy.max_iterations,
|
||||
scores=all_scores,
|
||||
reason=f"Iteration limit reached: {iteration}/{self.policy.max_iterations}",
|
||||
)
|
||||
|
||||
# Check if quality is met
|
||||
if self._quality_met(results):
|
||||
return HaltingRecord(
|
||||
decision=HaltDecision.HALT_QUALITY_MET,
|
||||
iteration=iteration,
|
||||
max_iterations=self.policy.max_iterations,
|
||||
scores=all_scores,
|
||||
reason="All quality gates passed",
|
||||
)
|
||||
|
||||
# Check marginal improvement
|
||||
if len(all_scores) >= 2:
|
||||
improvement = all_scores[-1] - all_scores[-2]
|
||||
if improvement < self.policy.min_improvement:
|
||||
return HaltingRecord(
|
||||
decision=HaltDecision.HALT_NO_IMPROVEMENT,
|
||||
iteration=iteration,
|
||||
max_iterations=self.policy.max_iterations,
|
||||
scores=all_scores,
|
||||
reason=(
|
||||
f"Marginal improvement {improvement:.4f} below "
|
||||
f"threshold {self.policy.min_improvement}"
|
||||
),
|
||||
)
|
||||
|
||||
# Continue refinement
|
||||
return HaltingRecord(
|
||||
decision=HaltDecision.CONTINUE,
|
||||
iteration=iteration,
|
||||
max_iterations=self.policy.max_iterations,
|
||||
scores=all_scores,
|
||||
reason="Continuing refinement",
|
||||
)
|
||||
|
||||
def _quality_met(self, results: List[ValidationResult]) -> bool:
|
||||
"""
|
||||
Check if quality requirements are met.
|
||||
|
||||
If required_gate_ids is set, only those gates must pass.
|
||||
Otherwise, all gates must pass.
|
||||
|
||||
Args:
|
||||
results: Validation results to check
|
||||
|
||||
Returns:
|
||||
True if quality requirements are met
|
||||
"""
|
||||
if self.policy.required_gate_ids:
|
||||
for gate_id in self.policy.required_gate_ids:
|
||||
gate_results = [r for r in results if r.gate_id == gate_id]
|
||||
if not gate_results:
|
||||
return False
|
||||
if any(r.status == ValidationStatus.FAIL for r in gate_results):
|
||||
return False
|
||||
return True
|
||||
|
||||
return all(r.status == ValidationStatus.PASS for r in results)
|
||||
|
||||
def _aggregate_score(self, results: List[ValidationResult]) -> float:
|
||||
"""Calculate aggregate score from results."""
|
||||
if not results:
|
||||
return 0.0
|
||||
scores = [r.score for r in results if r.score is not None]
|
||||
if not scores:
|
||||
return 0.0
|
||||
return sum(scores) / len(scores)
|
||||
108
markitect/prompts/quality/refinement.py
Normal file
108
markitect/prompts/quality/refinement.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""
|
||||
Refinement loop for iterative quality improvement.
|
||||
|
||||
Implements FR-10: Halting and Refinement Policy
|
||||
Execute → Validate → Halt or Refine cycle.
|
||||
"""
|
||||
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
from markitect.prompts.quality.models import (
|
||||
HaltDecision,
|
||||
QualityPolicy,
|
||||
RefinementResult,
|
||||
ValidationResult,
|
||||
)
|
||||
from markitect.prompts.quality.policy import HaltingPolicyEngine
|
||||
from markitect.prompts.quality.validator import QualityValidator
|
||||
|
||||
|
||||
class RefinementLoop:
|
||||
"""
|
||||
Iterative refinement loop with quality gate checks.
|
||||
|
||||
Executes a cycle of: execute → validate → check halting → refine
|
||||
until a halting condition is met.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
validator: QualityValidator,
|
||||
policy: QualityPolicy,
|
||||
):
|
||||
"""
|
||||
Initialize with validator and policy.
|
||||
|
||||
Args:
|
||||
validator: Quality validator with configured gates
|
||||
policy: Halting policy configuration
|
||||
"""
|
||||
self.validator = validator
|
||||
self.policy = policy
|
||||
self.policy_engine = HaltingPolicyEngine(policy)
|
||||
|
||||
def run(
|
||||
self,
|
||||
execution_callback: Callable[[int, List[ValidationResult]], Tuple[str, str, str]],
|
||||
artifact_id: str,
|
||||
) -> RefinementResult:
|
||||
"""
|
||||
Execute the refinement loop.
|
||||
|
||||
The execution_callback is called each iteration with:
|
||||
- iteration number (1-based)
|
||||
- previous validation results (empty list on first iteration)
|
||||
|
||||
It should return a tuple of (run_id, content, artifact_id).
|
||||
|
||||
Args:
|
||||
execution_callback: Callable that executes/refines and returns
|
||||
(run_id, content, artifact_id)
|
||||
artifact_id: ID of the artifact being refined
|
||||
|
||||
Returns:
|
||||
RefinementResult with complete iteration history
|
||||
"""
|
||||
result = RefinementResult(iterations_run=0)
|
||||
score_history: List[float] = []
|
||||
prev_results: List[ValidationResult] = []
|
||||
|
||||
for iteration in range(1, self.policy.max_iterations + 1):
|
||||
# Execute / refine
|
||||
run_id, content, art_id = execution_callback(iteration, prev_results)
|
||||
result.run_ids.append(run_id)
|
||||
|
||||
# Validate
|
||||
current_results = self.validator.validate_artifact(
|
||||
content, art_id, run_id=run_id if self.validator.db_path else None,
|
||||
)
|
||||
result.all_results.append(current_results)
|
||||
result.iterations_run = iteration
|
||||
|
||||
# Evaluate halting
|
||||
halting_record = self.policy_engine.evaluate(
|
||||
results=current_results,
|
||||
iteration=iteration,
|
||||
score_history=score_history,
|
||||
total_runs=len(result.run_ids),
|
||||
)
|
||||
|
||||
current_score = self.policy_engine._aggregate_score(current_results)
|
||||
score_history.append(current_score)
|
||||
|
||||
if halting_record.decision != HaltDecision.CONTINUE:
|
||||
result.final_results = current_results
|
||||
result.halting_record = halting_record
|
||||
return result
|
||||
|
||||
prev_results = current_results
|
||||
|
||||
# Reached max iterations without explicit halt
|
||||
result.final_results = prev_results
|
||||
result.halting_record = self.policy_engine.evaluate(
|
||||
results=prev_results,
|
||||
iteration=self.policy.max_iterations,
|
||||
score_history=score_history,
|
||||
total_runs=len(result.run_ids),
|
||||
)
|
||||
return result
|
||||
294
markitect/prompts/quality/validator.py
Normal file
294
markitect/prompts/quality/validator.py
Normal file
@@ -0,0 +1,294 @@
|
||||
"""
|
||||
Quality validator for applying multiple gates to artifacts.
|
||||
|
||||
Implements FR-9.2: Multiple QualityGates per artifact.
|
||||
Implements FR-9.3: Record pass/fail results and diagnostics in RunManifest.
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
from markitect.prompts.quality.models import (
|
||||
GateType,
|
||||
QualityGate,
|
||||
ValidationResult,
|
||||
ValidationStatus,
|
||||
)
|
||||
|
||||
|
||||
# SQL schema for quality tables
|
||||
QUALITY_TABLES_SQL = """
|
||||
CREATE TABLE IF NOT EXISTS quality_gates (
|
||||
id TEXT PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
gate_type TEXT NOT NULL,
|
||||
config JSON NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS validation_results (
|
||||
id TEXT PRIMARY KEY,
|
||||
run_id TEXT NOT NULL,
|
||||
gate_id TEXT NOT NULL,
|
||||
artifact_id TEXT,
|
||||
status TEXT NOT NULL,
|
||||
score REAL,
|
||||
diagnostics JSON,
|
||||
validated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_validations_run ON validation_results(run_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_validations_artifact ON validation_results(artifact_id);
|
||||
"""
|
||||
|
||||
|
||||
class QualityValidator:
|
||||
"""
|
||||
Applies multiple quality gates to artifacts and records results.
|
||||
|
||||
Implements FR-9.2 and FR-9.3.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
gates: Optional[List[QualityGate]] = None,
|
||||
db_path: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize validator with quality gates.
|
||||
|
||||
Args:
|
||||
gates: List of quality gates to apply
|
||||
db_path: Optional database path for persisting results
|
||||
"""
|
||||
self.gates: List[QualityGate] = gates or []
|
||||
self.db_path = db_path
|
||||
if db_path:
|
||||
self._initialize_tables()
|
||||
|
||||
def _initialize_tables(self) -> None:
|
||||
"""Initialize quality tables if DB path is set."""
|
||||
db_dir = Path(self.db_path).parent
|
||||
if db_dir and not db_dir.exists():
|
||||
db_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
try:
|
||||
conn.executescript(QUALITY_TABLES_SQL)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def _get_connection(self) -> sqlite3.Connection:
|
||||
"""Get a database connection."""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
|
||||
def add_gate(self, gate: QualityGate) -> None:
|
||||
"""
|
||||
Add a quality gate to the validator.
|
||||
|
||||
Args:
|
||||
gate: Quality gate to add
|
||||
"""
|
||||
self.gates.append(gate)
|
||||
|
||||
def validate_artifact(
|
||||
self,
|
||||
content: str,
|
||||
artifact_id: str,
|
||||
run_id: Optional[str] = None,
|
||||
) -> List[ValidationResult]:
|
||||
"""
|
||||
Apply all quality gates to an artifact.
|
||||
|
||||
Args:
|
||||
content: Artifact content to validate
|
||||
artifact_id: ID of the artifact
|
||||
run_id: Optional run ID for persistence
|
||||
|
||||
Returns:
|
||||
List of ValidationResult from all gates
|
||||
"""
|
||||
results = []
|
||||
for gate in self.gates:
|
||||
result = gate.validate(content, artifact_id)
|
||||
results.append(result)
|
||||
|
||||
if run_id and self.db_path:
|
||||
self._persist_result(result, run_id)
|
||||
|
||||
return results
|
||||
|
||||
def all_passed(self, results: List[ValidationResult]) -> bool:
|
||||
"""
|
||||
Check if all validation results passed.
|
||||
|
||||
Args:
|
||||
results: List of validation results
|
||||
|
||||
Returns:
|
||||
True if all results have PASS status
|
||||
"""
|
||||
return all(r.status == ValidationStatus.PASS for r in results)
|
||||
|
||||
def aggregate_score(self, results: List[ValidationResult]) -> float:
|
||||
"""
|
||||
Calculate aggregate score across all results.
|
||||
|
||||
Args:
|
||||
results: List of validation results
|
||||
|
||||
Returns:
|
||||
Average score (0.0-1.0), or 1.0 if no results
|
||||
"""
|
||||
if not results:
|
||||
return 1.0
|
||||
scores = [r.score for r in results if r.score is not None]
|
||||
if not scores:
|
||||
return 1.0
|
||||
return sum(scores) / len(scores)
|
||||
|
||||
def get_failed_gates(
|
||||
self,
|
||||
results: List[ValidationResult],
|
||||
) -> List[ValidationResult]:
|
||||
"""
|
||||
Get only failed validation results.
|
||||
|
||||
Args:
|
||||
results: List of validation results
|
||||
|
||||
Returns:
|
||||
List of results with FAIL status
|
||||
"""
|
||||
return [r for r in results if r.status == ValidationStatus.FAIL]
|
||||
|
||||
def results_to_manifest_dict(
|
||||
self,
|
||||
results: List[ValidationResult],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert validation results to RunManifest-compatible dict.
|
||||
|
||||
Implements FR-9.3: Results for RunManifest.
|
||||
|
||||
Args:
|
||||
results: List of validation results
|
||||
|
||||
Returns:
|
||||
Dictionary suitable for RunManifest.validation_results
|
||||
"""
|
||||
return {
|
||||
"quality_gates": [r.to_dict() for r in results],
|
||||
"all_passed": self.all_passed(results),
|
||||
"aggregate_score": self.aggregate_score(results),
|
||||
}
|
||||
|
||||
def _persist_result(
|
||||
self,
|
||||
result: ValidationResult,
|
||||
run_id: str,
|
||||
) -> None:
|
||||
"""Persist a validation result to the database."""
|
||||
import json
|
||||
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO validation_results (
|
||||
id, run_id, gate_id, artifact_id,
|
||||
status, score, diagnostics, validated_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
result.id,
|
||||
run_id,
|
||||
result.gate_id,
|
||||
result.artifact_id,
|
||||
result.status.value,
|
||||
result.score,
|
||||
json.dumps([d.to_dict() for d in result.diagnostics]),
|
||||
result.validated_at.isoformat(),
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_results_for_run(self, run_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Retrieve persisted validation results for a run.
|
||||
|
||||
Args:
|
||||
run_id: Run identifier
|
||||
|
||||
Returns:
|
||||
List of result dictionaries
|
||||
"""
|
||||
if not self.db_path:
|
||||
return []
|
||||
|
||||
import json
|
||||
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
cursor = conn.execute(
|
||||
"SELECT * FROM validation_results WHERE run_id = ?",
|
||||
(run_id,),
|
||||
)
|
||||
results = []
|
||||
for row in cursor.fetchall():
|
||||
results.append({
|
||||
"id": row["id"],
|
||||
"run_id": run_id,
|
||||
"gate_id": row["gate_id"],
|
||||
"artifact_id": row["artifact_id"],
|
||||
"status": row["status"],
|
||||
"score": row["score"],
|
||||
"diagnostics": json.loads(row["diagnostics"]) if row["diagnostics"] else [],
|
||||
"validated_at": row["validated_at"],
|
||||
})
|
||||
return results
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def get_results_for_artifact(self, artifact_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Retrieve persisted validation results for an artifact.
|
||||
|
||||
Args:
|
||||
artifact_id: Artifact identifier
|
||||
|
||||
Returns:
|
||||
List of result dictionaries
|
||||
"""
|
||||
if not self.db_path:
|
||||
return []
|
||||
|
||||
import json
|
||||
|
||||
conn = self._get_connection()
|
||||
try:
|
||||
cursor = conn.execute(
|
||||
"SELECT * FROM validation_results WHERE artifact_id = ?",
|
||||
(artifact_id,),
|
||||
)
|
||||
results = []
|
||||
for row in cursor.fetchall():
|
||||
results.append({
|
||||
"id": row["id"],
|
||||
"run_id": row["run_id"],
|
||||
"gate_id": row["gate_id"],
|
||||
"artifact_id": artifact_id,
|
||||
"status": row["status"],
|
||||
"score": row["score"],
|
||||
"diagnostics": json.loads(row["diagnostics"]) if row["diagnostics"] else [],
|
||||
"validated_at": row["validated_at"],
|
||||
})
|
||||
return results
|
||||
finally:
|
||||
conn.close()
|
||||
25
migrations/prompts/006_create_quality_tables.sql
Normal file
25
migrations/prompts/006_create_quality_tables.sql
Normal file
@@ -0,0 +1,25 @@
|
||||
-- Phase 7: Quality & Validation tables
|
||||
-- quality_gates: registered quality gate configurations
|
||||
-- validation_results: recorded pass/fail results from gate checks
|
||||
|
||||
CREATE TABLE IF NOT EXISTS quality_gates (
|
||||
id TEXT PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
gate_type TEXT NOT NULL,
|
||||
config JSON NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS validation_results (
|
||||
id TEXT PRIMARY KEY,
|
||||
run_id TEXT NOT NULL,
|
||||
gate_id TEXT NOT NULL,
|
||||
artifact_id TEXT,
|
||||
status TEXT NOT NULL,
|
||||
score REAL,
|
||||
diagnostics JSON,
|
||||
validated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_validations_run ON validation_results(run_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_validations_artifact ON validation_results(artifact_id);
|
||||
239
tests/integration/prompts/test_halting_execution.py
Normal file
239
tests/integration/prompts/test_halting_execution.py
Normal file
@@ -0,0 +1,239 @@
|
||||
"""
|
||||
Integration tests for halting execution with refinement loop.
|
||||
|
||||
Tests the full execute → validate → halt or refine cycle with
|
||||
real quality gates and persistence.
|
||||
"""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from markitect.prompts.models import Artifact, ArtifactType
|
||||
from markitect.prompts.repositories.sqlite import SQLiteArtifactRepository
|
||||
from markitect.prompts.quality.models import (
|
||||
HaltDecision,
|
||||
QualityPolicy,
|
||||
ValidationStatus,
|
||||
)
|
||||
from markitect.prompts.quality.gates.pattern_gate import PatternValidationGate
|
||||
from markitect.prompts.quality.gates.schema_gate import SchemaValidationGate
|
||||
from markitect.prompts.quality.validator import QualityValidator
|
||||
from markitect.prompts.quality.refinement import RefinementLoop
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_db():
|
||||
"""Create temporary database for testing."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||
db_path = f.name
|
||||
yield db_path
|
||||
Path(db_path).unlink(missing_ok=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def artifact_repo(temp_db):
|
||||
"""Create artifact repository."""
|
||||
return SQLiteArtifactRepository(temp_db)
|
||||
|
||||
|
||||
class TestImmediateQualityMet:
|
||||
"""Tests where quality is met on the first iteration."""
|
||||
|
||||
def test_single_iteration_success(self, temp_db):
|
||||
"""Test refinement completes in one iteration when quality is met."""
|
||||
gate = PatternValidationGate(
|
||||
required_patterns=[r"## Summary", r"## Conclusion"],
|
||||
gate_id="gate-1",
|
||||
)
|
||||
validator = QualityValidator(gates=[gate], db_path=temp_db)
|
||||
policy = QualityPolicy(max_iterations=5)
|
||||
loop = RefinementLoop(validator, policy)
|
||||
|
||||
def execute(iteration, prev_results):
|
||||
return (
|
||||
f"run-{iteration}",
|
||||
"## Summary\nOverview.\n## Conclusion\nDone.",
|
||||
"art-1",
|
||||
)
|
||||
|
||||
result = loop.run(execute, "art-1")
|
||||
|
||||
assert result.iterations_run == 1
|
||||
assert result.halting_record.decision == HaltDecision.HALT_QUALITY_MET
|
||||
assert len(result.final_results) == 1
|
||||
assert result.final_results[0].status == ValidationStatus.PASS
|
||||
|
||||
# Verify results persisted
|
||||
persisted = validator.get_results_for_run("run-1")
|
||||
assert len(persisted) == 1
|
||||
|
||||
|
||||
class TestIterativeRefinement:
|
||||
"""Tests for iterative refinement improving quality."""
|
||||
|
||||
def test_progressive_improvement(self, temp_db):
|
||||
"""Test refinement improves content over iterations."""
|
||||
gate = PatternValidationGate(
|
||||
required_patterns=[r"## Summary", r"## Details", r"## Conclusion"],
|
||||
gate_id="gate-1",
|
||||
)
|
||||
validator = QualityValidator(gates=[gate], db_path=temp_db)
|
||||
policy = QualityPolicy(max_iterations=5)
|
||||
loop = RefinementLoop(validator, policy)
|
||||
|
||||
versions = [
|
||||
"## Summary\nBasic.", # iter 1: missing 2 patterns
|
||||
"## Summary\n## Details\nBetter.", # iter 2: missing 1 pattern
|
||||
"## Summary\n## Details\n## Conclusion\nComplete.", # iter 3: all pass
|
||||
]
|
||||
|
||||
def execute(iteration, prev_results):
|
||||
content = versions[min(iteration - 1, len(versions) - 1)]
|
||||
return (f"run-{iteration}", content, "art-1")
|
||||
|
||||
result = loop.run(execute, "art-1")
|
||||
|
||||
assert result.iterations_run == 3
|
||||
assert result.halting_record.decision == HaltDecision.HALT_QUALITY_MET
|
||||
assert len(result.all_results) == 3
|
||||
|
||||
# Verify all iterations persisted
|
||||
for i in range(1, 4):
|
||||
persisted = validator.get_results_for_run(f"run-{i}")
|
||||
assert len(persisted) == 1
|
||||
|
||||
|
||||
class TestIterationLimit:
|
||||
"""Tests for hitting iteration limits."""
|
||||
|
||||
def test_never_meets_quality(self, temp_db):
|
||||
"""Test refinement stops at iteration limit when quality never met."""
|
||||
gate = PatternValidationGate(
|
||||
required_patterns=[r"NEVER_MATCHES_XYZ123"],
|
||||
gate_id="gate-1",
|
||||
)
|
||||
validator = QualityValidator(gates=[gate], db_path=temp_db)
|
||||
policy = QualityPolicy(max_iterations=3, min_improvement=0.0)
|
||||
loop = RefinementLoop(validator, policy)
|
||||
|
||||
def execute(iteration, prev_results):
|
||||
return (f"run-{iteration}", "always insufficient", "art-1")
|
||||
|
||||
result = loop.run(execute, "art-1")
|
||||
|
||||
assert result.iterations_run == 3
|
||||
assert result.halting_record.decision == HaltDecision.HALT_ITERATION_LIMIT
|
||||
assert len(result.run_ids) == 3
|
||||
|
||||
|
||||
class TestBudgetExhaustion:
|
||||
"""Tests for resource budget exhaustion."""
|
||||
|
||||
def test_budget_limits_iterations(self, temp_db):
|
||||
"""Test budget exhaustion stops refinement."""
|
||||
gate = PatternValidationGate(
|
||||
required_patterns=[r"UNREACHABLE"],
|
||||
gate_id="gate-1",
|
||||
)
|
||||
validator = QualityValidator(gates=[gate], db_path=temp_db)
|
||||
policy = QualityPolicy(max_iterations=10, resource_budget=2)
|
||||
loop = RefinementLoop(validator, policy)
|
||||
|
||||
def execute(iteration, prev_results):
|
||||
return (f"run-{iteration}", "content", "art-1")
|
||||
|
||||
result = loop.run(execute, "art-1")
|
||||
|
||||
assert result.iterations_run == 2
|
||||
assert result.halting_record.decision == HaltDecision.HALT_BUDGET_EXHAUSTED
|
||||
|
||||
|
||||
class TestMultiGateRefinement:
|
||||
"""Tests for refinement with multiple quality gates."""
|
||||
|
||||
def test_all_gates_must_pass(self, temp_db):
|
||||
"""Test refinement continues until all gates pass."""
|
||||
gate_a = PatternValidationGate(
|
||||
required_patterns=[r"## Summary"],
|
||||
gate_id="gate-a",
|
||||
)
|
||||
gate_b = PatternValidationGate(
|
||||
forbidden_patterns=[r"TODO"],
|
||||
gate_id="gate-b",
|
||||
)
|
||||
validator = QualityValidator(gates=[gate_a, gate_b], db_path=temp_db)
|
||||
policy = QualityPolicy(max_iterations=5)
|
||||
loop = RefinementLoop(validator, policy)
|
||||
|
||||
versions = [
|
||||
"## Summary\nTODO: finish this", # gate-a pass, gate-b fail
|
||||
"## Summary\nAll clean content.", # both pass
|
||||
]
|
||||
|
||||
def execute(iteration, prev_results):
|
||||
content = versions[min(iteration - 1, len(versions) - 1)]
|
||||
return (f"run-{iteration}", content, "art-1")
|
||||
|
||||
result = loop.run(execute, "art-1")
|
||||
|
||||
assert result.iterations_run == 2
|
||||
assert result.halting_record.decision == HaltDecision.HALT_QUALITY_MET
|
||||
|
||||
|
||||
class TestRefinementWithSchemaGate:
|
||||
"""Tests for refinement with schema validation gates."""
|
||||
|
||||
def test_json_refinement(self, temp_db):
|
||||
"""Test refining JSON content to pass schema validation."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"required": ["title", "version", "sections"],
|
||||
"properties": {
|
||||
"title": {"type": "string"},
|
||||
"version": {"type": "integer"},
|
||||
"sections": {"type": "array"},
|
||||
},
|
||||
}
|
||||
gate = SchemaValidationGate(schema=schema, gate_id="schema-1")
|
||||
validator = QualityValidator(gates=[gate], db_path=temp_db)
|
||||
policy = QualityPolicy(max_iterations=5)
|
||||
loop = RefinementLoop(validator, policy)
|
||||
|
||||
versions = [
|
||||
json.dumps({"title": "Doc"}), # missing version & sections
|
||||
json.dumps({"title": "Doc", "version": 1}), # missing sections
|
||||
json.dumps({"title": "Doc", "version": 1, "sections": []}), # complete
|
||||
]
|
||||
|
||||
def execute(iteration, prev_results):
|
||||
content = versions[min(iteration - 1, len(versions) - 1)]
|
||||
return (f"run-{iteration}", content, "art-1")
|
||||
|
||||
result = loop.run(execute, "art-1")
|
||||
|
||||
assert result.iterations_run == 3
|
||||
assert result.halting_record.decision == HaltDecision.HALT_QUALITY_MET
|
||||
|
||||
|
||||
class TestResultSerialization:
|
||||
"""Tests for RefinementResult serialization."""
|
||||
|
||||
def test_result_to_dict(self, temp_db):
|
||||
"""Test RefinementResult can be serialized."""
|
||||
gate = PatternValidationGate(gate_id="gate-1")
|
||||
validator = QualityValidator(gates=[gate], db_path=temp_db)
|
||||
policy = QualityPolicy(max_iterations=1)
|
||||
loop = RefinementLoop(validator, policy)
|
||||
|
||||
def execute(iteration, prev_results):
|
||||
return ("run-1", "content", "art-1")
|
||||
|
||||
result = loop.run(execute, "art-1")
|
||||
d = result.to_dict()
|
||||
|
||||
assert isinstance(d, dict)
|
||||
assert "iterations_run" in d
|
||||
assert "halting_record" in d
|
||||
assert "run_ids" in d
|
||||
208
tests/integration/prompts/test_quality_validation.py
Normal file
208
tests/integration/prompts/test_quality_validation.py
Normal file
@@ -0,0 +1,208 @@
|
||||
"""
|
||||
Integration tests for full quality validation workflow.
|
||||
|
||||
Tests applying quality gates to artifacts with real DB persistence,
|
||||
manifest integration, and multi-gate validation.
|
||||
"""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from markitect.prompts.models import Artifact, ArtifactType
|
||||
from markitect.prompts.repositories.sqlite import SQLiteArtifactRepository
|
||||
from markitect.prompts.quality.models import (
|
||||
GateType,
|
||||
ValidationStatus,
|
||||
)
|
||||
from markitect.prompts.quality.gates.schema_gate import SchemaValidationGate
|
||||
from markitect.prompts.quality.gates.pattern_gate import PatternValidationGate
|
||||
from markitect.prompts.quality.validator import QualityValidator
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_db():
|
||||
"""Create temporary database for testing."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||
db_path = f.name
|
||||
yield db_path
|
||||
Path(db_path).unlink(missing_ok=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def artifact_repo(temp_db):
|
||||
"""Create artifact repository."""
|
||||
return SQLiteArtifactRepository(temp_db)
|
||||
|
||||
|
||||
def _create_artifact(repo, name, content, art_type=ArtifactType.GENERATED):
|
||||
"""Helper to create and persist an artifact."""
|
||||
artifact = Artifact.create(
|
||||
space_id="space-1",
|
||||
name=name,
|
||||
content=content,
|
||||
artifact_type=art_type,
|
||||
)
|
||||
return repo.create(artifact)
|
||||
|
||||
|
||||
class TestSchemaValidationWorkflow:
|
||||
"""Full schema validation workflow with real DB."""
|
||||
|
||||
def test_validate_json_artifact_passes(self, temp_db, artifact_repo):
|
||||
"""Test validating a valid JSON artifact."""
|
||||
content = json.dumps({
|
||||
"name": "API Spec",
|
||||
"version": 1,
|
||||
"endpoints": ["/users", "/auth"],
|
||||
})
|
||||
artifact = _create_artifact(artifact_repo, "api-spec", content)
|
||||
|
||||
schema = {
|
||||
"type": "object",
|
||||
"required": ["name", "version", "endpoints"],
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"version": {"type": "integer"},
|
||||
"endpoints": {"type": "array", "items": {"type": "string"}},
|
||||
},
|
||||
}
|
||||
gate = SchemaValidationGate(schema=schema, gate_id="schema-api")
|
||||
validator = QualityValidator(gates=[gate], db_path=temp_db)
|
||||
|
||||
results = validator.validate_artifact(
|
||||
content, artifact.id, run_id="run-1"
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].status == ValidationStatus.PASS
|
||||
|
||||
# Verify persisted
|
||||
persisted = validator.get_results_for_run("run-1")
|
||||
assert len(persisted) == 1
|
||||
assert persisted[0]["status"] == "pass"
|
||||
|
||||
def test_validate_json_artifact_fails(self, temp_db, artifact_repo):
|
||||
"""Test validating an invalid JSON artifact."""
|
||||
content = json.dumps({"name": "Incomplete"})
|
||||
artifact = _create_artifact(artifact_repo, "bad-spec", content)
|
||||
|
||||
schema = {
|
||||
"type": "object",
|
||||
"required": ["name", "version"],
|
||||
}
|
||||
gate = SchemaValidationGate(schema=schema, gate_id="schema-strict")
|
||||
validator = QualityValidator(gates=[gate], db_path=temp_db)
|
||||
|
||||
results = validator.validate_artifact(
|
||||
content, artifact.id, run_id="run-2"
|
||||
)
|
||||
|
||||
assert results[0].status == ValidationStatus.FAIL
|
||||
assert len(results[0].diagnostics) > 0
|
||||
|
||||
persisted = validator.get_results_for_run("run-2")
|
||||
assert persisted[0]["status"] == "fail"
|
||||
|
||||
|
||||
class TestPatternValidationWorkflow:
|
||||
"""Full pattern validation workflow with real DB."""
|
||||
|
||||
def test_validate_markdown_artifact(self, temp_db, artifact_repo):
|
||||
"""Test validating a markdown artifact against patterns."""
|
||||
content = "# API Documentation\n## Endpoints\n### Authentication\nOAuth2 flow."
|
||||
artifact = _create_artifact(artifact_repo, "api-docs", content)
|
||||
|
||||
gate = PatternValidationGate(
|
||||
required_patterns=[r"## Endpoints", r"### Authentication"],
|
||||
forbidden_patterns=[r"TODO", r"FIXME"],
|
||||
gate_id="pattern-api",
|
||||
)
|
||||
validator = QualityValidator(gates=[gate], db_path=temp_db)
|
||||
|
||||
results = validator.validate_artifact(
|
||||
content, artifact.id, run_id="run-3"
|
||||
)
|
||||
|
||||
assert results[0].status == ValidationStatus.PASS
|
||||
|
||||
def test_forbidden_pattern_detected(self, temp_db, artifact_repo):
|
||||
"""Test that forbidden patterns are caught."""
|
||||
content = "# Draft\n## Endpoints\nTODO: Add authentication."
|
||||
artifact = _create_artifact(artifact_repo, "draft-docs", content)
|
||||
|
||||
gate = PatternValidationGate(
|
||||
required_patterns=[r"## Endpoints"],
|
||||
forbidden_patterns=[r"TODO"],
|
||||
gate_id="pattern-clean",
|
||||
)
|
||||
validator = QualityValidator(gates=[gate], db_path=temp_db)
|
||||
|
||||
results = validator.validate_artifact(
|
||||
content, artifact.id, run_id="run-4"
|
||||
)
|
||||
|
||||
assert results[0].status == ValidationStatus.FAIL
|
||||
|
||||
|
||||
class TestMultiGateWorkflow:
|
||||
"""Tests applying multiple gates in a single validation."""
|
||||
|
||||
def test_multi_gate_validation(self, temp_db, artifact_repo):
|
||||
"""Test applying schema + pattern gates to an artifact."""
|
||||
content = json.dumps({
|
||||
"title": "Design Doc",
|
||||
"sections": ["## Overview", "## Details"],
|
||||
})
|
||||
artifact = _create_artifact(artifact_repo, "design-doc", content)
|
||||
|
||||
schema_gate = SchemaValidationGate(
|
||||
schema={
|
||||
"type": "object",
|
||||
"required": ["title", "sections"],
|
||||
},
|
||||
gate_id="schema-doc",
|
||||
)
|
||||
pattern_gate = PatternValidationGate(
|
||||
forbidden_patterns=[r"FIXME"],
|
||||
gate_id="pattern-clean",
|
||||
)
|
||||
validator = QualityValidator(
|
||||
gates=[schema_gate, pattern_gate],
|
||||
db_path=temp_db,
|
||||
)
|
||||
|
||||
results = validator.validate_artifact(
|
||||
content, artifact.id, run_id="run-5"
|
||||
)
|
||||
|
||||
assert len(results) == 2
|
||||
assert all(r.status == ValidationStatus.PASS for r in results)
|
||||
|
||||
# Check manifest dict
|
||||
manifest = validator.results_to_manifest_dict(results)
|
||||
assert manifest["all_passed"] is True
|
||||
assert manifest["aggregate_score"] == 1.0
|
||||
|
||||
# Verify all persisted
|
||||
persisted = validator.get_results_for_run("run-5")
|
||||
assert len(persisted) == 2
|
||||
|
||||
def test_retrieve_by_artifact(self, temp_db, artifact_repo):
|
||||
"""Test retrieving results by artifact across multiple runs."""
|
||||
content = json.dumps({"name": "test"})
|
||||
artifact = _create_artifact(artifact_repo, "test-art", content)
|
||||
|
||||
gate = SchemaValidationGate(
|
||||
schema={"type": "object", "required": ["name"]},
|
||||
gate_id="schema-1",
|
||||
)
|
||||
validator = QualityValidator(gates=[gate], db_path=temp_db)
|
||||
|
||||
# Validate across two runs
|
||||
validator.validate_artifact(content, artifact.id, run_id="run-a")
|
||||
validator.validate_artifact(content, artifact.id, run_id="run-b")
|
||||
|
||||
results = validator.get_results_for_artifact(artifact.id)
|
||||
assert len(results) == 2
|
||||
221
tests/unit/prompts/test_halting_policy.py
Normal file
221
tests/unit/prompts/test_halting_policy.py
Normal file
@@ -0,0 +1,221 @@
|
||||
"""
|
||||
Unit tests for HaltingPolicyEngine.
|
||||
|
||||
Tests halting decisions based on quality results, iteration limits,
|
||||
marginal improvement, and resource budgets.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from markitect.prompts.quality.models import (
|
||||
GateType,
|
||||
HaltDecision,
|
||||
QualityPolicy,
|
||||
ValidationResult,
|
||||
ValidationStatus,
|
||||
)
|
||||
from markitect.prompts.quality.policy import HaltingPolicyEngine
|
||||
|
||||
|
||||
def _make_result(status=ValidationStatus.PASS, score=1.0, gate_id="gate-1"):
|
||||
"""Helper to create a ValidationResult."""
|
||||
return ValidationResult.create(
|
||||
gate_id=gate_id,
|
||||
gate_type=GateType.PATTERN,
|
||||
artifact_id="art-1",
|
||||
status=status,
|
||||
score=score,
|
||||
)
|
||||
|
||||
|
||||
class TestQualityMetDecision:
|
||||
"""Tests for quality met halting."""
|
||||
|
||||
def test_all_pass_halts_quality_met(self):
|
||||
"""Test all gates passing triggers quality met halt."""
|
||||
policy = QualityPolicy(max_iterations=5)
|
||||
engine = HaltingPolicyEngine(policy)
|
||||
|
||||
results = [_make_result(ValidationStatus.PASS, 1.0)]
|
||||
record = engine.evaluate(results, iteration=1)
|
||||
|
||||
assert record.decision == HaltDecision.HALT_QUALITY_MET
|
||||
assert "quality gates passed" in record.reason.lower()
|
||||
|
||||
def test_required_gates_all_pass(self):
|
||||
"""Test required gates all passing triggers quality met."""
|
||||
policy = QualityPolicy(
|
||||
max_iterations=5,
|
||||
required_gate_ids=["required-gate"],
|
||||
)
|
||||
engine = HaltingPolicyEngine(policy)
|
||||
|
||||
results = [
|
||||
_make_result(ValidationStatus.PASS, 1.0, gate_id="required-gate"),
|
||||
_make_result(ValidationStatus.FAIL, 0.5, gate_id="optional-gate"),
|
||||
]
|
||||
record = engine.evaluate(results, iteration=1)
|
||||
|
||||
assert record.decision == HaltDecision.HALT_QUALITY_MET
|
||||
|
||||
def test_required_gate_fails_continues(self):
|
||||
"""Test required gate failing allows continuation."""
|
||||
policy = QualityPolicy(
|
||||
max_iterations=5,
|
||||
required_gate_ids=["required-gate"],
|
||||
)
|
||||
engine = HaltingPolicyEngine(policy)
|
||||
|
||||
results = [
|
||||
_make_result(ValidationStatus.FAIL, 0.5, gate_id="required-gate"),
|
||||
]
|
||||
record = engine.evaluate(results, iteration=1)
|
||||
|
||||
assert record.decision == HaltDecision.CONTINUE
|
||||
|
||||
|
||||
class TestIterationLimitDecision:
|
||||
"""Tests for iteration limit halting."""
|
||||
|
||||
def test_at_iteration_limit(self):
|
||||
"""Test halting at iteration limit."""
|
||||
policy = QualityPolicy(max_iterations=3)
|
||||
engine = HaltingPolicyEngine(policy)
|
||||
|
||||
results = [_make_result(ValidationStatus.FAIL, 0.5)]
|
||||
record = engine.evaluate(results, iteration=3)
|
||||
|
||||
assert record.decision == HaltDecision.HALT_ITERATION_LIMIT
|
||||
assert record.iteration == 3
|
||||
|
||||
def test_before_iteration_limit(self):
|
||||
"""Test not halting before iteration limit."""
|
||||
policy = QualityPolicy(max_iterations=5)
|
||||
engine = HaltingPolicyEngine(policy)
|
||||
|
||||
results = [_make_result(ValidationStatus.FAIL, 0.5)]
|
||||
record = engine.evaluate(results, iteration=2)
|
||||
|
||||
assert record.decision == HaltDecision.CONTINUE
|
||||
|
||||
|
||||
class TestBudgetExhaustedDecision:
|
||||
"""Tests for resource budget exhaustion."""
|
||||
|
||||
def test_budget_exhausted(self):
|
||||
"""Test halting when budget is exhausted."""
|
||||
policy = QualityPolicy(max_iterations=10, resource_budget=5)
|
||||
engine = HaltingPolicyEngine(policy)
|
||||
|
||||
results = [_make_result(ValidationStatus.FAIL, 0.5)]
|
||||
record = engine.evaluate(results, iteration=1, total_runs=5)
|
||||
|
||||
assert record.decision == HaltDecision.HALT_BUDGET_EXHAUSTED
|
||||
|
||||
def test_budget_not_exhausted(self):
|
||||
"""Test not halting when budget remains."""
|
||||
policy = QualityPolicy(max_iterations=10, resource_budget=10)
|
||||
engine = HaltingPolicyEngine(policy)
|
||||
|
||||
results = [_make_result(ValidationStatus.FAIL, 0.5)]
|
||||
record = engine.evaluate(results, iteration=1, total_runs=3)
|
||||
|
||||
assert record.decision == HaltDecision.CONTINUE
|
||||
|
||||
|
||||
class TestMarginalImprovementDecision:
|
||||
"""Tests for marginal improvement halting."""
|
||||
|
||||
def test_no_improvement_halts(self):
|
||||
"""Test halting when improvement is below threshold."""
|
||||
policy = QualityPolicy(max_iterations=10, min_improvement=0.05)
|
||||
engine = HaltingPolicyEngine(policy)
|
||||
|
||||
results = [_make_result(ValidationStatus.FAIL, 0.52)]
|
||||
record = engine.evaluate(
|
||||
results,
|
||||
iteration=2,
|
||||
score_history=[0.50], # improvement: 0.02 < 0.05
|
||||
)
|
||||
|
||||
assert record.decision == HaltDecision.HALT_NO_IMPROVEMENT
|
||||
|
||||
def test_sufficient_improvement_continues(self):
|
||||
"""Test continuing when improvement meets threshold."""
|
||||
policy = QualityPolicy(max_iterations=10, min_improvement=0.05)
|
||||
engine = HaltingPolicyEngine(policy)
|
||||
|
||||
results = [_make_result(ValidationStatus.FAIL, 0.60)]
|
||||
record = engine.evaluate(
|
||||
results,
|
||||
iteration=2,
|
||||
score_history=[0.50], # improvement: 0.10 >= 0.05
|
||||
)
|
||||
|
||||
assert record.decision == HaltDecision.CONTINUE
|
||||
|
||||
def test_first_iteration_no_history(self):
|
||||
"""Test first iteration with no history continues."""
|
||||
policy = QualityPolicy(max_iterations=10, min_improvement=0.05)
|
||||
engine = HaltingPolicyEngine(policy)
|
||||
|
||||
results = [_make_result(ValidationStatus.FAIL, 0.50)]
|
||||
record = engine.evaluate(results, iteration=1)
|
||||
|
||||
assert record.decision == HaltDecision.CONTINUE
|
||||
|
||||
|
||||
class TestPriorityOrder:
|
||||
"""Tests for the priority order of halting checks."""
|
||||
|
||||
def test_budget_checked_before_iteration(self):
|
||||
"""Test budget exhaustion takes priority over iteration limit."""
|
||||
policy = QualityPolicy(max_iterations=3, resource_budget=2)
|
||||
engine = HaltingPolicyEngine(policy)
|
||||
|
||||
results = [_make_result(ValidationStatus.FAIL, 0.5)]
|
||||
record = engine.evaluate(results, iteration=3, total_runs=2)
|
||||
|
||||
assert record.decision == HaltDecision.HALT_BUDGET_EXHAUSTED
|
||||
|
||||
def test_iteration_checked_before_quality(self):
|
||||
"""Test iteration limit checked before quality met."""
|
||||
policy = QualityPolicy(max_iterations=2)
|
||||
engine = HaltingPolicyEngine(policy)
|
||||
|
||||
results = [_make_result(ValidationStatus.PASS, 1.0)]
|
||||
# At iteration limit AND quality met — iteration limit wins
|
||||
# Actually quality is checked after iteration, so quality would be checked
|
||||
# But iteration 2 >= max 2 triggers first
|
||||
record = engine.evaluate(results, iteration=2)
|
||||
|
||||
assert record.decision == HaltDecision.HALT_ITERATION_LIMIT
|
||||
|
||||
|
||||
class TestHaltingRecord:
|
||||
"""Tests for HaltingRecord attributes."""
|
||||
|
||||
def test_record_has_scores(self):
|
||||
"""Test halting record includes score history."""
|
||||
policy = QualityPolicy(max_iterations=5)
|
||||
engine = HaltingPolicyEngine(policy)
|
||||
|
||||
results = [_make_result(ValidationStatus.PASS, 0.9)]
|
||||
record = engine.evaluate(
|
||||
results, iteration=2, score_history=[0.5]
|
||||
)
|
||||
|
||||
assert record.scores == [0.5, 0.9]
|
||||
|
||||
def test_record_to_dict(self):
|
||||
"""Test halting record serialization."""
|
||||
policy = QualityPolicy(max_iterations=3)
|
||||
engine = HaltingPolicyEngine(policy)
|
||||
|
||||
results = [_make_result(ValidationStatus.PASS, 1.0)]
|
||||
record = engine.evaluate(results, iteration=1)
|
||||
|
||||
d = record.to_dict()
|
||||
assert d["decision"] == "halted_quality_met"
|
||||
assert d["iteration"] == 1
|
||||
assert d["max_iterations"] == 3
|
||||
303
tests/unit/prompts/test_quality_gates.py
Normal file
303
tests/unit/prompts/test_quality_gates.py
Normal file
@@ -0,0 +1,303 @@
|
||||
"""
|
||||
Unit tests for quality gate models and individual gate implementations.
|
||||
|
||||
Tests QualityGate models, SchemaValidationGate, and PatternValidationGate.
|
||||
"""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
|
||||
from markitect.prompts.quality.models import (
|
||||
GateType,
|
||||
ValidationStatus,
|
||||
ValidationDiagnostic,
|
||||
ValidationResult,
|
||||
QualityPolicy,
|
||||
HaltDecision,
|
||||
HaltingRecord,
|
||||
)
|
||||
from markitect.prompts.quality.gates.schema_gate import SchemaValidationGate
|
||||
from markitect.prompts.quality.gates.pattern_gate import PatternValidationGate
|
||||
|
||||
|
||||
class TestValidationDiagnostic:
|
||||
"""Tests for ValidationDiagnostic."""
|
||||
|
||||
def test_create_diagnostic(self):
|
||||
"""Test creating a diagnostic."""
|
||||
diag = ValidationDiagnostic(
|
||||
code="TEST_ERROR",
|
||||
message="Something went wrong",
|
||||
severity="error",
|
||||
)
|
||||
assert diag.code == "TEST_ERROR"
|
||||
assert diag.severity == "error"
|
||||
|
||||
def test_to_dict(self):
|
||||
"""Test diagnostic serialization."""
|
||||
diag = ValidationDiagnostic(code="E1", message="msg", severity="warning")
|
||||
d = diag.to_dict()
|
||||
assert d["code"] == "E1"
|
||||
assert d["severity"] == "warning"
|
||||
|
||||
def test_from_dict(self):
|
||||
"""Test diagnostic deserialization."""
|
||||
data = {"code": "E1", "message": "msg", "severity": "info"}
|
||||
diag = ValidationDiagnostic.from_dict(data)
|
||||
assert diag.code == "E1"
|
||||
assert diag.severity == "info"
|
||||
|
||||
|
||||
class TestValidationResult:
|
||||
"""Tests for ValidationResult."""
|
||||
|
||||
def test_create_result(self):
|
||||
"""Test creating a validation result."""
|
||||
result = ValidationResult.create(
|
||||
gate_id="gate-1",
|
||||
gate_type=GateType.SCHEMA,
|
||||
artifact_id="art-1",
|
||||
status=ValidationStatus.PASS,
|
||||
score=1.0,
|
||||
)
|
||||
assert result.gate_id == "gate-1"
|
||||
assert result.status == ValidationStatus.PASS
|
||||
assert result.score == 1.0
|
||||
|
||||
def test_result_unique_ids(self):
|
||||
"""Test that each result gets a unique ID."""
|
||||
r1 = ValidationResult.create(
|
||||
gate_id="g", gate_type=GateType.PATTERN,
|
||||
artifact_id="a", status=ValidationStatus.PASS,
|
||||
)
|
||||
r2 = ValidationResult.create(
|
||||
gate_id="g", gate_type=GateType.PATTERN,
|
||||
artifact_id="a", status=ValidationStatus.PASS,
|
||||
)
|
||||
assert r1.id != r2.id
|
||||
|
||||
def test_to_dict_from_dict(self):
|
||||
"""Test round-trip serialization."""
|
||||
result = ValidationResult.create(
|
||||
gate_id="g1",
|
||||
gate_type=GateType.SCHEMA,
|
||||
artifact_id="a1",
|
||||
status=ValidationStatus.FAIL,
|
||||
score=0.5,
|
||||
diagnostics=[
|
||||
ValidationDiagnostic(code="E1", message="err", severity="error"),
|
||||
],
|
||||
)
|
||||
d = result.to_dict()
|
||||
restored = ValidationResult.from_dict(d)
|
||||
assert restored.id == result.id
|
||||
assert restored.status == ValidationStatus.FAIL
|
||||
assert restored.score == 0.5
|
||||
assert len(restored.diagnostics) == 1
|
||||
|
||||
|
||||
class TestQualityPolicy:
|
||||
"""Tests for QualityPolicy."""
|
||||
|
||||
def test_default_policy(self):
|
||||
"""Test default policy values."""
|
||||
policy = QualityPolicy()
|
||||
assert policy.max_iterations == 3
|
||||
assert policy.min_improvement == 0.05
|
||||
assert policy.fail_on_gate_failure is True
|
||||
assert policy.resource_budget == 10
|
||||
|
||||
def test_to_dict_from_dict(self):
|
||||
"""Test round-trip serialization."""
|
||||
policy = QualityPolicy(
|
||||
max_iterations=5,
|
||||
min_improvement=0.1,
|
||||
required_gate_ids=["g1", "g2"],
|
||||
)
|
||||
d = policy.to_dict()
|
||||
restored = QualityPolicy.from_dict(d)
|
||||
assert restored.max_iterations == 5
|
||||
assert restored.min_improvement == 0.1
|
||||
assert restored.required_gate_ids == ["g1", "g2"]
|
||||
|
||||
|
||||
class TestSchemaValidationGate:
|
||||
"""Tests for SchemaValidationGate."""
|
||||
|
||||
def test_valid_json_passes(self):
|
||||
"""Test valid JSON against schema passes."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"required": ["name", "version"],
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"version": {"type": "integer"},
|
||||
},
|
||||
}
|
||||
gate = SchemaValidationGate(schema=schema, name="test-schema")
|
||||
content = json.dumps({"name": "test", "version": 1})
|
||||
|
||||
result = gate.validate(content, "art-1")
|
||||
assert result.status == ValidationStatus.PASS
|
||||
assert result.score == 1.0
|
||||
assert len(result.diagnostics) == 0
|
||||
|
||||
def test_missing_required_field_fails(self):
|
||||
"""Test missing required field fails validation."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"required": ["name", "version"],
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"version": {"type": "integer"},
|
||||
},
|
||||
}
|
||||
gate = SchemaValidationGate(schema=schema)
|
||||
content = json.dumps({"name": "test"})
|
||||
|
||||
result = gate.validate(content, "art-1")
|
||||
assert result.status == ValidationStatus.FAIL
|
||||
assert any(d.code == "SCHEMA_VIOLATION" for d in result.diagnostics)
|
||||
|
||||
def test_wrong_type_fails(self):
|
||||
"""Test wrong type fails validation."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"count": {"type": "integer"},
|
||||
},
|
||||
}
|
||||
gate = SchemaValidationGate(schema=schema)
|
||||
content = json.dumps({"count": "not-a-number"})
|
||||
|
||||
result = gate.validate(content, "art-1")
|
||||
assert result.status == ValidationStatus.FAIL
|
||||
|
||||
def test_invalid_json_fails(self):
|
||||
"""Test invalid JSON content fails."""
|
||||
schema = {"type": "object"}
|
||||
gate = SchemaValidationGate(schema=schema)
|
||||
|
||||
result = gate.validate("not json {{{", "art-1")
|
||||
assert result.status == ValidationStatus.FAIL
|
||||
assert result.score == 0.0
|
||||
assert any(d.code == "INVALID_JSON" for d in result.diagnostics)
|
||||
|
||||
def test_gate_has_correct_type(self):
|
||||
"""Test gate type is SCHEMA."""
|
||||
gate = SchemaValidationGate(schema={"type": "object"})
|
||||
assert gate.gate_type == GateType.SCHEMA
|
||||
|
||||
def test_empty_schema_passes_any_object(self):
|
||||
"""Test empty schema passes any valid JSON."""
|
||||
gate = SchemaValidationGate(schema={})
|
||||
result = gate.validate(json.dumps({"any": "thing"}), "art-1")
|
||||
assert result.status == ValidationStatus.PASS
|
||||
|
||||
def test_score_reflects_error_count(self):
|
||||
"""Test that score decreases with more errors."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"required": ["a", "b", "c"],
|
||||
"properties": {
|
||||
"a": {"type": "string"},
|
||||
"b": {"type": "string"},
|
||||
"c": {"type": "string"},
|
||||
},
|
||||
}
|
||||
gate = SchemaValidationGate(schema=schema)
|
||||
|
||||
# Missing all 3 required fields
|
||||
result = gate.validate(json.dumps({}), "art-1")
|
||||
assert result.status == ValidationStatus.FAIL
|
||||
assert result.score < 1.0
|
||||
|
||||
|
||||
class TestPatternValidationGate:
|
||||
"""Tests for PatternValidationGate."""
|
||||
|
||||
def test_required_pattern_present(self):
|
||||
"""Test content with required patterns passes."""
|
||||
gate = PatternValidationGate(
|
||||
required_patterns=[r"## Endpoints", r"### Authentication"],
|
||||
)
|
||||
content = "# API\n## Endpoints\n### Authentication\nDetails here."
|
||||
|
||||
result = gate.validate(content, "art-1")
|
||||
assert result.status == ValidationStatus.PASS
|
||||
assert result.score == 1.0
|
||||
|
||||
def test_required_pattern_missing(self):
|
||||
"""Test missing required pattern fails."""
|
||||
gate = PatternValidationGate(
|
||||
required_patterns=[r"## Endpoints", r"### Authentication"],
|
||||
)
|
||||
content = "# API\n## Endpoints\nNo auth section."
|
||||
|
||||
result = gate.validate(content, "art-1")
|
||||
assert result.status == ValidationStatus.FAIL
|
||||
assert any(d.code == "MISSING_PATTERN" for d in result.diagnostics)
|
||||
|
||||
def test_forbidden_pattern_absent(self):
|
||||
"""Test content without forbidden patterns passes."""
|
||||
gate = PatternValidationGate(
|
||||
forbidden_patterns=[r"TODO", r"FIXME"],
|
||||
)
|
||||
content = "This is a clean document."
|
||||
|
||||
result = gate.validate(content, "art-1")
|
||||
assert result.status == ValidationStatus.PASS
|
||||
|
||||
def test_forbidden_pattern_present(self):
|
||||
"""Test content with forbidden pattern fails."""
|
||||
gate = PatternValidationGate(
|
||||
forbidden_patterns=[r"TODO", r"FIXME"],
|
||||
)
|
||||
content = "This needs work. TODO: fix this."
|
||||
|
||||
result = gate.validate(content, "art-1")
|
||||
assert result.status == ValidationStatus.FAIL
|
||||
assert any(d.code == "FORBIDDEN_PATTERN" for d in result.diagnostics)
|
||||
|
||||
def test_combined_required_and_forbidden(self):
|
||||
"""Test both required and forbidden patterns together."""
|
||||
gate = PatternValidationGate(
|
||||
required_patterns=[r"## Summary"],
|
||||
forbidden_patterns=[r"FIXME"],
|
||||
)
|
||||
# Has required, no forbidden
|
||||
result1 = gate.validate("## Summary\nAll good.", "art-1")
|
||||
assert result1.status == ValidationStatus.PASS
|
||||
|
||||
# Has required AND forbidden
|
||||
result2 = gate.validate("## Summary\nFIXME: broken.", "art-1")
|
||||
assert result2.status == ValidationStatus.FAIL
|
||||
|
||||
def test_no_patterns_passes(self):
|
||||
"""Test gate with no patterns always passes."""
|
||||
gate = PatternValidationGate()
|
||||
result = gate.validate("anything", "art-1")
|
||||
assert result.status == ValidationStatus.PASS
|
||||
|
||||
def test_gate_has_correct_type(self):
|
||||
"""Test gate type is PATTERN."""
|
||||
gate = PatternValidationGate()
|
||||
assert gate.gate_type == GateType.PATTERN
|
||||
|
||||
def test_score_proportional_to_failures(self):
|
||||
"""Test score is proportional to number of checks passed."""
|
||||
gate = PatternValidationGate(
|
||||
required_patterns=[r"A", r"B", r"C", r"D"],
|
||||
)
|
||||
# Only A is present → 3 out of 4 fail
|
||||
result = gate.validate("A", "art-1")
|
||||
assert result.status == ValidationStatus.FAIL
|
||||
assert 0.0 < result.score < 1.0
|
||||
|
||||
def test_regex_pattern_matching(self):
|
||||
"""Test regex patterns work correctly."""
|
||||
gate = PatternValidationGate(
|
||||
required_patterns=[r"\d{3}-\d{4}"],
|
||||
)
|
||||
result = gate.validate("Call 555-1234 for info", "art-1")
|
||||
assert result.status == ValidationStatus.PASS
|
||||
264
tests/unit/prompts/test_quality_validator.py
Normal file
264
tests/unit/prompts/test_quality_validator.py
Normal file
@@ -0,0 +1,264 @@
|
||||
"""
|
||||
Unit tests for QualityValidator.
|
||||
|
||||
Tests applying multiple gates, aggregating results, and persistence.
|
||||
"""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from markitect.prompts.quality.models import (
|
||||
GateType,
|
||||
ValidationStatus,
|
||||
ValidationResult,
|
||||
)
|
||||
from markitect.prompts.quality.gates.schema_gate import SchemaValidationGate
|
||||
from markitect.prompts.quality.gates.pattern_gate import PatternValidationGate
|
||||
from markitect.prompts.quality.validator import QualityValidator
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_db():
|
||||
"""Create temporary database for testing."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||
db_path = f.name
|
||||
yield db_path
|
||||
Path(db_path).unlink(missing_ok=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def schema_gate():
|
||||
"""Create a simple schema gate."""
|
||||
return SchemaValidationGate(
|
||||
schema={
|
||||
"type": "object",
|
||||
"required": ["name"],
|
||||
"properties": {"name": {"type": "string"}},
|
||||
},
|
||||
gate_id="schema-gate-1",
|
||||
name="test-schema",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def pattern_gate():
|
||||
"""Create a simple pattern gate."""
|
||||
return PatternValidationGate(
|
||||
required_patterns=[r"## Summary"],
|
||||
forbidden_patterns=[r"TODO"],
|
||||
gate_id="pattern-gate-1",
|
||||
name="test-pattern",
|
||||
)
|
||||
|
||||
|
||||
class TestValidateArtifact:
|
||||
"""Tests for validating artifacts with multiple gates."""
|
||||
|
||||
def test_all_gates_pass(self, schema_gate, pattern_gate):
|
||||
"""Test all gates passing."""
|
||||
validator = QualityValidator(gates=[schema_gate, pattern_gate])
|
||||
|
||||
# Content that satisfies both gates (JSON for schema, text for pattern)
|
||||
# Schema gate needs JSON, pattern gate needs text patterns
|
||||
# Use separate validators for different content types
|
||||
schema_validator = QualityValidator(gates=[schema_gate])
|
||||
results = schema_validator.validate_artifact(
|
||||
json.dumps({"name": "test"}), "art-1"
|
||||
)
|
||||
assert len(results) == 1
|
||||
assert results[0].status == ValidationStatus.PASS
|
||||
|
||||
def test_pattern_gate_validates(self, pattern_gate):
|
||||
"""Test pattern gate validation."""
|
||||
validator = QualityValidator(gates=[pattern_gate])
|
||||
results = validator.validate_artifact(
|
||||
"## Summary\nAll good here.", "art-1"
|
||||
)
|
||||
assert len(results) == 1
|
||||
assert results[0].status == ValidationStatus.PASS
|
||||
|
||||
def test_multiple_gates_mixed_results(self, pattern_gate):
|
||||
"""Test multiple gates with mixed pass/fail."""
|
||||
gate_a = PatternValidationGate(
|
||||
required_patterns=[r"## Summary"],
|
||||
gate_id="gate-a",
|
||||
)
|
||||
gate_b = PatternValidationGate(
|
||||
required_patterns=[r"## Missing Section"],
|
||||
gate_id="gate-b",
|
||||
)
|
||||
validator = QualityValidator(gates=[gate_a, gate_b])
|
||||
|
||||
results = validator.validate_artifact("## Summary\nContent.", "art-1")
|
||||
assert len(results) == 2
|
||||
statuses = {r.gate_id: r.status for r in results}
|
||||
assert statuses["gate-a"] == ValidationStatus.PASS
|
||||
assert statuses["gate-b"] == ValidationStatus.FAIL
|
||||
|
||||
def test_no_gates_returns_empty(self):
|
||||
"""Test validator with no gates returns empty list."""
|
||||
validator = QualityValidator()
|
||||
results = validator.validate_artifact("content", "art-1")
|
||||
assert results == []
|
||||
|
||||
|
||||
class TestAllPassed:
|
||||
"""Tests for the all_passed helper."""
|
||||
|
||||
def test_all_pass(self):
|
||||
"""Test all_passed returns True when all pass."""
|
||||
validator = QualityValidator()
|
||||
results = [
|
||||
ValidationResult.create(
|
||||
gate_id="g1", gate_type=GateType.PATTERN,
|
||||
artifact_id="a", status=ValidationStatus.PASS,
|
||||
),
|
||||
ValidationResult.create(
|
||||
gate_id="g2", gate_type=GateType.PATTERN,
|
||||
artifact_id="a", status=ValidationStatus.PASS,
|
||||
),
|
||||
]
|
||||
assert validator.all_passed(results) is True
|
||||
|
||||
def test_one_fails(self):
|
||||
"""Test all_passed returns False when one fails."""
|
||||
validator = QualityValidator()
|
||||
results = [
|
||||
ValidationResult.create(
|
||||
gate_id="g1", gate_type=GateType.PATTERN,
|
||||
artifact_id="a", status=ValidationStatus.PASS,
|
||||
),
|
||||
ValidationResult.create(
|
||||
gate_id="g2", gate_type=GateType.PATTERN,
|
||||
artifact_id="a", status=ValidationStatus.FAIL,
|
||||
),
|
||||
]
|
||||
assert validator.all_passed(results) is False
|
||||
|
||||
def test_empty_results(self):
|
||||
"""Test all_passed with empty list returns True."""
|
||||
validator = QualityValidator()
|
||||
assert validator.all_passed([]) is True
|
||||
|
||||
|
||||
class TestAggregateScore:
|
||||
"""Tests for aggregate score calculation."""
|
||||
|
||||
def test_average_scores(self):
|
||||
"""Test aggregate is average of scores."""
|
||||
validator = QualityValidator()
|
||||
results = [
|
||||
ValidationResult.create(
|
||||
gate_id="g1", gate_type=GateType.PATTERN,
|
||||
artifact_id="a", status=ValidationStatus.PASS, score=1.0,
|
||||
),
|
||||
ValidationResult.create(
|
||||
gate_id="g2", gate_type=GateType.PATTERN,
|
||||
artifact_id="a", status=ValidationStatus.FAIL, score=0.5,
|
||||
),
|
||||
]
|
||||
assert validator.aggregate_score(results) == 0.75
|
||||
|
||||
def test_no_results(self):
|
||||
"""Test aggregate with no results returns 1.0."""
|
||||
validator = QualityValidator()
|
||||
assert validator.aggregate_score([]) == 1.0
|
||||
|
||||
def test_none_scores_ignored(self):
|
||||
"""Test results with None scores are handled."""
|
||||
validator = QualityValidator()
|
||||
results = [
|
||||
ValidationResult.create(
|
||||
gate_id="g1", gate_type=GateType.PATTERN,
|
||||
artifact_id="a", status=ValidationStatus.PASS, score=None,
|
||||
),
|
||||
]
|
||||
assert validator.aggregate_score(results) == 1.0
|
||||
|
||||
|
||||
class TestGetFailedGates:
|
||||
"""Tests for getting failed gates."""
|
||||
|
||||
def test_get_failed(self):
|
||||
"""Test filtering failed results."""
|
||||
validator = QualityValidator()
|
||||
results = [
|
||||
ValidationResult.create(
|
||||
gate_id="g1", gate_type=GateType.PATTERN,
|
||||
artifact_id="a", status=ValidationStatus.PASS,
|
||||
),
|
||||
ValidationResult.create(
|
||||
gate_id="g2", gate_type=GateType.PATTERN,
|
||||
artifact_id="a", status=ValidationStatus.FAIL,
|
||||
),
|
||||
]
|
||||
failed = validator.get_failed_gates(results)
|
||||
assert len(failed) == 1
|
||||
assert failed[0].gate_id == "g2"
|
||||
|
||||
|
||||
class TestResultsToManifest:
|
||||
"""Tests for converting results to manifest dict."""
|
||||
|
||||
def test_manifest_dict_format(self):
|
||||
"""Test manifest dict has correct structure."""
|
||||
validator = QualityValidator()
|
||||
results = [
|
||||
ValidationResult.create(
|
||||
gate_id="g1", gate_type=GateType.PATTERN,
|
||||
artifact_id="a", status=ValidationStatus.PASS, score=1.0,
|
||||
),
|
||||
]
|
||||
manifest = validator.results_to_manifest_dict(results)
|
||||
assert "quality_gates" in manifest
|
||||
assert manifest["all_passed"] is True
|
||||
assert manifest["aggregate_score"] == 1.0
|
||||
|
||||
|
||||
class TestPersistence:
|
||||
"""Tests for persisting validation results."""
|
||||
|
||||
def test_persist_and_retrieve_by_run(self, temp_db, pattern_gate):
|
||||
"""Test persisting results and retrieving by run ID."""
|
||||
validator = QualityValidator(gates=[pattern_gate], db_path=temp_db)
|
||||
|
||||
validator.validate_artifact(
|
||||
"## Summary\nClean content.", "art-1", run_id="run-1"
|
||||
)
|
||||
|
||||
results = validator.get_results_for_run("run-1")
|
||||
assert len(results) == 1
|
||||
assert results[0]["status"] == "pass"
|
||||
|
||||
def test_persist_and_retrieve_by_artifact(self, temp_db, pattern_gate):
|
||||
"""Test persisting results and retrieving by artifact ID."""
|
||||
validator = QualityValidator(gates=[pattern_gate], db_path=temp_db)
|
||||
|
||||
validator.validate_artifact(
|
||||
"## Summary\nClean content.", "art-1", run_id="run-1"
|
||||
)
|
||||
|
||||
results = validator.get_results_for_artifact("art-1")
|
||||
assert len(results) == 1
|
||||
assert results[0]["artifact_id"] == "art-1"
|
||||
|
||||
def test_no_persistence_without_db(self, pattern_gate):
|
||||
"""Test no persistence when db_path is None."""
|
||||
validator = QualityValidator(gates=[pattern_gate])
|
||||
results = validator.validate_artifact(
|
||||
"## Summary\nContent.", "art-1", run_id="run-1"
|
||||
)
|
||||
assert len(results) == 1
|
||||
# No DB queries should work
|
||||
assert validator.get_results_for_run("run-1") == []
|
||||
|
||||
def test_add_gate(self):
|
||||
"""Test adding a gate after construction."""
|
||||
validator = QualityValidator()
|
||||
assert len(validator.gates) == 0
|
||||
|
||||
gate = PatternValidationGate(required_patterns=[r"test"])
|
||||
validator.add_gate(gate)
|
||||
assert len(validator.gates) == 1
|
||||
219
tests/unit/prompts/test_refinement_loop.py
Normal file
219
tests/unit/prompts/test_refinement_loop.py
Normal file
@@ -0,0 +1,219 @@
|
||||
"""
|
||||
Unit tests for RefinementLoop.
|
||||
|
||||
Tests the execute → validate → halt or refine cycle.
|
||||
"""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from typing import List, Tuple
|
||||
|
||||
from markitect.prompts.quality.models import (
|
||||
HaltDecision,
|
||||
QualityPolicy,
|
||||
ValidationResult,
|
||||
ValidationStatus,
|
||||
)
|
||||
from markitect.prompts.quality.gates.pattern_gate import PatternValidationGate
|
||||
from markitect.prompts.quality.validator import QualityValidator
|
||||
from markitect.prompts.quality.refinement import RefinementLoop
|
||||
|
||||
|
||||
class TestRefinementLoopBasic:
|
||||
"""Tests for basic refinement loop operation."""
|
||||
|
||||
def test_immediate_quality_met(self):
|
||||
"""Test loop halts immediately when quality is met on first iteration."""
|
||||
gate = PatternValidationGate(
|
||||
required_patterns=[r"## Summary"],
|
||||
gate_id="gate-1",
|
||||
)
|
||||
validator = QualityValidator(gates=[gate])
|
||||
policy = QualityPolicy(max_iterations=5)
|
||||
loop = RefinementLoop(validator, policy)
|
||||
|
||||
def callback(iteration, prev_results):
|
||||
return (f"run-{iteration}", "## Summary\nGood content.", "art-1")
|
||||
|
||||
result = loop.run(callback, "art-1")
|
||||
|
||||
assert result.iterations_run == 1
|
||||
assert result.halting_record.decision == HaltDecision.HALT_QUALITY_MET
|
||||
assert len(result.run_ids) == 1
|
||||
|
||||
def test_iterates_until_quality_met(self):
|
||||
"""Test loop iterates until quality requirements are met."""
|
||||
gate = PatternValidationGate(
|
||||
required_patterns=[r"## Summary", r"## Details"],
|
||||
gate_id="gate-1",
|
||||
)
|
||||
validator = QualityValidator(gates=[gate])
|
||||
policy = QualityPolicy(max_iterations=5, min_improvement=0.0)
|
||||
loop = RefinementLoop(validator, policy)
|
||||
|
||||
contents = [
|
||||
"## Summary\nOnly summary.", # missing Details
|
||||
"## Summary\nStill only summary.", # still missing, slightly different
|
||||
"## Summary\n## Details\nBoth sections.", # complete
|
||||
]
|
||||
|
||||
def callback(iteration, prev_results):
|
||||
content = contents[min(iteration - 1, len(contents) - 1)]
|
||||
return (f"run-{iteration}", content, "art-1")
|
||||
|
||||
result = loop.run(callback, "art-1")
|
||||
|
||||
assert result.iterations_run == 3
|
||||
assert result.halting_record.decision == HaltDecision.HALT_QUALITY_MET
|
||||
assert len(result.run_ids) == 3
|
||||
|
||||
def test_hits_iteration_limit(self):
|
||||
"""Test loop stops at iteration limit."""
|
||||
gate = PatternValidationGate(
|
||||
required_patterns=[r"IMPOSSIBLE_PATTERN_12345"],
|
||||
gate_id="gate-1",
|
||||
)
|
||||
validator = QualityValidator(gates=[gate])
|
||||
policy = QualityPolicy(max_iterations=3, min_improvement=0.0)
|
||||
loop = RefinementLoop(validator, policy)
|
||||
|
||||
def callback(iteration, prev_results):
|
||||
return (f"run-{iteration}", "never matches", "art-1")
|
||||
|
||||
result = loop.run(callback, "art-1")
|
||||
|
||||
assert result.iterations_run == 3
|
||||
assert result.halting_record.decision == HaltDecision.HALT_ITERATION_LIMIT
|
||||
|
||||
def test_budget_exhaustion(self):
|
||||
"""Test loop stops when resource budget is exhausted."""
|
||||
gate = PatternValidationGate(
|
||||
required_patterns=[r"IMPOSSIBLE"],
|
||||
gate_id="gate-1",
|
||||
)
|
||||
validator = QualityValidator(gates=[gate])
|
||||
policy = QualityPolicy(max_iterations=10, resource_budget=2)
|
||||
loop = RefinementLoop(validator, policy)
|
||||
|
||||
def callback(iteration, prev_results):
|
||||
return (f"run-{iteration}", "content", "art-1")
|
||||
|
||||
result = loop.run(callback, "art-1")
|
||||
|
||||
assert result.iterations_run == 2
|
||||
assert result.halting_record.decision == HaltDecision.HALT_BUDGET_EXHAUSTED
|
||||
|
||||
|
||||
class TestRefinementLoopHistory:
|
||||
"""Tests for refinement loop tracking history."""
|
||||
|
||||
def test_all_results_tracked(self):
|
||||
"""Test all iteration results are tracked."""
|
||||
gate = PatternValidationGate(
|
||||
required_patterns=[r"## Summary"],
|
||||
gate_id="gate-1",
|
||||
)
|
||||
validator = QualityValidator(gates=[gate])
|
||||
policy = QualityPolicy(max_iterations=3)
|
||||
loop = RefinementLoop(validator, policy)
|
||||
|
||||
iteration_count = [0]
|
||||
|
||||
def callback(iteration, prev_results):
|
||||
iteration_count[0] = iteration
|
||||
if iteration < 3:
|
||||
return (f"run-{iteration}", "no match", "art-1")
|
||||
return (f"run-{iteration}", "## Summary\nDone.", "art-1")
|
||||
|
||||
result = loop.run(callback, "art-1")
|
||||
|
||||
assert len(result.all_results) == result.iterations_run
|
||||
assert len(result.run_ids) == result.iterations_run
|
||||
|
||||
def test_run_ids_collected(self):
|
||||
"""Test run IDs are collected from each iteration."""
|
||||
gate = PatternValidationGate(gate_id="gate-1")
|
||||
validator = QualityValidator(gates=[gate])
|
||||
policy = QualityPolicy(max_iterations=3)
|
||||
loop = RefinementLoop(validator, policy)
|
||||
|
||||
def callback(iteration, prev_results):
|
||||
return (f"run-{iteration}", "content", "art-1")
|
||||
|
||||
result = loop.run(callback, "art-1")
|
||||
|
||||
assert result.run_ids[0] == "run-1"
|
||||
|
||||
def test_previous_results_passed_to_callback(self):
|
||||
"""Test callback receives previous iteration results."""
|
||||
gate = PatternValidationGate(
|
||||
required_patterns=[r"## Final"],
|
||||
gate_id="gate-1",
|
||||
)
|
||||
validator = QualityValidator(gates=[gate])
|
||||
policy = QualityPolicy(max_iterations=3)
|
||||
loop = RefinementLoop(validator, policy)
|
||||
|
||||
received_prev: List[List[ValidationResult]] = []
|
||||
|
||||
def callback(iteration, prev_results):
|
||||
received_prev.append(prev_results)
|
||||
if iteration >= 2:
|
||||
return (f"run-{iteration}", "## Final\nDone.", "art-1")
|
||||
return (f"run-{iteration}", "incomplete", "art-1")
|
||||
|
||||
result = loop.run(callback, "art-1")
|
||||
|
||||
# First iteration gets empty prev_results
|
||||
assert len(received_prev[0]) == 0
|
||||
# Second iteration gets results from first
|
||||
assert len(received_prev[1]) == 1
|
||||
|
||||
|
||||
class TestRefinementLoopNoImprovement:
|
||||
"""Tests for no-improvement halting in refinement loop."""
|
||||
|
||||
def test_halts_on_no_improvement(self):
|
||||
"""Test loop halts when scores stop improving."""
|
||||
gate = PatternValidationGate(
|
||||
required_patterns=[r"## Required"],
|
||||
gate_id="gate-1",
|
||||
)
|
||||
validator = QualityValidator(gates=[gate])
|
||||
policy = QualityPolicy(
|
||||
max_iterations=10,
|
||||
min_improvement=0.1,
|
||||
)
|
||||
loop = RefinementLoop(validator, policy)
|
||||
|
||||
# Both iterations produce same failing content → same score → no improvement
|
||||
def callback(iteration, prev_results):
|
||||
return (f"run-{iteration}", "no match at all", "art-1")
|
||||
|
||||
result = loop.run(callback, "art-1")
|
||||
|
||||
# Second iteration should detect no improvement and halt
|
||||
assert result.iterations_run == 2
|
||||
assert result.halting_record.decision == HaltDecision.HALT_NO_IMPROVEMENT
|
||||
|
||||
|
||||
class TestRefinementResultSerialization:
|
||||
"""Tests for RefinementResult serialization."""
|
||||
|
||||
def test_to_dict(self):
|
||||
"""Test RefinementResult serialization."""
|
||||
gate = PatternValidationGate(gate_id="gate-1")
|
||||
validator = QualityValidator(gates=[gate])
|
||||
policy = QualityPolicy(max_iterations=1)
|
||||
loop = RefinementLoop(validator, policy)
|
||||
|
||||
def callback(iteration, prev_results):
|
||||
return ("run-1", "content", "art-1")
|
||||
|
||||
result = loop.run(callback, "art-1")
|
||||
d = result.to_dict()
|
||||
|
||||
assert "iterations_run" in d
|
||||
assert "final_results" in d
|
||||
assert "halting_record" in d
|
||||
assert "run_ids" in d
|
||||
Reference in New Issue
Block a user