Files
llm-connect/llm_connect/rates.py
tegwick c11c6afa3f
Some checks failed
CI / test (3.10) (push) Has been cancelled
CI / test (3.11) (push) Has been cancelled
CI / test (3.12) (push) Has been cancelled
Implement-LLM-WP-0005-cost-model-estimators
2026-05-19 05:02:20 +02:00

274 lines
9.7 KiB
Python

"""Model rate registry for preview and post-hoc cost estimation."""
from __future__ import annotations
from collections.abc import Mapping
from dataclasses import dataclass
from pathlib import Path
from typing import Any
DEFAULT_RATE_SOURCE_URL = "https://openrouter.ai/models"
DEFAULT_RATE_CAPTURED_AT = "2026-05-17"
DEFAULT_RATE_CURRENCY = "USD"
@dataclass(frozen=True)
class ModelRate:
"""USD-denominated list price for one model."""
model_id: str
prompt_per_1k: float
completion_per_1k: float
currency: str = DEFAULT_RATE_CURRENCY
source_url: str = ""
captured_at: str = ""
def __post_init__(self) -> None:
model_id = str(self.model_id).strip()
currency = str(self.currency or DEFAULT_RATE_CURRENCY).strip().upper()
if not model_id:
raise ValueError("model_id must be a non-empty string")
if not currency:
raise ValueError("currency must be a non-empty string")
prompt_rate = _non_negative_float("prompt_per_1k", self.prompt_per_1k)
completion_rate = _non_negative_float("completion_per_1k", self.completion_per_1k)
object.__setattr__(self, "model_id", model_id)
object.__setattr__(self, "prompt_per_1k", prompt_rate)
object.__setattr__(self, "completion_per_1k", completion_rate)
object.__setattr__(self, "currency", currency)
object.__setattr__(self, "source_url", str(self.source_url or ""))
object.__setattr__(self, "captured_at", str(self.captured_at or ""))
class ModelRateRegistry:
"""Lookup table for model list prices."""
def __init__(self, rates: Mapping[str, ModelRate | Mapping[str, Any]] | None = None) -> None:
self._rates: dict[str, ModelRate] = {}
for model_id, rate in (rates or {}).items():
model_rate = _coerce_rate(model_id, rate)
self._rates[model_rate.model_id] = model_rate
def get(self, model_id: str) -> ModelRate | None:
"""Return the rate for *model_id*, or ``None`` when absent."""
return self._rates.get(str(model_id).strip())
def all(self) -> dict[str, ModelRate]:
"""Return a copy of the registry mapping."""
return dict(self._rates)
@classmethod
def default(cls) -> "ModelRateRegistry":
"""Return the bundled OpenRouter list-price snapshot."""
return cls(_default_rate_payload())
@classmethod
def from_yaml(cls, path: Path | str) -> "ModelRateRegistry":
"""Load rates from a YAML file.
The expected shape matches the historic infospace-bench table::
currency: USD
source_url: https://openrouter.ai/models
captured_at: "2026-05-17"
rates:
openai/gpt-4o-mini:
prompt_per_1k: 0.00015
completion_per_1k: 0.00060
PyYAML is used when installed; otherwise a small parser handles this
schema so llm-connect keeps its current lightweight dependency surface.
"""
payload = _load_yaml_mapping(Path(path))
return cls(_rates_from_payload(payload))
def merged_with(self, override: "ModelRateRegistry") -> "ModelRateRegistry":
"""Return a new registry where *override* entries win by model id."""
merged = self.all()
merged.update(override.all())
return ModelRateRegistry(merged)
_DEFAULT_RATES: dict[str, tuple[float, float]] = {
"openai/gpt-4o-mini": (0.00015, 0.00060),
"openai/gpt-4o": (0.0025, 0.01),
"openai/gpt-4-turbo": (0.01, 0.03),
"anthropic/claude-3.5-sonnet": (0.003, 0.015),
"anthropic/claude-3.5-haiku": (0.0008, 0.004),
"anthropic/claude-3-opus": (0.015, 0.075),
"google/gemini-1.5-flash": (0.000075, 0.0003),
"google/gemini-1.5-pro": (0.00125, 0.005),
"meta-llama/llama-3.1-70b-instruct": (0.00059, 0.00079),
}
def _default_rate_payload() -> dict[str, ModelRate]:
return {
model_id: ModelRate(
model_id=model_id,
prompt_per_1k=prompt_rate,
completion_per_1k=completion_rate,
currency=DEFAULT_RATE_CURRENCY,
source_url=DEFAULT_RATE_SOURCE_URL,
captured_at=DEFAULT_RATE_CAPTURED_AT,
)
for model_id, (prompt_rate, completion_rate) in _DEFAULT_RATES.items()
}
def _coerce_rate(model_id: str, rate: ModelRate | Mapping[str, Any]) -> ModelRate:
if isinstance(rate, ModelRate):
return rate
if not isinstance(rate, Mapping):
raise TypeError(f"Rate for {model_id!r} must be a ModelRate or mapping")
return ModelRate(
model_id=str(model_id),
prompt_per_1k=rate["prompt_per_1k"],
completion_per_1k=rate["completion_per_1k"],
currency=str(rate.get("currency") or DEFAULT_RATE_CURRENCY),
source_url=str(rate.get("source_url") or ""),
captured_at=str(rate.get("captured_at") or ""),
)
def _rates_from_payload(payload: Mapping[str, Any]) -> dict[str, ModelRate]:
rates_payload = payload.get("rates")
if not isinstance(rates_payload, Mapping):
raise ValueError("Rate YAML must contain a 'rates' mapping")
currency = str(payload.get("currency") or DEFAULT_RATE_CURRENCY)
source_url = str(payload.get("source_url") or "")
captured_at = str(payload.get("captured_at") or "")
rates: dict[str, ModelRate] = {}
for model_id, raw_rate in rates_payload.items():
if not isinstance(raw_rate, Mapping):
raise ValueError(f"Rate entry for {model_id!r} must be a mapping")
rates[str(model_id)] = ModelRate(
model_id=str(model_id),
prompt_per_1k=raw_rate["prompt_per_1k"],
completion_per_1k=raw_rate["completion_per_1k"],
currency=str(raw_rate.get("currency") or currency),
source_url=str(raw_rate.get("source_url") or source_url),
captured_at=str(raw_rate.get("captured_at") or captured_at),
)
return rates
def _non_negative_float(name: str, value: Any) -> float:
if isinstance(value, bool):
raise ValueError(f"{name} must be a non-negative number")
try:
number = float(value)
except (TypeError, ValueError) as exc:
raise ValueError(f"{name} must be a non-negative number") from exc
if number < 0:
raise ValueError(f"{name} must be a non-negative number")
return number
def _load_yaml_mapping(path: Path) -> Mapping[str, Any]:
try:
import yaml
except ImportError:
return _parse_rate_yaml(path.read_text(encoding="utf-8"))
data = yaml.safe_load(path.read_text(encoding="utf-8")) or {}
if not isinstance(data, Mapping):
raise ValueError("Rate YAML root must be a mapping")
return data
def _parse_rate_yaml(text: str) -> dict[str, Any]:
lines: list[tuple[int, str]] = []
for raw_line in text.splitlines():
line = _normalise_yaml_line(raw_line)
if line is not None:
lines.append(line)
data: dict[str, Any] = {}
index = 0
while index < len(lines):
indent, content = lines[index]
if indent != 0:
raise ValueError("Only top-level mappings are supported in rate YAML")
key, raw_value = _split_yaml_key_value(content)
if key == "rates" and raw_value == "":
rates, index = _parse_rates_block(lines, index + 1)
data["rates"] = rates
continue
data[key] = _parse_yaml_scalar(raw_value)
index += 1
return data
def _parse_rates_block(
lines: list[tuple[int, str]],
index: int,
) -> tuple[dict[str, dict[str, Any]], int]:
rates: dict[str, dict[str, Any]] = {}
while index < len(lines):
indent, content = lines[index]
if indent == 0:
break
if indent != 2:
raise ValueError("Rate model entries must be indented by two spaces")
model_id, raw_value = _split_yaml_key_value(content)
if raw_value:
raise ValueError(f"Rate entry for {model_id!r} must be a nested mapping")
entry: dict[str, Any] = {}
index += 1
while index < len(lines):
child_indent, child_content = lines[index]
if child_indent <= indent:
break
if child_indent != 4:
raise ValueError("Rate fields must be indented by four spaces")
child_key, child_value = _split_yaml_key_value(child_content)
entry[child_key] = _parse_yaml_scalar(child_value)
index += 1
rates[model_id] = entry
return rates, index
def _normalise_yaml_line(line: str) -> tuple[int, str] | None:
stripped = _strip_yaml_comment(line.rstrip())
if not stripped.strip():
return None
indent = len(stripped) - len(stripped.lstrip(" "))
return indent, stripped.strip()
def _strip_yaml_comment(line: str) -> str:
quote: str | None = None
for index, char in enumerate(line):
if char in {"'", '"'}:
quote = None if quote == char else char if quote is None else quote
elif char == "#" and quote is None:
return line[:index]
return line
def _split_yaml_key_value(content: str) -> tuple[str, str]:
key, separator, value = content.partition(":")
if not separator:
raise ValueError(f"Invalid YAML mapping line: {content!r}")
return key.strip().strip("'\""), value.strip()
def _parse_yaml_scalar(value: str) -> Any:
if value == "":
return ""
if (value.startswith('"') and value.endswith('"')) or (
value.startswith("'") and value.endswith("'")
):
return value[1:-1]
if value.lower() in {"null", "none", "~"}:
return None
try:
if any(char in value for char in (".", "e", "E")):
return float(value)
return int(value)
except ValueError:
return value