generated from coulomb/repo-seed
150 lines
5.0 KiB
Python
150 lines
5.0 KiB
Python
"""
|
|
Tests for ShadowingAdapter.
|
|
"""
|
|
|
|
import asyncio
|
|
import time
|
|
|
|
from llm_connect.adapter import LLMAdapter
|
|
from llm_connect.grading import ExactMatchJudge, PairedGrader
|
|
from llm_connect.models import LLMResponse, RunConfig
|
|
from llm_connect.quality import QualityLedger
|
|
from llm_connect.shadowing import ShadowingAdapter
|
|
|
|
|
|
class StaticAdapter(LLMAdapter):
|
|
def __init__(
|
|
self,
|
|
content: str,
|
|
*,
|
|
model: str = "model",
|
|
cost_usd: float = 0.0,
|
|
fail: bool = False,
|
|
delay_seconds: float = 0.0,
|
|
):
|
|
self.content = content
|
|
self.model = model
|
|
self.cost_usd = cost_usd
|
|
self.fail = fail
|
|
self.delay_seconds = delay_seconds
|
|
self.calls = 0
|
|
|
|
def execute_prompt(self, prompt: str, config: RunConfig) -> LLMResponse:
|
|
self.calls += 1
|
|
if self.delay_seconds:
|
|
time.sleep(self.delay_seconds)
|
|
if self.fail:
|
|
raise RuntimeError("adapter failed")
|
|
return LLMResponse(
|
|
content=self.content,
|
|
model=self.model,
|
|
usage={
|
|
"prompt_tokens": len(prompt.split()),
|
|
"completion_tokens": len(self.content.split()),
|
|
"total_tokens": len(prompt.split()) + len(self.content.split()),
|
|
},
|
|
metadata={"cost_usd": self.cost_usd, "latency_ms": 42.0},
|
|
)
|
|
|
|
def validate_config(self, config: RunConfig) -> bool:
|
|
return True
|
|
|
|
|
|
def shadowing_adapter(
|
|
tmp_path,
|
|
*,
|
|
candidate: StaticAdapter | None = None,
|
|
baseline: StaticAdapter | None = None,
|
|
shadow_rate: float = 1.0,
|
|
async_shadow: bool = False,
|
|
errors: list[Exception] | None = None,
|
|
) -> ShadowingAdapter:
|
|
return ShadowingAdapter(
|
|
candidate_adapter=candidate or StaticAdapter("same", model="candidate", cost_usd=0.02),
|
|
baseline_adapter=baseline or StaticAdapter("same", model="baseline"),
|
|
grader=PairedGrader(ExactMatchJudge()),
|
|
ledger=QualityLedger(tmp_path / "quality.jsonl"),
|
|
task_type="summarize",
|
|
adapter_id="candidate",
|
|
baseline_adapter_id="baseline",
|
|
shadow_rate=shadow_rate,
|
|
async_shadow=async_shadow,
|
|
tags={"prompt_fingerprint": "fixture"},
|
|
on_shadow_error=errors.append if errors is not None else None,
|
|
)
|
|
|
|
|
|
class TestShadowingAdapter:
|
|
def test_sync_shadow_appends_quality_observation(self, tmp_path):
|
|
adapter = shadowing_adapter(tmp_path)
|
|
|
|
response = adapter.execute_prompt("hello world", RunConfig(model_name="candidate-model"))
|
|
|
|
observations = adapter.ledger.read_all()
|
|
assert response.content == "same"
|
|
assert len(observations) == 1
|
|
assert observations[0].quality_score == 1.0
|
|
assert observations[0].cost_usd == 0.02
|
|
assert observations[0].tokens_in == 2
|
|
assert observations[0].tokens_out == 1
|
|
assert observations[0].baseline_adapter_id == "baseline"
|
|
assert observations[0].tags["prompt_fingerprint"] == "fixture"
|
|
|
|
def test_candidate_response_survives_baseline_failure(self, tmp_path):
|
|
candidate = StaticAdapter("candidate", model="candidate")
|
|
baseline = StaticAdapter("baseline", fail=True)
|
|
errors: list[Exception] = []
|
|
adapter = shadowing_adapter(
|
|
tmp_path,
|
|
candidate=candidate,
|
|
baseline=baseline,
|
|
errors=errors,
|
|
)
|
|
|
|
response = adapter.execute_prompt("prompt", RunConfig())
|
|
|
|
assert response.content == "candidate"
|
|
assert candidate.calls == 1
|
|
assert baseline.calls == 1
|
|
assert adapter.ledger.read_all() == []
|
|
assert len(errors) == 1
|
|
|
|
def test_shadow_rate_zero_skips_baseline_and_ledger(self, tmp_path):
|
|
baseline = StaticAdapter("same")
|
|
adapter = shadowing_adapter(tmp_path, baseline=baseline, shadow_rate=0.0)
|
|
|
|
for _ in range(3):
|
|
adapter.execute_prompt("prompt", RunConfig())
|
|
|
|
assert baseline.calls == 0
|
|
assert adapter.ledger.read_all() == []
|
|
|
|
def test_shadow_rate_one_records_each_call(self, tmp_path):
|
|
baseline = StaticAdapter("same")
|
|
adapter = shadowing_adapter(tmp_path, baseline=baseline, shadow_rate=1.0)
|
|
|
|
for _ in range(3):
|
|
adapter.execute_prompt("prompt", RunConfig())
|
|
|
|
assert baseline.calls == 3
|
|
assert len(adapter.ledger.read_all()) == 3
|
|
|
|
def test_async_shadow_mode_flushes_background_work(self, tmp_path):
|
|
baseline = StaticAdapter("same", delay_seconds=0.02)
|
|
adapter = shadowing_adapter(tmp_path, baseline=baseline, async_shadow=True)
|
|
|
|
response = adapter.execute_prompt("prompt", RunConfig())
|
|
adapter.flush(timeout=1)
|
|
adapter.shutdown()
|
|
|
|
assert response.content == "same"
|
|
assert len(adapter.ledger.read_all()) == 1
|
|
|
|
def test_async_execute_prompt_records_shadow(self, tmp_path):
|
|
adapter = shadowing_adapter(tmp_path)
|
|
|
|
response = asyncio.run(adapter.async_execute_prompt("prompt", RunConfig()))
|
|
|
|
assert response.content == "same"
|
|
assert len(adapter.ledger.read_all()) == 1
|