Files
llm-connect/llm_connect/shadowing.py
tegwick c4ad4bb9f2
Some checks failed
CI / test (3.10) (push) Has been cancelled
CI / test (3.11) (push) Has been cancelled
CI / test (3.12) (push) Has been cancelled
Add adaptive cost-quality routing primitives
2026-05-17 21:32:27 +02:00

178 lines
6.6 KiB
Python

"""Shadow-mode observation adapter for adaptive routing."""
from __future__ import annotations
import asyncio
import random
import threading
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass, field, replace
from typing import Any, Callable, Mapping
from llm_connect.adapter import LLMAdapter
from llm_connect.grading import BaselineGrader
from llm_connect.models import LLMResponse, RunConfig
from llm_connect.quality import QualityLedger, QualityObservation
def _default_cost_estimator(response: LLMResponse) -> float:
for key in ("cost_usd", "estimated_cost_usd", "cost"):
value = response.metadata.get(key)
if isinstance(value, (int, float)) and value >= 0:
return float(value)
return 0.0
class _StaticResponseAdapter(LLMAdapter):
"""Adapter shim that lets a BaselineGrader reuse an existing response."""
def __init__(self, response: LLMResponse):
self._response = response
def execute_prompt(self, prompt: str, config: RunConfig) -> LLMResponse:
return self._response
def validate_config(self, config: RunConfig) -> bool:
return True
@dataclass
class ShadowingAdapter(LLMAdapter):
"""Return candidate responses while recording sampled baseline grades.
Shadow work is best-effort: baseline, grading, or ledger failures are
reported to ``on_shadow_error`` when provided, but never alter the candidate
response returned to the caller.
"""
candidate_adapter: LLMAdapter
baseline_adapter: LLMAdapter
grader: BaselineGrader
ledger: QualityLedger
task_type: str
adapter_id: str
model_id: str | None = None
baseline_adapter_id: str | None = None
shadow_rate: float = 1.0
async_shadow: bool = False
random_source: random.Random = field(default_factory=random.Random, repr=False)
cost_estimator: Callable[[LLMResponse], float] = _default_cost_estimator
tags: Mapping[str, Any] = field(default_factory=dict)
on_shadow_error: Callable[[Exception], None] | None = None
_executor: ThreadPoolExecutor | None = field(default=None, init=False, repr=False)
_futures: list[Future[None]] = field(default_factory=list, init=False, repr=False)
_lock: threading.Lock = field(default_factory=threading.Lock, init=False, repr=False)
def __post_init__(self) -> None:
if not str(self.task_type).strip():
raise ValueError("task_type must be a non-empty string")
if not str(self.adapter_id).strip():
raise ValueError("adapter_id must be a non-empty string")
if not 0 <= self.shadow_rate <= 1:
raise ValueError("shadow_rate must be between 0 and 1")
if self.async_shadow:
self._executor = ThreadPoolExecutor(max_workers=1)
def execute_prompt(self, prompt: str, config: RunConfig) -> LLMResponse:
response = self.candidate_adapter.execute_prompt(prompt, config)
if self._should_shadow():
self._handle_shadow(prompt, config, response)
return response
async def async_execute_prompt(self, prompt: str, config: RunConfig) -> LLMResponse:
response = await self.candidate_adapter.async_execute_prompt(prompt, config)
if self._should_shadow():
if self.async_shadow:
self._schedule_shadow(prompt, config, response)
else:
await asyncio.to_thread(self._run_shadow, prompt, config, response)
return response
def validate_config(self, config: RunConfig) -> bool:
return self.candidate_adapter.validate_config(config)
def flush(self, timeout: float | None = None) -> None:
"""Wait for currently queued async shadow work to finish."""
with self._lock:
futures = list(self._futures)
self._futures.clear()
for future in futures:
future.result(timeout=timeout)
def shutdown(self, wait: bool = True) -> None:
"""Shut down the background shadow executor if one was created."""
if self._executor is not None:
self._executor.shutdown(wait=wait)
self._executor = None
def _should_shadow(self) -> bool:
if self.shadow_rate <= 0:
return False
if self.shadow_rate >= 1:
return True
with self._lock:
return self.random_source.random() < self.shadow_rate
def _handle_shadow(
self,
prompt: str,
config: RunConfig,
candidate_response: LLMResponse,
) -> None:
if self.async_shadow:
self._schedule_shadow(prompt, config, candidate_response)
else:
self._run_shadow(prompt, config, candidate_response)
def _schedule_shadow(
self,
prompt: str,
config: RunConfig,
candidate_response: LLMResponse,
) -> None:
if self._executor is None:
self._executor = ThreadPoolExecutor(max_workers=1)
future = self._executor.submit(self._run_shadow, prompt, config, candidate_response)
with self._lock:
self._futures = [item for item in self._futures if not item.done()]
self._futures.append(future)
def _run_shadow(
self,
prompt: str,
config: RunConfig,
candidate_response: LLMResponse,
) -> None:
try:
shadow_config = replace(config, budget_tracker=None)
result = self.grader.grade(
self.baseline_adapter,
_StaticResponseAdapter(candidate_response),
prompt,
shadow_config,
)
self.ledger.append(
QualityObservation(
task_type=self.task_type,
adapter_id=self.adapter_id,
model_id=self.model_id or candidate_response.model or config.model_name,
cost_usd=self.cost_estimator(candidate_response),
quality_score=result.quality_score,
latency_ms=float(candidate_response.metadata.get("latency_ms", 0.0)),
tokens_in=int(candidate_response.usage.get("prompt_tokens", 0)),
tokens_out=int(candidate_response.usage.get("completion_tokens", 0)),
baseline_adapter_id=self.baseline_adapter_id,
tags=dict(self.tags),
)
)
except Exception as exc:
self._report_shadow_error(exc)
def _report_shadow_error(self, exc: Exception) -> None:
if self.on_shadow_error is None:
return
try:
self.on_shadow_error(exc)
except Exception:
pass