generated from coulomb/repo-seed
165 lines
6.1 KiB
Python
165 lines
6.1 KiB
Python
"""
|
|
Tests for quality observations and the append-only quality ledger.
|
|
"""
|
|
|
|
import threading
|
|
from datetime import datetime, timedelta, timezone
|
|
|
|
import pytest
|
|
|
|
from llm_connect.quality import QualityLedger, QualityObservation, is_stale
|
|
|
|
|
|
def observation(
|
|
*,
|
|
task_type: str = "summarize",
|
|
adapter_id: str = "openrouter:cheap",
|
|
model_id: str = "cheap-model",
|
|
quality_score: float = 0.8,
|
|
recorded_at: datetime | None = None,
|
|
tag: str | None = None,
|
|
) -> QualityObservation:
|
|
tags = {"tag": tag} if tag is not None else {}
|
|
return QualityObservation(
|
|
task_type=task_type,
|
|
adapter_id=adapter_id,
|
|
model_id=model_id,
|
|
cost_usd=0.01,
|
|
quality_score=quality_score,
|
|
latency_ms=123.4,
|
|
tokens_in=100,
|
|
tokens_out=50,
|
|
baseline_adapter_id="claude-code",
|
|
recorded_at=recorded_at or datetime(2026, 5, 17, tzinfo=timezone.utc),
|
|
tags=tags,
|
|
)
|
|
|
|
|
|
class TestQualityObservation:
|
|
def test_round_trip_dict(self):
|
|
obs = observation(tag="a")
|
|
restored = QualityObservation.from_dict(obs.to_dict())
|
|
assert restored == obs
|
|
assert restored.total_tokens == 150
|
|
assert restored.recorded_at.tzinfo is not None
|
|
|
|
def test_naive_recorded_at_is_interpreted_as_utc(self):
|
|
obs = observation(recorded_at=datetime(2026, 5, 17, 12, 0, 0))
|
|
assert obs.recorded_at.tzinfo == timezone.utc
|
|
|
|
@pytest.mark.parametrize("score", [-0.1, 1.1])
|
|
def test_quality_score_must_be_between_zero_and_one(self, score):
|
|
with pytest.raises(ValueError, match="quality_score"):
|
|
observation(quality_score=score)
|
|
|
|
def test_required_ids_must_be_non_empty(self):
|
|
with pytest.raises(ValueError, match="task_type"):
|
|
observation(task_type="")
|
|
|
|
def test_non_negative_fields_are_enforced(self):
|
|
with pytest.raises(ValueError, match="tokens_in"):
|
|
QualityObservation(
|
|
task_type="x",
|
|
adapter_id="a",
|
|
model_id="m",
|
|
cost_usd=0,
|
|
quality_score=1,
|
|
latency_ms=0,
|
|
tokens_in=-1,
|
|
tokens_out=0,
|
|
)
|
|
|
|
|
|
class TestQualityLedger:
|
|
def test_append_and_read_round_trip(self, tmp_path):
|
|
ledger = QualityLedger(tmp_path / "quality.jsonl")
|
|
obs = observation()
|
|
ledger.append(obs)
|
|
assert ledger.read_all() == [obs]
|
|
|
|
def test_by_task_type_filters_observations(self, tmp_path):
|
|
ledger = QualityLedger(tmp_path / "quality.jsonl")
|
|
ledger.append(observation(task_type="summarize"))
|
|
ledger.append(observation(task_type="extract"))
|
|
assert [obs.task_type for obs in ledger.by_task_type("summarize")] == ["summarize"]
|
|
|
|
def test_recent_returns_newest_first_with_filters(self, tmp_path):
|
|
ledger = QualityLedger(tmp_path / "quality.jsonl")
|
|
older = observation(recorded_at=datetime(2026, 5, 1, tzinfo=timezone.utc), tag="older")
|
|
newer = observation(recorded_at=datetime(2026, 5, 2, tzinfo=timezone.utc), tag="newer")
|
|
other = observation(
|
|
task_type="extract",
|
|
recorded_at=datetime(2026, 5, 3, tzinfo=timezone.utc),
|
|
tag="other",
|
|
)
|
|
ledger.append(older)
|
|
ledger.append(newer)
|
|
ledger.append(other)
|
|
|
|
recent = ledger.recent(limit=1, task_type="summarize")
|
|
assert [obs.tags["tag"] for obs in recent] == ["newer"]
|
|
|
|
def test_mean_quality_filters_by_adapter_and_minimum_count(self, tmp_path):
|
|
ledger = QualityLedger(tmp_path / "quality.jsonl")
|
|
ledger.append(observation(adapter_id="a", quality_score=0.5))
|
|
ledger.append(observation(adapter_id="a", quality_score=1.0))
|
|
ledger.append(observation(adapter_id="b", quality_score=0.1))
|
|
|
|
assert ledger.mean_quality("summarize", adapter_id="a") == 0.75
|
|
assert ledger.mean_quality("summarize", adapter_id="a", min_observations=3) is None
|
|
|
|
def test_is_stale_uses_utc_reference(self):
|
|
obs = observation(recorded_at=datetime(2026, 5, 1, tzinfo=timezone.utc))
|
|
now = datetime(2026, 5, 3, tzinfo=timezone.utc)
|
|
assert is_stale(obs, timedelta(days=1), now=now) is True
|
|
assert is_stale(obs, timedelta(days=3), now=now) is False
|
|
|
|
def test_prune_before_removes_old_valid_observations(self, tmp_path):
|
|
ledger = QualityLedger(tmp_path / "quality.jsonl")
|
|
old = observation(recorded_at=datetime(2026, 5, 1, tzinfo=timezone.utc), tag="old")
|
|
keep = observation(recorded_at=datetime(2026, 5, 2, tzinfo=timezone.utc), tag="keep")
|
|
ledger.append(old)
|
|
ledger.append(keep)
|
|
|
|
removed = ledger.prune_before(datetime(2026, 5, 2, tzinfo=timezone.utc))
|
|
|
|
assert removed == 1
|
|
assert [obs.tags["tag"] for obs in ledger.read_all()] == ["keep"]
|
|
|
|
def test_malformed_lines_are_skipped_and_counted(self, tmp_path):
|
|
path = tmp_path / "quality.jsonl"
|
|
path.write_text("{not json}\n", encoding="utf-8")
|
|
ledger = QualityLedger(path)
|
|
ledger.append(observation())
|
|
|
|
assert len(ledger.read_all()) == 1
|
|
assert ledger.malformed_count() == 1
|
|
|
|
def test_prune_preserves_malformed_lines(self, tmp_path):
|
|
path = tmp_path / "quality.jsonl"
|
|
path.write_text("{not json}\n", encoding="utf-8")
|
|
ledger = QualityLedger(path)
|
|
ledger.append(observation(recorded_at=datetime(2026, 5, 1, tzinfo=timezone.utc)))
|
|
|
|
removed = ledger.prune_before(datetime(2026, 5, 2, tzinfo=timezone.utc))
|
|
|
|
assert removed == 1
|
|
assert ledger.malformed_count() == 1
|
|
assert ledger.read_all() == []
|
|
|
|
def test_concurrent_writes_round_trip(self, tmp_path):
|
|
ledger = QualityLedger(tmp_path / "quality.jsonl")
|
|
|
|
def append_one(index: int) -> None:
|
|
ledger.append(observation(tag=str(index)))
|
|
|
|
threads = [threading.Thread(target=append_one, args=(i,)) for i in range(25)]
|
|
for thread in threads:
|
|
thread.start()
|
|
for thread in threads:
|
|
thread.join()
|
|
|
|
observations = ledger.read_all()
|
|
assert len(observations) == 25
|
|
assert {obs.tags["tag"] for obs in observations} == {str(i) for i in range(25)}
|