Files
llm-connect/llm_connect/cli.py
tegwick c11c6afa3f
Some checks failed
CI / test (3.10) (push) Has been cancelled
CI / test (3.11) (push) Has been cancelled
CI / test (3.12) (push) Has been cancelled
Implement-LLM-WP-0005-cost-model-estimators
2026-05-19 05:02:20 +02:00

144 lines
5.4 KiB
Python

"""Command-line helpers for llm-connect registries."""
from __future__ import annotations
import argparse
import json
from collections.abc import Iterable, Mapping
from pathlib import Path
from typing import Any
from llm_connect.problem_classes import ProblemClass, ProblemClassRegistry
from llm_connect.quality import QualityLedger
from llm_connect.rates import ModelRateRegistry
def main(argv: list[str] | None = None) -> int:
"""Run the ``llm-connect`` command."""
parser = _build_parser()
args = parser.parse_args(argv)
return int(args.func(args))
def _build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(prog="llm-connect")
commands = parser.add_subparsers(dest="command", required=True)
rates = commands.add_parser("rates", help="Inspect model rate registries")
rate_commands = rates.add_subparsers(dest="rates_command", required=True)
rate_show = rate_commands.add_parser("show", help="Show model rates")
rate_show.add_argument("--rates", type=Path, help="YAML registry overlay")
rate_show.add_argument("--json", action="store_true", help="Emit JSON")
rate_show.set_defaults(func=_rates_show)
classes = commands.add_parser("classes", help="Inspect problem classes")
class_commands = classes.add_subparsers(dest="classes_command", required=True)
class_show = class_commands.add_parser("show", help="Show problem classes")
class_show.add_argument("--json", action="store_true", help="Emit JSON")
class_show.set_defaults(func=_classes_show)
class_fit = class_commands.add_parser("fit", help="Fit problem-class params from a ledger")
class_fit.add_argument("ledger", type=Path, help="QualityLedger JSONL path")
class_fit.add_argument("--class", dest="class_name", help="Fit one class by name")
class_fit.add_argument("--min-observations", type=int, default=3)
class_fit.add_argument("--json", action="store_true", help="Emit JSON")
class_fit.set_defaults(func=_classes_fit)
return parser
def _rates_show(args: argparse.Namespace) -> int:
registry = ModelRateRegistry.default()
if args.rates:
registry = registry.merged_with(ModelRateRegistry.from_yaml(args.rates))
rates = registry.all()
if args.json:
print(
json.dumps(
{
model_id: {
"prompt_per_1k": rate.prompt_per_1k,
"completion_per_1k": rate.completion_per_1k,
"currency": rate.currency,
"source_url": rate.source_url,
"captured_at": rate.captured_at,
}
for model_id, rate in sorted(rates.items())
},
indent=2,
sort_keys=True,
)
)
return 0
print("model_id\tprompt_per_1k\tcompletion_per_1k\tcurrency\tcaptured_at")
for model_id, rate in sorted(rates.items()):
print(
f"{model_id}\t{rate.prompt_per_1k:g}\t{rate.completion_per_1k:g}\t"
f"{rate.currency}\t{rate.captured_at}"
)
return 0
def _classes_show(args: argparse.Namespace) -> int:
classes = ProblemClassRegistry.default().all()
if args.json:
print(json.dumps(_classes_payload(classes.values()), indent=2, sort_keys=True))
return 0
print("name\tdimensions\ttunable_params\tcurrent_params")
for problem_class in sorted(classes.values(), key=lambda item: item.name):
print(
f"{problem_class.name}\t{', '.join(problem_class.base_dimensions)}\t"
f"{', '.join(problem_class.tunable_params)}\t{_format_params(problem_class.params)}"
)
return 0
def _classes_fit(args: argparse.Namespace) -> int:
if args.min_observations <= 0:
raise SystemExit("--min-observations must be positive")
registry = ProblemClassRegistry.default()
classes = registry.all()
if args.class_name:
problem_class = registry.get(args.class_name)
if problem_class is None:
raise SystemExit(f"Unknown problem class: {args.class_name}")
selected: list[ProblemClass] = [problem_class]
else:
selected = list(classes.values())
observations = QualityLedger(args.ledger).read_all()
fitted: list[ProblemClass] = [
problem_class.fit(observations, min_observations=args.min_observations)
for problem_class in selected
]
if args.json:
print(json.dumps(_classes_payload(fitted), indent=2, sort_keys=True))
return 0
print("name\tfitted_params\tconfidence")
for problem_class in sorted(fitted, key=lambda item: item.name):
confidence = getattr(problem_class, "confidence", 0.5)
print(f"{problem_class.name}\t{_format_params(problem_class.params)}\t{confidence:g}")
return 0
def _classes_payload(classes: Iterable[ProblemClass]) -> dict[str, dict[str, Any]]:
return {
problem_class.name: {
"base_dimensions": list(problem_class.base_dimensions),
"tunable_params": list(problem_class.tunable_params),
"params": dict(problem_class.params),
"confidence": getattr(problem_class, "confidence", 0.5),
}
for problem_class in sorted(classes, key=lambda item: item.name)
}
def _format_params(params: Mapping[str, float]) -> str:
return ", ".join(f"{key}={value:g}" for key, value in sorted(dict(params).items()))
if __name__ == "__main__":
raise SystemExit(main())