generated from coulomb/repo-seed
feat(token-events): auto-capture real token counts via PostToolUse hook
- Add PATCH /token-events/{id} endpoint to correct heuristic events
- Add `note` filter to GET /token-events/ list
- Add TokenEventPatch schema
- Add task_token_hook.py: PostToolUse hook that reads the Claude Code
session transcript, computes per-task token delta, and replaces the
heuristic token event with real measured counts (note="measured")
- Register hook in ~/.claude/settings.json on mcp__state-hub__update_task_status
Covers both interactive sessions and ralph-workplan loops
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -10,7 +10,7 @@ from api.models.managed_repo import ManagedRepo
|
|||||||
from api.models.task import Task
|
from api.models.task import Task
|
||||||
from api.models.token_event import TokenEvent
|
from api.models.token_event import TokenEvent
|
||||||
from api.models.workstream import Workstream
|
from api.models.workstream import Workstream
|
||||||
from api.schemas.token_event import RepoTokenSummary, TokenEventCreate, TokenEventRead, TokenSummary
|
from api.schemas.token_event import RepoTokenSummary, TokenEventCreate, TokenEventPatch, TokenEventRead, TokenSummary
|
||||||
|
|
||||||
router = APIRouter(prefix="/token-events", tags=["token-events"])
|
router = APIRouter(prefix="/token-events", tags=["token-events"])
|
||||||
|
|
||||||
@@ -166,6 +166,22 @@ async def get_tokens_by_repo(
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@router.patch("/{event_id}", response_model=TokenEventRead)
|
||||||
|
async def patch_token_event(
|
||||||
|
event_id: uuid.UUID,
|
||||||
|
body: TokenEventPatch,
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
|
) -> TokenEvent:
|
||||||
|
event = await session.get(TokenEvent, event_id)
|
||||||
|
if event is None:
|
||||||
|
raise HTTPException(status_code=404, detail="Token event not found")
|
||||||
|
for field, value in body.model_dump(exclude_none=True).items():
|
||||||
|
setattr(event, field, value)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(event)
|
||||||
|
return event
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{event_id}", response_model=TokenEventRead)
|
@router.get("/{event_id}", response_model=TokenEventRead)
|
||||||
async def get_token_event(
|
async def get_token_event(
|
||||||
event_id: uuid.UUID,
|
event_id: uuid.UUID,
|
||||||
@@ -186,6 +202,7 @@ async def list_token_events(
|
|||||||
ref_id: str | None = None,
|
ref_id: str | None = None,
|
||||||
model: str | None = None,
|
model: str | None = None,
|
||||||
agent: str | None = None,
|
agent: str | None = None,
|
||||||
|
note: str | None = None,
|
||||||
limit: int = Query(100, le=1000),
|
limit: int = Query(100, le=1000),
|
||||||
session: AsyncSession = Depends(get_session),
|
session: AsyncSession = Depends(get_session),
|
||||||
) -> list[TokenEvent]:
|
) -> list[TokenEvent]:
|
||||||
@@ -204,6 +221,8 @@ async def list_token_events(
|
|||||||
q = q.where(TokenEvent.model == model)
|
q = q.where(TokenEvent.model == model)
|
||||||
if agent:
|
if agent:
|
||||||
q = q.where(TokenEvent.agent == agent)
|
q = q.where(TokenEvent.agent == agent)
|
||||||
|
if note:
|
||||||
|
q = q.where(TokenEvent.note == note)
|
||||||
q = q.order_by(TokenEvent.created_at.desc()).limit(limit)
|
q = q.order_by(TokenEvent.created_at.desc()).limit(limit)
|
||||||
result = await session.execute(q)
|
result = await session.execute(q)
|
||||||
return list(result.scalars().all())
|
return list(result.scalars().all())
|
||||||
|
|||||||
@@ -52,6 +52,14 @@ class TokenSummary(BaseModel):
|
|||||||
by_agent: dict[str, int]
|
by_agent: dict[str, int]
|
||||||
|
|
||||||
|
|
||||||
|
class TokenEventPatch(BaseModel):
|
||||||
|
tokens_in: int | None = None
|
||||||
|
tokens_out: int | None = None
|
||||||
|
note: str | None = None
|
||||||
|
model: str | None = None
|
||||||
|
agent: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class RepoTokenSummary(BaseModel):
|
class RepoTokenSummary(BaseModel):
|
||||||
repo_id: uuid.UUID
|
repo_id: uuid.UUID
|
||||||
repo_slug: str
|
repo_slug: str
|
||||||
|
|||||||
135
scripts/task_token_hook.py
Executable file
135
scripts/task_token_hook.py
Executable file
@@ -0,0 +1,135 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""PostToolUse hook: replace heuristic token events with real transcript-derived counts.
|
||||||
|
|
||||||
|
Fires after mcp__state-hub__update_task_status when status=done.
|
||||||
|
Reads the Claude Code session transcript to compute the token delta since the
|
||||||
|
previous task completion, then PATCHes the heuristic event with real counts.
|
||||||
|
|
||||||
|
State is persisted per session in /tmp/custodian_tokens_<session_id>.json so
|
||||||
|
deltas are correctly scoped even when multiple tasks complete in one session.
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import urllib.error
|
||||||
|
import urllib.request
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
API = os.environ.get("CUSTODIAN_API", "http://127.0.0.1:8000")
|
||||||
|
STATE_DIR = Path(os.environ.get("TMPDIR", "/tmp"))
|
||||||
|
|
||||||
|
|
||||||
|
def read_transcript_totals(transcript_path: str) -> tuple[int, int]:
|
||||||
|
"""Sum all usage entries in the transcript JSONL up to the current point."""
|
||||||
|
total_in = total_out = 0
|
||||||
|
try:
|
||||||
|
with open(transcript_path) as f:
|
||||||
|
for line in f:
|
||||||
|
try:
|
||||||
|
entry = json.loads(line)
|
||||||
|
usage = entry.get("message", {}).get("usage", {})
|
||||||
|
if usage:
|
||||||
|
# Count all input token variants (direct + cache creation + cache read)
|
||||||
|
total_in += (
|
||||||
|
usage.get("input_tokens", 0)
|
||||||
|
+ usage.get("cache_creation_input_tokens", 0)
|
||||||
|
+ usage.get("cache_read_input_tokens", 0)
|
||||||
|
)
|
||||||
|
total_out += usage.get("output_tokens", 0)
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
continue
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
return total_in, total_out
|
||||||
|
|
||||||
|
|
||||||
|
def load_state(session_id: str) -> tuple[int, int]:
|
||||||
|
state_file = STATE_DIR / f"custodian_tokens_{session_id}.json"
|
||||||
|
try:
|
||||||
|
data = json.loads(state_file.read_text())
|
||||||
|
return data.get("total_in", 0), data.get("total_out", 0)
|
||||||
|
except (OSError, json.JSONDecodeError):
|
||||||
|
return 0, 0
|
||||||
|
|
||||||
|
|
||||||
|
def save_state(session_id: str, total_in: int, total_out: int) -> None:
|
||||||
|
state_file = STATE_DIR / f"custodian_tokens_{session_id}.json"
|
||||||
|
state_file.write_text(json.dumps({"total_in": total_in, "total_out": total_out}))
|
||||||
|
|
||||||
|
|
||||||
|
def api_get(path: str):
|
||||||
|
req = urllib.request.Request(f"{API}{path}")
|
||||||
|
with urllib.request.urlopen(req, timeout=5) as r:
|
||||||
|
return json.loads(r.read())
|
||||||
|
|
||||||
|
|
||||||
|
def api_patch(path: str, data: dict):
|
||||||
|
body = json.dumps(data).encode()
|
||||||
|
req = urllib.request.Request(
|
||||||
|
f"{API}{path}",
|
||||||
|
data=body,
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
method="PATCH",
|
||||||
|
)
|
||||||
|
with urllib.request.urlopen(req, timeout=5) as r:
|
||||||
|
return json.loads(r.read())
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
try:
|
||||||
|
payload = json.loads(sys.stdin.read())
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return
|
||||||
|
|
||||||
|
tool_name = payload.get("tool_name", "")
|
||||||
|
if "update_task_status" not in tool_name:
|
||||||
|
return
|
||||||
|
|
||||||
|
tool_input = payload.get("tool_input", {})
|
||||||
|
if tool_input.get("status") != "done":
|
||||||
|
return
|
||||||
|
|
||||||
|
task_id = tool_input.get("task_id")
|
||||||
|
if not task_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
transcript_path = payload.get("transcript_path", "")
|
||||||
|
session_id = payload.get("session_id", "unknown")
|
||||||
|
|
||||||
|
# Compute token delta for this task
|
||||||
|
current_in, current_out = read_transcript_totals(transcript_path)
|
||||||
|
last_in, last_out = load_state(session_id)
|
||||||
|
delta_in = max(0, current_in - last_in)
|
||||||
|
delta_out = max(0, current_out - last_out)
|
||||||
|
save_state(session_id, current_in, current_out)
|
||||||
|
|
||||||
|
if delta_in == 0 and delta_out == 0:
|
||||||
|
return # Nothing measurable — leave heuristic in place
|
||||||
|
|
||||||
|
# Find the most recent heuristic event for this task and replace it
|
||||||
|
try:
|
||||||
|
events = api_get(f"/token-events/?task_id={task_id}¬e=heuristic&limit=5")
|
||||||
|
except (urllib.error.URLError, OSError):
|
||||||
|
return # API offline — leave heuristic as-is
|
||||||
|
|
||||||
|
if not events:
|
||||||
|
return
|
||||||
|
|
||||||
|
event_id = events[0]["id"]
|
||||||
|
model = tool_input.get("model")
|
||||||
|
agent = tool_input.get("agent")
|
||||||
|
|
||||||
|
patch_body: dict = {"tokens_in": delta_in, "tokens_out": delta_out, "note": "measured"}
|
||||||
|
if model:
|
||||||
|
patch_body["model"] = model
|
||||||
|
if agent:
|
||||||
|
patch_body["agent"] = agent
|
||||||
|
|
||||||
|
try:
|
||||||
|
api_patch(f"/token-events/{event_id}", patch_body)
|
||||||
|
except (urllib.error.URLError, OSError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user