Files
llm-connect/llm_connect/quality.py
tegwick c4ad4bb9f2
Some checks failed
CI / test (3.10) (push) Has been cancelled
CI / test (3.11) (push) Has been cancelled
CI / test (3.12) (push) Has been cancelled
Add adaptive cost-quality routing primitives
2026-05-17 21:32:27 +02:00

319 lines
11 KiB
Python

"""Quality observations and append-only ledger support.
These primitives let callers record observed quality/cost outcomes for a
task type without baking consumer-specific routing policy into llm-connect.
"""
from __future__ import annotations
import json
import os
import threading
from contextlib import contextmanager
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Any, Iterator, TextIO
_PATH_LOCKS: dict[Path, threading.Lock] = {}
_PATH_LOCKS_GUARD = threading.Lock()
def _utc_now() -> datetime:
return datetime.now(timezone.utc)
def _normalise_datetime(value: datetime | str) -> datetime:
if isinstance(value, datetime):
dt = value
elif isinstance(value, str):
dt = datetime.fromisoformat(value.replace("Z", "+00:00"))
else:
raise TypeError(f"Expected datetime or ISO string, got {type(value).__name__}")
if dt.tzinfo is None:
return dt.replace(tzinfo=timezone.utc)
return dt.astimezone(timezone.utc)
def _serialise_datetime(value: datetime) -> str:
return _normalise_datetime(value).isoformat().replace("+00:00", "Z")
def _validate_non_negative_int(name: str, value: int) -> None:
if not isinstance(value, int) or value < 0:
raise ValueError(f"{name} must be a non-negative integer")
def _validate_non_negative_float(name: str, value: float) -> None:
if not isinstance(value, (int, float)) or float(value) < 0:
raise ValueError(f"{name} must be a non-negative number")
def _path_lock(path: Path) -> threading.Lock:
resolved = path.resolve()
with _PATH_LOCKS_GUARD:
lock = _PATH_LOCKS.get(resolved)
if lock is None:
lock = threading.Lock()
_PATH_LOCKS[resolved] = lock
return lock
def _lock_file(handle: TextIO) -> None:
if os.name == "nt":
import msvcrt
msvcrt.locking(handle.fileno(), msvcrt.LK_LOCK, 1)
else:
import fcntl
fcntl.flock(handle.fileno(), fcntl.LOCK_EX)
def _unlock_file(handle: TextIO) -> None:
if os.name == "nt":
import msvcrt
msvcrt.locking(handle.fileno(), msvcrt.LK_UNLCK, 1)
else:
import fcntl
fcntl.flock(handle.fileno(), fcntl.LOCK_UN)
@contextmanager
def _locked_file(path: Path, mode: str) -> Iterator[TextIO]:
path.parent.mkdir(parents=True, exist_ok=True)
local_lock = _path_lock(path)
with local_lock:
with path.open(mode, encoding="utf-8") as handle:
_lock_file(handle)
try:
yield handle
finally:
_unlock_file(handle)
@dataclass(frozen=True)
class QualityObservation:
"""Observed quality/cost outcome for one adapter on one task type."""
task_type: str
adapter_id: str
model_id: str
cost_usd: float
quality_score: float
latency_ms: float
tokens_in: int
tokens_out: int
baseline_adapter_id: str | None = None
recorded_at: datetime = field(default_factory=_utc_now)
tags: dict[str, Any] = field(default_factory=dict)
def __post_init__(self) -> None:
for name in ("task_type", "adapter_id", "model_id"):
if not str(getattr(self, name)).strip():
raise ValueError(f"{name} must be a non-empty string")
_validate_non_negative_float("cost_usd", self.cost_usd)
_validate_non_negative_float("latency_ms", self.latency_ms)
_validate_non_negative_int("tokens_in", self.tokens_in)
_validate_non_negative_int("tokens_out", self.tokens_out)
if not isinstance(self.quality_score, (int, float)):
raise ValueError("quality_score must be a number between 0 and 1")
if not 0 <= float(self.quality_score) <= 1:
raise ValueError("quality_score must be between 0 and 1")
object.__setattr__(self, "task_type", str(self.task_type))
object.__setattr__(self, "adapter_id", str(self.adapter_id))
object.__setattr__(self, "model_id", str(self.model_id))
object.__setattr__(self, "cost_usd", float(self.cost_usd))
object.__setattr__(self, "quality_score", float(self.quality_score))
object.__setattr__(self, "latency_ms", float(self.latency_ms))
object.__setattr__(self, "recorded_at", _normalise_datetime(self.recorded_at))
object.__setattr__(self, "tags", dict(self.tags))
@property
def total_tokens(self) -> int:
"""Return input plus output tokens."""
return self.tokens_in + self.tokens_out
def to_dict(self) -> dict[str, Any]:
"""Convert to a JSON-serialisable dictionary."""
return {
"task_type": self.task_type,
"adapter_id": self.adapter_id,
"model_id": self.model_id,
"cost_usd": self.cost_usd,
"quality_score": self.quality_score,
"latency_ms": self.latency_ms,
"tokens_in": self.tokens_in,
"tokens_out": self.tokens_out,
"baseline_adapter_id": self.baseline_adapter_id,
"recorded_at": _serialise_datetime(self.recorded_at),
"tags": dict(self.tags),
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "QualityObservation":
"""Create an observation from a JSON-decoded dictionary."""
return cls(
task_type=data["task_type"],
adapter_id=data["adapter_id"],
model_id=data["model_id"],
cost_usd=data["cost_usd"],
quality_score=data["quality_score"],
latency_ms=data["latency_ms"],
tokens_in=data["tokens_in"],
tokens_out=data["tokens_out"],
baseline_adapter_id=data.get("baseline_adapter_id"),
recorded_at=data.get("recorded_at", _utc_now()),
tags=data.get("tags") or {},
)
def is_stale(
observation: QualityObservation,
max_age: timedelta,
*,
now: datetime | None = None,
) -> bool:
"""Return whether *observation* is older than *max_age*."""
if max_age.total_seconds() < 0:
raise ValueError("max_age must be non-negative")
reference = _normalise_datetime(now or _utc_now())
return observation.recorded_at < reference - max_age
class QualityLedger:
"""Append-only JSONL store for :class:`QualityObservation` records."""
def __init__(self, path: str | Path):
self._path = Path(path)
@property
def path(self) -> Path:
"""Ledger file path."""
return self._path
def append(self, observation: QualityObservation) -> None:
"""Append one observation as a locked JSONL record."""
line = json.dumps(observation.to_dict(), sort_keys=True, separators=(",", ":"))
with _locked_file(self._path, "a") as handle:
handle.write(line + "\n")
handle.flush()
os.fsync(handle.fileno())
def read_all(self) -> list[QualityObservation]:
"""Return all parseable observations, skipping malformed lines."""
observations, _ = self._read_with_malformed_count()
return observations
def malformed_count(self) -> int:
"""Return the number of malformed lines currently skipped by reads."""
_, malformed = self._read_with_malformed_count()
return malformed
def by_task_type(self, task_type: str) -> list[QualityObservation]:
"""Return observations matching *task_type*."""
return [obs for obs in self.read_all() if obs.task_type == task_type]
def recent(
self,
limit: int | None = None,
*,
task_type: str | None = None,
adapter_id: str | None = None,
since: datetime | None = None,
) -> list[QualityObservation]:
"""Return newest observations first, optionally filtered."""
if limit is not None and limit < 0:
raise ValueError("limit must be non-negative")
cutoff = _normalise_datetime(since) if since is not None else None
observations = self.read_all()
if task_type is not None:
observations = [obs for obs in observations if obs.task_type == task_type]
if adapter_id is not None:
observations = [obs for obs in observations if obs.adapter_id == adapter_id]
if cutoff is not None:
observations = [obs for obs in observations if obs.recorded_at >= cutoff]
observations.sort(key=lambda obs: obs.recorded_at, reverse=True)
if limit is None:
return observations
return observations[:limit]
def mean_quality(
self,
task_type: str,
*,
adapter_id: str | None = None,
model_id: str | None = None,
max_age: timedelta | None = None,
min_observations: int = 1,
) -> float | None:
"""Return mean quality for matching observations, or ``None`` if absent."""
if min_observations <= 0:
raise ValueError("min_observations must be positive")
observations = self.by_task_type(task_type)
if adapter_id is not None:
observations = [obs for obs in observations if obs.adapter_id == adapter_id]
if model_id is not None:
observations = [obs for obs in observations if obs.model_id == model_id]
if max_age is not None:
observations = [obs for obs in observations if not is_stale(obs, max_age)]
if len(observations) < min_observations:
return None
return sum(obs.quality_score for obs in observations) / len(observations)
def prune_before(self, timestamp: datetime) -> int:
"""Remove valid observations recorded before *timestamp*.
Malformed lines are preserved because their timestamp cannot be trusted.
Returns the number of valid observation records removed.
"""
cutoff = _normalise_datetime(timestamp)
removed = 0
with _locked_file(self._path, "a+") as handle:
handle.seek(0)
lines = handle.readlines()
kept: list[str] = []
for line in lines:
try:
obs = QualityObservation.from_dict(json.loads(line))
except (json.JSONDecodeError, KeyError, TypeError, ValueError):
kept.append(line)
continue
if obs.recorded_at < cutoff:
removed += 1
else:
kept.append(line)
handle.seek(0)
handle.truncate()
handle.writelines(kept)
handle.flush()
os.fsync(handle.fileno())
return removed
def _read_with_malformed_count(self) -> tuple[list[QualityObservation], int]:
if not self._path.is_file():
return [], 0
observations: list[QualityObservation] = []
malformed = 0
with _locked_file(self._path, "r") as handle:
for line in handle:
if not line.strip():
continue
try:
observations.append(QualityObservation.from_dict(json.loads(line)))
except (json.JSONDecodeError, KeyError, TypeError, ValueError):
malformed += 1
return observations, malformed