generated from coulomb/repo-seed
Add adaptive cost-quality routing primitives
This commit is contained in:
164
tests/test_quality.py
Normal file
164
tests/test_quality.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""
|
||||
Tests for quality observations and the append-only quality ledger.
|
||||
"""
|
||||
|
||||
import threading
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from llm_connect.quality import QualityLedger, QualityObservation, is_stale
|
||||
|
||||
|
||||
def observation(
|
||||
*,
|
||||
task_type: str = "summarize",
|
||||
adapter_id: str = "openrouter:cheap",
|
||||
model_id: str = "cheap-model",
|
||||
quality_score: float = 0.8,
|
||||
recorded_at: datetime | None = None,
|
||||
tag: str | None = None,
|
||||
) -> QualityObservation:
|
||||
tags = {"tag": tag} if tag is not None else {}
|
||||
return QualityObservation(
|
||||
task_type=task_type,
|
||||
adapter_id=adapter_id,
|
||||
model_id=model_id,
|
||||
cost_usd=0.01,
|
||||
quality_score=quality_score,
|
||||
latency_ms=123.4,
|
||||
tokens_in=100,
|
||||
tokens_out=50,
|
||||
baseline_adapter_id="claude-code",
|
||||
recorded_at=recorded_at or datetime(2026, 5, 17, tzinfo=timezone.utc),
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
|
||||
class TestQualityObservation:
|
||||
def test_round_trip_dict(self):
|
||||
obs = observation(tag="a")
|
||||
restored = QualityObservation.from_dict(obs.to_dict())
|
||||
assert restored == obs
|
||||
assert restored.total_tokens == 150
|
||||
assert restored.recorded_at.tzinfo is not None
|
||||
|
||||
def test_naive_recorded_at_is_interpreted_as_utc(self):
|
||||
obs = observation(recorded_at=datetime(2026, 5, 17, 12, 0, 0))
|
||||
assert obs.recorded_at.tzinfo == timezone.utc
|
||||
|
||||
@pytest.mark.parametrize("score", [-0.1, 1.1])
|
||||
def test_quality_score_must_be_between_zero_and_one(self, score):
|
||||
with pytest.raises(ValueError, match="quality_score"):
|
||||
observation(quality_score=score)
|
||||
|
||||
def test_required_ids_must_be_non_empty(self):
|
||||
with pytest.raises(ValueError, match="task_type"):
|
||||
observation(task_type="")
|
||||
|
||||
def test_non_negative_fields_are_enforced(self):
|
||||
with pytest.raises(ValueError, match="tokens_in"):
|
||||
QualityObservation(
|
||||
task_type="x",
|
||||
adapter_id="a",
|
||||
model_id="m",
|
||||
cost_usd=0,
|
||||
quality_score=1,
|
||||
latency_ms=0,
|
||||
tokens_in=-1,
|
||||
tokens_out=0,
|
||||
)
|
||||
|
||||
|
||||
class TestQualityLedger:
|
||||
def test_append_and_read_round_trip(self, tmp_path):
|
||||
ledger = QualityLedger(tmp_path / "quality.jsonl")
|
||||
obs = observation()
|
||||
ledger.append(obs)
|
||||
assert ledger.read_all() == [obs]
|
||||
|
||||
def test_by_task_type_filters_observations(self, tmp_path):
|
||||
ledger = QualityLedger(tmp_path / "quality.jsonl")
|
||||
ledger.append(observation(task_type="summarize"))
|
||||
ledger.append(observation(task_type="extract"))
|
||||
assert [obs.task_type for obs in ledger.by_task_type("summarize")] == ["summarize"]
|
||||
|
||||
def test_recent_returns_newest_first_with_filters(self, tmp_path):
|
||||
ledger = QualityLedger(tmp_path / "quality.jsonl")
|
||||
older = observation(recorded_at=datetime(2026, 5, 1, tzinfo=timezone.utc), tag="older")
|
||||
newer = observation(recorded_at=datetime(2026, 5, 2, tzinfo=timezone.utc), tag="newer")
|
||||
other = observation(
|
||||
task_type="extract",
|
||||
recorded_at=datetime(2026, 5, 3, tzinfo=timezone.utc),
|
||||
tag="other",
|
||||
)
|
||||
ledger.append(older)
|
||||
ledger.append(newer)
|
||||
ledger.append(other)
|
||||
|
||||
recent = ledger.recent(limit=1, task_type="summarize")
|
||||
assert [obs.tags["tag"] for obs in recent] == ["newer"]
|
||||
|
||||
def test_mean_quality_filters_by_adapter_and_minimum_count(self, tmp_path):
|
||||
ledger = QualityLedger(tmp_path / "quality.jsonl")
|
||||
ledger.append(observation(adapter_id="a", quality_score=0.5))
|
||||
ledger.append(observation(adapter_id="a", quality_score=1.0))
|
||||
ledger.append(observation(adapter_id="b", quality_score=0.1))
|
||||
|
||||
assert ledger.mean_quality("summarize", adapter_id="a") == 0.75
|
||||
assert ledger.mean_quality("summarize", adapter_id="a", min_observations=3) is None
|
||||
|
||||
def test_is_stale_uses_utc_reference(self):
|
||||
obs = observation(recorded_at=datetime(2026, 5, 1, tzinfo=timezone.utc))
|
||||
now = datetime(2026, 5, 3, tzinfo=timezone.utc)
|
||||
assert is_stale(obs, timedelta(days=1), now=now) is True
|
||||
assert is_stale(obs, timedelta(days=3), now=now) is False
|
||||
|
||||
def test_prune_before_removes_old_valid_observations(self, tmp_path):
|
||||
ledger = QualityLedger(tmp_path / "quality.jsonl")
|
||||
old = observation(recorded_at=datetime(2026, 5, 1, tzinfo=timezone.utc), tag="old")
|
||||
keep = observation(recorded_at=datetime(2026, 5, 2, tzinfo=timezone.utc), tag="keep")
|
||||
ledger.append(old)
|
||||
ledger.append(keep)
|
||||
|
||||
removed = ledger.prune_before(datetime(2026, 5, 2, tzinfo=timezone.utc))
|
||||
|
||||
assert removed == 1
|
||||
assert [obs.tags["tag"] for obs in ledger.read_all()] == ["keep"]
|
||||
|
||||
def test_malformed_lines_are_skipped_and_counted(self, tmp_path):
|
||||
path = tmp_path / "quality.jsonl"
|
||||
path.write_text("{not json}\n", encoding="utf-8")
|
||||
ledger = QualityLedger(path)
|
||||
ledger.append(observation())
|
||||
|
||||
assert len(ledger.read_all()) == 1
|
||||
assert ledger.malformed_count() == 1
|
||||
|
||||
def test_prune_preserves_malformed_lines(self, tmp_path):
|
||||
path = tmp_path / "quality.jsonl"
|
||||
path.write_text("{not json}\n", encoding="utf-8")
|
||||
ledger = QualityLedger(path)
|
||||
ledger.append(observation(recorded_at=datetime(2026, 5, 1, tzinfo=timezone.utc)))
|
||||
|
||||
removed = ledger.prune_before(datetime(2026, 5, 2, tzinfo=timezone.utc))
|
||||
|
||||
assert removed == 1
|
||||
assert ledger.malformed_count() == 1
|
||||
assert ledger.read_all() == []
|
||||
|
||||
def test_concurrent_writes_round_trip(self, tmp_path):
|
||||
ledger = QualityLedger(tmp_path / "quality.jsonl")
|
||||
|
||||
def append_one(index: int) -> None:
|
||||
ledger.append(observation(tag=str(index)))
|
||||
|
||||
threads = [threading.Thread(target=append_one, args=(i,)) for i in range(25)]
|
||||
for thread in threads:
|
||||
thread.start()
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
observations = ledger.read_all()
|
||||
assert len(observations) == 25
|
||||
assert {obs.tags["tag"] for obs in observations} == {str(i) for i in range(25)}
|
||||
Reference in New Issue
Block a user