Fixed and improved token tracking

This commit is contained in:
2026-05-23 13:59:05 +02:00
parent dd3279ea1a
commit c12091c2eb
29 changed files with 3549 additions and 278 deletions

View File

@@ -1,27 +1,48 @@
#!/usr/bin/env python3
"""PostToolUse hook: replace heuristic token events with real transcript-derived counts.
Fires after mcp__state-hub__update_task_status when status=done.
Fires after supported task completion tools when status=done.
Reads the Claude Code session transcript to compute the token delta since the
previous task completion, then PATCHes the heuristic event with real counts.
State is persisted per session in /tmp/custodian_tokens_<session_id>.json so
deltas are correctly scoped even when multiple tasks complete in one session.
State is persisted per session in a durable cache directory so deltas survive
restarts and multiple task completions in one session.
"""
import json
import os
import sys
import urllib.error
import urllib.request
from datetime import datetime, timezone
from pathlib import Path
API = os.environ.get("CUSTODIAN_API", "http://127.0.0.1:8000")
STATE_DIR = Path(os.environ.get("TMPDIR", "/tmp"))
STATE_DIR = Path(os.environ.get("CUSTODIAN_TOKEN_STATE_DIR", Path.home() / ".cache" / "state-hub" / "token-hooks"))
HEALTH_LOG = STATE_DIR / "hook-health.jsonl"
PARSER_VERSION = "claude-transcript-delta-v1"
SUPPORTED_TOOL_HINTS = (
"update_task_status",
"tasks",
"task",
)
def read_transcript_totals(transcript_path: str) -> tuple[int, int]:
def utc_now() -> str:
return datetime.now(timezone.utc).isoformat()
def write_health(event: dict) -> None:
try:
STATE_DIR.mkdir(parents=True, exist_ok=True)
with HEALTH_LOG.open("a", encoding="utf-8") as handle:
handle.write(json.dumps({"ts": utc_now(), **event}, sort_keys=True) + "\n")
except OSError:
pass
def read_transcript_totals(transcript_path: str) -> tuple[int, int, int]:
"""Sum all usage entries in the transcript JSONL up to the current point."""
total_in = total_out = 0
total_in = total_out = cached_in = 0
try:
with open(transcript_path) as f:
for line in f:
@@ -29,10 +50,9 @@ def read_transcript_totals(transcript_path: str) -> tuple[int, int]:
entry = json.loads(line)
usage = entry.get("message", {}).get("usage", {})
if usage:
# Count all input token variants (direct + cache creation + cache read)
total_in += (
usage.get("input_tokens", 0)
+ usage.get("cache_creation_input_tokens", 0)
total_in += usage.get("input_tokens", 0)
cached_in += (
usage.get("cache_creation_input_tokens", 0)
+ usage.get("cache_read_input_tokens", 0)
)
total_out += usage.get("output_tokens", 0)
@@ -40,21 +60,22 @@ def read_transcript_totals(transcript_path: str) -> tuple[int, int]:
continue
except OSError:
pass
return total_in, total_out
return total_in, total_out, cached_in
def load_state(session_id: str) -> tuple[int, int]:
def load_state(session_id: str) -> tuple[int, int, int]:
state_file = STATE_DIR / f"custodian_tokens_{session_id}.json"
try:
data = json.loads(state_file.read_text())
return data.get("total_in", 0), data.get("total_out", 0)
return data.get("total_in", 0), data.get("total_out", 0), data.get("cached_in", 0)
except (OSError, json.JSONDecodeError):
return 0, 0
return 0, 0, 0
def save_state(session_id: str, total_in: int, total_out: int) -> None:
def save_state(session_id: str, total_in: int, total_out: int, cached_in: int) -> None:
STATE_DIR.mkdir(parents=True, exist_ok=True)
state_file = STATE_DIR / f"custodian_tokens_{session_id}.json"
state_file.write_text(json.dumps({"total_in": total_in, "total_out": total_out}))
state_file.write_text(json.dumps({"total_in": total_in, "total_out": total_out, "cached_in": cached_in}))
def api_get(path: str):
@@ -75,51 +96,89 @@ def api_patch(path: str, data: dict):
return json.loads(r.read())
def extract_done_task(payload: dict) -> tuple[str | None, dict]:
tool_name = payload.get("tool_name", "")
if not any(hint in tool_name for hint in SUPPORTED_TOOL_HINTS):
return None, {}
tool_input = payload.get("tool_input", {}) or {}
status = tool_input.get("status")
if status != "done":
return None, {}
task_id = (
tool_input.get("task_id")
or tool_input.get("id")
or tool_input.get("taskId")
)
return task_id, tool_input
def main() -> None:
try:
payload = json.loads(sys.stdin.read())
except json.JSONDecodeError:
return
tool_name = payload.get("tool_name", "")
if "update_task_status" not in tool_name:
return
tool_input = payload.get("tool_input", {})
if tool_input.get("status") != "done":
return
task_id = tool_input.get("task_id")
task_id, tool_input = extract_done_task(payload)
if not task_id:
write_health({"status": "skipped", "reason": "not_done_task_completion", "tool_name": payload.get("tool_name")})
return
transcript_path = payload.get("transcript_path", "")
session_id = payload.get("session_id", "unknown")
# Compute token delta for this task
current_in, current_out = read_transcript_totals(transcript_path)
last_in, last_out = load_state(session_id)
current_in, current_out, current_cached = read_transcript_totals(transcript_path)
last_in, last_out, last_cached = load_state(session_id)
delta_in = max(0, current_in - last_in)
delta_out = max(0, current_out - last_out)
save_state(session_id, current_in, current_out)
delta_cached = max(0, current_cached - last_cached)
save_state(session_id, current_in, current_out, current_cached)
if delta_in == 0 and delta_out == 0:
return # Nothing measurable — leave heuristic in place
if delta_in == 0 and delta_out == 0 and delta_cached == 0:
write_health({
"status": "skipped",
"reason": "zero_delta",
"session_id": session_id,
"task_id": task_id,
"source_path": transcript_path,
})
return
# Find the most recent heuristic event for this task and replace it
try:
events = api_get(f"/token-events/?task_id={task_id}&note=heuristic&limit=5")
except (urllib.error.URLError, OSError):
write_health({"status": "skipped", "reason": "api_offline", "session_id": session_id, "task_id": task_id})
return # API offline — leave heuristic as-is
if not events:
write_health({"status": "skipped", "reason": "no_fallback_event", "session_id": session_id, "task_id": task_id})
return
event_id = events[0]["id"]
model = tool_input.get("model")
agent = tool_input.get("agent")
patch_body: dict = {"tokens_in": delta_in, "tokens_out": delta_out, "note": "measured"}
patch_body: dict = {
"tokens_in": delta_in,
"tokens_out": delta_out,
"note": "measured",
"measurement_kind": "measured",
"source_provider": "claude_transcript",
"source_id": f"claude:{session_id}:task:{task_id}",
"source_path": transcript_path or None,
"parser_version": PARSER_VERSION,
"confidence": 1.0,
"cached_input_tokens": delta_cached,
"raw_total_tokens": delta_in + delta_out + delta_cached,
"raw_metadata": {
"hook": "post_tool_use",
"tool_name": payload.get("tool_name"),
"state_dir": str(STATE_DIR),
},
}
if model:
patch_body["model"] = model
if agent:
@@ -128,7 +187,19 @@ def main() -> None:
try:
api_patch(f"/token-events/{event_id}", patch_body)
except (urllib.error.URLError, OSError):
pass
write_health({"status": "skipped", "reason": "patch_failed", "session_id": session_id, "task_id": task_id})
return
write_health({
"status": "patched",
"session_id": session_id,
"task_id": task_id,
"event_id": event_id,
"tokens_in": delta_in,
"tokens_out": delta_out,
"cached_input_tokens": delta_cached,
"source_path": transcript_path,
})
if __name__ == "__main__":