feat(token-tracking): three-tier token recording on task done
Token events are now always created when update_task_status is called
with status="done", using the best available data:
Tier 1 (best): exact tokens_in + tokens_out passed by agent
Tier 2: workplan_tokens_in + workplan_tokens_out prorated
across workstream task count (note="workplan")
Tier 3 (fallback): heuristic 1000 in / 500 out (note="heuristic")
Non-done status changes never create a token event.
MCP tool updated with workplan_tokens_in/out params and tiered docs.
Ralph-workplan skill files updated with the three-tier guidance.
184 tests pass.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -2,7 +2,7 @@ import uuid
|
||||
from datetime import date
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from api.database import get_session
|
||||
@@ -75,27 +75,44 @@ async def update_task(
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
# Separate token fields from task fields
|
||||
token_fields = {"tokens_in", "tokens_out", "model", "agent", "session_id"}
|
||||
token_field_names = {"tokens_in", "tokens_out", "workplan_tokens_in", "workplan_tokens_out", "model", "agent", "session_id"}
|
||||
update_data = body.model_dump(exclude_unset=True)
|
||||
token_data = {k: update_data.pop(k) for k in list(update_data.keys()) if k in token_fields}
|
||||
token_data = {k: update_data.pop(k) for k in list(update_data.keys()) if k in token_field_names}
|
||||
|
||||
for field, value in update_data.items():
|
||||
setattr(task, field, value)
|
||||
await session.commit()
|
||||
await session.refresh(task)
|
||||
|
||||
# Create token event if token passthrough fields provided
|
||||
if "tokens_in" in token_data and "tokens_out" in token_data:
|
||||
# Token event — three-tier logic, only when marking done
|
||||
if update_data.get("status") == "done":
|
||||
if "tokens_in" in token_data and "tokens_out" in token_data:
|
||||
# Tier 1: exact counts provided
|
||||
tin, tout, tnote = token_data["tokens_in"], token_data["tokens_out"], None
|
||||
elif "workplan_tokens_in" in token_data and "workplan_tokens_out" in token_data:
|
||||
# Tier 2: prorate workplan total across task count
|
||||
count_result = await session.execute(
|
||||
select(func.count(Task.id)).where(Task.workstream_id == task.workstream_id)
|
||||
)
|
||||
task_count = max(count_result.scalar() or 1, 1)
|
||||
tin = token_data["workplan_tokens_in"] // task_count
|
||||
tout = token_data["workplan_tokens_out"] // task_count
|
||||
tnote = "workplan"
|
||||
else:
|
||||
# Tier 3: heuristic fallback
|
||||
tin, tout, tnote = 1000, 500, "heuristic"
|
||||
|
||||
event = TokenEvent(
|
||||
task_id=task_id,
|
||||
workstream_id=task.workstream_id,
|
||||
tokens_in=token_data["tokens_in"],
|
||||
tokens_out=token_data["tokens_out"],
|
||||
tokens_in=tin,
|
||||
tokens_out=tout,
|
||||
model=token_data.get("model"),
|
||||
agent=token_data.get("agent"),
|
||||
session_id=token_data.get("session_id"),
|
||||
ref_type="task",
|
||||
ref_id=str(task_id),
|
||||
note=tnote,
|
||||
)
|
||||
session.add(event)
|
||||
await session.commit()
|
||||
|
||||
@@ -38,9 +38,14 @@ class TaskUpdate(BaseModel):
|
||||
needs_human: bool | None = None
|
||||
intervention_note: str | None = None
|
||||
parent_task_id: uuid.UUID | None = None
|
||||
# Optional token passthrough — when provided, a token_event is created
|
||||
# Token passthrough — three tiers (highest precision wins):
|
||||
# 1. tokens_in + tokens_out → exact counts (best practice)
|
||||
# 2. workplan_tokens_in + workplan_tokens_out → prorated across task count (note="workplan")
|
||||
# 3. neither provided, status=done → heuristic 1000/500 (note="heuristic")
|
||||
tokens_in: int | None = None
|
||||
tokens_out: int | None = None
|
||||
workplan_tokens_in: int | None = None
|
||||
workplan_tokens_out: int | None = None
|
||||
model: str | None = None
|
||||
agent: str | None = None
|
||||
session_id: str | None = None
|
||||
|
||||
@@ -428,29 +428,51 @@ def update_task_status(
|
||||
blocking_reason: Optional[str] = None,
|
||||
tokens_in: Optional[int] = None,
|
||||
tokens_out: Optional[int] = None,
|
||||
workplan_tokens_in: Optional[int] = None,
|
||||
workplan_tokens_out: Optional[int] = None,
|
||||
model: Optional[str] = None,
|
||||
agent: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Update a task's status. blocking_reason is required when status='blocked'.
|
||||
|
||||
Optionally record token consumption in one call by passing tokens_in/tokens_out.
|
||||
When provided, a token_event is created automatically with workstream_id and
|
||||
repo_id auto-populated from the task.
|
||||
When status='done', always records a token event using the best available data:
|
||||
Tier 1 (best): pass tokens_in + tokens_out — exact counts from the session
|
||||
Tier 2: pass workplan_tokens_in + workplan_tokens_out — total workplan
|
||||
effort prorated across task count (note="workplan")
|
||||
Tier 3 (fallback): no token args — heuristic 1000 in / 500 out (note="heuristic")
|
||||
|
||||
Best practice: read tokens from the Claude Code status bar and pass exact counts.
|
||||
|
||||
Args:
|
||||
task_id: UUID of the task
|
||||
status: todo | in_progress | blocked | done | cancelled
|
||||
blocking_reason: required when status=blocked
|
||||
tokens_in: optional input token count (triggers token_event creation)
|
||||
tokens_out: optional output token count (required if tokens_in provided)
|
||||
model: optional model identifier, e.g. 'claude-sonnet-4-6'
|
||||
agent: optional agent name, e.g. 'custodian', 'ralph'
|
||||
session_id: optional agent session identifier
|
||||
tokens_in: exact input token count for this task (Tier 1)
|
||||
tokens_out: exact output token count for this task (Tier 1)
|
||||
workplan_tokens_in: total input tokens for the whole workplan (Tier 2)
|
||||
workplan_tokens_out: total output tokens for the whole workplan (Tier 2)
|
||||
model: model identifier, e.g. 'claude-sonnet-4-6'
|
||||
agent: agent name, e.g. 'custodian', 'ralph'
|
||||
session_id: agent session identifier
|
||||
"""
|
||||
body: dict[str, Any] = {"status": status}
|
||||
body: dict[str, Any] = {
|
||||
"status": status,
|
||||
"model": model,
|
||||
"agent": agent,
|
||||
"session_id": session_id,
|
||||
}
|
||||
if blocking_reason:
|
||||
body["blocking_reason"] = blocking_reason
|
||||
if tokens_in is not None:
|
||||
body["tokens_in"] = tokens_in
|
||||
if tokens_out is not None:
|
||||
body["tokens_out"] = tokens_out
|
||||
if workplan_tokens_in is not None:
|
||||
body["workplan_tokens_in"] = workplan_tokens_in
|
||||
if workplan_tokens_out is not None:
|
||||
body["workplan_tokens_out"] = workplan_tokens_out
|
||||
|
||||
task = _patch(f"/tasks/{task_id}", body)
|
||||
_post("/progress", {
|
||||
"task_id": task_id,
|
||||
@@ -461,19 +483,6 @@ def update_task_status(
|
||||
"detail": {"blocking_reason": blocking_reason},
|
||||
})
|
||||
|
||||
if tokens_in is not None and tokens_out is not None:
|
||||
_post("/token-events", {
|
||||
"task_id": task_id,
|
||||
"workstream_id": task.get("workstream_id"),
|
||||
"tokens_in": tokens_in,
|
||||
"tokens_out": tokens_out,
|
||||
"model": model,
|
||||
"agent": agent,
|
||||
"session_id": session_id,
|
||||
"ref_type": "task",
|
||||
"ref_id": task_id,
|
||||
})
|
||||
|
||||
return json.dumps(task, indent=2)
|
||||
|
||||
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
"""
|
||||
Token passthrough test: update_task_status with tokens_in/tokens_out
|
||||
creates a token event automatically.
|
||||
Token passthrough test: update_task_status creates a token event on done.
|
||||
|
||||
Tests the API-level behaviour (the MCP tool delegates to the same endpoints).
|
||||
Three-tier logic:
|
||||
Tier 1 — exact tokens_in/tokens_out provided
|
||||
Tier 2 — workplan_tokens_in/out provided → prorated by task count (note="workplan")
|
||||
Tier 3 — no token args, status=done → heuristic 1000/500 (note="heuristic")
|
||||
|
||||
Non-done status changes never create a token event.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -27,22 +31,21 @@ async def _create_workstream(client, topic_id):
|
||||
return r.json()
|
||||
|
||||
|
||||
async def _create_task(client, workstream_id):
|
||||
r = await client.post("/tasks/", json={"workstream_id": workstream_id, "title": "my task"})
|
||||
async def _create_task(client, workstream_id, title="my task"):
|
||||
r = await client.post("/tasks/", json={"workstream_id": workstream_id, "title": title})
|
||||
assert r.status_code == 201, r.text
|
||||
return r.json()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestTokenPassthrough:
|
||||
async def test_update_status_with_tokens_creates_event(self, client):
|
||||
"""PATCH /tasks/{id} with tokens_in/tokens_out creates a token_event."""
|
||||
async def test_tier1_exact_tokens(self, client):
|
||||
"""Tier 1: exact tokens_in/tokens_out → used as-is, no note."""
|
||||
await _create_domain(client)
|
||||
topic = await _create_topic(client)
|
||||
ws = await _create_workstream(client, topic["id"])
|
||||
task = await _create_task(client, ws["id"])
|
||||
|
||||
# Update task status with token data
|
||||
r = await client.patch(f"/tasks/{task['id']}", json={
|
||||
"status": "done",
|
||||
"tokens_in": 1200,
|
||||
@@ -53,10 +56,7 @@ class TestTokenPassthrough:
|
||||
assert r.status_code == 200
|
||||
assert r.json()["status"] == "done"
|
||||
|
||||
# Token event should now exist for this task
|
||||
r2 = await client.get("/token-events/", params={"task_id": task["id"]})
|
||||
assert r2.status_code == 200
|
||||
events = r2.json()
|
||||
events = (await client.get("/token-events/", params={"task_id": task["id"]})).json()
|
||||
assert len(events) == 1
|
||||
ev = events[0]
|
||||
assert ev["tokens_in"] == 1200
|
||||
@@ -65,9 +65,51 @@ class TestTokenPassthrough:
|
||||
assert ev["model"] == "claude-sonnet-4-6"
|
||||
assert ev["agent"] == "custodian"
|
||||
assert ev["workstream_id"] == ws["id"]
|
||||
assert ev["note"] is None
|
||||
|
||||
async def test_update_status_without_tokens_creates_no_event(self, client):
|
||||
"""PATCH /tasks/{id} without token fields creates no token_event."""
|
||||
async def test_tier2_workplan_prorated(self, client):
|
||||
"""Tier 2: workplan totals prorated across 4 tasks → 250/125 each, note='workplan'."""
|
||||
await _create_domain(client)
|
||||
topic = await _create_topic(client)
|
||||
ws = await _create_workstream(client, topic["id"])
|
||||
# Create 4 tasks; mark the first done with workplan totals
|
||||
task = await _create_task(client, ws["id"], "T1")
|
||||
for title in ["T2", "T3", "T4"]:
|
||||
await _create_task(client, ws["id"], title)
|
||||
|
||||
r = await client.patch(f"/tasks/{task['id']}", json={
|
||||
"status": "done",
|
||||
"workplan_tokens_in": 1000,
|
||||
"workplan_tokens_out": 500,
|
||||
})
|
||||
assert r.status_code == 200
|
||||
|
||||
events = (await client.get("/token-events/", params={"task_id": task["id"]})).json()
|
||||
assert len(events) == 1
|
||||
ev = events[0]
|
||||
assert ev["tokens_in"] == 250 # 1000 // 4
|
||||
assert ev["tokens_out"] == 125 # 500 // 4
|
||||
assert ev["note"] == "workplan"
|
||||
|
||||
async def test_tier3_heuristic_fallback(self, client):
|
||||
"""Tier 3: status=done with no token args → heuristic 1000/500, note='heuristic'."""
|
||||
await _create_domain(client)
|
||||
topic = await _create_topic(client)
|
||||
ws = await _create_workstream(client, topic["id"])
|
||||
task = await _create_task(client, ws["id"])
|
||||
|
||||
r = await client.patch(f"/tasks/{task['id']}", json={"status": "done"})
|
||||
assert r.status_code == 200
|
||||
|
||||
events = (await client.get("/token-events/", params={"task_id": task["id"]})).json()
|
||||
assert len(events) == 1
|
||||
ev = events[0]
|
||||
assert ev["tokens_in"] == 1000
|
||||
assert ev["tokens_out"] == 500
|
||||
assert ev["note"] == "heuristic"
|
||||
|
||||
async def test_non_done_status_creates_no_event(self, client):
|
||||
"""Non-done status updates never create a token event."""
|
||||
await _create_domain(client)
|
||||
topic = await _create_topic(client)
|
||||
ws = await _create_workstream(client, topic["id"])
|
||||
@@ -76,6 +118,5 @@ class TestTokenPassthrough:
|
||||
r = await client.patch(f"/tasks/{task['id']}", json={"status": "in_progress"})
|
||||
assert r.status_code == 200
|
||||
|
||||
r2 = await client.get("/token-events/", params={"task_id": task["id"]})
|
||||
assert r2.status_code == 200
|
||||
assert r2.json() == []
|
||||
events = (await client.get("/token-events/", params={"task_id": task["id"]})).json()
|
||||
assert events == []
|
||||
|
||||
Reference in New Issue
Block a user