generated from coulomb/repo-seed
Add adaptive cost-quality routing primitives
This commit is contained in:
318
llm_connect/quality.py
Normal file
318
llm_connect/quality.py
Normal file
@@ -0,0 +1,318 @@
|
||||
"""Quality observations and append-only ledger support.
|
||||
|
||||
These primitives let callers record observed quality/cost outcomes for a
|
||||
task type without baking consumer-specific routing policy into llm-connect.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterator, TextIO
|
||||
|
||||
|
||||
_PATH_LOCKS: dict[Path, threading.Lock] = {}
|
||||
_PATH_LOCKS_GUARD = threading.Lock()
|
||||
|
||||
|
||||
def _utc_now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def _normalise_datetime(value: datetime | str) -> datetime:
|
||||
if isinstance(value, datetime):
|
||||
dt = value
|
||||
elif isinstance(value, str):
|
||||
dt = datetime.fromisoformat(value.replace("Z", "+00:00"))
|
||||
else:
|
||||
raise TypeError(f"Expected datetime or ISO string, got {type(value).__name__}")
|
||||
|
||||
if dt.tzinfo is None:
|
||||
return dt.replace(tzinfo=timezone.utc)
|
||||
return dt.astimezone(timezone.utc)
|
||||
|
||||
|
||||
def _serialise_datetime(value: datetime) -> str:
|
||||
return _normalise_datetime(value).isoformat().replace("+00:00", "Z")
|
||||
|
||||
|
||||
def _validate_non_negative_int(name: str, value: int) -> None:
|
||||
if not isinstance(value, int) or value < 0:
|
||||
raise ValueError(f"{name} must be a non-negative integer")
|
||||
|
||||
|
||||
def _validate_non_negative_float(name: str, value: float) -> None:
|
||||
if not isinstance(value, (int, float)) or float(value) < 0:
|
||||
raise ValueError(f"{name} must be a non-negative number")
|
||||
|
||||
|
||||
def _path_lock(path: Path) -> threading.Lock:
|
||||
resolved = path.resolve()
|
||||
with _PATH_LOCKS_GUARD:
|
||||
lock = _PATH_LOCKS.get(resolved)
|
||||
if lock is None:
|
||||
lock = threading.Lock()
|
||||
_PATH_LOCKS[resolved] = lock
|
||||
return lock
|
||||
|
||||
|
||||
def _lock_file(handle: TextIO) -> None:
|
||||
if os.name == "nt":
|
||||
import msvcrt
|
||||
|
||||
msvcrt.locking(handle.fileno(), msvcrt.LK_LOCK, 1)
|
||||
else:
|
||||
import fcntl
|
||||
|
||||
fcntl.flock(handle.fileno(), fcntl.LOCK_EX)
|
||||
|
||||
|
||||
def _unlock_file(handle: TextIO) -> None:
|
||||
if os.name == "nt":
|
||||
import msvcrt
|
||||
|
||||
msvcrt.locking(handle.fileno(), msvcrt.LK_UNLCK, 1)
|
||||
else:
|
||||
import fcntl
|
||||
|
||||
fcntl.flock(handle.fileno(), fcntl.LOCK_UN)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _locked_file(path: Path, mode: str) -> Iterator[TextIO]:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
local_lock = _path_lock(path)
|
||||
with local_lock:
|
||||
with path.open(mode, encoding="utf-8") as handle:
|
||||
_lock_file(handle)
|
||||
try:
|
||||
yield handle
|
||||
finally:
|
||||
_unlock_file(handle)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class QualityObservation:
|
||||
"""Observed quality/cost outcome for one adapter on one task type."""
|
||||
|
||||
task_type: str
|
||||
adapter_id: str
|
||||
model_id: str
|
||||
cost_usd: float
|
||||
quality_score: float
|
||||
latency_ms: float
|
||||
tokens_in: int
|
||||
tokens_out: int
|
||||
baseline_adapter_id: str | None = None
|
||||
recorded_at: datetime = field(default_factory=_utc_now)
|
||||
tags: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
for name in ("task_type", "adapter_id", "model_id"):
|
||||
if not str(getattr(self, name)).strip():
|
||||
raise ValueError(f"{name} must be a non-empty string")
|
||||
|
||||
_validate_non_negative_float("cost_usd", self.cost_usd)
|
||||
_validate_non_negative_float("latency_ms", self.latency_ms)
|
||||
_validate_non_negative_int("tokens_in", self.tokens_in)
|
||||
_validate_non_negative_int("tokens_out", self.tokens_out)
|
||||
if not isinstance(self.quality_score, (int, float)):
|
||||
raise ValueError("quality_score must be a number between 0 and 1")
|
||||
if not 0 <= float(self.quality_score) <= 1:
|
||||
raise ValueError("quality_score must be between 0 and 1")
|
||||
|
||||
object.__setattr__(self, "task_type", str(self.task_type))
|
||||
object.__setattr__(self, "adapter_id", str(self.adapter_id))
|
||||
object.__setattr__(self, "model_id", str(self.model_id))
|
||||
object.__setattr__(self, "cost_usd", float(self.cost_usd))
|
||||
object.__setattr__(self, "quality_score", float(self.quality_score))
|
||||
object.__setattr__(self, "latency_ms", float(self.latency_ms))
|
||||
object.__setattr__(self, "recorded_at", _normalise_datetime(self.recorded_at))
|
||||
object.__setattr__(self, "tags", dict(self.tags))
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int:
|
||||
"""Return input plus output tokens."""
|
||||
return self.tokens_in + self.tokens_out
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to a JSON-serialisable dictionary."""
|
||||
return {
|
||||
"task_type": self.task_type,
|
||||
"adapter_id": self.adapter_id,
|
||||
"model_id": self.model_id,
|
||||
"cost_usd": self.cost_usd,
|
||||
"quality_score": self.quality_score,
|
||||
"latency_ms": self.latency_ms,
|
||||
"tokens_in": self.tokens_in,
|
||||
"tokens_out": self.tokens_out,
|
||||
"baseline_adapter_id": self.baseline_adapter_id,
|
||||
"recorded_at": _serialise_datetime(self.recorded_at),
|
||||
"tags": dict(self.tags),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "QualityObservation":
|
||||
"""Create an observation from a JSON-decoded dictionary."""
|
||||
return cls(
|
||||
task_type=data["task_type"],
|
||||
adapter_id=data["adapter_id"],
|
||||
model_id=data["model_id"],
|
||||
cost_usd=data["cost_usd"],
|
||||
quality_score=data["quality_score"],
|
||||
latency_ms=data["latency_ms"],
|
||||
tokens_in=data["tokens_in"],
|
||||
tokens_out=data["tokens_out"],
|
||||
baseline_adapter_id=data.get("baseline_adapter_id"),
|
||||
recorded_at=data.get("recorded_at", _utc_now()),
|
||||
tags=data.get("tags") or {},
|
||||
)
|
||||
|
||||
|
||||
def is_stale(
|
||||
observation: QualityObservation,
|
||||
max_age: timedelta,
|
||||
*,
|
||||
now: datetime | None = None,
|
||||
) -> bool:
|
||||
"""Return whether *observation* is older than *max_age*."""
|
||||
if max_age.total_seconds() < 0:
|
||||
raise ValueError("max_age must be non-negative")
|
||||
reference = _normalise_datetime(now or _utc_now())
|
||||
return observation.recorded_at < reference - max_age
|
||||
|
||||
|
||||
class QualityLedger:
|
||||
"""Append-only JSONL store for :class:`QualityObservation` records."""
|
||||
|
||||
def __init__(self, path: str | Path):
|
||||
self._path = Path(path)
|
||||
|
||||
@property
|
||||
def path(self) -> Path:
|
||||
"""Ledger file path."""
|
||||
return self._path
|
||||
|
||||
def append(self, observation: QualityObservation) -> None:
|
||||
"""Append one observation as a locked JSONL record."""
|
||||
line = json.dumps(observation.to_dict(), sort_keys=True, separators=(",", ":"))
|
||||
with _locked_file(self._path, "a") as handle:
|
||||
handle.write(line + "\n")
|
||||
handle.flush()
|
||||
os.fsync(handle.fileno())
|
||||
|
||||
def read_all(self) -> list[QualityObservation]:
|
||||
"""Return all parseable observations, skipping malformed lines."""
|
||||
observations, _ = self._read_with_malformed_count()
|
||||
return observations
|
||||
|
||||
def malformed_count(self) -> int:
|
||||
"""Return the number of malformed lines currently skipped by reads."""
|
||||
_, malformed = self._read_with_malformed_count()
|
||||
return malformed
|
||||
|
||||
def by_task_type(self, task_type: str) -> list[QualityObservation]:
|
||||
"""Return observations matching *task_type*."""
|
||||
return [obs for obs in self.read_all() if obs.task_type == task_type]
|
||||
|
||||
def recent(
|
||||
self,
|
||||
limit: int | None = None,
|
||||
*,
|
||||
task_type: str | None = None,
|
||||
adapter_id: str | None = None,
|
||||
since: datetime | None = None,
|
||||
) -> list[QualityObservation]:
|
||||
"""Return newest observations first, optionally filtered."""
|
||||
if limit is not None and limit < 0:
|
||||
raise ValueError("limit must be non-negative")
|
||||
|
||||
cutoff = _normalise_datetime(since) if since is not None else None
|
||||
observations = self.read_all()
|
||||
if task_type is not None:
|
||||
observations = [obs for obs in observations if obs.task_type == task_type]
|
||||
if adapter_id is not None:
|
||||
observations = [obs for obs in observations if obs.adapter_id == adapter_id]
|
||||
if cutoff is not None:
|
||||
observations = [obs for obs in observations if obs.recorded_at >= cutoff]
|
||||
|
||||
observations.sort(key=lambda obs: obs.recorded_at, reverse=True)
|
||||
if limit is None:
|
||||
return observations
|
||||
return observations[:limit]
|
||||
|
||||
def mean_quality(
|
||||
self,
|
||||
task_type: str,
|
||||
*,
|
||||
adapter_id: str | None = None,
|
||||
model_id: str | None = None,
|
||||
max_age: timedelta | None = None,
|
||||
min_observations: int = 1,
|
||||
) -> float | None:
|
||||
"""Return mean quality for matching observations, or ``None`` if absent."""
|
||||
if min_observations <= 0:
|
||||
raise ValueError("min_observations must be positive")
|
||||
|
||||
observations = self.by_task_type(task_type)
|
||||
if adapter_id is not None:
|
||||
observations = [obs for obs in observations if obs.adapter_id == adapter_id]
|
||||
if model_id is not None:
|
||||
observations = [obs for obs in observations if obs.model_id == model_id]
|
||||
if max_age is not None:
|
||||
observations = [obs for obs in observations if not is_stale(obs, max_age)]
|
||||
|
||||
if len(observations) < min_observations:
|
||||
return None
|
||||
return sum(obs.quality_score for obs in observations) / len(observations)
|
||||
|
||||
def prune_before(self, timestamp: datetime) -> int:
|
||||
"""Remove valid observations recorded before *timestamp*.
|
||||
|
||||
Malformed lines are preserved because their timestamp cannot be trusted.
|
||||
Returns the number of valid observation records removed.
|
||||
"""
|
||||
cutoff = _normalise_datetime(timestamp)
|
||||
removed = 0
|
||||
with _locked_file(self._path, "a+") as handle:
|
||||
handle.seek(0)
|
||||
lines = handle.readlines()
|
||||
kept: list[str] = []
|
||||
for line in lines:
|
||||
try:
|
||||
obs = QualityObservation.from_dict(json.loads(line))
|
||||
except (json.JSONDecodeError, KeyError, TypeError, ValueError):
|
||||
kept.append(line)
|
||||
continue
|
||||
if obs.recorded_at < cutoff:
|
||||
removed += 1
|
||||
else:
|
||||
kept.append(line)
|
||||
|
||||
handle.seek(0)
|
||||
handle.truncate()
|
||||
handle.writelines(kept)
|
||||
handle.flush()
|
||||
os.fsync(handle.fileno())
|
||||
return removed
|
||||
|
||||
def _read_with_malformed_count(self) -> tuple[list[QualityObservation], int]:
|
||||
if not self._path.is_file():
|
||||
return [], 0
|
||||
|
||||
observations: list[QualityObservation] = []
|
||||
malformed = 0
|
||||
with _locked_file(self._path, "r") as handle:
|
||||
for line in handle:
|
||||
if not line.strip():
|
||||
continue
|
||||
try:
|
||||
observations.append(QualityObservation.from_dict(json.loads(line)))
|
||||
except (json.JSONDecodeError, KeyError, TypeError, ValueError):
|
||||
malformed += 1
|
||||
return observations, malformed
|
||||
Reference in New Issue
Block a user