""" RoutingPolicy — task-type-aware adapter selection (FR-2). Maps task types to preferred adapters with optional cost-cap fallback. """ from dataclasses import dataclass, field from datetime import datetime, timedelta, timezone from typing import List, Mapping, Optional from llm_connect.adapter import LLMAdapter from llm_connect.quality import QualityLedger, QualityObservation @dataclass class RoutingRule: """Single routing rule binding a task type to an adapter. Attributes: task_type: Logical task identifier (e.g. ``"triage"``, ``"summarise"``). prefer: Adapter to use when this rule matches. max_cost_per_1k: Optional cost ceiling (USD per 1 000 tokens). When the caller supplies ``estimated_cost_per_1k`` to :meth:`RoutingPolicy.resolve` and it exceeds this cap, *fallback* is returned instead of *prefer*. fallback: Adapter to use when the cost cap is breached. """ task_type: str prefer: LLMAdapter max_cost_per_1k: Optional[float] = None fallback: Optional[LLMAdapter] = None @dataclass class RoutingPolicy: """Route task types to LLM adapters. Rules are evaluated in order; the first match wins. When no rule matches, *default* is returned. If *default* is also absent, ``LookupError`` is raised. Example:: policy = RoutingPolicy( rules=[ RoutingRule("triage", prefer=fast_adapter, max_cost_per_1k=0.5, fallback=cheap_adapter), RoutingRule("analysis", prefer=smart_adapter), ], default=cheap_adapter, ) adapter = policy.resolve("triage") """ rules: List[RoutingRule] = field(default_factory=list) default: Optional[LLMAdapter] = None def resolve( self, task_type: str, estimated_cost_per_1k: Optional[float] = None, ) -> LLMAdapter: """Return the adapter for *task_type*. Args: task_type: Logical task identifier. estimated_cost_per_1k: Caller-supplied cost estimate (USD / 1k tokens). When provided and a matching rule has ``max_cost_per_1k`` set, the rule's ``fallback`` is returned if the estimate exceeds the cap. Returns: The selected :class:`~llm_connect.adapter.LLMAdapter`. Raises: LookupError: No matching rule and no *default* configured. """ for rule in self.rules: if rule.task_type == task_type: if ( estimated_cost_per_1k is not None and rule.max_cost_per_1k is not None and estimated_cost_per_1k > rule.max_cost_per_1k and rule.fallback is not None ): return rule.fallback return rule.prefer if self.default is not None: return self.default raise LookupError( f"No routing rule for task_type={task_type!r} and no default configured" ) @dataclass(frozen=True) class _CandidateMetrics: adapter_id: str adapter: LLMAdapter mean_quality: float mean_cost_usd: float order: int is_static_prefer: bool @dataclass class AdaptiveRoutingPolicy(RoutingPolicy): """Route to the cheapest adapter whose observed quality clears a floor. The policy consults a :class:`~llm_connect.quality.QualityLedger` for observations matching ``task_type`` and adapter id. When the ledger has no qualifying observations, resolution falls through to ``RoutingPolicy`` so a caller can use the same policy on day zero and after observations accrue. """ ledger: Optional[QualityLedger] = None adapters_by_id: Mapping[str, LLMAdapter] = field(default_factory=dict) window_size: int = 20 min_observations: int = 1 max_age: Optional[timedelta] = None def __post_init__(self) -> None: if self.window_size <= 0: raise ValueError("window_size must be positive") if self.min_observations <= 0: raise ValueError("min_observations must be positive") if self.max_age is not None and self.max_age.total_seconds() < 0: raise ValueError("max_age must be non-negative") def resolve( self, task_type: str, estimated_cost_per_1k: Optional[float] = None, *, quality_floor: Optional[float] = None, ) -> LLMAdapter: """Return the adaptive adapter for *task_type*. Args: task_type: Logical task identifier. estimated_cost_per_1k: Passed through to static routing fallback. quality_floor: Minimum observed mean quality required for adaptive selection. When omitted, static routing is used. Returns: The selected :class:`~llm_connect.adapter.LLMAdapter`. """ if quality_floor is None or self.ledger is None: return super().resolve(task_type, estimated_cost_per_1k) if not 0 <= quality_floor <= 1: raise ValueError("quality_floor must be between 0 and 1") metrics = self._qualifying_candidates(task_type, quality_floor) if not metrics: return super().resolve(task_type, estimated_cost_per_1k) best = min( metrics, key=lambda candidate: ( candidate.mean_cost_usd, 0 if candidate.is_static_prefer else 1, candidate.order, ), ) return best.adapter def _qualifying_candidates( self, task_type: str, quality_floor: float, ) -> list[_CandidateMetrics]: static_prefer = self._static_preferred_adapter(task_type) candidates: list[_CandidateMetrics] = [] for order, (adapter_id, adapter) in enumerate(self._candidate_entries(task_type)): observations = self._windowed_observations(task_type, adapter_id) if len(observations) < self.min_observations: continue mean_quality = sum(obs.quality_score for obs in observations) / len(observations) if mean_quality < quality_floor: continue mean_cost = sum(obs.cost_usd for obs in observations) / len(observations) candidates.append( _CandidateMetrics( adapter_id=adapter_id, adapter=adapter, mean_quality=mean_quality, mean_cost_usd=mean_cost, order=order, is_static_prefer=adapter is static_prefer, ) ) return candidates def _windowed_observations( self, task_type: str, adapter_id: str, ) -> list[QualityObservation]: if self.ledger is None: return [] since = None if self.max_age is not None: since = datetime.now(timezone.utc) - self.max_age return self.ledger.recent( limit=self.window_size, task_type=task_type, adapter_id=adapter_id, since=since, ) def _candidate_entries(self, task_type: str) -> list[tuple[str, LLMAdapter]]: entries: list[tuple[str, LLMAdapter]] = [] seen_ids: set[str] = set() def add(adapter_id: str | None, adapter: LLMAdapter | None) -> None: if adapter is None or adapter_id is None or adapter_id in seen_ids: return seen_ids.add(adapter_id) entries.append((adapter_id, adapter)) for adapter_id, adapter in self.adapters_by_id.items(): add(adapter_id, adapter) for adapter in self._static_candidate_adapters(task_type): add(self._adapter_id_for(adapter), adapter) return entries def _static_candidate_adapters(self, task_type: str) -> list[LLMAdapter]: for rule in self.rules: if rule.task_type == task_type: candidates = [rule.prefer] if rule.fallback is not None: candidates.append(rule.fallback) if self.default is not None: candidates.append(self.default) return candidates if self.default is not None: return [self.default] return [] def _static_preferred_adapter(self, task_type: str) -> LLMAdapter | None: for rule in self.rules: if rule.task_type == task_type: return rule.prefer return None def _adapter_id_for(self, adapter: LLMAdapter) -> str | None: for adapter_id, candidate in self.adapters_by_id.items(): if candidate is adapter: return adapter_id for attribute in ("adapter_id", "id", "name"): value = getattr(adapter, attribute, None) if isinstance(value, str) and value.strip(): return value return None