generated from coulomb/repo-seed
261 lines
8.8 KiB
Python
261 lines
8.8 KiB
Python
"""
|
|
RoutingPolicy — task-type-aware adapter selection (FR-2).
|
|
|
|
Maps task types to preferred adapters with optional cost-cap fallback.
|
|
"""
|
|
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import List, Mapping, Optional
|
|
|
|
from llm_connect.adapter import LLMAdapter
|
|
from llm_connect.quality import QualityLedger, QualityObservation
|
|
|
|
|
|
@dataclass
|
|
class RoutingRule:
|
|
"""Single routing rule binding a task type to an adapter.
|
|
|
|
Attributes:
|
|
task_type: Logical task identifier (e.g. ``"triage"``, ``"summarise"``).
|
|
prefer: Adapter to use when this rule matches.
|
|
max_cost_per_1k: Optional cost ceiling (USD per 1 000 tokens). When the
|
|
caller supplies ``estimated_cost_per_1k`` to :meth:`RoutingPolicy.resolve`
|
|
and it exceeds this cap, *fallback* is returned instead of *prefer*.
|
|
fallback: Adapter to use when the cost cap is breached.
|
|
"""
|
|
|
|
task_type: str
|
|
prefer: LLMAdapter
|
|
max_cost_per_1k: Optional[float] = None
|
|
fallback: Optional[LLMAdapter] = None
|
|
|
|
|
|
@dataclass
|
|
class RoutingPolicy:
|
|
"""Route task types to LLM adapters.
|
|
|
|
Rules are evaluated in order; the first match wins. When no rule matches,
|
|
*default* is returned. If *default* is also absent, ``LookupError`` is raised.
|
|
|
|
Example::
|
|
|
|
policy = RoutingPolicy(
|
|
rules=[
|
|
RoutingRule("triage", prefer=fast_adapter, max_cost_per_1k=0.5, fallback=cheap_adapter),
|
|
RoutingRule("analysis", prefer=smart_adapter),
|
|
],
|
|
default=cheap_adapter,
|
|
)
|
|
adapter = policy.resolve("triage")
|
|
"""
|
|
|
|
rules: List[RoutingRule] = field(default_factory=list)
|
|
default: Optional[LLMAdapter] = None
|
|
|
|
def resolve(
|
|
self,
|
|
task_type: str,
|
|
estimated_cost_per_1k: Optional[float] = None,
|
|
) -> LLMAdapter:
|
|
"""Return the adapter for *task_type*.
|
|
|
|
Args:
|
|
task_type: Logical task identifier.
|
|
estimated_cost_per_1k: Caller-supplied cost estimate (USD / 1k tokens).
|
|
When provided and a matching rule has ``max_cost_per_1k`` set, the
|
|
rule's ``fallback`` is returned if the estimate exceeds the cap.
|
|
|
|
Returns:
|
|
The selected :class:`~llm_connect.adapter.LLMAdapter`.
|
|
|
|
Raises:
|
|
LookupError: No matching rule and no *default* configured.
|
|
"""
|
|
for rule in self.rules:
|
|
if rule.task_type == task_type:
|
|
if (
|
|
estimated_cost_per_1k is not None
|
|
and rule.max_cost_per_1k is not None
|
|
and estimated_cost_per_1k > rule.max_cost_per_1k
|
|
and rule.fallback is not None
|
|
):
|
|
return rule.fallback
|
|
return rule.prefer
|
|
|
|
if self.default is not None:
|
|
return self.default
|
|
|
|
raise LookupError(
|
|
f"No routing rule for task_type={task_type!r} and no default configured"
|
|
)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class _CandidateMetrics:
|
|
adapter_id: str
|
|
adapter: LLMAdapter
|
|
mean_quality: float
|
|
mean_cost_usd: float
|
|
order: int
|
|
is_static_prefer: bool
|
|
|
|
|
|
@dataclass
|
|
class AdaptiveRoutingPolicy(RoutingPolicy):
|
|
"""Route to the cheapest adapter whose observed quality clears a floor.
|
|
|
|
The policy consults a :class:`~llm_connect.quality.QualityLedger` for
|
|
observations matching ``task_type`` and adapter id. When the ledger has no
|
|
qualifying observations, resolution falls through to ``RoutingPolicy`` so a
|
|
caller can use the same policy on day zero and after observations accrue.
|
|
"""
|
|
|
|
ledger: Optional[QualityLedger] = None
|
|
adapters_by_id: Mapping[str, LLMAdapter] = field(default_factory=dict)
|
|
window_size: int = 20
|
|
min_observations: int = 1
|
|
max_age: Optional[timedelta] = None
|
|
|
|
def __post_init__(self) -> None:
|
|
if self.window_size <= 0:
|
|
raise ValueError("window_size must be positive")
|
|
if self.min_observations <= 0:
|
|
raise ValueError("min_observations must be positive")
|
|
if self.max_age is not None and self.max_age.total_seconds() < 0:
|
|
raise ValueError("max_age must be non-negative")
|
|
|
|
def resolve(
|
|
self,
|
|
task_type: str,
|
|
estimated_cost_per_1k: Optional[float] = None,
|
|
*,
|
|
quality_floor: Optional[float] = None,
|
|
) -> LLMAdapter:
|
|
"""Return the adaptive adapter for *task_type*.
|
|
|
|
Args:
|
|
task_type: Logical task identifier.
|
|
estimated_cost_per_1k: Passed through to static routing fallback.
|
|
quality_floor: Minimum observed mean quality required for adaptive
|
|
selection. When omitted, static routing is used.
|
|
|
|
Returns:
|
|
The selected :class:`~llm_connect.adapter.LLMAdapter`.
|
|
"""
|
|
if quality_floor is None or self.ledger is None:
|
|
return super().resolve(task_type, estimated_cost_per_1k)
|
|
if not 0 <= quality_floor <= 1:
|
|
raise ValueError("quality_floor must be between 0 and 1")
|
|
|
|
metrics = self._qualifying_candidates(task_type, quality_floor)
|
|
if not metrics:
|
|
return super().resolve(task_type, estimated_cost_per_1k)
|
|
|
|
best = min(
|
|
metrics,
|
|
key=lambda candidate: (
|
|
candidate.mean_cost_usd,
|
|
0 if candidate.is_static_prefer else 1,
|
|
candidate.order,
|
|
),
|
|
)
|
|
return best.adapter
|
|
|
|
def _qualifying_candidates(
|
|
self,
|
|
task_type: str,
|
|
quality_floor: float,
|
|
) -> list[_CandidateMetrics]:
|
|
static_prefer = self._static_preferred_adapter(task_type)
|
|
candidates: list[_CandidateMetrics] = []
|
|
for order, (adapter_id, adapter) in enumerate(self._candidate_entries(task_type)):
|
|
observations = self._windowed_observations(task_type, adapter_id)
|
|
if len(observations) < self.min_observations:
|
|
continue
|
|
|
|
mean_quality = sum(obs.quality_score for obs in observations) / len(observations)
|
|
if mean_quality < quality_floor:
|
|
continue
|
|
|
|
mean_cost = sum(obs.cost_usd for obs in observations) / len(observations)
|
|
candidates.append(
|
|
_CandidateMetrics(
|
|
adapter_id=adapter_id,
|
|
adapter=adapter,
|
|
mean_quality=mean_quality,
|
|
mean_cost_usd=mean_cost,
|
|
order=order,
|
|
is_static_prefer=adapter is static_prefer,
|
|
)
|
|
)
|
|
return candidates
|
|
|
|
def _windowed_observations(
|
|
self,
|
|
task_type: str,
|
|
adapter_id: str,
|
|
) -> list[QualityObservation]:
|
|
if self.ledger is None:
|
|
return []
|
|
|
|
since = None
|
|
if self.max_age is not None:
|
|
since = datetime.now(timezone.utc) - self.max_age
|
|
|
|
return self.ledger.recent(
|
|
limit=self.window_size,
|
|
task_type=task_type,
|
|
adapter_id=adapter_id,
|
|
since=since,
|
|
)
|
|
|
|
def _candidate_entries(self, task_type: str) -> list[tuple[str, LLMAdapter]]:
|
|
entries: list[tuple[str, LLMAdapter]] = []
|
|
seen_ids: set[str] = set()
|
|
|
|
def add(adapter_id: str | None, adapter: LLMAdapter | None) -> None:
|
|
if adapter is None or adapter_id is None or adapter_id in seen_ids:
|
|
return
|
|
seen_ids.add(adapter_id)
|
|
entries.append((adapter_id, adapter))
|
|
|
|
for adapter_id, adapter in self.adapters_by_id.items():
|
|
add(adapter_id, adapter)
|
|
|
|
for adapter in self._static_candidate_adapters(task_type):
|
|
add(self._adapter_id_for(adapter), adapter)
|
|
|
|
return entries
|
|
|
|
def _static_candidate_adapters(self, task_type: str) -> list[LLMAdapter]:
|
|
for rule in self.rules:
|
|
if rule.task_type == task_type:
|
|
candidates = [rule.prefer]
|
|
if rule.fallback is not None:
|
|
candidates.append(rule.fallback)
|
|
if self.default is not None:
|
|
candidates.append(self.default)
|
|
return candidates
|
|
|
|
if self.default is not None:
|
|
return [self.default]
|
|
return []
|
|
|
|
def _static_preferred_adapter(self, task_type: str) -> LLMAdapter | None:
|
|
for rule in self.rules:
|
|
if rule.task_type == task_type:
|
|
return rule.prefer
|
|
return None
|
|
|
|
def _adapter_id_for(self, adapter: LLMAdapter) -> str | None:
|
|
for adapter_id, candidate in self.adapters_by_id.items():
|
|
if candidate is adapter:
|
|
return adapter_id
|
|
|
|
for attribute in ("adapter_id", "id", "name"):
|
|
value = getattr(adapter, attribute, None)
|
|
if isinstance(value, str) and value.strip():
|
|
return value
|
|
return None
|