Implement-LLM-WP-0005-cost-model-estimators
Some checks failed
CI / test (3.10) (push) Has been cancelled
CI / test (3.11) (push) Has been cancelled
CI / test (3.12) (push) Has been cancelled

This commit is contained in:
2026-05-19 05:02:20 +02:00
parent 0054afe689
commit c11c6afa3f
16 changed files with 1525 additions and 10 deletions

View File

@@ -15,6 +15,7 @@ Quick start::
from llm_connect.adapter import ErrorLLMAdapter, LLMAdapter, MockLLMAdapter
from llm_connect.claude_code import ClaudeCodeAdapter
from llm_connect.config import LLMConfig, load_config
from llm_connect.costs import CostEstimate, CostModel, estimate_cost
from llm_connect.embedding_adapter import EmbeddingAdapter
from llm_connect.embedding_cache import EmbeddingCache
from llm_connect.embedding_factory import create_embedding_adapter
@@ -42,7 +43,20 @@ from llm_connect.grading import (
from llm_connect.models import BudgetTracker, LLMResponse, RunConfig
from llm_connect.openai import OpenAIAdapter
from llm_connect.openrouter import OpenRouterAdapter
from llm_connect.problem_classes import (
ChunkSummarizationProblemClass,
EntityExtractionProblemClass,
JudgeEvalProblemClass,
Observation,
ProblemClass,
ProblemClassRegistry,
RelationExtractionProblemClass,
ReportSynthesisProblemClass,
TokenEstimate,
default_problem_class_registry,
)
from llm_connect.quality import QualityLedger, QualityObservation, is_stale
from llm_connect.rates import ModelRate, ModelRateRegistry
from llm_connect.routing import AdaptiveRoutingPolicy, RoutingPolicy, RoutingRule
from llm_connect.server import LLMServer
from llm_connect.shadowing import ShadowingAdapter
@@ -95,4 +109,19 @@ __all__ = [
"AdaptiveRoutingPolicy",
"ShadowingAdapter",
"LLMServer",
"ModelRate",
"ModelRateRegistry",
"CostEstimate",
"CostModel",
"estimate_cost",
"TokenEstimate",
"Observation",
"ProblemClass",
"ProblemClassRegistry",
"default_problem_class_registry",
"ChunkSummarizationProblemClass",
"EntityExtractionProblemClass",
"RelationExtractionProblemClass",
"JudgeEvalProblemClass",
"ReportSynthesisProblemClass",
]

143
llm_connect/cli.py Normal file
View File

@@ -0,0 +1,143 @@
"""Command-line helpers for llm-connect registries."""
from __future__ import annotations
import argparse
import json
from collections.abc import Iterable, Mapping
from pathlib import Path
from typing import Any
from llm_connect.problem_classes import ProblemClass, ProblemClassRegistry
from llm_connect.quality import QualityLedger
from llm_connect.rates import ModelRateRegistry
def main(argv: list[str] | None = None) -> int:
"""Run the ``llm-connect`` command."""
parser = _build_parser()
args = parser.parse_args(argv)
return int(args.func(args))
def _build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(prog="llm-connect")
commands = parser.add_subparsers(dest="command", required=True)
rates = commands.add_parser("rates", help="Inspect model rate registries")
rate_commands = rates.add_subparsers(dest="rates_command", required=True)
rate_show = rate_commands.add_parser("show", help="Show model rates")
rate_show.add_argument("--rates", type=Path, help="YAML registry overlay")
rate_show.add_argument("--json", action="store_true", help="Emit JSON")
rate_show.set_defaults(func=_rates_show)
classes = commands.add_parser("classes", help="Inspect problem classes")
class_commands = classes.add_subparsers(dest="classes_command", required=True)
class_show = class_commands.add_parser("show", help="Show problem classes")
class_show.add_argument("--json", action="store_true", help="Emit JSON")
class_show.set_defaults(func=_classes_show)
class_fit = class_commands.add_parser("fit", help="Fit problem-class params from a ledger")
class_fit.add_argument("ledger", type=Path, help="QualityLedger JSONL path")
class_fit.add_argument("--class", dest="class_name", help="Fit one class by name")
class_fit.add_argument("--min-observations", type=int, default=3)
class_fit.add_argument("--json", action="store_true", help="Emit JSON")
class_fit.set_defaults(func=_classes_fit)
return parser
def _rates_show(args: argparse.Namespace) -> int:
registry = ModelRateRegistry.default()
if args.rates:
registry = registry.merged_with(ModelRateRegistry.from_yaml(args.rates))
rates = registry.all()
if args.json:
print(
json.dumps(
{
model_id: {
"prompt_per_1k": rate.prompt_per_1k,
"completion_per_1k": rate.completion_per_1k,
"currency": rate.currency,
"source_url": rate.source_url,
"captured_at": rate.captured_at,
}
for model_id, rate in sorted(rates.items())
},
indent=2,
sort_keys=True,
)
)
return 0
print("model_id\tprompt_per_1k\tcompletion_per_1k\tcurrency\tcaptured_at")
for model_id, rate in sorted(rates.items()):
print(
f"{model_id}\t{rate.prompt_per_1k:g}\t{rate.completion_per_1k:g}\t"
f"{rate.currency}\t{rate.captured_at}"
)
return 0
def _classes_show(args: argparse.Namespace) -> int:
classes = ProblemClassRegistry.default().all()
if args.json:
print(json.dumps(_classes_payload(classes.values()), indent=2, sort_keys=True))
return 0
print("name\tdimensions\ttunable_params\tcurrent_params")
for problem_class in sorted(classes.values(), key=lambda item: item.name):
print(
f"{problem_class.name}\t{', '.join(problem_class.base_dimensions)}\t"
f"{', '.join(problem_class.tunable_params)}\t{_format_params(problem_class.params)}"
)
return 0
def _classes_fit(args: argparse.Namespace) -> int:
if args.min_observations <= 0:
raise SystemExit("--min-observations must be positive")
registry = ProblemClassRegistry.default()
classes = registry.all()
if args.class_name:
problem_class = registry.get(args.class_name)
if problem_class is None:
raise SystemExit(f"Unknown problem class: {args.class_name}")
selected: list[ProblemClass] = [problem_class]
else:
selected = list(classes.values())
observations = QualityLedger(args.ledger).read_all()
fitted: list[ProblemClass] = [
problem_class.fit(observations, min_observations=args.min_observations)
for problem_class in selected
]
if args.json:
print(json.dumps(_classes_payload(fitted), indent=2, sort_keys=True))
return 0
print("name\tfitted_params\tconfidence")
for problem_class in sorted(fitted, key=lambda item: item.name):
confidence = getattr(problem_class, "confidence", 0.5)
print(f"{problem_class.name}\t{_format_params(problem_class.params)}\t{confidence:g}")
return 0
def _classes_payload(classes: Iterable[ProblemClass]) -> dict[str, dict[str, Any]]:
return {
problem_class.name: {
"base_dimensions": list(problem_class.base_dimensions),
"tunable_params": list(problem_class.tunable_params),
"params": dict(problem_class.params),
"confidence": getattr(problem_class, "confidence", 0.5),
}
for problem_class in sorted(classes, key=lambda item: item.name)
}
def _format_params(params: Mapping[str, float]) -> str:
return ", ".join(f"{key}={value:g}" for key, value in sorted(dict(params).items()))
if __name__ == "__main__":
raise SystemExit(main())

74
llm_connect/costs.py Normal file
View File

@@ -0,0 +1,74 @@
"""Cost estimation over model rates and token counts."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from llm_connect.rates import ModelRateRegistry
@dataclass(frozen=True)
class CostEstimate:
"""Cost estimate split by prompt and completion token spend."""
cost_usd: float | None
cost_source: str
prompt_cost_usd: float | None = None
completion_cost_usd: float | None = None
def estimate_cost(
model_id: str,
prompt_tokens: int,
completion_tokens: int = 0,
*,
registry: ModelRateRegistry | None = None,
) -> CostEstimate:
"""Estimate USD cost for token counts using *registry*.
Unknown models return ``CostEstimate(None, "unknown")`` so callers can
record uncertainty explicitly instead of treating missing prices as zero.
"""
prompt_count = _non_negative_int("prompt_tokens", prompt_tokens)
completion_count = _non_negative_int("completion_tokens", completion_tokens)
rates = registry or ModelRateRegistry.default()
rate = rates.get(model_id)
if rate is None:
return CostEstimate(cost_usd=None, cost_source="unknown")
prompt_cost = (prompt_count / 1000.0) * rate.prompt_per_1k
completion_cost = (completion_count / 1000.0) * rate.completion_per_1k
return CostEstimate(
cost_usd=prompt_cost + completion_cost,
cost_source=f"rate_table:{rate.model_id}",
prompt_cost_usd=prompt_cost,
completion_cost_usd=completion_cost,
)
@dataclass(frozen=True)
class CostModel:
"""Small wrapper for callers that prefer an object over a free function."""
registry: ModelRateRegistry | None = None
def estimate_cost(
self,
model_id: str,
prompt_tokens: int,
completion_tokens: int = 0,
) -> CostEstimate:
"""Estimate cost using this model's registry."""
return estimate_cost(
model_id,
prompt_tokens,
completion_tokens,
registry=self.registry,
)
def _non_negative_int(name: str, value: Any) -> int:
if isinstance(value, bool) or not isinstance(value, int) or value < 0:
raise ValueError(f"{name} must be a non-negative integer")
return value

View File

@@ -0,0 +1,463 @@
"""Problem-class token estimators for common LLM workflow shapes."""
from __future__ import annotations
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from typing import Any, Protocol
DEFAULT_WORDS_PER_TOKEN = 0.75
@dataclass(frozen=True)
class TokenEstimate:
"""Prompt/completion token estimate for a prospective LLM call."""
prompt_tokens: int
completion_tokens: int
confidence: float = 0.5
def __post_init__(self) -> None:
prompt_tokens = _non_negative_int("prompt_tokens", self.prompt_tokens)
completion_tokens = _non_negative_int("completion_tokens", self.completion_tokens)
confidence = _bounded_float("confidence", self.confidence)
object.__setattr__(self, "prompt_tokens", prompt_tokens)
object.__setattr__(self, "completion_tokens", completion_tokens)
object.__setattr__(self, "confidence", confidence)
@dataclass(frozen=True)
class Observation:
"""Actual token use paired with the problem dimensions that produced it."""
dimensions: dict[str, Any]
prompt_tokens: int
completion_tokens: int
def __post_init__(self) -> None:
object.__setattr__(self, "dimensions", dict(self.dimensions))
object.__setattr__(self, "prompt_tokens", _non_negative_int("prompt_tokens", self.prompt_tokens))
object.__setattr__(
self,
"completion_tokens",
_non_negative_int("completion_tokens", self.completion_tokens),
)
class ProblemClass(Protocol):
"""Estimator contract implemented by built-in and consumer classes."""
name: str
base_dimensions: tuple[str, ...]
tunable_params: tuple[str, ...]
params: dict[str, float]
def estimate(
self,
dimensions: dict[str, Any],
params: dict[str, Any] | None = None,
) -> TokenEstimate:
"""Estimate token use from dimensions and optional parameter overrides."""
...
def fit(
self,
observations: Sequence[Any],
*,
min_observations: int = 3,
) -> "ProblemClass":
"""Return an estimator with params adapted from observed token use."""
...
class ProblemClassRegistry:
"""Registry keyed by stable problem-class names."""
schema_version = 1
def __init__(self, classes: Sequence[ProblemClass] | None = None) -> None:
self._classes: dict[str, ProblemClass] = {}
for problem_class in classes or ():
self.register(problem_class)
def get(self, name: str) -> ProblemClass | None:
"""Return a registered class by name."""
return self._classes.get(str(name).strip())
def all(self) -> dict[str, ProblemClass]:
"""Return a copy of registered problem classes."""
return dict(self._classes)
def register(self, problem_class: ProblemClass, *, replace: bool = False) -> None:
"""Register *problem_class* under its name."""
name = str(problem_class.name).strip()
if not name:
raise ValueError("problem_class.name must be a non-empty string")
if name in self._classes and not replace:
raise ValueError(f"Problem class {name!r} is already registered")
self._classes[name] = problem_class
@classmethod
def default(cls) -> "ProblemClassRegistry":
"""Return the built-in problem-class registry."""
return cls(
[
ChunkSummarizationProblemClass(),
EntityExtractionProblemClass(),
RelationExtractionProblemClass(),
JudgeEvalProblemClass(),
ReportSynthesisProblemClass(),
]
)
class _BaseProblemClass:
name = ""
base_dimensions: tuple[str, ...] = ()
tunable_params: tuple[str, ...] = ()
seed_params: Mapping[str, float] = {}
def __init__(
self,
*,
params: Mapping[str, Any] | None = None,
confidence: float = 0.5,
) -> None:
merged = dict(self.seed_params)
for key, value in (params or {}).items():
if key not in self.tunable_params:
raise ValueError(f"Unknown parameter {key!r} for problem class {self.name!r}")
merged[key] = _non_negative_float(key, value)
self.params: dict[str, float] = merged
self.confidence = _bounded_float("confidence", confidence)
def estimate(
self,
dimensions: dict[str, Any],
params: dict[str, Any] | None = None,
) -> TokenEstimate:
dimensions = dict(dimensions)
self._validate_dimensions(dimensions)
merged_params = dict(self.params)
for key, value in (params or {}).items():
if key not in self.tunable_params:
raise ValueError(f"Unknown parameter {key!r} for problem class {self.name!r}")
merged_params[key] = _non_negative_float(key, value)
prompt_tokens, completion_tokens = self._estimate_tokens(dimensions, merged_params)
return TokenEstimate(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
confidence=self.confidence,
)
def fit(
self,
observations: Sequence[Any],
*,
min_observations: int = 3,
) -> ProblemClass:
if min_observations <= 0:
raise ValueError("min_observations must be positive")
parsed = [
observation
for observation in (
_coerce_observation(raw, self.name, self.base_dimensions) for raw in observations
)
if observation is not None
]
if len(parsed) < min_observations:
return self
fitted: dict[str, float] = {}
for param in self.tunable_params:
values = [
value
for value in (
self._infer_param(param, observation) for observation in parsed
)
if value is not None
]
if values:
fitted[param] = sum(values) / len(values)
if not fitted:
return self
confidence = min(0.95, max(self.confidence, len(parsed) / (len(parsed) + 5)))
return type(self)(params={**self.params, **fitted}, confidence=confidence)
def _validate_dimensions(self, dimensions: Mapping[str, Any]) -> None:
missing = [name for name in self.base_dimensions if name not in dimensions]
if missing:
raise ValueError(f"Missing dimensions for {self.name!r}: {', '.join(missing)}")
for name in self.base_dimensions:
_non_negative_float(name, dimensions[name])
def _estimate_tokens(
self,
dimensions: Mapping[str, Any],
params: Mapping[str, float],
) -> tuple[int, int]:
raise NotImplementedError
def _infer_param(self, param: str, observation: Observation) -> float | None:
raise NotImplementedError
class ChunkSummarizationProblemClass(_BaseProblemClass):
name = "chunk-summarization"
base_dimensions: tuple[str, ...] = ("chunk_words", "template_words")
tunable_params: tuple[str, ...] = ("completion_ratio",)
seed_params: Mapping[str, float] = {"completion_ratio": 0.25}
def _estimate_tokens(
self,
dimensions: Mapping[str, Any],
params: Mapping[str, float],
) -> tuple[int, int]:
prompt_tokens = _words_to_tokens(
_dimension(dimensions, "chunk_words") + _dimension(dimensions, "template_words")
)
completion_tokens = _round_tokens(prompt_tokens * params["completion_ratio"])
return prompt_tokens, completion_tokens
def _infer_param(self, param: str, observation: Observation) -> float | None:
if param != "completion_ratio" or observation.prompt_tokens == 0:
return None
return observation.completion_tokens / observation.prompt_tokens
class EntityExtractionProblemClass(_BaseProblemClass):
name = "entity-extraction"
base_dimensions: tuple[str, ...] = ("chunk_words", "template_words", "expected_entities")
tunable_params: tuple[str, ...] = ("tokens_per_entity",)
seed_params: Mapping[str, float] = {"tokens_per_entity": 70.0}
def _estimate_tokens(
self,
dimensions: Mapping[str, Any],
params: Mapping[str, float],
) -> tuple[int, int]:
prompt_tokens = _words_to_tokens(
_dimension(dimensions, "chunk_words") + _dimension(dimensions, "template_words")
)
completion_tokens = _round_tokens(
_dimension(dimensions, "expected_entities") * params["tokens_per_entity"]
)
return prompt_tokens, completion_tokens
def _infer_param(self, param: str, observation: Observation) -> float | None:
expected_entities = _dimension(observation.dimensions, "expected_entities")
if param != "tokens_per_entity" or expected_entities <= 0:
return None
return observation.completion_tokens / expected_entities
class RelationExtractionProblemClass(_BaseProblemClass):
name = "relation-extraction"
base_dimensions: tuple[str, ...] = ("chunk_words", "template_words", "expected_relations")
tunable_params: tuple[str, ...] = ("tokens_per_relation",)
seed_params: Mapping[str, float] = {"tokens_per_relation": 80.0}
def _estimate_tokens(
self,
dimensions: Mapping[str, Any],
params: Mapping[str, float],
) -> tuple[int, int]:
prompt_tokens = _words_to_tokens(
_dimension(dimensions, "chunk_words") + _dimension(dimensions, "template_words")
)
completion_tokens = _round_tokens(
_dimension(dimensions, "expected_relations") * params["tokens_per_relation"]
)
return prompt_tokens, completion_tokens
def _infer_param(self, param: str, observation: Observation) -> float | None:
expected_relations = _dimension(observation.dimensions, "expected_relations")
if param != "tokens_per_relation" or expected_relations <= 0:
return None
return observation.completion_tokens / expected_relations
class JudgeEvalProblemClass(_BaseProblemClass):
name = "judge-eval"
base_dimensions: tuple[str, ...] = ("artifact_words", "template_words", "n_criteria")
tunable_params: tuple[str, ...] = ("tokens_per_criterion",)
seed_params: Mapping[str, float] = {"tokens_per_criterion": 35.0}
def _estimate_tokens(
self,
dimensions: Mapping[str, Any],
params: Mapping[str, float],
) -> tuple[int, int]:
prompt_tokens = _words_to_tokens(
_dimension(dimensions, "artifact_words") + _dimension(dimensions, "template_words")
)
completion_tokens = _round_tokens(
_dimension(dimensions, "n_criteria") * params["tokens_per_criterion"]
)
return prompt_tokens, completion_tokens
def _infer_param(self, param: str, observation: Observation) -> float | None:
n_criteria = _dimension(observation.dimensions, "n_criteria")
if param != "tokens_per_criterion" or n_criteria <= 0:
return None
return observation.completion_tokens / n_criteria
class ReportSynthesisProblemClass(_BaseProblemClass):
name = "report-synthesis"
base_dimensions: tuple[str, ...] = ("n_chunks", "n_entities", "n_relations", "template_words")
tunable_params: tuple[str, ...] = ("base_completion_tokens",)
seed_params: Mapping[str, float] = {"base_completion_tokens": 400.0}
def _estimate_tokens(
self,
dimensions: Mapping[str, Any],
params: Mapping[str, float],
) -> tuple[int, int]:
prompt_tokens = _words_to_tokens(_dimension(dimensions, "template_words"))
prompt_tokens += _round_tokens(_dimension(dimensions, "n_chunks") * 40)
prompt_tokens += _round_tokens(_dimension(dimensions, "n_entities") * 25)
prompt_tokens += _round_tokens(_dimension(dimensions, "n_relations") * 35)
return prompt_tokens, _round_tokens(params["base_completion_tokens"])
def _infer_param(self, param: str, observation: Observation) -> float | None:
if param != "base_completion_tokens":
return None
return float(observation.completion_tokens)
def default_problem_class_registry() -> ProblemClassRegistry:
"""Return the built-in problem-class registry."""
return ProblemClassRegistry.default()
def _coerce_observation(
raw: Any,
class_name: str,
required_dimensions: tuple[str, ...],
) -> Observation | None:
try:
if isinstance(raw, Observation):
return raw
if isinstance(raw, Mapping):
return _coerce_mapping_observation(raw, class_name, required_dimensions)
return _coerce_object_observation(raw, class_name, required_dimensions)
except (KeyError, TypeError, ValueError):
return None
def _coerce_mapping_observation(
raw: Mapping[str, Any],
class_name: str,
required_dimensions: tuple[str, ...],
) -> Observation | None:
raw_tags = raw.get("tags")
tags: Mapping[str, Any] = raw_tags if isinstance(raw_tags, Mapping) else {}
problem_class = raw.get("problem_class") or tags.get("problem_class")
if problem_class is not None and str(problem_class) != class_name:
return None
dimensions = _dimensions_from_sources(required_dimensions, raw, tags)
prompt_tokens = _token_value(raw, "prompt_tokens", "tokens_in", "actual_prompt_tokens")
completion_tokens = _token_value(
raw,
"completion_tokens",
"tokens_out",
"actual_completion_tokens",
)
return Observation(dimensions, prompt_tokens, completion_tokens)
def _coerce_object_observation(
raw: Any,
class_name: str,
required_dimensions: tuple[str, ...],
) -> Observation | None:
raw_tags = getattr(raw, "tags", {}) or {}
tags: Mapping[str, Any] = raw_tags if isinstance(raw_tags, Mapping) else {}
problem_class = tags.get("problem_class")
if problem_class is not None and str(problem_class) != class_name:
return None
dimensions = _dimensions_from_sources(required_dimensions, tags)
return Observation(
dimensions=dimensions,
prompt_tokens=getattr(raw, "tokens_in"),
completion_tokens=getattr(raw, "tokens_out"),
)
def _dimensions_from_sources(
required_dimensions: tuple[str, ...],
*sources: Mapping[str, Any],
) -> dict[str, Any]:
for source in sources:
candidate = source.get("dimensions")
if isinstance(candidate, Mapping):
return dict(candidate)
dimensions: dict[str, Any] = {}
for name in required_dimensions:
for source in sources:
if name in source:
dimensions[name] = source[name]
break
if len(dimensions) != len(required_dimensions):
raise ValueError("observation is missing required dimensions")
return dimensions
def _token_value(raw: Mapping[str, Any], *names: str) -> int:
for name in names:
if name in raw:
return _non_negative_int(name, raw[name])
usage = raw.get("usage")
if isinstance(usage, Mapping):
for name in names:
if name in usage:
return _non_negative_int(name, usage[name])
raise KeyError(names[0])
def _dimension(dimensions: Mapping[str, Any], name: str) -> float:
return _non_negative_float(name, dimensions[name])
def _words_to_tokens(words: float) -> int:
if words == 0:
return 0
return max(1, _round_tokens(words / DEFAULT_WORDS_PER_TOKEN))
def _round_tokens(value: float) -> int:
return max(0, int(round(value)))
def _non_negative_int(name: str, value: Any) -> int:
if isinstance(value, bool):
raise ValueError(f"{name} must be a non-negative integer")
try:
integer = int(value)
except (TypeError, ValueError) as exc:
raise ValueError(f"{name} must be a non-negative integer") from exc
if integer < 0 or integer != float(value):
raise ValueError(f"{name} must be a non-negative integer")
return integer
def _non_negative_float(name: str, value: Any) -> float:
if isinstance(value, bool):
raise ValueError(f"{name} must be a non-negative number")
try:
number = float(value)
except (TypeError, ValueError) as exc:
raise ValueError(f"{name} must be a non-negative number") from exc
if number < 0:
raise ValueError(f"{name} must be a non-negative number")
return number
def _bounded_float(name: str, value: Any) -> float:
number = _non_negative_float(name, value)
if number > 1:
raise ValueError(f"{name} must be between 0 and 1")
return number

273
llm_connect/rates.py Normal file
View File

@@ -0,0 +1,273 @@
"""Model rate registry for preview and post-hoc cost estimation."""
from __future__ import annotations
from collections.abc import Mapping
from dataclasses import dataclass
from pathlib import Path
from typing import Any
DEFAULT_RATE_SOURCE_URL = "https://openrouter.ai/models"
DEFAULT_RATE_CAPTURED_AT = "2026-05-17"
DEFAULT_RATE_CURRENCY = "USD"
@dataclass(frozen=True)
class ModelRate:
"""USD-denominated list price for one model."""
model_id: str
prompt_per_1k: float
completion_per_1k: float
currency: str = DEFAULT_RATE_CURRENCY
source_url: str = ""
captured_at: str = ""
def __post_init__(self) -> None:
model_id = str(self.model_id).strip()
currency = str(self.currency or DEFAULT_RATE_CURRENCY).strip().upper()
if not model_id:
raise ValueError("model_id must be a non-empty string")
if not currency:
raise ValueError("currency must be a non-empty string")
prompt_rate = _non_negative_float("prompt_per_1k", self.prompt_per_1k)
completion_rate = _non_negative_float("completion_per_1k", self.completion_per_1k)
object.__setattr__(self, "model_id", model_id)
object.__setattr__(self, "prompt_per_1k", prompt_rate)
object.__setattr__(self, "completion_per_1k", completion_rate)
object.__setattr__(self, "currency", currency)
object.__setattr__(self, "source_url", str(self.source_url or ""))
object.__setattr__(self, "captured_at", str(self.captured_at or ""))
class ModelRateRegistry:
"""Lookup table for model list prices."""
def __init__(self, rates: Mapping[str, ModelRate | Mapping[str, Any]] | None = None) -> None:
self._rates: dict[str, ModelRate] = {}
for model_id, rate in (rates or {}).items():
model_rate = _coerce_rate(model_id, rate)
self._rates[model_rate.model_id] = model_rate
def get(self, model_id: str) -> ModelRate | None:
"""Return the rate for *model_id*, or ``None`` when absent."""
return self._rates.get(str(model_id).strip())
def all(self) -> dict[str, ModelRate]:
"""Return a copy of the registry mapping."""
return dict(self._rates)
@classmethod
def default(cls) -> "ModelRateRegistry":
"""Return the bundled OpenRouter list-price snapshot."""
return cls(_default_rate_payload())
@classmethod
def from_yaml(cls, path: Path | str) -> "ModelRateRegistry":
"""Load rates from a YAML file.
The expected shape matches the historic infospace-bench table::
currency: USD
source_url: https://openrouter.ai/models
captured_at: "2026-05-17"
rates:
openai/gpt-4o-mini:
prompt_per_1k: 0.00015
completion_per_1k: 0.00060
PyYAML is used when installed; otherwise a small parser handles this
schema so llm-connect keeps its current lightweight dependency surface.
"""
payload = _load_yaml_mapping(Path(path))
return cls(_rates_from_payload(payload))
def merged_with(self, override: "ModelRateRegistry") -> "ModelRateRegistry":
"""Return a new registry where *override* entries win by model id."""
merged = self.all()
merged.update(override.all())
return ModelRateRegistry(merged)
_DEFAULT_RATES: dict[str, tuple[float, float]] = {
"openai/gpt-4o-mini": (0.00015, 0.00060),
"openai/gpt-4o": (0.0025, 0.01),
"openai/gpt-4-turbo": (0.01, 0.03),
"anthropic/claude-3.5-sonnet": (0.003, 0.015),
"anthropic/claude-3.5-haiku": (0.0008, 0.004),
"anthropic/claude-3-opus": (0.015, 0.075),
"google/gemini-1.5-flash": (0.000075, 0.0003),
"google/gemini-1.5-pro": (0.00125, 0.005),
"meta-llama/llama-3.1-70b-instruct": (0.00059, 0.00079),
}
def _default_rate_payload() -> dict[str, ModelRate]:
return {
model_id: ModelRate(
model_id=model_id,
prompt_per_1k=prompt_rate,
completion_per_1k=completion_rate,
currency=DEFAULT_RATE_CURRENCY,
source_url=DEFAULT_RATE_SOURCE_URL,
captured_at=DEFAULT_RATE_CAPTURED_AT,
)
for model_id, (prompt_rate, completion_rate) in _DEFAULT_RATES.items()
}
def _coerce_rate(model_id: str, rate: ModelRate | Mapping[str, Any]) -> ModelRate:
if isinstance(rate, ModelRate):
return rate
if not isinstance(rate, Mapping):
raise TypeError(f"Rate for {model_id!r} must be a ModelRate or mapping")
return ModelRate(
model_id=str(model_id),
prompt_per_1k=rate["prompt_per_1k"],
completion_per_1k=rate["completion_per_1k"],
currency=str(rate.get("currency") or DEFAULT_RATE_CURRENCY),
source_url=str(rate.get("source_url") or ""),
captured_at=str(rate.get("captured_at") or ""),
)
def _rates_from_payload(payload: Mapping[str, Any]) -> dict[str, ModelRate]:
rates_payload = payload.get("rates")
if not isinstance(rates_payload, Mapping):
raise ValueError("Rate YAML must contain a 'rates' mapping")
currency = str(payload.get("currency") or DEFAULT_RATE_CURRENCY)
source_url = str(payload.get("source_url") or "")
captured_at = str(payload.get("captured_at") or "")
rates: dict[str, ModelRate] = {}
for model_id, raw_rate in rates_payload.items():
if not isinstance(raw_rate, Mapping):
raise ValueError(f"Rate entry for {model_id!r} must be a mapping")
rates[str(model_id)] = ModelRate(
model_id=str(model_id),
prompt_per_1k=raw_rate["prompt_per_1k"],
completion_per_1k=raw_rate["completion_per_1k"],
currency=str(raw_rate.get("currency") or currency),
source_url=str(raw_rate.get("source_url") or source_url),
captured_at=str(raw_rate.get("captured_at") or captured_at),
)
return rates
def _non_negative_float(name: str, value: Any) -> float:
if isinstance(value, bool):
raise ValueError(f"{name} must be a non-negative number")
try:
number = float(value)
except (TypeError, ValueError) as exc:
raise ValueError(f"{name} must be a non-negative number") from exc
if number < 0:
raise ValueError(f"{name} must be a non-negative number")
return number
def _load_yaml_mapping(path: Path) -> Mapping[str, Any]:
try:
import yaml
except ImportError:
return _parse_rate_yaml(path.read_text(encoding="utf-8"))
data = yaml.safe_load(path.read_text(encoding="utf-8")) or {}
if not isinstance(data, Mapping):
raise ValueError("Rate YAML root must be a mapping")
return data
def _parse_rate_yaml(text: str) -> dict[str, Any]:
lines: list[tuple[int, str]] = []
for raw_line in text.splitlines():
line = _normalise_yaml_line(raw_line)
if line is not None:
lines.append(line)
data: dict[str, Any] = {}
index = 0
while index < len(lines):
indent, content = lines[index]
if indent != 0:
raise ValueError("Only top-level mappings are supported in rate YAML")
key, raw_value = _split_yaml_key_value(content)
if key == "rates" and raw_value == "":
rates, index = _parse_rates_block(lines, index + 1)
data["rates"] = rates
continue
data[key] = _parse_yaml_scalar(raw_value)
index += 1
return data
def _parse_rates_block(
lines: list[tuple[int, str]],
index: int,
) -> tuple[dict[str, dict[str, Any]], int]:
rates: dict[str, dict[str, Any]] = {}
while index < len(lines):
indent, content = lines[index]
if indent == 0:
break
if indent != 2:
raise ValueError("Rate model entries must be indented by two spaces")
model_id, raw_value = _split_yaml_key_value(content)
if raw_value:
raise ValueError(f"Rate entry for {model_id!r} must be a nested mapping")
entry: dict[str, Any] = {}
index += 1
while index < len(lines):
child_indent, child_content = lines[index]
if child_indent <= indent:
break
if child_indent != 4:
raise ValueError("Rate fields must be indented by four spaces")
child_key, child_value = _split_yaml_key_value(child_content)
entry[child_key] = _parse_yaml_scalar(child_value)
index += 1
rates[model_id] = entry
return rates, index
def _normalise_yaml_line(line: str) -> tuple[int, str] | None:
stripped = _strip_yaml_comment(line.rstrip())
if not stripped.strip():
return None
indent = len(stripped) - len(stripped.lstrip(" "))
return indent, stripped.strip()
def _strip_yaml_comment(line: str) -> str:
quote: str | None = None
for index, char in enumerate(line):
if char in {"'", '"'}:
quote = None if quote == char else char if quote is None else quote
elif char == "#" and quote is None:
return line[:index]
return line
def _split_yaml_key_value(content: str) -> tuple[str, str]:
key, separator, value = content.partition(":")
if not separator:
raise ValueError(f"Invalid YAML mapping line: {content!r}")
return key.strip().strip("'\""), value.strip()
def _parse_yaml_scalar(value: str) -> Any:
if value == "":
return ""
if (value.startswith('"') and value.endswith('"')) or (
value.startswith("'") and value.endswith("'")
):
return value[1:-1]
if value.lower() in {"null", "none", "~"}:
return None
try:
if any(char in value for char in (".", "e", "E")):
return float(value)
return int(value)
except ValueError:
return value