generated from coulomb/repo-seed
464 lines
16 KiB
Python
464 lines
16 KiB
Python
"""Problem-class token estimators for common LLM workflow shapes."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Mapping, Sequence
|
|
from dataclasses import dataclass
|
|
from typing import Any, Protocol
|
|
|
|
|
|
DEFAULT_WORDS_PER_TOKEN = 0.75
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class TokenEstimate:
|
|
"""Prompt/completion token estimate for a prospective LLM call."""
|
|
|
|
prompt_tokens: int
|
|
completion_tokens: int
|
|
confidence: float = 0.5
|
|
|
|
def __post_init__(self) -> None:
|
|
prompt_tokens = _non_negative_int("prompt_tokens", self.prompt_tokens)
|
|
completion_tokens = _non_negative_int("completion_tokens", self.completion_tokens)
|
|
confidence = _bounded_float("confidence", self.confidence)
|
|
object.__setattr__(self, "prompt_tokens", prompt_tokens)
|
|
object.__setattr__(self, "completion_tokens", completion_tokens)
|
|
object.__setattr__(self, "confidence", confidence)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class Observation:
|
|
"""Actual token use paired with the problem dimensions that produced it."""
|
|
|
|
dimensions: dict[str, Any]
|
|
prompt_tokens: int
|
|
completion_tokens: int
|
|
|
|
def __post_init__(self) -> None:
|
|
object.__setattr__(self, "dimensions", dict(self.dimensions))
|
|
object.__setattr__(self, "prompt_tokens", _non_negative_int("prompt_tokens", self.prompt_tokens))
|
|
object.__setattr__(
|
|
self,
|
|
"completion_tokens",
|
|
_non_negative_int("completion_tokens", self.completion_tokens),
|
|
)
|
|
|
|
|
|
class ProblemClass(Protocol):
|
|
"""Estimator contract implemented by built-in and consumer classes."""
|
|
|
|
name: str
|
|
base_dimensions: tuple[str, ...]
|
|
tunable_params: tuple[str, ...]
|
|
params: dict[str, float]
|
|
|
|
def estimate(
|
|
self,
|
|
dimensions: dict[str, Any],
|
|
params: dict[str, Any] | None = None,
|
|
) -> TokenEstimate:
|
|
"""Estimate token use from dimensions and optional parameter overrides."""
|
|
...
|
|
|
|
def fit(
|
|
self,
|
|
observations: Sequence[Any],
|
|
*,
|
|
min_observations: int = 3,
|
|
) -> "ProblemClass":
|
|
"""Return an estimator with params adapted from observed token use."""
|
|
...
|
|
|
|
|
|
class ProblemClassRegistry:
|
|
"""Registry keyed by stable problem-class names."""
|
|
|
|
schema_version = 1
|
|
|
|
def __init__(self, classes: Sequence[ProblemClass] | None = None) -> None:
|
|
self._classes: dict[str, ProblemClass] = {}
|
|
for problem_class in classes or ():
|
|
self.register(problem_class)
|
|
|
|
def get(self, name: str) -> ProblemClass | None:
|
|
"""Return a registered class by name."""
|
|
return self._classes.get(str(name).strip())
|
|
|
|
def all(self) -> dict[str, ProblemClass]:
|
|
"""Return a copy of registered problem classes."""
|
|
return dict(self._classes)
|
|
|
|
def register(self, problem_class: ProblemClass, *, replace: bool = False) -> None:
|
|
"""Register *problem_class* under its name."""
|
|
name = str(problem_class.name).strip()
|
|
if not name:
|
|
raise ValueError("problem_class.name must be a non-empty string")
|
|
if name in self._classes and not replace:
|
|
raise ValueError(f"Problem class {name!r} is already registered")
|
|
self._classes[name] = problem_class
|
|
|
|
@classmethod
|
|
def default(cls) -> "ProblemClassRegistry":
|
|
"""Return the built-in problem-class registry."""
|
|
return cls(
|
|
[
|
|
ChunkSummarizationProblemClass(),
|
|
EntityExtractionProblemClass(),
|
|
RelationExtractionProblemClass(),
|
|
JudgeEvalProblemClass(),
|
|
ReportSynthesisProblemClass(),
|
|
]
|
|
)
|
|
|
|
|
|
class _BaseProblemClass:
|
|
name = ""
|
|
base_dimensions: tuple[str, ...] = ()
|
|
tunable_params: tuple[str, ...] = ()
|
|
seed_params: Mapping[str, float] = {}
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
params: Mapping[str, Any] | None = None,
|
|
confidence: float = 0.5,
|
|
) -> None:
|
|
merged = dict(self.seed_params)
|
|
for key, value in (params or {}).items():
|
|
if key not in self.tunable_params:
|
|
raise ValueError(f"Unknown parameter {key!r} for problem class {self.name!r}")
|
|
merged[key] = _non_negative_float(key, value)
|
|
self.params: dict[str, float] = merged
|
|
self.confidence = _bounded_float("confidence", confidence)
|
|
|
|
def estimate(
|
|
self,
|
|
dimensions: dict[str, Any],
|
|
params: dict[str, Any] | None = None,
|
|
) -> TokenEstimate:
|
|
dimensions = dict(dimensions)
|
|
self._validate_dimensions(dimensions)
|
|
merged_params = dict(self.params)
|
|
for key, value in (params or {}).items():
|
|
if key not in self.tunable_params:
|
|
raise ValueError(f"Unknown parameter {key!r} for problem class {self.name!r}")
|
|
merged_params[key] = _non_negative_float(key, value)
|
|
prompt_tokens, completion_tokens = self._estimate_tokens(dimensions, merged_params)
|
|
return TokenEstimate(
|
|
prompt_tokens=prompt_tokens,
|
|
completion_tokens=completion_tokens,
|
|
confidence=self.confidence,
|
|
)
|
|
|
|
def fit(
|
|
self,
|
|
observations: Sequence[Any],
|
|
*,
|
|
min_observations: int = 3,
|
|
) -> ProblemClass:
|
|
if min_observations <= 0:
|
|
raise ValueError("min_observations must be positive")
|
|
parsed = [
|
|
observation
|
|
for observation in (
|
|
_coerce_observation(raw, self.name, self.base_dimensions) for raw in observations
|
|
)
|
|
if observation is not None
|
|
]
|
|
if len(parsed) < min_observations:
|
|
return self
|
|
|
|
fitted: dict[str, float] = {}
|
|
for param in self.tunable_params:
|
|
values = [
|
|
value
|
|
for value in (
|
|
self._infer_param(param, observation) for observation in parsed
|
|
)
|
|
if value is not None
|
|
]
|
|
if values:
|
|
fitted[param] = sum(values) / len(values)
|
|
if not fitted:
|
|
return self
|
|
|
|
confidence = min(0.95, max(self.confidence, len(parsed) / (len(parsed) + 5)))
|
|
return type(self)(params={**self.params, **fitted}, confidence=confidence)
|
|
|
|
def _validate_dimensions(self, dimensions: Mapping[str, Any]) -> None:
|
|
missing = [name for name in self.base_dimensions if name not in dimensions]
|
|
if missing:
|
|
raise ValueError(f"Missing dimensions for {self.name!r}: {', '.join(missing)}")
|
|
for name in self.base_dimensions:
|
|
_non_negative_float(name, dimensions[name])
|
|
|
|
def _estimate_tokens(
|
|
self,
|
|
dimensions: Mapping[str, Any],
|
|
params: Mapping[str, float],
|
|
) -> tuple[int, int]:
|
|
raise NotImplementedError
|
|
|
|
def _infer_param(self, param: str, observation: Observation) -> float | None:
|
|
raise NotImplementedError
|
|
|
|
|
|
class ChunkSummarizationProblemClass(_BaseProblemClass):
|
|
name = "chunk-summarization"
|
|
base_dimensions: tuple[str, ...] = ("chunk_words", "template_words")
|
|
tunable_params: tuple[str, ...] = ("completion_ratio",)
|
|
seed_params: Mapping[str, float] = {"completion_ratio": 0.25}
|
|
|
|
def _estimate_tokens(
|
|
self,
|
|
dimensions: Mapping[str, Any],
|
|
params: Mapping[str, float],
|
|
) -> tuple[int, int]:
|
|
prompt_tokens = _words_to_tokens(
|
|
_dimension(dimensions, "chunk_words") + _dimension(dimensions, "template_words")
|
|
)
|
|
completion_tokens = _round_tokens(prompt_tokens * params["completion_ratio"])
|
|
return prompt_tokens, completion_tokens
|
|
|
|
def _infer_param(self, param: str, observation: Observation) -> float | None:
|
|
if param != "completion_ratio" or observation.prompt_tokens == 0:
|
|
return None
|
|
return observation.completion_tokens / observation.prompt_tokens
|
|
|
|
|
|
class EntityExtractionProblemClass(_BaseProblemClass):
|
|
name = "entity-extraction"
|
|
base_dimensions: tuple[str, ...] = ("chunk_words", "template_words", "expected_entities")
|
|
tunable_params: tuple[str, ...] = ("tokens_per_entity",)
|
|
seed_params: Mapping[str, float] = {"tokens_per_entity": 70.0}
|
|
|
|
def _estimate_tokens(
|
|
self,
|
|
dimensions: Mapping[str, Any],
|
|
params: Mapping[str, float],
|
|
) -> tuple[int, int]:
|
|
prompt_tokens = _words_to_tokens(
|
|
_dimension(dimensions, "chunk_words") + _dimension(dimensions, "template_words")
|
|
)
|
|
completion_tokens = _round_tokens(
|
|
_dimension(dimensions, "expected_entities") * params["tokens_per_entity"]
|
|
)
|
|
return prompt_tokens, completion_tokens
|
|
|
|
def _infer_param(self, param: str, observation: Observation) -> float | None:
|
|
expected_entities = _dimension(observation.dimensions, "expected_entities")
|
|
if param != "tokens_per_entity" or expected_entities <= 0:
|
|
return None
|
|
return observation.completion_tokens / expected_entities
|
|
|
|
|
|
class RelationExtractionProblemClass(_BaseProblemClass):
|
|
name = "relation-extraction"
|
|
base_dimensions: tuple[str, ...] = ("chunk_words", "template_words", "expected_relations")
|
|
tunable_params: tuple[str, ...] = ("tokens_per_relation",)
|
|
seed_params: Mapping[str, float] = {"tokens_per_relation": 80.0}
|
|
|
|
def _estimate_tokens(
|
|
self,
|
|
dimensions: Mapping[str, Any],
|
|
params: Mapping[str, float],
|
|
) -> tuple[int, int]:
|
|
prompt_tokens = _words_to_tokens(
|
|
_dimension(dimensions, "chunk_words") + _dimension(dimensions, "template_words")
|
|
)
|
|
completion_tokens = _round_tokens(
|
|
_dimension(dimensions, "expected_relations") * params["tokens_per_relation"]
|
|
)
|
|
return prompt_tokens, completion_tokens
|
|
|
|
def _infer_param(self, param: str, observation: Observation) -> float | None:
|
|
expected_relations = _dimension(observation.dimensions, "expected_relations")
|
|
if param != "tokens_per_relation" or expected_relations <= 0:
|
|
return None
|
|
return observation.completion_tokens / expected_relations
|
|
|
|
|
|
class JudgeEvalProblemClass(_BaseProblemClass):
|
|
name = "judge-eval"
|
|
base_dimensions: tuple[str, ...] = ("artifact_words", "template_words", "n_criteria")
|
|
tunable_params: tuple[str, ...] = ("tokens_per_criterion",)
|
|
seed_params: Mapping[str, float] = {"tokens_per_criterion": 35.0}
|
|
|
|
def _estimate_tokens(
|
|
self,
|
|
dimensions: Mapping[str, Any],
|
|
params: Mapping[str, float],
|
|
) -> tuple[int, int]:
|
|
prompt_tokens = _words_to_tokens(
|
|
_dimension(dimensions, "artifact_words") + _dimension(dimensions, "template_words")
|
|
)
|
|
completion_tokens = _round_tokens(
|
|
_dimension(dimensions, "n_criteria") * params["tokens_per_criterion"]
|
|
)
|
|
return prompt_tokens, completion_tokens
|
|
|
|
def _infer_param(self, param: str, observation: Observation) -> float | None:
|
|
n_criteria = _dimension(observation.dimensions, "n_criteria")
|
|
if param != "tokens_per_criterion" or n_criteria <= 0:
|
|
return None
|
|
return observation.completion_tokens / n_criteria
|
|
|
|
|
|
class ReportSynthesisProblemClass(_BaseProblemClass):
|
|
name = "report-synthesis"
|
|
base_dimensions: tuple[str, ...] = ("n_chunks", "n_entities", "n_relations", "template_words")
|
|
tunable_params: tuple[str, ...] = ("base_completion_tokens",)
|
|
seed_params: Mapping[str, float] = {"base_completion_tokens": 400.0}
|
|
|
|
def _estimate_tokens(
|
|
self,
|
|
dimensions: Mapping[str, Any],
|
|
params: Mapping[str, float],
|
|
) -> tuple[int, int]:
|
|
prompt_tokens = _words_to_tokens(_dimension(dimensions, "template_words"))
|
|
prompt_tokens += _round_tokens(_dimension(dimensions, "n_chunks") * 40)
|
|
prompt_tokens += _round_tokens(_dimension(dimensions, "n_entities") * 25)
|
|
prompt_tokens += _round_tokens(_dimension(dimensions, "n_relations") * 35)
|
|
return prompt_tokens, _round_tokens(params["base_completion_tokens"])
|
|
|
|
def _infer_param(self, param: str, observation: Observation) -> float | None:
|
|
if param != "base_completion_tokens":
|
|
return None
|
|
return float(observation.completion_tokens)
|
|
|
|
|
|
def default_problem_class_registry() -> ProblemClassRegistry:
|
|
"""Return the built-in problem-class registry."""
|
|
return ProblemClassRegistry.default()
|
|
|
|
|
|
def _coerce_observation(
|
|
raw: Any,
|
|
class_name: str,
|
|
required_dimensions: tuple[str, ...],
|
|
) -> Observation | None:
|
|
try:
|
|
if isinstance(raw, Observation):
|
|
return raw
|
|
if isinstance(raw, Mapping):
|
|
return _coerce_mapping_observation(raw, class_name, required_dimensions)
|
|
return _coerce_object_observation(raw, class_name, required_dimensions)
|
|
except (KeyError, TypeError, ValueError):
|
|
return None
|
|
|
|
|
|
def _coerce_mapping_observation(
|
|
raw: Mapping[str, Any],
|
|
class_name: str,
|
|
required_dimensions: tuple[str, ...],
|
|
) -> Observation | None:
|
|
raw_tags = raw.get("tags")
|
|
tags: Mapping[str, Any] = raw_tags if isinstance(raw_tags, Mapping) else {}
|
|
problem_class = raw.get("problem_class") or tags.get("problem_class")
|
|
if problem_class is not None and str(problem_class) != class_name:
|
|
return None
|
|
dimensions = _dimensions_from_sources(required_dimensions, raw, tags)
|
|
prompt_tokens = _token_value(raw, "prompt_tokens", "tokens_in", "actual_prompt_tokens")
|
|
completion_tokens = _token_value(
|
|
raw,
|
|
"completion_tokens",
|
|
"tokens_out",
|
|
"actual_completion_tokens",
|
|
)
|
|
return Observation(dimensions, prompt_tokens, completion_tokens)
|
|
|
|
|
|
def _coerce_object_observation(
|
|
raw: Any,
|
|
class_name: str,
|
|
required_dimensions: tuple[str, ...],
|
|
) -> Observation | None:
|
|
raw_tags = getattr(raw, "tags", {}) or {}
|
|
tags: Mapping[str, Any] = raw_tags if isinstance(raw_tags, Mapping) else {}
|
|
problem_class = tags.get("problem_class")
|
|
if problem_class is not None and str(problem_class) != class_name:
|
|
return None
|
|
dimensions = _dimensions_from_sources(required_dimensions, tags)
|
|
return Observation(
|
|
dimensions=dimensions,
|
|
prompt_tokens=getattr(raw, "tokens_in"),
|
|
completion_tokens=getattr(raw, "tokens_out"),
|
|
)
|
|
|
|
|
|
def _dimensions_from_sources(
|
|
required_dimensions: tuple[str, ...],
|
|
*sources: Mapping[str, Any],
|
|
) -> dict[str, Any]:
|
|
for source in sources:
|
|
candidate = source.get("dimensions")
|
|
if isinstance(candidate, Mapping):
|
|
return dict(candidate)
|
|
dimensions: dict[str, Any] = {}
|
|
for name in required_dimensions:
|
|
for source in sources:
|
|
if name in source:
|
|
dimensions[name] = source[name]
|
|
break
|
|
if len(dimensions) != len(required_dimensions):
|
|
raise ValueError("observation is missing required dimensions")
|
|
return dimensions
|
|
|
|
|
|
def _token_value(raw: Mapping[str, Any], *names: str) -> int:
|
|
for name in names:
|
|
if name in raw:
|
|
return _non_negative_int(name, raw[name])
|
|
usage = raw.get("usage")
|
|
if isinstance(usage, Mapping):
|
|
for name in names:
|
|
if name in usage:
|
|
return _non_negative_int(name, usage[name])
|
|
raise KeyError(names[0])
|
|
|
|
|
|
def _dimension(dimensions: Mapping[str, Any], name: str) -> float:
|
|
return _non_negative_float(name, dimensions[name])
|
|
|
|
|
|
def _words_to_tokens(words: float) -> int:
|
|
if words == 0:
|
|
return 0
|
|
return max(1, _round_tokens(words / DEFAULT_WORDS_PER_TOKEN))
|
|
|
|
|
|
def _round_tokens(value: float) -> int:
|
|
return max(0, int(round(value)))
|
|
|
|
|
|
def _non_negative_int(name: str, value: Any) -> int:
|
|
if isinstance(value, bool):
|
|
raise ValueError(f"{name} must be a non-negative integer")
|
|
try:
|
|
integer = int(value)
|
|
except (TypeError, ValueError) as exc:
|
|
raise ValueError(f"{name} must be a non-negative integer") from exc
|
|
if integer < 0 or integer != float(value):
|
|
raise ValueError(f"{name} must be a non-negative integer")
|
|
return integer
|
|
|
|
|
|
def _non_negative_float(name: str, value: Any) -> float:
|
|
if isinstance(value, bool):
|
|
raise ValueError(f"{name} must be a non-negative number")
|
|
try:
|
|
number = float(value)
|
|
except (TypeError, ValueError) as exc:
|
|
raise ValueError(f"{name} must be a non-negative number") from exc
|
|
if number < 0:
|
|
raise ValueError(f"{name} must be a non-negative number")
|
|
return number
|
|
|
|
|
|
def _bounded_float(name: str, value: Any) -> float:
|
|
number = _non_negative_float(name, value)
|
|
if number > 1:
|
|
raise ValueError(f"{name} must be between 0 and 1")
|
|
return number
|