Files
llm-connect/llm_connect/routing.py
tegwick c4ad4bb9f2
Some checks failed
CI / test (3.10) (push) Has been cancelled
CI / test (3.11) (push) Has been cancelled
CI / test (3.12) (push) Has been cancelled
Add adaptive cost-quality routing primitives
2026-05-17 21:32:27 +02:00

261 lines
8.8 KiB
Python

"""
RoutingPolicy — task-type-aware adapter selection (FR-2).
Maps task types to preferred adapters with optional cost-cap fallback.
"""
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from typing import List, Mapping, Optional
from llm_connect.adapter import LLMAdapter
from llm_connect.quality import QualityLedger, QualityObservation
@dataclass
class RoutingRule:
"""Single routing rule binding a task type to an adapter.
Attributes:
task_type: Logical task identifier (e.g. ``"triage"``, ``"summarise"``).
prefer: Adapter to use when this rule matches.
max_cost_per_1k: Optional cost ceiling (USD per 1 000 tokens). When the
caller supplies ``estimated_cost_per_1k`` to :meth:`RoutingPolicy.resolve`
and it exceeds this cap, *fallback* is returned instead of *prefer*.
fallback: Adapter to use when the cost cap is breached.
"""
task_type: str
prefer: LLMAdapter
max_cost_per_1k: Optional[float] = None
fallback: Optional[LLMAdapter] = None
@dataclass
class RoutingPolicy:
"""Route task types to LLM adapters.
Rules are evaluated in order; the first match wins. When no rule matches,
*default* is returned. If *default* is also absent, ``LookupError`` is raised.
Example::
policy = RoutingPolicy(
rules=[
RoutingRule("triage", prefer=fast_adapter, max_cost_per_1k=0.5, fallback=cheap_adapter),
RoutingRule("analysis", prefer=smart_adapter),
],
default=cheap_adapter,
)
adapter = policy.resolve("triage")
"""
rules: List[RoutingRule] = field(default_factory=list)
default: Optional[LLMAdapter] = None
def resolve(
self,
task_type: str,
estimated_cost_per_1k: Optional[float] = None,
) -> LLMAdapter:
"""Return the adapter for *task_type*.
Args:
task_type: Logical task identifier.
estimated_cost_per_1k: Caller-supplied cost estimate (USD / 1k tokens).
When provided and a matching rule has ``max_cost_per_1k`` set, the
rule's ``fallback`` is returned if the estimate exceeds the cap.
Returns:
The selected :class:`~llm_connect.adapter.LLMAdapter`.
Raises:
LookupError: No matching rule and no *default* configured.
"""
for rule in self.rules:
if rule.task_type == task_type:
if (
estimated_cost_per_1k is not None
and rule.max_cost_per_1k is not None
and estimated_cost_per_1k > rule.max_cost_per_1k
and rule.fallback is not None
):
return rule.fallback
return rule.prefer
if self.default is not None:
return self.default
raise LookupError(
f"No routing rule for task_type={task_type!r} and no default configured"
)
@dataclass(frozen=True)
class _CandidateMetrics:
adapter_id: str
adapter: LLMAdapter
mean_quality: float
mean_cost_usd: float
order: int
is_static_prefer: bool
@dataclass
class AdaptiveRoutingPolicy(RoutingPolicy):
"""Route to the cheapest adapter whose observed quality clears a floor.
The policy consults a :class:`~llm_connect.quality.QualityLedger` for
observations matching ``task_type`` and adapter id. When the ledger has no
qualifying observations, resolution falls through to ``RoutingPolicy`` so a
caller can use the same policy on day zero and after observations accrue.
"""
ledger: Optional[QualityLedger] = None
adapters_by_id: Mapping[str, LLMAdapter] = field(default_factory=dict)
window_size: int = 20
min_observations: int = 1
max_age: Optional[timedelta] = None
def __post_init__(self) -> None:
if self.window_size <= 0:
raise ValueError("window_size must be positive")
if self.min_observations <= 0:
raise ValueError("min_observations must be positive")
if self.max_age is not None and self.max_age.total_seconds() < 0:
raise ValueError("max_age must be non-negative")
def resolve(
self,
task_type: str,
estimated_cost_per_1k: Optional[float] = None,
*,
quality_floor: Optional[float] = None,
) -> LLMAdapter:
"""Return the adaptive adapter for *task_type*.
Args:
task_type: Logical task identifier.
estimated_cost_per_1k: Passed through to static routing fallback.
quality_floor: Minimum observed mean quality required for adaptive
selection. When omitted, static routing is used.
Returns:
The selected :class:`~llm_connect.adapter.LLMAdapter`.
"""
if quality_floor is None or self.ledger is None:
return super().resolve(task_type, estimated_cost_per_1k)
if not 0 <= quality_floor <= 1:
raise ValueError("quality_floor must be between 0 and 1")
metrics = self._qualifying_candidates(task_type, quality_floor)
if not metrics:
return super().resolve(task_type, estimated_cost_per_1k)
best = min(
metrics,
key=lambda candidate: (
candidate.mean_cost_usd,
0 if candidate.is_static_prefer else 1,
candidate.order,
),
)
return best.adapter
def _qualifying_candidates(
self,
task_type: str,
quality_floor: float,
) -> list[_CandidateMetrics]:
static_prefer = self._static_preferred_adapter(task_type)
candidates: list[_CandidateMetrics] = []
for order, (adapter_id, adapter) in enumerate(self._candidate_entries(task_type)):
observations = self._windowed_observations(task_type, adapter_id)
if len(observations) < self.min_observations:
continue
mean_quality = sum(obs.quality_score for obs in observations) / len(observations)
if mean_quality < quality_floor:
continue
mean_cost = sum(obs.cost_usd for obs in observations) / len(observations)
candidates.append(
_CandidateMetrics(
adapter_id=adapter_id,
adapter=adapter,
mean_quality=mean_quality,
mean_cost_usd=mean_cost,
order=order,
is_static_prefer=adapter is static_prefer,
)
)
return candidates
def _windowed_observations(
self,
task_type: str,
adapter_id: str,
) -> list[QualityObservation]:
if self.ledger is None:
return []
since = None
if self.max_age is not None:
since = datetime.now(timezone.utc) - self.max_age
return self.ledger.recent(
limit=self.window_size,
task_type=task_type,
adapter_id=adapter_id,
since=since,
)
def _candidate_entries(self, task_type: str) -> list[tuple[str, LLMAdapter]]:
entries: list[tuple[str, LLMAdapter]] = []
seen_ids: set[str] = set()
def add(adapter_id: str | None, adapter: LLMAdapter | None) -> None:
if adapter is None or adapter_id is None or adapter_id in seen_ids:
return
seen_ids.add(adapter_id)
entries.append((adapter_id, adapter))
for adapter_id, adapter in self.adapters_by_id.items():
add(adapter_id, adapter)
for adapter in self._static_candidate_adapters(task_type):
add(self._adapter_id_for(adapter), adapter)
return entries
def _static_candidate_adapters(self, task_type: str) -> list[LLMAdapter]:
for rule in self.rules:
if rule.task_type == task_type:
candidates = [rule.prefer]
if rule.fallback is not None:
candidates.append(rule.fallback)
if self.default is not None:
candidates.append(self.default)
return candidates
if self.default is not None:
return [self.default]
return []
def _static_preferred_adapter(self, task_type: str) -> LLMAdapter | None:
for rule in self.rules:
if rule.task_type == task_type:
return rule.prefer
return None
def _adapter_id_for(self, adapter: LLMAdapter) -> str | None:
for adapter_id, candidate in self.adapters_by_id.items():
if candidate is adapter:
return adapter_id
for attribute in ("adapter_id", "id", "name"):
value = getattr(adapter, attribute, None)
if isinstance(value, str) and value.strip():
return value
return None