generated from coulomb/repo-seed
Add adaptive cost-quality routing primitives
This commit is contained in:
239
llm_connect/grading.py
Normal file
239
llm_connect/grading.py
Normal file
@@ -0,0 +1,239 @@
|
||||
"""Baseline grading primitives for adaptive routing.
|
||||
|
||||
Graders compare a candidate adapter response against a caller-chosen baseline.
|
||||
They produce normalised quality scores that can be recorded in a
|
||||
``QualityLedger`` and consumed later by adaptive routing policy.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass, field, replace
|
||||
from typing import Any, Protocol
|
||||
|
||||
from llm_connect.adapter import LLMAdapter
|
||||
from llm_connect.embedding_adapter import EmbeddingAdapter
|
||||
from llm_connect.models import LLMResponse, RunConfig
|
||||
from llm_connect.similarity import cosine_similarity
|
||||
|
||||
|
||||
def _validate_score(value: float) -> float:
|
||||
if not isinstance(value, (int, float)):
|
||||
raise ValueError("quality_score must be a number between 0 and 1")
|
||||
score = float(value)
|
||||
if not 0 <= score <= 1:
|
||||
raise ValueError("quality_score must be between 0 and 1")
|
||||
return score
|
||||
|
||||
|
||||
def _normalise_text(text: str) -> str:
|
||||
return " ".join(text.strip().split())
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GradingResult:
|
||||
"""Structured result from comparing candidate output to baseline output."""
|
||||
|
||||
quality_score: float
|
||||
notes: str
|
||||
grader_id: str
|
||||
baseline_response: LLMResponse
|
||||
candidate_response: LLMResponse
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not str(self.grader_id).strip():
|
||||
raise ValueError("grader_id must be a non-empty string")
|
||||
object.__setattr__(self, "quality_score", _validate_score(self.quality_score))
|
||||
object.__setattr__(self, "notes", str(self.notes))
|
||||
|
||||
|
||||
class Judge(Protocol):
|
||||
"""Compare baseline and candidate responses."""
|
||||
|
||||
grader_id: str
|
||||
|
||||
def judge(
|
||||
self,
|
||||
baseline_response: LLMResponse,
|
||||
candidate_response: LLMResponse,
|
||||
*,
|
||||
prompt: str,
|
||||
run_config: RunConfig,
|
||||
) -> GradingResult:
|
||||
"""Return a quality score for candidate relative to baseline."""
|
||||
|
||||
|
||||
class BaselineGrader(Protocol):
|
||||
"""Run baseline and candidate adapters, then judge the paired responses."""
|
||||
|
||||
def grade(
|
||||
self,
|
||||
baseline_adapter: LLMAdapter,
|
||||
candidate_adapter: LLMAdapter,
|
||||
prompt: str,
|
||||
run_config: RunConfig,
|
||||
) -> GradingResult:
|
||||
"""Return a structured grading result."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExactMatchJudge:
|
||||
"""Judge that scores 1.0 when response text matches exactly after normalisation."""
|
||||
|
||||
normalize_whitespace: bool = True
|
||||
case_sensitive: bool = True
|
||||
grader_id: str = "exact-match"
|
||||
|
||||
def judge(
|
||||
self,
|
||||
baseline_response: LLMResponse,
|
||||
candidate_response: LLMResponse,
|
||||
*,
|
||||
prompt: str,
|
||||
run_config: RunConfig,
|
||||
) -> GradingResult:
|
||||
baseline_text = baseline_response.content
|
||||
candidate_text = candidate_response.content
|
||||
if self.normalize_whitespace:
|
||||
baseline_text = _normalise_text(baseline_text)
|
||||
candidate_text = _normalise_text(candidate_text)
|
||||
if not self.case_sensitive:
|
||||
baseline_text = baseline_text.casefold()
|
||||
candidate_text = candidate_text.casefold()
|
||||
|
||||
matched = baseline_text == candidate_text
|
||||
return GradingResult(
|
||||
quality_score=1.0 if matched else 0.0,
|
||||
notes="exact match" if matched else "candidate content differs from baseline",
|
||||
grader_id=self.grader_id,
|
||||
baseline_response=baseline_response,
|
||||
candidate_response=candidate_response,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingSimilarityJudge:
|
||||
"""Judge that maps cosine similarity between response embeddings to 0..1."""
|
||||
|
||||
embedding_adapter: EmbeddingAdapter
|
||||
grader_id: str = "embedding-similarity"
|
||||
|
||||
def judge(
|
||||
self,
|
||||
baseline_response: LLMResponse,
|
||||
candidate_response: LLMResponse,
|
||||
*,
|
||||
prompt: str,
|
||||
run_config: RunConfig,
|
||||
) -> GradingResult:
|
||||
embeddings = self.embedding_adapter.embed(
|
||||
[baseline_response.content, candidate_response.content]
|
||||
)
|
||||
if len(embeddings) != 2:
|
||||
raise ValueError("EmbeddingSimilarityJudge expected exactly two embeddings")
|
||||
|
||||
raw_similarity = cosine_similarity(embeddings[0], embeddings[1])
|
||||
quality_score = max(0.0, min(1.0, raw_similarity))
|
||||
return GradingResult(
|
||||
quality_score=quality_score,
|
||||
notes=f"cosine similarity {raw_similarity:.4f}",
|
||||
grader_id=self.grader_id,
|
||||
baseline_response=baseline_response,
|
||||
candidate_response=candidate_response,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMJudge:
|
||||
"""LLM-as-judge wrapper using a fixed rubric prompt and JSON response."""
|
||||
|
||||
judge_adapter: LLMAdapter
|
||||
rubric: str = (
|
||||
"Compare the candidate response to the baseline response. "
|
||||
"Return JSON only with keys quality_score and notes. "
|
||||
"quality_score must be a number from 0 to 1."
|
||||
)
|
||||
grader_id: str = "llm-judge"
|
||||
seed: int | None = 0
|
||||
|
||||
def judge(
|
||||
self,
|
||||
baseline_response: LLMResponse,
|
||||
candidate_response: LLMResponse,
|
||||
*,
|
||||
prompt: str,
|
||||
run_config: RunConfig,
|
||||
) -> GradingResult:
|
||||
judge_prompt = self._build_prompt(prompt, baseline_response, candidate_response)
|
||||
judge_config = self._judge_config(run_config)
|
||||
response = self.judge_adapter.execute_prompt(judge_prompt, judge_config)
|
||||
parsed = self._parse_judge_response(response.content)
|
||||
return GradingResult(
|
||||
quality_score=parsed["quality_score"],
|
||||
notes=parsed["notes"],
|
||||
grader_id=self.grader_id,
|
||||
baseline_response=baseline_response,
|
||||
candidate_response=candidate_response,
|
||||
)
|
||||
|
||||
def _judge_config(self, run_config: RunConfig) -> RunConfig:
|
||||
params: dict[str, Any] = dict(run_config.model_params)
|
||||
if self.seed is not None:
|
||||
params.setdefault("seed", self.seed)
|
||||
return replace(run_config, temperature=0.0, model_params=params, budget_tracker=None)
|
||||
|
||||
def _build_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
baseline_response: LLMResponse,
|
||||
candidate_response: LLMResponse,
|
||||
) -> str:
|
||||
return (
|
||||
f"{self.rubric}\n\n"
|
||||
f"Original prompt:\n{prompt}\n\n"
|
||||
f"Baseline response:\n{baseline_response.content}\n\n"
|
||||
f"Candidate response:\n{candidate_response.content}\n"
|
||||
)
|
||||
|
||||
def _parse_judge_response(self, content: str) -> dict[str, Any]:
|
||||
try:
|
||||
data = json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
match = re.search(r"\{.*\}", content, flags=re.DOTALL)
|
||||
if not match:
|
||||
raise ValueError("LLMJudge response did not contain JSON") from None
|
||||
try:
|
||||
data = json.loads(match.group(0))
|
||||
except json.JSONDecodeError as exc:
|
||||
raise ValueError("LLMJudge response JSON could not be parsed") from exc
|
||||
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError("LLMJudge response JSON must be an object")
|
||||
return {
|
||||
"quality_score": _validate_score(data.get("quality_score")),
|
||||
"notes": str(data.get("notes", "")),
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class PairedGrader:
|
||||
"""Baseline grader that runs both adapters and delegates comparison to a judge."""
|
||||
|
||||
judge: Judge = field(default_factory=ExactMatchJudge)
|
||||
|
||||
def grade(
|
||||
self,
|
||||
baseline_adapter: LLMAdapter,
|
||||
candidate_adapter: LLMAdapter,
|
||||
prompt: str,
|
||||
run_config: RunConfig,
|
||||
) -> GradingResult:
|
||||
baseline_response = baseline_adapter.execute_prompt(prompt, run_config)
|
||||
candidate_response = candidate_adapter.execute_prompt(prompt, run_config)
|
||||
return self.judge.judge(
|
||||
baseline_response,
|
||||
candidate_response,
|
||||
prompt=prompt,
|
||||
run_config=run_config,
|
||||
)
|
||||
318
llm_connect/quality.py
Normal file
318
llm_connect/quality.py
Normal file
@@ -0,0 +1,318 @@
|
||||
"""Quality observations and append-only ledger support.
|
||||
|
||||
These primitives let callers record observed quality/cost outcomes for a
|
||||
task type without baking consumer-specific routing policy into llm-connect.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterator, TextIO
|
||||
|
||||
|
||||
_PATH_LOCKS: dict[Path, threading.Lock] = {}
|
||||
_PATH_LOCKS_GUARD = threading.Lock()
|
||||
|
||||
|
||||
def _utc_now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def _normalise_datetime(value: datetime | str) -> datetime:
|
||||
if isinstance(value, datetime):
|
||||
dt = value
|
||||
elif isinstance(value, str):
|
||||
dt = datetime.fromisoformat(value.replace("Z", "+00:00"))
|
||||
else:
|
||||
raise TypeError(f"Expected datetime or ISO string, got {type(value).__name__}")
|
||||
|
||||
if dt.tzinfo is None:
|
||||
return dt.replace(tzinfo=timezone.utc)
|
||||
return dt.astimezone(timezone.utc)
|
||||
|
||||
|
||||
def _serialise_datetime(value: datetime) -> str:
|
||||
return _normalise_datetime(value).isoformat().replace("+00:00", "Z")
|
||||
|
||||
|
||||
def _validate_non_negative_int(name: str, value: int) -> None:
|
||||
if not isinstance(value, int) or value < 0:
|
||||
raise ValueError(f"{name} must be a non-negative integer")
|
||||
|
||||
|
||||
def _validate_non_negative_float(name: str, value: float) -> None:
|
||||
if not isinstance(value, (int, float)) or float(value) < 0:
|
||||
raise ValueError(f"{name} must be a non-negative number")
|
||||
|
||||
|
||||
def _path_lock(path: Path) -> threading.Lock:
|
||||
resolved = path.resolve()
|
||||
with _PATH_LOCKS_GUARD:
|
||||
lock = _PATH_LOCKS.get(resolved)
|
||||
if lock is None:
|
||||
lock = threading.Lock()
|
||||
_PATH_LOCKS[resolved] = lock
|
||||
return lock
|
||||
|
||||
|
||||
def _lock_file(handle: TextIO) -> None:
|
||||
if os.name == "nt":
|
||||
import msvcrt
|
||||
|
||||
msvcrt.locking(handle.fileno(), msvcrt.LK_LOCK, 1)
|
||||
else:
|
||||
import fcntl
|
||||
|
||||
fcntl.flock(handle.fileno(), fcntl.LOCK_EX)
|
||||
|
||||
|
||||
def _unlock_file(handle: TextIO) -> None:
|
||||
if os.name == "nt":
|
||||
import msvcrt
|
||||
|
||||
msvcrt.locking(handle.fileno(), msvcrt.LK_UNLCK, 1)
|
||||
else:
|
||||
import fcntl
|
||||
|
||||
fcntl.flock(handle.fileno(), fcntl.LOCK_UN)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _locked_file(path: Path, mode: str) -> Iterator[TextIO]:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
local_lock = _path_lock(path)
|
||||
with local_lock:
|
||||
with path.open(mode, encoding="utf-8") as handle:
|
||||
_lock_file(handle)
|
||||
try:
|
||||
yield handle
|
||||
finally:
|
||||
_unlock_file(handle)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class QualityObservation:
|
||||
"""Observed quality/cost outcome for one adapter on one task type."""
|
||||
|
||||
task_type: str
|
||||
adapter_id: str
|
||||
model_id: str
|
||||
cost_usd: float
|
||||
quality_score: float
|
||||
latency_ms: float
|
||||
tokens_in: int
|
||||
tokens_out: int
|
||||
baseline_adapter_id: str | None = None
|
||||
recorded_at: datetime = field(default_factory=_utc_now)
|
||||
tags: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
for name in ("task_type", "adapter_id", "model_id"):
|
||||
if not str(getattr(self, name)).strip():
|
||||
raise ValueError(f"{name} must be a non-empty string")
|
||||
|
||||
_validate_non_negative_float("cost_usd", self.cost_usd)
|
||||
_validate_non_negative_float("latency_ms", self.latency_ms)
|
||||
_validate_non_negative_int("tokens_in", self.tokens_in)
|
||||
_validate_non_negative_int("tokens_out", self.tokens_out)
|
||||
if not isinstance(self.quality_score, (int, float)):
|
||||
raise ValueError("quality_score must be a number between 0 and 1")
|
||||
if not 0 <= float(self.quality_score) <= 1:
|
||||
raise ValueError("quality_score must be between 0 and 1")
|
||||
|
||||
object.__setattr__(self, "task_type", str(self.task_type))
|
||||
object.__setattr__(self, "adapter_id", str(self.adapter_id))
|
||||
object.__setattr__(self, "model_id", str(self.model_id))
|
||||
object.__setattr__(self, "cost_usd", float(self.cost_usd))
|
||||
object.__setattr__(self, "quality_score", float(self.quality_score))
|
||||
object.__setattr__(self, "latency_ms", float(self.latency_ms))
|
||||
object.__setattr__(self, "recorded_at", _normalise_datetime(self.recorded_at))
|
||||
object.__setattr__(self, "tags", dict(self.tags))
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int:
|
||||
"""Return input plus output tokens."""
|
||||
return self.tokens_in + self.tokens_out
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to a JSON-serialisable dictionary."""
|
||||
return {
|
||||
"task_type": self.task_type,
|
||||
"adapter_id": self.adapter_id,
|
||||
"model_id": self.model_id,
|
||||
"cost_usd": self.cost_usd,
|
||||
"quality_score": self.quality_score,
|
||||
"latency_ms": self.latency_ms,
|
||||
"tokens_in": self.tokens_in,
|
||||
"tokens_out": self.tokens_out,
|
||||
"baseline_adapter_id": self.baseline_adapter_id,
|
||||
"recorded_at": _serialise_datetime(self.recorded_at),
|
||||
"tags": dict(self.tags),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "QualityObservation":
|
||||
"""Create an observation from a JSON-decoded dictionary."""
|
||||
return cls(
|
||||
task_type=data["task_type"],
|
||||
adapter_id=data["adapter_id"],
|
||||
model_id=data["model_id"],
|
||||
cost_usd=data["cost_usd"],
|
||||
quality_score=data["quality_score"],
|
||||
latency_ms=data["latency_ms"],
|
||||
tokens_in=data["tokens_in"],
|
||||
tokens_out=data["tokens_out"],
|
||||
baseline_adapter_id=data.get("baseline_adapter_id"),
|
||||
recorded_at=data.get("recorded_at", _utc_now()),
|
||||
tags=data.get("tags") or {},
|
||||
)
|
||||
|
||||
|
||||
def is_stale(
|
||||
observation: QualityObservation,
|
||||
max_age: timedelta,
|
||||
*,
|
||||
now: datetime | None = None,
|
||||
) -> bool:
|
||||
"""Return whether *observation* is older than *max_age*."""
|
||||
if max_age.total_seconds() < 0:
|
||||
raise ValueError("max_age must be non-negative")
|
||||
reference = _normalise_datetime(now or _utc_now())
|
||||
return observation.recorded_at < reference - max_age
|
||||
|
||||
|
||||
class QualityLedger:
|
||||
"""Append-only JSONL store for :class:`QualityObservation` records."""
|
||||
|
||||
def __init__(self, path: str | Path):
|
||||
self._path = Path(path)
|
||||
|
||||
@property
|
||||
def path(self) -> Path:
|
||||
"""Ledger file path."""
|
||||
return self._path
|
||||
|
||||
def append(self, observation: QualityObservation) -> None:
|
||||
"""Append one observation as a locked JSONL record."""
|
||||
line = json.dumps(observation.to_dict(), sort_keys=True, separators=(",", ":"))
|
||||
with _locked_file(self._path, "a") as handle:
|
||||
handle.write(line + "\n")
|
||||
handle.flush()
|
||||
os.fsync(handle.fileno())
|
||||
|
||||
def read_all(self) -> list[QualityObservation]:
|
||||
"""Return all parseable observations, skipping malformed lines."""
|
||||
observations, _ = self._read_with_malformed_count()
|
||||
return observations
|
||||
|
||||
def malformed_count(self) -> int:
|
||||
"""Return the number of malformed lines currently skipped by reads."""
|
||||
_, malformed = self._read_with_malformed_count()
|
||||
return malformed
|
||||
|
||||
def by_task_type(self, task_type: str) -> list[QualityObservation]:
|
||||
"""Return observations matching *task_type*."""
|
||||
return [obs for obs in self.read_all() if obs.task_type == task_type]
|
||||
|
||||
def recent(
|
||||
self,
|
||||
limit: int | None = None,
|
||||
*,
|
||||
task_type: str | None = None,
|
||||
adapter_id: str | None = None,
|
||||
since: datetime | None = None,
|
||||
) -> list[QualityObservation]:
|
||||
"""Return newest observations first, optionally filtered."""
|
||||
if limit is not None and limit < 0:
|
||||
raise ValueError("limit must be non-negative")
|
||||
|
||||
cutoff = _normalise_datetime(since) if since is not None else None
|
||||
observations = self.read_all()
|
||||
if task_type is not None:
|
||||
observations = [obs for obs in observations if obs.task_type == task_type]
|
||||
if adapter_id is not None:
|
||||
observations = [obs for obs in observations if obs.adapter_id == adapter_id]
|
||||
if cutoff is not None:
|
||||
observations = [obs for obs in observations if obs.recorded_at >= cutoff]
|
||||
|
||||
observations.sort(key=lambda obs: obs.recorded_at, reverse=True)
|
||||
if limit is None:
|
||||
return observations
|
||||
return observations[:limit]
|
||||
|
||||
def mean_quality(
|
||||
self,
|
||||
task_type: str,
|
||||
*,
|
||||
adapter_id: str | None = None,
|
||||
model_id: str | None = None,
|
||||
max_age: timedelta | None = None,
|
||||
min_observations: int = 1,
|
||||
) -> float | None:
|
||||
"""Return mean quality for matching observations, or ``None`` if absent."""
|
||||
if min_observations <= 0:
|
||||
raise ValueError("min_observations must be positive")
|
||||
|
||||
observations = self.by_task_type(task_type)
|
||||
if adapter_id is not None:
|
||||
observations = [obs for obs in observations if obs.adapter_id == adapter_id]
|
||||
if model_id is not None:
|
||||
observations = [obs for obs in observations if obs.model_id == model_id]
|
||||
if max_age is not None:
|
||||
observations = [obs for obs in observations if not is_stale(obs, max_age)]
|
||||
|
||||
if len(observations) < min_observations:
|
||||
return None
|
||||
return sum(obs.quality_score for obs in observations) / len(observations)
|
||||
|
||||
def prune_before(self, timestamp: datetime) -> int:
|
||||
"""Remove valid observations recorded before *timestamp*.
|
||||
|
||||
Malformed lines are preserved because their timestamp cannot be trusted.
|
||||
Returns the number of valid observation records removed.
|
||||
"""
|
||||
cutoff = _normalise_datetime(timestamp)
|
||||
removed = 0
|
||||
with _locked_file(self._path, "a+") as handle:
|
||||
handle.seek(0)
|
||||
lines = handle.readlines()
|
||||
kept: list[str] = []
|
||||
for line in lines:
|
||||
try:
|
||||
obs = QualityObservation.from_dict(json.loads(line))
|
||||
except (json.JSONDecodeError, KeyError, TypeError, ValueError):
|
||||
kept.append(line)
|
||||
continue
|
||||
if obs.recorded_at < cutoff:
|
||||
removed += 1
|
||||
else:
|
||||
kept.append(line)
|
||||
|
||||
handle.seek(0)
|
||||
handle.truncate()
|
||||
handle.writelines(kept)
|
||||
handle.flush()
|
||||
os.fsync(handle.fileno())
|
||||
return removed
|
||||
|
||||
def _read_with_malformed_count(self) -> tuple[list[QualityObservation], int]:
|
||||
if not self._path.is_file():
|
||||
return [], 0
|
||||
|
||||
observations: list[QualityObservation] = []
|
||||
malformed = 0
|
||||
with _locked_file(self._path, "r") as handle:
|
||||
for line in handle:
|
||||
if not line.strip():
|
||||
continue
|
||||
try:
|
||||
observations.append(QualityObservation.from_dict(json.loads(line)))
|
||||
except (json.JSONDecodeError, KeyError, TypeError, ValueError):
|
||||
malformed += 1
|
||||
return observations, malformed
|
||||
@@ -5,9 +5,11 @@ Maps task types to preferred adapters with optional cost-cap fallback.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, List
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import List, Mapping, Optional
|
||||
|
||||
from llm_connect.adapter import LLMAdapter
|
||||
from llm_connect.quality import QualityLedger, QualityObservation
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -87,3 +89,172 @@ class RoutingPolicy:
|
||||
raise LookupError(
|
||||
f"No routing rule for task_type={task_type!r} and no default configured"
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _CandidateMetrics:
|
||||
adapter_id: str
|
||||
adapter: LLMAdapter
|
||||
mean_quality: float
|
||||
mean_cost_usd: float
|
||||
order: int
|
||||
is_static_prefer: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdaptiveRoutingPolicy(RoutingPolicy):
|
||||
"""Route to the cheapest adapter whose observed quality clears a floor.
|
||||
|
||||
The policy consults a :class:`~llm_connect.quality.QualityLedger` for
|
||||
observations matching ``task_type`` and adapter id. When the ledger has no
|
||||
qualifying observations, resolution falls through to ``RoutingPolicy`` so a
|
||||
caller can use the same policy on day zero and after observations accrue.
|
||||
"""
|
||||
|
||||
ledger: Optional[QualityLedger] = None
|
||||
adapters_by_id: Mapping[str, LLMAdapter] = field(default_factory=dict)
|
||||
window_size: int = 20
|
||||
min_observations: int = 1
|
||||
max_age: Optional[timedelta] = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.window_size <= 0:
|
||||
raise ValueError("window_size must be positive")
|
||||
if self.min_observations <= 0:
|
||||
raise ValueError("min_observations must be positive")
|
||||
if self.max_age is not None and self.max_age.total_seconds() < 0:
|
||||
raise ValueError("max_age must be non-negative")
|
||||
|
||||
def resolve(
|
||||
self,
|
||||
task_type: str,
|
||||
estimated_cost_per_1k: Optional[float] = None,
|
||||
*,
|
||||
quality_floor: Optional[float] = None,
|
||||
) -> LLMAdapter:
|
||||
"""Return the adaptive adapter for *task_type*.
|
||||
|
||||
Args:
|
||||
task_type: Logical task identifier.
|
||||
estimated_cost_per_1k: Passed through to static routing fallback.
|
||||
quality_floor: Minimum observed mean quality required for adaptive
|
||||
selection. When omitted, static routing is used.
|
||||
|
||||
Returns:
|
||||
The selected :class:`~llm_connect.adapter.LLMAdapter`.
|
||||
"""
|
||||
if quality_floor is None or self.ledger is None:
|
||||
return super().resolve(task_type, estimated_cost_per_1k)
|
||||
if not 0 <= quality_floor <= 1:
|
||||
raise ValueError("quality_floor must be between 0 and 1")
|
||||
|
||||
metrics = self._qualifying_candidates(task_type, quality_floor)
|
||||
if not metrics:
|
||||
return super().resolve(task_type, estimated_cost_per_1k)
|
||||
|
||||
best = min(
|
||||
metrics,
|
||||
key=lambda candidate: (
|
||||
candidate.mean_cost_usd,
|
||||
0 if candidate.is_static_prefer else 1,
|
||||
candidate.order,
|
||||
),
|
||||
)
|
||||
return best.adapter
|
||||
|
||||
def _qualifying_candidates(
|
||||
self,
|
||||
task_type: str,
|
||||
quality_floor: float,
|
||||
) -> list[_CandidateMetrics]:
|
||||
static_prefer = self._static_preferred_adapter(task_type)
|
||||
candidates: list[_CandidateMetrics] = []
|
||||
for order, (adapter_id, adapter) in enumerate(self._candidate_entries(task_type)):
|
||||
observations = self._windowed_observations(task_type, adapter_id)
|
||||
if len(observations) < self.min_observations:
|
||||
continue
|
||||
|
||||
mean_quality = sum(obs.quality_score for obs in observations) / len(observations)
|
||||
if mean_quality < quality_floor:
|
||||
continue
|
||||
|
||||
mean_cost = sum(obs.cost_usd for obs in observations) / len(observations)
|
||||
candidates.append(
|
||||
_CandidateMetrics(
|
||||
adapter_id=adapter_id,
|
||||
adapter=adapter,
|
||||
mean_quality=mean_quality,
|
||||
mean_cost_usd=mean_cost,
|
||||
order=order,
|
||||
is_static_prefer=adapter is static_prefer,
|
||||
)
|
||||
)
|
||||
return candidates
|
||||
|
||||
def _windowed_observations(
|
||||
self,
|
||||
task_type: str,
|
||||
adapter_id: str,
|
||||
) -> list[QualityObservation]:
|
||||
if self.ledger is None:
|
||||
return []
|
||||
|
||||
since = None
|
||||
if self.max_age is not None:
|
||||
since = datetime.now(timezone.utc) - self.max_age
|
||||
|
||||
return self.ledger.recent(
|
||||
limit=self.window_size,
|
||||
task_type=task_type,
|
||||
adapter_id=adapter_id,
|
||||
since=since,
|
||||
)
|
||||
|
||||
def _candidate_entries(self, task_type: str) -> list[tuple[str, LLMAdapter]]:
|
||||
entries: list[tuple[str, LLMAdapter]] = []
|
||||
seen_ids: set[str] = set()
|
||||
|
||||
def add(adapter_id: str | None, adapter: LLMAdapter | None) -> None:
|
||||
if adapter is None or adapter_id is None or adapter_id in seen_ids:
|
||||
return
|
||||
seen_ids.add(adapter_id)
|
||||
entries.append((adapter_id, adapter))
|
||||
|
||||
for adapter_id, adapter in self.adapters_by_id.items():
|
||||
add(adapter_id, adapter)
|
||||
|
||||
for adapter in self._static_candidate_adapters(task_type):
|
||||
add(self._adapter_id_for(adapter), adapter)
|
||||
|
||||
return entries
|
||||
|
||||
def _static_candidate_adapters(self, task_type: str) -> list[LLMAdapter]:
|
||||
for rule in self.rules:
|
||||
if rule.task_type == task_type:
|
||||
candidates = [rule.prefer]
|
||||
if rule.fallback is not None:
|
||||
candidates.append(rule.fallback)
|
||||
if self.default is not None:
|
||||
candidates.append(self.default)
|
||||
return candidates
|
||||
|
||||
if self.default is not None:
|
||||
return [self.default]
|
||||
return []
|
||||
|
||||
def _static_preferred_adapter(self, task_type: str) -> LLMAdapter | None:
|
||||
for rule in self.rules:
|
||||
if rule.task_type == task_type:
|
||||
return rule.prefer
|
||||
return None
|
||||
|
||||
def _adapter_id_for(self, adapter: LLMAdapter) -> str | None:
|
||||
for adapter_id, candidate in self.adapters_by_id.items():
|
||||
if candidate is adapter:
|
||||
return adapter_id
|
||||
|
||||
for attribute in ("adapter_id", "id", "name"):
|
||||
value = getattr(adapter, attribute, None)
|
||||
if isinstance(value, str) and value.strip():
|
||||
return value
|
||||
return None
|
||||
|
||||
177
llm_connect/shadowing.py
Normal file
177
llm_connect/shadowing.py
Normal file
@@ -0,0 +1,177 @@
|
||||
"""Shadow-mode observation adapter for adaptive routing."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
import threading
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from dataclasses import dataclass, field, replace
|
||||
from typing import Any, Callable, Mapping
|
||||
|
||||
from llm_connect.adapter import LLMAdapter
|
||||
from llm_connect.grading import BaselineGrader
|
||||
from llm_connect.models import LLMResponse, RunConfig
|
||||
from llm_connect.quality import QualityLedger, QualityObservation
|
||||
|
||||
|
||||
def _default_cost_estimator(response: LLMResponse) -> float:
|
||||
for key in ("cost_usd", "estimated_cost_usd", "cost"):
|
||||
value = response.metadata.get(key)
|
||||
if isinstance(value, (int, float)) and value >= 0:
|
||||
return float(value)
|
||||
return 0.0
|
||||
|
||||
|
||||
class _StaticResponseAdapter(LLMAdapter):
|
||||
"""Adapter shim that lets a BaselineGrader reuse an existing response."""
|
||||
|
||||
def __init__(self, response: LLMResponse):
|
||||
self._response = response
|
||||
|
||||
def execute_prompt(self, prompt: str, config: RunConfig) -> LLMResponse:
|
||||
return self._response
|
||||
|
||||
def validate_config(self, config: RunConfig) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShadowingAdapter(LLMAdapter):
|
||||
"""Return candidate responses while recording sampled baseline grades.
|
||||
|
||||
Shadow work is best-effort: baseline, grading, or ledger failures are
|
||||
reported to ``on_shadow_error`` when provided, but never alter the candidate
|
||||
response returned to the caller.
|
||||
"""
|
||||
|
||||
candidate_adapter: LLMAdapter
|
||||
baseline_adapter: LLMAdapter
|
||||
grader: BaselineGrader
|
||||
ledger: QualityLedger
|
||||
task_type: str
|
||||
adapter_id: str
|
||||
model_id: str | None = None
|
||||
baseline_adapter_id: str | None = None
|
||||
shadow_rate: float = 1.0
|
||||
async_shadow: bool = False
|
||||
random_source: random.Random = field(default_factory=random.Random, repr=False)
|
||||
cost_estimator: Callable[[LLMResponse], float] = _default_cost_estimator
|
||||
tags: Mapping[str, Any] = field(default_factory=dict)
|
||||
on_shadow_error: Callable[[Exception], None] | None = None
|
||||
_executor: ThreadPoolExecutor | None = field(default=None, init=False, repr=False)
|
||||
_futures: list[Future[None]] = field(default_factory=list, init=False, repr=False)
|
||||
_lock: threading.Lock = field(default_factory=threading.Lock, init=False, repr=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not str(self.task_type).strip():
|
||||
raise ValueError("task_type must be a non-empty string")
|
||||
if not str(self.adapter_id).strip():
|
||||
raise ValueError("adapter_id must be a non-empty string")
|
||||
if not 0 <= self.shadow_rate <= 1:
|
||||
raise ValueError("shadow_rate must be between 0 and 1")
|
||||
if self.async_shadow:
|
||||
self._executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
def execute_prompt(self, prompt: str, config: RunConfig) -> LLMResponse:
|
||||
response = self.candidate_adapter.execute_prompt(prompt, config)
|
||||
if self._should_shadow():
|
||||
self._handle_shadow(prompt, config, response)
|
||||
return response
|
||||
|
||||
async def async_execute_prompt(self, prompt: str, config: RunConfig) -> LLMResponse:
|
||||
response = await self.candidate_adapter.async_execute_prompt(prompt, config)
|
||||
if self._should_shadow():
|
||||
if self.async_shadow:
|
||||
self._schedule_shadow(prompt, config, response)
|
||||
else:
|
||||
await asyncio.to_thread(self._run_shadow, prompt, config, response)
|
||||
return response
|
||||
|
||||
def validate_config(self, config: RunConfig) -> bool:
|
||||
return self.candidate_adapter.validate_config(config)
|
||||
|
||||
def flush(self, timeout: float | None = None) -> None:
|
||||
"""Wait for currently queued async shadow work to finish."""
|
||||
with self._lock:
|
||||
futures = list(self._futures)
|
||||
self._futures.clear()
|
||||
for future in futures:
|
||||
future.result(timeout=timeout)
|
||||
|
||||
def shutdown(self, wait: bool = True) -> None:
|
||||
"""Shut down the background shadow executor if one was created."""
|
||||
if self._executor is not None:
|
||||
self._executor.shutdown(wait=wait)
|
||||
self._executor = None
|
||||
|
||||
def _should_shadow(self) -> bool:
|
||||
if self.shadow_rate <= 0:
|
||||
return False
|
||||
if self.shadow_rate >= 1:
|
||||
return True
|
||||
with self._lock:
|
||||
return self.random_source.random() < self.shadow_rate
|
||||
|
||||
def _handle_shadow(
|
||||
self,
|
||||
prompt: str,
|
||||
config: RunConfig,
|
||||
candidate_response: LLMResponse,
|
||||
) -> None:
|
||||
if self.async_shadow:
|
||||
self._schedule_shadow(prompt, config, candidate_response)
|
||||
else:
|
||||
self._run_shadow(prompt, config, candidate_response)
|
||||
|
||||
def _schedule_shadow(
|
||||
self,
|
||||
prompt: str,
|
||||
config: RunConfig,
|
||||
candidate_response: LLMResponse,
|
||||
) -> None:
|
||||
if self._executor is None:
|
||||
self._executor = ThreadPoolExecutor(max_workers=1)
|
||||
future = self._executor.submit(self._run_shadow, prompt, config, candidate_response)
|
||||
with self._lock:
|
||||
self._futures = [item for item in self._futures if not item.done()]
|
||||
self._futures.append(future)
|
||||
|
||||
def _run_shadow(
|
||||
self,
|
||||
prompt: str,
|
||||
config: RunConfig,
|
||||
candidate_response: LLMResponse,
|
||||
) -> None:
|
||||
try:
|
||||
shadow_config = replace(config, budget_tracker=None)
|
||||
result = self.grader.grade(
|
||||
self.baseline_adapter,
|
||||
_StaticResponseAdapter(candidate_response),
|
||||
prompt,
|
||||
shadow_config,
|
||||
)
|
||||
self.ledger.append(
|
||||
QualityObservation(
|
||||
task_type=self.task_type,
|
||||
adapter_id=self.adapter_id,
|
||||
model_id=self.model_id or candidate_response.model or config.model_name,
|
||||
cost_usd=self.cost_estimator(candidate_response),
|
||||
quality_score=result.quality_score,
|
||||
latency_ms=float(candidate_response.metadata.get("latency_ms", 0.0)),
|
||||
tokens_in=int(candidate_response.usage.get("prompt_tokens", 0)),
|
||||
tokens_out=int(candidate_response.usage.get("completion_tokens", 0)),
|
||||
baseline_adapter_id=self.baseline_adapter_id,
|
||||
tags=dict(self.tags),
|
||||
)
|
||||
)
|
||||
except Exception as exc:
|
||||
self._report_shadow_error(exc)
|
||||
|
||||
def _report_shadow_error(self, exc: Exception) -> None:
|
||||
if self.on_shadow_error is None:
|
||||
return
|
||||
try:
|
||||
self.on_shadow_error(exc)
|
||||
except Exception:
|
||||
pass
|
||||
Reference in New Issue
Block a user