generated from coulomb/repo-seed
Add adaptive cost-quality routing primitives
This commit is contained in:
149
tests/test_shadowing.py
Normal file
149
tests/test_shadowing.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""
|
||||
Tests for ShadowingAdapter.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from llm_connect.adapter import LLMAdapter
|
||||
from llm_connect.grading import ExactMatchJudge, PairedGrader
|
||||
from llm_connect.models import LLMResponse, RunConfig
|
||||
from llm_connect.quality import QualityLedger
|
||||
from llm_connect.shadowing import ShadowingAdapter
|
||||
|
||||
|
||||
class StaticAdapter(LLMAdapter):
|
||||
def __init__(
|
||||
self,
|
||||
content: str,
|
||||
*,
|
||||
model: str = "model",
|
||||
cost_usd: float = 0.0,
|
||||
fail: bool = False,
|
||||
delay_seconds: float = 0.0,
|
||||
):
|
||||
self.content = content
|
||||
self.model = model
|
||||
self.cost_usd = cost_usd
|
||||
self.fail = fail
|
||||
self.delay_seconds = delay_seconds
|
||||
self.calls = 0
|
||||
|
||||
def execute_prompt(self, prompt: str, config: RunConfig) -> LLMResponse:
|
||||
self.calls += 1
|
||||
if self.delay_seconds:
|
||||
time.sleep(self.delay_seconds)
|
||||
if self.fail:
|
||||
raise RuntimeError("adapter failed")
|
||||
return LLMResponse(
|
||||
content=self.content,
|
||||
model=self.model,
|
||||
usage={
|
||||
"prompt_tokens": len(prompt.split()),
|
||||
"completion_tokens": len(self.content.split()),
|
||||
"total_tokens": len(prompt.split()) + len(self.content.split()),
|
||||
},
|
||||
metadata={"cost_usd": self.cost_usd, "latency_ms": 42.0},
|
||||
)
|
||||
|
||||
def validate_config(self, config: RunConfig) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def shadowing_adapter(
|
||||
tmp_path,
|
||||
*,
|
||||
candidate: StaticAdapter | None = None,
|
||||
baseline: StaticAdapter | None = None,
|
||||
shadow_rate: float = 1.0,
|
||||
async_shadow: bool = False,
|
||||
errors: list[Exception] | None = None,
|
||||
) -> ShadowingAdapter:
|
||||
return ShadowingAdapter(
|
||||
candidate_adapter=candidate or StaticAdapter("same", model="candidate", cost_usd=0.02),
|
||||
baseline_adapter=baseline or StaticAdapter("same", model="baseline"),
|
||||
grader=PairedGrader(ExactMatchJudge()),
|
||||
ledger=QualityLedger(tmp_path / "quality.jsonl"),
|
||||
task_type="summarize",
|
||||
adapter_id="candidate",
|
||||
baseline_adapter_id="baseline",
|
||||
shadow_rate=shadow_rate,
|
||||
async_shadow=async_shadow,
|
||||
tags={"prompt_fingerprint": "fixture"},
|
||||
on_shadow_error=errors.append if errors is not None else None,
|
||||
)
|
||||
|
||||
|
||||
class TestShadowingAdapter:
|
||||
def test_sync_shadow_appends_quality_observation(self, tmp_path):
|
||||
adapter = shadowing_adapter(tmp_path)
|
||||
|
||||
response = adapter.execute_prompt("hello world", RunConfig(model_name="candidate-model"))
|
||||
|
||||
observations = adapter.ledger.read_all()
|
||||
assert response.content == "same"
|
||||
assert len(observations) == 1
|
||||
assert observations[0].quality_score == 1.0
|
||||
assert observations[0].cost_usd == 0.02
|
||||
assert observations[0].tokens_in == 2
|
||||
assert observations[0].tokens_out == 1
|
||||
assert observations[0].baseline_adapter_id == "baseline"
|
||||
assert observations[0].tags["prompt_fingerprint"] == "fixture"
|
||||
|
||||
def test_candidate_response_survives_baseline_failure(self, tmp_path):
|
||||
candidate = StaticAdapter("candidate", model="candidate")
|
||||
baseline = StaticAdapter("baseline", fail=True)
|
||||
errors: list[Exception] = []
|
||||
adapter = shadowing_adapter(
|
||||
tmp_path,
|
||||
candidate=candidate,
|
||||
baseline=baseline,
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
response = adapter.execute_prompt("prompt", RunConfig())
|
||||
|
||||
assert response.content == "candidate"
|
||||
assert candidate.calls == 1
|
||||
assert baseline.calls == 1
|
||||
assert adapter.ledger.read_all() == []
|
||||
assert len(errors) == 1
|
||||
|
||||
def test_shadow_rate_zero_skips_baseline_and_ledger(self, tmp_path):
|
||||
baseline = StaticAdapter("same")
|
||||
adapter = shadowing_adapter(tmp_path, baseline=baseline, shadow_rate=0.0)
|
||||
|
||||
for _ in range(3):
|
||||
adapter.execute_prompt("prompt", RunConfig())
|
||||
|
||||
assert baseline.calls == 0
|
||||
assert adapter.ledger.read_all() == []
|
||||
|
||||
def test_shadow_rate_one_records_each_call(self, tmp_path):
|
||||
baseline = StaticAdapter("same")
|
||||
adapter = shadowing_adapter(tmp_path, baseline=baseline, shadow_rate=1.0)
|
||||
|
||||
for _ in range(3):
|
||||
adapter.execute_prompt("prompt", RunConfig())
|
||||
|
||||
assert baseline.calls == 3
|
||||
assert len(adapter.ledger.read_all()) == 3
|
||||
|
||||
def test_async_shadow_mode_flushes_background_work(self, tmp_path):
|
||||
baseline = StaticAdapter("same", delay_seconds=0.02)
|
||||
adapter = shadowing_adapter(tmp_path, baseline=baseline, async_shadow=True)
|
||||
|
||||
response = adapter.execute_prompt("prompt", RunConfig())
|
||||
adapter.flush(timeout=1)
|
||||
adapter.shutdown()
|
||||
|
||||
assert response.content == "same"
|
||||
assert len(adapter.ledger.read_all()) == 1
|
||||
|
||||
def test_async_execute_prompt_records_shadow(self, tmp_path):
|
||||
adapter = shadowing_adapter(tmp_path)
|
||||
|
||||
response = asyncio.run(adapter.async_execute_prompt("prompt", RunConfig()))
|
||||
|
||||
assert response.content == "same"
|
||||
assert len(adapter.ledger.read_all()) == 1
|
||||
Reference in New Issue
Block a user