Files
llm-connect/tests/test_quality.py
tegwick c4ad4bb9f2
Some checks failed
CI / test (3.10) (push) Has been cancelled
CI / test (3.11) (push) Has been cancelled
CI / test (3.12) (push) Has been cancelled
Add adaptive cost-quality routing primitives
2026-05-17 21:32:27 +02:00

165 lines
6.1 KiB
Python

"""
Tests for quality observations and the append-only quality ledger.
"""
import threading
from datetime import datetime, timedelta, timezone
import pytest
from llm_connect.quality import QualityLedger, QualityObservation, is_stale
def observation(
*,
task_type: str = "summarize",
adapter_id: str = "openrouter:cheap",
model_id: str = "cheap-model",
quality_score: float = 0.8,
recorded_at: datetime | None = None,
tag: str | None = None,
) -> QualityObservation:
tags = {"tag": tag} if tag is not None else {}
return QualityObservation(
task_type=task_type,
adapter_id=adapter_id,
model_id=model_id,
cost_usd=0.01,
quality_score=quality_score,
latency_ms=123.4,
tokens_in=100,
tokens_out=50,
baseline_adapter_id="claude-code",
recorded_at=recorded_at or datetime(2026, 5, 17, tzinfo=timezone.utc),
tags=tags,
)
class TestQualityObservation:
def test_round_trip_dict(self):
obs = observation(tag="a")
restored = QualityObservation.from_dict(obs.to_dict())
assert restored == obs
assert restored.total_tokens == 150
assert restored.recorded_at.tzinfo is not None
def test_naive_recorded_at_is_interpreted_as_utc(self):
obs = observation(recorded_at=datetime(2026, 5, 17, 12, 0, 0))
assert obs.recorded_at.tzinfo == timezone.utc
@pytest.mark.parametrize("score", [-0.1, 1.1])
def test_quality_score_must_be_between_zero_and_one(self, score):
with pytest.raises(ValueError, match="quality_score"):
observation(quality_score=score)
def test_required_ids_must_be_non_empty(self):
with pytest.raises(ValueError, match="task_type"):
observation(task_type="")
def test_non_negative_fields_are_enforced(self):
with pytest.raises(ValueError, match="tokens_in"):
QualityObservation(
task_type="x",
adapter_id="a",
model_id="m",
cost_usd=0,
quality_score=1,
latency_ms=0,
tokens_in=-1,
tokens_out=0,
)
class TestQualityLedger:
def test_append_and_read_round_trip(self, tmp_path):
ledger = QualityLedger(tmp_path / "quality.jsonl")
obs = observation()
ledger.append(obs)
assert ledger.read_all() == [obs]
def test_by_task_type_filters_observations(self, tmp_path):
ledger = QualityLedger(tmp_path / "quality.jsonl")
ledger.append(observation(task_type="summarize"))
ledger.append(observation(task_type="extract"))
assert [obs.task_type for obs in ledger.by_task_type("summarize")] == ["summarize"]
def test_recent_returns_newest_first_with_filters(self, tmp_path):
ledger = QualityLedger(tmp_path / "quality.jsonl")
older = observation(recorded_at=datetime(2026, 5, 1, tzinfo=timezone.utc), tag="older")
newer = observation(recorded_at=datetime(2026, 5, 2, tzinfo=timezone.utc), tag="newer")
other = observation(
task_type="extract",
recorded_at=datetime(2026, 5, 3, tzinfo=timezone.utc),
tag="other",
)
ledger.append(older)
ledger.append(newer)
ledger.append(other)
recent = ledger.recent(limit=1, task_type="summarize")
assert [obs.tags["tag"] for obs in recent] == ["newer"]
def test_mean_quality_filters_by_adapter_and_minimum_count(self, tmp_path):
ledger = QualityLedger(tmp_path / "quality.jsonl")
ledger.append(observation(adapter_id="a", quality_score=0.5))
ledger.append(observation(adapter_id="a", quality_score=1.0))
ledger.append(observation(adapter_id="b", quality_score=0.1))
assert ledger.mean_quality("summarize", adapter_id="a") == 0.75
assert ledger.mean_quality("summarize", adapter_id="a", min_observations=3) is None
def test_is_stale_uses_utc_reference(self):
obs = observation(recorded_at=datetime(2026, 5, 1, tzinfo=timezone.utc))
now = datetime(2026, 5, 3, tzinfo=timezone.utc)
assert is_stale(obs, timedelta(days=1), now=now) is True
assert is_stale(obs, timedelta(days=3), now=now) is False
def test_prune_before_removes_old_valid_observations(self, tmp_path):
ledger = QualityLedger(tmp_path / "quality.jsonl")
old = observation(recorded_at=datetime(2026, 5, 1, tzinfo=timezone.utc), tag="old")
keep = observation(recorded_at=datetime(2026, 5, 2, tzinfo=timezone.utc), tag="keep")
ledger.append(old)
ledger.append(keep)
removed = ledger.prune_before(datetime(2026, 5, 2, tzinfo=timezone.utc))
assert removed == 1
assert [obs.tags["tag"] for obs in ledger.read_all()] == ["keep"]
def test_malformed_lines_are_skipped_and_counted(self, tmp_path):
path = tmp_path / "quality.jsonl"
path.write_text("{not json}\n", encoding="utf-8")
ledger = QualityLedger(path)
ledger.append(observation())
assert len(ledger.read_all()) == 1
assert ledger.malformed_count() == 1
def test_prune_preserves_malformed_lines(self, tmp_path):
path = tmp_path / "quality.jsonl"
path.write_text("{not json}\n", encoding="utf-8")
ledger = QualityLedger(path)
ledger.append(observation(recorded_at=datetime(2026, 5, 1, tzinfo=timezone.utc)))
removed = ledger.prune_before(datetime(2026, 5, 2, tzinfo=timezone.utc))
assert removed == 1
assert ledger.malformed_count() == 1
assert ledger.read_all() == []
def test_concurrent_writes_round_trip(self, tmp_path):
ledger = QualityLedger(tmp_path / "quality.jsonl")
def append_one(index: int) -> None:
ledger.append(observation(tag=str(index)))
threads = [threading.Thread(target=append_one, args=(i,)) for i in range(25)]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
observations = ledger.read_all()
assert len(observations) == 25
assert {obs.tags["tag"] for obs in observations} == {str(i) for i in range(25)}