#!/usr/bin/env python3 """PostToolUse hook: replace heuristic token events with real transcript-derived counts. Fires after mcp__state-hub__update_task_status when status=done. Reads the Claude Code session transcript to compute the token delta since the previous task completion, then PATCHes the heuristic event with real counts. State is persisted per session in /tmp/custodian_tokens_.json so deltas are correctly scoped even when multiple tasks complete in one session. """ import json import os import sys import urllib.error import urllib.request from pathlib import Path API = os.environ.get("CUSTODIAN_API", "http://127.0.0.1:8000") STATE_DIR = Path(os.environ.get("TMPDIR", "/tmp")) def read_transcript_totals(transcript_path: str) -> tuple[int, int]: """Sum all usage entries in the transcript JSONL up to the current point.""" total_in = total_out = 0 try: with open(transcript_path) as f: for line in f: try: entry = json.loads(line) usage = entry.get("message", {}).get("usage", {}) if usage: # Count all input token variants (direct + cache creation + cache read) total_in += ( usage.get("input_tokens", 0) + usage.get("cache_creation_input_tokens", 0) + usage.get("cache_read_input_tokens", 0) ) total_out += usage.get("output_tokens", 0) except (json.JSONDecodeError, TypeError): continue except OSError: pass return total_in, total_out def load_state(session_id: str) -> tuple[int, int]: state_file = STATE_DIR / f"custodian_tokens_{session_id}.json" try: data = json.loads(state_file.read_text()) return data.get("total_in", 0), data.get("total_out", 0) except (OSError, json.JSONDecodeError): return 0, 0 def save_state(session_id: str, total_in: int, total_out: int) -> None: state_file = STATE_DIR / f"custodian_tokens_{session_id}.json" state_file.write_text(json.dumps({"total_in": total_in, "total_out": total_out})) def api_get(path: str): req = urllib.request.Request(f"{API}{path}") with urllib.request.urlopen(req, timeout=5) as r: return json.loads(r.read()) def api_patch(path: str, data: dict): body = json.dumps(data).encode() req = urllib.request.Request( f"{API}{path}", data=body, headers={"Content-Type": "application/json"}, method="PATCH", ) with urllib.request.urlopen(req, timeout=5) as r: return json.loads(r.read()) def main() -> None: try: payload = json.loads(sys.stdin.read()) except json.JSONDecodeError: return tool_name = payload.get("tool_name", "") if "update_task_status" not in tool_name: return tool_input = payload.get("tool_input", {}) if tool_input.get("status") != "done": return task_id = tool_input.get("task_id") if not task_id: return transcript_path = payload.get("transcript_path", "") session_id = payload.get("session_id", "unknown") # Compute token delta for this task current_in, current_out = read_transcript_totals(transcript_path) last_in, last_out = load_state(session_id) delta_in = max(0, current_in - last_in) delta_out = max(0, current_out - last_out) save_state(session_id, current_in, current_out) if delta_in == 0 and delta_out == 0: return # Nothing measurable — leave heuristic in place # Find the most recent heuristic event for this task and replace it try: events = api_get(f"/token-events/?task_id={task_id}¬e=heuristic&limit=5") except (urllib.error.URLError, OSError): return # API offline — leave heuristic as-is if not events: return event_id = events[0]["id"] model = tool_input.get("model") agent = tool_input.get("agent") patch_body: dict = {"tokens_in": delta_in, "tokens_out": delta_out, "note": "measured"} if model: patch_body["model"] = model if agent: patch_body["agent"] = agent try: api_patch(f"/token-events/{event_id}", patch_body) except (urllib.error.URLError, OSError): pass if __name__ == "__main__": main()