Files
llm-connect/tests/test_adaptive_routing.py
tegwick c4ad4bb9f2
Some checks failed
CI / test (3.10) (push) Has been cancelled
CI / test (3.11) (push) Has been cancelled
CI / test (3.12) (push) Has been cancelled
Add adaptive cost-quality routing primitives
2026-05-17 21:32:27 +02:00

182 lines
6.7 KiB
Python

"""
Tests for AdaptiveRoutingPolicy.
"""
from datetime import datetime, timedelta, timezone
from llm_connect.adapter import MockLLMAdapter
from llm_connect.quality import QualityLedger, QualityObservation
from llm_connect.routing import AdaptiveRoutingPolicy, RoutingRule
def append_observation(
ledger: QualityLedger,
*,
adapter_id: str,
quality_score: float,
cost_usd: float,
task_type: str = "summarize",
recorded_at: datetime | None = None,
) -> None:
ledger.append(
QualityObservation(
task_type=task_type,
adapter_id=adapter_id,
model_id=f"{adapter_id}-model",
cost_usd=cost_usd,
quality_score=quality_score,
latency_ms=100,
tokens_in=100,
tokens_out=50,
baseline_adapter_id="baseline",
recorded_at=recorded_at or datetime(2026, 5, 17, tzinfo=timezone.utc),
)
)
class TestAdaptiveRoutingPolicy:
def _adapter(self, name: str) -> MockLLMAdapter:
return MockLLMAdapter(mock_response=name)
def test_selects_cheapest_adapter_that_clears_quality_floor(self, tmp_path):
cheap = self._adapter("cheap")
smart = self._adapter("smart")
ledger = QualityLedger(tmp_path / "quality.jsonl")
append_observation(ledger, adapter_id="cheap", quality_score=0.7, cost_usd=0.01)
append_observation(ledger, adapter_id="smart", quality_score=0.9, cost_usd=0.03)
policy = AdaptiveRoutingPolicy(
rules=[RoutingRule("summarize", prefer=cheap)],
ledger=ledger,
adapters_by_id={"cheap": cheap, "smart": smart},
)
assert policy.resolve("summarize", quality_floor=0.8) is smart
def test_prefers_lower_observed_cost_when_multiple_adapters_clear_floor(self, tmp_path):
cheap = self._adapter("cheap")
smart = self._adapter("smart")
ledger = QualityLedger(tmp_path / "quality.jsonl")
append_observation(ledger, adapter_id="cheap", quality_score=0.9, cost_usd=0.01)
append_observation(ledger, adapter_id="smart", quality_score=0.95, cost_usd=0.03)
policy = AdaptiveRoutingPolicy(
rules=[RoutingRule("summarize", prefer=smart)],
ledger=ledger,
adapters_by_id={"cheap": cheap, "smart": smart},
)
assert policy.resolve("summarize", quality_floor=0.8) is cheap
def test_equal_cost_tie_prefers_static_rule_prefer(self, tmp_path):
candidate = self._adapter("candidate")
preferred = self._adapter("preferred")
ledger = QualityLedger(tmp_path / "quality.jsonl")
append_observation(ledger, adapter_id="candidate", quality_score=0.9, cost_usd=0.01)
append_observation(ledger, adapter_id="preferred", quality_score=0.9, cost_usd=0.01)
policy = AdaptiveRoutingPolicy(
rules=[RoutingRule("summarize", prefer=preferred)],
ledger=ledger,
adapters_by_id={"candidate": candidate, "preferred": preferred},
)
assert policy.resolve("summarize", quality_floor=0.8) is preferred
def test_cold_start_falls_through_to_static_policy(self, tmp_path):
preferred = self._adapter("preferred")
fallback = self._adapter("fallback")
policy = AdaptiveRoutingPolicy(
rules=[RoutingRule("summarize", prefer=preferred, fallback=fallback)],
ledger=QualityLedger(tmp_path / "quality.jsonl"),
adapters_by_id={"preferred": preferred, "fallback": fallback},
)
assert policy.resolve("summarize", quality_floor=0.8) is preferred
def test_window_size_changes_observed_mean_quality(self, tmp_path):
cheap = self._adapter("cheap")
smart = self._adapter("smart")
ledger = QualityLedger(tmp_path / "quality.jsonl")
append_observation(
ledger,
adapter_id="cheap",
quality_score=0.9,
cost_usd=0.01,
recorded_at=datetime(2026, 5, 16, tzinfo=timezone.utc),
)
append_observation(
ledger,
adapter_id="cheap",
quality_score=0.7,
cost_usd=0.01,
recorded_at=datetime(2026, 5, 17, tzinfo=timezone.utc),
)
append_observation(ledger, adapter_id="smart", quality_score=0.9, cost_usd=0.03)
recent_only = AdaptiveRoutingPolicy(
rules=[RoutingRule("summarize", prefer=smart)],
ledger=ledger,
adapters_by_id={"cheap": cheap, "smart": smart},
window_size=1,
)
wider_window = AdaptiveRoutingPolicy(
rules=[RoutingRule("summarize", prefer=smart)],
ledger=ledger,
adapters_by_id={"cheap": cheap, "smart": smart},
window_size=2,
)
assert recent_only.resolve("summarize", quality_floor=0.8) is smart
assert wider_window.resolve("summarize", quality_floor=0.8) is cheap
def test_stale_observations_are_ignored_by_max_age(self, tmp_path):
stale = self._adapter("stale")
fresh = self._adapter("fresh")
ledger = QualityLedger(tmp_path / "quality.jsonl")
append_observation(
ledger,
adapter_id="stale",
quality_score=1.0,
cost_usd=0.01,
recorded_at=datetime(2020, 1, 1, tzinfo=timezone.utc),
)
append_observation(
ledger,
adapter_id="fresh",
quality_score=0.9,
cost_usd=0.03,
recorded_at=datetime.now(timezone.utc),
)
policy = AdaptiveRoutingPolicy(
rules=[RoutingRule("summarize", prefer=stale)],
ledger=ledger,
adapters_by_id={"stale": stale, "fresh": fresh},
max_age=timedelta(days=1),
)
assert policy.resolve("summarize", quality_floor=0.8) is fresh
def test_static_fallback_chain_is_preserved_when_no_candidate_qualifies(self, tmp_path):
preferred = self._adapter("preferred")
fallback = self._adapter("fallback")
ledger = QualityLedger(tmp_path / "quality.jsonl")
append_observation(ledger, adapter_id="preferred", quality_score=0.6, cost_usd=0.01)
append_observation(ledger, adapter_id="fallback", quality_score=0.7, cost_usd=0.005)
policy = AdaptiveRoutingPolicy(
rules=[
RoutingRule(
"summarize",
prefer=preferred,
max_cost_per_1k=1.0,
fallback=fallback,
)
],
ledger=ledger,
adapters_by_id={"preferred": preferred, "fallback": fallback},
)
assert policy.resolve("summarize", 2.0, quality_floor=0.8) is fallback