generated from coulomb/repo-seed
Add adaptive cost-quality routing primitives
This commit is contained in:
@@ -5,9 +5,11 @@ Maps task types to preferred adapters with optional cost-cap fallback.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, List
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import List, Mapping, Optional
|
||||
|
||||
from llm_connect.adapter import LLMAdapter
|
||||
from llm_connect.quality import QualityLedger, QualityObservation
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -87,3 +89,172 @@ class RoutingPolicy:
|
||||
raise LookupError(
|
||||
f"No routing rule for task_type={task_type!r} and no default configured"
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _CandidateMetrics:
|
||||
adapter_id: str
|
||||
adapter: LLMAdapter
|
||||
mean_quality: float
|
||||
mean_cost_usd: float
|
||||
order: int
|
||||
is_static_prefer: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class AdaptiveRoutingPolicy(RoutingPolicy):
|
||||
"""Route to the cheapest adapter whose observed quality clears a floor.
|
||||
|
||||
The policy consults a :class:`~llm_connect.quality.QualityLedger` for
|
||||
observations matching ``task_type`` and adapter id. When the ledger has no
|
||||
qualifying observations, resolution falls through to ``RoutingPolicy`` so a
|
||||
caller can use the same policy on day zero and after observations accrue.
|
||||
"""
|
||||
|
||||
ledger: Optional[QualityLedger] = None
|
||||
adapters_by_id: Mapping[str, LLMAdapter] = field(default_factory=dict)
|
||||
window_size: int = 20
|
||||
min_observations: int = 1
|
||||
max_age: Optional[timedelta] = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.window_size <= 0:
|
||||
raise ValueError("window_size must be positive")
|
||||
if self.min_observations <= 0:
|
||||
raise ValueError("min_observations must be positive")
|
||||
if self.max_age is not None and self.max_age.total_seconds() < 0:
|
||||
raise ValueError("max_age must be non-negative")
|
||||
|
||||
def resolve(
|
||||
self,
|
||||
task_type: str,
|
||||
estimated_cost_per_1k: Optional[float] = None,
|
||||
*,
|
||||
quality_floor: Optional[float] = None,
|
||||
) -> LLMAdapter:
|
||||
"""Return the adaptive adapter for *task_type*.
|
||||
|
||||
Args:
|
||||
task_type: Logical task identifier.
|
||||
estimated_cost_per_1k: Passed through to static routing fallback.
|
||||
quality_floor: Minimum observed mean quality required for adaptive
|
||||
selection. When omitted, static routing is used.
|
||||
|
||||
Returns:
|
||||
The selected :class:`~llm_connect.adapter.LLMAdapter`.
|
||||
"""
|
||||
if quality_floor is None or self.ledger is None:
|
||||
return super().resolve(task_type, estimated_cost_per_1k)
|
||||
if not 0 <= quality_floor <= 1:
|
||||
raise ValueError("quality_floor must be between 0 and 1")
|
||||
|
||||
metrics = self._qualifying_candidates(task_type, quality_floor)
|
||||
if not metrics:
|
||||
return super().resolve(task_type, estimated_cost_per_1k)
|
||||
|
||||
best = min(
|
||||
metrics,
|
||||
key=lambda candidate: (
|
||||
candidate.mean_cost_usd,
|
||||
0 if candidate.is_static_prefer else 1,
|
||||
candidate.order,
|
||||
),
|
||||
)
|
||||
return best.adapter
|
||||
|
||||
def _qualifying_candidates(
|
||||
self,
|
||||
task_type: str,
|
||||
quality_floor: float,
|
||||
) -> list[_CandidateMetrics]:
|
||||
static_prefer = self._static_preferred_adapter(task_type)
|
||||
candidates: list[_CandidateMetrics] = []
|
||||
for order, (adapter_id, adapter) in enumerate(self._candidate_entries(task_type)):
|
||||
observations = self._windowed_observations(task_type, adapter_id)
|
||||
if len(observations) < self.min_observations:
|
||||
continue
|
||||
|
||||
mean_quality = sum(obs.quality_score for obs in observations) / len(observations)
|
||||
if mean_quality < quality_floor:
|
||||
continue
|
||||
|
||||
mean_cost = sum(obs.cost_usd for obs in observations) / len(observations)
|
||||
candidates.append(
|
||||
_CandidateMetrics(
|
||||
adapter_id=adapter_id,
|
||||
adapter=adapter,
|
||||
mean_quality=mean_quality,
|
||||
mean_cost_usd=mean_cost,
|
||||
order=order,
|
||||
is_static_prefer=adapter is static_prefer,
|
||||
)
|
||||
)
|
||||
return candidates
|
||||
|
||||
def _windowed_observations(
|
||||
self,
|
||||
task_type: str,
|
||||
adapter_id: str,
|
||||
) -> list[QualityObservation]:
|
||||
if self.ledger is None:
|
||||
return []
|
||||
|
||||
since = None
|
||||
if self.max_age is not None:
|
||||
since = datetime.now(timezone.utc) - self.max_age
|
||||
|
||||
return self.ledger.recent(
|
||||
limit=self.window_size,
|
||||
task_type=task_type,
|
||||
adapter_id=adapter_id,
|
||||
since=since,
|
||||
)
|
||||
|
||||
def _candidate_entries(self, task_type: str) -> list[tuple[str, LLMAdapter]]:
|
||||
entries: list[tuple[str, LLMAdapter]] = []
|
||||
seen_ids: set[str] = set()
|
||||
|
||||
def add(adapter_id: str | None, adapter: LLMAdapter | None) -> None:
|
||||
if adapter is None or adapter_id is None or adapter_id in seen_ids:
|
||||
return
|
||||
seen_ids.add(adapter_id)
|
||||
entries.append((adapter_id, adapter))
|
||||
|
||||
for adapter_id, adapter in self.adapters_by_id.items():
|
||||
add(adapter_id, adapter)
|
||||
|
||||
for adapter in self._static_candidate_adapters(task_type):
|
||||
add(self._adapter_id_for(adapter), adapter)
|
||||
|
||||
return entries
|
||||
|
||||
def _static_candidate_adapters(self, task_type: str) -> list[LLMAdapter]:
|
||||
for rule in self.rules:
|
||||
if rule.task_type == task_type:
|
||||
candidates = [rule.prefer]
|
||||
if rule.fallback is not None:
|
||||
candidates.append(rule.fallback)
|
||||
if self.default is not None:
|
||||
candidates.append(self.default)
|
||||
return candidates
|
||||
|
||||
if self.default is not None:
|
||||
return [self.default]
|
||||
return []
|
||||
|
||||
def _static_preferred_adapter(self, task_type: str) -> LLMAdapter | None:
|
||||
for rule in self.rules:
|
||||
if rule.task_type == task_type:
|
||||
return rule.prefer
|
||||
return None
|
||||
|
||||
def _adapter_id_for(self, adapter: LLMAdapter) -> str | None:
|
||||
for adapter_id, candidate in self.adapters_by_id.items():
|
||||
if candidate is adapter:
|
||||
return adapter_id
|
||||
|
||||
for attribute in ("adapter_id", "id", "name"):
|
||||
value = getattr(adapter, attribute, None)
|
||||
if isinstance(value, str) and value.strip():
|
||||
return value
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user