Files
llm-connect/llm_connect/costs.py
tegwick c11c6afa3f
Some checks failed
CI / test (3.10) (push) Has been cancelled
CI / test (3.11) (push) Has been cancelled
CI / test (3.12) (push) Has been cancelled
Implement-LLM-WP-0005-cost-model-estimators
2026-05-19 05:02:20 +02:00

75 lines
2.2 KiB
Python

"""Cost estimation over model rates and token counts."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from llm_connect.rates import ModelRateRegistry
@dataclass(frozen=True)
class CostEstimate:
"""Cost estimate split by prompt and completion token spend."""
cost_usd: float | None
cost_source: str
prompt_cost_usd: float | None = None
completion_cost_usd: float | None = None
def estimate_cost(
model_id: str,
prompt_tokens: int,
completion_tokens: int = 0,
*,
registry: ModelRateRegistry | None = None,
) -> CostEstimate:
"""Estimate USD cost for token counts using *registry*.
Unknown models return ``CostEstimate(None, "unknown")`` so callers can
record uncertainty explicitly instead of treating missing prices as zero.
"""
prompt_count = _non_negative_int("prompt_tokens", prompt_tokens)
completion_count = _non_negative_int("completion_tokens", completion_tokens)
rates = registry or ModelRateRegistry.default()
rate = rates.get(model_id)
if rate is None:
return CostEstimate(cost_usd=None, cost_source="unknown")
prompt_cost = (prompt_count / 1000.0) * rate.prompt_per_1k
completion_cost = (completion_count / 1000.0) * rate.completion_per_1k
return CostEstimate(
cost_usd=prompt_cost + completion_cost,
cost_source=f"rate_table:{rate.model_id}",
prompt_cost_usd=prompt_cost,
completion_cost_usd=completion_cost,
)
@dataclass(frozen=True)
class CostModel:
"""Small wrapper for callers that prefer an object over a free function."""
registry: ModelRateRegistry | None = None
def estimate_cost(
self,
model_id: str,
prompt_tokens: int,
completion_tokens: int = 0,
) -> CostEstimate:
"""Estimate cost using this model's registry."""
return estimate_cost(
model_id,
prompt_tokens,
completion_tokens,
registry=self.registry,
)
def _non_negative_int(name: str, value: Any) -> int:
if isinstance(value, bool) or not isinstance(value, int) or value < 0:
raise ValueError(f"{name} must be a non-negative integer")
return value