""" Budget and usage registry for infospaces. Layer 1 of the three-layer design (see IB-WP-0019): - This module persists per-infospace plan snapshots, usage rollups, and plan-vs-actual variance under `output/budget/`. - Layer 2 (cross-application observations for adaptive routing) lives in llm-connect's QualityLedger (LLM-WP-0004). - Layer 3 (organizational rollup) is state-hub `record_token_event`. """ from __future__ import annotations import hashlib import json from datetime import datetime, timezone from pathlib import Path from typing import Any, Callable import yaml RATES_FILENAME = "model-rates.yaml" _PACKAGE_RATES_PATH = Path(__file__).parent / "model_rates.yaml" BUDGET_DIR = Path("output/budget") PLANS_FILE = BUDGET_DIR / "plans.yaml" USAGE_FILE = BUDGET_DIR / "usage.yaml" SUMMARY_FILE = BUDGET_DIR / "summary.yaml" PLAN_RETENTION_DEFAULT = 50 PLANS_SCHEMA_VERSION = 1 USAGE_SCHEMA_VERSION = 1 SUMMARY_SCHEMA_VERSION = 1 _SNAPSHOT_FINGERPRINT_FIELDS = ( "stage", "selected_chunk_count", "selected_chunk_ids", "selected_chapter_numbers", "total_provider_calls_estimate", "total_prompt_tokens_estimate", "estimated_cost_usd", "cost_per_1k_tokens", "max_calls", "cost_cap", ) def record_plan_snapshot( root: str | Path, summary: dict[str, Any], *, retention: int = PLAN_RETENTION_DEFAULT, ) -> str: """Persist a compact plan summary to ``output/budget/plans.yaml``. Returns the snapshot_id assigned to this entry. If a snapshot with the same fingerprint already exists at the head of the list, its ``recorded_at`` is refreshed instead of producing a duplicate entry. """ root_path = Path(root) budget_path = root_path / PLANS_FILE budget_path.parent.mkdir(parents=True, exist_ok=True) snapshot = _build_snapshot(summary) payload = _read_plans(budget_path) snapshots = payload.get("snapshots") or [] pruned_count = int(payload.get("pruned_count") or 0) if snapshots and snapshots[-1].get("snapshot_id") == snapshot["snapshot_id"]: snapshots[-1]["recorded_at"] = snapshot["recorded_at"] else: snapshots.append(snapshot) if retention > 0 and len(snapshots) > retention: overflow = len(snapshots) - retention pruned_count += overflow snapshots = snapshots[overflow:] _write_plans( budget_path, { "schema_version": PLANS_SCHEMA_VERSION, "pruned_count": pruned_count, "snapshots": snapshots, }, ) return snapshot["snapshot_id"] def read_plan_snapshots(root: str | Path) -> list[dict[str, Any]]: """Return the persisted plan snapshots in chronological order.""" payload = _read_plans(Path(root) / PLANS_FILE) return list(payload.get("snapshots") or []) def latest_plan_snapshot_id(root: str | Path) -> str | None: snapshots = read_plan_snapshots(root) if not snapshots: return None return snapshots[-1].get("snapshot_id") def record_run_usage( root: str | Path, workflow_results: list[dict[str, Any]], *, snapshot_id: str | None = None, duration_seconds: float | None = None, started_at: str | None = None, cost_resolver: Any | None = None, ) -> dict[str, Any]: """Aggregate per-call usage from completed workflow run records. ``cost_resolver`` is a callable ``(provider, model, prompt_tokens, completion_tokens) -> float | None`` used to fill ``cost_usd_estimated`` when the adapter did not return a cost. Left as ``None`` here; T03 wires the rate-table resolver in. """ root_path = Path(root) usage_path = root_path / USAGE_FILE usage_path.parent.mkdir(parents=True, exist_ok=True) buckets: dict[tuple, dict[str, Any]] = {} workflow_summaries: list[dict[str, Any]] = [] for workflow in workflow_results or []: if not isinstance(workflow, dict): continue workflow_id = str(workflow.get("workflow_id") or "") workflow_summary = { "run_id": workflow.get("run_id"), "workflow_id": workflow_id, "status": workflow.get("status"), "stage_count": len(workflow.get("stages") or []), } workflow_summaries.append(workflow_summary) for stage in workflow.get("stages") or []: if not isinstance(stage, dict): continue provider = str(stage.get("provider") or "") if not provider: continue metadata = stage.get("metadata") or {} model = str(metadata.get("model") or "") usage = metadata.get("usage") or {} prompt_tokens = int(usage.get("prompt_tokens") or 0) completion_tokens = int(usage.get("completion_tokens") or 0) reported_cost = _coerce_float(usage.get("cost")) bucket_key = (workflow_id, str(stage.get("stage_id") or ""), provider, model) bucket = buckets.setdefault( bucket_key, { "workflow_id": workflow_id, "stage_id": str(stage.get("stage_id") or ""), "provider": provider, "model": model, "calls": 0, "prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0, "cost_usd_known": 0.0, "cost_usd_estimated": 0.0, "cost_status": "known" if reported_cost is not None else "unknown", "cost_estimated_for_calls": 0, }, ) bucket["calls"] += 1 bucket["prompt_tokens"] += prompt_tokens bucket["completion_tokens"] += completion_tokens bucket["total_tokens"] += prompt_tokens + completion_tokens if reported_cost is not None: bucket["cost_usd_known"] = round(bucket["cost_usd_known"] + reported_cost, 6) bucket["cost_status"] = "known" elif cost_resolver is not None: estimated = cost_resolver(provider, model, prompt_tokens, completion_tokens) if estimated is not None: bucket["cost_usd_estimated"] = round( bucket["cost_usd_estimated"] + float(estimated), 6 ) bucket["cost_estimated_for_calls"] += 1 if bucket["cost_status"] != "known": bucket["cost_status"] = "estimated" per_bucket = list(buckets.values()) for bucket in per_bucket: if bucket["cost_usd_estimated"] == 0.0 and bucket["cost_estimated_for_calls"] == 0: bucket["cost_usd_estimated"] = None rollup = { "total_calls": sum(b["calls"] for b in per_bucket), "total_prompt_tokens": sum(b["prompt_tokens"] for b in per_bucket), "total_completion_tokens": sum(b["completion_tokens"] for b in per_bucket), "total_tokens": sum(b["total_tokens"] for b in per_bucket), "total_cost_usd_known": round(sum(b["cost_usd_known"] for b in per_bucket), 6), "total_cost_usd_estimated": round( sum(b["cost_usd_estimated"] or 0.0 for b in per_bucket), 6 ) or None, } completed_at = _now() entry = { "run_index": _next_run_index(usage_path), "started_at": started_at, "completed_at": completed_at, "duration_seconds": duration_seconds, "snapshot_id": snapshot_id, "workflows": workflow_summaries, "rollup": rollup, "per_bucket": per_bucket, } payload = _read_usage(usage_path) runs = list(payload.get("runs") or []) runs.append(entry) _write_usage( usage_path, {"schema_version": USAGE_SCHEMA_VERSION, "runs": runs}, ) return entry def read_usage_runs(root: str | Path) -> list[dict[str, Any]]: payload = _read_usage(Path(root) / USAGE_FILE) return list(payload.get("runs") or []) def record_run_variance( root: str | Path, run_entry: dict[str, Any], ) -> dict[str, Any]: """Compute and persist plan-vs-actual variance for the just-completed run. Reads the plan snapshot referenced by ``run_entry['snapshot_id']`` from ``output/budget/plans.yaml``, derives call/token/cost variance, writes the result to ``output/budget/summary.yaml`` (overwrite), and returns it. When no snapshot is referenced or the snapshot cannot be located, the variance payload is still written with null comparison fields so the consumer always sees a current summary. """ root_path = Path(root) summary_path = root_path / SUMMARY_FILE summary_path.parent.mkdir(parents=True, exist_ok=True) snapshot_id = run_entry.get("snapshot_id") snapshot = _lookup_snapshot(root_path, snapshot_id) if snapshot_id else None rollup = run_entry.get("rollup") or {} actual_calls = int(rollup.get("total_calls") or 0) actual_tokens = int(rollup.get("total_tokens") or 0) actual_prompt_tokens = int(rollup.get("total_prompt_tokens") or 0) actual_cost_known = _coerce_float(rollup.get("total_cost_usd_known")) or 0.0 actual_cost_estimated = _coerce_float(rollup.get("total_cost_usd_estimated")) or 0.0 actual_cost_total = round(actual_cost_known + actual_cost_estimated, 6) if snapshot is not None: estimated_calls = int(snapshot.get("total_provider_calls_estimate") or 0) estimated_prompt_tokens = int(snapshot.get("total_prompt_tokens_estimate") or 0) estimated_cost = _coerce_float(snapshot.get("estimated_cost_usd")) else: estimated_calls = None estimated_prompt_tokens = None estimated_cost = None summary = { "schema_version": SUMMARY_SCHEMA_VERSION, "recorded_at": _now(), "run_index": run_entry.get("run_index"), "snapshot_id": snapshot_id, "snapshot_resolved": snapshot is not None, "calls": _variance_pair(estimated_calls, actual_calls), "prompt_tokens": _variance_pair(estimated_prompt_tokens, actual_prompt_tokens), "total_tokens": _variance_pair(estimated_prompt_tokens, actual_tokens), "cost_usd": { "estimated": estimated_cost, "actual_known": actual_cost_known, "actual_estimated_from_rates": actual_cost_estimated, "actual_total": actual_cost_total, **_variance_delta_ratio(estimated_cost, actual_cost_total), }, "per_workflow": _per_workflow_variance(snapshot, run_entry), "duration_seconds": run_entry.get("duration_seconds"), } summary_path.write_text(yaml.safe_dump(summary, sort_keys=False), encoding="utf-8") return summary def read_run_variance(root: str | Path) -> dict[str, Any] | None: path = Path(root) / SUMMARY_FILE if not path.is_file(): return None try: data = yaml.safe_load(path.read_text(encoding="utf-8")) except yaml.YAMLError: return None return data if isinstance(data, dict) else None def _lookup_snapshot(root: Path, snapshot_id: str) -> dict[str, Any] | None: for snap in reversed(read_plan_snapshots(root)): if snap.get("snapshot_id") == snapshot_id: return snap return None def _variance_pair(estimated: int | None, actual: int) -> dict[str, Any]: delta = None if estimated is None else actual - estimated ratio = _safe_ratio(actual, estimated) return { "estimated": estimated, "actual": actual, "delta": delta, "ratio": ratio, } def _variance_delta_ratio(estimated: float | None, actual: float) -> dict[str, Any]: delta = None if estimated is None else round(actual - estimated, 6) ratio = _safe_ratio(actual, estimated) return {"delta": delta, "ratio": ratio} def _safe_ratio(actual: float | int, estimated: float | int | None) -> float | None: if estimated in (None, 0, 0.0): return None return round(float(actual) / float(estimated), 4) def _per_workflow_variance( snapshot: dict[str, Any] | None, run_entry: dict[str, Any] ) -> list[dict[str, Any]]: actuals: dict[str, dict[str, int]] = {} for bucket in run_entry.get("per_bucket") or []: workflow_id = bucket.get("workflow_id") or "" if not workflow_id: continue agg = actuals.setdefault( workflow_id, {"calls": 0, "prompt_tokens": 0, "completion_tokens": 0} ) agg["calls"] += int(bucket.get("calls") or 0) agg["prompt_tokens"] += int(bucket.get("prompt_tokens") or 0) agg["completion_tokens"] += int(bucket.get("completion_tokens") or 0) estimates: dict[str, dict[str, int]] = {} if snapshot is not None: for entry in snapshot.get("per_workflow") or []: workflow_id = entry.get("workflow_id") or "" if not workflow_id: continue estimates[workflow_id] = { "calls": int(entry.get("calls") or 0), "prompt_words_estimate": int(entry.get("prompt_words_estimate") or 0), } workflow_ids = sorted(set(actuals) | set(estimates)) out: list[dict[str, Any]] = [] for workflow_id in workflow_ids: actual = actuals.get(workflow_id, {"calls": 0, "prompt_tokens": 0}) estimate = estimates.get(workflow_id) estimated_calls = estimate["calls"] if estimate else None out.append( { "workflow_id": workflow_id, "calls": _variance_pair(estimated_calls, actual["calls"]), "prompt_tokens_actual": actual["prompt_tokens"], "prompt_words_estimate": estimate["prompt_words_estimate"] if estimate else None, } ) return out def load_rate_table(workspace: Path | str | None = None) -> dict[str, dict[str, float]]: """Load the model rate table, with optional workspace override. Returns a mapping ``model_id -> {prompt_per_1k, completion_per_1k}``. The workspace override (``/model-rates.yaml``) is overlaid on top of the package default, so individual models can be tweaked without copying the whole table. """ rates: dict[str, dict[str, float]] = {} for path in (_PACKAGE_RATES_PATH, _workspace_rate_path(workspace)): if path is None or not path.is_file(): continue try: data = yaml.safe_load(path.read_text(encoding="utf-8")) except yaml.YAMLError: continue if not isinstance(data, dict): continue for model, entry in (data.get("rates") or {}).items(): if not isinstance(entry, dict): continue prompt = _coerce_float(entry.get("prompt_per_1k")) completion = _coerce_float(entry.get("completion_per_1k")) if prompt is None and completion is None: continue rates[str(model)] = { "prompt_per_1k": prompt if prompt is not None else 0.0, "completion_per_1k": completion if completion is not None else 0.0, } return rates def estimate_cost_usd( model: str, prompt_tokens: int, completion_tokens: int, rate_table: dict[str, dict[str, float]], ) -> float | None: entry = rate_table.get(model) if entry is None: return None prompt_rate = float(entry.get("prompt_per_1k") or 0.0) completion_rate = float(entry.get("completion_per_1k") or 0.0) cost = (prompt_tokens / 1000.0) * prompt_rate + ( completion_tokens / 1000.0 ) * completion_rate return round(cost, 6) def make_cost_resolver( workspace: Path | str | None, ) -> Callable[[str, str, int, int], float | None]: """Return a resolver suitable for ``record_run_usage(..., cost_resolver=...)``.""" rates = load_rate_table(workspace) def _resolve(provider: str, model: str, prompt_tokens: int, completion_tokens: int) -> float | None: if not model: return None return estimate_cost_usd(model, prompt_tokens, completion_tokens, rates) return _resolve def _workspace_rate_path(workspace: Path | str | None) -> Path | None: if workspace is None: return None candidate = Path(workspace) / RATES_FILENAME return candidate def _coerce_float(value: Any) -> float | None: if value is None: return None try: return float(value) except (TypeError, ValueError): return None def _next_run_index(usage_path: Path) -> int: payload = _read_usage(usage_path) return len(payload.get("runs") or []) + 1 def _read_usage(path: Path) -> dict[str, Any]: if not path.is_file(): return {"schema_version": USAGE_SCHEMA_VERSION, "runs": []} try: data = yaml.safe_load(path.read_text(encoding="utf-8")) except yaml.YAMLError: return {"schema_version": USAGE_SCHEMA_VERSION, "runs": []} if not isinstance(data, dict): return {"schema_version": USAGE_SCHEMA_VERSION, "runs": []} return data def _write_usage(path: Path, payload: dict[str, Any]) -> None: path.write_text(yaml.safe_dump(payload, sort_keys=False), encoding="utf-8") def _build_snapshot(summary: dict[str, Any]) -> dict[str, Any]: filters = { "stage": summary.get("stage"), "chapter_filter": summary.get("chapter_filter"), "chunk_filter": summary.get("chunk_filter"), "from_chapter": summary.get("from_chapter"), "to_chapter": summary.get("to_chapter"), } fingerprint_source = { key: summary.get(key) for key in _SNAPSHOT_FINGERPRINT_FIELDS } fingerprint_source["filters"] = filters snapshot_id = _fingerprint(fingerprint_source) return { "snapshot_id": snapshot_id, "recorded_at": _now(), "stage": summary.get("stage"), "filters": filters, "selected_chunk_count": summary.get("selected_chunk_count"), "selected_chunk_ids": list(summary.get("selected_chunk_ids") or []), "selected_chapter_numbers": list(summary.get("selected_chapter_numbers") or []), "per_workflow": list(summary.get("per_workflow") or []), "total_provider_calls_estimate": summary.get("total_provider_calls_estimate"), "total_prompt_tokens_estimate": summary.get("total_prompt_tokens_estimate"), "total_prompt_words_estimate": summary.get("total_prompt_words_estimate"), "estimated_cost_usd": summary.get("estimated_cost_usd"), "cost_per_1k_tokens": summary.get("cost_per_1k_tokens"), "max_calls": summary.get("max_calls"), "cost_cap": summary.get("cost_cap"), "exceeds_max_calls": bool(summary.get("exceeds_max_calls")), "exceeds_cost_cap": bool(summary.get("exceeds_cost_cap")), } def _fingerprint(payload: dict[str, Any]) -> str: serialised = json.dumps(payload, sort_keys=True, default=str) return hashlib.sha256(serialised.encode("utf-8")).hexdigest()[:12] def _read_plans(path: Path) -> dict[str, Any]: if not path.is_file(): return {"schema_version": PLANS_SCHEMA_VERSION, "pruned_count": 0, "snapshots": []} try: data = yaml.safe_load(path.read_text(encoding="utf-8")) except yaml.YAMLError: return {"schema_version": PLANS_SCHEMA_VERSION, "pruned_count": 0, "snapshots": []} if not isinstance(data, dict): return {"schema_version": PLANS_SCHEMA_VERSION, "pruned_count": 0, "snapshots": []} return data def _write_plans(path: Path, payload: dict[str, Any]) -> None: path.write_text(yaml.safe_dump(payload, sort_keys=False), encoding="utf-8") def _now() -> str: return datetime.now(timezone.utc).isoformat()