generated from coulomb/repo-seed
Implement-LLM-WP-0005-cost-model-estimators
This commit is contained in:
143
llm_connect/cli.py
Normal file
143
llm_connect/cli.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""Command-line helpers for llm-connect registries."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from collections.abc import Iterable, Mapping
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from llm_connect.problem_classes import ProblemClass, ProblemClassRegistry
|
||||
from llm_connect.quality import QualityLedger
|
||||
from llm_connect.rates import ModelRateRegistry
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> int:
|
||||
"""Run the ``llm-connect`` command."""
|
||||
parser = _build_parser()
|
||||
args = parser.parse_args(argv)
|
||||
return int(args.func(args))
|
||||
|
||||
|
||||
def _build_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(prog="llm-connect")
|
||||
commands = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
rates = commands.add_parser("rates", help="Inspect model rate registries")
|
||||
rate_commands = rates.add_subparsers(dest="rates_command", required=True)
|
||||
rate_show = rate_commands.add_parser("show", help="Show model rates")
|
||||
rate_show.add_argument("--rates", type=Path, help="YAML registry overlay")
|
||||
rate_show.add_argument("--json", action="store_true", help="Emit JSON")
|
||||
rate_show.set_defaults(func=_rates_show)
|
||||
|
||||
classes = commands.add_parser("classes", help="Inspect problem classes")
|
||||
class_commands = classes.add_subparsers(dest="classes_command", required=True)
|
||||
class_show = class_commands.add_parser("show", help="Show problem classes")
|
||||
class_show.add_argument("--json", action="store_true", help="Emit JSON")
|
||||
class_show.set_defaults(func=_classes_show)
|
||||
|
||||
class_fit = class_commands.add_parser("fit", help="Fit problem-class params from a ledger")
|
||||
class_fit.add_argument("ledger", type=Path, help="QualityLedger JSONL path")
|
||||
class_fit.add_argument("--class", dest="class_name", help="Fit one class by name")
|
||||
class_fit.add_argument("--min-observations", type=int, default=3)
|
||||
class_fit.add_argument("--json", action="store_true", help="Emit JSON")
|
||||
class_fit.set_defaults(func=_classes_fit)
|
||||
return parser
|
||||
|
||||
|
||||
def _rates_show(args: argparse.Namespace) -> int:
|
||||
registry = ModelRateRegistry.default()
|
||||
if args.rates:
|
||||
registry = registry.merged_with(ModelRateRegistry.from_yaml(args.rates))
|
||||
rates = registry.all()
|
||||
if args.json:
|
||||
print(
|
||||
json.dumps(
|
||||
{
|
||||
model_id: {
|
||||
"prompt_per_1k": rate.prompt_per_1k,
|
||||
"completion_per_1k": rate.completion_per_1k,
|
||||
"currency": rate.currency,
|
||||
"source_url": rate.source_url,
|
||||
"captured_at": rate.captured_at,
|
||||
}
|
||||
for model_id, rate in sorted(rates.items())
|
||||
},
|
||||
indent=2,
|
||||
sort_keys=True,
|
||||
)
|
||||
)
|
||||
return 0
|
||||
|
||||
print("model_id\tprompt_per_1k\tcompletion_per_1k\tcurrency\tcaptured_at")
|
||||
for model_id, rate in sorted(rates.items()):
|
||||
print(
|
||||
f"{model_id}\t{rate.prompt_per_1k:g}\t{rate.completion_per_1k:g}\t"
|
||||
f"{rate.currency}\t{rate.captured_at}"
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
def _classes_show(args: argparse.Namespace) -> int:
|
||||
classes = ProblemClassRegistry.default().all()
|
||||
if args.json:
|
||||
print(json.dumps(_classes_payload(classes.values()), indent=2, sort_keys=True))
|
||||
return 0
|
||||
|
||||
print("name\tdimensions\ttunable_params\tcurrent_params")
|
||||
for problem_class in sorted(classes.values(), key=lambda item: item.name):
|
||||
print(
|
||||
f"{problem_class.name}\t{', '.join(problem_class.base_dimensions)}\t"
|
||||
f"{', '.join(problem_class.tunable_params)}\t{_format_params(problem_class.params)}"
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
def _classes_fit(args: argparse.Namespace) -> int:
|
||||
if args.min_observations <= 0:
|
||||
raise SystemExit("--min-observations must be positive")
|
||||
registry = ProblemClassRegistry.default()
|
||||
classes = registry.all()
|
||||
if args.class_name:
|
||||
problem_class = registry.get(args.class_name)
|
||||
if problem_class is None:
|
||||
raise SystemExit(f"Unknown problem class: {args.class_name}")
|
||||
selected: list[ProblemClass] = [problem_class]
|
||||
else:
|
||||
selected = list(classes.values())
|
||||
|
||||
observations = QualityLedger(args.ledger).read_all()
|
||||
fitted: list[ProblemClass] = [
|
||||
problem_class.fit(observations, min_observations=args.min_observations)
|
||||
for problem_class in selected
|
||||
]
|
||||
if args.json:
|
||||
print(json.dumps(_classes_payload(fitted), indent=2, sort_keys=True))
|
||||
return 0
|
||||
|
||||
print("name\tfitted_params\tconfidence")
|
||||
for problem_class in sorted(fitted, key=lambda item: item.name):
|
||||
confidence = getattr(problem_class, "confidence", 0.5)
|
||||
print(f"{problem_class.name}\t{_format_params(problem_class.params)}\t{confidence:g}")
|
||||
return 0
|
||||
|
||||
|
||||
def _classes_payload(classes: Iterable[ProblemClass]) -> dict[str, dict[str, Any]]:
|
||||
return {
|
||||
problem_class.name: {
|
||||
"base_dimensions": list(problem_class.base_dimensions),
|
||||
"tunable_params": list(problem_class.tunable_params),
|
||||
"params": dict(problem_class.params),
|
||||
"confidence": getattr(problem_class, "confidence", 0.5),
|
||||
}
|
||||
for problem_class in sorted(classes, key=lambda item: item.name)
|
||||
}
|
||||
|
||||
|
||||
def _format_params(params: Mapping[str, float]) -> str:
|
||||
return ", ".join(f"{key}={value:g}" for key, value in sorted(dict(params).items()))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
Reference in New Issue
Block a user