Files
state-hub/scripts/task_token_hook.py

207 lines
6.9 KiB
Python
Executable File

#!/usr/bin/env python3
"""PostToolUse hook: replace heuristic token events with real transcript-derived counts.
Fires after supported task completion tools when status=done.
Reads the Claude Code session transcript to compute the token delta since the
previous task completion, then PATCHes the heuristic event with real counts.
State is persisted per session in a durable cache directory so deltas survive
restarts and multiple task completions in one session.
"""
import json
import os
import sys
import urllib.error
import urllib.request
from datetime import datetime, timezone
from pathlib import Path
API = os.environ.get("CUSTODIAN_API", "http://127.0.0.1:8000")
STATE_DIR = Path(os.environ.get("CUSTODIAN_TOKEN_STATE_DIR", Path.home() / ".cache" / "state-hub" / "token-hooks"))
HEALTH_LOG = STATE_DIR / "hook-health.jsonl"
PARSER_VERSION = "claude-transcript-delta-v1"
SUPPORTED_TOOL_HINTS = (
"update_task_status",
"tasks",
"task",
)
def utc_now() -> str:
return datetime.now(timezone.utc).isoformat()
def write_health(event: dict) -> None:
try:
STATE_DIR.mkdir(parents=True, exist_ok=True)
with HEALTH_LOG.open("a", encoding="utf-8") as handle:
handle.write(json.dumps({"ts": utc_now(), **event}, sort_keys=True) + "\n")
except OSError:
pass
def read_transcript_totals(transcript_path: str) -> tuple[int, int, int]:
"""Sum all usage entries in the transcript JSONL up to the current point."""
total_in = total_out = cached_in = 0
try:
with open(transcript_path) as f:
for line in f:
try:
entry = json.loads(line)
usage = entry.get("message", {}).get("usage", {})
if usage:
total_in += usage.get("input_tokens", 0)
cached_in += (
usage.get("cache_creation_input_tokens", 0)
+ usage.get("cache_read_input_tokens", 0)
)
total_out += usage.get("output_tokens", 0)
except (json.JSONDecodeError, TypeError):
continue
except OSError:
pass
return total_in, total_out, cached_in
def load_state(session_id: str) -> tuple[int, int, int]:
state_file = STATE_DIR / f"custodian_tokens_{session_id}.json"
try:
data = json.loads(state_file.read_text())
return data.get("total_in", 0), data.get("total_out", 0), data.get("cached_in", 0)
except (OSError, json.JSONDecodeError):
return 0, 0, 0
def save_state(session_id: str, total_in: int, total_out: int, cached_in: int) -> None:
STATE_DIR.mkdir(parents=True, exist_ok=True)
state_file = STATE_DIR / f"custodian_tokens_{session_id}.json"
state_file.write_text(json.dumps({"total_in": total_in, "total_out": total_out, "cached_in": cached_in}))
def api_get(path: str):
req = urllib.request.Request(f"{API}{path}")
with urllib.request.urlopen(req, timeout=5) as r:
return json.loads(r.read())
def api_patch(path: str, data: dict):
body = json.dumps(data).encode()
req = urllib.request.Request(
f"{API}{path}",
data=body,
headers={"Content-Type": "application/json"},
method="PATCH",
)
with urllib.request.urlopen(req, timeout=5) as r:
return json.loads(r.read())
def extract_done_task(payload: dict) -> tuple[str | None, dict]:
tool_name = payload.get("tool_name", "")
if not any(hint in tool_name for hint in SUPPORTED_TOOL_HINTS):
return None, {}
tool_input = payload.get("tool_input", {}) or {}
status = tool_input.get("status")
if status != "done":
return None, {}
task_id = (
tool_input.get("task_id")
or tool_input.get("id")
or tool_input.get("taskId")
)
return task_id, tool_input
def main() -> None:
try:
payload = json.loads(sys.stdin.read())
except json.JSONDecodeError:
return
task_id, tool_input = extract_done_task(payload)
if not task_id:
write_health({"status": "skipped", "reason": "not_done_task_completion", "tool_name": payload.get("tool_name")})
return
transcript_path = payload.get("transcript_path", "")
session_id = payload.get("session_id", "unknown")
# Compute token delta for this task
current_in, current_out, current_cached = read_transcript_totals(transcript_path)
last_in, last_out, last_cached = load_state(session_id)
delta_in = max(0, current_in - last_in)
delta_out = max(0, current_out - last_out)
delta_cached = max(0, current_cached - last_cached)
save_state(session_id, current_in, current_out, current_cached)
if delta_in == 0 and delta_out == 0 and delta_cached == 0:
write_health({
"status": "skipped",
"reason": "zero_delta",
"session_id": session_id,
"task_id": task_id,
"source_path": transcript_path,
})
return
# Find the most recent heuristic event for this task and replace it
try:
events = api_get(f"/token-events/?task_id={task_id}&note=heuristic&limit=5")
except (urllib.error.URLError, OSError):
write_health({"status": "skipped", "reason": "api_offline", "session_id": session_id, "task_id": task_id})
return # API offline — leave heuristic as-is
if not events:
write_health({"status": "skipped", "reason": "no_fallback_event", "session_id": session_id, "task_id": task_id})
return
event_id = events[0]["id"]
model = tool_input.get("model")
agent = tool_input.get("agent")
patch_body: dict = {
"tokens_in": delta_in,
"tokens_out": delta_out,
"note": "measured",
"measurement_kind": "measured",
"source_provider": "claude_transcript",
"source_id": f"claude:{session_id}:task:{task_id}",
"source_path": transcript_path or None,
"parser_version": PARSER_VERSION,
"confidence": 1.0,
"cached_input_tokens": delta_cached,
"raw_total_tokens": delta_in + delta_out + delta_cached,
"raw_metadata": {
"hook": "post_tool_use",
"tool_name": payload.get("tool_name"),
"state_dir": str(STATE_DIR),
},
}
if model:
patch_body["model"] = model
if agent:
patch_body["agent"] = agent
try:
api_patch(f"/token-events/{event_id}", patch_body)
except (urllib.error.URLError, OSError):
write_health({"status": "skipped", "reason": "patch_failed", "session_id": session_id, "task_id": task_id})
return
write_health({
"status": "patched",
"session_id": session_id,
"task_id": task_id,
"event_id": event_id,
"tokens_in": delta_in,
"tokens_out": delta_out,
"cached_input_tokens": delta_cached,
"source_path": transcript_path,
})
if __name__ == "__main__":
main()