Files
state-hub/scripts/token_reconcile.py

240 lines
9.1 KiB
Python

#!/usr/bin/env python3
"""Reconcile token evidence from local agent sources against State Hub.
Dry-run is the default. Use ``--apply`` to upsert measured source events and
``--zero-superseded-fallbacks`` to zero task fallback rows that are covered by
source-backed measurements.
"""
from __future__ import annotations
import argparse
import json
import os
import sys
import urllib.parse
import urllib.request
from collections import Counter, defaultdict
from datetime import datetime
from pathlib import Path
from typing import Any
ROOT = Path(__file__).resolve().parent.parent
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from api.services.token_sources import collect_claude_transcripts, collect_codex_sessions, parse_iso # noqa: E402
from api.services.token_sources.attribution import repo_refs_from_api, resolve_repo # noqa: E402
DEFAULT_API = os.environ.get("STATE_HUB_API", "http://127.0.0.1:8000")
SUPERSEDED_HEURISTIC_NOTE = "heuristic_superseded_by_source_measurement"
def http_json(api_base: str, method: str, path: str, body: dict[str, Any] | None = None) -> Any:
url = f"{api_base.rstrip('/')}/{path.lstrip('/')}"
data = None
headers = {"Content-Type": "application/json"}
if body is not None:
data = json.dumps(body).encode("utf-8")
req = urllib.request.Request(url, data=data, headers=headers, method=method)
with urllib.request.urlopen(req, timeout=30) as resp:
return json.loads(resp.read() or b"null")
def list_events(api_base: str, params: dict[str, Any]) -> list[dict[str, Any]]:
events: list[dict[str, Any]] = []
offset = 0
while True:
encoded = urllib.parse.urlencode({**params, "limit": 1000, "offset": offset})
page = http_json(api_base, "GET", f"/token-events/?{encoded}")
if not isinstance(page, list) or not page:
break
events.extend(page)
if len(page) < 1000:
break
offset += 1000
return events
def find_home(explicit: str | None, env_name: str, default: Path) -> Path | None:
candidates: list[Path] = []
if explicit:
candidates.append(Path(explicit))
env_home = os.environ.get(env_name)
if env_home:
candidates.append(Path(env_home))
candidates.append(default)
for candidate in candidates:
if candidate.is_dir():
return candidate
return None
def event_total(event: dict[str, Any]) -> int:
return int(event.get("tokens_in") or 0) + int(event.get("tokens_out") or 0)
def source_index(events: list[dict[str, Any]]) -> dict[str, dict[str, Any]]:
by_source: dict[str, dict[str, Any]] = {}
for event in events:
source_id = event.get("source_id") or event.get("ref_id")
if isinstance(source_id, str):
by_source[source_id] = event
return by_source
def print_report(report: dict[str, Any]) -> None:
print(json.dumps(report, indent=2, sort_keys=True, default=str))
def main() -> int:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--since", default="2026-05-19", help="UTC date/time to reconcile from")
parser.add_argument("--api-base", default=DEFAULT_API)
parser.add_argument("--codex-home")
parser.add_argument("--claude-home")
parser.add_argument("--apply", action="store_true", help="upsert measured source events")
parser.add_argument(
"--zero-superseded-fallbacks",
action="store_true",
help="with --apply, zero heuristic fallback rows after measured source ingestion",
)
args = parser.parse_args()
since = parse_iso(args.since)
since_param = since.isoformat()
codex_home = find_home(args.codex_home, "CODEX_HOME", Path.home() / ".codex")
if codex_home is None:
windows_codex = Path("/mnt/c/Users/bernd.worsch/.codex")
codex_home = windows_codex if windows_codex.is_dir() else None
claude_home = find_home(args.claude_home, "CLAUDE_HOME", Path.home() / ".claude")
records = []
source_health: dict[str, dict[str, Any]] = {}
if codex_home:
codex_records = collect_codex_sessions(codex_home, since)
records.extend(codex_records)
source_health["codex_session"] = {"home": str(codex_home), "sessions_found": len(codex_records)}
else:
source_health["codex_session"] = {"home": None, "sessions_found": 0, "warning": "Codex home not found"}
if claude_home:
claude_records = collect_claude_transcripts(claude_home, since)
records.extend(claude_records)
source_health["claude_transcript"] = {"home": str(claude_home), "sessions_found": len(claude_records)}
else:
source_health["claude_transcript"] = {"home": None, "sessions_found": 0, "warning": "Claude home not found"}
repos = repo_refs_from_api(http_json(args.api_base, "GET", "/repos/"))
existing_events = list_events(args.api_base, {"since": since_param, "include_superseded": "true"})
existing_by_source = source_index(existing_events)
fallback_events = [
event for event in existing_events
if event.get("source_provider") == "task_fallback" or event.get("note") == "heuristic"
]
superseded_events = [
event for event in existing_events
if event.get("measurement_kind") == "superseded" or str(event.get("note") or "").startswith("heuristic_superseded")
]
planned_upserts = []
unattributed = 0
stale = 0
source_totals: dict[str, int] = defaultdict(int)
for record in records:
source_totals[record.source_provider] += record.tokens_total
existing = existing_by_source.get(record.source_id)
if existing and event_total(existing) >= record.tokens_total:
continue
if existing:
stale += 1
match = resolve_repo(record.cwd, repos)
if match is None:
unattributed += 1
planned_upserts.append((record, match))
source_ids = [
event.get("source_id")
for event in existing_events
if event.get("source_id") and event.get("measurement_kind") == "measured"
]
duplicate_sources = {
source_id: count for source_id, count in Counter(source_ids).items() if count > 1
}
missing_provenance = [
event for event in existing_events
if event.get("measurement_kind") == "measured" and not event.get("source_id")
]
progress_events = http_json(args.api_base, "GET", f"/progress/?since={urllib.parse.quote(since_param)}&limit=1000")
measured_total = sum(
event_total(event)
for event in existing_events
if event.get("measurement_kind") == "measured"
) + sum(record.tokens_total for record, _ in planned_upserts)
canary_failed = bool(progress_events) and measured_total == 0
report = {
"since": since.isoformat(),
"apply": args.apply,
"sources": source_health,
"sessions_found": len(records),
"source_tokens_total": dict(source_totals),
"events_existing": len(existing_events),
"events_to_upsert": len(planned_upserts),
"sessions_stale": stale,
"fallback_events": len(fallback_events),
"superseded_events": len(superseded_events),
"unattributed_source_records": unattributed,
"missing_provenance_events": len(missing_provenance),
"duplicate_source_ids": duplicate_sources,
"progress_events": len(progress_events) if isinstance(progress_events, list) else 0,
"measured_tokens_total_after_plan": measured_total,
"canary_failed": canary_failed,
}
if args.apply:
for record, match in planned_upserts:
payload = record.to_token_event_payload(repo_id=match.repo_id if match else None)
payload["raw_metadata"] = {
**payload.get("raw_metadata", {}),
"repo_slug": match.slug if match else None,
"attribution_method": match.method if match else None,
}
http_json(args.api_base, "POST", "/token-events/upsert", payload)
if args.zero_superseded_fallbacks:
for event in fallback_events:
http_json(
args.api_base,
"PATCH",
f"/token-events/{event['id']}",
{
"tokens_in": 0,
"tokens_out": 0,
"note": SUPERSEDED_HEURISTIC_NOTE,
"measurement_kind": "superseded",
"source_provider": "task_fallback",
"confidence": 0.0,
"raw_total_tokens": 0,
},
)
http_json(
args.api_base,
"POST",
"/progress/",
{
"summary": (
"Token reconciliation: "
f"{len(records)} source records, {len(planned_upserts)} upserts, "
f"{len(fallback_events)} fallback events, canary_failed={canary_failed}"
),
"event_type": "token_reconciliation",
"author": "codex",
"detail": report,
},
)
print_report(report)
return 1 if canary_failed else 0
if __name__ == "__main__":
raise SystemExit(main())