generated from coulomb/repo-seed
207 lines
6.9 KiB
Python
Executable File
207 lines
6.9 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""PostToolUse hook: replace heuristic token events with real transcript-derived counts.
|
|
|
|
Fires after supported task completion tools when status=done.
|
|
Reads the Claude Code session transcript to compute the token delta since the
|
|
previous task completion, then PATCHes the heuristic event with real counts.
|
|
|
|
State is persisted per session in a durable cache directory so deltas survive
|
|
restarts and multiple task completions in one session.
|
|
"""
|
|
import json
|
|
import os
|
|
import sys
|
|
import urllib.error
|
|
import urllib.request
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
|
|
API = os.environ.get("CUSTODIAN_API", "http://127.0.0.1:8000")
|
|
STATE_DIR = Path(os.environ.get("CUSTODIAN_TOKEN_STATE_DIR", Path.home() / ".cache" / "state-hub" / "token-hooks"))
|
|
HEALTH_LOG = STATE_DIR / "hook-health.jsonl"
|
|
PARSER_VERSION = "claude-transcript-delta-v1"
|
|
SUPPORTED_TOOL_HINTS = (
|
|
"update_task_status",
|
|
"tasks",
|
|
"task",
|
|
)
|
|
|
|
|
|
def utc_now() -> str:
|
|
return datetime.now(timezone.utc).isoformat()
|
|
|
|
|
|
def write_health(event: dict) -> None:
|
|
try:
|
|
STATE_DIR.mkdir(parents=True, exist_ok=True)
|
|
with HEALTH_LOG.open("a", encoding="utf-8") as handle:
|
|
handle.write(json.dumps({"ts": utc_now(), **event}, sort_keys=True) + "\n")
|
|
except OSError:
|
|
pass
|
|
|
|
|
|
def read_transcript_totals(transcript_path: str) -> tuple[int, int, int]:
|
|
"""Sum all usage entries in the transcript JSONL up to the current point."""
|
|
total_in = total_out = cached_in = 0
|
|
try:
|
|
with open(transcript_path) as f:
|
|
for line in f:
|
|
try:
|
|
entry = json.loads(line)
|
|
usage = entry.get("message", {}).get("usage", {})
|
|
if usage:
|
|
total_in += usage.get("input_tokens", 0)
|
|
cached_in += (
|
|
usage.get("cache_creation_input_tokens", 0)
|
|
+ usage.get("cache_read_input_tokens", 0)
|
|
)
|
|
total_out += usage.get("output_tokens", 0)
|
|
except (json.JSONDecodeError, TypeError):
|
|
continue
|
|
except OSError:
|
|
pass
|
|
return total_in, total_out, cached_in
|
|
|
|
|
|
def load_state(session_id: str) -> tuple[int, int, int]:
|
|
state_file = STATE_DIR / f"custodian_tokens_{session_id}.json"
|
|
try:
|
|
data = json.loads(state_file.read_text())
|
|
return data.get("total_in", 0), data.get("total_out", 0), data.get("cached_in", 0)
|
|
except (OSError, json.JSONDecodeError):
|
|
return 0, 0, 0
|
|
|
|
|
|
def save_state(session_id: str, total_in: int, total_out: int, cached_in: int) -> None:
|
|
STATE_DIR.mkdir(parents=True, exist_ok=True)
|
|
state_file = STATE_DIR / f"custodian_tokens_{session_id}.json"
|
|
state_file.write_text(json.dumps({"total_in": total_in, "total_out": total_out, "cached_in": cached_in}))
|
|
|
|
|
|
def api_get(path: str):
|
|
req = urllib.request.Request(f"{API}{path}")
|
|
with urllib.request.urlopen(req, timeout=5) as r:
|
|
return json.loads(r.read())
|
|
|
|
|
|
def api_patch(path: str, data: dict):
|
|
body = json.dumps(data).encode()
|
|
req = urllib.request.Request(
|
|
f"{API}{path}",
|
|
data=body,
|
|
headers={"Content-Type": "application/json"},
|
|
method="PATCH",
|
|
)
|
|
with urllib.request.urlopen(req, timeout=5) as r:
|
|
return json.loads(r.read())
|
|
|
|
|
|
def extract_done_task(payload: dict) -> tuple[str | None, dict]:
|
|
tool_name = payload.get("tool_name", "")
|
|
if not any(hint in tool_name for hint in SUPPORTED_TOOL_HINTS):
|
|
return None, {}
|
|
|
|
tool_input = payload.get("tool_input", {}) or {}
|
|
status = tool_input.get("status")
|
|
if status != "done":
|
|
return None, {}
|
|
|
|
task_id = (
|
|
tool_input.get("task_id")
|
|
or tool_input.get("id")
|
|
or tool_input.get("taskId")
|
|
)
|
|
return task_id, tool_input
|
|
|
|
|
|
def main() -> None:
|
|
try:
|
|
payload = json.loads(sys.stdin.read())
|
|
except json.JSONDecodeError:
|
|
return
|
|
|
|
task_id, tool_input = extract_done_task(payload)
|
|
if not task_id:
|
|
write_health({"status": "skipped", "reason": "not_done_task_completion", "tool_name": payload.get("tool_name")})
|
|
return
|
|
|
|
transcript_path = payload.get("transcript_path", "")
|
|
session_id = payload.get("session_id", "unknown")
|
|
|
|
# Compute token delta for this task
|
|
current_in, current_out, current_cached = read_transcript_totals(transcript_path)
|
|
last_in, last_out, last_cached = load_state(session_id)
|
|
delta_in = max(0, current_in - last_in)
|
|
delta_out = max(0, current_out - last_out)
|
|
delta_cached = max(0, current_cached - last_cached)
|
|
save_state(session_id, current_in, current_out, current_cached)
|
|
|
|
if delta_in == 0 and delta_out == 0 and delta_cached == 0:
|
|
write_health({
|
|
"status": "skipped",
|
|
"reason": "zero_delta",
|
|
"session_id": session_id,
|
|
"task_id": task_id,
|
|
"source_path": transcript_path,
|
|
})
|
|
return
|
|
|
|
# Find the most recent heuristic event for this task and replace it
|
|
try:
|
|
events = api_get(f"/token-events/?task_id={task_id}¬e=heuristic&limit=5")
|
|
except (urllib.error.URLError, OSError):
|
|
write_health({"status": "skipped", "reason": "api_offline", "session_id": session_id, "task_id": task_id})
|
|
return # API offline — leave heuristic as-is
|
|
|
|
if not events:
|
|
write_health({"status": "skipped", "reason": "no_fallback_event", "session_id": session_id, "task_id": task_id})
|
|
return
|
|
|
|
event_id = events[0]["id"]
|
|
model = tool_input.get("model")
|
|
agent = tool_input.get("agent")
|
|
|
|
patch_body: dict = {
|
|
"tokens_in": delta_in,
|
|
"tokens_out": delta_out,
|
|
"note": "measured",
|
|
"measurement_kind": "measured",
|
|
"source_provider": "claude_transcript",
|
|
"source_id": f"claude:{session_id}:task:{task_id}",
|
|
"source_path": transcript_path or None,
|
|
"parser_version": PARSER_VERSION,
|
|
"confidence": 1.0,
|
|
"cached_input_tokens": delta_cached,
|
|
"raw_total_tokens": delta_in + delta_out + delta_cached,
|
|
"raw_metadata": {
|
|
"hook": "post_tool_use",
|
|
"tool_name": payload.get("tool_name"),
|
|
"state_dir": str(STATE_DIR),
|
|
},
|
|
}
|
|
if model:
|
|
patch_body["model"] = model
|
|
if agent:
|
|
patch_body["agent"] = agent
|
|
|
|
try:
|
|
api_patch(f"/token-events/{event_id}", patch_body)
|
|
except (urllib.error.URLError, OSError):
|
|
write_health({"status": "skipped", "reason": "patch_failed", "session_id": session_id, "task_id": task_id})
|
|
return
|
|
|
|
write_health({
|
|
"status": "patched",
|
|
"session_id": session_id,
|
|
"task_id": task_id,
|
|
"event_id": event_id,
|
|
"tokens_in": delta_in,
|
|
"tokens_out": delta_out,
|
|
"cached_input_tokens": delta_cached,
|
|
"source_path": transcript_path,
|
|
})
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|