Add adaptive cost-quality routing primitives
Some checks failed
CI / test (3.10) (push) Has been cancelled
CI / test (3.11) (push) Has been cancelled
CI / test (3.12) (push) Has been cancelled

This commit is contained in:
2026-05-17 21:32:27 +02:00
parent bf86a03c5d
commit c4ad4bb9f2
17 changed files with 2480 additions and 25 deletions

View File

@@ -5,9 +5,11 @@ Maps task types to preferred adapters with optional cost-cap fallback.
"""
from dataclasses import dataclass, field
from typing import Optional, List
from datetime import datetime, timedelta, timezone
from typing import List, Mapping, Optional
from llm_connect.adapter import LLMAdapter
from llm_connect.quality import QualityLedger, QualityObservation
@dataclass
@@ -87,3 +89,172 @@ class RoutingPolicy:
raise LookupError(
f"No routing rule for task_type={task_type!r} and no default configured"
)
@dataclass(frozen=True)
class _CandidateMetrics:
adapter_id: str
adapter: LLMAdapter
mean_quality: float
mean_cost_usd: float
order: int
is_static_prefer: bool
@dataclass
class AdaptiveRoutingPolicy(RoutingPolicy):
"""Route to the cheapest adapter whose observed quality clears a floor.
The policy consults a :class:`~llm_connect.quality.QualityLedger` for
observations matching ``task_type`` and adapter id. When the ledger has no
qualifying observations, resolution falls through to ``RoutingPolicy`` so a
caller can use the same policy on day zero and after observations accrue.
"""
ledger: Optional[QualityLedger] = None
adapters_by_id: Mapping[str, LLMAdapter] = field(default_factory=dict)
window_size: int = 20
min_observations: int = 1
max_age: Optional[timedelta] = None
def __post_init__(self) -> None:
if self.window_size <= 0:
raise ValueError("window_size must be positive")
if self.min_observations <= 0:
raise ValueError("min_observations must be positive")
if self.max_age is not None and self.max_age.total_seconds() < 0:
raise ValueError("max_age must be non-negative")
def resolve(
self,
task_type: str,
estimated_cost_per_1k: Optional[float] = None,
*,
quality_floor: Optional[float] = None,
) -> LLMAdapter:
"""Return the adaptive adapter for *task_type*.
Args:
task_type: Logical task identifier.
estimated_cost_per_1k: Passed through to static routing fallback.
quality_floor: Minimum observed mean quality required for adaptive
selection. When omitted, static routing is used.
Returns:
The selected :class:`~llm_connect.adapter.LLMAdapter`.
"""
if quality_floor is None or self.ledger is None:
return super().resolve(task_type, estimated_cost_per_1k)
if not 0 <= quality_floor <= 1:
raise ValueError("quality_floor must be between 0 and 1")
metrics = self._qualifying_candidates(task_type, quality_floor)
if not metrics:
return super().resolve(task_type, estimated_cost_per_1k)
best = min(
metrics,
key=lambda candidate: (
candidate.mean_cost_usd,
0 if candidate.is_static_prefer else 1,
candidate.order,
),
)
return best.adapter
def _qualifying_candidates(
self,
task_type: str,
quality_floor: float,
) -> list[_CandidateMetrics]:
static_prefer = self._static_preferred_adapter(task_type)
candidates: list[_CandidateMetrics] = []
for order, (adapter_id, adapter) in enumerate(self._candidate_entries(task_type)):
observations = self._windowed_observations(task_type, adapter_id)
if len(observations) < self.min_observations:
continue
mean_quality = sum(obs.quality_score for obs in observations) / len(observations)
if mean_quality < quality_floor:
continue
mean_cost = sum(obs.cost_usd for obs in observations) / len(observations)
candidates.append(
_CandidateMetrics(
adapter_id=adapter_id,
adapter=adapter,
mean_quality=mean_quality,
mean_cost_usd=mean_cost,
order=order,
is_static_prefer=adapter is static_prefer,
)
)
return candidates
def _windowed_observations(
self,
task_type: str,
adapter_id: str,
) -> list[QualityObservation]:
if self.ledger is None:
return []
since = None
if self.max_age is not None:
since = datetime.now(timezone.utc) - self.max_age
return self.ledger.recent(
limit=self.window_size,
task_type=task_type,
adapter_id=adapter_id,
since=since,
)
def _candidate_entries(self, task_type: str) -> list[tuple[str, LLMAdapter]]:
entries: list[tuple[str, LLMAdapter]] = []
seen_ids: set[str] = set()
def add(adapter_id: str | None, adapter: LLMAdapter | None) -> None:
if adapter is None or adapter_id is None or adapter_id in seen_ids:
return
seen_ids.add(adapter_id)
entries.append((adapter_id, adapter))
for adapter_id, adapter in self.adapters_by_id.items():
add(adapter_id, adapter)
for adapter in self._static_candidate_adapters(task_type):
add(self._adapter_id_for(adapter), adapter)
return entries
def _static_candidate_adapters(self, task_type: str) -> list[LLMAdapter]:
for rule in self.rules:
if rule.task_type == task_type:
candidates = [rule.prefer]
if rule.fallback is not None:
candidates.append(rule.fallback)
if self.default is not None:
candidates.append(self.default)
return candidates
if self.default is not None:
return [self.default]
return []
def _static_preferred_adapter(self, task_type: str) -> LLMAdapter | None:
for rule in self.rules:
if rule.task_type == task_type:
return rule.prefer
return None
def _adapter_id_for(self, adapter: LLMAdapter) -> str | None:
for adapter_id, candidate in self.adapters_by_id.items():
if candidate is adapter:
return adapter_id
for attribute in ("adapter_id", "id", "name"):
value = getattr(adapter, attribute, None)
if isinstance(value, str) and value.strip():
return value
return None