Files
llm-connect/tests/test_problem_classes.py
tegwick c11c6afa3f
Some checks failed
CI / test (3.10) (push) Has been cancelled
CI / test (3.11) (push) Has been cancelled
CI / test (3.12) (push) Has been cancelled
Implement-LLM-WP-0005-cost-model-estimators
2026-05-19 05:02:20 +02:00

138 lines
4.7 KiB
Python

from datetime import datetime, timezone
import pytest
from llm_connect.problem_classes import (
EntityExtractionProblemClass,
Observation,
ProblemClassRegistry,
TokenEstimate,
)
from llm_connect.quality import QualityObservation
DIMENSIONS_BY_CLASS = {
"chunk-summarization": [
{"chunk_words": 900, "template_words": 150},
{"chunk_words": 400, "template_words": 125},
{"chunk_words": 1200, "template_words": 200},
],
"entity-extraction": [
{"chunk_words": 900, "template_words": 200, "expected_entities": 4},
{"chunk_words": 450, "template_words": 180, "expected_entities": 6},
{"chunk_words": 1200, "template_words": 220, "expected_entities": 8},
],
"relation-extraction": [
{"chunk_words": 900, "template_words": 200, "expected_relations": 3},
{"chunk_words": 450, "template_words": 180, "expected_relations": 5},
{"chunk_words": 1200, "template_words": 220, "expected_relations": 7},
],
"judge-eval": [
{"artifact_words": 700, "template_words": 180, "n_criteria": 4},
{"artifact_words": 300, "template_words": 160, "n_criteria": 5},
{"artifact_words": 1100, "template_words": 200, "n_criteria": 6},
],
"report-synthesis": [
{"n_chunks": 5, "n_entities": 20, "n_relations": 8, "template_words": 250},
{"n_chunks": 8, "n_entities": 30, "n_relations": 12, "template_words": 250},
{"n_chunks": 2, "n_entities": 10, "n_relations": 3, "template_words": 180},
],
}
def test_default_registry_exposes_builtin_classes():
registry = ProblemClassRegistry.default()
assert set(registry.all()) == set(DIMENSIONS_BY_CLASS)
assert registry.schema_version == 1
@pytest.mark.parametrize("name,dimensions_list", DIMENSIONS_BY_CLASS.items())
def test_builtin_estimators_produce_token_estimates(name, dimensions_list):
problem_class = ProblemClassRegistry.default().get(name)
estimate = problem_class.estimate(dimensions_list[0])
assert isinstance(estimate, TokenEstimate)
assert estimate.prompt_tokens >= 0
assert estimate.completion_tokens >= 0
assert 0 <= estimate.confidence <= 1
@pytest.mark.parametrize("name,dimensions_list", DIMENSIONS_BY_CLASS.items())
def test_fit_recovers_seeded_params_from_synthetic_observations(name, dimensions_list):
seeded = ProblemClassRegistry.default().get(name)
param_name = seeded.tunable_params[0]
off_seed = type(seeded)(params={param_name: seeded.params[param_name] * 2})
observations = []
for dimensions in dimensions_list:
estimate = seeded.estimate(dimensions)
observations.append(
Observation(
dimensions=dimensions,
prompt_tokens=estimate.prompt_tokens,
completion_tokens=estimate.completion_tokens,
)
)
fitted = off_seed.fit(observations, min_observations=3)
assert fitted.params[param_name] == pytest.approx(seeded.params[param_name], rel=0.1)
def test_fit_uses_quality_ledger_observation_shape():
problem_class = EntityExtractionProblemClass(params={"tokens_per_entity": 10})
observations = [
QualityObservation(
task_type="extract",
adapter_id="openrouter",
model_id="openai/gpt-4o-mini",
cost_usd=0.001,
quality_score=0.9,
latency_ms=100,
tokens_in=500,
tokens_out=350,
recorded_at=datetime(2026, 5, 19, tzinfo=timezone.utc),
tags={
"problem_class": "entity-extraction",
"dimensions": {
"chunk_words": 300,
"template_words": 100,
"expected_entities": 5,
},
},
)
for _ in range(3)
]
fitted = problem_class.fit(observations)
assert fitted.params["tokens_per_entity"] == pytest.approx(70)
def test_fit_keeps_seed_when_sample_is_too_small():
problem_class = EntityExtractionProblemClass()
estimate = problem_class.estimate(
{"chunk_words": 300, "template_words": 100, "expected_entities": 5}
)
fitted = problem_class.fit(
[
Observation(
dimensions={"chunk_words": 300, "template_words": 100, "expected_entities": 5},
prompt_tokens=estimate.prompt_tokens,
completion_tokens=estimate.completion_tokens,
)
],
min_observations=3,
)
assert fitted is problem_class
def test_missing_dimensions_are_rejected():
problem_class = ProblemClassRegistry.default().get("chunk-summarization")
with pytest.raises(ValueError, match="Missing dimensions"):
problem_class.estimate({"chunk_words": 100})