generated from coulomb/repo-seed
110 lines
3.0 KiB
Python
110 lines
3.0 KiB
Python
"""
|
|
Integration coverage for the adaptive routing workplan flow.
|
|
"""
|
|
|
|
from datetime import datetime, timezone
|
|
|
|
from examples.adaptive_routing_fixture_batch import populate_ledger
|
|
from llm_connect.adapter import MockLLMAdapter
|
|
from llm_connect.quality import QualityLedger, QualityObservation
|
|
from llm_connect.routing import AdaptiveRoutingPolicy, RoutingRule
|
|
|
|
|
|
def append_quality(
|
|
ledger: QualityLedger,
|
|
adapter_id: str,
|
|
quality_score: float,
|
|
cost_usd: float,
|
|
*,
|
|
recorded_at: datetime,
|
|
) -> None:
|
|
ledger.append(
|
|
QualityObservation(
|
|
task_type="summarize",
|
|
adapter_id=adapter_id,
|
|
model_id=f"{adapter_id}-model",
|
|
cost_usd=cost_usd,
|
|
quality_score=quality_score,
|
|
latency_ms=100,
|
|
tokens_in=100,
|
|
tokens_out=50,
|
|
recorded_at=recorded_at,
|
|
baseline_adapter_id="baseline",
|
|
)
|
|
)
|
|
|
|
|
|
def test_adaptive_policy_converges_to_cheapest_qualifying_adapter(tmp_path):
|
|
cheap = MockLLMAdapter("cheap")
|
|
mid = MockLLMAdapter("mid")
|
|
smart = MockLLMAdapter("smart")
|
|
ledger = QualityLedger(tmp_path / "quality.jsonl")
|
|
policy = AdaptiveRoutingPolicy(
|
|
rules=[
|
|
RoutingRule(
|
|
"summarize",
|
|
prefer=smart,
|
|
max_cost_per_1k=1.0,
|
|
fallback=mid,
|
|
)
|
|
],
|
|
ledger=ledger,
|
|
adapters_by_id={"cheap": cheap, "mid": mid, "smart": smart},
|
|
window_size=2,
|
|
)
|
|
|
|
assert policy.resolve("summarize", quality_floor=0.8) is smart
|
|
assert policy.resolve("summarize", 2.0, quality_floor=0.8) is mid
|
|
|
|
append_quality(
|
|
ledger,
|
|
"cheap",
|
|
quality_score=0.7,
|
|
cost_usd=0.01,
|
|
recorded_at=datetime(2026, 5, 17, 10, tzinfo=timezone.utc),
|
|
)
|
|
append_quality(
|
|
ledger,
|
|
"mid",
|
|
quality_score=0.86,
|
|
cost_usd=0.02,
|
|
recorded_at=datetime(2026, 5, 17, 10, tzinfo=timezone.utc),
|
|
)
|
|
append_quality(
|
|
ledger,
|
|
"smart",
|
|
quality_score=0.95,
|
|
cost_usd=0.05,
|
|
recorded_at=datetime(2026, 5, 17, 10, tzinfo=timezone.utc),
|
|
)
|
|
|
|
assert policy.resolve("summarize", quality_floor=0.8) is mid
|
|
|
|
append_quality(
|
|
ledger,
|
|
"cheap",
|
|
quality_score=0.95,
|
|
cost_usd=0.01,
|
|
recorded_at=datetime(2026, 5, 17, 11, tzinfo=timezone.utc),
|
|
)
|
|
|
|
assert policy.resolve("summarize", quality_floor=0.8) is cheap
|
|
|
|
|
|
def test_fixture_batch_populates_three_candidate_observations_per_task(tmp_path):
|
|
ledger = QualityLedger(tmp_path / "quality.jsonl")
|
|
|
|
populate_ledger(ledger)
|
|
|
|
observations = ledger.read_all()
|
|
by_task_type: dict[str, set[str]] = {}
|
|
for observation in observations:
|
|
by_task_type.setdefault(observation.task_type, set()).add(observation.adapter_id)
|
|
|
|
assert set(by_task_type) == {
|
|
"summarize-source",
|
|
"extract-relations",
|
|
"evaluate-entity",
|
|
}
|
|
assert all(len(adapter_ids) == 3 for adapter_ids in by_task_type.values())
|