"""Problem-class token estimators for common LLM workflow shapes.""" from __future__ import annotations from collections.abc import Mapping, Sequence from dataclasses import dataclass from typing import Any, Protocol DEFAULT_WORDS_PER_TOKEN = 0.75 @dataclass(frozen=True) class TokenEstimate: """Prompt/completion token estimate for a prospective LLM call.""" prompt_tokens: int completion_tokens: int confidence: float = 0.5 def __post_init__(self) -> None: prompt_tokens = _non_negative_int("prompt_tokens", self.prompt_tokens) completion_tokens = _non_negative_int("completion_tokens", self.completion_tokens) confidence = _bounded_float("confidence", self.confidence) object.__setattr__(self, "prompt_tokens", prompt_tokens) object.__setattr__(self, "completion_tokens", completion_tokens) object.__setattr__(self, "confidence", confidence) @dataclass(frozen=True) class Observation: """Actual token use paired with the problem dimensions that produced it.""" dimensions: dict[str, Any] prompt_tokens: int completion_tokens: int def __post_init__(self) -> None: object.__setattr__(self, "dimensions", dict(self.dimensions)) object.__setattr__(self, "prompt_tokens", _non_negative_int("prompt_tokens", self.prompt_tokens)) object.__setattr__( self, "completion_tokens", _non_negative_int("completion_tokens", self.completion_tokens), ) class ProblemClass(Protocol): """Estimator contract implemented by built-in and consumer classes.""" name: str base_dimensions: tuple[str, ...] tunable_params: tuple[str, ...] params: dict[str, float] def estimate( self, dimensions: dict[str, Any], params: dict[str, Any] | None = None, ) -> TokenEstimate: """Estimate token use from dimensions and optional parameter overrides.""" ... def fit( self, observations: Sequence[Any], *, min_observations: int = 3, ) -> "ProblemClass": """Return an estimator with params adapted from observed token use.""" ... class ProblemClassRegistry: """Registry keyed by stable problem-class names.""" schema_version = 1 def __init__(self, classes: Sequence[ProblemClass] | None = None) -> None: self._classes: dict[str, ProblemClass] = {} for problem_class in classes or (): self.register(problem_class) def get(self, name: str) -> ProblemClass | None: """Return a registered class by name.""" return self._classes.get(str(name).strip()) def all(self) -> dict[str, ProblemClass]: """Return a copy of registered problem classes.""" return dict(self._classes) def register(self, problem_class: ProblemClass, *, replace: bool = False) -> None: """Register *problem_class* under its name.""" name = str(problem_class.name).strip() if not name: raise ValueError("problem_class.name must be a non-empty string") if name in self._classes and not replace: raise ValueError(f"Problem class {name!r} is already registered") self._classes[name] = problem_class @classmethod def default(cls) -> "ProblemClassRegistry": """Return the built-in problem-class registry.""" return cls( [ ChunkSummarizationProblemClass(), EntityExtractionProblemClass(), RelationExtractionProblemClass(), JudgeEvalProblemClass(), ReportSynthesisProblemClass(), ] ) class _BaseProblemClass: name = "" base_dimensions: tuple[str, ...] = () tunable_params: tuple[str, ...] = () seed_params: Mapping[str, float] = {} def __init__( self, *, params: Mapping[str, Any] | None = None, confidence: float = 0.5, ) -> None: merged = dict(self.seed_params) for key, value in (params or {}).items(): if key not in self.tunable_params: raise ValueError(f"Unknown parameter {key!r} for problem class {self.name!r}") merged[key] = _non_negative_float(key, value) self.params: dict[str, float] = merged self.confidence = _bounded_float("confidence", confidence) def estimate( self, dimensions: dict[str, Any], params: dict[str, Any] | None = None, ) -> TokenEstimate: dimensions = dict(dimensions) self._validate_dimensions(dimensions) merged_params = dict(self.params) for key, value in (params or {}).items(): if key not in self.tunable_params: raise ValueError(f"Unknown parameter {key!r} for problem class {self.name!r}") merged_params[key] = _non_negative_float(key, value) prompt_tokens, completion_tokens = self._estimate_tokens(dimensions, merged_params) return TokenEstimate( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, confidence=self.confidence, ) def fit( self, observations: Sequence[Any], *, min_observations: int = 3, ) -> ProblemClass: if min_observations <= 0: raise ValueError("min_observations must be positive") parsed = [ observation for observation in ( _coerce_observation(raw, self.name, self.base_dimensions) for raw in observations ) if observation is not None ] if len(parsed) < min_observations: return self fitted: dict[str, float] = {} for param in self.tunable_params: values = [ value for value in ( self._infer_param(param, observation) for observation in parsed ) if value is not None ] if values: fitted[param] = sum(values) / len(values) if not fitted: return self confidence = min(0.95, max(self.confidence, len(parsed) / (len(parsed) + 5))) return type(self)(params={**self.params, **fitted}, confidence=confidence) def _validate_dimensions(self, dimensions: Mapping[str, Any]) -> None: missing = [name for name in self.base_dimensions if name not in dimensions] if missing: raise ValueError(f"Missing dimensions for {self.name!r}: {', '.join(missing)}") for name in self.base_dimensions: _non_negative_float(name, dimensions[name]) def _estimate_tokens( self, dimensions: Mapping[str, Any], params: Mapping[str, float], ) -> tuple[int, int]: raise NotImplementedError def _infer_param(self, param: str, observation: Observation) -> float | None: raise NotImplementedError class ChunkSummarizationProblemClass(_BaseProblemClass): name = "chunk-summarization" base_dimensions: tuple[str, ...] = ("chunk_words", "template_words") tunable_params: tuple[str, ...] = ("completion_ratio",) seed_params: Mapping[str, float] = {"completion_ratio": 0.25} def _estimate_tokens( self, dimensions: Mapping[str, Any], params: Mapping[str, float], ) -> tuple[int, int]: prompt_tokens = _words_to_tokens( _dimension(dimensions, "chunk_words") + _dimension(dimensions, "template_words") ) completion_tokens = _round_tokens(prompt_tokens * params["completion_ratio"]) return prompt_tokens, completion_tokens def _infer_param(self, param: str, observation: Observation) -> float | None: if param != "completion_ratio" or observation.prompt_tokens == 0: return None return observation.completion_tokens / observation.prompt_tokens class EntityExtractionProblemClass(_BaseProblemClass): name = "entity-extraction" base_dimensions: tuple[str, ...] = ("chunk_words", "template_words", "expected_entities") tunable_params: tuple[str, ...] = ("tokens_per_entity",) seed_params: Mapping[str, float] = {"tokens_per_entity": 70.0} def _estimate_tokens( self, dimensions: Mapping[str, Any], params: Mapping[str, float], ) -> tuple[int, int]: prompt_tokens = _words_to_tokens( _dimension(dimensions, "chunk_words") + _dimension(dimensions, "template_words") ) completion_tokens = _round_tokens( _dimension(dimensions, "expected_entities") * params["tokens_per_entity"] ) return prompt_tokens, completion_tokens def _infer_param(self, param: str, observation: Observation) -> float | None: expected_entities = _dimension(observation.dimensions, "expected_entities") if param != "tokens_per_entity" or expected_entities <= 0: return None return observation.completion_tokens / expected_entities class RelationExtractionProblemClass(_BaseProblemClass): name = "relation-extraction" base_dimensions: tuple[str, ...] = ("chunk_words", "template_words", "expected_relations") tunable_params: tuple[str, ...] = ("tokens_per_relation",) seed_params: Mapping[str, float] = {"tokens_per_relation": 80.0} def _estimate_tokens( self, dimensions: Mapping[str, Any], params: Mapping[str, float], ) -> tuple[int, int]: prompt_tokens = _words_to_tokens( _dimension(dimensions, "chunk_words") + _dimension(dimensions, "template_words") ) completion_tokens = _round_tokens( _dimension(dimensions, "expected_relations") * params["tokens_per_relation"] ) return prompt_tokens, completion_tokens def _infer_param(self, param: str, observation: Observation) -> float | None: expected_relations = _dimension(observation.dimensions, "expected_relations") if param != "tokens_per_relation" or expected_relations <= 0: return None return observation.completion_tokens / expected_relations class JudgeEvalProblemClass(_BaseProblemClass): name = "judge-eval" base_dimensions: tuple[str, ...] = ("artifact_words", "template_words", "n_criteria") tunable_params: tuple[str, ...] = ("tokens_per_criterion",) seed_params: Mapping[str, float] = {"tokens_per_criterion": 35.0} def _estimate_tokens( self, dimensions: Mapping[str, Any], params: Mapping[str, float], ) -> tuple[int, int]: prompt_tokens = _words_to_tokens( _dimension(dimensions, "artifact_words") + _dimension(dimensions, "template_words") ) completion_tokens = _round_tokens( _dimension(dimensions, "n_criteria") * params["tokens_per_criterion"] ) return prompt_tokens, completion_tokens def _infer_param(self, param: str, observation: Observation) -> float | None: n_criteria = _dimension(observation.dimensions, "n_criteria") if param != "tokens_per_criterion" or n_criteria <= 0: return None return observation.completion_tokens / n_criteria class ReportSynthesisProblemClass(_BaseProblemClass): name = "report-synthesis" base_dimensions: tuple[str, ...] = ("n_chunks", "n_entities", "n_relations", "template_words") tunable_params: tuple[str, ...] = ("base_completion_tokens",) seed_params: Mapping[str, float] = {"base_completion_tokens": 400.0} def _estimate_tokens( self, dimensions: Mapping[str, Any], params: Mapping[str, float], ) -> tuple[int, int]: prompt_tokens = _words_to_tokens(_dimension(dimensions, "template_words")) prompt_tokens += _round_tokens(_dimension(dimensions, "n_chunks") * 40) prompt_tokens += _round_tokens(_dimension(dimensions, "n_entities") * 25) prompt_tokens += _round_tokens(_dimension(dimensions, "n_relations") * 35) return prompt_tokens, _round_tokens(params["base_completion_tokens"]) def _infer_param(self, param: str, observation: Observation) -> float | None: if param != "base_completion_tokens": return None return float(observation.completion_tokens) def default_problem_class_registry() -> ProblemClassRegistry: """Return the built-in problem-class registry.""" return ProblemClassRegistry.default() def _coerce_observation( raw: Any, class_name: str, required_dimensions: tuple[str, ...], ) -> Observation | None: try: if isinstance(raw, Observation): return raw if isinstance(raw, Mapping): return _coerce_mapping_observation(raw, class_name, required_dimensions) return _coerce_object_observation(raw, class_name, required_dimensions) except (KeyError, TypeError, ValueError): return None def _coerce_mapping_observation( raw: Mapping[str, Any], class_name: str, required_dimensions: tuple[str, ...], ) -> Observation | None: raw_tags = raw.get("tags") tags: Mapping[str, Any] = raw_tags if isinstance(raw_tags, Mapping) else {} problem_class = raw.get("problem_class") or tags.get("problem_class") if problem_class is not None and str(problem_class) != class_name: return None dimensions = _dimensions_from_sources(required_dimensions, raw, tags) prompt_tokens = _token_value(raw, "prompt_tokens", "tokens_in", "actual_prompt_tokens") completion_tokens = _token_value( raw, "completion_tokens", "tokens_out", "actual_completion_tokens", ) return Observation(dimensions, prompt_tokens, completion_tokens) def _coerce_object_observation( raw: Any, class_name: str, required_dimensions: tuple[str, ...], ) -> Observation | None: raw_tags = getattr(raw, "tags", {}) or {} tags: Mapping[str, Any] = raw_tags if isinstance(raw_tags, Mapping) else {} problem_class = tags.get("problem_class") if problem_class is not None and str(problem_class) != class_name: return None dimensions = _dimensions_from_sources(required_dimensions, tags) return Observation( dimensions=dimensions, prompt_tokens=getattr(raw, "tokens_in"), completion_tokens=getattr(raw, "tokens_out"), ) def _dimensions_from_sources( required_dimensions: tuple[str, ...], *sources: Mapping[str, Any], ) -> dict[str, Any]: for source in sources: candidate = source.get("dimensions") if isinstance(candidate, Mapping): return dict(candidate) dimensions: dict[str, Any] = {} for name in required_dimensions: for source in sources: if name in source: dimensions[name] = source[name] break if len(dimensions) != len(required_dimensions): raise ValueError("observation is missing required dimensions") return dimensions def _token_value(raw: Mapping[str, Any], *names: str) -> int: for name in names: if name in raw: return _non_negative_int(name, raw[name]) usage = raw.get("usage") if isinstance(usage, Mapping): for name in names: if name in usage: return _non_negative_int(name, usage[name]) raise KeyError(names[0]) def _dimension(dimensions: Mapping[str, Any], name: str) -> float: return _non_negative_float(name, dimensions[name]) def _words_to_tokens(words: float) -> int: if words == 0: return 0 return max(1, _round_tokens(words / DEFAULT_WORDS_PER_TOKEN)) def _round_tokens(value: float) -> int: return max(0, int(round(value))) def _non_negative_int(name: str, value: Any) -> int: if isinstance(value, bool): raise ValueError(f"{name} must be a non-negative integer") try: integer = int(value) except (TypeError, ValueError) as exc: raise ValueError(f"{name} must be a non-negative integer") from exc if integer < 0 or integer != float(value): raise ValueError(f"{name} must be a non-negative integer") return integer def _non_negative_float(name: str, value: Any) -> float: if isinstance(value, bool): raise ValueError(f"{name} must be a non-negative number") try: number = float(value) except (TypeError, ValueError) as exc: raise ValueError(f"{name} must be a non-negative number") from exc if number < 0: raise ValueError(f"{name} must be a non-negative number") return number def _bounded_float(name: str, value: Any) -> float: number = _non_negative_float(name, value) if number > 1: raise ValueError(f"{name} must be between 0 and 1") return number