generated from coulomb/repo-seed
182 lines
6.7 KiB
Python
182 lines
6.7 KiB
Python
"""
|
|
Tests for AdaptiveRoutingPolicy.
|
|
"""
|
|
|
|
from datetime import datetime, timedelta, timezone
|
|
|
|
from llm_connect.adapter import MockLLMAdapter
|
|
from llm_connect.quality import QualityLedger, QualityObservation
|
|
from llm_connect.routing import AdaptiveRoutingPolicy, RoutingRule
|
|
|
|
|
|
def append_observation(
|
|
ledger: QualityLedger,
|
|
*,
|
|
adapter_id: str,
|
|
quality_score: float,
|
|
cost_usd: float,
|
|
task_type: str = "summarize",
|
|
recorded_at: datetime | None = None,
|
|
) -> None:
|
|
ledger.append(
|
|
QualityObservation(
|
|
task_type=task_type,
|
|
adapter_id=adapter_id,
|
|
model_id=f"{adapter_id}-model",
|
|
cost_usd=cost_usd,
|
|
quality_score=quality_score,
|
|
latency_ms=100,
|
|
tokens_in=100,
|
|
tokens_out=50,
|
|
baseline_adapter_id="baseline",
|
|
recorded_at=recorded_at or datetime(2026, 5, 17, tzinfo=timezone.utc),
|
|
)
|
|
)
|
|
|
|
|
|
class TestAdaptiveRoutingPolicy:
|
|
def _adapter(self, name: str) -> MockLLMAdapter:
|
|
return MockLLMAdapter(mock_response=name)
|
|
|
|
def test_selects_cheapest_adapter_that_clears_quality_floor(self, tmp_path):
|
|
cheap = self._adapter("cheap")
|
|
smart = self._adapter("smart")
|
|
ledger = QualityLedger(tmp_path / "quality.jsonl")
|
|
append_observation(ledger, adapter_id="cheap", quality_score=0.7, cost_usd=0.01)
|
|
append_observation(ledger, adapter_id="smart", quality_score=0.9, cost_usd=0.03)
|
|
|
|
policy = AdaptiveRoutingPolicy(
|
|
rules=[RoutingRule("summarize", prefer=cheap)],
|
|
ledger=ledger,
|
|
adapters_by_id={"cheap": cheap, "smart": smart},
|
|
)
|
|
|
|
assert policy.resolve("summarize", quality_floor=0.8) is smart
|
|
|
|
def test_prefers_lower_observed_cost_when_multiple_adapters_clear_floor(self, tmp_path):
|
|
cheap = self._adapter("cheap")
|
|
smart = self._adapter("smart")
|
|
ledger = QualityLedger(tmp_path / "quality.jsonl")
|
|
append_observation(ledger, adapter_id="cheap", quality_score=0.9, cost_usd=0.01)
|
|
append_observation(ledger, adapter_id="smart", quality_score=0.95, cost_usd=0.03)
|
|
|
|
policy = AdaptiveRoutingPolicy(
|
|
rules=[RoutingRule("summarize", prefer=smart)],
|
|
ledger=ledger,
|
|
adapters_by_id={"cheap": cheap, "smart": smart},
|
|
)
|
|
|
|
assert policy.resolve("summarize", quality_floor=0.8) is cheap
|
|
|
|
def test_equal_cost_tie_prefers_static_rule_prefer(self, tmp_path):
|
|
candidate = self._adapter("candidate")
|
|
preferred = self._adapter("preferred")
|
|
ledger = QualityLedger(tmp_path / "quality.jsonl")
|
|
append_observation(ledger, adapter_id="candidate", quality_score=0.9, cost_usd=0.01)
|
|
append_observation(ledger, adapter_id="preferred", quality_score=0.9, cost_usd=0.01)
|
|
|
|
policy = AdaptiveRoutingPolicy(
|
|
rules=[RoutingRule("summarize", prefer=preferred)],
|
|
ledger=ledger,
|
|
adapters_by_id={"candidate": candidate, "preferred": preferred},
|
|
)
|
|
|
|
assert policy.resolve("summarize", quality_floor=0.8) is preferred
|
|
|
|
def test_cold_start_falls_through_to_static_policy(self, tmp_path):
|
|
preferred = self._adapter("preferred")
|
|
fallback = self._adapter("fallback")
|
|
policy = AdaptiveRoutingPolicy(
|
|
rules=[RoutingRule("summarize", prefer=preferred, fallback=fallback)],
|
|
ledger=QualityLedger(tmp_path / "quality.jsonl"),
|
|
adapters_by_id={"preferred": preferred, "fallback": fallback},
|
|
)
|
|
|
|
assert policy.resolve("summarize", quality_floor=0.8) is preferred
|
|
|
|
def test_window_size_changes_observed_mean_quality(self, tmp_path):
|
|
cheap = self._adapter("cheap")
|
|
smart = self._adapter("smart")
|
|
ledger = QualityLedger(tmp_path / "quality.jsonl")
|
|
append_observation(
|
|
ledger,
|
|
adapter_id="cheap",
|
|
quality_score=0.9,
|
|
cost_usd=0.01,
|
|
recorded_at=datetime(2026, 5, 16, tzinfo=timezone.utc),
|
|
)
|
|
append_observation(
|
|
ledger,
|
|
adapter_id="cheap",
|
|
quality_score=0.7,
|
|
cost_usd=0.01,
|
|
recorded_at=datetime(2026, 5, 17, tzinfo=timezone.utc),
|
|
)
|
|
append_observation(ledger, adapter_id="smart", quality_score=0.9, cost_usd=0.03)
|
|
|
|
recent_only = AdaptiveRoutingPolicy(
|
|
rules=[RoutingRule("summarize", prefer=smart)],
|
|
ledger=ledger,
|
|
adapters_by_id={"cheap": cheap, "smart": smart},
|
|
window_size=1,
|
|
)
|
|
wider_window = AdaptiveRoutingPolicy(
|
|
rules=[RoutingRule("summarize", prefer=smart)],
|
|
ledger=ledger,
|
|
adapters_by_id={"cheap": cheap, "smart": smart},
|
|
window_size=2,
|
|
)
|
|
|
|
assert recent_only.resolve("summarize", quality_floor=0.8) is smart
|
|
assert wider_window.resolve("summarize", quality_floor=0.8) is cheap
|
|
|
|
def test_stale_observations_are_ignored_by_max_age(self, tmp_path):
|
|
stale = self._adapter("stale")
|
|
fresh = self._adapter("fresh")
|
|
ledger = QualityLedger(tmp_path / "quality.jsonl")
|
|
append_observation(
|
|
ledger,
|
|
adapter_id="stale",
|
|
quality_score=1.0,
|
|
cost_usd=0.01,
|
|
recorded_at=datetime(2020, 1, 1, tzinfo=timezone.utc),
|
|
)
|
|
append_observation(
|
|
ledger,
|
|
adapter_id="fresh",
|
|
quality_score=0.9,
|
|
cost_usd=0.03,
|
|
recorded_at=datetime.now(timezone.utc),
|
|
)
|
|
|
|
policy = AdaptiveRoutingPolicy(
|
|
rules=[RoutingRule("summarize", prefer=stale)],
|
|
ledger=ledger,
|
|
adapters_by_id={"stale": stale, "fresh": fresh},
|
|
max_age=timedelta(days=1),
|
|
)
|
|
|
|
assert policy.resolve("summarize", quality_floor=0.8) is fresh
|
|
|
|
def test_static_fallback_chain_is_preserved_when_no_candidate_qualifies(self, tmp_path):
|
|
preferred = self._adapter("preferred")
|
|
fallback = self._adapter("fallback")
|
|
ledger = QualityLedger(tmp_path / "quality.jsonl")
|
|
append_observation(ledger, adapter_id="preferred", quality_score=0.6, cost_usd=0.01)
|
|
append_observation(ledger, adapter_id="fallback", quality_score=0.7, cost_usd=0.005)
|
|
|
|
policy = AdaptiveRoutingPolicy(
|
|
rules=[
|
|
RoutingRule(
|
|
"summarize",
|
|
prefer=preferred,
|
|
max_cost_per_1k=1.0,
|
|
fallback=fallback,
|
|
)
|
|
],
|
|
ledger=ledger,
|
|
adapters_by_id={"preferred": preferred, "fallback": fallback},
|
|
)
|
|
|
|
assert policy.resolve("summarize", 2.0, quality_floor=0.8) is fallback
|