generated from coulomb/repo-seed
75 lines
2.2 KiB
Python
75 lines
2.2 KiB
Python
"""Cost estimation over model rates and token counts."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
from llm_connect.rates import ModelRateRegistry
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class CostEstimate:
|
|
"""Cost estimate split by prompt and completion token spend."""
|
|
|
|
cost_usd: float | None
|
|
cost_source: str
|
|
prompt_cost_usd: float | None = None
|
|
completion_cost_usd: float | None = None
|
|
|
|
|
|
def estimate_cost(
|
|
model_id: str,
|
|
prompt_tokens: int,
|
|
completion_tokens: int = 0,
|
|
*,
|
|
registry: ModelRateRegistry | None = None,
|
|
) -> CostEstimate:
|
|
"""Estimate USD cost for token counts using *registry*.
|
|
|
|
Unknown models return ``CostEstimate(None, "unknown")`` so callers can
|
|
record uncertainty explicitly instead of treating missing prices as zero.
|
|
"""
|
|
prompt_count = _non_negative_int("prompt_tokens", prompt_tokens)
|
|
completion_count = _non_negative_int("completion_tokens", completion_tokens)
|
|
rates = registry or ModelRateRegistry.default()
|
|
rate = rates.get(model_id)
|
|
if rate is None:
|
|
return CostEstimate(cost_usd=None, cost_source="unknown")
|
|
|
|
prompt_cost = (prompt_count / 1000.0) * rate.prompt_per_1k
|
|
completion_cost = (completion_count / 1000.0) * rate.completion_per_1k
|
|
return CostEstimate(
|
|
cost_usd=prompt_cost + completion_cost,
|
|
cost_source=f"rate_table:{rate.model_id}",
|
|
prompt_cost_usd=prompt_cost,
|
|
completion_cost_usd=completion_cost,
|
|
)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class CostModel:
|
|
"""Small wrapper for callers that prefer an object over a free function."""
|
|
|
|
registry: ModelRateRegistry | None = None
|
|
|
|
def estimate_cost(
|
|
self,
|
|
model_id: str,
|
|
prompt_tokens: int,
|
|
completion_tokens: int = 0,
|
|
) -> CostEstimate:
|
|
"""Estimate cost using this model's registry."""
|
|
return estimate_cost(
|
|
model_id,
|
|
prompt_tokens,
|
|
completion_tokens,
|
|
registry=self.registry,
|
|
)
|
|
|
|
|
|
def _non_negative_int(name: str, value: Any) -> int:
|
|
if isinstance(value, bool) or not isinstance(value, int) or value < 0:
|
|
raise ValueError(f"{name} must be a non-negative integer")
|
|
return value
|