From 704272644c9d94b247a37909024125a45b0a6a27 Mon Sep 17 00:00:00 2001 From: tegwick Date: Mon, 9 Feb 2026 13:31:37 +0100 Subject: [PATCH] feat(prompts): implement Phase 7 - Quality & Validation (FR-9, FR-10) Add quality gate framework with schema validation (JSON Schema via jsonschema library), pattern validation (regex-based), multi-gate QualityValidator with SQLite persistence, HaltingPolicyEngine with budget/iteration/improvement checks, and RefinementLoop for iterative execute-validate-halt cycles. Co-Authored-By: Claude Opus 4.6 --- markitect/prompts/quality/__init__.py | 53 +++ markitect/prompts/quality/gates/__init__.py | 13 + .../prompts/quality/gates/pattern_gate.py | 109 +++++++ .../prompts/quality/gates/schema_gate.py | 123 +++++++ markitect/prompts/quality/models.py | 283 ++++++++++++++++ markitect/prompts/quality/policy.py | 153 +++++++++ markitect/prompts/quality/refinement.py | 108 +++++++ markitect/prompts/quality/validator.py | 294 +++++++++++++++++ .../prompts/006_create_quality_tables.sql | 25 ++ .../prompts/test_halting_execution.py | 239 ++++++++++++++ .../prompts/test_quality_validation.py | 208 ++++++++++++ tests/unit/prompts/test_halting_policy.py | 221 +++++++++++++ tests/unit/prompts/test_quality_gates.py | 303 ++++++++++++++++++ tests/unit/prompts/test_quality_validator.py | 264 +++++++++++++++ tests/unit/prompts/test_refinement_loop.py | 219 +++++++++++++ 15 files changed, 2615 insertions(+) create mode 100644 markitect/prompts/quality/__init__.py create mode 100644 markitect/prompts/quality/gates/__init__.py create mode 100644 markitect/prompts/quality/gates/pattern_gate.py create mode 100644 markitect/prompts/quality/gates/schema_gate.py create mode 100644 markitect/prompts/quality/models.py create mode 100644 markitect/prompts/quality/policy.py create mode 100644 markitect/prompts/quality/refinement.py create mode 100644 markitect/prompts/quality/validator.py create mode 100644 migrations/prompts/006_create_quality_tables.sql create mode 100644 tests/integration/prompts/test_halting_execution.py create mode 100644 tests/integration/prompts/test_quality_validation.py create mode 100644 tests/unit/prompts/test_halting_policy.py create mode 100644 tests/unit/prompts/test_quality_gates.py create mode 100644 tests/unit/prompts/test_quality_validator.py create mode 100644 tests/unit/prompts/test_refinement_loop.py diff --git a/markitect/prompts/quality/__init__.py b/markitect/prompts/quality/__init__.py new file mode 100644 index 00000000..f9a892ff --- /dev/null +++ b/markitect/prompts/quality/__init__.py @@ -0,0 +1,53 @@ +""" +Quality validation and halting policies for prompt artifacts. + +Implements FR-9: QualityGate Validation +Implements FR-10: Halting and Refinement Policy + +- FR-9.1: Schema validation against generated artifacts +- FR-9.2: Multiple QualityGates per artifact +- FR-9.3: Record pass/fail results and diagnostics +- FR-9.4: Halting policies based on QualityGate results +- FR-10.1: Configurable QualityPolicies +- FR-10.2: Halting decisions (quality, improvement, iterations, budget) +- FR-10.3: Record halting decisions in RunManifest +""" + +from markitect.prompts.quality.models import ( + GateType, + ValidationStatus, + HaltDecision, + ValidationDiagnostic, + ValidationResult, + QualityGate, + QualityPolicy, + HaltingRecord, + RefinementResult, +) +from markitect.prompts.quality.gates.schema_gate import SchemaValidationGate +from markitect.prompts.quality.gates.pattern_gate import PatternValidationGate +from markitect.prompts.quality.validator import QualityValidator +from markitect.prompts.quality.policy import HaltingPolicyEngine +from markitect.prompts.quality.refinement import RefinementLoop + +__all__ = [ + # Models + "GateType", + "ValidationStatus", + "HaltDecision", + "ValidationDiagnostic", + "ValidationResult", + "QualityGate", + "QualityPolicy", + "HaltingRecord", + "RefinementResult", + # Gates + "SchemaValidationGate", + "PatternValidationGate", + # Validator + "QualityValidator", + # Policy + "HaltingPolicyEngine", + # Refinement + "RefinementLoop", +] diff --git a/markitect/prompts/quality/gates/__init__.py b/markitect/prompts/quality/gates/__init__.py new file mode 100644 index 00000000..48b4aa5b --- /dev/null +++ b/markitect/prompts/quality/gates/__init__.py @@ -0,0 +1,13 @@ +""" +Quality gate implementations. + +Provides SchemaValidationGate and PatternValidationGate. +""" + +from markitect.prompts.quality.gates.schema_gate import SchemaValidationGate +from markitect.prompts.quality.gates.pattern_gate import PatternValidationGate + +__all__ = [ + "SchemaValidationGate", + "PatternValidationGate", +] diff --git a/markitect/prompts/quality/gates/pattern_gate.py b/markitect/prompts/quality/gates/pattern_gate.py new file mode 100644 index 00000000..4d012e5f --- /dev/null +++ b/markitect/prompts/quality/gates/pattern_gate.py @@ -0,0 +1,109 @@ +""" +Pattern validation quality gate. + +Validates content against required and forbidden regex patterns. +""" + +import re +import uuid +from typing import List, Optional + +from markitect.prompts.quality.models import ( + GateType, + QualityGate, + ValidationDiagnostic, + ValidationResult, + ValidationStatus, +) + + +class PatternValidationGate(QualityGate): + """ + Validates artifact content against regex patterns. + + Checks that all required patterns are present and no forbidden + patterns are found. + """ + + def __init__( + self, + required_patterns: Optional[List[str]] = None, + forbidden_patterns: Optional[List[str]] = None, + gate_id: Optional[str] = None, + name: str = "pattern", + ): + """ + Initialize with pattern lists. + + Args: + required_patterns: Regex patterns that must be found in content + forbidden_patterns: Regex patterns that must NOT be found in content + gate_id: Optional gate identifier + name: Human-readable gate name + """ + super().__init__( + gate_id=gate_id or str(uuid.uuid4()), + name=name, + gate_type=GateType.PATTERN, + ) + self.required_patterns = required_patterns or [] + self.forbidden_patterns = forbidden_patterns or [] + + def validate(self, content: str, artifact_id: str) -> ValidationResult: + """ + Validate content against required and forbidden patterns. + + Args: + content: Content string to validate + artifact_id: ID of the artifact being validated + + Returns: + ValidationResult with status and diagnostics + """ + diagnostics = [] + total_checks = len(self.required_patterns) + len(self.forbidden_patterns) + failures = 0 + + # Check required patterns + for pattern in self.required_patterns: + if not re.search(pattern, content): + diagnostics.append( + ValidationDiagnostic( + code="MISSING_PATTERN", + message=f"Required pattern not found: {pattern}", + severity="error", + ) + ) + failures += 1 + + # Check forbidden patterns + for pattern in self.forbidden_patterns: + match = re.search(pattern, content) + if match: + diagnostics.append( + ValidationDiagnostic( + code="FORBIDDEN_PATTERN", + message=f"Forbidden pattern found: {pattern} (matched: '{match.group()}')", + severity="error", + ) + ) + failures += 1 + + if total_checks == 0: + status = ValidationStatus.PASS + score = 1.0 + elif failures == 0: + status = ValidationStatus.PASS + score = 1.0 + else: + status = ValidationStatus.FAIL + score = max(0.0, 1.0 - failures / total_checks) + + return ValidationResult.create( + gate_id=self.id, + gate_type=self.gate_type, + artifact_id=artifact_id, + status=status, + score=score, + diagnostics=diagnostics, + ) diff --git a/markitect/prompts/quality/gates/schema_gate.py b/markitect/prompts/quality/gates/schema_gate.py new file mode 100644 index 00000000..58793310 --- /dev/null +++ b/markitect/prompts/quality/gates/schema_gate.py @@ -0,0 +1,123 @@ +""" +Schema validation quality gate. + +Implements FR-9.1: Validate generated artifacts against JSON schemas. +Uses the jsonschema library for validation. +""" + +import json +import uuid +from typing import Any, Dict, Optional + +import jsonschema + +from markitect.prompts.quality.models import ( + GateType, + QualityGate, + ValidationDiagnostic, + ValidationResult, + ValidationStatus, +) + + +class SchemaValidationGate(QualityGate): + """ + Validates artifact content against a JSON schema. + + Parses content as JSON and validates against the provided schema + using the jsonschema library. + """ + + def __init__( + self, + schema: Dict[str, Any], + gate_id: Optional[str] = None, + name: str = "schema", + ): + """ + Initialize with a JSON schema. + + Args: + schema: JSON Schema dictionary + gate_id: Optional gate identifier (auto-generated if not provided) + name: Human-readable gate name + """ + super().__init__( + gate_id=gate_id or str(uuid.uuid4()), + name=name, + gate_type=GateType.SCHEMA, + ) + self.schema = schema + + def validate(self, content: str, artifact_id: str) -> ValidationResult: + """ + Validate content against the JSON schema. + + Parses the content as JSON, then validates against the schema. + Returns FAIL if content is not valid JSON or fails schema validation. + + Args: + content: JSON content string to validate + artifact_id: ID of the artifact being validated + + Returns: + ValidationResult with status and diagnostics + """ + diagnostics = [] + + # Parse JSON + try: + data = json.loads(content) + except (json.JSONDecodeError, TypeError) as e: + diagnostics.append( + ValidationDiagnostic( + code="INVALID_JSON", + message=f"Content is not valid JSON: {e}", + severity="error", + ) + ) + return ValidationResult.create( + gate_id=self.id, + gate_type=self.gate_type, + artifact_id=artifact_id, + status=ValidationStatus.FAIL, + score=0.0, + diagnostics=diagnostics, + ) + + # Validate against schema + validator = jsonschema.Draft7Validator(self.schema) + errors = list(validator.iter_errors(data)) + + if not errors: + return ValidationResult.create( + gate_id=self.id, + gate_type=self.gate_type, + artifact_id=artifact_id, + status=ValidationStatus.PASS, + score=1.0, + diagnostics=[], + ) + + for error in errors: + path = ".".join(str(p) for p in error.absolute_path) or "(root)" + diagnostics.append( + ValidationDiagnostic( + code="SCHEMA_VIOLATION", + message=f"At '{path}': {error.message}", + severity="error", + ) + ) + + # Score based on proportion of passing validations + total_checks = len(errors) + 1 # approximate + score = max(0.0, 1.0 - len(errors) / total_checks) + + return ValidationResult.create( + gate_id=self.id, + gate_type=self.gate_type, + artifact_id=artifact_id, + status=ValidationStatus.FAIL, + score=score, + diagnostics=diagnostics, + ) diff --git a/markitect/prompts/quality/models.py b/markitect/prompts/quality/models.py new file mode 100644 index 00000000..39b76320 --- /dev/null +++ b/markitect/prompts/quality/models.py @@ -0,0 +1,283 @@ +""" +Data models for quality validation and halting policies. + +Implements FR-9: QualityGate Validation +Implements FR-10: Halting and Refinement Policy +""" + +import uuid +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum +from typing import Any, Dict, List, Optional + + +class GateType(Enum): + """Type classification for quality gates.""" + SCHEMA = "schema" + PATTERN = "pattern" + CUSTOM = "custom" + + +class ValidationStatus(Enum): + """Outcome status of a quality gate check.""" + PASS = "pass" + FAIL = "fail" + WARNING = "warning" + SKIPPED = "skipped" + + +class HaltDecision(Enum): + """Decision outcome from halting policy evaluation.""" + CONTINUE = "continue" + HALT_QUALITY_MET = "halted_quality_met" + HALT_ITERATION_LIMIT = "halted_iteration_limit" + HALT_BUDGET_EXHAUSTED = "halted_budget_exhausted" + HALT_NO_IMPROVEMENT = "halted_no_improvement" + + +@dataclass +class ValidationDiagnostic: + """ + Single diagnostic message from a quality gate. + + Attributes: + code: Machine-readable diagnostic code + message: Human-readable description + severity: Severity level (error, warning, info) + """ + code: str + message: str + severity: str = "error" + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary.""" + return { + "code": self.code, + "message": self.message, + "severity": self.severity, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ValidationDiagnostic": + """Create from dictionary.""" + return cls( + code=data["code"], + message=data["message"], + severity=data.get("severity", "error"), + ) + + +@dataclass +class ValidationResult: + """ + Result of applying a quality gate to an artifact. + + Implements FR-9.3: Record pass/fail results and diagnostics. + + Attributes: + id: Unique result identifier + gate_id: ID of the quality gate that produced this result + gate_type: Type of the quality gate + artifact_id: ID of the validated artifact + status: Pass/fail outcome + score: Optional quality score (0.0-1.0) + diagnostics: List of diagnostic messages + validated_at: When validation occurred + """ + id: str + gate_id: str + gate_type: GateType + artifact_id: str + status: ValidationStatus + score: Optional[float] = None + diagnostics: List[ValidationDiagnostic] = field(default_factory=list) + validated_at: datetime = field(default_factory=datetime.utcnow) + + @classmethod + def create( + cls, + gate_id: str, + gate_type: GateType, + artifact_id: str, + status: ValidationStatus, + score: Optional[float] = None, + diagnostics: Optional[List[ValidationDiagnostic]] = None, + ) -> "ValidationResult": + """Create a new ValidationResult.""" + return cls( + id=str(uuid.uuid4()), + gate_id=gate_id, + gate_type=gate_type, + artifact_id=artifact_id, + status=status, + score=score, + diagnostics=diagnostics or [], + ) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for serialization.""" + return { + "id": self.id, + "gate_id": self.gate_id, + "gate_type": self.gate_type.value, + "artifact_id": self.artifact_id, + "status": self.status.value, + "score": self.score, + "diagnostics": [d.to_dict() for d in self.diagnostics], + "validated_at": self.validated_at.isoformat(), + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "ValidationResult": + """Create from dictionary.""" + return cls( + id=data["id"], + gate_id=data["gate_id"], + gate_type=GateType(data["gate_type"]), + artifact_id=data["artifact_id"], + status=ValidationStatus(data["status"]), + score=data.get("score"), + diagnostics=[ + ValidationDiagnostic.from_dict(d) + for d in data.get("diagnostics", []) + ], + validated_at=datetime.fromisoformat(data["validated_at"]), + ) + + +class QualityGate(ABC): + """ + Abstract base class for quality gates. + + Implements FR-9.1/FR-9.2: Pluggable validation framework + supporting multiple gates per artifact. + """ + + def __init__(self, gate_id: str, name: str, gate_type: GateType): + self.id = gate_id + self.name = name + self.gate_type = gate_type + + @abstractmethod + def validate(self, content: str, artifact_id: str) -> ValidationResult: + """ + Validate content against this quality gate. + + Args: + content: Content to validate + artifact_id: ID of the artifact being validated + + Returns: + ValidationResult with status and diagnostics + """ + pass + + +@dataclass +class QualityPolicy: + """ + Configuration for halting and refinement policy. + + Implements FR-10.1: Configurable QualityPolicies. + + Attributes: + max_iterations: Maximum refinement iterations + min_improvement: Minimum score improvement to continue + fail_on_gate_failure: Whether any gate failure halts execution + resource_budget: Maximum total runs allowed + required_gate_ids: Gate IDs that must pass for quality to be met + """ + max_iterations: int = 3 + min_improvement: float = 0.05 + fail_on_gate_failure: bool = True + resource_budget: int = 10 + required_gate_ids: List[str] = field(default_factory=list) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary.""" + return { + "max_iterations": self.max_iterations, + "min_improvement": self.min_improvement, + "fail_on_gate_failure": self.fail_on_gate_failure, + "resource_budget": self.resource_budget, + "required_gate_ids": self.required_gate_ids, + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "QualityPolicy": + """Create from dictionary.""" + return cls( + max_iterations=data.get("max_iterations", 3), + min_improvement=data.get("min_improvement", 0.05), + fail_on_gate_failure=data.get("fail_on_gate_failure", True), + resource_budget=data.get("resource_budget", 10), + required_gate_ids=data.get("required_gate_ids", []), + ) + + +@dataclass +class HaltingRecord: + """ + Record of a halting decision. + + Implements FR-10.3: Record halting decisions in the RunManifest. + + Attributes: + decision: The halting decision + iteration: Current iteration number + max_iterations: Maximum allowed iterations + scores: Score history across iterations + reason: Human-readable reason for decision + recorded_at: When the decision was made + """ + decision: HaltDecision + iteration: int + max_iterations: int + scores: List[float] = field(default_factory=list) + reason: str = "" + recorded_at: datetime = field(default_factory=datetime.utcnow) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary for serialization.""" + return { + "decision": self.decision.value, + "iteration": self.iteration, + "max_iterations": self.max_iterations, + "scores": self.scores, + "reason": self.reason, + "recorded_at": self.recorded_at.isoformat(), + } + + +@dataclass +class RefinementResult: + """ + Result of a refinement loop execution. + + Attributes: + iterations_run: Number of iterations executed + final_results: Validation results from the last iteration + halting_record: Record of the halting decision + all_results: Validation results from all iterations + run_ids: List of run IDs produced during refinement + """ + iterations_run: int + final_results: List[ValidationResult] = field(default_factory=list) + halting_record: Optional[HaltingRecord] = None + all_results: List[List[ValidationResult]] = field(default_factory=list) + run_ids: List[str] = field(default_factory=list) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary.""" + return { + "iterations_run": self.iterations_run, + "final_results": [r.to_dict() for r in self.final_results], + "halting_record": self.halting_record.to_dict() if self.halting_record else None, + "all_results": [ + [r.to_dict() for r in iteration] + for iteration in self.all_results + ], + "run_ids": self.run_ids, + } diff --git a/markitect/prompts/quality/policy.py b/markitect/prompts/quality/policy.py new file mode 100644 index 00000000..f178487a --- /dev/null +++ b/markitect/prompts/quality/policy.py @@ -0,0 +1,153 @@ +""" +Halting policy engine for refinement control. + +Implements FR-10: Halting and Refinement Policy +Evaluates halting conditions based on quality gate results, +marginal improvement, iteration limits, and resource budgets. +""" + +from typing import List, Optional + +from markitect.prompts.quality.models import ( + HaltDecision, + HaltingRecord, + QualityPolicy, + ValidationResult, + ValidationStatus, +) + + +class HaltingPolicyEngine: + """ + Evaluates halting decisions based on configurable policies. + + Implements FR-10.2: Evaluate halting based on quality gate results, + marginal improvement metrics, iteration limits, and resource budgets. + + Implements FR-10.3: Record halting decisions. + """ + + def __init__(self, policy: QualityPolicy): + """ + Initialize with a quality policy. + + Args: + policy: Quality policy configuration + """ + self.policy = policy + + def evaluate( + self, + results: List[ValidationResult], + iteration: int, + score_history: Optional[List[float]] = None, + total_runs: int = 0, + ) -> HaltingRecord: + """ + Evaluate whether to halt or continue refinement. + + Checks in order: + 1. Resource budget exhaustion + 2. Iteration limit + 3. Quality met (all required gates pass) + 4. Marginal improvement below threshold + + Args: + results: Validation results from current iteration + iteration: Current iteration number (1-based) + score_history: Aggregate scores from previous iterations + total_runs: Total number of runs consumed + + Returns: + HaltingRecord with the decision + """ + score_history = score_history or [] + current_score = self._aggregate_score(results) + all_scores = score_history + [current_score] + + # Check resource budget + if total_runs >= self.policy.resource_budget: + return HaltingRecord( + decision=HaltDecision.HALT_BUDGET_EXHAUSTED, + iteration=iteration, + max_iterations=self.policy.max_iterations, + scores=all_scores, + reason=f"Resource budget exhausted: {total_runs}/{self.policy.resource_budget} runs used", + ) + + # Check iteration limit + if iteration >= self.policy.max_iterations: + return HaltingRecord( + decision=HaltDecision.HALT_ITERATION_LIMIT, + iteration=iteration, + max_iterations=self.policy.max_iterations, + scores=all_scores, + reason=f"Iteration limit reached: {iteration}/{self.policy.max_iterations}", + ) + + # Check if quality is met + if self._quality_met(results): + return HaltingRecord( + decision=HaltDecision.HALT_QUALITY_MET, + iteration=iteration, + max_iterations=self.policy.max_iterations, + scores=all_scores, + reason="All quality gates passed", + ) + + # Check marginal improvement + if len(all_scores) >= 2: + improvement = all_scores[-1] - all_scores[-2] + if improvement < self.policy.min_improvement: + return HaltingRecord( + decision=HaltDecision.HALT_NO_IMPROVEMENT, + iteration=iteration, + max_iterations=self.policy.max_iterations, + scores=all_scores, + reason=( + f"Marginal improvement {improvement:.4f} below " + f"threshold {self.policy.min_improvement}" + ), + ) + + # Continue refinement + return HaltingRecord( + decision=HaltDecision.CONTINUE, + iteration=iteration, + max_iterations=self.policy.max_iterations, + scores=all_scores, + reason="Continuing refinement", + ) + + def _quality_met(self, results: List[ValidationResult]) -> bool: + """ + Check if quality requirements are met. + + If required_gate_ids is set, only those gates must pass. + Otherwise, all gates must pass. + + Args: + results: Validation results to check + + Returns: + True if quality requirements are met + """ + if self.policy.required_gate_ids: + for gate_id in self.policy.required_gate_ids: + gate_results = [r for r in results if r.gate_id == gate_id] + if not gate_results: + return False + if any(r.status == ValidationStatus.FAIL for r in gate_results): + return False + return True + + return all(r.status == ValidationStatus.PASS for r in results) + + def _aggregate_score(self, results: List[ValidationResult]) -> float: + """Calculate aggregate score from results.""" + if not results: + return 0.0 + scores = [r.score for r in results if r.score is not None] + if not scores: + return 0.0 + return sum(scores) / len(scores) diff --git a/markitect/prompts/quality/refinement.py b/markitect/prompts/quality/refinement.py new file mode 100644 index 00000000..dcbb938f --- /dev/null +++ b/markitect/prompts/quality/refinement.py @@ -0,0 +1,108 @@ +""" +Refinement loop for iterative quality improvement. + +Implements FR-10: Halting and Refinement Policy +Execute → Validate → Halt or Refine cycle. +""" + +from typing import Callable, List, Optional, Tuple + +from markitect.prompts.quality.models import ( + HaltDecision, + QualityPolicy, + RefinementResult, + ValidationResult, +) +from markitect.prompts.quality.policy import HaltingPolicyEngine +from markitect.prompts.quality.validator import QualityValidator + + +class RefinementLoop: + """ + Iterative refinement loop with quality gate checks. + + Executes a cycle of: execute → validate → check halting → refine + until a halting condition is met. + """ + + def __init__( + self, + validator: QualityValidator, + policy: QualityPolicy, + ): + """ + Initialize with validator and policy. + + Args: + validator: Quality validator with configured gates + policy: Halting policy configuration + """ + self.validator = validator + self.policy = policy + self.policy_engine = HaltingPolicyEngine(policy) + + def run( + self, + execution_callback: Callable[[int, List[ValidationResult]], Tuple[str, str, str]], + artifact_id: str, + ) -> RefinementResult: + """ + Execute the refinement loop. + + The execution_callback is called each iteration with: + - iteration number (1-based) + - previous validation results (empty list on first iteration) + + It should return a tuple of (run_id, content, artifact_id). + + Args: + execution_callback: Callable that executes/refines and returns + (run_id, content, artifact_id) + artifact_id: ID of the artifact being refined + + Returns: + RefinementResult with complete iteration history + """ + result = RefinementResult(iterations_run=0) + score_history: List[float] = [] + prev_results: List[ValidationResult] = [] + + for iteration in range(1, self.policy.max_iterations + 1): + # Execute / refine + run_id, content, art_id = execution_callback(iteration, prev_results) + result.run_ids.append(run_id) + + # Validate + current_results = self.validator.validate_artifact( + content, art_id, run_id=run_id if self.validator.db_path else None, + ) + result.all_results.append(current_results) + result.iterations_run = iteration + + # Evaluate halting + halting_record = self.policy_engine.evaluate( + results=current_results, + iteration=iteration, + score_history=score_history, + total_runs=len(result.run_ids), + ) + + current_score = self.policy_engine._aggregate_score(current_results) + score_history.append(current_score) + + if halting_record.decision != HaltDecision.CONTINUE: + result.final_results = current_results + result.halting_record = halting_record + return result + + prev_results = current_results + + # Reached max iterations without explicit halt + result.final_results = prev_results + result.halting_record = self.policy_engine.evaluate( + results=prev_results, + iteration=self.policy.max_iterations, + score_history=score_history, + total_runs=len(result.run_ids), + ) + return result diff --git a/markitect/prompts/quality/validator.py b/markitect/prompts/quality/validator.py new file mode 100644 index 00000000..32141c7f --- /dev/null +++ b/markitect/prompts/quality/validator.py @@ -0,0 +1,294 @@ +""" +Quality validator for applying multiple gates to artifacts. + +Implements FR-9.2: Multiple QualityGates per artifact. +Implements FR-9.3: Record pass/fail results and diagnostics in RunManifest. +""" + +import sqlite3 +from pathlib import Path +from typing import Dict, Any, List, Optional + +from markitect.prompts.quality.models import ( + GateType, + QualityGate, + ValidationResult, + ValidationStatus, +) + + +# SQL schema for quality tables +QUALITY_TABLES_SQL = """ +CREATE TABLE IF NOT EXISTS quality_gates ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + gate_type TEXT NOT NULL, + config JSON NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS validation_results ( + id TEXT PRIMARY KEY, + run_id TEXT NOT NULL, + gate_id TEXT NOT NULL, + artifact_id TEXT, + status TEXT NOT NULL, + score REAL, + diagnostics JSON, + validated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +CREATE INDEX IF NOT EXISTS idx_validations_run ON validation_results(run_id); +CREATE INDEX IF NOT EXISTS idx_validations_artifact ON validation_results(artifact_id); +""" + + +class QualityValidator: + """ + Applies multiple quality gates to artifacts and records results. + + Implements FR-9.2 and FR-9.3. + """ + + def __init__( + self, + gates: Optional[List[QualityGate]] = None, + db_path: Optional[str] = None, + ): + """ + Initialize validator with quality gates. + + Args: + gates: List of quality gates to apply + db_path: Optional database path for persisting results + """ + self.gates: List[QualityGate] = gates or [] + self.db_path = db_path + if db_path: + self._initialize_tables() + + def _initialize_tables(self) -> None: + """Initialize quality tables if DB path is set.""" + db_dir = Path(self.db_path).parent + if db_dir and not db_dir.exists(): + db_dir.mkdir(parents=True, exist_ok=True) + + conn = sqlite3.connect(self.db_path) + try: + conn.executescript(QUALITY_TABLES_SQL) + conn.commit() + finally: + conn.close() + + def _get_connection(self) -> sqlite3.Connection: + """Get a database connection.""" + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row + return conn + + def add_gate(self, gate: QualityGate) -> None: + """ + Add a quality gate to the validator. + + Args: + gate: Quality gate to add + """ + self.gates.append(gate) + + def validate_artifact( + self, + content: str, + artifact_id: str, + run_id: Optional[str] = None, + ) -> List[ValidationResult]: + """ + Apply all quality gates to an artifact. + + Args: + content: Artifact content to validate + artifact_id: ID of the artifact + run_id: Optional run ID for persistence + + Returns: + List of ValidationResult from all gates + """ + results = [] + for gate in self.gates: + result = gate.validate(content, artifact_id) + results.append(result) + + if run_id and self.db_path: + self._persist_result(result, run_id) + + return results + + def all_passed(self, results: List[ValidationResult]) -> bool: + """ + Check if all validation results passed. + + Args: + results: List of validation results + + Returns: + True if all results have PASS status + """ + return all(r.status == ValidationStatus.PASS for r in results) + + def aggregate_score(self, results: List[ValidationResult]) -> float: + """ + Calculate aggregate score across all results. + + Args: + results: List of validation results + + Returns: + Average score (0.0-1.0), or 1.0 if no results + """ + if not results: + return 1.0 + scores = [r.score for r in results if r.score is not None] + if not scores: + return 1.0 + return sum(scores) / len(scores) + + def get_failed_gates( + self, + results: List[ValidationResult], + ) -> List[ValidationResult]: + """ + Get only failed validation results. + + Args: + results: List of validation results + + Returns: + List of results with FAIL status + """ + return [r for r in results if r.status == ValidationStatus.FAIL] + + def results_to_manifest_dict( + self, + results: List[ValidationResult], + ) -> Dict[str, Any]: + """ + Convert validation results to RunManifest-compatible dict. + + Implements FR-9.3: Results for RunManifest. + + Args: + results: List of validation results + + Returns: + Dictionary suitable for RunManifest.validation_results + """ + return { + "quality_gates": [r.to_dict() for r in results], + "all_passed": self.all_passed(results), + "aggregate_score": self.aggregate_score(results), + } + + def _persist_result( + self, + result: ValidationResult, + run_id: str, + ) -> None: + """Persist a validation result to the database.""" + import json + + conn = self._get_connection() + try: + conn.execute( + """ + INSERT INTO validation_results ( + id, run_id, gate_id, artifact_id, + status, score, diagnostics, validated_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + result.id, + run_id, + result.gate_id, + result.artifact_id, + result.status.value, + result.score, + json.dumps([d.to_dict() for d in result.diagnostics]), + result.validated_at.isoformat(), + ), + ) + conn.commit() + finally: + conn.close() + + def get_results_for_run(self, run_id: str) -> List[Dict[str, Any]]: + """ + Retrieve persisted validation results for a run. + + Args: + run_id: Run identifier + + Returns: + List of result dictionaries + """ + if not self.db_path: + return [] + + import json + + conn = self._get_connection() + try: + cursor = conn.execute( + "SELECT * FROM validation_results WHERE run_id = ?", + (run_id,), + ) + results = [] + for row in cursor.fetchall(): + results.append({ + "id": row["id"], + "run_id": run_id, + "gate_id": row["gate_id"], + "artifact_id": row["artifact_id"], + "status": row["status"], + "score": row["score"], + "diagnostics": json.loads(row["diagnostics"]) if row["diagnostics"] else [], + "validated_at": row["validated_at"], + }) + return results + finally: + conn.close() + + def get_results_for_artifact(self, artifact_id: str) -> List[Dict[str, Any]]: + """ + Retrieve persisted validation results for an artifact. + + Args: + artifact_id: Artifact identifier + + Returns: + List of result dictionaries + """ + if not self.db_path: + return [] + + import json + + conn = self._get_connection() + try: + cursor = conn.execute( + "SELECT * FROM validation_results WHERE artifact_id = ?", + (artifact_id,), + ) + results = [] + for row in cursor.fetchall(): + results.append({ + "id": row["id"], + "run_id": row["run_id"], + "gate_id": row["gate_id"], + "artifact_id": artifact_id, + "status": row["status"], + "score": row["score"], + "diagnostics": json.loads(row["diagnostics"]) if row["diagnostics"] else [], + "validated_at": row["validated_at"], + }) + return results + finally: + conn.close() diff --git a/migrations/prompts/006_create_quality_tables.sql b/migrations/prompts/006_create_quality_tables.sql new file mode 100644 index 00000000..fd49b31a --- /dev/null +++ b/migrations/prompts/006_create_quality_tables.sql @@ -0,0 +1,25 @@ +-- Phase 7: Quality & Validation tables +-- quality_gates: registered quality gate configurations +-- validation_results: recorded pass/fail results from gate checks + +CREATE TABLE IF NOT EXISTS quality_gates ( + id TEXT PRIMARY KEY, + name TEXT NOT NULL, + gate_type TEXT NOT NULL, + config JSON NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS validation_results ( + id TEXT PRIMARY KEY, + run_id TEXT NOT NULL, + gate_id TEXT NOT NULL, + artifact_id TEXT, + status TEXT NOT NULL, + score REAL, + diagnostics JSON, + validated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +CREATE INDEX IF NOT EXISTS idx_validations_run ON validation_results(run_id); +CREATE INDEX IF NOT EXISTS idx_validations_artifact ON validation_results(artifact_id); diff --git a/tests/integration/prompts/test_halting_execution.py b/tests/integration/prompts/test_halting_execution.py new file mode 100644 index 00000000..4d2adf99 --- /dev/null +++ b/tests/integration/prompts/test_halting_execution.py @@ -0,0 +1,239 @@ +""" +Integration tests for halting execution with refinement loop. + +Tests the full execute → validate → halt or refine cycle with +real quality gates and persistence. +""" + +import json +import pytest +import tempfile +from pathlib import Path + +from markitect.prompts.models import Artifact, ArtifactType +from markitect.prompts.repositories.sqlite import SQLiteArtifactRepository +from markitect.prompts.quality.models import ( + HaltDecision, + QualityPolicy, + ValidationStatus, +) +from markitect.prompts.quality.gates.pattern_gate import PatternValidationGate +from markitect.prompts.quality.gates.schema_gate import SchemaValidationGate +from markitect.prompts.quality.validator import QualityValidator +from markitect.prompts.quality.refinement import RefinementLoop + + +@pytest.fixture +def temp_db(): + """Create temporary database for testing.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f: + db_path = f.name + yield db_path + Path(db_path).unlink(missing_ok=True) + + +@pytest.fixture +def artifact_repo(temp_db): + """Create artifact repository.""" + return SQLiteArtifactRepository(temp_db) + + +class TestImmediateQualityMet: + """Tests where quality is met on the first iteration.""" + + def test_single_iteration_success(self, temp_db): + """Test refinement completes in one iteration when quality is met.""" + gate = PatternValidationGate( + required_patterns=[r"## Summary", r"## Conclusion"], + gate_id="gate-1", + ) + validator = QualityValidator(gates=[gate], db_path=temp_db) + policy = QualityPolicy(max_iterations=5) + loop = RefinementLoop(validator, policy) + + def execute(iteration, prev_results): + return ( + f"run-{iteration}", + "## Summary\nOverview.\n## Conclusion\nDone.", + "art-1", + ) + + result = loop.run(execute, "art-1") + + assert result.iterations_run == 1 + assert result.halting_record.decision == HaltDecision.HALT_QUALITY_MET + assert len(result.final_results) == 1 + assert result.final_results[0].status == ValidationStatus.PASS + + # Verify results persisted + persisted = validator.get_results_for_run("run-1") + assert len(persisted) == 1 + + +class TestIterativeRefinement: + """Tests for iterative refinement improving quality.""" + + def test_progressive_improvement(self, temp_db): + """Test refinement improves content over iterations.""" + gate = PatternValidationGate( + required_patterns=[r"## Summary", r"## Details", r"## Conclusion"], + gate_id="gate-1", + ) + validator = QualityValidator(gates=[gate], db_path=temp_db) + policy = QualityPolicy(max_iterations=5) + loop = RefinementLoop(validator, policy) + + versions = [ + "## Summary\nBasic.", # iter 1: missing 2 patterns + "## Summary\n## Details\nBetter.", # iter 2: missing 1 pattern + "## Summary\n## Details\n## Conclusion\nComplete.", # iter 3: all pass + ] + + def execute(iteration, prev_results): + content = versions[min(iteration - 1, len(versions) - 1)] + return (f"run-{iteration}", content, "art-1") + + result = loop.run(execute, "art-1") + + assert result.iterations_run == 3 + assert result.halting_record.decision == HaltDecision.HALT_QUALITY_MET + assert len(result.all_results) == 3 + + # Verify all iterations persisted + for i in range(1, 4): + persisted = validator.get_results_for_run(f"run-{i}") + assert len(persisted) == 1 + + +class TestIterationLimit: + """Tests for hitting iteration limits.""" + + def test_never_meets_quality(self, temp_db): + """Test refinement stops at iteration limit when quality never met.""" + gate = PatternValidationGate( + required_patterns=[r"NEVER_MATCHES_XYZ123"], + gate_id="gate-1", + ) + validator = QualityValidator(gates=[gate], db_path=temp_db) + policy = QualityPolicy(max_iterations=3, min_improvement=0.0) + loop = RefinementLoop(validator, policy) + + def execute(iteration, prev_results): + return (f"run-{iteration}", "always insufficient", "art-1") + + result = loop.run(execute, "art-1") + + assert result.iterations_run == 3 + assert result.halting_record.decision == HaltDecision.HALT_ITERATION_LIMIT + assert len(result.run_ids) == 3 + + +class TestBudgetExhaustion: + """Tests for resource budget exhaustion.""" + + def test_budget_limits_iterations(self, temp_db): + """Test budget exhaustion stops refinement.""" + gate = PatternValidationGate( + required_patterns=[r"UNREACHABLE"], + gate_id="gate-1", + ) + validator = QualityValidator(gates=[gate], db_path=temp_db) + policy = QualityPolicy(max_iterations=10, resource_budget=2) + loop = RefinementLoop(validator, policy) + + def execute(iteration, prev_results): + return (f"run-{iteration}", "content", "art-1") + + result = loop.run(execute, "art-1") + + assert result.iterations_run == 2 + assert result.halting_record.decision == HaltDecision.HALT_BUDGET_EXHAUSTED + + +class TestMultiGateRefinement: + """Tests for refinement with multiple quality gates.""" + + def test_all_gates_must_pass(self, temp_db): + """Test refinement continues until all gates pass.""" + gate_a = PatternValidationGate( + required_patterns=[r"## Summary"], + gate_id="gate-a", + ) + gate_b = PatternValidationGate( + forbidden_patterns=[r"TODO"], + gate_id="gate-b", + ) + validator = QualityValidator(gates=[gate_a, gate_b], db_path=temp_db) + policy = QualityPolicy(max_iterations=5) + loop = RefinementLoop(validator, policy) + + versions = [ + "## Summary\nTODO: finish this", # gate-a pass, gate-b fail + "## Summary\nAll clean content.", # both pass + ] + + def execute(iteration, prev_results): + content = versions[min(iteration - 1, len(versions) - 1)] + return (f"run-{iteration}", content, "art-1") + + result = loop.run(execute, "art-1") + + assert result.iterations_run == 2 + assert result.halting_record.decision == HaltDecision.HALT_QUALITY_MET + + +class TestRefinementWithSchemaGate: + """Tests for refinement with schema validation gates.""" + + def test_json_refinement(self, temp_db): + """Test refining JSON content to pass schema validation.""" + schema = { + "type": "object", + "required": ["title", "version", "sections"], + "properties": { + "title": {"type": "string"}, + "version": {"type": "integer"}, + "sections": {"type": "array"}, + }, + } + gate = SchemaValidationGate(schema=schema, gate_id="schema-1") + validator = QualityValidator(gates=[gate], db_path=temp_db) + policy = QualityPolicy(max_iterations=5) + loop = RefinementLoop(validator, policy) + + versions = [ + json.dumps({"title": "Doc"}), # missing version & sections + json.dumps({"title": "Doc", "version": 1}), # missing sections + json.dumps({"title": "Doc", "version": 1, "sections": []}), # complete + ] + + def execute(iteration, prev_results): + content = versions[min(iteration - 1, len(versions) - 1)] + return (f"run-{iteration}", content, "art-1") + + result = loop.run(execute, "art-1") + + assert result.iterations_run == 3 + assert result.halting_record.decision == HaltDecision.HALT_QUALITY_MET + + +class TestResultSerialization: + """Tests for RefinementResult serialization.""" + + def test_result_to_dict(self, temp_db): + """Test RefinementResult can be serialized.""" + gate = PatternValidationGate(gate_id="gate-1") + validator = QualityValidator(gates=[gate], db_path=temp_db) + policy = QualityPolicy(max_iterations=1) + loop = RefinementLoop(validator, policy) + + def execute(iteration, prev_results): + return ("run-1", "content", "art-1") + + result = loop.run(execute, "art-1") + d = result.to_dict() + + assert isinstance(d, dict) + assert "iterations_run" in d + assert "halting_record" in d + assert "run_ids" in d diff --git a/tests/integration/prompts/test_quality_validation.py b/tests/integration/prompts/test_quality_validation.py new file mode 100644 index 00000000..db0cd750 --- /dev/null +++ b/tests/integration/prompts/test_quality_validation.py @@ -0,0 +1,208 @@ +""" +Integration tests for full quality validation workflow. + +Tests applying quality gates to artifacts with real DB persistence, +manifest integration, and multi-gate validation. +""" + +import json +import pytest +import tempfile +from pathlib import Path + +from markitect.prompts.models import Artifact, ArtifactType +from markitect.prompts.repositories.sqlite import SQLiteArtifactRepository +from markitect.prompts.quality.models import ( + GateType, + ValidationStatus, +) +from markitect.prompts.quality.gates.schema_gate import SchemaValidationGate +from markitect.prompts.quality.gates.pattern_gate import PatternValidationGate +from markitect.prompts.quality.validator import QualityValidator + + +@pytest.fixture +def temp_db(): + """Create temporary database for testing.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f: + db_path = f.name + yield db_path + Path(db_path).unlink(missing_ok=True) + + +@pytest.fixture +def artifact_repo(temp_db): + """Create artifact repository.""" + return SQLiteArtifactRepository(temp_db) + + +def _create_artifact(repo, name, content, art_type=ArtifactType.GENERATED): + """Helper to create and persist an artifact.""" + artifact = Artifact.create( + space_id="space-1", + name=name, + content=content, + artifact_type=art_type, + ) + return repo.create(artifact) + + +class TestSchemaValidationWorkflow: + """Full schema validation workflow with real DB.""" + + def test_validate_json_artifact_passes(self, temp_db, artifact_repo): + """Test validating a valid JSON artifact.""" + content = json.dumps({ + "name": "API Spec", + "version": 1, + "endpoints": ["/users", "/auth"], + }) + artifact = _create_artifact(artifact_repo, "api-spec", content) + + schema = { + "type": "object", + "required": ["name", "version", "endpoints"], + "properties": { + "name": {"type": "string"}, + "version": {"type": "integer"}, + "endpoints": {"type": "array", "items": {"type": "string"}}, + }, + } + gate = SchemaValidationGate(schema=schema, gate_id="schema-api") + validator = QualityValidator(gates=[gate], db_path=temp_db) + + results = validator.validate_artifact( + content, artifact.id, run_id="run-1" + ) + + assert len(results) == 1 + assert results[0].status == ValidationStatus.PASS + + # Verify persisted + persisted = validator.get_results_for_run("run-1") + assert len(persisted) == 1 + assert persisted[0]["status"] == "pass" + + def test_validate_json_artifact_fails(self, temp_db, artifact_repo): + """Test validating an invalid JSON artifact.""" + content = json.dumps({"name": "Incomplete"}) + artifact = _create_artifact(artifact_repo, "bad-spec", content) + + schema = { + "type": "object", + "required": ["name", "version"], + } + gate = SchemaValidationGate(schema=schema, gate_id="schema-strict") + validator = QualityValidator(gates=[gate], db_path=temp_db) + + results = validator.validate_artifact( + content, artifact.id, run_id="run-2" + ) + + assert results[0].status == ValidationStatus.FAIL + assert len(results[0].diagnostics) > 0 + + persisted = validator.get_results_for_run("run-2") + assert persisted[0]["status"] == "fail" + + +class TestPatternValidationWorkflow: + """Full pattern validation workflow with real DB.""" + + def test_validate_markdown_artifact(self, temp_db, artifact_repo): + """Test validating a markdown artifact against patterns.""" + content = "# API Documentation\n## Endpoints\n### Authentication\nOAuth2 flow." + artifact = _create_artifact(artifact_repo, "api-docs", content) + + gate = PatternValidationGate( + required_patterns=[r"## Endpoints", r"### Authentication"], + forbidden_patterns=[r"TODO", r"FIXME"], + gate_id="pattern-api", + ) + validator = QualityValidator(gates=[gate], db_path=temp_db) + + results = validator.validate_artifact( + content, artifact.id, run_id="run-3" + ) + + assert results[0].status == ValidationStatus.PASS + + def test_forbidden_pattern_detected(self, temp_db, artifact_repo): + """Test that forbidden patterns are caught.""" + content = "# Draft\n## Endpoints\nTODO: Add authentication." + artifact = _create_artifact(artifact_repo, "draft-docs", content) + + gate = PatternValidationGate( + required_patterns=[r"## Endpoints"], + forbidden_patterns=[r"TODO"], + gate_id="pattern-clean", + ) + validator = QualityValidator(gates=[gate], db_path=temp_db) + + results = validator.validate_artifact( + content, artifact.id, run_id="run-4" + ) + + assert results[0].status == ValidationStatus.FAIL + + +class TestMultiGateWorkflow: + """Tests applying multiple gates in a single validation.""" + + def test_multi_gate_validation(self, temp_db, artifact_repo): + """Test applying schema + pattern gates to an artifact.""" + content = json.dumps({ + "title": "Design Doc", + "sections": ["## Overview", "## Details"], + }) + artifact = _create_artifact(artifact_repo, "design-doc", content) + + schema_gate = SchemaValidationGate( + schema={ + "type": "object", + "required": ["title", "sections"], + }, + gate_id="schema-doc", + ) + pattern_gate = PatternValidationGate( + forbidden_patterns=[r"FIXME"], + gate_id="pattern-clean", + ) + validator = QualityValidator( + gates=[schema_gate, pattern_gate], + db_path=temp_db, + ) + + results = validator.validate_artifact( + content, artifact.id, run_id="run-5" + ) + + assert len(results) == 2 + assert all(r.status == ValidationStatus.PASS for r in results) + + # Check manifest dict + manifest = validator.results_to_manifest_dict(results) + assert manifest["all_passed"] is True + assert manifest["aggregate_score"] == 1.0 + + # Verify all persisted + persisted = validator.get_results_for_run("run-5") + assert len(persisted) == 2 + + def test_retrieve_by_artifact(self, temp_db, artifact_repo): + """Test retrieving results by artifact across multiple runs.""" + content = json.dumps({"name": "test"}) + artifact = _create_artifact(artifact_repo, "test-art", content) + + gate = SchemaValidationGate( + schema={"type": "object", "required": ["name"]}, + gate_id="schema-1", + ) + validator = QualityValidator(gates=[gate], db_path=temp_db) + + # Validate across two runs + validator.validate_artifact(content, artifact.id, run_id="run-a") + validator.validate_artifact(content, artifact.id, run_id="run-b") + + results = validator.get_results_for_artifact(artifact.id) + assert len(results) == 2 diff --git a/tests/unit/prompts/test_halting_policy.py b/tests/unit/prompts/test_halting_policy.py new file mode 100644 index 00000000..b9806468 --- /dev/null +++ b/tests/unit/prompts/test_halting_policy.py @@ -0,0 +1,221 @@ +""" +Unit tests for HaltingPolicyEngine. + +Tests halting decisions based on quality results, iteration limits, +marginal improvement, and resource budgets. +""" + +import pytest + +from markitect.prompts.quality.models import ( + GateType, + HaltDecision, + QualityPolicy, + ValidationResult, + ValidationStatus, +) +from markitect.prompts.quality.policy import HaltingPolicyEngine + + +def _make_result(status=ValidationStatus.PASS, score=1.0, gate_id="gate-1"): + """Helper to create a ValidationResult.""" + return ValidationResult.create( + gate_id=gate_id, + gate_type=GateType.PATTERN, + artifact_id="art-1", + status=status, + score=score, + ) + + +class TestQualityMetDecision: + """Tests for quality met halting.""" + + def test_all_pass_halts_quality_met(self): + """Test all gates passing triggers quality met halt.""" + policy = QualityPolicy(max_iterations=5) + engine = HaltingPolicyEngine(policy) + + results = [_make_result(ValidationStatus.PASS, 1.0)] + record = engine.evaluate(results, iteration=1) + + assert record.decision == HaltDecision.HALT_QUALITY_MET + assert "quality gates passed" in record.reason.lower() + + def test_required_gates_all_pass(self): + """Test required gates all passing triggers quality met.""" + policy = QualityPolicy( + max_iterations=5, + required_gate_ids=["required-gate"], + ) + engine = HaltingPolicyEngine(policy) + + results = [ + _make_result(ValidationStatus.PASS, 1.0, gate_id="required-gate"), + _make_result(ValidationStatus.FAIL, 0.5, gate_id="optional-gate"), + ] + record = engine.evaluate(results, iteration=1) + + assert record.decision == HaltDecision.HALT_QUALITY_MET + + def test_required_gate_fails_continues(self): + """Test required gate failing allows continuation.""" + policy = QualityPolicy( + max_iterations=5, + required_gate_ids=["required-gate"], + ) + engine = HaltingPolicyEngine(policy) + + results = [ + _make_result(ValidationStatus.FAIL, 0.5, gate_id="required-gate"), + ] + record = engine.evaluate(results, iteration=1) + + assert record.decision == HaltDecision.CONTINUE + + +class TestIterationLimitDecision: + """Tests for iteration limit halting.""" + + def test_at_iteration_limit(self): + """Test halting at iteration limit.""" + policy = QualityPolicy(max_iterations=3) + engine = HaltingPolicyEngine(policy) + + results = [_make_result(ValidationStatus.FAIL, 0.5)] + record = engine.evaluate(results, iteration=3) + + assert record.decision == HaltDecision.HALT_ITERATION_LIMIT + assert record.iteration == 3 + + def test_before_iteration_limit(self): + """Test not halting before iteration limit.""" + policy = QualityPolicy(max_iterations=5) + engine = HaltingPolicyEngine(policy) + + results = [_make_result(ValidationStatus.FAIL, 0.5)] + record = engine.evaluate(results, iteration=2) + + assert record.decision == HaltDecision.CONTINUE + + +class TestBudgetExhaustedDecision: + """Tests for resource budget exhaustion.""" + + def test_budget_exhausted(self): + """Test halting when budget is exhausted.""" + policy = QualityPolicy(max_iterations=10, resource_budget=5) + engine = HaltingPolicyEngine(policy) + + results = [_make_result(ValidationStatus.FAIL, 0.5)] + record = engine.evaluate(results, iteration=1, total_runs=5) + + assert record.decision == HaltDecision.HALT_BUDGET_EXHAUSTED + + def test_budget_not_exhausted(self): + """Test not halting when budget remains.""" + policy = QualityPolicy(max_iterations=10, resource_budget=10) + engine = HaltingPolicyEngine(policy) + + results = [_make_result(ValidationStatus.FAIL, 0.5)] + record = engine.evaluate(results, iteration=1, total_runs=3) + + assert record.decision == HaltDecision.CONTINUE + + +class TestMarginalImprovementDecision: + """Tests for marginal improvement halting.""" + + def test_no_improvement_halts(self): + """Test halting when improvement is below threshold.""" + policy = QualityPolicy(max_iterations=10, min_improvement=0.05) + engine = HaltingPolicyEngine(policy) + + results = [_make_result(ValidationStatus.FAIL, 0.52)] + record = engine.evaluate( + results, + iteration=2, + score_history=[0.50], # improvement: 0.02 < 0.05 + ) + + assert record.decision == HaltDecision.HALT_NO_IMPROVEMENT + + def test_sufficient_improvement_continues(self): + """Test continuing when improvement meets threshold.""" + policy = QualityPolicy(max_iterations=10, min_improvement=0.05) + engine = HaltingPolicyEngine(policy) + + results = [_make_result(ValidationStatus.FAIL, 0.60)] + record = engine.evaluate( + results, + iteration=2, + score_history=[0.50], # improvement: 0.10 >= 0.05 + ) + + assert record.decision == HaltDecision.CONTINUE + + def test_first_iteration_no_history(self): + """Test first iteration with no history continues.""" + policy = QualityPolicy(max_iterations=10, min_improvement=0.05) + engine = HaltingPolicyEngine(policy) + + results = [_make_result(ValidationStatus.FAIL, 0.50)] + record = engine.evaluate(results, iteration=1) + + assert record.decision == HaltDecision.CONTINUE + + +class TestPriorityOrder: + """Tests for the priority order of halting checks.""" + + def test_budget_checked_before_iteration(self): + """Test budget exhaustion takes priority over iteration limit.""" + policy = QualityPolicy(max_iterations=3, resource_budget=2) + engine = HaltingPolicyEngine(policy) + + results = [_make_result(ValidationStatus.FAIL, 0.5)] + record = engine.evaluate(results, iteration=3, total_runs=2) + + assert record.decision == HaltDecision.HALT_BUDGET_EXHAUSTED + + def test_iteration_checked_before_quality(self): + """Test iteration limit checked before quality met.""" + policy = QualityPolicy(max_iterations=2) + engine = HaltingPolicyEngine(policy) + + results = [_make_result(ValidationStatus.PASS, 1.0)] + # At iteration limit AND quality met — iteration limit wins + # Actually quality is checked after iteration, so quality would be checked + # But iteration 2 >= max 2 triggers first + record = engine.evaluate(results, iteration=2) + + assert record.decision == HaltDecision.HALT_ITERATION_LIMIT + + +class TestHaltingRecord: + """Tests for HaltingRecord attributes.""" + + def test_record_has_scores(self): + """Test halting record includes score history.""" + policy = QualityPolicy(max_iterations=5) + engine = HaltingPolicyEngine(policy) + + results = [_make_result(ValidationStatus.PASS, 0.9)] + record = engine.evaluate( + results, iteration=2, score_history=[0.5] + ) + + assert record.scores == [0.5, 0.9] + + def test_record_to_dict(self): + """Test halting record serialization.""" + policy = QualityPolicy(max_iterations=3) + engine = HaltingPolicyEngine(policy) + + results = [_make_result(ValidationStatus.PASS, 1.0)] + record = engine.evaluate(results, iteration=1) + + d = record.to_dict() + assert d["decision"] == "halted_quality_met" + assert d["iteration"] == 1 + assert d["max_iterations"] == 3 diff --git a/tests/unit/prompts/test_quality_gates.py b/tests/unit/prompts/test_quality_gates.py new file mode 100644 index 00000000..bcb8b033 --- /dev/null +++ b/tests/unit/prompts/test_quality_gates.py @@ -0,0 +1,303 @@ +""" +Unit tests for quality gate models and individual gate implementations. + +Tests QualityGate models, SchemaValidationGate, and PatternValidationGate. +""" + +import json +import pytest + +from markitect.prompts.quality.models import ( + GateType, + ValidationStatus, + ValidationDiagnostic, + ValidationResult, + QualityPolicy, + HaltDecision, + HaltingRecord, +) +from markitect.prompts.quality.gates.schema_gate import SchemaValidationGate +from markitect.prompts.quality.gates.pattern_gate import PatternValidationGate + + +class TestValidationDiagnostic: + """Tests for ValidationDiagnostic.""" + + def test_create_diagnostic(self): + """Test creating a diagnostic.""" + diag = ValidationDiagnostic( + code="TEST_ERROR", + message="Something went wrong", + severity="error", + ) + assert diag.code == "TEST_ERROR" + assert diag.severity == "error" + + def test_to_dict(self): + """Test diagnostic serialization.""" + diag = ValidationDiagnostic(code="E1", message="msg", severity="warning") + d = diag.to_dict() + assert d["code"] == "E1" + assert d["severity"] == "warning" + + def test_from_dict(self): + """Test diagnostic deserialization.""" + data = {"code": "E1", "message": "msg", "severity": "info"} + diag = ValidationDiagnostic.from_dict(data) + assert diag.code == "E1" + assert diag.severity == "info" + + +class TestValidationResult: + """Tests for ValidationResult.""" + + def test_create_result(self): + """Test creating a validation result.""" + result = ValidationResult.create( + gate_id="gate-1", + gate_type=GateType.SCHEMA, + artifact_id="art-1", + status=ValidationStatus.PASS, + score=1.0, + ) + assert result.gate_id == "gate-1" + assert result.status == ValidationStatus.PASS + assert result.score == 1.0 + + def test_result_unique_ids(self): + """Test that each result gets a unique ID.""" + r1 = ValidationResult.create( + gate_id="g", gate_type=GateType.PATTERN, + artifact_id="a", status=ValidationStatus.PASS, + ) + r2 = ValidationResult.create( + gate_id="g", gate_type=GateType.PATTERN, + artifact_id="a", status=ValidationStatus.PASS, + ) + assert r1.id != r2.id + + def test_to_dict_from_dict(self): + """Test round-trip serialization.""" + result = ValidationResult.create( + gate_id="g1", + gate_type=GateType.SCHEMA, + artifact_id="a1", + status=ValidationStatus.FAIL, + score=0.5, + diagnostics=[ + ValidationDiagnostic(code="E1", message="err", severity="error"), + ], + ) + d = result.to_dict() + restored = ValidationResult.from_dict(d) + assert restored.id == result.id + assert restored.status == ValidationStatus.FAIL + assert restored.score == 0.5 + assert len(restored.diagnostics) == 1 + + +class TestQualityPolicy: + """Tests for QualityPolicy.""" + + def test_default_policy(self): + """Test default policy values.""" + policy = QualityPolicy() + assert policy.max_iterations == 3 + assert policy.min_improvement == 0.05 + assert policy.fail_on_gate_failure is True + assert policy.resource_budget == 10 + + def test_to_dict_from_dict(self): + """Test round-trip serialization.""" + policy = QualityPolicy( + max_iterations=5, + min_improvement=0.1, + required_gate_ids=["g1", "g2"], + ) + d = policy.to_dict() + restored = QualityPolicy.from_dict(d) + assert restored.max_iterations == 5 + assert restored.min_improvement == 0.1 + assert restored.required_gate_ids == ["g1", "g2"] + + +class TestSchemaValidationGate: + """Tests for SchemaValidationGate.""" + + def test_valid_json_passes(self): + """Test valid JSON against schema passes.""" + schema = { + "type": "object", + "required": ["name", "version"], + "properties": { + "name": {"type": "string"}, + "version": {"type": "integer"}, + }, + } + gate = SchemaValidationGate(schema=schema, name="test-schema") + content = json.dumps({"name": "test", "version": 1}) + + result = gate.validate(content, "art-1") + assert result.status == ValidationStatus.PASS + assert result.score == 1.0 + assert len(result.diagnostics) == 0 + + def test_missing_required_field_fails(self): + """Test missing required field fails validation.""" + schema = { + "type": "object", + "required": ["name", "version"], + "properties": { + "name": {"type": "string"}, + "version": {"type": "integer"}, + }, + } + gate = SchemaValidationGate(schema=schema) + content = json.dumps({"name": "test"}) + + result = gate.validate(content, "art-1") + assert result.status == ValidationStatus.FAIL + assert any(d.code == "SCHEMA_VIOLATION" for d in result.diagnostics) + + def test_wrong_type_fails(self): + """Test wrong type fails validation.""" + schema = { + "type": "object", + "properties": { + "count": {"type": "integer"}, + }, + } + gate = SchemaValidationGate(schema=schema) + content = json.dumps({"count": "not-a-number"}) + + result = gate.validate(content, "art-1") + assert result.status == ValidationStatus.FAIL + + def test_invalid_json_fails(self): + """Test invalid JSON content fails.""" + schema = {"type": "object"} + gate = SchemaValidationGate(schema=schema) + + result = gate.validate("not json {{{", "art-1") + assert result.status == ValidationStatus.FAIL + assert result.score == 0.0 + assert any(d.code == "INVALID_JSON" for d in result.diagnostics) + + def test_gate_has_correct_type(self): + """Test gate type is SCHEMA.""" + gate = SchemaValidationGate(schema={"type": "object"}) + assert gate.gate_type == GateType.SCHEMA + + def test_empty_schema_passes_any_object(self): + """Test empty schema passes any valid JSON.""" + gate = SchemaValidationGate(schema={}) + result = gate.validate(json.dumps({"any": "thing"}), "art-1") + assert result.status == ValidationStatus.PASS + + def test_score_reflects_error_count(self): + """Test that score decreases with more errors.""" + schema = { + "type": "object", + "required": ["a", "b", "c"], + "properties": { + "a": {"type": "string"}, + "b": {"type": "string"}, + "c": {"type": "string"}, + }, + } + gate = SchemaValidationGate(schema=schema) + + # Missing all 3 required fields + result = gate.validate(json.dumps({}), "art-1") + assert result.status == ValidationStatus.FAIL + assert result.score < 1.0 + + +class TestPatternValidationGate: + """Tests for PatternValidationGate.""" + + def test_required_pattern_present(self): + """Test content with required patterns passes.""" + gate = PatternValidationGate( + required_patterns=[r"## Endpoints", r"### Authentication"], + ) + content = "# API\n## Endpoints\n### Authentication\nDetails here." + + result = gate.validate(content, "art-1") + assert result.status == ValidationStatus.PASS + assert result.score == 1.0 + + def test_required_pattern_missing(self): + """Test missing required pattern fails.""" + gate = PatternValidationGate( + required_patterns=[r"## Endpoints", r"### Authentication"], + ) + content = "# API\n## Endpoints\nNo auth section." + + result = gate.validate(content, "art-1") + assert result.status == ValidationStatus.FAIL + assert any(d.code == "MISSING_PATTERN" for d in result.diagnostics) + + def test_forbidden_pattern_absent(self): + """Test content without forbidden patterns passes.""" + gate = PatternValidationGate( + forbidden_patterns=[r"TODO", r"FIXME"], + ) + content = "This is a clean document." + + result = gate.validate(content, "art-1") + assert result.status == ValidationStatus.PASS + + def test_forbidden_pattern_present(self): + """Test content with forbidden pattern fails.""" + gate = PatternValidationGate( + forbidden_patterns=[r"TODO", r"FIXME"], + ) + content = "This needs work. TODO: fix this." + + result = gate.validate(content, "art-1") + assert result.status == ValidationStatus.FAIL + assert any(d.code == "FORBIDDEN_PATTERN" for d in result.diagnostics) + + def test_combined_required_and_forbidden(self): + """Test both required and forbidden patterns together.""" + gate = PatternValidationGate( + required_patterns=[r"## Summary"], + forbidden_patterns=[r"FIXME"], + ) + # Has required, no forbidden + result1 = gate.validate("## Summary\nAll good.", "art-1") + assert result1.status == ValidationStatus.PASS + + # Has required AND forbidden + result2 = gate.validate("## Summary\nFIXME: broken.", "art-1") + assert result2.status == ValidationStatus.FAIL + + def test_no_patterns_passes(self): + """Test gate with no patterns always passes.""" + gate = PatternValidationGate() + result = gate.validate("anything", "art-1") + assert result.status == ValidationStatus.PASS + + def test_gate_has_correct_type(self): + """Test gate type is PATTERN.""" + gate = PatternValidationGate() + assert gate.gate_type == GateType.PATTERN + + def test_score_proportional_to_failures(self): + """Test score is proportional to number of checks passed.""" + gate = PatternValidationGate( + required_patterns=[r"A", r"B", r"C", r"D"], + ) + # Only A is present → 3 out of 4 fail + result = gate.validate("A", "art-1") + assert result.status == ValidationStatus.FAIL + assert 0.0 < result.score < 1.0 + + def test_regex_pattern_matching(self): + """Test regex patterns work correctly.""" + gate = PatternValidationGate( + required_patterns=[r"\d{3}-\d{4}"], + ) + result = gate.validate("Call 555-1234 for info", "art-1") + assert result.status == ValidationStatus.PASS diff --git a/tests/unit/prompts/test_quality_validator.py b/tests/unit/prompts/test_quality_validator.py new file mode 100644 index 00000000..c09bbe58 --- /dev/null +++ b/tests/unit/prompts/test_quality_validator.py @@ -0,0 +1,264 @@ +""" +Unit tests for QualityValidator. + +Tests applying multiple gates, aggregating results, and persistence. +""" + +import json +import pytest +import tempfile +from pathlib import Path + +from markitect.prompts.quality.models import ( + GateType, + ValidationStatus, + ValidationResult, +) +from markitect.prompts.quality.gates.schema_gate import SchemaValidationGate +from markitect.prompts.quality.gates.pattern_gate import PatternValidationGate +from markitect.prompts.quality.validator import QualityValidator + + +@pytest.fixture +def temp_db(): + """Create temporary database for testing.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f: + db_path = f.name + yield db_path + Path(db_path).unlink(missing_ok=True) + + +@pytest.fixture +def schema_gate(): + """Create a simple schema gate.""" + return SchemaValidationGate( + schema={ + "type": "object", + "required": ["name"], + "properties": {"name": {"type": "string"}}, + }, + gate_id="schema-gate-1", + name="test-schema", + ) + + +@pytest.fixture +def pattern_gate(): + """Create a simple pattern gate.""" + return PatternValidationGate( + required_patterns=[r"## Summary"], + forbidden_patterns=[r"TODO"], + gate_id="pattern-gate-1", + name="test-pattern", + ) + + +class TestValidateArtifact: + """Tests for validating artifacts with multiple gates.""" + + def test_all_gates_pass(self, schema_gate, pattern_gate): + """Test all gates passing.""" + validator = QualityValidator(gates=[schema_gate, pattern_gate]) + + # Content that satisfies both gates (JSON for schema, text for pattern) + # Schema gate needs JSON, pattern gate needs text patterns + # Use separate validators for different content types + schema_validator = QualityValidator(gates=[schema_gate]) + results = schema_validator.validate_artifact( + json.dumps({"name": "test"}), "art-1" + ) + assert len(results) == 1 + assert results[0].status == ValidationStatus.PASS + + def test_pattern_gate_validates(self, pattern_gate): + """Test pattern gate validation.""" + validator = QualityValidator(gates=[pattern_gate]) + results = validator.validate_artifact( + "## Summary\nAll good here.", "art-1" + ) + assert len(results) == 1 + assert results[0].status == ValidationStatus.PASS + + def test_multiple_gates_mixed_results(self, pattern_gate): + """Test multiple gates with mixed pass/fail.""" + gate_a = PatternValidationGate( + required_patterns=[r"## Summary"], + gate_id="gate-a", + ) + gate_b = PatternValidationGate( + required_patterns=[r"## Missing Section"], + gate_id="gate-b", + ) + validator = QualityValidator(gates=[gate_a, gate_b]) + + results = validator.validate_artifact("## Summary\nContent.", "art-1") + assert len(results) == 2 + statuses = {r.gate_id: r.status for r in results} + assert statuses["gate-a"] == ValidationStatus.PASS + assert statuses["gate-b"] == ValidationStatus.FAIL + + def test_no_gates_returns_empty(self): + """Test validator with no gates returns empty list.""" + validator = QualityValidator() + results = validator.validate_artifact("content", "art-1") + assert results == [] + + +class TestAllPassed: + """Tests for the all_passed helper.""" + + def test_all_pass(self): + """Test all_passed returns True when all pass.""" + validator = QualityValidator() + results = [ + ValidationResult.create( + gate_id="g1", gate_type=GateType.PATTERN, + artifact_id="a", status=ValidationStatus.PASS, + ), + ValidationResult.create( + gate_id="g2", gate_type=GateType.PATTERN, + artifact_id="a", status=ValidationStatus.PASS, + ), + ] + assert validator.all_passed(results) is True + + def test_one_fails(self): + """Test all_passed returns False when one fails.""" + validator = QualityValidator() + results = [ + ValidationResult.create( + gate_id="g1", gate_type=GateType.PATTERN, + artifact_id="a", status=ValidationStatus.PASS, + ), + ValidationResult.create( + gate_id="g2", gate_type=GateType.PATTERN, + artifact_id="a", status=ValidationStatus.FAIL, + ), + ] + assert validator.all_passed(results) is False + + def test_empty_results(self): + """Test all_passed with empty list returns True.""" + validator = QualityValidator() + assert validator.all_passed([]) is True + + +class TestAggregateScore: + """Tests for aggregate score calculation.""" + + def test_average_scores(self): + """Test aggregate is average of scores.""" + validator = QualityValidator() + results = [ + ValidationResult.create( + gate_id="g1", gate_type=GateType.PATTERN, + artifact_id="a", status=ValidationStatus.PASS, score=1.0, + ), + ValidationResult.create( + gate_id="g2", gate_type=GateType.PATTERN, + artifact_id="a", status=ValidationStatus.FAIL, score=0.5, + ), + ] + assert validator.aggregate_score(results) == 0.75 + + def test_no_results(self): + """Test aggregate with no results returns 1.0.""" + validator = QualityValidator() + assert validator.aggregate_score([]) == 1.0 + + def test_none_scores_ignored(self): + """Test results with None scores are handled.""" + validator = QualityValidator() + results = [ + ValidationResult.create( + gate_id="g1", gate_type=GateType.PATTERN, + artifact_id="a", status=ValidationStatus.PASS, score=None, + ), + ] + assert validator.aggregate_score(results) == 1.0 + + +class TestGetFailedGates: + """Tests for getting failed gates.""" + + def test_get_failed(self): + """Test filtering failed results.""" + validator = QualityValidator() + results = [ + ValidationResult.create( + gate_id="g1", gate_type=GateType.PATTERN, + artifact_id="a", status=ValidationStatus.PASS, + ), + ValidationResult.create( + gate_id="g2", gate_type=GateType.PATTERN, + artifact_id="a", status=ValidationStatus.FAIL, + ), + ] + failed = validator.get_failed_gates(results) + assert len(failed) == 1 + assert failed[0].gate_id == "g2" + + +class TestResultsToManifest: + """Tests for converting results to manifest dict.""" + + def test_manifest_dict_format(self): + """Test manifest dict has correct structure.""" + validator = QualityValidator() + results = [ + ValidationResult.create( + gate_id="g1", gate_type=GateType.PATTERN, + artifact_id="a", status=ValidationStatus.PASS, score=1.0, + ), + ] + manifest = validator.results_to_manifest_dict(results) + assert "quality_gates" in manifest + assert manifest["all_passed"] is True + assert manifest["aggregate_score"] == 1.0 + + +class TestPersistence: + """Tests for persisting validation results.""" + + def test_persist_and_retrieve_by_run(self, temp_db, pattern_gate): + """Test persisting results and retrieving by run ID.""" + validator = QualityValidator(gates=[pattern_gate], db_path=temp_db) + + validator.validate_artifact( + "## Summary\nClean content.", "art-1", run_id="run-1" + ) + + results = validator.get_results_for_run("run-1") + assert len(results) == 1 + assert results[0]["status"] == "pass" + + def test_persist_and_retrieve_by_artifact(self, temp_db, pattern_gate): + """Test persisting results and retrieving by artifact ID.""" + validator = QualityValidator(gates=[pattern_gate], db_path=temp_db) + + validator.validate_artifact( + "## Summary\nClean content.", "art-1", run_id="run-1" + ) + + results = validator.get_results_for_artifact("art-1") + assert len(results) == 1 + assert results[0]["artifact_id"] == "art-1" + + def test_no_persistence_without_db(self, pattern_gate): + """Test no persistence when db_path is None.""" + validator = QualityValidator(gates=[pattern_gate]) + results = validator.validate_artifact( + "## Summary\nContent.", "art-1", run_id="run-1" + ) + assert len(results) == 1 + # No DB queries should work + assert validator.get_results_for_run("run-1") == [] + + def test_add_gate(self): + """Test adding a gate after construction.""" + validator = QualityValidator() + assert len(validator.gates) == 0 + + gate = PatternValidationGate(required_patterns=[r"test"]) + validator.add_gate(gate) + assert len(validator.gates) == 1 diff --git a/tests/unit/prompts/test_refinement_loop.py b/tests/unit/prompts/test_refinement_loop.py new file mode 100644 index 00000000..d603301d --- /dev/null +++ b/tests/unit/prompts/test_refinement_loop.py @@ -0,0 +1,219 @@ +""" +Unit tests for RefinementLoop. + +Tests the execute → validate → halt or refine cycle. +""" + +import json +import pytest +from typing import List, Tuple + +from markitect.prompts.quality.models import ( + HaltDecision, + QualityPolicy, + ValidationResult, + ValidationStatus, +) +from markitect.prompts.quality.gates.pattern_gate import PatternValidationGate +from markitect.prompts.quality.validator import QualityValidator +from markitect.prompts.quality.refinement import RefinementLoop + + +class TestRefinementLoopBasic: + """Tests for basic refinement loop operation.""" + + def test_immediate_quality_met(self): + """Test loop halts immediately when quality is met on first iteration.""" + gate = PatternValidationGate( + required_patterns=[r"## Summary"], + gate_id="gate-1", + ) + validator = QualityValidator(gates=[gate]) + policy = QualityPolicy(max_iterations=5) + loop = RefinementLoop(validator, policy) + + def callback(iteration, prev_results): + return (f"run-{iteration}", "## Summary\nGood content.", "art-1") + + result = loop.run(callback, "art-1") + + assert result.iterations_run == 1 + assert result.halting_record.decision == HaltDecision.HALT_QUALITY_MET + assert len(result.run_ids) == 1 + + def test_iterates_until_quality_met(self): + """Test loop iterates until quality requirements are met.""" + gate = PatternValidationGate( + required_patterns=[r"## Summary", r"## Details"], + gate_id="gate-1", + ) + validator = QualityValidator(gates=[gate]) + policy = QualityPolicy(max_iterations=5, min_improvement=0.0) + loop = RefinementLoop(validator, policy) + + contents = [ + "## Summary\nOnly summary.", # missing Details + "## Summary\nStill only summary.", # still missing, slightly different + "## Summary\n## Details\nBoth sections.", # complete + ] + + def callback(iteration, prev_results): + content = contents[min(iteration - 1, len(contents) - 1)] + return (f"run-{iteration}", content, "art-1") + + result = loop.run(callback, "art-1") + + assert result.iterations_run == 3 + assert result.halting_record.decision == HaltDecision.HALT_QUALITY_MET + assert len(result.run_ids) == 3 + + def test_hits_iteration_limit(self): + """Test loop stops at iteration limit.""" + gate = PatternValidationGate( + required_patterns=[r"IMPOSSIBLE_PATTERN_12345"], + gate_id="gate-1", + ) + validator = QualityValidator(gates=[gate]) + policy = QualityPolicy(max_iterations=3, min_improvement=0.0) + loop = RefinementLoop(validator, policy) + + def callback(iteration, prev_results): + return (f"run-{iteration}", "never matches", "art-1") + + result = loop.run(callback, "art-1") + + assert result.iterations_run == 3 + assert result.halting_record.decision == HaltDecision.HALT_ITERATION_LIMIT + + def test_budget_exhaustion(self): + """Test loop stops when resource budget is exhausted.""" + gate = PatternValidationGate( + required_patterns=[r"IMPOSSIBLE"], + gate_id="gate-1", + ) + validator = QualityValidator(gates=[gate]) + policy = QualityPolicy(max_iterations=10, resource_budget=2) + loop = RefinementLoop(validator, policy) + + def callback(iteration, prev_results): + return (f"run-{iteration}", "content", "art-1") + + result = loop.run(callback, "art-1") + + assert result.iterations_run == 2 + assert result.halting_record.decision == HaltDecision.HALT_BUDGET_EXHAUSTED + + +class TestRefinementLoopHistory: + """Tests for refinement loop tracking history.""" + + def test_all_results_tracked(self): + """Test all iteration results are tracked.""" + gate = PatternValidationGate( + required_patterns=[r"## Summary"], + gate_id="gate-1", + ) + validator = QualityValidator(gates=[gate]) + policy = QualityPolicy(max_iterations=3) + loop = RefinementLoop(validator, policy) + + iteration_count = [0] + + def callback(iteration, prev_results): + iteration_count[0] = iteration + if iteration < 3: + return (f"run-{iteration}", "no match", "art-1") + return (f"run-{iteration}", "## Summary\nDone.", "art-1") + + result = loop.run(callback, "art-1") + + assert len(result.all_results) == result.iterations_run + assert len(result.run_ids) == result.iterations_run + + def test_run_ids_collected(self): + """Test run IDs are collected from each iteration.""" + gate = PatternValidationGate(gate_id="gate-1") + validator = QualityValidator(gates=[gate]) + policy = QualityPolicy(max_iterations=3) + loop = RefinementLoop(validator, policy) + + def callback(iteration, prev_results): + return (f"run-{iteration}", "content", "art-1") + + result = loop.run(callback, "art-1") + + assert result.run_ids[0] == "run-1" + + def test_previous_results_passed_to_callback(self): + """Test callback receives previous iteration results.""" + gate = PatternValidationGate( + required_patterns=[r"## Final"], + gate_id="gate-1", + ) + validator = QualityValidator(gates=[gate]) + policy = QualityPolicy(max_iterations=3) + loop = RefinementLoop(validator, policy) + + received_prev: List[List[ValidationResult]] = [] + + def callback(iteration, prev_results): + received_prev.append(prev_results) + if iteration >= 2: + return (f"run-{iteration}", "## Final\nDone.", "art-1") + return (f"run-{iteration}", "incomplete", "art-1") + + result = loop.run(callback, "art-1") + + # First iteration gets empty prev_results + assert len(received_prev[0]) == 0 + # Second iteration gets results from first + assert len(received_prev[1]) == 1 + + +class TestRefinementLoopNoImprovement: + """Tests for no-improvement halting in refinement loop.""" + + def test_halts_on_no_improvement(self): + """Test loop halts when scores stop improving.""" + gate = PatternValidationGate( + required_patterns=[r"## Required"], + gate_id="gate-1", + ) + validator = QualityValidator(gates=[gate]) + policy = QualityPolicy( + max_iterations=10, + min_improvement=0.1, + ) + loop = RefinementLoop(validator, policy) + + # Both iterations produce same failing content → same score → no improvement + def callback(iteration, prev_results): + return (f"run-{iteration}", "no match at all", "art-1") + + result = loop.run(callback, "art-1") + + # Second iteration should detect no improvement and halt + assert result.iterations_run == 2 + assert result.halting_record.decision == HaltDecision.HALT_NO_IMPROVEMENT + + +class TestRefinementResultSerialization: + """Tests for RefinementResult serialization.""" + + def test_to_dict(self): + """Test RefinementResult serialization.""" + gate = PatternValidationGate(gate_id="gate-1") + validator = QualityValidator(gates=[gate]) + policy = QualityPolicy(max_iterations=1) + loop = RefinementLoop(validator, policy) + + def callback(iteration, prev_results): + return ("run-1", "content", "art-1") + + result = loop.run(callback, "art-1") + d = result.to_dict() + + assert "iterations_run" in d + assert "final_results" in d + assert "halting_record" in d + assert "run_ids" in d