feat(token-tracking): three-tier token recording on task done

Token events are now always created when update_task_status is called
with status="done", using the best available data:

  Tier 1 (best): exact tokens_in + tokens_out passed by agent
  Tier 2:        workplan_tokens_in + workplan_tokens_out prorated
                 across workstream task count (note="workplan")
  Tier 3 (fallback): heuristic 1000 in / 500 out (note="heuristic")

Non-done status changes never create a token event.
MCP tool updated with workplan_tokens_in/out params and tiered docs.
Ralph-workplan skill files updated with the three-tier guidance.
184 tests pass.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-29 18:28:18 +02:00
parent 58e1bafce9
commit fdfd4365cd
4 changed files with 119 additions and 47 deletions

View File

@@ -2,7 +2,7 @@ import uuid
from datetime import date from datetime import date
from fastapi import APIRouter, Depends, HTTPException, Query, status from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import select from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from api.database import get_session from api.database import get_session
@@ -75,27 +75,44 @@ async def update_task(
raise HTTPException(status_code=404, detail="Task not found") raise HTTPException(status_code=404, detail="Task not found")
# Separate token fields from task fields # Separate token fields from task fields
token_fields = {"tokens_in", "tokens_out", "model", "agent", "session_id"} token_field_names = {"tokens_in", "tokens_out", "workplan_tokens_in", "workplan_tokens_out", "model", "agent", "session_id"}
update_data = body.model_dump(exclude_unset=True) update_data = body.model_dump(exclude_unset=True)
token_data = {k: update_data.pop(k) for k in list(update_data.keys()) if k in token_fields} token_data = {k: update_data.pop(k) for k in list(update_data.keys()) if k in token_field_names}
for field, value in update_data.items(): for field, value in update_data.items():
setattr(task, field, value) setattr(task, field, value)
await session.commit() await session.commit()
await session.refresh(task) await session.refresh(task)
# Create token event if token passthrough fields provided # Token event — three-tier logic, only when marking done
if "tokens_in" in token_data and "tokens_out" in token_data: if update_data.get("status") == "done":
if "tokens_in" in token_data and "tokens_out" in token_data:
# Tier 1: exact counts provided
tin, tout, tnote = token_data["tokens_in"], token_data["tokens_out"], None
elif "workplan_tokens_in" in token_data and "workplan_tokens_out" in token_data:
# Tier 2: prorate workplan total across task count
count_result = await session.execute(
select(func.count(Task.id)).where(Task.workstream_id == task.workstream_id)
)
task_count = max(count_result.scalar() or 1, 1)
tin = token_data["workplan_tokens_in"] // task_count
tout = token_data["workplan_tokens_out"] // task_count
tnote = "workplan"
else:
# Tier 3: heuristic fallback
tin, tout, tnote = 1000, 500, "heuristic"
event = TokenEvent( event = TokenEvent(
task_id=task_id, task_id=task_id,
workstream_id=task.workstream_id, workstream_id=task.workstream_id,
tokens_in=token_data["tokens_in"], tokens_in=tin,
tokens_out=token_data["tokens_out"], tokens_out=tout,
model=token_data.get("model"), model=token_data.get("model"),
agent=token_data.get("agent"), agent=token_data.get("agent"),
session_id=token_data.get("session_id"), session_id=token_data.get("session_id"),
ref_type="task", ref_type="task",
ref_id=str(task_id), ref_id=str(task_id),
note=tnote,
) )
session.add(event) session.add(event)
await session.commit() await session.commit()

View File

@@ -38,9 +38,14 @@ class TaskUpdate(BaseModel):
needs_human: bool | None = None needs_human: bool | None = None
intervention_note: str | None = None intervention_note: str | None = None
parent_task_id: uuid.UUID | None = None parent_task_id: uuid.UUID | None = None
# Optional token passthrough — when provided, a token_event is created # Token passthrough — three tiers (highest precision wins):
# 1. tokens_in + tokens_out → exact counts (best practice)
# 2. workplan_tokens_in + workplan_tokens_out → prorated across task count (note="workplan")
# 3. neither provided, status=done → heuristic 1000/500 (note="heuristic")
tokens_in: int | None = None tokens_in: int | None = None
tokens_out: int | None = None tokens_out: int | None = None
workplan_tokens_in: int | None = None
workplan_tokens_out: int | None = None
model: str | None = None model: str | None = None
agent: str | None = None agent: str | None = None
session_id: str | None = None session_id: str | None = None

View File

@@ -428,29 +428,51 @@ def update_task_status(
blocking_reason: Optional[str] = None, blocking_reason: Optional[str] = None,
tokens_in: Optional[int] = None, tokens_in: Optional[int] = None,
tokens_out: Optional[int] = None, tokens_out: Optional[int] = None,
workplan_tokens_in: Optional[int] = None,
workplan_tokens_out: Optional[int] = None,
model: Optional[str] = None, model: Optional[str] = None,
agent: Optional[str] = None, agent: Optional[str] = None,
session_id: Optional[str] = None, session_id: Optional[str] = None,
) -> str: ) -> str:
"""Update a task's status. blocking_reason is required when status='blocked'. """Update a task's status. blocking_reason is required when status='blocked'.
Optionally record token consumption in one call by passing tokens_in/tokens_out. When status='done', always records a token event using the best available data:
When provided, a token_event is created automatically with workstream_id and Tier 1 (best): pass tokens_in + tokens_out — exact counts from the session
repo_id auto-populated from the task. Tier 2: pass workplan_tokens_in + workplan_tokens_out — total workplan
effort prorated across task count (note="workplan")
Tier 3 (fallback): no token args — heuristic 1000 in / 500 out (note="heuristic")
Best practice: read tokens from the Claude Code status bar and pass exact counts.
Args: Args:
task_id: UUID of the task task_id: UUID of the task
status: todo | in_progress | blocked | done | cancelled status: todo | in_progress | blocked | done | cancelled
blocking_reason: required when status=blocked blocking_reason: required when status=blocked
tokens_in: optional input token count (triggers token_event creation) tokens_in: exact input token count for this task (Tier 1)
tokens_out: optional output token count (required if tokens_in provided) tokens_out: exact output token count for this task (Tier 1)
model: optional model identifier, e.g. 'claude-sonnet-4-6' workplan_tokens_in: total input tokens for the whole workplan (Tier 2)
agent: optional agent name, e.g. 'custodian', 'ralph' workplan_tokens_out: total output tokens for the whole workplan (Tier 2)
session_id: optional agent session identifier model: model identifier, e.g. 'claude-sonnet-4-6'
agent: agent name, e.g. 'custodian', 'ralph'
session_id: agent session identifier
""" """
body: dict[str, Any] = {"status": status} body: dict[str, Any] = {
"status": status,
"model": model,
"agent": agent,
"session_id": session_id,
}
if blocking_reason: if blocking_reason:
body["blocking_reason"] = blocking_reason body["blocking_reason"] = blocking_reason
if tokens_in is not None:
body["tokens_in"] = tokens_in
if tokens_out is not None:
body["tokens_out"] = tokens_out
if workplan_tokens_in is not None:
body["workplan_tokens_in"] = workplan_tokens_in
if workplan_tokens_out is not None:
body["workplan_tokens_out"] = workplan_tokens_out
task = _patch(f"/tasks/{task_id}", body) task = _patch(f"/tasks/{task_id}", body)
_post("/progress", { _post("/progress", {
"task_id": task_id, "task_id": task_id,
@@ -461,19 +483,6 @@ def update_task_status(
"detail": {"blocking_reason": blocking_reason}, "detail": {"blocking_reason": blocking_reason},
}) })
if tokens_in is not None and tokens_out is not None:
_post("/token-events", {
"task_id": task_id,
"workstream_id": task.get("workstream_id"),
"tokens_in": tokens_in,
"tokens_out": tokens_out,
"model": model,
"agent": agent,
"session_id": session_id,
"ref_type": "task",
"ref_id": task_id,
})
return json.dumps(task, indent=2) return json.dumps(task, indent=2)

View File

@@ -1,8 +1,12 @@
""" """
Token passthrough test: update_task_status with tokens_in/tokens_out Token passthrough test: update_task_status creates a token event on done.
creates a token event automatically.
Tests the API-level behaviour (the MCP tool delegates to the same endpoints). Three-tier logic:
Tier 1 — exact tokens_in/tokens_out provided
Tier 2 — workplan_tokens_in/out provided → prorated by task count (note="workplan")
Tier 3 — no token args, status=done → heuristic 1000/500 (note="heuristic")
Non-done status changes never create a token event.
""" """
from __future__ import annotations from __future__ import annotations
@@ -27,22 +31,21 @@ async def _create_workstream(client, topic_id):
return r.json() return r.json()
async def _create_task(client, workstream_id): async def _create_task(client, workstream_id, title="my task"):
r = await client.post("/tasks/", json={"workstream_id": workstream_id, "title": "my task"}) r = await client.post("/tasks/", json={"workstream_id": workstream_id, "title": title})
assert r.status_code == 201, r.text assert r.status_code == 201, r.text
return r.json() return r.json()
@pytest.mark.asyncio @pytest.mark.asyncio
class TestTokenPassthrough: class TestTokenPassthrough:
async def test_update_status_with_tokens_creates_event(self, client): async def test_tier1_exact_tokens(self, client):
"""PATCH /tasks/{id} with tokens_in/tokens_out creates a token_event.""" """Tier 1: exact tokens_in/tokens_out → used as-is, no note."""
await _create_domain(client) await _create_domain(client)
topic = await _create_topic(client) topic = await _create_topic(client)
ws = await _create_workstream(client, topic["id"]) ws = await _create_workstream(client, topic["id"])
task = await _create_task(client, ws["id"]) task = await _create_task(client, ws["id"])
# Update task status with token data
r = await client.patch(f"/tasks/{task['id']}", json={ r = await client.patch(f"/tasks/{task['id']}", json={
"status": "done", "status": "done",
"tokens_in": 1200, "tokens_in": 1200,
@@ -53,10 +56,7 @@ class TestTokenPassthrough:
assert r.status_code == 200 assert r.status_code == 200
assert r.json()["status"] == "done" assert r.json()["status"] == "done"
# Token event should now exist for this task events = (await client.get("/token-events/", params={"task_id": task["id"]})).json()
r2 = await client.get("/token-events/", params={"task_id": task["id"]})
assert r2.status_code == 200
events = r2.json()
assert len(events) == 1 assert len(events) == 1
ev = events[0] ev = events[0]
assert ev["tokens_in"] == 1200 assert ev["tokens_in"] == 1200
@@ -65,9 +65,51 @@ class TestTokenPassthrough:
assert ev["model"] == "claude-sonnet-4-6" assert ev["model"] == "claude-sonnet-4-6"
assert ev["agent"] == "custodian" assert ev["agent"] == "custodian"
assert ev["workstream_id"] == ws["id"] assert ev["workstream_id"] == ws["id"]
assert ev["note"] is None
async def test_update_status_without_tokens_creates_no_event(self, client): async def test_tier2_workplan_prorated(self, client):
"""PATCH /tasks/{id} without token fields creates no token_event.""" """Tier 2: workplan totals prorated across 4 tasks → 250/125 each, note='workplan'."""
await _create_domain(client)
topic = await _create_topic(client)
ws = await _create_workstream(client, topic["id"])
# Create 4 tasks; mark the first done with workplan totals
task = await _create_task(client, ws["id"], "T1")
for title in ["T2", "T3", "T4"]:
await _create_task(client, ws["id"], title)
r = await client.patch(f"/tasks/{task['id']}", json={
"status": "done",
"workplan_tokens_in": 1000,
"workplan_tokens_out": 500,
})
assert r.status_code == 200
events = (await client.get("/token-events/", params={"task_id": task["id"]})).json()
assert len(events) == 1
ev = events[0]
assert ev["tokens_in"] == 250 # 1000 // 4
assert ev["tokens_out"] == 125 # 500 // 4
assert ev["note"] == "workplan"
async def test_tier3_heuristic_fallback(self, client):
"""Tier 3: status=done with no token args → heuristic 1000/500, note='heuristic'."""
await _create_domain(client)
topic = await _create_topic(client)
ws = await _create_workstream(client, topic["id"])
task = await _create_task(client, ws["id"])
r = await client.patch(f"/tasks/{task['id']}", json={"status": "done"})
assert r.status_code == 200
events = (await client.get("/token-events/", params={"task_id": task["id"]})).json()
assert len(events) == 1
ev = events[0]
assert ev["tokens_in"] == 1000
assert ev["tokens_out"] == 500
assert ev["note"] == "heuristic"
async def test_non_done_status_creates_no_event(self, client):
"""Non-done status updates never create a token event."""
await _create_domain(client) await _create_domain(client)
topic = await _create_topic(client) topic = await _create_topic(client)
ws = await _create_workstream(client, topic["id"]) ws = await _create_workstream(client, topic["id"])
@@ -76,6 +118,5 @@ class TestTokenPassthrough:
r = await client.patch(f"/tasks/{task['id']}", json={"status": "in_progress"}) r = await client.patch(f"/tasks/{task['id']}", json={"status": "in_progress"})
assert r.status_code == 200 assert r.status_code == 200
r2 = await client.get("/token-events/", params={"task_id": task["id"]}) events = (await client.get("/token-events/", params={"task_id": task["id"]})).json()
assert r2.status_code == 200 assert events == []
assert r2.json() == []