feat(token-tracking): three-tier token recording on task done

Token events are now always created when update_task_status is called
with status="done", using the best available data:

  Tier 1 (best): exact tokens_in + tokens_out passed by agent
  Tier 2:        workplan_tokens_in + workplan_tokens_out prorated
                 across workstream task count (note="workplan")
  Tier 3 (fallback): heuristic 1000 in / 500 out (note="heuristic")

Non-done status changes never create a token event.
MCP tool updated with workplan_tokens_in/out params and tiered docs.
Ralph-workplan skill files updated with the three-tier guidance.
184 tests pass.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-29 18:28:18 +02:00
parent 3de29b5a9a
commit e247c20439
4 changed files with 119 additions and 47 deletions

View File

@@ -2,7 +2,7 @@ import uuid
from datetime import date
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import select
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from api.database import get_session
@@ -75,27 +75,44 @@ async def update_task(
raise HTTPException(status_code=404, detail="Task not found")
# Separate token fields from task fields
token_fields = {"tokens_in", "tokens_out", "model", "agent", "session_id"}
token_field_names = {"tokens_in", "tokens_out", "workplan_tokens_in", "workplan_tokens_out", "model", "agent", "session_id"}
update_data = body.model_dump(exclude_unset=True)
token_data = {k: update_data.pop(k) for k in list(update_data.keys()) if k in token_fields}
token_data = {k: update_data.pop(k) for k in list(update_data.keys()) if k in token_field_names}
for field, value in update_data.items():
setattr(task, field, value)
await session.commit()
await session.refresh(task)
# Create token event if token passthrough fields provided
if "tokens_in" in token_data and "tokens_out" in token_data:
# Token event — three-tier logic, only when marking done
if update_data.get("status") == "done":
if "tokens_in" in token_data and "tokens_out" in token_data:
# Tier 1: exact counts provided
tin, tout, tnote = token_data["tokens_in"], token_data["tokens_out"], None
elif "workplan_tokens_in" in token_data and "workplan_tokens_out" in token_data:
# Tier 2: prorate workplan total across task count
count_result = await session.execute(
select(func.count(Task.id)).where(Task.workstream_id == task.workstream_id)
)
task_count = max(count_result.scalar() or 1, 1)
tin = token_data["workplan_tokens_in"] // task_count
tout = token_data["workplan_tokens_out"] // task_count
tnote = "workplan"
else:
# Tier 3: heuristic fallback
tin, tout, tnote = 1000, 500, "heuristic"
event = TokenEvent(
task_id=task_id,
workstream_id=task.workstream_id,
tokens_in=token_data["tokens_in"],
tokens_out=token_data["tokens_out"],
tokens_in=tin,
tokens_out=tout,
model=token_data.get("model"),
agent=token_data.get("agent"),
session_id=token_data.get("session_id"),
ref_type="task",
ref_id=str(task_id),
note=tnote,
)
session.add(event)
await session.commit()

View File

@@ -38,9 +38,14 @@ class TaskUpdate(BaseModel):
needs_human: bool | None = None
intervention_note: str | None = None
parent_task_id: uuid.UUID | None = None
# Optional token passthrough — when provided, a token_event is created
# Token passthrough — three tiers (highest precision wins):
# 1. tokens_in + tokens_out → exact counts (best practice)
# 2. workplan_tokens_in + workplan_tokens_out → prorated across task count (note="workplan")
# 3. neither provided, status=done → heuristic 1000/500 (note="heuristic")
tokens_in: int | None = None
tokens_out: int | None = None
workplan_tokens_in: int | None = None
workplan_tokens_out: int | None = None
model: str | None = None
agent: str | None = None
session_id: str | None = None

View File

@@ -428,29 +428,51 @@ def update_task_status(
blocking_reason: Optional[str] = None,
tokens_in: Optional[int] = None,
tokens_out: Optional[int] = None,
workplan_tokens_in: Optional[int] = None,
workplan_tokens_out: Optional[int] = None,
model: Optional[str] = None,
agent: Optional[str] = None,
session_id: Optional[str] = None,
) -> str:
"""Update a task's status. blocking_reason is required when status='blocked'.
Optionally record token consumption in one call by passing tokens_in/tokens_out.
When provided, a token_event is created automatically with workstream_id and
repo_id auto-populated from the task.
When status='done', always records a token event using the best available data:
Tier 1 (best): pass tokens_in + tokens_out — exact counts from the session
Tier 2: pass workplan_tokens_in + workplan_tokens_out — total workplan
effort prorated across task count (note="workplan")
Tier 3 (fallback): no token args — heuristic 1000 in / 500 out (note="heuristic")
Best practice: read tokens from the Claude Code status bar and pass exact counts.
Args:
task_id: UUID of the task
status: todo | in_progress | blocked | done | cancelled
blocking_reason: required when status=blocked
tokens_in: optional input token count (triggers token_event creation)
tokens_out: optional output token count (required if tokens_in provided)
model: optional model identifier, e.g. 'claude-sonnet-4-6'
agent: optional agent name, e.g. 'custodian', 'ralph'
session_id: optional agent session identifier
tokens_in: exact input token count for this task (Tier 1)
tokens_out: exact output token count for this task (Tier 1)
workplan_tokens_in: total input tokens for the whole workplan (Tier 2)
workplan_tokens_out: total output tokens for the whole workplan (Tier 2)
model: model identifier, e.g. 'claude-sonnet-4-6'
agent: agent name, e.g. 'custodian', 'ralph'
session_id: agent session identifier
"""
body: dict[str, Any] = {"status": status}
body: dict[str, Any] = {
"status": status,
"model": model,
"agent": agent,
"session_id": session_id,
}
if blocking_reason:
body["blocking_reason"] = blocking_reason
if tokens_in is not None:
body["tokens_in"] = tokens_in
if tokens_out is not None:
body["tokens_out"] = tokens_out
if workplan_tokens_in is not None:
body["workplan_tokens_in"] = workplan_tokens_in
if workplan_tokens_out is not None:
body["workplan_tokens_out"] = workplan_tokens_out
task = _patch(f"/tasks/{task_id}", body)
_post("/progress", {
"task_id": task_id,
@@ -461,19 +483,6 @@ def update_task_status(
"detail": {"blocking_reason": blocking_reason},
})
if tokens_in is not None and tokens_out is not None:
_post("/token-events", {
"task_id": task_id,
"workstream_id": task.get("workstream_id"),
"tokens_in": tokens_in,
"tokens_out": tokens_out,
"model": model,
"agent": agent,
"session_id": session_id,
"ref_type": "task",
"ref_id": task_id,
})
return json.dumps(task, indent=2)

View File

@@ -1,8 +1,12 @@
"""
Token passthrough test: update_task_status with tokens_in/tokens_out
creates a token event automatically.
Token passthrough test: update_task_status creates a token event on done.
Tests the API-level behaviour (the MCP tool delegates to the same endpoints).
Three-tier logic:
Tier 1 — exact tokens_in/tokens_out provided
Tier 2 — workplan_tokens_in/out provided → prorated by task count (note="workplan")
Tier 3 — no token args, status=done → heuristic 1000/500 (note="heuristic")
Non-done status changes never create a token event.
"""
from __future__ import annotations
@@ -27,22 +31,21 @@ async def _create_workstream(client, topic_id):
return r.json()
async def _create_task(client, workstream_id):
r = await client.post("/tasks/", json={"workstream_id": workstream_id, "title": "my task"})
async def _create_task(client, workstream_id, title="my task"):
r = await client.post("/tasks/", json={"workstream_id": workstream_id, "title": title})
assert r.status_code == 201, r.text
return r.json()
@pytest.mark.asyncio
class TestTokenPassthrough:
async def test_update_status_with_tokens_creates_event(self, client):
"""PATCH /tasks/{id} with tokens_in/tokens_out creates a token_event."""
async def test_tier1_exact_tokens(self, client):
"""Tier 1: exact tokens_in/tokens_out → used as-is, no note."""
await _create_domain(client)
topic = await _create_topic(client)
ws = await _create_workstream(client, topic["id"])
task = await _create_task(client, ws["id"])
# Update task status with token data
r = await client.patch(f"/tasks/{task['id']}", json={
"status": "done",
"tokens_in": 1200,
@@ -53,10 +56,7 @@ class TestTokenPassthrough:
assert r.status_code == 200
assert r.json()["status"] == "done"
# Token event should now exist for this task
r2 = await client.get("/token-events/", params={"task_id": task["id"]})
assert r2.status_code == 200
events = r2.json()
events = (await client.get("/token-events/", params={"task_id": task["id"]})).json()
assert len(events) == 1
ev = events[0]
assert ev["tokens_in"] == 1200
@@ -65,9 +65,51 @@ class TestTokenPassthrough:
assert ev["model"] == "claude-sonnet-4-6"
assert ev["agent"] == "custodian"
assert ev["workstream_id"] == ws["id"]
assert ev["note"] is None
async def test_update_status_without_tokens_creates_no_event(self, client):
"""PATCH /tasks/{id} without token fields creates no token_event."""
async def test_tier2_workplan_prorated(self, client):
"""Tier 2: workplan totals prorated across 4 tasks → 250/125 each, note='workplan'."""
await _create_domain(client)
topic = await _create_topic(client)
ws = await _create_workstream(client, topic["id"])
# Create 4 tasks; mark the first done with workplan totals
task = await _create_task(client, ws["id"], "T1")
for title in ["T2", "T3", "T4"]:
await _create_task(client, ws["id"], title)
r = await client.patch(f"/tasks/{task['id']}", json={
"status": "done",
"workplan_tokens_in": 1000,
"workplan_tokens_out": 500,
})
assert r.status_code == 200
events = (await client.get("/token-events/", params={"task_id": task["id"]})).json()
assert len(events) == 1
ev = events[0]
assert ev["tokens_in"] == 250 # 1000 // 4
assert ev["tokens_out"] == 125 # 500 // 4
assert ev["note"] == "workplan"
async def test_tier3_heuristic_fallback(self, client):
"""Tier 3: status=done with no token args → heuristic 1000/500, note='heuristic'."""
await _create_domain(client)
topic = await _create_topic(client)
ws = await _create_workstream(client, topic["id"])
task = await _create_task(client, ws["id"])
r = await client.patch(f"/tasks/{task['id']}", json={"status": "done"})
assert r.status_code == 200
events = (await client.get("/token-events/", params={"task_id": task["id"]})).json()
assert len(events) == 1
ev = events[0]
assert ev["tokens_in"] == 1000
assert ev["tokens_out"] == 500
assert ev["note"] == "heuristic"
async def test_non_done_status_creates_no_event(self, client):
"""Non-done status updates never create a token event."""
await _create_domain(client)
topic = await _create_topic(client)
ws = await _create_workstream(client, topic["id"])
@@ -76,6 +118,5 @@ class TestTokenPassthrough:
r = await client.patch(f"/tasks/{task['id']}", json={"status": "in_progress"})
assert r.status_code == 200
r2 = await client.get("/token-events/", params={"task_id": task["id"]})
assert r2.status_code == 200
assert r2.json() == []
events = (await client.get("/token-events/", params={"task_id": task["id"]})).json()
assert events == []