generated from coulomb/repo-seed
Add adaptive cost-quality routing primitives
This commit is contained in:
181
tests/test_adaptive_routing.py
Normal file
181
tests/test_adaptive_routing.py
Normal file
@@ -0,0 +1,181 @@
|
||||
"""
|
||||
Tests for AdaptiveRoutingPolicy.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from llm_connect.adapter import MockLLMAdapter
|
||||
from llm_connect.quality import QualityLedger, QualityObservation
|
||||
from llm_connect.routing import AdaptiveRoutingPolicy, RoutingRule
|
||||
|
||||
|
||||
def append_observation(
|
||||
ledger: QualityLedger,
|
||||
*,
|
||||
adapter_id: str,
|
||||
quality_score: float,
|
||||
cost_usd: float,
|
||||
task_type: str = "summarize",
|
||||
recorded_at: datetime | None = None,
|
||||
) -> None:
|
||||
ledger.append(
|
||||
QualityObservation(
|
||||
task_type=task_type,
|
||||
adapter_id=adapter_id,
|
||||
model_id=f"{adapter_id}-model",
|
||||
cost_usd=cost_usd,
|
||||
quality_score=quality_score,
|
||||
latency_ms=100,
|
||||
tokens_in=100,
|
||||
tokens_out=50,
|
||||
baseline_adapter_id="baseline",
|
||||
recorded_at=recorded_at or datetime(2026, 5, 17, tzinfo=timezone.utc),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class TestAdaptiveRoutingPolicy:
|
||||
def _adapter(self, name: str) -> MockLLMAdapter:
|
||||
return MockLLMAdapter(mock_response=name)
|
||||
|
||||
def test_selects_cheapest_adapter_that_clears_quality_floor(self, tmp_path):
|
||||
cheap = self._adapter("cheap")
|
||||
smart = self._adapter("smart")
|
||||
ledger = QualityLedger(tmp_path / "quality.jsonl")
|
||||
append_observation(ledger, adapter_id="cheap", quality_score=0.7, cost_usd=0.01)
|
||||
append_observation(ledger, adapter_id="smart", quality_score=0.9, cost_usd=0.03)
|
||||
|
||||
policy = AdaptiveRoutingPolicy(
|
||||
rules=[RoutingRule("summarize", prefer=cheap)],
|
||||
ledger=ledger,
|
||||
adapters_by_id={"cheap": cheap, "smart": smart},
|
||||
)
|
||||
|
||||
assert policy.resolve("summarize", quality_floor=0.8) is smart
|
||||
|
||||
def test_prefers_lower_observed_cost_when_multiple_adapters_clear_floor(self, tmp_path):
|
||||
cheap = self._adapter("cheap")
|
||||
smart = self._adapter("smart")
|
||||
ledger = QualityLedger(tmp_path / "quality.jsonl")
|
||||
append_observation(ledger, adapter_id="cheap", quality_score=0.9, cost_usd=0.01)
|
||||
append_observation(ledger, adapter_id="smart", quality_score=0.95, cost_usd=0.03)
|
||||
|
||||
policy = AdaptiveRoutingPolicy(
|
||||
rules=[RoutingRule("summarize", prefer=smart)],
|
||||
ledger=ledger,
|
||||
adapters_by_id={"cheap": cheap, "smart": smart},
|
||||
)
|
||||
|
||||
assert policy.resolve("summarize", quality_floor=0.8) is cheap
|
||||
|
||||
def test_equal_cost_tie_prefers_static_rule_prefer(self, tmp_path):
|
||||
candidate = self._adapter("candidate")
|
||||
preferred = self._adapter("preferred")
|
||||
ledger = QualityLedger(tmp_path / "quality.jsonl")
|
||||
append_observation(ledger, adapter_id="candidate", quality_score=0.9, cost_usd=0.01)
|
||||
append_observation(ledger, adapter_id="preferred", quality_score=0.9, cost_usd=0.01)
|
||||
|
||||
policy = AdaptiveRoutingPolicy(
|
||||
rules=[RoutingRule("summarize", prefer=preferred)],
|
||||
ledger=ledger,
|
||||
adapters_by_id={"candidate": candidate, "preferred": preferred},
|
||||
)
|
||||
|
||||
assert policy.resolve("summarize", quality_floor=0.8) is preferred
|
||||
|
||||
def test_cold_start_falls_through_to_static_policy(self, tmp_path):
|
||||
preferred = self._adapter("preferred")
|
||||
fallback = self._adapter("fallback")
|
||||
policy = AdaptiveRoutingPolicy(
|
||||
rules=[RoutingRule("summarize", prefer=preferred, fallback=fallback)],
|
||||
ledger=QualityLedger(tmp_path / "quality.jsonl"),
|
||||
adapters_by_id={"preferred": preferred, "fallback": fallback},
|
||||
)
|
||||
|
||||
assert policy.resolve("summarize", quality_floor=0.8) is preferred
|
||||
|
||||
def test_window_size_changes_observed_mean_quality(self, tmp_path):
|
||||
cheap = self._adapter("cheap")
|
||||
smart = self._adapter("smart")
|
||||
ledger = QualityLedger(tmp_path / "quality.jsonl")
|
||||
append_observation(
|
||||
ledger,
|
||||
adapter_id="cheap",
|
||||
quality_score=0.9,
|
||||
cost_usd=0.01,
|
||||
recorded_at=datetime(2026, 5, 16, tzinfo=timezone.utc),
|
||||
)
|
||||
append_observation(
|
||||
ledger,
|
||||
adapter_id="cheap",
|
||||
quality_score=0.7,
|
||||
cost_usd=0.01,
|
||||
recorded_at=datetime(2026, 5, 17, tzinfo=timezone.utc),
|
||||
)
|
||||
append_observation(ledger, adapter_id="smart", quality_score=0.9, cost_usd=0.03)
|
||||
|
||||
recent_only = AdaptiveRoutingPolicy(
|
||||
rules=[RoutingRule("summarize", prefer=smart)],
|
||||
ledger=ledger,
|
||||
adapters_by_id={"cheap": cheap, "smart": smart},
|
||||
window_size=1,
|
||||
)
|
||||
wider_window = AdaptiveRoutingPolicy(
|
||||
rules=[RoutingRule("summarize", prefer=smart)],
|
||||
ledger=ledger,
|
||||
adapters_by_id={"cheap": cheap, "smart": smart},
|
||||
window_size=2,
|
||||
)
|
||||
|
||||
assert recent_only.resolve("summarize", quality_floor=0.8) is smart
|
||||
assert wider_window.resolve("summarize", quality_floor=0.8) is cheap
|
||||
|
||||
def test_stale_observations_are_ignored_by_max_age(self, tmp_path):
|
||||
stale = self._adapter("stale")
|
||||
fresh = self._adapter("fresh")
|
||||
ledger = QualityLedger(tmp_path / "quality.jsonl")
|
||||
append_observation(
|
||||
ledger,
|
||||
adapter_id="stale",
|
||||
quality_score=1.0,
|
||||
cost_usd=0.01,
|
||||
recorded_at=datetime(2020, 1, 1, tzinfo=timezone.utc),
|
||||
)
|
||||
append_observation(
|
||||
ledger,
|
||||
adapter_id="fresh",
|
||||
quality_score=0.9,
|
||||
cost_usd=0.03,
|
||||
recorded_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
policy = AdaptiveRoutingPolicy(
|
||||
rules=[RoutingRule("summarize", prefer=stale)],
|
||||
ledger=ledger,
|
||||
adapters_by_id={"stale": stale, "fresh": fresh},
|
||||
max_age=timedelta(days=1),
|
||||
)
|
||||
|
||||
assert policy.resolve("summarize", quality_floor=0.8) is fresh
|
||||
|
||||
def test_static_fallback_chain_is_preserved_when_no_candidate_qualifies(self, tmp_path):
|
||||
preferred = self._adapter("preferred")
|
||||
fallback = self._adapter("fallback")
|
||||
ledger = QualityLedger(tmp_path / "quality.jsonl")
|
||||
append_observation(ledger, adapter_id="preferred", quality_score=0.6, cost_usd=0.01)
|
||||
append_observation(ledger, adapter_id="fallback", quality_score=0.7, cost_usd=0.005)
|
||||
|
||||
policy = AdaptiveRoutingPolicy(
|
||||
rules=[
|
||||
RoutingRule(
|
||||
"summarize",
|
||||
prefer=preferred,
|
||||
max_cost_per_1k=1.0,
|
||||
fallback=fallback,
|
||||
)
|
||||
],
|
||||
ledger=ledger,
|
||||
adapters_by_id={"preferred": preferred, "fallback": fallback},
|
||||
)
|
||||
|
||||
assert policy.resolve("summarize", 2.0, quality_floor=0.8) is fallback
|
||||
Reference in New Issue
Block a user