"""Model rate registry for preview and post-hoc cost estimation.""" from __future__ import annotations from collections.abc import Mapping from dataclasses import dataclass from pathlib import Path from typing import Any DEFAULT_RATE_SOURCE_URL = "https://openrouter.ai/models" DEFAULT_RATE_CAPTURED_AT = "2026-05-17" DEFAULT_RATE_CURRENCY = "USD" @dataclass(frozen=True) class ModelRate: """USD-denominated list price for one model.""" model_id: str prompt_per_1k: float completion_per_1k: float currency: str = DEFAULT_RATE_CURRENCY source_url: str = "" captured_at: str = "" def __post_init__(self) -> None: model_id = str(self.model_id).strip() currency = str(self.currency or DEFAULT_RATE_CURRENCY).strip().upper() if not model_id: raise ValueError("model_id must be a non-empty string") if not currency: raise ValueError("currency must be a non-empty string") prompt_rate = _non_negative_float("prompt_per_1k", self.prompt_per_1k) completion_rate = _non_negative_float("completion_per_1k", self.completion_per_1k) object.__setattr__(self, "model_id", model_id) object.__setattr__(self, "prompt_per_1k", prompt_rate) object.__setattr__(self, "completion_per_1k", completion_rate) object.__setattr__(self, "currency", currency) object.__setattr__(self, "source_url", str(self.source_url or "")) object.__setattr__(self, "captured_at", str(self.captured_at or "")) class ModelRateRegistry: """Lookup table for model list prices.""" def __init__(self, rates: Mapping[str, ModelRate | Mapping[str, Any]] | None = None) -> None: self._rates: dict[str, ModelRate] = {} for model_id, rate in (rates or {}).items(): model_rate = _coerce_rate(model_id, rate) self._rates[model_rate.model_id] = model_rate def get(self, model_id: str) -> ModelRate | None: """Return the rate for *model_id*, or ``None`` when absent.""" return self._rates.get(str(model_id).strip()) def all(self) -> dict[str, ModelRate]: """Return a copy of the registry mapping.""" return dict(self._rates) @classmethod def default(cls) -> "ModelRateRegistry": """Return the bundled OpenRouter list-price snapshot.""" return cls(_default_rate_payload()) @classmethod def from_yaml(cls, path: Path | str) -> "ModelRateRegistry": """Load rates from a YAML file. The expected shape matches the historic infospace-bench table:: currency: USD source_url: https://openrouter.ai/models captured_at: "2026-05-17" rates: openai/gpt-4o-mini: prompt_per_1k: 0.00015 completion_per_1k: 0.00060 PyYAML is used when installed; otherwise a small parser handles this schema so llm-connect keeps its current lightweight dependency surface. """ payload = _load_yaml_mapping(Path(path)) return cls(_rates_from_payload(payload)) def merged_with(self, override: "ModelRateRegistry") -> "ModelRateRegistry": """Return a new registry where *override* entries win by model id.""" merged = self.all() merged.update(override.all()) return ModelRateRegistry(merged) _DEFAULT_RATES: dict[str, tuple[float, float]] = { "openai/gpt-4o-mini": (0.00015, 0.00060), "openai/gpt-4o": (0.0025, 0.01), "openai/gpt-4-turbo": (0.01, 0.03), "anthropic/claude-3.5-sonnet": (0.003, 0.015), "anthropic/claude-3.5-haiku": (0.0008, 0.004), "anthropic/claude-3-opus": (0.015, 0.075), "google/gemini-1.5-flash": (0.000075, 0.0003), "google/gemini-1.5-pro": (0.00125, 0.005), "meta-llama/llama-3.1-70b-instruct": (0.00059, 0.00079), } def _default_rate_payload() -> dict[str, ModelRate]: return { model_id: ModelRate( model_id=model_id, prompt_per_1k=prompt_rate, completion_per_1k=completion_rate, currency=DEFAULT_RATE_CURRENCY, source_url=DEFAULT_RATE_SOURCE_URL, captured_at=DEFAULT_RATE_CAPTURED_AT, ) for model_id, (prompt_rate, completion_rate) in _DEFAULT_RATES.items() } def _coerce_rate(model_id: str, rate: ModelRate | Mapping[str, Any]) -> ModelRate: if isinstance(rate, ModelRate): return rate if not isinstance(rate, Mapping): raise TypeError(f"Rate for {model_id!r} must be a ModelRate or mapping") return ModelRate( model_id=str(model_id), prompt_per_1k=rate["prompt_per_1k"], completion_per_1k=rate["completion_per_1k"], currency=str(rate.get("currency") or DEFAULT_RATE_CURRENCY), source_url=str(rate.get("source_url") or ""), captured_at=str(rate.get("captured_at") or ""), ) def _rates_from_payload(payload: Mapping[str, Any]) -> dict[str, ModelRate]: rates_payload = payload.get("rates") if not isinstance(rates_payload, Mapping): raise ValueError("Rate YAML must contain a 'rates' mapping") currency = str(payload.get("currency") or DEFAULT_RATE_CURRENCY) source_url = str(payload.get("source_url") or "") captured_at = str(payload.get("captured_at") or "") rates: dict[str, ModelRate] = {} for model_id, raw_rate in rates_payload.items(): if not isinstance(raw_rate, Mapping): raise ValueError(f"Rate entry for {model_id!r} must be a mapping") rates[str(model_id)] = ModelRate( model_id=str(model_id), prompt_per_1k=raw_rate["prompt_per_1k"], completion_per_1k=raw_rate["completion_per_1k"], currency=str(raw_rate.get("currency") or currency), source_url=str(raw_rate.get("source_url") or source_url), captured_at=str(raw_rate.get("captured_at") or captured_at), ) return rates def _non_negative_float(name: str, value: Any) -> float: if isinstance(value, bool): raise ValueError(f"{name} must be a non-negative number") try: number = float(value) except (TypeError, ValueError) as exc: raise ValueError(f"{name} must be a non-negative number") from exc if number < 0: raise ValueError(f"{name} must be a non-negative number") return number def _load_yaml_mapping(path: Path) -> Mapping[str, Any]: try: import yaml except ImportError: return _parse_rate_yaml(path.read_text(encoding="utf-8")) data = yaml.safe_load(path.read_text(encoding="utf-8")) or {} if not isinstance(data, Mapping): raise ValueError("Rate YAML root must be a mapping") return data def _parse_rate_yaml(text: str) -> dict[str, Any]: lines: list[tuple[int, str]] = [] for raw_line in text.splitlines(): line = _normalise_yaml_line(raw_line) if line is not None: lines.append(line) data: dict[str, Any] = {} index = 0 while index < len(lines): indent, content = lines[index] if indent != 0: raise ValueError("Only top-level mappings are supported in rate YAML") key, raw_value = _split_yaml_key_value(content) if key == "rates" and raw_value == "": rates, index = _parse_rates_block(lines, index + 1) data["rates"] = rates continue data[key] = _parse_yaml_scalar(raw_value) index += 1 return data def _parse_rates_block( lines: list[tuple[int, str]], index: int, ) -> tuple[dict[str, dict[str, Any]], int]: rates: dict[str, dict[str, Any]] = {} while index < len(lines): indent, content = lines[index] if indent == 0: break if indent != 2: raise ValueError("Rate model entries must be indented by two spaces") model_id, raw_value = _split_yaml_key_value(content) if raw_value: raise ValueError(f"Rate entry for {model_id!r} must be a nested mapping") entry: dict[str, Any] = {} index += 1 while index < len(lines): child_indent, child_content = lines[index] if child_indent <= indent: break if child_indent != 4: raise ValueError("Rate fields must be indented by four spaces") child_key, child_value = _split_yaml_key_value(child_content) entry[child_key] = _parse_yaml_scalar(child_value) index += 1 rates[model_id] = entry return rates, index def _normalise_yaml_line(line: str) -> tuple[int, str] | None: stripped = _strip_yaml_comment(line.rstrip()) if not stripped.strip(): return None indent = len(stripped) - len(stripped.lstrip(" ")) return indent, stripped.strip() def _strip_yaml_comment(line: str) -> str: quote: str | None = None for index, char in enumerate(line): if char in {"'", '"'}: quote = None if quote == char else char if quote is None else quote elif char == "#" and quote is None: return line[:index] return line def _split_yaml_key_value(content: str) -> tuple[str, str]: key, separator, value = content.partition(":") if not separator: raise ValueError(f"Invalid YAML mapping line: {content!r}") return key.strip().strip("'\""), value.strip() def _parse_yaml_scalar(value: str) -> Any: if value == "": return "" if (value.startswith('"') and value.endswith('"')) or ( value.startswith("'") and value.endswith("'") ): return value[1:-1] if value.lower() in {"null", "none", "~"}: return None try: if any(char in value for char in (".", "e", "E")): return float(value) return int(value) except ValueError: return value