generated from coulomb/repo-seed
319 lines
11 KiB
Python
319 lines
11 KiB
Python
"""Quality observations and append-only ledger support.
|
|
|
|
These primitives let callers record observed quality/cost outcomes for a
|
|
task type without baking consumer-specific routing policy into llm-connect.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
import threading
|
|
from contextlib import contextmanager
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime, timedelta, timezone
|
|
from pathlib import Path
|
|
from typing import Any, Iterator, TextIO
|
|
|
|
|
|
_PATH_LOCKS: dict[Path, threading.Lock] = {}
|
|
_PATH_LOCKS_GUARD = threading.Lock()
|
|
|
|
|
|
def _utc_now() -> datetime:
|
|
return datetime.now(timezone.utc)
|
|
|
|
|
|
def _normalise_datetime(value: datetime | str) -> datetime:
|
|
if isinstance(value, datetime):
|
|
dt = value
|
|
elif isinstance(value, str):
|
|
dt = datetime.fromisoformat(value.replace("Z", "+00:00"))
|
|
else:
|
|
raise TypeError(f"Expected datetime or ISO string, got {type(value).__name__}")
|
|
|
|
if dt.tzinfo is None:
|
|
return dt.replace(tzinfo=timezone.utc)
|
|
return dt.astimezone(timezone.utc)
|
|
|
|
|
|
def _serialise_datetime(value: datetime) -> str:
|
|
return _normalise_datetime(value).isoformat().replace("+00:00", "Z")
|
|
|
|
|
|
def _validate_non_negative_int(name: str, value: int) -> None:
|
|
if not isinstance(value, int) or value < 0:
|
|
raise ValueError(f"{name} must be a non-negative integer")
|
|
|
|
|
|
def _validate_non_negative_float(name: str, value: float) -> None:
|
|
if not isinstance(value, (int, float)) or float(value) < 0:
|
|
raise ValueError(f"{name} must be a non-negative number")
|
|
|
|
|
|
def _path_lock(path: Path) -> threading.Lock:
|
|
resolved = path.resolve()
|
|
with _PATH_LOCKS_GUARD:
|
|
lock = _PATH_LOCKS.get(resolved)
|
|
if lock is None:
|
|
lock = threading.Lock()
|
|
_PATH_LOCKS[resolved] = lock
|
|
return lock
|
|
|
|
|
|
def _lock_file(handle: TextIO) -> None:
|
|
if os.name == "nt":
|
|
import msvcrt
|
|
|
|
msvcrt.locking(handle.fileno(), msvcrt.LK_LOCK, 1)
|
|
else:
|
|
import fcntl
|
|
|
|
fcntl.flock(handle.fileno(), fcntl.LOCK_EX)
|
|
|
|
|
|
def _unlock_file(handle: TextIO) -> None:
|
|
if os.name == "nt":
|
|
import msvcrt
|
|
|
|
msvcrt.locking(handle.fileno(), msvcrt.LK_UNLCK, 1)
|
|
else:
|
|
import fcntl
|
|
|
|
fcntl.flock(handle.fileno(), fcntl.LOCK_UN)
|
|
|
|
|
|
@contextmanager
|
|
def _locked_file(path: Path, mode: str) -> Iterator[TextIO]:
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
local_lock = _path_lock(path)
|
|
with local_lock:
|
|
with path.open(mode, encoding="utf-8") as handle:
|
|
_lock_file(handle)
|
|
try:
|
|
yield handle
|
|
finally:
|
|
_unlock_file(handle)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class QualityObservation:
|
|
"""Observed quality/cost outcome for one adapter on one task type."""
|
|
|
|
task_type: str
|
|
adapter_id: str
|
|
model_id: str
|
|
cost_usd: float
|
|
quality_score: float
|
|
latency_ms: float
|
|
tokens_in: int
|
|
tokens_out: int
|
|
baseline_adapter_id: str | None = None
|
|
recorded_at: datetime = field(default_factory=_utc_now)
|
|
tags: dict[str, Any] = field(default_factory=dict)
|
|
|
|
def __post_init__(self) -> None:
|
|
for name in ("task_type", "adapter_id", "model_id"):
|
|
if not str(getattr(self, name)).strip():
|
|
raise ValueError(f"{name} must be a non-empty string")
|
|
|
|
_validate_non_negative_float("cost_usd", self.cost_usd)
|
|
_validate_non_negative_float("latency_ms", self.latency_ms)
|
|
_validate_non_negative_int("tokens_in", self.tokens_in)
|
|
_validate_non_negative_int("tokens_out", self.tokens_out)
|
|
if not isinstance(self.quality_score, (int, float)):
|
|
raise ValueError("quality_score must be a number between 0 and 1")
|
|
if not 0 <= float(self.quality_score) <= 1:
|
|
raise ValueError("quality_score must be between 0 and 1")
|
|
|
|
object.__setattr__(self, "task_type", str(self.task_type))
|
|
object.__setattr__(self, "adapter_id", str(self.adapter_id))
|
|
object.__setattr__(self, "model_id", str(self.model_id))
|
|
object.__setattr__(self, "cost_usd", float(self.cost_usd))
|
|
object.__setattr__(self, "quality_score", float(self.quality_score))
|
|
object.__setattr__(self, "latency_ms", float(self.latency_ms))
|
|
object.__setattr__(self, "recorded_at", _normalise_datetime(self.recorded_at))
|
|
object.__setattr__(self, "tags", dict(self.tags))
|
|
|
|
@property
|
|
def total_tokens(self) -> int:
|
|
"""Return input plus output tokens."""
|
|
return self.tokens_in + self.tokens_out
|
|
|
|
def to_dict(self) -> dict[str, Any]:
|
|
"""Convert to a JSON-serialisable dictionary."""
|
|
return {
|
|
"task_type": self.task_type,
|
|
"adapter_id": self.adapter_id,
|
|
"model_id": self.model_id,
|
|
"cost_usd": self.cost_usd,
|
|
"quality_score": self.quality_score,
|
|
"latency_ms": self.latency_ms,
|
|
"tokens_in": self.tokens_in,
|
|
"tokens_out": self.tokens_out,
|
|
"baseline_adapter_id": self.baseline_adapter_id,
|
|
"recorded_at": _serialise_datetime(self.recorded_at),
|
|
"tags": dict(self.tags),
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: dict[str, Any]) -> "QualityObservation":
|
|
"""Create an observation from a JSON-decoded dictionary."""
|
|
return cls(
|
|
task_type=data["task_type"],
|
|
adapter_id=data["adapter_id"],
|
|
model_id=data["model_id"],
|
|
cost_usd=data["cost_usd"],
|
|
quality_score=data["quality_score"],
|
|
latency_ms=data["latency_ms"],
|
|
tokens_in=data["tokens_in"],
|
|
tokens_out=data["tokens_out"],
|
|
baseline_adapter_id=data.get("baseline_adapter_id"),
|
|
recorded_at=data.get("recorded_at", _utc_now()),
|
|
tags=data.get("tags") or {},
|
|
)
|
|
|
|
|
|
def is_stale(
|
|
observation: QualityObservation,
|
|
max_age: timedelta,
|
|
*,
|
|
now: datetime | None = None,
|
|
) -> bool:
|
|
"""Return whether *observation* is older than *max_age*."""
|
|
if max_age.total_seconds() < 0:
|
|
raise ValueError("max_age must be non-negative")
|
|
reference = _normalise_datetime(now or _utc_now())
|
|
return observation.recorded_at < reference - max_age
|
|
|
|
|
|
class QualityLedger:
|
|
"""Append-only JSONL store for :class:`QualityObservation` records."""
|
|
|
|
def __init__(self, path: str | Path):
|
|
self._path = Path(path)
|
|
|
|
@property
|
|
def path(self) -> Path:
|
|
"""Ledger file path."""
|
|
return self._path
|
|
|
|
def append(self, observation: QualityObservation) -> None:
|
|
"""Append one observation as a locked JSONL record."""
|
|
line = json.dumps(observation.to_dict(), sort_keys=True, separators=(",", ":"))
|
|
with _locked_file(self._path, "a") as handle:
|
|
handle.write(line + "\n")
|
|
handle.flush()
|
|
os.fsync(handle.fileno())
|
|
|
|
def read_all(self) -> list[QualityObservation]:
|
|
"""Return all parseable observations, skipping malformed lines."""
|
|
observations, _ = self._read_with_malformed_count()
|
|
return observations
|
|
|
|
def malformed_count(self) -> int:
|
|
"""Return the number of malformed lines currently skipped by reads."""
|
|
_, malformed = self._read_with_malformed_count()
|
|
return malformed
|
|
|
|
def by_task_type(self, task_type: str) -> list[QualityObservation]:
|
|
"""Return observations matching *task_type*."""
|
|
return [obs for obs in self.read_all() if obs.task_type == task_type]
|
|
|
|
def recent(
|
|
self,
|
|
limit: int | None = None,
|
|
*,
|
|
task_type: str | None = None,
|
|
adapter_id: str | None = None,
|
|
since: datetime | None = None,
|
|
) -> list[QualityObservation]:
|
|
"""Return newest observations first, optionally filtered."""
|
|
if limit is not None and limit < 0:
|
|
raise ValueError("limit must be non-negative")
|
|
|
|
cutoff = _normalise_datetime(since) if since is not None else None
|
|
observations = self.read_all()
|
|
if task_type is not None:
|
|
observations = [obs for obs in observations if obs.task_type == task_type]
|
|
if adapter_id is not None:
|
|
observations = [obs for obs in observations if obs.adapter_id == adapter_id]
|
|
if cutoff is not None:
|
|
observations = [obs for obs in observations if obs.recorded_at >= cutoff]
|
|
|
|
observations.sort(key=lambda obs: obs.recorded_at, reverse=True)
|
|
if limit is None:
|
|
return observations
|
|
return observations[:limit]
|
|
|
|
def mean_quality(
|
|
self,
|
|
task_type: str,
|
|
*,
|
|
adapter_id: str | None = None,
|
|
model_id: str | None = None,
|
|
max_age: timedelta | None = None,
|
|
min_observations: int = 1,
|
|
) -> float | None:
|
|
"""Return mean quality for matching observations, or ``None`` if absent."""
|
|
if min_observations <= 0:
|
|
raise ValueError("min_observations must be positive")
|
|
|
|
observations = self.by_task_type(task_type)
|
|
if adapter_id is not None:
|
|
observations = [obs for obs in observations if obs.adapter_id == adapter_id]
|
|
if model_id is not None:
|
|
observations = [obs for obs in observations if obs.model_id == model_id]
|
|
if max_age is not None:
|
|
observations = [obs for obs in observations if not is_stale(obs, max_age)]
|
|
|
|
if len(observations) < min_observations:
|
|
return None
|
|
return sum(obs.quality_score for obs in observations) / len(observations)
|
|
|
|
def prune_before(self, timestamp: datetime) -> int:
|
|
"""Remove valid observations recorded before *timestamp*.
|
|
|
|
Malformed lines are preserved because their timestamp cannot be trusted.
|
|
Returns the number of valid observation records removed.
|
|
"""
|
|
cutoff = _normalise_datetime(timestamp)
|
|
removed = 0
|
|
with _locked_file(self._path, "a+") as handle:
|
|
handle.seek(0)
|
|
lines = handle.readlines()
|
|
kept: list[str] = []
|
|
for line in lines:
|
|
try:
|
|
obs = QualityObservation.from_dict(json.loads(line))
|
|
except (json.JSONDecodeError, KeyError, TypeError, ValueError):
|
|
kept.append(line)
|
|
continue
|
|
if obs.recorded_at < cutoff:
|
|
removed += 1
|
|
else:
|
|
kept.append(line)
|
|
|
|
handle.seek(0)
|
|
handle.truncate()
|
|
handle.writelines(kept)
|
|
handle.flush()
|
|
os.fsync(handle.fileno())
|
|
return removed
|
|
|
|
def _read_with_malformed_count(self) -> tuple[list[QualityObservation], int]:
|
|
if not self._path.is_file():
|
|
return [], 0
|
|
|
|
observations: list[QualityObservation] = []
|
|
malformed = 0
|
|
with _locked_file(self._path, "r") as handle:
|
|
for line in handle:
|
|
if not line.strip():
|
|
continue
|
|
try:
|
|
observations.append(QualityObservation.from_dict(json.loads(line)))
|
|
except (json.JSONDecodeError, KeyError, TypeError, ValueError):
|
|
malformed += 1
|
|
return observations, malformed
|