generated from coulomb/repo-seed
Add adaptive cost-quality routing primitives
This commit is contained in:
90
tests/test_adaptive_integration.py
Normal file
90
tests/test_adaptive_integration.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""
|
||||
Integration coverage for the adaptive routing workplan flow.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from llm_connect.adapter import MockLLMAdapter
|
||||
from llm_connect.quality import QualityLedger, QualityObservation
|
||||
from llm_connect.routing import AdaptiveRoutingPolicy, RoutingRule
|
||||
|
||||
|
||||
def append_quality(
|
||||
ledger: QualityLedger,
|
||||
adapter_id: str,
|
||||
quality_score: float,
|
||||
cost_usd: float,
|
||||
*,
|
||||
recorded_at: datetime,
|
||||
) -> None:
|
||||
ledger.append(
|
||||
QualityObservation(
|
||||
task_type="summarize",
|
||||
adapter_id=adapter_id,
|
||||
model_id=f"{adapter_id}-model",
|
||||
cost_usd=cost_usd,
|
||||
quality_score=quality_score,
|
||||
latency_ms=100,
|
||||
tokens_in=100,
|
||||
tokens_out=50,
|
||||
recorded_at=recorded_at,
|
||||
baseline_adapter_id="baseline",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_adaptive_policy_converges_to_cheapest_qualifying_adapter(tmp_path):
|
||||
cheap = MockLLMAdapter("cheap")
|
||||
mid = MockLLMAdapter("mid")
|
||||
smart = MockLLMAdapter("smart")
|
||||
ledger = QualityLedger(tmp_path / "quality.jsonl")
|
||||
policy = AdaptiveRoutingPolicy(
|
||||
rules=[
|
||||
RoutingRule(
|
||||
"summarize",
|
||||
prefer=smart,
|
||||
max_cost_per_1k=1.0,
|
||||
fallback=mid,
|
||||
)
|
||||
],
|
||||
ledger=ledger,
|
||||
adapters_by_id={"cheap": cheap, "mid": mid, "smart": smart},
|
||||
window_size=2,
|
||||
)
|
||||
|
||||
assert policy.resolve("summarize", quality_floor=0.8) is smart
|
||||
assert policy.resolve("summarize", 2.0, quality_floor=0.8) is mid
|
||||
|
||||
append_quality(
|
||||
ledger,
|
||||
"cheap",
|
||||
quality_score=0.7,
|
||||
cost_usd=0.01,
|
||||
recorded_at=datetime(2026, 5, 17, 10, tzinfo=timezone.utc),
|
||||
)
|
||||
append_quality(
|
||||
ledger,
|
||||
"mid",
|
||||
quality_score=0.86,
|
||||
cost_usd=0.02,
|
||||
recorded_at=datetime(2026, 5, 17, 10, tzinfo=timezone.utc),
|
||||
)
|
||||
append_quality(
|
||||
ledger,
|
||||
"smart",
|
||||
quality_score=0.95,
|
||||
cost_usd=0.05,
|
||||
recorded_at=datetime(2026, 5, 17, 10, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
assert policy.resolve("summarize", quality_floor=0.8) is mid
|
||||
|
||||
append_quality(
|
||||
ledger,
|
||||
"cheap",
|
||||
quality_score=0.95,
|
||||
cost_usd=0.01,
|
||||
recorded_at=datetime(2026, 5, 17, 11, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
assert policy.resolve("summarize", quality_floor=0.8) is cheap
|
||||
181
tests/test_adaptive_routing.py
Normal file
181
tests/test_adaptive_routing.py
Normal file
@@ -0,0 +1,181 @@
|
||||
"""
|
||||
Tests for AdaptiveRoutingPolicy.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from llm_connect.adapter import MockLLMAdapter
|
||||
from llm_connect.quality import QualityLedger, QualityObservation
|
||||
from llm_connect.routing import AdaptiveRoutingPolicy, RoutingRule
|
||||
|
||||
|
||||
def append_observation(
|
||||
ledger: QualityLedger,
|
||||
*,
|
||||
adapter_id: str,
|
||||
quality_score: float,
|
||||
cost_usd: float,
|
||||
task_type: str = "summarize",
|
||||
recorded_at: datetime | None = None,
|
||||
) -> None:
|
||||
ledger.append(
|
||||
QualityObservation(
|
||||
task_type=task_type,
|
||||
adapter_id=adapter_id,
|
||||
model_id=f"{adapter_id}-model",
|
||||
cost_usd=cost_usd,
|
||||
quality_score=quality_score,
|
||||
latency_ms=100,
|
||||
tokens_in=100,
|
||||
tokens_out=50,
|
||||
baseline_adapter_id="baseline",
|
||||
recorded_at=recorded_at or datetime(2026, 5, 17, tzinfo=timezone.utc),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class TestAdaptiveRoutingPolicy:
|
||||
def _adapter(self, name: str) -> MockLLMAdapter:
|
||||
return MockLLMAdapter(mock_response=name)
|
||||
|
||||
def test_selects_cheapest_adapter_that_clears_quality_floor(self, tmp_path):
|
||||
cheap = self._adapter("cheap")
|
||||
smart = self._adapter("smart")
|
||||
ledger = QualityLedger(tmp_path / "quality.jsonl")
|
||||
append_observation(ledger, adapter_id="cheap", quality_score=0.7, cost_usd=0.01)
|
||||
append_observation(ledger, adapter_id="smart", quality_score=0.9, cost_usd=0.03)
|
||||
|
||||
policy = AdaptiveRoutingPolicy(
|
||||
rules=[RoutingRule("summarize", prefer=cheap)],
|
||||
ledger=ledger,
|
||||
adapters_by_id={"cheap": cheap, "smart": smart},
|
||||
)
|
||||
|
||||
assert policy.resolve("summarize", quality_floor=0.8) is smart
|
||||
|
||||
def test_prefers_lower_observed_cost_when_multiple_adapters_clear_floor(self, tmp_path):
|
||||
cheap = self._adapter("cheap")
|
||||
smart = self._adapter("smart")
|
||||
ledger = QualityLedger(tmp_path / "quality.jsonl")
|
||||
append_observation(ledger, adapter_id="cheap", quality_score=0.9, cost_usd=0.01)
|
||||
append_observation(ledger, adapter_id="smart", quality_score=0.95, cost_usd=0.03)
|
||||
|
||||
policy = AdaptiveRoutingPolicy(
|
||||
rules=[RoutingRule("summarize", prefer=smart)],
|
||||
ledger=ledger,
|
||||
adapters_by_id={"cheap": cheap, "smart": smart},
|
||||
)
|
||||
|
||||
assert policy.resolve("summarize", quality_floor=0.8) is cheap
|
||||
|
||||
def test_equal_cost_tie_prefers_static_rule_prefer(self, tmp_path):
|
||||
candidate = self._adapter("candidate")
|
||||
preferred = self._adapter("preferred")
|
||||
ledger = QualityLedger(tmp_path / "quality.jsonl")
|
||||
append_observation(ledger, adapter_id="candidate", quality_score=0.9, cost_usd=0.01)
|
||||
append_observation(ledger, adapter_id="preferred", quality_score=0.9, cost_usd=0.01)
|
||||
|
||||
policy = AdaptiveRoutingPolicy(
|
||||
rules=[RoutingRule("summarize", prefer=preferred)],
|
||||
ledger=ledger,
|
||||
adapters_by_id={"candidate": candidate, "preferred": preferred},
|
||||
)
|
||||
|
||||
assert policy.resolve("summarize", quality_floor=0.8) is preferred
|
||||
|
||||
def test_cold_start_falls_through_to_static_policy(self, tmp_path):
|
||||
preferred = self._adapter("preferred")
|
||||
fallback = self._adapter("fallback")
|
||||
policy = AdaptiveRoutingPolicy(
|
||||
rules=[RoutingRule("summarize", prefer=preferred, fallback=fallback)],
|
||||
ledger=QualityLedger(tmp_path / "quality.jsonl"),
|
||||
adapters_by_id={"preferred": preferred, "fallback": fallback},
|
||||
)
|
||||
|
||||
assert policy.resolve("summarize", quality_floor=0.8) is preferred
|
||||
|
||||
def test_window_size_changes_observed_mean_quality(self, tmp_path):
|
||||
cheap = self._adapter("cheap")
|
||||
smart = self._adapter("smart")
|
||||
ledger = QualityLedger(tmp_path / "quality.jsonl")
|
||||
append_observation(
|
||||
ledger,
|
||||
adapter_id="cheap",
|
||||
quality_score=0.9,
|
||||
cost_usd=0.01,
|
||||
recorded_at=datetime(2026, 5, 16, tzinfo=timezone.utc),
|
||||
)
|
||||
append_observation(
|
||||
ledger,
|
||||
adapter_id="cheap",
|
||||
quality_score=0.7,
|
||||
cost_usd=0.01,
|
||||
recorded_at=datetime(2026, 5, 17, tzinfo=timezone.utc),
|
||||
)
|
||||
append_observation(ledger, adapter_id="smart", quality_score=0.9, cost_usd=0.03)
|
||||
|
||||
recent_only = AdaptiveRoutingPolicy(
|
||||
rules=[RoutingRule("summarize", prefer=smart)],
|
||||
ledger=ledger,
|
||||
adapters_by_id={"cheap": cheap, "smart": smart},
|
||||
window_size=1,
|
||||
)
|
||||
wider_window = AdaptiveRoutingPolicy(
|
||||
rules=[RoutingRule("summarize", prefer=smart)],
|
||||
ledger=ledger,
|
||||
adapters_by_id={"cheap": cheap, "smart": smart},
|
||||
window_size=2,
|
||||
)
|
||||
|
||||
assert recent_only.resolve("summarize", quality_floor=0.8) is smart
|
||||
assert wider_window.resolve("summarize", quality_floor=0.8) is cheap
|
||||
|
||||
def test_stale_observations_are_ignored_by_max_age(self, tmp_path):
|
||||
stale = self._adapter("stale")
|
||||
fresh = self._adapter("fresh")
|
||||
ledger = QualityLedger(tmp_path / "quality.jsonl")
|
||||
append_observation(
|
||||
ledger,
|
||||
adapter_id="stale",
|
||||
quality_score=1.0,
|
||||
cost_usd=0.01,
|
||||
recorded_at=datetime(2020, 1, 1, tzinfo=timezone.utc),
|
||||
)
|
||||
append_observation(
|
||||
ledger,
|
||||
adapter_id="fresh",
|
||||
quality_score=0.9,
|
||||
cost_usd=0.03,
|
||||
recorded_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
policy = AdaptiveRoutingPolicy(
|
||||
rules=[RoutingRule("summarize", prefer=stale)],
|
||||
ledger=ledger,
|
||||
adapters_by_id={"stale": stale, "fresh": fresh},
|
||||
max_age=timedelta(days=1),
|
||||
)
|
||||
|
||||
assert policy.resolve("summarize", quality_floor=0.8) is fresh
|
||||
|
||||
def test_static_fallback_chain_is_preserved_when_no_candidate_qualifies(self, tmp_path):
|
||||
preferred = self._adapter("preferred")
|
||||
fallback = self._adapter("fallback")
|
||||
ledger = QualityLedger(tmp_path / "quality.jsonl")
|
||||
append_observation(ledger, adapter_id="preferred", quality_score=0.6, cost_usd=0.01)
|
||||
append_observation(ledger, adapter_id="fallback", quality_score=0.7, cost_usd=0.005)
|
||||
|
||||
policy = AdaptiveRoutingPolicy(
|
||||
rules=[
|
||||
RoutingRule(
|
||||
"summarize",
|
||||
prefer=preferred,
|
||||
max_cost_per_1k=1.0,
|
||||
fallback=fallback,
|
||||
)
|
||||
],
|
||||
ledger=ledger,
|
||||
adapters_by_id={"preferred": preferred, "fallback": fallback},
|
||||
)
|
||||
|
||||
assert policy.resolve("summarize", 2.0, quality_floor=0.8) is fallback
|
||||
198
tests/test_grading.py
Normal file
198
tests/test_grading.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""
|
||||
Tests for baseline grading and built-in judges.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from llm_connect.adapter import MockLLMAdapter
|
||||
from llm_connect.embedding_adapter import EmbeddingAdapter
|
||||
from llm_connect.grading import (
|
||||
EmbeddingSimilarityJudge,
|
||||
ExactMatchJudge,
|
||||
GradingResult,
|
||||
LLMJudge,
|
||||
PairedGrader,
|
||||
)
|
||||
from llm_connect.models import LLMResponse, RunConfig
|
||||
|
||||
|
||||
class StaticEmbeddingAdapter(EmbeddingAdapter):
|
||||
def __init__(self, embeddings: list[list[float]]):
|
||||
self.embeddings = embeddings
|
||||
self.seen_texts: list[str] | None = None
|
||||
|
||||
def embed(self, texts: list[str]) -> list[list[float]]:
|
||||
self.seen_texts = texts
|
||||
return self.embeddings
|
||||
|
||||
def validate(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def response(content: str, model: str = "m") -> LLMResponse:
|
||||
return LLMResponse(content=content, model=model)
|
||||
|
||||
|
||||
class TestGradingResult:
|
||||
def test_score_must_be_between_zero_and_one(self):
|
||||
with pytest.raises(ValueError, match="quality_score"):
|
||||
GradingResult(
|
||||
quality_score=1.5,
|
||||
notes="bad",
|
||||
grader_id="g",
|
||||
baseline_response=response("a"),
|
||||
candidate_response=response("b"),
|
||||
)
|
||||
|
||||
def test_grader_id_must_be_non_empty(self):
|
||||
with pytest.raises(ValueError, match="grader_id"):
|
||||
GradingResult(
|
||||
quality_score=1.0,
|
||||
notes="ok",
|
||||
grader_id="",
|
||||
baseline_response=response("a"),
|
||||
candidate_response=response("a"),
|
||||
)
|
||||
|
||||
|
||||
class TestExactMatchJudge:
|
||||
def test_scores_one_for_normalised_match(self):
|
||||
judge = ExactMatchJudge()
|
||||
result = judge.judge(
|
||||
response("hello world"),
|
||||
response("hello world"),
|
||||
prompt="p",
|
||||
run_config=RunConfig(),
|
||||
)
|
||||
assert result.quality_score == 1.0
|
||||
assert result.baseline_response.content == "hello world"
|
||||
assert result.candidate_response.content == "hello world"
|
||||
|
||||
def test_scores_zero_for_difference(self):
|
||||
result = ExactMatchJudge().judge(
|
||||
response("hello"),
|
||||
response("goodbye"),
|
||||
prompt="p",
|
||||
run_config=RunConfig(),
|
||||
)
|
||||
assert result.quality_score == 0.0
|
||||
|
||||
def test_case_insensitive_mode(self):
|
||||
result = ExactMatchJudge(case_sensitive=False).judge(
|
||||
response("Hello"),
|
||||
response("hello"),
|
||||
prompt="p",
|
||||
run_config=RunConfig(),
|
||||
)
|
||||
assert result.quality_score == 1.0
|
||||
|
||||
|
||||
class TestEmbeddingSimilarityJudge:
|
||||
def test_scores_cosine_similarity(self):
|
||||
embedding_adapter = StaticEmbeddingAdapter([[1.0, 0.0], [0.5, 0.0]])
|
||||
result = EmbeddingSimilarityJudge(embedding_adapter).judge(
|
||||
response("baseline"),
|
||||
response("candidate"),
|
||||
prompt="p",
|
||||
run_config=RunConfig(),
|
||||
)
|
||||
assert result.quality_score == 1.0
|
||||
assert embedding_adapter.seen_texts == ["baseline", "candidate"]
|
||||
|
||||
def test_negative_similarity_clamps_to_zero(self):
|
||||
embedding_adapter = StaticEmbeddingAdapter([[1.0, 0.0], [-1.0, 0.0]])
|
||||
result = EmbeddingSimilarityJudge(embedding_adapter).judge(
|
||||
response("baseline"),
|
||||
response("candidate"),
|
||||
prompt="p",
|
||||
run_config=RunConfig(),
|
||||
)
|
||||
assert result.quality_score == 0.0
|
||||
|
||||
def test_wrong_embedding_count_raises(self):
|
||||
embedding_adapter = StaticEmbeddingAdapter([[1.0, 0.0]])
|
||||
with pytest.raises(ValueError, match="two embeddings"):
|
||||
EmbeddingSimilarityJudge(embedding_adapter).judge(
|
||||
response("baseline"),
|
||||
response("candidate"),
|
||||
prompt="p",
|
||||
run_config=RunConfig(),
|
||||
)
|
||||
|
||||
|
||||
class TestLLMJudge:
|
||||
def test_parses_json_judge_response(self):
|
||||
judge_adapter = MockLLMAdapter(
|
||||
mock_response='{"quality_score": 0.75, "notes": "mostly equivalent"}'
|
||||
)
|
||||
run_config = RunConfig(model_params={"existing": True})
|
||||
|
||||
result = LLMJudge(judge_adapter).judge(
|
||||
response("baseline answer"),
|
||||
response("candidate answer"),
|
||||
prompt="original prompt",
|
||||
run_config=run_config,
|
||||
)
|
||||
|
||||
assert result.quality_score == 0.75
|
||||
assert result.notes == "mostly equivalent"
|
||||
assert "baseline answer" in judge_adapter.last_prompt
|
||||
assert "candidate answer" in judge_adapter.last_prompt
|
||||
assert judge_adapter.last_config.temperature == 0.0
|
||||
assert judge_adapter.last_config.model_params["existing"] is True
|
||||
assert judge_adapter.last_config.model_params["seed"] == 0
|
||||
assert judge_adapter.last_config.budget_tracker is None
|
||||
|
||||
def test_extracts_json_from_wrapped_response(self):
|
||||
judge_adapter = MockLLMAdapter(
|
||||
mock_response='Here is the result: {"quality_score": 1, "notes": "same"}'
|
||||
)
|
||||
result = LLMJudge(judge_adapter).judge(
|
||||
response("a"),
|
||||
response("a"),
|
||||
prompt="p",
|
||||
run_config=RunConfig(),
|
||||
)
|
||||
assert result.quality_score == 1.0
|
||||
assert result.notes == "same"
|
||||
|
||||
def test_invalid_judge_response_raises(self):
|
||||
judge_adapter = MockLLMAdapter(mock_response="not json")
|
||||
with pytest.raises(ValueError, match="JSON"):
|
||||
LLMJudge(judge_adapter).judge(
|
||||
response("a"),
|
||||
response("b"),
|
||||
prompt="p",
|
||||
run_config=RunConfig(),
|
||||
)
|
||||
|
||||
|
||||
class TestPairedGrader:
|
||||
def test_runs_both_adapters_and_preserves_responses(self):
|
||||
baseline = MockLLMAdapter(mock_response="same")
|
||||
candidate = MockLLMAdapter(mock_response="same")
|
||||
result = PairedGrader(ExactMatchJudge()).grade(
|
||||
baseline,
|
||||
candidate,
|
||||
"prompt",
|
||||
RunConfig(model_name="mock-model"),
|
||||
)
|
||||
|
||||
assert result.quality_score == 1.0
|
||||
assert result.baseline_response.content == "same"
|
||||
assert result.candidate_response.content == "same"
|
||||
assert baseline.call_count == 1
|
||||
assert candidate.call_count == 1
|
||||
assert baseline.last_prompt == "prompt"
|
||||
assert candidate.last_prompt == "prompt"
|
||||
|
||||
def test_uses_custom_judge(self):
|
||||
baseline = MockLLMAdapter(mock_response="a")
|
||||
candidate = MockLLMAdapter(mock_response="b")
|
||||
result = PairedGrader(ExactMatchJudge()).grade(
|
||||
baseline,
|
||||
candidate,
|
||||
"prompt",
|
||||
RunConfig(),
|
||||
)
|
||||
assert result.quality_score == 0.0
|
||||
164
tests/test_quality.py
Normal file
164
tests/test_quality.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""
|
||||
Tests for quality observations and the append-only quality ledger.
|
||||
"""
|
||||
|
||||
import threading
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from llm_connect.quality import QualityLedger, QualityObservation, is_stale
|
||||
|
||||
|
||||
def observation(
|
||||
*,
|
||||
task_type: str = "summarize",
|
||||
adapter_id: str = "openrouter:cheap",
|
||||
model_id: str = "cheap-model",
|
||||
quality_score: float = 0.8,
|
||||
recorded_at: datetime | None = None,
|
||||
tag: str | None = None,
|
||||
) -> QualityObservation:
|
||||
tags = {"tag": tag} if tag is not None else {}
|
||||
return QualityObservation(
|
||||
task_type=task_type,
|
||||
adapter_id=adapter_id,
|
||||
model_id=model_id,
|
||||
cost_usd=0.01,
|
||||
quality_score=quality_score,
|
||||
latency_ms=123.4,
|
||||
tokens_in=100,
|
||||
tokens_out=50,
|
||||
baseline_adapter_id="claude-code",
|
||||
recorded_at=recorded_at or datetime(2026, 5, 17, tzinfo=timezone.utc),
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
|
||||
class TestQualityObservation:
|
||||
def test_round_trip_dict(self):
|
||||
obs = observation(tag="a")
|
||||
restored = QualityObservation.from_dict(obs.to_dict())
|
||||
assert restored == obs
|
||||
assert restored.total_tokens == 150
|
||||
assert restored.recorded_at.tzinfo is not None
|
||||
|
||||
def test_naive_recorded_at_is_interpreted_as_utc(self):
|
||||
obs = observation(recorded_at=datetime(2026, 5, 17, 12, 0, 0))
|
||||
assert obs.recorded_at.tzinfo == timezone.utc
|
||||
|
||||
@pytest.mark.parametrize("score", [-0.1, 1.1])
|
||||
def test_quality_score_must_be_between_zero_and_one(self, score):
|
||||
with pytest.raises(ValueError, match="quality_score"):
|
||||
observation(quality_score=score)
|
||||
|
||||
def test_required_ids_must_be_non_empty(self):
|
||||
with pytest.raises(ValueError, match="task_type"):
|
||||
observation(task_type="")
|
||||
|
||||
def test_non_negative_fields_are_enforced(self):
|
||||
with pytest.raises(ValueError, match="tokens_in"):
|
||||
QualityObservation(
|
||||
task_type="x",
|
||||
adapter_id="a",
|
||||
model_id="m",
|
||||
cost_usd=0,
|
||||
quality_score=1,
|
||||
latency_ms=0,
|
||||
tokens_in=-1,
|
||||
tokens_out=0,
|
||||
)
|
||||
|
||||
|
||||
class TestQualityLedger:
|
||||
def test_append_and_read_round_trip(self, tmp_path):
|
||||
ledger = QualityLedger(tmp_path / "quality.jsonl")
|
||||
obs = observation()
|
||||
ledger.append(obs)
|
||||
assert ledger.read_all() == [obs]
|
||||
|
||||
def test_by_task_type_filters_observations(self, tmp_path):
|
||||
ledger = QualityLedger(tmp_path / "quality.jsonl")
|
||||
ledger.append(observation(task_type="summarize"))
|
||||
ledger.append(observation(task_type="extract"))
|
||||
assert [obs.task_type for obs in ledger.by_task_type("summarize")] == ["summarize"]
|
||||
|
||||
def test_recent_returns_newest_first_with_filters(self, tmp_path):
|
||||
ledger = QualityLedger(tmp_path / "quality.jsonl")
|
||||
older = observation(recorded_at=datetime(2026, 5, 1, tzinfo=timezone.utc), tag="older")
|
||||
newer = observation(recorded_at=datetime(2026, 5, 2, tzinfo=timezone.utc), tag="newer")
|
||||
other = observation(
|
||||
task_type="extract",
|
||||
recorded_at=datetime(2026, 5, 3, tzinfo=timezone.utc),
|
||||
tag="other",
|
||||
)
|
||||
ledger.append(older)
|
||||
ledger.append(newer)
|
||||
ledger.append(other)
|
||||
|
||||
recent = ledger.recent(limit=1, task_type="summarize")
|
||||
assert [obs.tags["tag"] for obs in recent] == ["newer"]
|
||||
|
||||
def test_mean_quality_filters_by_adapter_and_minimum_count(self, tmp_path):
|
||||
ledger = QualityLedger(tmp_path / "quality.jsonl")
|
||||
ledger.append(observation(adapter_id="a", quality_score=0.5))
|
||||
ledger.append(observation(adapter_id="a", quality_score=1.0))
|
||||
ledger.append(observation(adapter_id="b", quality_score=0.1))
|
||||
|
||||
assert ledger.mean_quality("summarize", adapter_id="a") == 0.75
|
||||
assert ledger.mean_quality("summarize", adapter_id="a", min_observations=3) is None
|
||||
|
||||
def test_is_stale_uses_utc_reference(self):
|
||||
obs = observation(recorded_at=datetime(2026, 5, 1, tzinfo=timezone.utc))
|
||||
now = datetime(2026, 5, 3, tzinfo=timezone.utc)
|
||||
assert is_stale(obs, timedelta(days=1), now=now) is True
|
||||
assert is_stale(obs, timedelta(days=3), now=now) is False
|
||||
|
||||
def test_prune_before_removes_old_valid_observations(self, tmp_path):
|
||||
ledger = QualityLedger(tmp_path / "quality.jsonl")
|
||||
old = observation(recorded_at=datetime(2026, 5, 1, tzinfo=timezone.utc), tag="old")
|
||||
keep = observation(recorded_at=datetime(2026, 5, 2, tzinfo=timezone.utc), tag="keep")
|
||||
ledger.append(old)
|
||||
ledger.append(keep)
|
||||
|
||||
removed = ledger.prune_before(datetime(2026, 5, 2, tzinfo=timezone.utc))
|
||||
|
||||
assert removed == 1
|
||||
assert [obs.tags["tag"] for obs in ledger.read_all()] == ["keep"]
|
||||
|
||||
def test_malformed_lines_are_skipped_and_counted(self, tmp_path):
|
||||
path = tmp_path / "quality.jsonl"
|
||||
path.write_text("{not json}\n", encoding="utf-8")
|
||||
ledger = QualityLedger(path)
|
||||
ledger.append(observation())
|
||||
|
||||
assert len(ledger.read_all()) == 1
|
||||
assert ledger.malformed_count() == 1
|
||||
|
||||
def test_prune_preserves_malformed_lines(self, tmp_path):
|
||||
path = tmp_path / "quality.jsonl"
|
||||
path.write_text("{not json}\n", encoding="utf-8")
|
||||
ledger = QualityLedger(path)
|
||||
ledger.append(observation(recorded_at=datetime(2026, 5, 1, tzinfo=timezone.utc)))
|
||||
|
||||
removed = ledger.prune_before(datetime(2026, 5, 2, tzinfo=timezone.utc))
|
||||
|
||||
assert removed == 1
|
||||
assert ledger.malformed_count() == 1
|
||||
assert ledger.read_all() == []
|
||||
|
||||
def test_concurrent_writes_round_trip(self, tmp_path):
|
||||
ledger = QualityLedger(tmp_path / "quality.jsonl")
|
||||
|
||||
def append_one(index: int) -> None:
|
||||
ledger.append(observation(tag=str(index)))
|
||||
|
||||
threads = [threading.Thread(target=append_one, args=(i,)) for i in range(25)]
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
observations = ledger.read_all()
|
||||
assert len(observations) == 25
|
||||
assert {obs.tags["tag"] for obs in observations} == {str(i) for i in range(25)}
|
||||
149
tests/test_shadowing.py
Normal file
149
tests/test_shadowing.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""
|
||||
Tests for ShadowingAdapter.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from llm_connect.adapter import LLMAdapter
|
||||
from llm_connect.grading import ExactMatchJudge, PairedGrader
|
||||
from llm_connect.models import LLMResponse, RunConfig
|
||||
from llm_connect.quality import QualityLedger
|
||||
from llm_connect.shadowing import ShadowingAdapter
|
||||
|
||||
|
||||
class StaticAdapter(LLMAdapter):
|
||||
def __init__(
|
||||
self,
|
||||
content: str,
|
||||
*,
|
||||
model: str = "model",
|
||||
cost_usd: float = 0.0,
|
||||
fail: bool = False,
|
||||
delay_seconds: float = 0.0,
|
||||
):
|
||||
self.content = content
|
||||
self.model = model
|
||||
self.cost_usd = cost_usd
|
||||
self.fail = fail
|
||||
self.delay_seconds = delay_seconds
|
||||
self.calls = 0
|
||||
|
||||
def execute_prompt(self, prompt: str, config: RunConfig) -> LLMResponse:
|
||||
self.calls += 1
|
||||
if self.delay_seconds:
|
||||
time.sleep(self.delay_seconds)
|
||||
if self.fail:
|
||||
raise RuntimeError("adapter failed")
|
||||
return LLMResponse(
|
||||
content=self.content,
|
||||
model=self.model,
|
||||
usage={
|
||||
"prompt_tokens": len(prompt.split()),
|
||||
"completion_tokens": len(self.content.split()),
|
||||
"total_tokens": len(prompt.split()) + len(self.content.split()),
|
||||
},
|
||||
metadata={"cost_usd": self.cost_usd, "latency_ms": 42.0},
|
||||
)
|
||||
|
||||
def validate_config(self, config: RunConfig) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def shadowing_adapter(
|
||||
tmp_path,
|
||||
*,
|
||||
candidate: StaticAdapter | None = None,
|
||||
baseline: StaticAdapter | None = None,
|
||||
shadow_rate: float = 1.0,
|
||||
async_shadow: bool = False,
|
||||
errors: list[Exception] | None = None,
|
||||
) -> ShadowingAdapter:
|
||||
return ShadowingAdapter(
|
||||
candidate_adapter=candidate or StaticAdapter("same", model="candidate", cost_usd=0.02),
|
||||
baseline_adapter=baseline or StaticAdapter("same", model="baseline"),
|
||||
grader=PairedGrader(ExactMatchJudge()),
|
||||
ledger=QualityLedger(tmp_path / "quality.jsonl"),
|
||||
task_type="summarize",
|
||||
adapter_id="candidate",
|
||||
baseline_adapter_id="baseline",
|
||||
shadow_rate=shadow_rate,
|
||||
async_shadow=async_shadow,
|
||||
tags={"prompt_fingerprint": "fixture"},
|
||||
on_shadow_error=errors.append if errors is not None else None,
|
||||
)
|
||||
|
||||
|
||||
class TestShadowingAdapter:
|
||||
def test_sync_shadow_appends_quality_observation(self, tmp_path):
|
||||
adapter = shadowing_adapter(tmp_path)
|
||||
|
||||
response = adapter.execute_prompt("hello world", RunConfig(model_name="candidate-model"))
|
||||
|
||||
observations = adapter.ledger.read_all()
|
||||
assert response.content == "same"
|
||||
assert len(observations) == 1
|
||||
assert observations[0].quality_score == 1.0
|
||||
assert observations[0].cost_usd == 0.02
|
||||
assert observations[0].tokens_in == 2
|
||||
assert observations[0].tokens_out == 1
|
||||
assert observations[0].baseline_adapter_id == "baseline"
|
||||
assert observations[0].tags["prompt_fingerprint"] == "fixture"
|
||||
|
||||
def test_candidate_response_survives_baseline_failure(self, tmp_path):
|
||||
candidate = StaticAdapter("candidate", model="candidate")
|
||||
baseline = StaticAdapter("baseline", fail=True)
|
||||
errors: list[Exception] = []
|
||||
adapter = shadowing_adapter(
|
||||
tmp_path,
|
||||
candidate=candidate,
|
||||
baseline=baseline,
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
response = adapter.execute_prompt("prompt", RunConfig())
|
||||
|
||||
assert response.content == "candidate"
|
||||
assert candidate.calls == 1
|
||||
assert baseline.calls == 1
|
||||
assert adapter.ledger.read_all() == []
|
||||
assert len(errors) == 1
|
||||
|
||||
def test_shadow_rate_zero_skips_baseline_and_ledger(self, tmp_path):
|
||||
baseline = StaticAdapter("same")
|
||||
adapter = shadowing_adapter(tmp_path, baseline=baseline, shadow_rate=0.0)
|
||||
|
||||
for _ in range(3):
|
||||
adapter.execute_prompt("prompt", RunConfig())
|
||||
|
||||
assert baseline.calls == 0
|
||||
assert adapter.ledger.read_all() == []
|
||||
|
||||
def test_shadow_rate_one_records_each_call(self, tmp_path):
|
||||
baseline = StaticAdapter("same")
|
||||
adapter = shadowing_adapter(tmp_path, baseline=baseline, shadow_rate=1.0)
|
||||
|
||||
for _ in range(3):
|
||||
adapter.execute_prompt("prompt", RunConfig())
|
||||
|
||||
assert baseline.calls == 3
|
||||
assert len(adapter.ledger.read_all()) == 3
|
||||
|
||||
def test_async_shadow_mode_flushes_background_work(self, tmp_path):
|
||||
baseline = StaticAdapter("same", delay_seconds=0.02)
|
||||
adapter = shadowing_adapter(tmp_path, baseline=baseline, async_shadow=True)
|
||||
|
||||
response = adapter.execute_prompt("prompt", RunConfig())
|
||||
adapter.flush(timeout=1)
|
||||
adapter.shutdown()
|
||||
|
||||
assert response.content == "same"
|
||||
assert len(adapter.ledger.read_all()) == 1
|
||||
|
||||
def test_async_execute_prompt_records_shadow(self, tmp_path):
|
||||
adapter = shadowing_adapter(tmp_path)
|
||||
|
||||
response = asyncio.run(adapter.async_execute_prompt("prompt", RunConfig()))
|
||||
|
||||
assert response.content == "same"
|
||||
assert len(adapter.ledger.read_all()) == 1
|
||||
Reference in New Issue
Block a user