Add adaptive cost-quality routing primitives
Some checks failed
CI / test (3.10) (push) Has been cancelled
CI / test (3.11) (push) Has been cancelled
CI / test (3.12) (push) Has been cancelled

This commit is contained in:
2026-05-17 21:32:27 +02:00
parent bf86a03c5d
commit c4ad4bb9f2
17 changed files with 2480 additions and 25 deletions

239
llm_connect/grading.py Normal file
View File

@@ -0,0 +1,239 @@
"""Baseline grading primitives for adaptive routing.
Graders compare a candidate adapter response against a caller-chosen baseline.
They produce normalised quality scores that can be recorded in a
``QualityLedger`` and consumed later by adaptive routing policy.
"""
from __future__ import annotations
import json
import re
from dataclasses import dataclass, field, replace
from typing import Any, Protocol
from llm_connect.adapter import LLMAdapter
from llm_connect.embedding_adapter import EmbeddingAdapter
from llm_connect.models import LLMResponse, RunConfig
from llm_connect.similarity import cosine_similarity
def _validate_score(value: float) -> float:
if not isinstance(value, (int, float)):
raise ValueError("quality_score must be a number between 0 and 1")
score = float(value)
if not 0 <= score <= 1:
raise ValueError("quality_score must be between 0 and 1")
return score
def _normalise_text(text: str) -> str:
return " ".join(text.strip().split())
@dataclass(frozen=True)
class GradingResult:
"""Structured result from comparing candidate output to baseline output."""
quality_score: float
notes: str
grader_id: str
baseline_response: LLMResponse
candidate_response: LLMResponse
def __post_init__(self) -> None:
if not str(self.grader_id).strip():
raise ValueError("grader_id must be a non-empty string")
object.__setattr__(self, "quality_score", _validate_score(self.quality_score))
object.__setattr__(self, "notes", str(self.notes))
class Judge(Protocol):
"""Compare baseline and candidate responses."""
grader_id: str
def judge(
self,
baseline_response: LLMResponse,
candidate_response: LLMResponse,
*,
prompt: str,
run_config: RunConfig,
) -> GradingResult:
"""Return a quality score for candidate relative to baseline."""
class BaselineGrader(Protocol):
"""Run baseline and candidate adapters, then judge the paired responses."""
def grade(
self,
baseline_adapter: LLMAdapter,
candidate_adapter: LLMAdapter,
prompt: str,
run_config: RunConfig,
) -> GradingResult:
"""Return a structured grading result."""
@dataclass
class ExactMatchJudge:
"""Judge that scores 1.0 when response text matches exactly after normalisation."""
normalize_whitespace: bool = True
case_sensitive: bool = True
grader_id: str = "exact-match"
def judge(
self,
baseline_response: LLMResponse,
candidate_response: LLMResponse,
*,
prompt: str,
run_config: RunConfig,
) -> GradingResult:
baseline_text = baseline_response.content
candidate_text = candidate_response.content
if self.normalize_whitespace:
baseline_text = _normalise_text(baseline_text)
candidate_text = _normalise_text(candidate_text)
if not self.case_sensitive:
baseline_text = baseline_text.casefold()
candidate_text = candidate_text.casefold()
matched = baseline_text == candidate_text
return GradingResult(
quality_score=1.0 if matched else 0.0,
notes="exact match" if matched else "candidate content differs from baseline",
grader_id=self.grader_id,
baseline_response=baseline_response,
candidate_response=candidate_response,
)
@dataclass
class EmbeddingSimilarityJudge:
"""Judge that maps cosine similarity between response embeddings to 0..1."""
embedding_adapter: EmbeddingAdapter
grader_id: str = "embedding-similarity"
def judge(
self,
baseline_response: LLMResponse,
candidate_response: LLMResponse,
*,
prompt: str,
run_config: RunConfig,
) -> GradingResult:
embeddings = self.embedding_adapter.embed(
[baseline_response.content, candidate_response.content]
)
if len(embeddings) != 2:
raise ValueError("EmbeddingSimilarityJudge expected exactly two embeddings")
raw_similarity = cosine_similarity(embeddings[0], embeddings[1])
quality_score = max(0.0, min(1.0, raw_similarity))
return GradingResult(
quality_score=quality_score,
notes=f"cosine similarity {raw_similarity:.4f}",
grader_id=self.grader_id,
baseline_response=baseline_response,
candidate_response=candidate_response,
)
@dataclass
class LLMJudge:
"""LLM-as-judge wrapper using a fixed rubric prompt and JSON response."""
judge_adapter: LLMAdapter
rubric: str = (
"Compare the candidate response to the baseline response. "
"Return JSON only with keys quality_score and notes. "
"quality_score must be a number from 0 to 1."
)
grader_id: str = "llm-judge"
seed: int | None = 0
def judge(
self,
baseline_response: LLMResponse,
candidate_response: LLMResponse,
*,
prompt: str,
run_config: RunConfig,
) -> GradingResult:
judge_prompt = self._build_prompt(prompt, baseline_response, candidate_response)
judge_config = self._judge_config(run_config)
response = self.judge_adapter.execute_prompt(judge_prompt, judge_config)
parsed = self._parse_judge_response(response.content)
return GradingResult(
quality_score=parsed["quality_score"],
notes=parsed["notes"],
grader_id=self.grader_id,
baseline_response=baseline_response,
candidate_response=candidate_response,
)
def _judge_config(self, run_config: RunConfig) -> RunConfig:
params: dict[str, Any] = dict(run_config.model_params)
if self.seed is not None:
params.setdefault("seed", self.seed)
return replace(run_config, temperature=0.0, model_params=params, budget_tracker=None)
def _build_prompt(
self,
prompt: str,
baseline_response: LLMResponse,
candidate_response: LLMResponse,
) -> str:
return (
f"{self.rubric}\n\n"
f"Original prompt:\n{prompt}\n\n"
f"Baseline response:\n{baseline_response.content}\n\n"
f"Candidate response:\n{candidate_response.content}\n"
)
def _parse_judge_response(self, content: str) -> dict[str, Any]:
try:
data = json.loads(content)
except json.JSONDecodeError:
match = re.search(r"\{.*\}", content, flags=re.DOTALL)
if not match:
raise ValueError("LLMJudge response did not contain JSON") from None
try:
data = json.loads(match.group(0))
except json.JSONDecodeError as exc:
raise ValueError("LLMJudge response JSON could not be parsed") from exc
if not isinstance(data, dict):
raise ValueError("LLMJudge response JSON must be an object")
return {
"quality_score": _validate_score(data.get("quality_score")),
"notes": str(data.get("notes", "")),
}
@dataclass
class PairedGrader:
"""Baseline grader that runs both adapters and delegates comparison to a judge."""
judge: Judge = field(default_factory=ExactMatchJudge)
def grade(
self,
baseline_adapter: LLMAdapter,
candidate_adapter: LLMAdapter,
prompt: str,
run_config: RunConfig,
) -> GradingResult:
baseline_response = baseline_adapter.execute_prompt(prompt, run_config)
candidate_response = candidate_adapter.execute_prompt(prompt, run_config)
return self.judge.judge(
baseline_response,
candidate_response,
prompt=prompt,
run_config=run_config,
)

318
llm_connect/quality.py Normal file
View File

@@ -0,0 +1,318 @@
"""Quality observations and append-only ledger support.
These primitives let callers record observed quality/cost outcomes for a
task type without baking consumer-specific routing policy into llm-connect.
"""
from __future__ import annotations
import json
import os
import threading
from contextlib import contextmanager
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Any, Iterator, TextIO
_PATH_LOCKS: dict[Path, threading.Lock] = {}
_PATH_LOCKS_GUARD = threading.Lock()
def _utc_now() -> datetime:
return datetime.now(timezone.utc)
def _normalise_datetime(value: datetime | str) -> datetime:
if isinstance(value, datetime):
dt = value
elif isinstance(value, str):
dt = datetime.fromisoformat(value.replace("Z", "+00:00"))
else:
raise TypeError(f"Expected datetime or ISO string, got {type(value).__name__}")
if dt.tzinfo is None:
return dt.replace(tzinfo=timezone.utc)
return dt.astimezone(timezone.utc)
def _serialise_datetime(value: datetime) -> str:
return _normalise_datetime(value).isoformat().replace("+00:00", "Z")
def _validate_non_negative_int(name: str, value: int) -> None:
if not isinstance(value, int) or value < 0:
raise ValueError(f"{name} must be a non-negative integer")
def _validate_non_negative_float(name: str, value: float) -> None:
if not isinstance(value, (int, float)) or float(value) < 0:
raise ValueError(f"{name} must be a non-negative number")
def _path_lock(path: Path) -> threading.Lock:
resolved = path.resolve()
with _PATH_LOCKS_GUARD:
lock = _PATH_LOCKS.get(resolved)
if lock is None:
lock = threading.Lock()
_PATH_LOCKS[resolved] = lock
return lock
def _lock_file(handle: TextIO) -> None:
if os.name == "nt":
import msvcrt
msvcrt.locking(handle.fileno(), msvcrt.LK_LOCK, 1)
else:
import fcntl
fcntl.flock(handle.fileno(), fcntl.LOCK_EX)
def _unlock_file(handle: TextIO) -> None:
if os.name == "nt":
import msvcrt
msvcrt.locking(handle.fileno(), msvcrt.LK_UNLCK, 1)
else:
import fcntl
fcntl.flock(handle.fileno(), fcntl.LOCK_UN)
@contextmanager
def _locked_file(path: Path, mode: str) -> Iterator[TextIO]:
path.parent.mkdir(parents=True, exist_ok=True)
local_lock = _path_lock(path)
with local_lock:
with path.open(mode, encoding="utf-8") as handle:
_lock_file(handle)
try:
yield handle
finally:
_unlock_file(handle)
@dataclass(frozen=True)
class QualityObservation:
"""Observed quality/cost outcome for one adapter on one task type."""
task_type: str
adapter_id: str
model_id: str
cost_usd: float
quality_score: float
latency_ms: float
tokens_in: int
tokens_out: int
baseline_adapter_id: str | None = None
recorded_at: datetime = field(default_factory=_utc_now)
tags: dict[str, Any] = field(default_factory=dict)
def __post_init__(self) -> None:
for name in ("task_type", "adapter_id", "model_id"):
if not str(getattr(self, name)).strip():
raise ValueError(f"{name} must be a non-empty string")
_validate_non_negative_float("cost_usd", self.cost_usd)
_validate_non_negative_float("latency_ms", self.latency_ms)
_validate_non_negative_int("tokens_in", self.tokens_in)
_validate_non_negative_int("tokens_out", self.tokens_out)
if not isinstance(self.quality_score, (int, float)):
raise ValueError("quality_score must be a number between 0 and 1")
if not 0 <= float(self.quality_score) <= 1:
raise ValueError("quality_score must be between 0 and 1")
object.__setattr__(self, "task_type", str(self.task_type))
object.__setattr__(self, "adapter_id", str(self.adapter_id))
object.__setattr__(self, "model_id", str(self.model_id))
object.__setattr__(self, "cost_usd", float(self.cost_usd))
object.__setattr__(self, "quality_score", float(self.quality_score))
object.__setattr__(self, "latency_ms", float(self.latency_ms))
object.__setattr__(self, "recorded_at", _normalise_datetime(self.recorded_at))
object.__setattr__(self, "tags", dict(self.tags))
@property
def total_tokens(self) -> int:
"""Return input plus output tokens."""
return self.tokens_in + self.tokens_out
def to_dict(self) -> dict[str, Any]:
"""Convert to a JSON-serialisable dictionary."""
return {
"task_type": self.task_type,
"adapter_id": self.adapter_id,
"model_id": self.model_id,
"cost_usd": self.cost_usd,
"quality_score": self.quality_score,
"latency_ms": self.latency_ms,
"tokens_in": self.tokens_in,
"tokens_out": self.tokens_out,
"baseline_adapter_id": self.baseline_adapter_id,
"recorded_at": _serialise_datetime(self.recorded_at),
"tags": dict(self.tags),
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "QualityObservation":
"""Create an observation from a JSON-decoded dictionary."""
return cls(
task_type=data["task_type"],
adapter_id=data["adapter_id"],
model_id=data["model_id"],
cost_usd=data["cost_usd"],
quality_score=data["quality_score"],
latency_ms=data["latency_ms"],
tokens_in=data["tokens_in"],
tokens_out=data["tokens_out"],
baseline_adapter_id=data.get("baseline_adapter_id"),
recorded_at=data.get("recorded_at", _utc_now()),
tags=data.get("tags") or {},
)
def is_stale(
observation: QualityObservation,
max_age: timedelta,
*,
now: datetime | None = None,
) -> bool:
"""Return whether *observation* is older than *max_age*."""
if max_age.total_seconds() < 0:
raise ValueError("max_age must be non-negative")
reference = _normalise_datetime(now or _utc_now())
return observation.recorded_at < reference - max_age
class QualityLedger:
"""Append-only JSONL store for :class:`QualityObservation` records."""
def __init__(self, path: str | Path):
self._path = Path(path)
@property
def path(self) -> Path:
"""Ledger file path."""
return self._path
def append(self, observation: QualityObservation) -> None:
"""Append one observation as a locked JSONL record."""
line = json.dumps(observation.to_dict(), sort_keys=True, separators=(",", ":"))
with _locked_file(self._path, "a") as handle:
handle.write(line + "\n")
handle.flush()
os.fsync(handle.fileno())
def read_all(self) -> list[QualityObservation]:
"""Return all parseable observations, skipping malformed lines."""
observations, _ = self._read_with_malformed_count()
return observations
def malformed_count(self) -> int:
"""Return the number of malformed lines currently skipped by reads."""
_, malformed = self._read_with_malformed_count()
return malformed
def by_task_type(self, task_type: str) -> list[QualityObservation]:
"""Return observations matching *task_type*."""
return [obs for obs in self.read_all() if obs.task_type == task_type]
def recent(
self,
limit: int | None = None,
*,
task_type: str | None = None,
adapter_id: str | None = None,
since: datetime | None = None,
) -> list[QualityObservation]:
"""Return newest observations first, optionally filtered."""
if limit is not None and limit < 0:
raise ValueError("limit must be non-negative")
cutoff = _normalise_datetime(since) if since is not None else None
observations = self.read_all()
if task_type is not None:
observations = [obs for obs in observations if obs.task_type == task_type]
if adapter_id is not None:
observations = [obs for obs in observations if obs.adapter_id == adapter_id]
if cutoff is not None:
observations = [obs for obs in observations if obs.recorded_at >= cutoff]
observations.sort(key=lambda obs: obs.recorded_at, reverse=True)
if limit is None:
return observations
return observations[:limit]
def mean_quality(
self,
task_type: str,
*,
adapter_id: str | None = None,
model_id: str | None = None,
max_age: timedelta | None = None,
min_observations: int = 1,
) -> float | None:
"""Return mean quality for matching observations, or ``None`` if absent."""
if min_observations <= 0:
raise ValueError("min_observations must be positive")
observations = self.by_task_type(task_type)
if adapter_id is not None:
observations = [obs for obs in observations if obs.adapter_id == adapter_id]
if model_id is not None:
observations = [obs for obs in observations if obs.model_id == model_id]
if max_age is not None:
observations = [obs for obs in observations if not is_stale(obs, max_age)]
if len(observations) < min_observations:
return None
return sum(obs.quality_score for obs in observations) / len(observations)
def prune_before(self, timestamp: datetime) -> int:
"""Remove valid observations recorded before *timestamp*.
Malformed lines are preserved because their timestamp cannot be trusted.
Returns the number of valid observation records removed.
"""
cutoff = _normalise_datetime(timestamp)
removed = 0
with _locked_file(self._path, "a+") as handle:
handle.seek(0)
lines = handle.readlines()
kept: list[str] = []
for line in lines:
try:
obs = QualityObservation.from_dict(json.loads(line))
except (json.JSONDecodeError, KeyError, TypeError, ValueError):
kept.append(line)
continue
if obs.recorded_at < cutoff:
removed += 1
else:
kept.append(line)
handle.seek(0)
handle.truncate()
handle.writelines(kept)
handle.flush()
os.fsync(handle.fileno())
return removed
def _read_with_malformed_count(self) -> tuple[list[QualityObservation], int]:
if not self._path.is_file():
return [], 0
observations: list[QualityObservation] = []
malformed = 0
with _locked_file(self._path, "r") as handle:
for line in handle:
if not line.strip():
continue
try:
observations.append(QualityObservation.from_dict(json.loads(line)))
except (json.JSONDecodeError, KeyError, TypeError, ValueError):
malformed += 1
return observations, malformed

View File

@@ -5,9 +5,11 @@ Maps task types to preferred adapters with optional cost-cap fallback.
"""
from dataclasses import dataclass, field
from typing import Optional, List
from datetime import datetime, timedelta, timezone
from typing import List, Mapping, Optional
from llm_connect.adapter import LLMAdapter
from llm_connect.quality import QualityLedger, QualityObservation
@dataclass
@@ -87,3 +89,172 @@ class RoutingPolicy:
raise LookupError(
f"No routing rule for task_type={task_type!r} and no default configured"
)
@dataclass(frozen=True)
class _CandidateMetrics:
adapter_id: str
adapter: LLMAdapter
mean_quality: float
mean_cost_usd: float
order: int
is_static_prefer: bool
@dataclass
class AdaptiveRoutingPolicy(RoutingPolicy):
"""Route to the cheapest adapter whose observed quality clears a floor.
The policy consults a :class:`~llm_connect.quality.QualityLedger` for
observations matching ``task_type`` and adapter id. When the ledger has no
qualifying observations, resolution falls through to ``RoutingPolicy`` so a
caller can use the same policy on day zero and after observations accrue.
"""
ledger: Optional[QualityLedger] = None
adapters_by_id: Mapping[str, LLMAdapter] = field(default_factory=dict)
window_size: int = 20
min_observations: int = 1
max_age: Optional[timedelta] = None
def __post_init__(self) -> None:
if self.window_size <= 0:
raise ValueError("window_size must be positive")
if self.min_observations <= 0:
raise ValueError("min_observations must be positive")
if self.max_age is not None and self.max_age.total_seconds() < 0:
raise ValueError("max_age must be non-negative")
def resolve(
self,
task_type: str,
estimated_cost_per_1k: Optional[float] = None,
*,
quality_floor: Optional[float] = None,
) -> LLMAdapter:
"""Return the adaptive adapter for *task_type*.
Args:
task_type: Logical task identifier.
estimated_cost_per_1k: Passed through to static routing fallback.
quality_floor: Minimum observed mean quality required for adaptive
selection. When omitted, static routing is used.
Returns:
The selected :class:`~llm_connect.adapter.LLMAdapter`.
"""
if quality_floor is None or self.ledger is None:
return super().resolve(task_type, estimated_cost_per_1k)
if not 0 <= quality_floor <= 1:
raise ValueError("quality_floor must be between 0 and 1")
metrics = self._qualifying_candidates(task_type, quality_floor)
if not metrics:
return super().resolve(task_type, estimated_cost_per_1k)
best = min(
metrics,
key=lambda candidate: (
candidate.mean_cost_usd,
0 if candidate.is_static_prefer else 1,
candidate.order,
),
)
return best.adapter
def _qualifying_candidates(
self,
task_type: str,
quality_floor: float,
) -> list[_CandidateMetrics]:
static_prefer = self._static_preferred_adapter(task_type)
candidates: list[_CandidateMetrics] = []
for order, (adapter_id, adapter) in enumerate(self._candidate_entries(task_type)):
observations = self._windowed_observations(task_type, adapter_id)
if len(observations) < self.min_observations:
continue
mean_quality = sum(obs.quality_score for obs in observations) / len(observations)
if mean_quality < quality_floor:
continue
mean_cost = sum(obs.cost_usd for obs in observations) / len(observations)
candidates.append(
_CandidateMetrics(
adapter_id=adapter_id,
adapter=adapter,
mean_quality=mean_quality,
mean_cost_usd=mean_cost,
order=order,
is_static_prefer=adapter is static_prefer,
)
)
return candidates
def _windowed_observations(
self,
task_type: str,
adapter_id: str,
) -> list[QualityObservation]:
if self.ledger is None:
return []
since = None
if self.max_age is not None:
since = datetime.now(timezone.utc) - self.max_age
return self.ledger.recent(
limit=self.window_size,
task_type=task_type,
adapter_id=adapter_id,
since=since,
)
def _candidate_entries(self, task_type: str) -> list[tuple[str, LLMAdapter]]:
entries: list[tuple[str, LLMAdapter]] = []
seen_ids: set[str] = set()
def add(adapter_id: str | None, adapter: LLMAdapter | None) -> None:
if adapter is None or adapter_id is None or adapter_id in seen_ids:
return
seen_ids.add(adapter_id)
entries.append((adapter_id, adapter))
for adapter_id, adapter in self.adapters_by_id.items():
add(adapter_id, adapter)
for adapter in self._static_candidate_adapters(task_type):
add(self._adapter_id_for(adapter), adapter)
return entries
def _static_candidate_adapters(self, task_type: str) -> list[LLMAdapter]:
for rule in self.rules:
if rule.task_type == task_type:
candidates = [rule.prefer]
if rule.fallback is not None:
candidates.append(rule.fallback)
if self.default is not None:
candidates.append(self.default)
return candidates
if self.default is not None:
return [self.default]
return []
def _static_preferred_adapter(self, task_type: str) -> LLMAdapter | None:
for rule in self.rules:
if rule.task_type == task_type:
return rule.prefer
return None
def _adapter_id_for(self, adapter: LLMAdapter) -> str | None:
for adapter_id, candidate in self.adapters_by_id.items():
if candidate is adapter:
return adapter_id
for attribute in ("adapter_id", "id", "name"):
value = getattr(adapter, attribute, None)
if isinstance(value, str) and value.strip():
return value
return None

177
llm_connect/shadowing.py Normal file
View File

@@ -0,0 +1,177 @@
"""Shadow-mode observation adapter for adaptive routing."""
from __future__ import annotations
import asyncio
import random
import threading
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass, field, replace
from typing import Any, Callable, Mapping
from llm_connect.adapter import LLMAdapter
from llm_connect.grading import BaselineGrader
from llm_connect.models import LLMResponse, RunConfig
from llm_connect.quality import QualityLedger, QualityObservation
def _default_cost_estimator(response: LLMResponse) -> float:
for key in ("cost_usd", "estimated_cost_usd", "cost"):
value = response.metadata.get(key)
if isinstance(value, (int, float)) and value >= 0:
return float(value)
return 0.0
class _StaticResponseAdapter(LLMAdapter):
"""Adapter shim that lets a BaselineGrader reuse an existing response."""
def __init__(self, response: LLMResponse):
self._response = response
def execute_prompt(self, prompt: str, config: RunConfig) -> LLMResponse:
return self._response
def validate_config(self, config: RunConfig) -> bool:
return True
@dataclass
class ShadowingAdapter(LLMAdapter):
"""Return candidate responses while recording sampled baseline grades.
Shadow work is best-effort: baseline, grading, or ledger failures are
reported to ``on_shadow_error`` when provided, but never alter the candidate
response returned to the caller.
"""
candidate_adapter: LLMAdapter
baseline_adapter: LLMAdapter
grader: BaselineGrader
ledger: QualityLedger
task_type: str
adapter_id: str
model_id: str | None = None
baseline_adapter_id: str | None = None
shadow_rate: float = 1.0
async_shadow: bool = False
random_source: random.Random = field(default_factory=random.Random, repr=False)
cost_estimator: Callable[[LLMResponse], float] = _default_cost_estimator
tags: Mapping[str, Any] = field(default_factory=dict)
on_shadow_error: Callable[[Exception], None] | None = None
_executor: ThreadPoolExecutor | None = field(default=None, init=False, repr=False)
_futures: list[Future[None]] = field(default_factory=list, init=False, repr=False)
_lock: threading.Lock = field(default_factory=threading.Lock, init=False, repr=False)
def __post_init__(self) -> None:
if not str(self.task_type).strip():
raise ValueError("task_type must be a non-empty string")
if not str(self.adapter_id).strip():
raise ValueError("adapter_id must be a non-empty string")
if not 0 <= self.shadow_rate <= 1:
raise ValueError("shadow_rate must be between 0 and 1")
if self.async_shadow:
self._executor = ThreadPoolExecutor(max_workers=1)
def execute_prompt(self, prompt: str, config: RunConfig) -> LLMResponse:
response = self.candidate_adapter.execute_prompt(prompt, config)
if self._should_shadow():
self._handle_shadow(prompt, config, response)
return response
async def async_execute_prompt(self, prompt: str, config: RunConfig) -> LLMResponse:
response = await self.candidate_adapter.async_execute_prompt(prompt, config)
if self._should_shadow():
if self.async_shadow:
self._schedule_shadow(prompt, config, response)
else:
await asyncio.to_thread(self._run_shadow, prompt, config, response)
return response
def validate_config(self, config: RunConfig) -> bool:
return self.candidate_adapter.validate_config(config)
def flush(self, timeout: float | None = None) -> None:
"""Wait for currently queued async shadow work to finish."""
with self._lock:
futures = list(self._futures)
self._futures.clear()
for future in futures:
future.result(timeout=timeout)
def shutdown(self, wait: bool = True) -> None:
"""Shut down the background shadow executor if one was created."""
if self._executor is not None:
self._executor.shutdown(wait=wait)
self._executor = None
def _should_shadow(self) -> bool:
if self.shadow_rate <= 0:
return False
if self.shadow_rate >= 1:
return True
with self._lock:
return self.random_source.random() < self.shadow_rate
def _handle_shadow(
self,
prompt: str,
config: RunConfig,
candidate_response: LLMResponse,
) -> None:
if self.async_shadow:
self._schedule_shadow(prompt, config, candidate_response)
else:
self._run_shadow(prompt, config, candidate_response)
def _schedule_shadow(
self,
prompt: str,
config: RunConfig,
candidate_response: LLMResponse,
) -> None:
if self._executor is None:
self._executor = ThreadPoolExecutor(max_workers=1)
future = self._executor.submit(self._run_shadow, prompt, config, candidate_response)
with self._lock:
self._futures = [item for item in self._futures if not item.done()]
self._futures.append(future)
def _run_shadow(
self,
prompt: str,
config: RunConfig,
candidate_response: LLMResponse,
) -> None:
try:
shadow_config = replace(config, budget_tracker=None)
result = self.grader.grade(
self.baseline_adapter,
_StaticResponseAdapter(candidate_response),
prompt,
shadow_config,
)
self.ledger.append(
QualityObservation(
task_type=self.task_type,
adapter_id=self.adapter_id,
model_id=self.model_id or candidate_response.model or config.model_name,
cost_usd=self.cost_estimator(candidate_response),
quality_score=result.quality_score,
latency_ms=float(candidate_response.metadata.get("latency_ms", 0.0)),
tokens_in=int(candidate_response.usage.get("prompt_tokens", 0)),
tokens_out=int(candidate_response.usage.get("completion_tokens", 0)),
baseline_adapter_id=self.baseline_adapter_id,
tags=dict(self.tags),
)
)
except Exception as exc:
self._report_shadow_error(exc)
def _report_shadow_error(self, exc: Exception) -> None:
if self.on_shadow_error is None:
return
try:
self.on_shadow_error(exc)
except Exception:
pass