#!/usr/bin/env python3 """PostToolUse hook: replace heuristic token events with real transcript-derived counts. Fires after supported task completion tools when status=done. Reads the Claude Code session transcript to compute the token delta since the previous task completion, then PATCHes the heuristic event with real counts. State is persisted per session in a durable cache directory so deltas survive restarts and multiple task completions in one session. """ import json import os import sys import urllib.error import urllib.request from datetime import datetime, timezone from pathlib import Path API = os.environ.get("CUSTODIAN_API", "http://127.0.0.1:8000") STATE_DIR = Path(os.environ.get("CUSTODIAN_TOKEN_STATE_DIR", Path.home() / ".cache" / "state-hub" / "token-hooks")) HEALTH_LOG = STATE_DIR / "hook-health.jsonl" PARSER_VERSION = "claude-transcript-delta-v1" SUPPORTED_TOOL_HINTS = ( "update_task_status", "tasks", "task", ) def utc_now() -> str: return datetime.now(timezone.utc).isoformat() def write_health(event: dict) -> None: try: STATE_DIR.mkdir(parents=True, exist_ok=True) with HEALTH_LOG.open("a", encoding="utf-8") as handle: handle.write(json.dumps({"ts": utc_now(), **event}, sort_keys=True) + "\n") except OSError: pass def read_transcript_totals(transcript_path: str) -> tuple[int, int, int]: """Sum all usage entries in the transcript JSONL up to the current point.""" total_in = total_out = cached_in = 0 try: with open(transcript_path) as f: for line in f: try: entry = json.loads(line) usage = entry.get("message", {}).get("usage", {}) if usage: total_in += usage.get("input_tokens", 0) cached_in += ( usage.get("cache_creation_input_tokens", 0) + usage.get("cache_read_input_tokens", 0) ) total_out += usage.get("output_tokens", 0) except (json.JSONDecodeError, TypeError): continue except OSError: pass return total_in, total_out, cached_in def load_state(session_id: str) -> tuple[int, int, int]: state_file = STATE_DIR / f"custodian_tokens_{session_id}.json" try: data = json.loads(state_file.read_text()) return data.get("total_in", 0), data.get("total_out", 0), data.get("cached_in", 0) except (OSError, json.JSONDecodeError): return 0, 0, 0 def save_state(session_id: str, total_in: int, total_out: int, cached_in: int) -> None: STATE_DIR.mkdir(parents=True, exist_ok=True) state_file = STATE_DIR / f"custodian_tokens_{session_id}.json" state_file.write_text(json.dumps({"total_in": total_in, "total_out": total_out, "cached_in": cached_in})) def api_get(path: str): req = urllib.request.Request(f"{API}{path}") with urllib.request.urlopen(req, timeout=5) as r: return json.loads(r.read()) def api_patch(path: str, data: dict): body = json.dumps(data).encode() req = urllib.request.Request( f"{API}{path}", data=body, headers={"Content-Type": "application/json"}, method="PATCH", ) with urllib.request.urlopen(req, timeout=5) as r: return json.loads(r.read()) def extract_done_task(payload: dict) -> tuple[str | None, dict]: tool_name = payload.get("tool_name", "") if not any(hint in tool_name for hint in SUPPORTED_TOOL_HINTS): return None, {} tool_input = payload.get("tool_input", {}) or {} status = tool_input.get("status") if status != "done": return None, {} task_id = ( tool_input.get("task_id") or tool_input.get("id") or tool_input.get("taskId") ) return task_id, tool_input def main() -> None: try: payload = json.loads(sys.stdin.read()) except json.JSONDecodeError: return task_id, tool_input = extract_done_task(payload) if not task_id: write_health({"status": "skipped", "reason": "not_done_task_completion", "tool_name": payload.get("tool_name")}) return transcript_path = payload.get("transcript_path", "") session_id = payload.get("session_id", "unknown") # Compute token delta for this task current_in, current_out, current_cached = read_transcript_totals(transcript_path) last_in, last_out, last_cached = load_state(session_id) delta_in = max(0, current_in - last_in) delta_out = max(0, current_out - last_out) delta_cached = max(0, current_cached - last_cached) save_state(session_id, current_in, current_out, current_cached) if delta_in == 0 and delta_out == 0 and delta_cached == 0: write_health({ "status": "skipped", "reason": "zero_delta", "session_id": session_id, "task_id": task_id, "source_path": transcript_path, }) return # Find the most recent heuristic event for this task and replace it try: events = api_get(f"/token-events/?task_id={task_id}¬e=heuristic&limit=5") except (urllib.error.URLError, OSError): write_health({"status": "skipped", "reason": "api_offline", "session_id": session_id, "task_id": task_id}) return # API offline — leave heuristic as-is if not events: write_health({"status": "skipped", "reason": "no_fallback_event", "session_id": session_id, "task_id": task_id}) return event_id = events[0]["id"] model = tool_input.get("model") agent = tool_input.get("agent") patch_body: dict = { "tokens_in": delta_in, "tokens_out": delta_out, "note": "measured", "measurement_kind": "measured", "source_provider": "claude_transcript", "source_id": f"claude:{session_id}:task:{task_id}", "source_path": transcript_path or None, "parser_version": PARSER_VERSION, "confidence": 1.0, "cached_input_tokens": delta_cached, "raw_total_tokens": delta_in + delta_out + delta_cached, "raw_metadata": { "hook": "post_tool_use", "tool_name": payload.get("tool_name"), "state_dir": str(STATE_DIR), }, } if model: patch_body["model"] = model if agent: patch_body["agent"] = agent try: api_patch(f"/token-events/{event_id}", patch_body) except (urllib.error.URLError, OSError): write_health({"status": "skipped", "reason": "patch_failed", "session_id": session_id, "task_id": task_id}) return write_health({ "status": "patched", "session_id": session_id, "task_id": task_id, "event_id": event_id, "tokens_in": delta_in, "tokens_out": delta_out, "cached_input_tokens": delta_cached, "source_path": transcript_path, }) if __name__ == "__main__": main()