generated from coulomb/repo-seed
274 lines
9.7 KiB
Python
274 lines
9.7 KiB
Python
"""Model rate registry for preview and post-hoc cost estimation."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Mapping
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
|
|
DEFAULT_RATE_SOURCE_URL = "https://openrouter.ai/models"
|
|
DEFAULT_RATE_CAPTURED_AT = "2026-05-17"
|
|
DEFAULT_RATE_CURRENCY = "USD"
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ModelRate:
|
|
"""USD-denominated list price for one model."""
|
|
|
|
model_id: str
|
|
prompt_per_1k: float
|
|
completion_per_1k: float
|
|
currency: str = DEFAULT_RATE_CURRENCY
|
|
source_url: str = ""
|
|
captured_at: str = ""
|
|
|
|
def __post_init__(self) -> None:
|
|
model_id = str(self.model_id).strip()
|
|
currency = str(self.currency or DEFAULT_RATE_CURRENCY).strip().upper()
|
|
if not model_id:
|
|
raise ValueError("model_id must be a non-empty string")
|
|
if not currency:
|
|
raise ValueError("currency must be a non-empty string")
|
|
prompt_rate = _non_negative_float("prompt_per_1k", self.prompt_per_1k)
|
|
completion_rate = _non_negative_float("completion_per_1k", self.completion_per_1k)
|
|
|
|
object.__setattr__(self, "model_id", model_id)
|
|
object.__setattr__(self, "prompt_per_1k", prompt_rate)
|
|
object.__setattr__(self, "completion_per_1k", completion_rate)
|
|
object.__setattr__(self, "currency", currency)
|
|
object.__setattr__(self, "source_url", str(self.source_url or ""))
|
|
object.__setattr__(self, "captured_at", str(self.captured_at or ""))
|
|
|
|
|
|
class ModelRateRegistry:
|
|
"""Lookup table for model list prices."""
|
|
|
|
def __init__(self, rates: Mapping[str, ModelRate | Mapping[str, Any]] | None = None) -> None:
|
|
self._rates: dict[str, ModelRate] = {}
|
|
for model_id, rate in (rates or {}).items():
|
|
model_rate = _coerce_rate(model_id, rate)
|
|
self._rates[model_rate.model_id] = model_rate
|
|
|
|
def get(self, model_id: str) -> ModelRate | None:
|
|
"""Return the rate for *model_id*, or ``None`` when absent."""
|
|
return self._rates.get(str(model_id).strip())
|
|
|
|
def all(self) -> dict[str, ModelRate]:
|
|
"""Return a copy of the registry mapping."""
|
|
return dict(self._rates)
|
|
|
|
@classmethod
|
|
def default(cls) -> "ModelRateRegistry":
|
|
"""Return the bundled OpenRouter list-price snapshot."""
|
|
return cls(_default_rate_payload())
|
|
|
|
@classmethod
|
|
def from_yaml(cls, path: Path | str) -> "ModelRateRegistry":
|
|
"""Load rates from a YAML file.
|
|
|
|
The expected shape matches the historic infospace-bench table::
|
|
|
|
currency: USD
|
|
source_url: https://openrouter.ai/models
|
|
captured_at: "2026-05-17"
|
|
rates:
|
|
openai/gpt-4o-mini:
|
|
prompt_per_1k: 0.00015
|
|
completion_per_1k: 0.00060
|
|
|
|
PyYAML is used when installed; otherwise a small parser handles this
|
|
schema so llm-connect keeps its current lightweight dependency surface.
|
|
"""
|
|
payload = _load_yaml_mapping(Path(path))
|
|
return cls(_rates_from_payload(payload))
|
|
|
|
def merged_with(self, override: "ModelRateRegistry") -> "ModelRateRegistry":
|
|
"""Return a new registry where *override* entries win by model id."""
|
|
merged = self.all()
|
|
merged.update(override.all())
|
|
return ModelRateRegistry(merged)
|
|
|
|
|
|
_DEFAULT_RATES: dict[str, tuple[float, float]] = {
|
|
"openai/gpt-4o-mini": (0.00015, 0.00060),
|
|
"openai/gpt-4o": (0.0025, 0.01),
|
|
"openai/gpt-4-turbo": (0.01, 0.03),
|
|
"anthropic/claude-3.5-sonnet": (0.003, 0.015),
|
|
"anthropic/claude-3.5-haiku": (0.0008, 0.004),
|
|
"anthropic/claude-3-opus": (0.015, 0.075),
|
|
"google/gemini-1.5-flash": (0.000075, 0.0003),
|
|
"google/gemini-1.5-pro": (0.00125, 0.005),
|
|
"meta-llama/llama-3.1-70b-instruct": (0.00059, 0.00079),
|
|
}
|
|
|
|
|
|
def _default_rate_payload() -> dict[str, ModelRate]:
|
|
return {
|
|
model_id: ModelRate(
|
|
model_id=model_id,
|
|
prompt_per_1k=prompt_rate,
|
|
completion_per_1k=completion_rate,
|
|
currency=DEFAULT_RATE_CURRENCY,
|
|
source_url=DEFAULT_RATE_SOURCE_URL,
|
|
captured_at=DEFAULT_RATE_CAPTURED_AT,
|
|
)
|
|
for model_id, (prompt_rate, completion_rate) in _DEFAULT_RATES.items()
|
|
}
|
|
|
|
|
|
def _coerce_rate(model_id: str, rate: ModelRate | Mapping[str, Any]) -> ModelRate:
|
|
if isinstance(rate, ModelRate):
|
|
return rate
|
|
if not isinstance(rate, Mapping):
|
|
raise TypeError(f"Rate for {model_id!r} must be a ModelRate or mapping")
|
|
return ModelRate(
|
|
model_id=str(model_id),
|
|
prompt_per_1k=rate["prompt_per_1k"],
|
|
completion_per_1k=rate["completion_per_1k"],
|
|
currency=str(rate.get("currency") or DEFAULT_RATE_CURRENCY),
|
|
source_url=str(rate.get("source_url") or ""),
|
|
captured_at=str(rate.get("captured_at") or ""),
|
|
)
|
|
|
|
|
|
def _rates_from_payload(payload: Mapping[str, Any]) -> dict[str, ModelRate]:
|
|
rates_payload = payload.get("rates")
|
|
if not isinstance(rates_payload, Mapping):
|
|
raise ValueError("Rate YAML must contain a 'rates' mapping")
|
|
|
|
currency = str(payload.get("currency") or DEFAULT_RATE_CURRENCY)
|
|
source_url = str(payload.get("source_url") or "")
|
|
captured_at = str(payload.get("captured_at") or "")
|
|
rates: dict[str, ModelRate] = {}
|
|
for model_id, raw_rate in rates_payload.items():
|
|
if not isinstance(raw_rate, Mapping):
|
|
raise ValueError(f"Rate entry for {model_id!r} must be a mapping")
|
|
rates[str(model_id)] = ModelRate(
|
|
model_id=str(model_id),
|
|
prompt_per_1k=raw_rate["prompt_per_1k"],
|
|
completion_per_1k=raw_rate["completion_per_1k"],
|
|
currency=str(raw_rate.get("currency") or currency),
|
|
source_url=str(raw_rate.get("source_url") or source_url),
|
|
captured_at=str(raw_rate.get("captured_at") or captured_at),
|
|
)
|
|
return rates
|
|
|
|
|
|
def _non_negative_float(name: str, value: Any) -> float:
|
|
if isinstance(value, bool):
|
|
raise ValueError(f"{name} must be a non-negative number")
|
|
try:
|
|
number = float(value)
|
|
except (TypeError, ValueError) as exc:
|
|
raise ValueError(f"{name} must be a non-negative number") from exc
|
|
if number < 0:
|
|
raise ValueError(f"{name} must be a non-negative number")
|
|
return number
|
|
|
|
|
|
def _load_yaml_mapping(path: Path) -> Mapping[str, Any]:
|
|
try:
|
|
import yaml
|
|
except ImportError:
|
|
return _parse_rate_yaml(path.read_text(encoding="utf-8"))
|
|
|
|
data = yaml.safe_load(path.read_text(encoding="utf-8")) or {}
|
|
if not isinstance(data, Mapping):
|
|
raise ValueError("Rate YAML root must be a mapping")
|
|
return data
|
|
|
|
|
|
def _parse_rate_yaml(text: str) -> dict[str, Any]:
|
|
lines: list[tuple[int, str]] = []
|
|
for raw_line in text.splitlines():
|
|
line = _normalise_yaml_line(raw_line)
|
|
if line is not None:
|
|
lines.append(line)
|
|
data: dict[str, Any] = {}
|
|
index = 0
|
|
while index < len(lines):
|
|
indent, content = lines[index]
|
|
if indent != 0:
|
|
raise ValueError("Only top-level mappings are supported in rate YAML")
|
|
key, raw_value = _split_yaml_key_value(content)
|
|
if key == "rates" and raw_value == "":
|
|
rates, index = _parse_rates_block(lines, index + 1)
|
|
data["rates"] = rates
|
|
continue
|
|
data[key] = _parse_yaml_scalar(raw_value)
|
|
index += 1
|
|
return data
|
|
|
|
|
|
def _parse_rates_block(
|
|
lines: list[tuple[int, str]],
|
|
index: int,
|
|
) -> tuple[dict[str, dict[str, Any]], int]:
|
|
rates: dict[str, dict[str, Any]] = {}
|
|
while index < len(lines):
|
|
indent, content = lines[index]
|
|
if indent == 0:
|
|
break
|
|
if indent != 2:
|
|
raise ValueError("Rate model entries must be indented by two spaces")
|
|
model_id, raw_value = _split_yaml_key_value(content)
|
|
if raw_value:
|
|
raise ValueError(f"Rate entry for {model_id!r} must be a nested mapping")
|
|
entry: dict[str, Any] = {}
|
|
index += 1
|
|
while index < len(lines):
|
|
child_indent, child_content = lines[index]
|
|
if child_indent <= indent:
|
|
break
|
|
if child_indent != 4:
|
|
raise ValueError("Rate fields must be indented by four spaces")
|
|
child_key, child_value = _split_yaml_key_value(child_content)
|
|
entry[child_key] = _parse_yaml_scalar(child_value)
|
|
index += 1
|
|
rates[model_id] = entry
|
|
return rates, index
|
|
|
|
|
|
def _normalise_yaml_line(line: str) -> tuple[int, str] | None:
|
|
stripped = _strip_yaml_comment(line.rstrip())
|
|
if not stripped.strip():
|
|
return None
|
|
indent = len(stripped) - len(stripped.lstrip(" "))
|
|
return indent, stripped.strip()
|
|
|
|
|
|
def _strip_yaml_comment(line: str) -> str:
|
|
quote: str | None = None
|
|
for index, char in enumerate(line):
|
|
if char in {"'", '"'}:
|
|
quote = None if quote == char else char if quote is None else quote
|
|
elif char == "#" and quote is None:
|
|
return line[:index]
|
|
return line
|
|
|
|
|
|
def _split_yaml_key_value(content: str) -> tuple[str, str]:
|
|
key, separator, value = content.partition(":")
|
|
if not separator:
|
|
raise ValueError(f"Invalid YAML mapping line: {content!r}")
|
|
return key.strip().strip("'\""), value.strip()
|
|
|
|
|
|
def _parse_yaml_scalar(value: str) -> Any:
|
|
if value == "":
|
|
return ""
|
|
if (value.startswith('"') and value.endswith('"')) or (
|
|
value.startswith("'") and value.endswith("'")
|
|
):
|
|
return value[1:-1]
|
|
if value.lower() in {"null", "none", "~"}:
|
|
return None
|
|
try:
|
|
if any(char in value for char in (".", "e", "E")):
|
|
return float(value)
|
|
return int(value)
|
|
except ValueError:
|
|
return value
|