""" Tests for ShadowingAdapter. """ import asyncio import time from llm_connect.adapter import LLMAdapter from llm_connect.grading import ExactMatchJudge, PairedGrader from llm_connect.models import LLMResponse, RunConfig from llm_connect.quality import QualityLedger from llm_connect.shadowing import ShadowingAdapter class StaticAdapter(LLMAdapter): def __init__( self, content: str, *, model: str = "model", cost_usd: float = 0.0, fail: bool = False, delay_seconds: float = 0.0, ): self.content = content self.model = model self.cost_usd = cost_usd self.fail = fail self.delay_seconds = delay_seconds self.calls = 0 def execute_prompt(self, prompt: str, config: RunConfig) -> LLMResponse: self.calls += 1 if self.delay_seconds: time.sleep(self.delay_seconds) if self.fail: raise RuntimeError("adapter failed") return LLMResponse( content=self.content, model=self.model, usage={ "prompt_tokens": len(prompt.split()), "completion_tokens": len(self.content.split()), "total_tokens": len(prompt.split()) + len(self.content.split()), }, metadata={"cost_usd": self.cost_usd, "latency_ms": 42.0}, ) def validate_config(self, config: RunConfig) -> bool: return True def shadowing_adapter( tmp_path, *, candidate: StaticAdapter | None = None, baseline: StaticAdapter | None = None, shadow_rate: float = 1.0, async_shadow: bool = False, errors: list[Exception] | None = None, ) -> ShadowingAdapter: return ShadowingAdapter( candidate_adapter=candidate or StaticAdapter("same", model="candidate", cost_usd=0.02), baseline_adapter=baseline or StaticAdapter("same", model="baseline"), grader=PairedGrader(ExactMatchJudge()), ledger=QualityLedger(tmp_path / "quality.jsonl"), task_type="summarize", adapter_id="candidate", baseline_adapter_id="baseline", shadow_rate=shadow_rate, async_shadow=async_shadow, tags={"prompt_fingerprint": "fixture"}, on_shadow_error=errors.append if errors is not None else None, ) class TestShadowingAdapter: def test_sync_shadow_appends_quality_observation(self, tmp_path): adapter = shadowing_adapter(tmp_path) response = adapter.execute_prompt("hello world", RunConfig(model_name="candidate-model")) observations = adapter.ledger.read_all() assert response.content == "same" assert len(observations) == 1 assert observations[0].quality_score == 1.0 assert observations[0].cost_usd == 0.02 assert observations[0].tokens_in == 2 assert observations[0].tokens_out == 1 assert observations[0].baseline_adapter_id == "baseline" assert observations[0].tags["prompt_fingerprint"] == "fixture" def test_candidate_response_survives_baseline_failure(self, tmp_path): candidate = StaticAdapter("candidate", model="candidate") baseline = StaticAdapter("baseline", fail=True) errors: list[Exception] = [] adapter = shadowing_adapter( tmp_path, candidate=candidate, baseline=baseline, errors=errors, ) response = adapter.execute_prompt("prompt", RunConfig()) assert response.content == "candidate" assert candidate.calls == 1 assert baseline.calls == 1 assert adapter.ledger.read_all() == [] assert len(errors) == 1 def test_shadow_rate_zero_skips_baseline_and_ledger(self, tmp_path): baseline = StaticAdapter("same") adapter = shadowing_adapter(tmp_path, baseline=baseline, shadow_rate=0.0) for _ in range(3): adapter.execute_prompt("prompt", RunConfig()) assert baseline.calls == 0 assert adapter.ledger.read_all() == [] def test_shadow_rate_one_records_each_call(self, tmp_path): baseline = StaticAdapter("same") adapter = shadowing_adapter(tmp_path, baseline=baseline, shadow_rate=1.0) for _ in range(3): adapter.execute_prompt("prompt", RunConfig()) assert baseline.calls == 3 assert len(adapter.ledger.read_all()) == 3 def test_async_shadow_mode_flushes_background_work(self, tmp_path): baseline = StaticAdapter("same", delay_seconds=0.02) adapter = shadowing_adapter(tmp_path, baseline=baseline, async_shadow=True) response = adapter.execute_prompt("prompt", RunConfig()) adapter.flush(timeout=1) adapter.shutdown() assert response.content == "same" assert len(adapter.ledger.read_all()) == 1 def test_async_execute_prompt_records_shadow(self, tmp_path): adapter = shadowing_adapter(tmp_path) response = asyncio.run(adapter.async_execute_prompt("prompt", RunConfig())) assert response.content == "same" assert len(adapter.ledger.read_all()) == 1