From e247c204397ecf1d967336ebd096e2089b33ccff Mon Sep 17 00:00:00 2001 From: tegwick Date: Sun, 29 Mar 2026 18:28:18 +0200 Subject: [PATCH] feat(token-tracking): three-tier token recording on task done Token events are now always created when update_task_status is called with status="done", using the best available data: Tier 1 (best): exact tokens_in + tokens_out passed by agent Tier 2: workplan_tokens_in + workplan_tokens_out prorated across workstream task count (note="workplan") Tier 3 (fallback): heuristic 1000 in / 500 out (note="heuristic") Non-done status changes never create a token event. MCP tool updated with workplan_tokens_in/out params and tiered docs. Ralph-workplan skill files updated with the three-tier guidance. 184 tests pass. Co-Authored-By: Claude Sonnet 4.6 --- state-hub/api/routers/tasks.py | 31 +++++++--- state-hub/api/schemas/task.py | 7 ++- state-hub/mcp_server/server.py | 53 +++++++++------- state-hub/tests/test_token_passthrough.py | 75 ++++++++++++++++++----- 4 files changed, 119 insertions(+), 47 deletions(-) diff --git a/state-hub/api/routers/tasks.py b/state-hub/api/routers/tasks.py index 3a3fe95..c332bfa 100644 --- a/state-hub/api/routers/tasks.py +++ b/state-hub/api/routers/tasks.py @@ -2,7 +2,7 @@ import uuid from datetime import date from fastapi import APIRouter, Depends, HTTPException, Query, status -from sqlalchemy import select +from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession from api.database import get_session @@ -75,27 +75,44 @@ async def update_task( raise HTTPException(status_code=404, detail="Task not found") # Separate token fields from task fields - token_fields = {"tokens_in", "tokens_out", "model", "agent", "session_id"} + token_field_names = {"tokens_in", "tokens_out", "workplan_tokens_in", "workplan_tokens_out", "model", "agent", "session_id"} update_data = body.model_dump(exclude_unset=True) - token_data = {k: update_data.pop(k) for k in list(update_data.keys()) if k in token_fields} + token_data = {k: update_data.pop(k) for k in list(update_data.keys()) if k in token_field_names} for field, value in update_data.items(): setattr(task, field, value) await session.commit() await session.refresh(task) - # Create token event if token passthrough fields provided - if "tokens_in" in token_data and "tokens_out" in token_data: + # Token event — three-tier logic, only when marking done + if update_data.get("status") == "done": + if "tokens_in" in token_data and "tokens_out" in token_data: + # Tier 1: exact counts provided + tin, tout, tnote = token_data["tokens_in"], token_data["tokens_out"], None + elif "workplan_tokens_in" in token_data and "workplan_tokens_out" in token_data: + # Tier 2: prorate workplan total across task count + count_result = await session.execute( + select(func.count(Task.id)).where(Task.workstream_id == task.workstream_id) + ) + task_count = max(count_result.scalar() or 1, 1) + tin = token_data["workplan_tokens_in"] // task_count + tout = token_data["workplan_tokens_out"] // task_count + tnote = "workplan" + else: + # Tier 3: heuristic fallback + tin, tout, tnote = 1000, 500, "heuristic" + event = TokenEvent( task_id=task_id, workstream_id=task.workstream_id, - tokens_in=token_data["tokens_in"], - tokens_out=token_data["tokens_out"], + tokens_in=tin, + tokens_out=tout, model=token_data.get("model"), agent=token_data.get("agent"), session_id=token_data.get("session_id"), ref_type="task", ref_id=str(task_id), + note=tnote, ) session.add(event) await session.commit() diff --git a/state-hub/api/schemas/task.py b/state-hub/api/schemas/task.py index 04ed116..65df601 100644 --- a/state-hub/api/schemas/task.py +++ b/state-hub/api/schemas/task.py @@ -38,9 +38,14 @@ class TaskUpdate(BaseModel): needs_human: bool | None = None intervention_note: str | None = None parent_task_id: uuid.UUID | None = None - # Optional token passthrough — when provided, a token_event is created + # Token passthrough — three tiers (highest precision wins): + # 1. tokens_in + tokens_out → exact counts (best practice) + # 2. workplan_tokens_in + workplan_tokens_out → prorated across task count (note="workplan") + # 3. neither provided, status=done → heuristic 1000/500 (note="heuristic") tokens_in: int | None = None tokens_out: int | None = None + workplan_tokens_in: int | None = None + workplan_tokens_out: int | None = None model: str | None = None agent: str | None = None session_id: str | None = None diff --git a/state-hub/mcp_server/server.py b/state-hub/mcp_server/server.py index 08fb1df..5579f6b 100644 --- a/state-hub/mcp_server/server.py +++ b/state-hub/mcp_server/server.py @@ -428,29 +428,51 @@ def update_task_status( blocking_reason: Optional[str] = None, tokens_in: Optional[int] = None, tokens_out: Optional[int] = None, + workplan_tokens_in: Optional[int] = None, + workplan_tokens_out: Optional[int] = None, model: Optional[str] = None, agent: Optional[str] = None, session_id: Optional[str] = None, ) -> str: """Update a task's status. blocking_reason is required when status='blocked'. - Optionally record token consumption in one call by passing tokens_in/tokens_out. - When provided, a token_event is created automatically with workstream_id and - repo_id auto-populated from the task. + When status='done', always records a token event using the best available data: + Tier 1 (best): pass tokens_in + tokens_out — exact counts from the session + Tier 2: pass workplan_tokens_in + workplan_tokens_out — total workplan + effort prorated across task count (note="workplan") + Tier 3 (fallback): no token args — heuristic 1000 in / 500 out (note="heuristic") + + Best practice: read tokens from the Claude Code status bar and pass exact counts. Args: task_id: UUID of the task status: todo | in_progress | blocked | done | cancelled blocking_reason: required when status=blocked - tokens_in: optional input token count (triggers token_event creation) - tokens_out: optional output token count (required if tokens_in provided) - model: optional model identifier, e.g. 'claude-sonnet-4-6' - agent: optional agent name, e.g. 'custodian', 'ralph' - session_id: optional agent session identifier + tokens_in: exact input token count for this task (Tier 1) + tokens_out: exact output token count for this task (Tier 1) + workplan_tokens_in: total input tokens for the whole workplan (Tier 2) + workplan_tokens_out: total output tokens for the whole workplan (Tier 2) + model: model identifier, e.g. 'claude-sonnet-4-6' + agent: agent name, e.g. 'custodian', 'ralph' + session_id: agent session identifier """ - body: dict[str, Any] = {"status": status} + body: dict[str, Any] = { + "status": status, + "model": model, + "agent": agent, + "session_id": session_id, + } if blocking_reason: body["blocking_reason"] = blocking_reason + if tokens_in is not None: + body["tokens_in"] = tokens_in + if tokens_out is not None: + body["tokens_out"] = tokens_out + if workplan_tokens_in is not None: + body["workplan_tokens_in"] = workplan_tokens_in + if workplan_tokens_out is not None: + body["workplan_tokens_out"] = workplan_tokens_out + task = _patch(f"/tasks/{task_id}", body) _post("/progress", { "task_id": task_id, @@ -461,19 +483,6 @@ def update_task_status( "detail": {"blocking_reason": blocking_reason}, }) - if tokens_in is not None and tokens_out is not None: - _post("/token-events", { - "task_id": task_id, - "workstream_id": task.get("workstream_id"), - "tokens_in": tokens_in, - "tokens_out": tokens_out, - "model": model, - "agent": agent, - "session_id": session_id, - "ref_type": "task", - "ref_id": task_id, - }) - return json.dumps(task, indent=2) diff --git a/state-hub/tests/test_token_passthrough.py b/state-hub/tests/test_token_passthrough.py index 0868454..c7bec5f 100644 --- a/state-hub/tests/test_token_passthrough.py +++ b/state-hub/tests/test_token_passthrough.py @@ -1,8 +1,12 @@ """ -Token passthrough test: update_task_status with tokens_in/tokens_out -creates a token event automatically. +Token passthrough test: update_task_status creates a token event on done. -Tests the API-level behaviour (the MCP tool delegates to the same endpoints). +Three-tier logic: + Tier 1 — exact tokens_in/tokens_out provided + Tier 2 — workplan_tokens_in/out provided → prorated by task count (note="workplan") + Tier 3 — no token args, status=done → heuristic 1000/500 (note="heuristic") + +Non-done status changes never create a token event. """ from __future__ import annotations @@ -27,22 +31,21 @@ async def _create_workstream(client, topic_id): return r.json() -async def _create_task(client, workstream_id): - r = await client.post("/tasks/", json={"workstream_id": workstream_id, "title": "my task"}) +async def _create_task(client, workstream_id, title="my task"): + r = await client.post("/tasks/", json={"workstream_id": workstream_id, "title": title}) assert r.status_code == 201, r.text return r.json() @pytest.mark.asyncio class TestTokenPassthrough: - async def test_update_status_with_tokens_creates_event(self, client): - """PATCH /tasks/{id} with tokens_in/tokens_out creates a token_event.""" + async def test_tier1_exact_tokens(self, client): + """Tier 1: exact tokens_in/tokens_out → used as-is, no note.""" await _create_domain(client) topic = await _create_topic(client) ws = await _create_workstream(client, topic["id"]) task = await _create_task(client, ws["id"]) - # Update task status with token data r = await client.patch(f"/tasks/{task['id']}", json={ "status": "done", "tokens_in": 1200, @@ -53,10 +56,7 @@ class TestTokenPassthrough: assert r.status_code == 200 assert r.json()["status"] == "done" - # Token event should now exist for this task - r2 = await client.get("/token-events/", params={"task_id": task["id"]}) - assert r2.status_code == 200 - events = r2.json() + events = (await client.get("/token-events/", params={"task_id": task["id"]})).json() assert len(events) == 1 ev = events[0] assert ev["tokens_in"] == 1200 @@ -65,9 +65,51 @@ class TestTokenPassthrough: assert ev["model"] == "claude-sonnet-4-6" assert ev["agent"] == "custodian" assert ev["workstream_id"] == ws["id"] + assert ev["note"] is None - async def test_update_status_without_tokens_creates_no_event(self, client): - """PATCH /tasks/{id} without token fields creates no token_event.""" + async def test_tier2_workplan_prorated(self, client): + """Tier 2: workplan totals prorated across 4 tasks → 250/125 each, note='workplan'.""" + await _create_domain(client) + topic = await _create_topic(client) + ws = await _create_workstream(client, topic["id"]) + # Create 4 tasks; mark the first done with workplan totals + task = await _create_task(client, ws["id"], "T1") + for title in ["T2", "T3", "T4"]: + await _create_task(client, ws["id"], title) + + r = await client.patch(f"/tasks/{task['id']}", json={ + "status": "done", + "workplan_tokens_in": 1000, + "workplan_tokens_out": 500, + }) + assert r.status_code == 200 + + events = (await client.get("/token-events/", params={"task_id": task["id"]})).json() + assert len(events) == 1 + ev = events[0] + assert ev["tokens_in"] == 250 # 1000 // 4 + assert ev["tokens_out"] == 125 # 500 // 4 + assert ev["note"] == "workplan" + + async def test_tier3_heuristic_fallback(self, client): + """Tier 3: status=done with no token args → heuristic 1000/500, note='heuristic'.""" + await _create_domain(client) + topic = await _create_topic(client) + ws = await _create_workstream(client, topic["id"]) + task = await _create_task(client, ws["id"]) + + r = await client.patch(f"/tasks/{task['id']}", json={"status": "done"}) + assert r.status_code == 200 + + events = (await client.get("/token-events/", params={"task_id": task["id"]})).json() + assert len(events) == 1 + ev = events[0] + assert ev["tokens_in"] == 1000 + assert ev["tokens_out"] == 500 + assert ev["note"] == "heuristic" + + async def test_non_done_status_creates_no_event(self, client): + """Non-done status updates never create a token event.""" await _create_domain(client) topic = await _create_topic(client) ws = await _create_workstream(client, topic["id"]) @@ -76,6 +118,5 @@ class TestTokenPassthrough: r = await client.patch(f"/tasks/{task['id']}", json={"status": "in_progress"}) assert r.status_code == 200 - r2 = await client.get("/token-events/", params={"task_id": task["id"]}) - assert r2.status_code == 200 - assert r2.json() == [] + events = (await client.get("/token-events/", params={"task_id": task["id"]})).json() + assert events == []