generated from coulomb/repo-seed
Implement-LLM-WP-0005-cost-model-estimators
This commit is contained in:
25
contracts/functional/costs.md
Normal file
25
contracts/functional/costs.md
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
# Cost Estimates
|
||||||
|
|
||||||
|
`llm_connect.costs` converts token estimates or observed token counts into
|
||||||
|
USD estimates using `ModelRateRegistry`.
|
||||||
|
|
||||||
|
## Contract
|
||||||
|
|
||||||
|
```python
|
||||||
|
from llm_connect import estimate_cost
|
||||||
|
|
||||||
|
estimate = estimate_cost("openai/gpt-4o-mini", 28_000, 7_500)
|
||||||
|
```
|
||||||
|
|
||||||
|
For known models the result is:
|
||||||
|
|
||||||
|
- `cost_usd`: prompt plus completion estimate.
|
||||||
|
- `prompt_cost_usd`: prompt-token component.
|
||||||
|
- `completion_cost_usd`: completion-token component.
|
||||||
|
- `cost_source`: `rate_table:<model_id>`.
|
||||||
|
|
||||||
|
Unknown models return `CostEstimate(cost_usd=None, cost_source="unknown")`.
|
||||||
|
Missing rates are never silently treated as zero cost.
|
||||||
|
|
||||||
|
The module also exposes `CostModel(registry=...)` for callers that prefer to
|
||||||
|
carry a registry object and call `model.estimate_cost(...)`.
|
||||||
46
contracts/functional/problem-classes.md
Normal file
46
contracts/functional/problem-classes.md
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
# Problem Classes
|
||||||
|
|
||||||
|
`llm_connect.problem_classes` provides generic token estimators for recurring
|
||||||
|
LLM workflow shapes.
|
||||||
|
|
||||||
|
## Contract
|
||||||
|
|
||||||
|
Every problem class exposes:
|
||||||
|
|
||||||
|
- `name`: stable registry key.
|
||||||
|
- `base_dimensions`: required dimension names supplied by consumers.
|
||||||
|
- `tunable_params`: parameters that can be overridden or fitted.
|
||||||
|
- `estimate(dimensions, params=None) -> TokenEstimate`.
|
||||||
|
- `fit(observations, min_observations=3) -> ProblemClass`.
|
||||||
|
|
||||||
|
`TokenEstimate` contains `prompt_tokens`, `completion_tokens`, and a
|
||||||
|
`confidence` score from `0` to `1`.
|
||||||
|
|
||||||
|
## Built-Ins
|
||||||
|
|
||||||
|
| Name | Dimensions | Tunable params |
|
||||||
|
|---|---|---|
|
||||||
|
| `chunk-summarization` | `chunk_words`, `template_words` | `completion_ratio` |
|
||||||
|
| `entity-extraction` | `chunk_words`, `template_words`, `expected_entities` | `tokens_per_entity` |
|
||||||
|
| `relation-extraction` | `chunk_words`, `template_words`, `expected_relations` | `tokens_per_relation` |
|
||||||
|
| `judge-eval` | `artifact_words`, `template_words`, `n_criteria` | `tokens_per_criterion` |
|
||||||
|
| `report-synthesis` | `n_chunks`, `n_entities`, `n_relations`, `template_words` | `base_completion_tokens` |
|
||||||
|
|
||||||
|
## Observations
|
||||||
|
|
||||||
|
`fit()` accepts either `Observation` objects or `QualityObservation` rows whose
|
||||||
|
`tags` include:
|
||||||
|
|
||||||
|
```python
|
||||||
|
{
|
||||||
|
"problem_class": "entity-extraction",
|
||||||
|
"dimensions": {
|
||||||
|
"chunk_words": 900,
|
||||||
|
"template_words": 200,
|
||||||
|
"expected_entities": 4,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
When fewer than `min_observations` usable rows are present, fitting falls back
|
||||||
|
to the current parameters.
|
||||||
30
contracts/functional/rates.md
Normal file
30
contracts/functional/rates.md
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
# Model Rate Registry
|
||||||
|
|
||||||
|
`llm_connect.rates` owns static model list prices used for planning and
|
||||||
|
post-hoc estimates.
|
||||||
|
|
||||||
|
## Contract
|
||||||
|
|
||||||
|
- `ModelRate` records `model_id`, prompt and completion rates in USD per
|
||||||
|
1,000 tokens, `currency`, `source_url`, and `captured_at`.
|
||||||
|
- `ModelRateRegistry.default()` returns the bundled OpenRouter snapshot
|
||||||
|
captured on `2026-05-17`.
|
||||||
|
- `ModelRateRegistry.from_yaml(path)` accepts the package/consumer override
|
||||||
|
shape:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
schema_version: 1
|
||||||
|
currency: USD
|
||||||
|
source_url: https://openrouter.ai/models
|
||||||
|
captured_at: "2026-05-17"
|
||||||
|
rates:
|
||||||
|
openai/gpt-4o-mini:
|
||||||
|
prompt_per_1k: 0.00015
|
||||||
|
completion_per_1k: 0.00060
|
||||||
|
```
|
||||||
|
|
||||||
|
- `merged_with(override)` returns a new registry where matching override
|
||||||
|
entries replace default entries by `model_id`.
|
||||||
|
|
||||||
|
Rates are a static snapshot. Consumers decide whether `captured_at` is fresh
|
||||||
|
enough for their workflow.
|
||||||
100
docs/infospace-bench-cost-model-migration.md
Normal file
100
docs/infospace-bench-cost-model-migration.md
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
# infospace-bench Cost Estimator Migration
|
||||||
|
|
||||||
|
`infospace-bench` can replace its local rate table and coarse word-count
|
||||||
|
budget math with the primitives added in `LLM-WP-0005`.
|
||||||
|
|
||||||
|
## Rate Table
|
||||||
|
|
||||||
|
- Drop `src/infospace_bench/model_rates.yaml` after the dependency is bumped.
|
||||||
|
- Load `ModelRateRegistry.default()` from `llm-connect`.
|
||||||
|
- Keep the workspace-level `model-rates.yaml` override and merge it with
|
||||||
|
`default().merged_with(ModelRateRegistry.from_yaml(path))`.
|
||||||
|
- Preserve `--cost-per-1k` as an explicit blended-rate override. When supplied,
|
||||||
|
it should win over the registry and report `cost_source="cost_per_1k_blended"`.
|
||||||
|
|
||||||
|
## Plan Summary Sketch
|
||||||
|
|
||||||
|
```python
|
||||||
|
from llm_connect import (
|
||||||
|
CostEstimate,
|
||||||
|
ModelRateRegistry,
|
||||||
|
ProblemClassRegistry,
|
||||||
|
estimate_cost,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def plan_generation_summary(...):
|
||||||
|
problem_classes = ProblemClassRegistry.default()
|
||||||
|
rates = ModelRateRegistry.default()
|
||||||
|
workspace_rates = _workspace_rate_path(root_path)
|
||||||
|
if workspace_rates.exists():
|
||||||
|
rates = rates.merged_with(ModelRateRegistry.from_yaml(workspace_rates))
|
||||||
|
|
||||||
|
total_prompt_tokens = 0
|
||||||
|
total_completion_tokens = 0
|
||||||
|
per_stage = []
|
||||||
|
for workflow_id in workflow_ids:
|
||||||
|
class_name, dimensions = _problem_class_for_workflow(
|
||||||
|
workflow_id,
|
||||||
|
selected_chunks=selected,
|
||||||
|
template_words=template_words,
|
||||||
|
entities_per_chunk=entities_per_chunk,
|
||||||
|
)
|
||||||
|
estimate = problem_classes.get(class_name).estimate(dimensions)
|
||||||
|
calls = _calls_for_workflow(workflow_id, selected, entities_per_chunk)
|
||||||
|
prompt_tokens = estimate.prompt_tokens * calls
|
||||||
|
completion_tokens = estimate.completion_tokens * calls
|
||||||
|
total_prompt_tokens += prompt_tokens
|
||||||
|
total_completion_tokens += completion_tokens
|
||||||
|
per_stage.append(
|
||||||
|
{
|
||||||
|
"workflow_id": workflow_id,
|
||||||
|
"problem_class": class_name,
|
||||||
|
"calls": calls,
|
||||||
|
"prompt_tokens_estimate": prompt_tokens,
|
||||||
|
"completion_tokens_estimate": completion_tokens,
|
||||||
|
"confidence": estimate.confidence,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if cost_per_1k_tokens > 0:
|
||||||
|
total_tokens = total_prompt_tokens + total_completion_tokens
|
||||||
|
cost = (total_tokens / 1000.0) * cost_per_1k_tokens
|
||||||
|
cost_source = "cost_per_1k_blended"
|
||||||
|
elif model:
|
||||||
|
cost_estimate = estimate_cost(
|
||||||
|
model,
|
||||||
|
total_prompt_tokens,
|
||||||
|
total_completion_tokens,
|
||||||
|
registry=rates,
|
||||||
|
)
|
||||||
|
cost = cost_estimate.cost_usd
|
||||||
|
cost_source = cost_estimate.cost_source
|
||||||
|
else:
|
||||||
|
cost = None
|
||||||
|
cost_source = None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"per_workflow": per_stage,
|
||||||
|
"total_prompt_tokens_estimate": total_prompt_tokens,
|
||||||
|
"estimated_completion_tokens": total_completion_tokens,
|
||||||
|
"estimated_cost_usd": round(cost, 6) if cost is not None else None,
|
||||||
|
"cost_source": cost_source,
|
||||||
|
...
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Workflow Mapping
|
||||||
|
|
||||||
|
Initial mapping can stay intentionally thin:
|
||||||
|
|
||||||
|
| infospace-bench workflow | llm-connect problem class |
|
||||||
|
|---|---|
|
||||||
|
| `summarize-source` | `chunk-summarization` |
|
||||||
|
| entity extraction workflows | `entity-extraction` |
|
||||||
|
| relation extraction workflows | `relation-extraction` |
|
||||||
|
| `generic-source-evaluations` | `judge-eval` |
|
||||||
|
| final report or rollup synthesis | `report-synthesis` |
|
||||||
|
|
||||||
|
The consumer still owns structure-specific dimensions such as selected chunk
|
||||||
|
counts, profile template word counts, and expected entities per chunk.
|
||||||
@@ -15,6 +15,7 @@ Quick start::
|
|||||||
from llm_connect.adapter import ErrorLLMAdapter, LLMAdapter, MockLLMAdapter
|
from llm_connect.adapter import ErrorLLMAdapter, LLMAdapter, MockLLMAdapter
|
||||||
from llm_connect.claude_code import ClaudeCodeAdapter
|
from llm_connect.claude_code import ClaudeCodeAdapter
|
||||||
from llm_connect.config import LLMConfig, load_config
|
from llm_connect.config import LLMConfig, load_config
|
||||||
|
from llm_connect.costs import CostEstimate, CostModel, estimate_cost
|
||||||
from llm_connect.embedding_adapter import EmbeddingAdapter
|
from llm_connect.embedding_adapter import EmbeddingAdapter
|
||||||
from llm_connect.embedding_cache import EmbeddingCache
|
from llm_connect.embedding_cache import EmbeddingCache
|
||||||
from llm_connect.embedding_factory import create_embedding_adapter
|
from llm_connect.embedding_factory import create_embedding_adapter
|
||||||
@@ -42,7 +43,20 @@ from llm_connect.grading import (
|
|||||||
from llm_connect.models import BudgetTracker, LLMResponse, RunConfig
|
from llm_connect.models import BudgetTracker, LLMResponse, RunConfig
|
||||||
from llm_connect.openai import OpenAIAdapter
|
from llm_connect.openai import OpenAIAdapter
|
||||||
from llm_connect.openrouter import OpenRouterAdapter
|
from llm_connect.openrouter import OpenRouterAdapter
|
||||||
|
from llm_connect.problem_classes import (
|
||||||
|
ChunkSummarizationProblemClass,
|
||||||
|
EntityExtractionProblemClass,
|
||||||
|
JudgeEvalProblemClass,
|
||||||
|
Observation,
|
||||||
|
ProblemClass,
|
||||||
|
ProblemClassRegistry,
|
||||||
|
RelationExtractionProblemClass,
|
||||||
|
ReportSynthesisProblemClass,
|
||||||
|
TokenEstimate,
|
||||||
|
default_problem_class_registry,
|
||||||
|
)
|
||||||
from llm_connect.quality import QualityLedger, QualityObservation, is_stale
|
from llm_connect.quality import QualityLedger, QualityObservation, is_stale
|
||||||
|
from llm_connect.rates import ModelRate, ModelRateRegistry
|
||||||
from llm_connect.routing import AdaptiveRoutingPolicy, RoutingPolicy, RoutingRule
|
from llm_connect.routing import AdaptiveRoutingPolicy, RoutingPolicy, RoutingRule
|
||||||
from llm_connect.server import LLMServer
|
from llm_connect.server import LLMServer
|
||||||
from llm_connect.shadowing import ShadowingAdapter
|
from llm_connect.shadowing import ShadowingAdapter
|
||||||
@@ -95,4 +109,19 @@ __all__ = [
|
|||||||
"AdaptiveRoutingPolicy",
|
"AdaptiveRoutingPolicy",
|
||||||
"ShadowingAdapter",
|
"ShadowingAdapter",
|
||||||
"LLMServer",
|
"LLMServer",
|
||||||
|
"ModelRate",
|
||||||
|
"ModelRateRegistry",
|
||||||
|
"CostEstimate",
|
||||||
|
"CostModel",
|
||||||
|
"estimate_cost",
|
||||||
|
"TokenEstimate",
|
||||||
|
"Observation",
|
||||||
|
"ProblemClass",
|
||||||
|
"ProblemClassRegistry",
|
||||||
|
"default_problem_class_registry",
|
||||||
|
"ChunkSummarizationProblemClass",
|
||||||
|
"EntityExtractionProblemClass",
|
||||||
|
"RelationExtractionProblemClass",
|
||||||
|
"JudgeEvalProblemClass",
|
||||||
|
"ReportSynthesisProblemClass",
|
||||||
]
|
]
|
||||||
|
|||||||
143
llm_connect/cli.py
Normal file
143
llm_connect/cli.py
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
"""Command-line helpers for llm-connect registries."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
from collections.abc import Iterable, Mapping
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llm_connect.problem_classes import ProblemClass, ProblemClassRegistry
|
||||||
|
from llm_connect.quality import QualityLedger
|
||||||
|
from llm_connect.rates import ModelRateRegistry
|
||||||
|
|
||||||
|
|
||||||
|
def main(argv: list[str] | None = None) -> int:
|
||||||
|
"""Run the ``llm-connect`` command."""
|
||||||
|
parser = _build_parser()
|
||||||
|
args = parser.parse_args(argv)
|
||||||
|
return int(args.func(args))
|
||||||
|
|
||||||
|
|
||||||
|
def _build_parser() -> argparse.ArgumentParser:
|
||||||
|
parser = argparse.ArgumentParser(prog="llm-connect")
|
||||||
|
commands = parser.add_subparsers(dest="command", required=True)
|
||||||
|
|
||||||
|
rates = commands.add_parser("rates", help="Inspect model rate registries")
|
||||||
|
rate_commands = rates.add_subparsers(dest="rates_command", required=True)
|
||||||
|
rate_show = rate_commands.add_parser("show", help="Show model rates")
|
||||||
|
rate_show.add_argument("--rates", type=Path, help="YAML registry overlay")
|
||||||
|
rate_show.add_argument("--json", action="store_true", help="Emit JSON")
|
||||||
|
rate_show.set_defaults(func=_rates_show)
|
||||||
|
|
||||||
|
classes = commands.add_parser("classes", help="Inspect problem classes")
|
||||||
|
class_commands = classes.add_subparsers(dest="classes_command", required=True)
|
||||||
|
class_show = class_commands.add_parser("show", help="Show problem classes")
|
||||||
|
class_show.add_argument("--json", action="store_true", help="Emit JSON")
|
||||||
|
class_show.set_defaults(func=_classes_show)
|
||||||
|
|
||||||
|
class_fit = class_commands.add_parser("fit", help="Fit problem-class params from a ledger")
|
||||||
|
class_fit.add_argument("ledger", type=Path, help="QualityLedger JSONL path")
|
||||||
|
class_fit.add_argument("--class", dest="class_name", help="Fit one class by name")
|
||||||
|
class_fit.add_argument("--min-observations", type=int, default=3)
|
||||||
|
class_fit.add_argument("--json", action="store_true", help="Emit JSON")
|
||||||
|
class_fit.set_defaults(func=_classes_fit)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def _rates_show(args: argparse.Namespace) -> int:
|
||||||
|
registry = ModelRateRegistry.default()
|
||||||
|
if args.rates:
|
||||||
|
registry = registry.merged_with(ModelRateRegistry.from_yaml(args.rates))
|
||||||
|
rates = registry.all()
|
||||||
|
if args.json:
|
||||||
|
print(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
model_id: {
|
||||||
|
"prompt_per_1k": rate.prompt_per_1k,
|
||||||
|
"completion_per_1k": rate.completion_per_1k,
|
||||||
|
"currency": rate.currency,
|
||||||
|
"source_url": rate.source_url,
|
||||||
|
"captured_at": rate.captured_at,
|
||||||
|
}
|
||||||
|
for model_id, rate in sorted(rates.items())
|
||||||
|
},
|
||||||
|
indent=2,
|
||||||
|
sort_keys=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
print("model_id\tprompt_per_1k\tcompletion_per_1k\tcurrency\tcaptured_at")
|
||||||
|
for model_id, rate in sorted(rates.items()):
|
||||||
|
print(
|
||||||
|
f"{model_id}\t{rate.prompt_per_1k:g}\t{rate.completion_per_1k:g}\t"
|
||||||
|
f"{rate.currency}\t{rate.captured_at}"
|
||||||
|
)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def _classes_show(args: argparse.Namespace) -> int:
|
||||||
|
classes = ProblemClassRegistry.default().all()
|
||||||
|
if args.json:
|
||||||
|
print(json.dumps(_classes_payload(classes.values()), indent=2, sort_keys=True))
|
||||||
|
return 0
|
||||||
|
|
||||||
|
print("name\tdimensions\ttunable_params\tcurrent_params")
|
||||||
|
for problem_class in sorted(classes.values(), key=lambda item: item.name):
|
||||||
|
print(
|
||||||
|
f"{problem_class.name}\t{', '.join(problem_class.base_dimensions)}\t"
|
||||||
|
f"{', '.join(problem_class.tunable_params)}\t{_format_params(problem_class.params)}"
|
||||||
|
)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def _classes_fit(args: argparse.Namespace) -> int:
|
||||||
|
if args.min_observations <= 0:
|
||||||
|
raise SystemExit("--min-observations must be positive")
|
||||||
|
registry = ProblemClassRegistry.default()
|
||||||
|
classes = registry.all()
|
||||||
|
if args.class_name:
|
||||||
|
problem_class = registry.get(args.class_name)
|
||||||
|
if problem_class is None:
|
||||||
|
raise SystemExit(f"Unknown problem class: {args.class_name}")
|
||||||
|
selected: list[ProblemClass] = [problem_class]
|
||||||
|
else:
|
||||||
|
selected = list(classes.values())
|
||||||
|
|
||||||
|
observations = QualityLedger(args.ledger).read_all()
|
||||||
|
fitted: list[ProblemClass] = [
|
||||||
|
problem_class.fit(observations, min_observations=args.min_observations)
|
||||||
|
for problem_class in selected
|
||||||
|
]
|
||||||
|
if args.json:
|
||||||
|
print(json.dumps(_classes_payload(fitted), indent=2, sort_keys=True))
|
||||||
|
return 0
|
||||||
|
|
||||||
|
print("name\tfitted_params\tconfidence")
|
||||||
|
for problem_class in sorted(fitted, key=lambda item: item.name):
|
||||||
|
confidence = getattr(problem_class, "confidence", 0.5)
|
||||||
|
print(f"{problem_class.name}\t{_format_params(problem_class.params)}\t{confidence:g}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def _classes_payload(classes: Iterable[ProblemClass]) -> dict[str, dict[str, Any]]:
|
||||||
|
return {
|
||||||
|
problem_class.name: {
|
||||||
|
"base_dimensions": list(problem_class.base_dimensions),
|
||||||
|
"tunable_params": list(problem_class.tunable_params),
|
||||||
|
"params": dict(problem_class.params),
|
||||||
|
"confidence": getattr(problem_class, "confidence", 0.5),
|
||||||
|
}
|
||||||
|
for problem_class in sorted(classes, key=lambda item: item.name)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _format_params(params: Mapping[str, float]) -> str:
|
||||||
|
return ", ".join(f"{key}={value:g}" for key, value in sorted(dict(params).items()))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise SystemExit(main())
|
||||||
74
llm_connect/costs.py
Normal file
74
llm_connect/costs.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
"""Cost estimation over model rates and token counts."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llm_connect.rates import ModelRateRegistry
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class CostEstimate:
|
||||||
|
"""Cost estimate split by prompt and completion token spend."""
|
||||||
|
|
||||||
|
cost_usd: float | None
|
||||||
|
cost_source: str
|
||||||
|
prompt_cost_usd: float | None = None
|
||||||
|
completion_cost_usd: float | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_cost(
|
||||||
|
model_id: str,
|
||||||
|
prompt_tokens: int,
|
||||||
|
completion_tokens: int = 0,
|
||||||
|
*,
|
||||||
|
registry: ModelRateRegistry | None = None,
|
||||||
|
) -> CostEstimate:
|
||||||
|
"""Estimate USD cost for token counts using *registry*.
|
||||||
|
|
||||||
|
Unknown models return ``CostEstimate(None, "unknown")`` so callers can
|
||||||
|
record uncertainty explicitly instead of treating missing prices as zero.
|
||||||
|
"""
|
||||||
|
prompt_count = _non_negative_int("prompt_tokens", prompt_tokens)
|
||||||
|
completion_count = _non_negative_int("completion_tokens", completion_tokens)
|
||||||
|
rates = registry or ModelRateRegistry.default()
|
||||||
|
rate = rates.get(model_id)
|
||||||
|
if rate is None:
|
||||||
|
return CostEstimate(cost_usd=None, cost_source="unknown")
|
||||||
|
|
||||||
|
prompt_cost = (prompt_count / 1000.0) * rate.prompt_per_1k
|
||||||
|
completion_cost = (completion_count / 1000.0) * rate.completion_per_1k
|
||||||
|
return CostEstimate(
|
||||||
|
cost_usd=prompt_cost + completion_cost,
|
||||||
|
cost_source=f"rate_table:{rate.model_id}",
|
||||||
|
prompt_cost_usd=prompt_cost,
|
||||||
|
completion_cost_usd=completion_cost,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class CostModel:
|
||||||
|
"""Small wrapper for callers that prefer an object over a free function."""
|
||||||
|
|
||||||
|
registry: ModelRateRegistry | None = None
|
||||||
|
|
||||||
|
def estimate_cost(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
prompt_tokens: int,
|
||||||
|
completion_tokens: int = 0,
|
||||||
|
) -> CostEstimate:
|
||||||
|
"""Estimate cost using this model's registry."""
|
||||||
|
return estimate_cost(
|
||||||
|
model_id,
|
||||||
|
prompt_tokens,
|
||||||
|
completion_tokens,
|
||||||
|
registry=self.registry,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _non_negative_int(name: str, value: Any) -> int:
|
||||||
|
if isinstance(value, bool) or not isinstance(value, int) or value < 0:
|
||||||
|
raise ValueError(f"{name} must be a non-negative integer")
|
||||||
|
return value
|
||||||
463
llm_connect/problem_classes.py
Normal file
463
llm_connect/problem_classes.py
Normal file
@@ -0,0 +1,463 @@
|
|||||||
|
"""Problem-class token estimators for common LLM workflow shapes."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Mapping, Sequence
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Protocol
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_WORDS_PER_TOKEN = 0.75
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class TokenEstimate:
|
||||||
|
"""Prompt/completion token estimate for a prospective LLM call."""
|
||||||
|
|
||||||
|
prompt_tokens: int
|
||||||
|
completion_tokens: int
|
||||||
|
confidence: float = 0.5
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
prompt_tokens = _non_negative_int("prompt_tokens", self.prompt_tokens)
|
||||||
|
completion_tokens = _non_negative_int("completion_tokens", self.completion_tokens)
|
||||||
|
confidence = _bounded_float("confidence", self.confidence)
|
||||||
|
object.__setattr__(self, "prompt_tokens", prompt_tokens)
|
||||||
|
object.__setattr__(self, "completion_tokens", completion_tokens)
|
||||||
|
object.__setattr__(self, "confidence", confidence)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Observation:
|
||||||
|
"""Actual token use paired with the problem dimensions that produced it."""
|
||||||
|
|
||||||
|
dimensions: dict[str, Any]
|
||||||
|
prompt_tokens: int
|
||||||
|
completion_tokens: int
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
object.__setattr__(self, "dimensions", dict(self.dimensions))
|
||||||
|
object.__setattr__(self, "prompt_tokens", _non_negative_int("prompt_tokens", self.prompt_tokens))
|
||||||
|
object.__setattr__(
|
||||||
|
self,
|
||||||
|
"completion_tokens",
|
||||||
|
_non_negative_int("completion_tokens", self.completion_tokens),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ProblemClass(Protocol):
|
||||||
|
"""Estimator contract implemented by built-in and consumer classes."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
base_dimensions: tuple[str, ...]
|
||||||
|
tunable_params: tuple[str, ...]
|
||||||
|
params: dict[str, float]
|
||||||
|
|
||||||
|
def estimate(
|
||||||
|
self,
|
||||||
|
dimensions: dict[str, Any],
|
||||||
|
params: dict[str, Any] | None = None,
|
||||||
|
) -> TokenEstimate:
|
||||||
|
"""Estimate token use from dimensions and optional parameter overrides."""
|
||||||
|
...
|
||||||
|
|
||||||
|
def fit(
|
||||||
|
self,
|
||||||
|
observations: Sequence[Any],
|
||||||
|
*,
|
||||||
|
min_observations: int = 3,
|
||||||
|
) -> "ProblemClass":
|
||||||
|
"""Return an estimator with params adapted from observed token use."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class ProblemClassRegistry:
|
||||||
|
"""Registry keyed by stable problem-class names."""
|
||||||
|
|
||||||
|
schema_version = 1
|
||||||
|
|
||||||
|
def __init__(self, classes: Sequence[ProblemClass] | None = None) -> None:
|
||||||
|
self._classes: dict[str, ProblemClass] = {}
|
||||||
|
for problem_class in classes or ():
|
||||||
|
self.register(problem_class)
|
||||||
|
|
||||||
|
def get(self, name: str) -> ProblemClass | None:
|
||||||
|
"""Return a registered class by name."""
|
||||||
|
return self._classes.get(str(name).strip())
|
||||||
|
|
||||||
|
def all(self) -> dict[str, ProblemClass]:
|
||||||
|
"""Return a copy of registered problem classes."""
|
||||||
|
return dict(self._classes)
|
||||||
|
|
||||||
|
def register(self, problem_class: ProblemClass, *, replace: bool = False) -> None:
|
||||||
|
"""Register *problem_class* under its name."""
|
||||||
|
name = str(problem_class.name).strip()
|
||||||
|
if not name:
|
||||||
|
raise ValueError("problem_class.name must be a non-empty string")
|
||||||
|
if name in self._classes and not replace:
|
||||||
|
raise ValueError(f"Problem class {name!r} is already registered")
|
||||||
|
self._classes[name] = problem_class
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def default(cls) -> "ProblemClassRegistry":
|
||||||
|
"""Return the built-in problem-class registry."""
|
||||||
|
return cls(
|
||||||
|
[
|
||||||
|
ChunkSummarizationProblemClass(),
|
||||||
|
EntityExtractionProblemClass(),
|
||||||
|
RelationExtractionProblemClass(),
|
||||||
|
JudgeEvalProblemClass(),
|
||||||
|
ReportSynthesisProblemClass(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _BaseProblemClass:
|
||||||
|
name = ""
|
||||||
|
base_dimensions: tuple[str, ...] = ()
|
||||||
|
tunable_params: tuple[str, ...] = ()
|
||||||
|
seed_params: Mapping[str, float] = {}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
params: Mapping[str, Any] | None = None,
|
||||||
|
confidence: float = 0.5,
|
||||||
|
) -> None:
|
||||||
|
merged = dict(self.seed_params)
|
||||||
|
for key, value in (params or {}).items():
|
||||||
|
if key not in self.tunable_params:
|
||||||
|
raise ValueError(f"Unknown parameter {key!r} for problem class {self.name!r}")
|
||||||
|
merged[key] = _non_negative_float(key, value)
|
||||||
|
self.params: dict[str, float] = merged
|
||||||
|
self.confidence = _bounded_float("confidence", confidence)
|
||||||
|
|
||||||
|
def estimate(
|
||||||
|
self,
|
||||||
|
dimensions: dict[str, Any],
|
||||||
|
params: dict[str, Any] | None = None,
|
||||||
|
) -> TokenEstimate:
|
||||||
|
dimensions = dict(dimensions)
|
||||||
|
self._validate_dimensions(dimensions)
|
||||||
|
merged_params = dict(self.params)
|
||||||
|
for key, value in (params or {}).items():
|
||||||
|
if key not in self.tunable_params:
|
||||||
|
raise ValueError(f"Unknown parameter {key!r} for problem class {self.name!r}")
|
||||||
|
merged_params[key] = _non_negative_float(key, value)
|
||||||
|
prompt_tokens, completion_tokens = self._estimate_tokens(dimensions, merged_params)
|
||||||
|
return TokenEstimate(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
confidence=self.confidence,
|
||||||
|
)
|
||||||
|
|
||||||
|
def fit(
|
||||||
|
self,
|
||||||
|
observations: Sequence[Any],
|
||||||
|
*,
|
||||||
|
min_observations: int = 3,
|
||||||
|
) -> ProblemClass:
|
||||||
|
if min_observations <= 0:
|
||||||
|
raise ValueError("min_observations must be positive")
|
||||||
|
parsed = [
|
||||||
|
observation
|
||||||
|
for observation in (
|
||||||
|
_coerce_observation(raw, self.name, self.base_dimensions) for raw in observations
|
||||||
|
)
|
||||||
|
if observation is not None
|
||||||
|
]
|
||||||
|
if len(parsed) < min_observations:
|
||||||
|
return self
|
||||||
|
|
||||||
|
fitted: dict[str, float] = {}
|
||||||
|
for param in self.tunable_params:
|
||||||
|
values = [
|
||||||
|
value
|
||||||
|
for value in (
|
||||||
|
self._infer_param(param, observation) for observation in parsed
|
||||||
|
)
|
||||||
|
if value is not None
|
||||||
|
]
|
||||||
|
if values:
|
||||||
|
fitted[param] = sum(values) / len(values)
|
||||||
|
if not fitted:
|
||||||
|
return self
|
||||||
|
|
||||||
|
confidence = min(0.95, max(self.confidence, len(parsed) / (len(parsed) + 5)))
|
||||||
|
return type(self)(params={**self.params, **fitted}, confidence=confidence)
|
||||||
|
|
||||||
|
def _validate_dimensions(self, dimensions: Mapping[str, Any]) -> None:
|
||||||
|
missing = [name for name in self.base_dimensions if name not in dimensions]
|
||||||
|
if missing:
|
||||||
|
raise ValueError(f"Missing dimensions for {self.name!r}: {', '.join(missing)}")
|
||||||
|
for name in self.base_dimensions:
|
||||||
|
_non_negative_float(name, dimensions[name])
|
||||||
|
|
||||||
|
def _estimate_tokens(
|
||||||
|
self,
|
||||||
|
dimensions: Mapping[str, Any],
|
||||||
|
params: Mapping[str, float],
|
||||||
|
) -> tuple[int, int]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _infer_param(self, param: str, observation: Observation) -> float | None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class ChunkSummarizationProblemClass(_BaseProblemClass):
|
||||||
|
name = "chunk-summarization"
|
||||||
|
base_dimensions: tuple[str, ...] = ("chunk_words", "template_words")
|
||||||
|
tunable_params: tuple[str, ...] = ("completion_ratio",)
|
||||||
|
seed_params: Mapping[str, float] = {"completion_ratio": 0.25}
|
||||||
|
|
||||||
|
def _estimate_tokens(
|
||||||
|
self,
|
||||||
|
dimensions: Mapping[str, Any],
|
||||||
|
params: Mapping[str, float],
|
||||||
|
) -> tuple[int, int]:
|
||||||
|
prompt_tokens = _words_to_tokens(
|
||||||
|
_dimension(dimensions, "chunk_words") + _dimension(dimensions, "template_words")
|
||||||
|
)
|
||||||
|
completion_tokens = _round_tokens(prompt_tokens * params["completion_ratio"])
|
||||||
|
return prompt_tokens, completion_tokens
|
||||||
|
|
||||||
|
def _infer_param(self, param: str, observation: Observation) -> float | None:
|
||||||
|
if param != "completion_ratio" or observation.prompt_tokens == 0:
|
||||||
|
return None
|
||||||
|
return observation.completion_tokens / observation.prompt_tokens
|
||||||
|
|
||||||
|
|
||||||
|
class EntityExtractionProblemClass(_BaseProblemClass):
|
||||||
|
name = "entity-extraction"
|
||||||
|
base_dimensions: tuple[str, ...] = ("chunk_words", "template_words", "expected_entities")
|
||||||
|
tunable_params: tuple[str, ...] = ("tokens_per_entity",)
|
||||||
|
seed_params: Mapping[str, float] = {"tokens_per_entity": 70.0}
|
||||||
|
|
||||||
|
def _estimate_tokens(
|
||||||
|
self,
|
||||||
|
dimensions: Mapping[str, Any],
|
||||||
|
params: Mapping[str, float],
|
||||||
|
) -> tuple[int, int]:
|
||||||
|
prompt_tokens = _words_to_tokens(
|
||||||
|
_dimension(dimensions, "chunk_words") + _dimension(dimensions, "template_words")
|
||||||
|
)
|
||||||
|
completion_tokens = _round_tokens(
|
||||||
|
_dimension(dimensions, "expected_entities") * params["tokens_per_entity"]
|
||||||
|
)
|
||||||
|
return prompt_tokens, completion_tokens
|
||||||
|
|
||||||
|
def _infer_param(self, param: str, observation: Observation) -> float | None:
|
||||||
|
expected_entities = _dimension(observation.dimensions, "expected_entities")
|
||||||
|
if param != "tokens_per_entity" or expected_entities <= 0:
|
||||||
|
return None
|
||||||
|
return observation.completion_tokens / expected_entities
|
||||||
|
|
||||||
|
|
||||||
|
class RelationExtractionProblemClass(_BaseProblemClass):
|
||||||
|
name = "relation-extraction"
|
||||||
|
base_dimensions: tuple[str, ...] = ("chunk_words", "template_words", "expected_relations")
|
||||||
|
tunable_params: tuple[str, ...] = ("tokens_per_relation",)
|
||||||
|
seed_params: Mapping[str, float] = {"tokens_per_relation": 80.0}
|
||||||
|
|
||||||
|
def _estimate_tokens(
|
||||||
|
self,
|
||||||
|
dimensions: Mapping[str, Any],
|
||||||
|
params: Mapping[str, float],
|
||||||
|
) -> tuple[int, int]:
|
||||||
|
prompt_tokens = _words_to_tokens(
|
||||||
|
_dimension(dimensions, "chunk_words") + _dimension(dimensions, "template_words")
|
||||||
|
)
|
||||||
|
completion_tokens = _round_tokens(
|
||||||
|
_dimension(dimensions, "expected_relations") * params["tokens_per_relation"]
|
||||||
|
)
|
||||||
|
return prompt_tokens, completion_tokens
|
||||||
|
|
||||||
|
def _infer_param(self, param: str, observation: Observation) -> float | None:
|
||||||
|
expected_relations = _dimension(observation.dimensions, "expected_relations")
|
||||||
|
if param != "tokens_per_relation" or expected_relations <= 0:
|
||||||
|
return None
|
||||||
|
return observation.completion_tokens / expected_relations
|
||||||
|
|
||||||
|
|
||||||
|
class JudgeEvalProblemClass(_BaseProblemClass):
|
||||||
|
name = "judge-eval"
|
||||||
|
base_dimensions: tuple[str, ...] = ("artifact_words", "template_words", "n_criteria")
|
||||||
|
tunable_params: tuple[str, ...] = ("tokens_per_criterion",)
|
||||||
|
seed_params: Mapping[str, float] = {"tokens_per_criterion": 35.0}
|
||||||
|
|
||||||
|
def _estimate_tokens(
|
||||||
|
self,
|
||||||
|
dimensions: Mapping[str, Any],
|
||||||
|
params: Mapping[str, float],
|
||||||
|
) -> tuple[int, int]:
|
||||||
|
prompt_tokens = _words_to_tokens(
|
||||||
|
_dimension(dimensions, "artifact_words") + _dimension(dimensions, "template_words")
|
||||||
|
)
|
||||||
|
completion_tokens = _round_tokens(
|
||||||
|
_dimension(dimensions, "n_criteria") * params["tokens_per_criterion"]
|
||||||
|
)
|
||||||
|
return prompt_tokens, completion_tokens
|
||||||
|
|
||||||
|
def _infer_param(self, param: str, observation: Observation) -> float | None:
|
||||||
|
n_criteria = _dimension(observation.dimensions, "n_criteria")
|
||||||
|
if param != "tokens_per_criterion" or n_criteria <= 0:
|
||||||
|
return None
|
||||||
|
return observation.completion_tokens / n_criteria
|
||||||
|
|
||||||
|
|
||||||
|
class ReportSynthesisProblemClass(_BaseProblemClass):
|
||||||
|
name = "report-synthesis"
|
||||||
|
base_dimensions: tuple[str, ...] = ("n_chunks", "n_entities", "n_relations", "template_words")
|
||||||
|
tunable_params: tuple[str, ...] = ("base_completion_tokens",)
|
||||||
|
seed_params: Mapping[str, float] = {"base_completion_tokens": 400.0}
|
||||||
|
|
||||||
|
def _estimate_tokens(
|
||||||
|
self,
|
||||||
|
dimensions: Mapping[str, Any],
|
||||||
|
params: Mapping[str, float],
|
||||||
|
) -> tuple[int, int]:
|
||||||
|
prompt_tokens = _words_to_tokens(_dimension(dimensions, "template_words"))
|
||||||
|
prompt_tokens += _round_tokens(_dimension(dimensions, "n_chunks") * 40)
|
||||||
|
prompt_tokens += _round_tokens(_dimension(dimensions, "n_entities") * 25)
|
||||||
|
prompt_tokens += _round_tokens(_dimension(dimensions, "n_relations") * 35)
|
||||||
|
return prompt_tokens, _round_tokens(params["base_completion_tokens"])
|
||||||
|
|
||||||
|
def _infer_param(self, param: str, observation: Observation) -> float | None:
|
||||||
|
if param != "base_completion_tokens":
|
||||||
|
return None
|
||||||
|
return float(observation.completion_tokens)
|
||||||
|
|
||||||
|
|
||||||
|
def default_problem_class_registry() -> ProblemClassRegistry:
|
||||||
|
"""Return the built-in problem-class registry."""
|
||||||
|
return ProblemClassRegistry.default()
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_observation(
|
||||||
|
raw: Any,
|
||||||
|
class_name: str,
|
||||||
|
required_dimensions: tuple[str, ...],
|
||||||
|
) -> Observation | None:
|
||||||
|
try:
|
||||||
|
if isinstance(raw, Observation):
|
||||||
|
return raw
|
||||||
|
if isinstance(raw, Mapping):
|
||||||
|
return _coerce_mapping_observation(raw, class_name, required_dimensions)
|
||||||
|
return _coerce_object_observation(raw, class_name, required_dimensions)
|
||||||
|
except (KeyError, TypeError, ValueError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_mapping_observation(
|
||||||
|
raw: Mapping[str, Any],
|
||||||
|
class_name: str,
|
||||||
|
required_dimensions: tuple[str, ...],
|
||||||
|
) -> Observation | None:
|
||||||
|
raw_tags = raw.get("tags")
|
||||||
|
tags: Mapping[str, Any] = raw_tags if isinstance(raw_tags, Mapping) else {}
|
||||||
|
problem_class = raw.get("problem_class") or tags.get("problem_class")
|
||||||
|
if problem_class is not None and str(problem_class) != class_name:
|
||||||
|
return None
|
||||||
|
dimensions = _dimensions_from_sources(required_dimensions, raw, tags)
|
||||||
|
prompt_tokens = _token_value(raw, "prompt_tokens", "tokens_in", "actual_prompt_tokens")
|
||||||
|
completion_tokens = _token_value(
|
||||||
|
raw,
|
||||||
|
"completion_tokens",
|
||||||
|
"tokens_out",
|
||||||
|
"actual_completion_tokens",
|
||||||
|
)
|
||||||
|
return Observation(dimensions, prompt_tokens, completion_tokens)
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_object_observation(
|
||||||
|
raw: Any,
|
||||||
|
class_name: str,
|
||||||
|
required_dimensions: tuple[str, ...],
|
||||||
|
) -> Observation | None:
|
||||||
|
raw_tags = getattr(raw, "tags", {}) or {}
|
||||||
|
tags: Mapping[str, Any] = raw_tags if isinstance(raw_tags, Mapping) else {}
|
||||||
|
problem_class = tags.get("problem_class")
|
||||||
|
if problem_class is not None and str(problem_class) != class_name:
|
||||||
|
return None
|
||||||
|
dimensions = _dimensions_from_sources(required_dimensions, tags)
|
||||||
|
return Observation(
|
||||||
|
dimensions=dimensions,
|
||||||
|
prompt_tokens=getattr(raw, "tokens_in"),
|
||||||
|
completion_tokens=getattr(raw, "tokens_out"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _dimensions_from_sources(
|
||||||
|
required_dimensions: tuple[str, ...],
|
||||||
|
*sources: Mapping[str, Any],
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
for source in sources:
|
||||||
|
candidate = source.get("dimensions")
|
||||||
|
if isinstance(candidate, Mapping):
|
||||||
|
return dict(candidate)
|
||||||
|
dimensions: dict[str, Any] = {}
|
||||||
|
for name in required_dimensions:
|
||||||
|
for source in sources:
|
||||||
|
if name in source:
|
||||||
|
dimensions[name] = source[name]
|
||||||
|
break
|
||||||
|
if len(dimensions) != len(required_dimensions):
|
||||||
|
raise ValueError("observation is missing required dimensions")
|
||||||
|
return dimensions
|
||||||
|
|
||||||
|
|
||||||
|
def _token_value(raw: Mapping[str, Any], *names: str) -> int:
|
||||||
|
for name in names:
|
||||||
|
if name in raw:
|
||||||
|
return _non_negative_int(name, raw[name])
|
||||||
|
usage = raw.get("usage")
|
||||||
|
if isinstance(usage, Mapping):
|
||||||
|
for name in names:
|
||||||
|
if name in usage:
|
||||||
|
return _non_negative_int(name, usage[name])
|
||||||
|
raise KeyError(names[0])
|
||||||
|
|
||||||
|
|
||||||
|
def _dimension(dimensions: Mapping[str, Any], name: str) -> float:
|
||||||
|
return _non_negative_float(name, dimensions[name])
|
||||||
|
|
||||||
|
|
||||||
|
def _words_to_tokens(words: float) -> int:
|
||||||
|
if words == 0:
|
||||||
|
return 0
|
||||||
|
return max(1, _round_tokens(words / DEFAULT_WORDS_PER_TOKEN))
|
||||||
|
|
||||||
|
|
||||||
|
def _round_tokens(value: float) -> int:
|
||||||
|
return max(0, int(round(value)))
|
||||||
|
|
||||||
|
|
||||||
|
def _non_negative_int(name: str, value: Any) -> int:
|
||||||
|
if isinstance(value, bool):
|
||||||
|
raise ValueError(f"{name} must be a non-negative integer")
|
||||||
|
try:
|
||||||
|
integer = int(value)
|
||||||
|
except (TypeError, ValueError) as exc:
|
||||||
|
raise ValueError(f"{name} must be a non-negative integer") from exc
|
||||||
|
if integer < 0 or integer != float(value):
|
||||||
|
raise ValueError(f"{name} must be a non-negative integer")
|
||||||
|
return integer
|
||||||
|
|
||||||
|
|
||||||
|
def _non_negative_float(name: str, value: Any) -> float:
|
||||||
|
if isinstance(value, bool):
|
||||||
|
raise ValueError(f"{name} must be a non-negative number")
|
||||||
|
try:
|
||||||
|
number = float(value)
|
||||||
|
except (TypeError, ValueError) as exc:
|
||||||
|
raise ValueError(f"{name} must be a non-negative number") from exc
|
||||||
|
if number < 0:
|
||||||
|
raise ValueError(f"{name} must be a non-negative number")
|
||||||
|
return number
|
||||||
|
|
||||||
|
|
||||||
|
def _bounded_float(name: str, value: Any) -> float:
|
||||||
|
number = _non_negative_float(name, value)
|
||||||
|
if number > 1:
|
||||||
|
raise ValueError(f"{name} must be between 0 and 1")
|
||||||
|
return number
|
||||||
273
llm_connect/rates.py
Normal file
273
llm_connect/rates.py
Normal file
@@ -0,0 +1,273 @@
|
|||||||
|
"""Model rate registry for preview and post-hoc cost estimation."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Mapping
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_RATE_SOURCE_URL = "https://openrouter.ai/models"
|
||||||
|
DEFAULT_RATE_CAPTURED_AT = "2026-05-17"
|
||||||
|
DEFAULT_RATE_CURRENCY = "USD"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ModelRate:
|
||||||
|
"""USD-denominated list price for one model."""
|
||||||
|
|
||||||
|
model_id: str
|
||||||
|
prompt_per_1k: float
|
||||||
|
completion_per_1k: float
|
||||||
|
currency: str = DEFAULT_RATE_CURRENCY
|
||||||
|
source_url: str = ""
|
||||||
|
captured_at: str = ""
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
model_id = str(self.model_id).strip()
|
||||||
|
currency = str(self.currency or DEFAULT_RATE_CURRENCY).strip().upper()
|
||||||
|
if not model_id:
|
||||||
|
raise ValueError("model_id must be a non-empty string")
|
||||||
|
if not currency:
|
||||||
|
raise ValueError("currency must be a non-empty string")
|
||||||
|
prompt_rate = _non_negative_float("prompt_per_1k", self.prompt_per_1k)
|
||||||
|
completion_rate = _non_negative_float("completion_per_1k", self.completion_per_1k)
|
||||||
|
|
||||||
|
object.__setattr__(self, "model_id", model_id)
|
||||||
|
object.__setattr__(self, "prompt_per_1k", prompt_rate)
|
||||||
|
object.__setattr__(self, "completion_per_1k", completion_rate)
|
||||||
|
object.__setattr__(self, "currency", currency)
|
||||||
|
object.__setattr__(self, "source_url", str(self.source_url or ""))
|
||||||
|
object.__setattr__(self, "captured_at", str(self.captured_at or ""))
|
||||||
|
|
||||||
|
|
||||||
|
class ModelRateRegistry:
|
||||||
|
"""Lookup table for model list prices."""
|
||||||
|
|
||||||
|
def __init__(self, rates: Mapping[str, ModelRate | Mapping[str, Any]] | None = None) -> None:
|
||||||
|
self._rates: dict[str, ModelRate] = {}
|
||||||
|
for model_id, rate in (rates or {}).items():
|
||||||
|
model_rate = _coerce_rate(model_id, rate)
|
||||||
|
self._rates[model_rate.model_id] = model_rate
|
||||||
|
|
||||||
|
def get(self, model_id: str) -> ModelRate | None:
|
||||||
|
"""Return the rate for *model_id*, or ``None`` when absent."""
|
||||||
|
return self._rates.get(str(model_id).strip())
|
||||||
|
|
||||||
|
def all(self) -> dict[str, ModelRate]:
|
||||||
|
"""Return a copy of the registry mapping."""
|
||||||
|
return dict(self._rates)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def default(cls) -> "ModelRateRegistry":
|
||||||
|
"""Return the bundled OpenRouter list-price snapshot."""
|
||||||
|
return cls(_default_rate_payload())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_yaml(cls, path: Path | str) -> "ModelRateRegistry":
|
||||||
|
"""Load rates from a YAML file.
|
||||||
|
|
||||||
|
The expected shape matches the historic infospace-bench table::
|
||||||
|
|
||||||
|
currency: USD
|
||||||
|
source_url: https://openrouter.ai/models
|
||||||
|
captured_at: "2026-05-17"
|
||||||
|
rates:
|
||||||
|
openai/gpt-4o-mini:
|
||||||
|
prompt_per_1k: 0.00015
|
||||||
|
completion_per_1k: 0.00060
|
||||||
|
|
||||||
|
PyYAML is used when installed; otherwise a small parser handles this
|
||||||
|
schema so llm-connect keeps its current lightweight dependency surface.
|
||||||
|
"""
|
||||||
|
payload = _load_yaml_mapping(Path(path))
|
||||||
|
return cls(_rates_from_payload(payload))
|
||||||
|
|
||||||
|
def merged_with(self, override: "ModelRateRegistry") -> "ModelRateRegistry":
|
||||||
|
"""Return a new registry where *override* entries win by model id."""
|
||||||
|
merged = self.all()
|
||||||
|
merged.update(override.all())
|
||||||
|
return ModelRateRegistry(merged)
|
||||||
|
|
||||||
|
|
||||||
|
_DEFAULT_RATES: dict[str, tuple[float, float]] = {
|
||||||
|
"openai/gpt-4o-mini": (0.00015, 0.00060),
|
||||||
|
"openai/gpt-4o": (0.0025, 0.01),
|
||||||
|
"openai/gpt-4-turbo": (0.01, 0.03),
|
||||||
|
"anthropic/claude-3.5-sonnet": (0.003, 0.015),
|
||||||
|
"anthropic/claude-3.5-haiku": (0.0008, 0.004),
|
||||||
|
"anthropic/claude-3-opus": (0.015, 0.075),
|
||||||
|
"google/gemini-1.5-flash": (0.000075, 0.0003),
|
||||||
|
"google/gemini-1.5-pro": (0.00125, 0.005),
|
||||||
|
"meta-llama/llama-3.1-70b-instruct": (0.00059, 0.00079),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _default_rate_payload() -> dict[str, ModelRate]:
|
||||||
|
return {
|
||||||
|
model_id: ModelRate(
|
||||||
|
model_id=model_id,
|
||||||
|
prompt_per_1k=prompt_rate,
|
||||||
|
completion_per_1k=completion_rate,
|
||||||
|
currency=DEFAULT_RATE_CURRENCY,
|
||||||
|
source_url=DEFAULT_RATE_SOURCE_URL,
|
||||||
|
captured_at=DEFAULT_RATE_CAPTURED_AT,
|
||||||
|
)
|
||||||
|
for model_id, (prompt_rate, completion_rate) in _DEFAULT_RATES.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_rate(model_id: str, rate: ModelRate | Mapping[str, Any]) -> ModelRate:
|
||||||
|
if isinstance(rate, ModelRate):
|
||||||
|
return rate
|
||||||
|
if not isinstance(rate, Mapping):
|
||||||
|
raise TypeError(f"Rate for {model_id!r} must be a ModelRate or mapping")
|
||||||
|
return ModelRate(
|
||||||
|
model_id=str(model_id),
|
||||||
|
prompt_per_1k=rate["prompt_per_1k"],
|
||||||
|
completion_per_1k=rate["completion_per_1k"],
|
||||||
|
currency=str(rate.get("currency") or DEFAULT_RATE_CURRENCY),
|
||||||
|
source_url=str(rate.get("source_url") or ""),
|
||||||
|
captured_at=str(rate.get("captured_at") or ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _rates_from_payload(payload: Mapping[str, Any]) -> dict[str, ModelRate]:
|
||||||
|
rates_payload = payload.get("rates")
|
||||||
|
if not isinstance(rates_payload, Mapping):
|
||||||
|
raise ValueError("Rate YAML must contain a 'rates' mapping")
|
||||||
|
|
||||||
|
currency = str(payload.get("currency") or DEFAULT_RATE_CURRENCY)
|
||||||
|
source_url = str(payload.get("source_url") or "")
|
||||||
|
captured_at = str(payload.get("captured_at") or "")
|
||||||
|
rates: dict[str, ModelRate] = {}
|
||||||
|
for model_id, raw_rate in rates_payload.items():
|
||||||
|
if not isinstance(raw_rate, Mapping):
|
||||||
|
raise ValueError(f"Rate entry for {model_id!r} must be a mapping")
|
||||||
|
rates[str(model_id)] = ModelRate(
|
||||||
|
model_id=str(model_id),
|
||||||
|
prompt_per_1k=raw_rate["prompt_per_1k"],
|
||||||
|
completion_per_1k=raw_rate["completion_per_1k"],
|
||||||
|
currency=str(raw_rate.get("currency") or currency),
|
||||||
|
source_url=str(raw_rate.get("source_url") or source_url),
|
||||||
|
captured_at=str(raw_rate.get("captured_at") or captured_at),
|
||||||
|
)
|
||||||
|
return rates
|
||||||
|
|
||||||
|
|
||||||
|
def _non_negative_float(name: str, value: Any) -> float:
|
||||||
|
if isinstance(value, bool):
|
||||||
|
raise ValueError(f"{name} must be a non-negative number")
|
||||||
|
try:
|
||||||
|
number = float(value)
|
||||||
|
except (TypeError, ValueError) as exc:
|
||||||
|
raise ValueError(f"{name} must be a non-negative number") from exc
|
||||||
|
if number < 0:
|
||||||
|
raise ValueError(f"{name} must be a non-negative number")
|
||||||
|
return number
|
||||||
|
|
||||||
|
|
||||||
|
def _load_yaml_mapping(path: Path) -> Mapping[str, Any]:
|
||||||
|
try:
|
||||||
|
import yaml
|
||||||
|
except ImportError:
|
||||||
|
return _parse_rate_yaml(path.read_text(encoding="utf-8"))
|
||||||
|
|
||||||
|
data = yaml.safe_load(path.read_text(encoding="utf-8")) or {}
|
||||||
|
if not isinstance(data, Mapping):
|
||||||
|
raise ValueError("Rate YAML root must be a mapping")
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_rate_yaml(text: str) -> dict[str, Any]:
|
||||||
|
lines: list[tuple[int, str]] = []
|
||||||
|
for raw_line in text.splitlines():
|
||||||
|
line = _normalise_yaml_line(raw_line)
|
||||||
|
if line is not None:
|
||||||
|
lines.append(line)
|
||||||
|
data: dict[str, Any] = {}
|
||||||
|
index = 0
|
||||||
|
while index < len(lines):
|
||||||
|
indent, content = lines[index]
|
||||||
|
if indent != 0:
|
||||||
|
raise ValueError("Only top-level mappings are supported in rate YAML")
|
||||||
|
key, raw_value = _split_yaml_key_value(content)
|
||||||
|
if key == "rates" and raw_value == "":
|
||||||
|
rates, index = _parse_rates_block(lines, index + 1)
|
||||||
|
data["rates"] = rates
|
||||||
|
continue
|
||||||
|
data[key] = _parse_yaml_scalar(raw_value)
|
||||||
|
index += 1
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_rates_block(
|
||||||
|
lines: list[tuple[int, str]],
|
||||||
|
index: int,
|
||||||
|
) -> tuple[dict[str, dict[str, Any]], int]:
|
||||||
|
rates: dict[str, dict[str, Any]] = {}
|
||||||
|
while index < len(lines):
|
||||||
|
indent, content = lines[index]
|
||||||
|
if indent == 0:
|
||||||
|
break
|
||||||
|
if indent != 2:
|
||||||
|
raise ValueError("Rate model entries must be indented by two spaces")
|
||||||
|
model_id, raw_value = _split_yaml_key_value(content)
|
||||||
|
if raw_value:
|
||||||
|
raise ValueError(f"Rate entry for {model_id!r} must be a nested mapping")
|
||||||
|
entry: dict[str, Any] = {}
|
||||||
|
index += 1
|
||||||
|
while index < len(lines):
|
||||||
|
child_indent, child_content = lines[index]
|
||||||
|
if child_indent <= indent:
|
||||||
|
break
|
||||||
|
if child_indent != 4:
|
||||||
|
raise ValueError("Rate fields must be indented by four spaces")
|
||||||
|
child_key, child_value = _split_yaml_key_value(child_content)
|
||||||
|
entry[child_key] = _parse_yaml_scalar(child_value)
|
||||||
|
index += 1
|
||||||
|
rates[model_id] = entry
|
||||||
|
return rates, index
|
||||||
|
|
||||||
|
|
||||||
|
def _normalise_yaml_line(line: str) -> tuple[int, str] | None:
|
||||||
|
stripped = _strip_yaml_comment(line.rstrip())
|
||||||
|
if not stripped.strip():
|
||||||
|
return None
|
||||||
|
indent = len(stripped) - len(stripped.lstrip(" "))
|
||||||
|
return indent, stripped.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_yaml_comment(line: str) -> str:
|
||||||
|
quote: str | None = None
|
||||||
|
for index, char in enumerate(line):
|
||||||
|
if char in {"'", '"'}:
|
||||||
|
quote = None if quote == char else char if quote is None else quote
|
||||||
|
elif char == "#" and quote is None:
|
||||||
|
return line[:index]
|
||||||
|
return line
|
||||||
|
|
||||||
|
|
||||||
|
def _split_yaml_key_value(content: str) -> tuple[str, str]:
|
||||||
|
key, separator, value = content.partition(":")
|
||||||
|
if not separator:
|
||||||
|
raise ValueError(f"Invalid YAML mapping line: {content!r}")
|
||||||
|
return key.strip().strip("'\""), value.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_yaml_scalar(value: str) -> Any:
|
||||||
|
if value == "":
|
||||||
|
return ""
|
||||||
|
if (value.startswith('"') and value.endswith('"')) or (
|
||||||
|
value.startswith("'") and value.endswith("'")
|
||||||
|
):
|
||||||
|
return value[1:-1]
|
||||||
|
if value.lower() in {"null", "none", "~"}:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
if any(char in value for char in (".", "e", "E")):
|
||||||
|
return float(value)
|
||||||
|
return int(value)
|
||||||
|
except ValueError:
|
||||||
|
return value
|
||||||
@@ -11,6 +11,9 @@ dependencies = [
|
|||||||
"toml",
|
"toml",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
llm-connect = "llm_connect.cli:main"
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
dev = [
|
dev = [
|
||||||
"pytest>=7.0",
|
"pytest>=7.0",
|
||||||
|
|||||||
54
tests/test_cli.py
Normal file
54
tests/test_cli.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
import json
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from llm_connect.cli import main
|
||||||
|
from llm_connect.quality import QualityLedger, QualityObservation
|
||||||
|
|
||||||
|
|
||||||
|
def test_rates_show_json_outputs_default_registry(capsys):
|
||||||
|
assert main(["rates", "show", "--json"]) == 0
|
||||||
|
|
||||||
|
payload = json.loads(capsys.readouterr().out)
|
||||||
|
|
||||||
|
assert payload["openai/gpt-4o-mini"]["prompt_per_1k"] == 0.00015
|
||||||
|
|
||||||
|
|
||||||
|
def test_classes_show_lists_builtins(capsys):
|
||||||
|
assert main(["classes", "show"]) == 0
|
||||||
|
|
||||||
|
output = capsys.readouterr().out
|
||||||
|
|
||||||
|
assert "chunk-summarization" in output
|
||||||
|
assert "entity-extraction" in output
|
||||||
|
|
||||||
|
|
||||||
|
def test_classes_fit_reads_quality_ledger(tmp_path, capsys):
|
||||||
|
ledger = QualityLedger(tmp_path / "quality.jsonl")
|
||||||
|
for _ in range(3):
|
||||||
|
ledger.append(
|
||||||
|
QualityObservation(
|
||||||
|
task_type="extract",
|
||||||
|
adapter_id="openrouter",
|
||||||
|
model_id="openai/gpt-4o-mini",
|
||||||
|
cost_usd=0.001,
|
||||||
|
quality_score=0.9,
|
||||||
|
latency_ms=100,
|
||||||
|
tokens_in=500,
|
||||||
|
tokens_out=350,
|
||||||
|
recorded_at=datetime(2026, 5, 19, tzinfo=timezone.utc),
|
||||||
|
tags={
|
||||||
|
"problem_class": "entity-extraction",
|
||||||
|
"dimensions": {
|
||||||
|
"chunk_words": 300,
|
||||||
|
"template_words": 100,
|
||||||
|
"expected_entities": 5,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert main(["classes", "fit", str(ledger.path), "--class", "entity-extraction", "--json"]) == 0
|
||||||
|
|
||||||
|
payload = json.loads(capsys.readouterr().out)
|
||||||
|
|
||||||
|
assert payload["entity-extraction"]["params"]["tokens_per_entity"] == 70
|
||||||
49
tests/test_costs.py
Normal file
49
tests/test_costs.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from llm_connect.costs import CostEstimate, CostModel, estimate_cost
|
||||||
|
from llm_connect.rates import ModelRate, ModelRateRegistry
|
||||||
|
|
||||||
|
|
||||||
|
def test_known_model_cost_matches_lefevre_smoke_budget():
|
||||||
|
estimate = estimate_cost("openai/gpt-4o-mini", 28_000, 7_500)
|
||||||
|
|
||||||
|
assert estimate.cost_source == "rate_table:openai/gpt-4o-mini"
|
||||||
|
assert estimate.cost_usd == pytest.approx(0.0087)
|
||||||
|
assert estimate.cost_usd == pytest.approx(0.009, rel=0.2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_unknown_model_returns_unknown_without_zeroing_cost():
|
||||||
|
estimate = estimate_cost("unknown/model", 100, 50)
|
||||||
|
|
||||||
|
assert estimate == CostEstimate(cost_usd=None, cost_source="unknown")
|
||||||
|
|
||||||
|
|
||||||
|
def test_registry_override_controls_estimate():
|
||||||
|
registry = ModelRateRegistry(
|
||||||
|
{
|
||||||
|
"vendor/model": ModelRate(
|
||||||
|
"vendor/model",
|
||||||
|
prompt_per_1k=1.0,
|
||||||
|
completion_per_1k=2.0,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
estimate = estimate_cost("vendor/model", 1_000, 500, registry=registry)
|
||||||
|
|
||||||
|
assert estimate.cost_usd == pytest.approx(2.0)
|
||||||
|
assert estimate.prompt_cost_usd == pytest.approx(1.0)
|
||||||
|
assert estimate.completion_cost_usd == pytest.approx(1.0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_zero_tokens_are_valid_and_cost_zero_for_known_model():
|
||||||
|
estimate = CostModel().estimate_cost("openai/gpt-4o-mini", 0, 0)
|
||||||
|
|
||||||
|
assert estimate.cost_usd == 0
|
||||||
|
assert estimate.prompt_cost_usd == 0
|
||||||
|
assert estimate.completion_cost_usd == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_negative_tokens_are_rejected():
|
||||||
|
with pytest.raises(ValueError, match="prompt_tokens"):
|
||||||
|
estimate_cost("openai/gpt-4o-mini", -1, 0)
|
||||||
@@ -24,3 +24,27 @@ def test_wp_0004_primitives_are_exported_from_package_root():
|
|||||||
for name in expected_names:
|
for name in expected_names:
|
||||||
assert hasattr(llm_connect, name)
|
assert hasattr(llm_connect, name)
|
||||||
assert name in llm_connect.__all__
|
assert name in llm_connect.__all__
|
||||||
|
|
||||||
|
|
||||||
|
def test_wp_0005_primitives_are_exported_from_package_root():
|
||||||
|
expected_names = [
|
||||||
|
"ModelRate",
|
||||||
|
"ModelRateRegistry",
|
||||||
|
"CostEstimate",
|
||||||
|
"CostModel",
|
||||||
|
"estimate_cost",
|
||||||
|
"TokenEstimate",
|
||||||
|
"Observation",
|
||||||
|
"ProblemClass",
|
||||||
|
"ProblemClassRegistry",
|
||||||
|
"default_problem_class_registry",
|
||||||
|
"ChunkSummarizationProblemClass",
|
||||||
|
"EntityExtractionProblemClass",
|
||||||
|
"RelationExtractionProblemClass",
|
||||||
|
"JudgeEvalProblemClass",
|
||||||
|
"ReportSynthesisProblemClass",
|
||||||
|
]
|
||||||
|
|
||||||
|
for name in expected_names:
|
||||||
|
assert hasattr(llm_connect, name)
|
||||||
|
assert name in llm_connect.__all__
|
||||||
|
|||||||
137
tests/test_problem_classes.py
Normal file
137
tests/test_problem_classes.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llm_connect.problem_classes import (
|
||||||
|
EntityExtractionProblemClass,
|
||||||
|
Observation,
|
||||||
|
ProblemClassRegistry,
|
||||||
|
TokenEstimate,
|
||||||
|
)
|
||||||
|
from llm_connect.quality import QualityObservation
|
||||||
|
|
||||||
|
|
||||||
|
DIMENSIONS_BY_CLASS = {
|
||||||
|
"chunk-summarization": [
|
||||||
|
{"chunk_words": 900, "template_words": 150},
|
||||||
|
{"chunk_words": 400, "template_words": 125},
|
||||||
|
{"chunk_words": 1200, "template_words": 200},
|
||||||
|
],
|
||||||
|
"entity-extraction": [
|
||||||
|
{"chunk_words": 900, "template_words": 200, "expected_entities": 4},
|
||||||
|
{"chunk_words": 450, "template_words": 180, "expected_entities": 6},
|
||||||
|
{"chunk_words": 1200, "template_words": 220, "expected_entities": 8},
|
||||||
|
],
|
||||||
|
"relation-extraction": [
|
||||||
|
{"chunk_words": 900, "template_words": 200, "expected_relations": 3},
|
||||||
|
{"chunk_words": 450, "template_words": 180, "expected_relations": 5},
|
||||||
|
{"chunk_words": 1200, "template_words": 220, "expected_relations": 7},
|
||||||
|
],
|
||||||
|
"judge-eval": [
|
||||||
|
{"artifact_words": 700, "template_words": 180, "n_criteria": 4},
|
||||||
|
{"artifact_words": 300, "template_words": 160, "n_criteria": 5},
|
||||||
|
{"artifact_words": 1100, "template_words": 200, "n_criteria": 6},
|
||||||
|
],
|
||||||
|
"report-synthesis": [
|
||||||
|
{"n_chunks": 5, "n_entities": 20, "n_relations": 8, "template_words": 250},
|
||||||
|
{"n_chunks": 8, "n_entities": 30, "n_relations": 12, "template_words": 250},
|
||||||
|
{"n_chunks": 2, "n_entities": 10, "n_relations": 3, "template_words": 180},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_registry_exposes_builtin_classes():
|
||||||
|
registry = ProblemClassRegistry.default()
|
||||||
|
|
||||||
|
assert set(registry.all()) == set(DIMENSIONS_BY_CLASS)
|
||||||
|
assert registry.schema_version == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("name,dimensions_list", DIMENSIONS_BY_CLASS.items())
|
||||||
|
def test_builtin_estimators_produce_token_estimates(name, dimensions_list):
|
||||||
|
problem_class = ProblemClassRegistry.default().get(name)
|
||||||
|
|
||||||
|
estimate = problem_class.estimate(dimensions_list[0])
|
||||||
|
|
||||||
|
assert isinstance(estimate, TokenEstimate)
|
||||||
|
assert estimate.prompt_tokens >= 0
|
||||||
|
assert estimate.completion_tokens >= 0
|
||||||
|
assert 0 <= estimate.confidence <= 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("name,dimensions_list", DIMENSIONS_BY_CLASS.items())
|
||||||
|
def test_fit_recovers_seeded_params_from_synthetic_observations(name, dimensions_list):
|
||||||
|
seeded = ProblemClassRegistry.default().get(name)
|
||||||
|
param_name = seeded.tunable_params[0]
|
||||||
|
off_seed = type(seeded)(params={param_name: seeded.params[param_name] * 2})
|
||||||
|
observations = []
|
||||||
|
for dimensions in dimensions_list:
|
||||||
|
estimate = seeded.estimate(dimensions)
|
||||||
|
observations.append(
|
||||||
|
Observation(
|
||||||
|
dimensions=dimensions,
|
||||||
|
prompt_tokens=estimate.prompt_tokens,
|
||||||
|
completion_tokens=estimate.completion_tokens,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
fitted = off_seed.fit(observations, min_observations=3)
|
||||||
|
|
||||||
|
assert fitted.params[param_name] == pytest.approx(seeded.params[param_name], rel=0.1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fit_uses_quality_ledger_observation_shape():
|
||||||
|
problem_class = EntityExtractionProblemClass(params={"tokens_per_entity": 10})
|
||||||
|
observations = [
|
||||||
|
QualityObservation(
|
||||||
|
task_type="extract",
|
||||||
|
adapter_id="openrouter",
|
||||||
|
model_id="openai/gpt-4o-mini",
|
||||||
|
cost_usd=0.001,
|
||||||
|
quality_score=0.9,
|
||||||
|
latency_ms=100,
|
||||||
|
tokens_in=500,
|
||||||
|
tokens_out=350,
|
||||||
|
recorded_at=datetime(2026, 5, 19, tzinfo=timezone.utc),
|
||||||
|
tags={
|
||||||
|
"problem_class": "entity-extraction",
|
||||||
|
"dimensions": {
|
||||||
|
"chunk_words": 300,
|
||||||
|
"template_words": 100,
|
||||||
|
"expected_entities": 5,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for _ in range(3)
|
||||||
|
]
|
||||||
|
|
||||||
|
fitted = problem_class.fit(observations)
|
||||||
|
|
||||||
|
assert fitted.params["tokens_per_entity"] == pytest.approx(70)
|
||||||
|
|
||||||
|
|
||||||
|
def test_fit_keeps_seed_when_sample_is_too_small():
|
||||||
|
problem_class = EntityExtractionProblemClass()
|
||||||
|
estimate = problem_class.estimate(
|
||||||
|
{"chunk_words": 300, "template_words": 100, "expected_entities": 5}
|
||||||
|
)
|
||||||
|
|
||||||
|
fitted = problem_class.fit(
|
||||||
|
[
|
||||||
|
Observation(
|
||||||
|
dimensions={"chunk_words": 300, "template_words": 100, "expected_entities": 5},
|
||||||
|
prompt_tokens=estimate.prompt_tokens,
|
||||||
|
completion_tokens=estimate.completion_tokens,
|
||||||
|
)
|
||||||
|
],
|
||||||
|
min_observations=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert fitted is problem_class
|
||||||
|
|
||||||
|
|
||||||
|
def test_missing_dimensions_are_rejected():
|
||||||
|
problem_class = ProblemClassRegistry.default().get("chunk-summarization")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Missing dimensions"):
|
||||||
|
problem_class.estimate({"chunk_words": 100})
|
||||||
65
tests/test_rates.py
Normal file
65
tests/test_rates.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from llm_connect.rates import ModelRate, ModelRateRegistry
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_registry_contains_openrouter_seed_models():
|
||||||
|
registry = ModelRateRegistry.default()
|
||||||
|
rates = registry.all()
|
||||||
|
|
||||||
|
assert len(rates) >= 9
|
||||||
|
assert rates["openai/gpt-4o-mini"].captured_at == "2026-05-17"
|
||||||
|
assert rates["openai/gpt-4o-mini"].source_url == "https://openrouter.ai/models"
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_yaml_loads_package_shape(tmp_path):
|
||||||
|
path = tmp_path / "model-rates.yaml"
|
||||||
|
path.write_text(
|
||||||
|
"""
|
||||||
|
schema_version: 1
|
||||||
|
currency: USD
|
||||||
|
source_url: https://example.test/rates
|
||||||
|
captured_at: "2026-05-19"
|
||||||
|
rates:
|
||||||
|
vendor/model:
|
||||||
|
prompt_per_1k: 0.1
|
||||||
|
completion_per_1k: 0.2
|
||||||
|
""",
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
registry = ModelRateRegistry.from_yaml(path)
|
||||||
|
rate = registry.get("vendor/model")
|
||||||
|
|
||||||
|
assert rate == ModelRate(
|
||||||
|
model_id="vendor/model",
|
||||||
|
prompt_per_1k=0.1,
|
||||||
|
completion_per_1k=0.2,
|
||||||
|
currency="USD",
|
||||||
|
source_url="https://example.test/rates",
|
||||||
|
captured_at="2026-05-19",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_merged_with_overrides_matching_model():
|
||||||
|
base = ModelRateRegistry.default()
|
||||||
|
override = ModelRateRegistry(
|
||||||
|
{
|
||||||
|
"openai/gpt-4o-mini": ModelRate(
|
||||||
|
"openai/gpt-4o-mini",
|
||||||
|
prompt_per_1k=1,
|
||||||
|
completion_per_1k=2,
|
||||||
|
captured_at="override",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
merged = base.merged_with(override)
|
||||||
|
|
||||||
|
assert merged.get("openai/gpt-4o-mini").prompt_per_1k == 1
|
||||||
|
assert merged.get("openai/gpt-4o-mini").captured_at == "override"
|
||||||
|
|
||||||
|
|
||||||
|
def test_negative_rates_are_rejected():
|
||||||
|
with pytest.raises(ValueError, match="prompt_per_1k"):
|
||||||
|
ModelRate("bad/model", prompt_per_1k=-1, completion_per_1k=0)
|
||||||
@@ -4,7 +4,7 @@ type: workplan
|
|||||||
title: "Cost Model and Problem-Class Token Estimators"
|
title: "Cost Model and Problem-Class Token Estimators"
|
||||||
domain: custodian
|
domain: custodian
|
||||||
repo: llm-connect
|
repo: llm-connect
|
||||||
status: proposed
|
status: finished
|
||||||
owner: llm-connect
|
owner: llm-connect
|
||||||
planning_priority: high
|
planning_priority: high
|
||||||
planning_order: 5
|
planning_order: 5
|
||||||
@@ -21,7 +21,7 @@ state_hub_workstream_id: "869196c5-551b-4eef-b8d8-cca6f770a9b0"
|
|||||||
|
|
||||||
# LLM-WP-0005 — Cost Model and Problem-Class Token Estimators
|
# LLM-WP-0005 — Cost Model and Problem-Class Token Estimators
|
||||||
|
|
||||||
**status:** proposed
|
**status:** finished
|
||||||
**owner:** llm-connect
|
**owner:** llm-connect
|
||||||
|
|
||||||
## Purpose
|
## Purpose
|
||||||
@@ -184,7 +184,7 @@ keeps working (LLM-WP-0004 owns the ledger).
|
|||||||
id: T01
|
id: T01
|
||||||
title: 'ModelRate + ModelRateRegistry data model, YAML loader, default-registry seed of nine OpenRouter models'
|
title: 'ModelRate + ModelRateRegistry data model, YAML loader, default-registry seed of nine OpenRouter models'
|
||||||
priority: high
|
priority: high
|
||||||
status: todo
|
status: done
|
||||||
state_hub_task_id: "535d3f12-911e-4b6a-87c3-b539c5986671"
|
state_hub_task_id: "535d3f12-911e-4b6a-87c3-b539c5986671"
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -192,7 +192,7 @@ state_hub_task_id: "535d3f12-911e-4b6a-87c3-b539c5986671"
|
|||||||
id: T02
|
id: T02
|
||||||
title: 'CostModel.estimate_cost() pure function; tests for known model, unknown model, registry override, zero-token edge'
|
title: 'CostModel.estimate_cost() pure function; tests for known model, unknown model, registry override, zero-token edge'
|
||||||
priority: high
|
priority: high
|
||||||
status: todo
|
status: done
|
||||||
state_hub_task_id: "691dd985-6a97-432d-8bf0-6cb99a9fbdcc"
|
state_hub_task_id: "691dd985-6a97-432d-8bf0-6cb99a9fbdcc"
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -200,7 +200,7 @@ state_hub_task_id: "691dd985-6a97-432d-8bf0-6cb99a9fbdcc"
|
|||||||
id: T03
|
id: T03
|
||||||
title: 'ProblemClass protocol + TokenEstimate + ProblemClassRegistry'
|
title: 'ProblemClass protocol + TokenEstimate + ProblemClassRegistry'
|
||||||
priority: high
|
priority: high
|
||||||
status: todo
|
status: done
|
||||||
state_hub_task_id: "ecf263d2-f40a-460e-9195-4e01135ef727"
|
state_hub_task_id: "ecf263d2-f40a-460e-9195-4e01135ef727"
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -208,7 +208,7 @@ state_hub_task_id: "ecf263d2-f40a-460e-9195-4e01135ef727"
|
|||||||
id: T04
|
id: T04
|
||||||
title: 'Built-in classes: chunk-summarization, entity-extraction, relation-extraction, judge-eval, report-synthesis'
|
title: 'Built-in classes: chunk-summarization, entity-extraction, relation-extraction, judge-eval, report-synthesis'
|
||||||
priority: high
|
priority: high
|
||||||
status: todo
|
status: done
|
||||||
state_hub_task_id: "f1860b10-7467-4ce3-9775-ab293cef3ed0"
|
state_hub_task_id: "f1860b10-7467-4ce3-9775-ab293cef3ed0"
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -216,7 +216,7 @@ state_hub_task_id: "f1860b10-7467-4ce3-9775-ab293cef3ed0"
|
|||||||
id: T05
|
id: T05
|
||||||
title: 'ProblemClass.fit() adapts tunable params from QualityLedger observations'
|
title: 'ProblemClass.fit() adapts tunable params from QualityLedger observations'
|
||||||
priority: medium
|
priority: medium
|
||||||
status: todo
|
status: done
|
||||||
state_hub_task_id: "950b74e9-ede8-477a-b6b7-c7af423d4ebb"
|
state_hub_task_id: "950b74e9-ede8-477a-b6b7-c7af423d4ebb"
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -224,7 +224,7 @@ state_hub_task_id: "950b74e9-ede8-477a-b6b7-c7af423d4ebb"
|
|||||||
id: T06
|
id: T06
|
||||||
title: 'CLI helpers: llm-connect rates show, llm-connect classes show, llm-connect classes fit <ledger>'
|
title: 'CLI helpers: llm-connect rates show, llm-connect classes show, llm-connect classes fit <ledger>'
|
||||||
priority: medium
|
priority: medium
|
||||||
status: todo
|
status: done
|
||||||
state_hub_task_id: "c47eca5f-4cb3-4f88-ac1b-38a9ae18e7e6"
|
state_hub_task_id: "c47eca5f-4cb3-4f88-ac1b-38a9ae18e7e6"
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -232,7 +232,7 @@ state_hub_task_id: "c47eca5f-4cb3-4f88-ac1b-38a9ae18e7e6"
|
|||||||
id: T07
|
id: T07
|
||||||
title: 'Functional contract docs under contracts/functional/ for rates, costs, and problem classes'
|
title: 'Functional contract docs under contracts/functional/ for rates, costs, and problem classes'
|
||||||
priority: medium
|
priority: medium
|
||||||
status: todo
|
status: done
|
||||||
state_hub_task_id: "c15fd1dc-48c3-40e9-abca-ba3ffe3684f9"
|
state_hub_task_id: "c15fd1dc-48c3-40e9-abca-ba3ffe3684f9"
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -240,7 +240,7 @@ state_hub_task_id: "c15fd1dc-48c3-40e9-abca-ba3ffe3684f9"
|
|||||||
id: T08
|
id: T08
|
||||||
title: 'Consumer migration note for infospace-bench: replace plan_generation_summary cost+token math with llm-connect calls'
|
title: 'Consumer migration note for infospace-bench: replace plan_generation_summary cost+token math with llm-connect calls'
|
||||||
priority: medium
|
priority: medium
|
||||||
status: todo
|
status: done
|
||||||
state_hub_task_id: "2993932a-334c-49f9-bb74-6ef4d3cbffcb"
|
state_hub_task_id: "2993932a-334c-49f9-bb74-6ef4d3cbffcb"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user