Fixed and improved token tracking

This commit is contained in:
2026-05-23 13:59:05 +02:00
parent dd3279ea1a
commit c12091c2eb
29 changed files with 3549 additions and 278 deletions

View File

@@ -75,23 +75,47 @@ async def update_task(
if task is None:
raise HTTPException(status_code=404, detail="Task not found")
previous_status = task.status.value
# Separate token fields from task fields
token_field_names = {"tokens_in", "tokens_out", "workplan_tokens_in", "workplan_tokens_out", "token_note", "model", "agent", "session_id"}
token_field_names = {
"tokens_in",
"tokens_out",
"workplan_tokens_in",
"workplan_tokens_out",
"token_note",
"model",
"agent",
"session_id",
"suppress_token_event",
}
update_data = body.model_dump(exclude_unset=True)
token_data = {k: update_data.pop(k) for k in list(update_data.keys()) if k in token_field_names}
suppress_token_event = bool(token_data.pop("suppress_token_event", False))
for field, value in update_data.items():
setattr(task, field, value)
await session.commit()
await session.refresh(task)
# Token event — three-tier logic, only when marking done
if update_data.get("status") == "done":
# Token event — three-tier logic, only for an intentional transition to done.
status_update = update_data.get("status")
new_status = status_update.value if hasattr(status_update, "value") else status_update
if (
new_status == "done"
and previous_status != "done"
and not suppress_token_event
):
if "tokens_in" in token_data and "tokens_out" in token_data:
# Tier 1: exact counts — default note "measured"; caller may override with token_note
tin = token_data["tokens_in"]
tout = token_data["tokens_out"]
tnote = token_data.get("token_note") or "measured"
measurement_kind = "measured"
source_provider = "manual"
confidence = 1.0
source_id = f"task:{task_id}:manual"
raw_metadata = {"input_source": "task_status_patch"}
elif "workplan_tokens_in" in token_data and "workplan_tokens_out" in token_data:
# Tier 2: prorate workplan total across task count
count_result = await session.execute(
@@ -101,9 +125,24 @@ async def update_task(
tin = token_data["workplan_tokens_in"] // task_count
tout = token_data["workplan_tokens_out"] // task_count
tnote = "workplan"
measurement_kind = "allocated"
source_provider = "manual"
confidence = 0.7
source_id = f"task:{task_id}:workplan-allocation"
raw_metadata = {
"allocation_method": "workplan_prorated",
"workplan_tokens_in": token_data["workplan_tokens_in"],
"workplan_tokens_out": token_data["workplan_tokens_out"],
"task_count": task_count,
}
else:
# Tier 3: heuristic fallback
tin, tout, tnote = 1000, 500, "heuristic"
measurement_kind = "estimated"
source_provider = "task_fallback"
confidence = 0.35
source_id = f"task:{task_id}:heuristic"
raw_metadata = {"estimation_method": "fixed_task_done_fallback"}
# Resolve repo_id via workstream
ws = await session.get(Workstream, task.workstream_id)
@@ -121,6 +160,12 @@ async def update_task(
ref_type="task",
ref_id=str(task_id),
note=tnote,
measurement_kind=measurement_kind,
source_provider=source_provider,
source_id=source_id,
confidence=confidence,
raw_total_tokens=tin + tout,
raw_metadata=raw_metadata,
)
session.add(event)
await session.commit()

View File

@@ -1,5 +1,7 @@
import uuid
from collections import defaultdict
from datetime import datetime
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import select
@@ -10,18 +12,95 @@ from api.models.managed_repo import ManagedRepo
from api.models.task import Task
from api.models.token_event import TokenEvent
from api.models.workstream import Workstream
from api.schemas.token_event import RepoTokenSummary, TokenEventCreate, TokenEventPatch, TokenEventRead, TokenSummary
from api.schemas.token_event import (
RepoTokenSummary,
TokenAggregateRow,
TokenAggregateSummary,
TokenEventCreate,
TokenEventPatch,
TokenEventRead,
TokenQualitySummary,
TokenSummary,
)
router = APIRouter(prefix="/token-events", tags=["token-events"])
DEFAULT_CONFIDENCE = {
"measured": 1.0,
"allocated": 0.70,
"estimated": 0.35,
"superseded": 0.0,
}
@router.post("/", response_model=TokenEventRead, status_code=status.HTTP_201_CREATED)
async def create_token_event(
body: TokenEventCreate,
session: AsyncSession = Depends(get_session),
) -> TokenEvent:
data = body.model_dump()
SOURCE_PARSER_DEFAULTS = {
"codex_session": "codex-desktop-v1",
"claude_transcript": "claude-transcript-v1",
"llm_connect": "llm-connect-v1",
}
def _event_total(event: TokenEvent) -> int:
return event.tokens_in + event.tokens_out
def _infer_measurement_kind(data: dict[str, Any]) -> str:
if data.get("measurement_kind"):
return str(data["measurement_kind"])
note = data.get("note")
if note == "heuristic_superseded_by_codex_backfill":
return "superseded"
if note == "workplan":
return "allocated"
if note == "heuristic":
return "estimated"
if note == "measured" or str(note or "").startswith("backfill:codex-session"):
return "measured"
provider = data.get("source_provider")
if provider in {"codex_session", "claude_transcript", "llm_connect"}:
return "measured"
return "estimated"
def _infer_source_provider(data: dict[str, Any], measurement_kind: str) -> str:
if data.get("source_provider"):
return str(data["source_provider"])
note = data.get("note")
ref_id = str(data.get("ref_id") or "")
agent = str(data.get("agent") or "").lower()
if note == "heuristic":
return "task_fallback"
if ref_id.startswith("codex:") or str(note or "").startswith("backfill:codex-session"):
return "codex_session"
if measurement_kind == "measured" and "claude" in agent:
return "claude_transcript"
return "manual"
def _apply_event_defaults(data: dict[str, Any]) -> dict[str, Any]:
measurement_kind = _infer_measurement_kind(data)
source_provider = _infer_source_provider(data, measurement_kind)
data["measurement_kind"] = measurement_kind
data["source_provider"] = source_provider
if not data.get("source_id") and source_provider in {"codex_session", "claude_transcript", "llm_connect"}:
source_id = data.get("ref_id") or data.get("session_id")
if source_id:
data["source_id"] = str(source_id)
if not data.get("source_created_at") and data.get("created_at") and data.get("source_id"):
data["source_created_at"] = data["created_at"]
data.setdefault("confidence", DEFAULT_CONFIDENCE.get(measurement_kind, 0.35))
data.setdefault("cached_input_tokens", 0)
data.setdefault("reasoning_output_tokens", 0)
data.setdefault("raw_total_tokens", (data.get("tokens_in") or 0) + (data.get("tokens_out") or 0))
data.setdefault("raw_metadata", {})
if source_provider in SOURCE_PARSER_DEFAULTS:
data.setdefault("parser_version", SOURCE_PARSER_DEFAULTS[source_provider])
return data
async def _populate_relationship_defaults(data: dict[str, Any], session: AsyncSession) -> dict[str, Any]:
# Auto-populate workstream_id from task if not provided
if data.get("task_id") and not data.get("workstream_id"):
task = await session.get(Task, data["task_id"])
@@ -33,6 +112,34 @@ async def create_token_event(
ws = await session.get(Workstream, data["workstream_id"])
if ws and ws.repo_id:
data["repo_id"] = ws.repo_id
return data
async def _find_source_event(data: dict[str, Any], session: AsyncSession) -> TokenEvent | None:
source_id = data.get("source_id")
if not source_id:
return None
result = await session.execute(
select(TokenEvent).where(
TokenEvent.measurement_kind == data["measurement_kind"],
TokenEvent.source_provider == data["source_provider"],
TokenEvent.source_id == source_id,
)
)
return result.scalar_one_or_none()
async def _create_or_upsert_event(data: dict[str, Any], session: AsyncSession) -> TokenEvent:
data = _apply_event_defaults(data)
data = await _populate_relationship_defaults(data, session)
existing = await _find_source_event(data, session)
if existing is not None:
for field, value in data.items():
setattr(existing, field, value)
await session.commit()
await session.refresh(existing)
return existing
event = TokenEvent(**data)
session.add(event)
@@ -41,6 +148,77 @@ async def create_token_event(
return event
def _filter_query(
q,
*,
task_id: uuid.UUID | None = None,
workstream_id: uuid.UUID | None = None,
repo_id: uuid.UUID | None = None,
ref_type: str | None = None,
ref_id: str | None = None,
model: str | None = None,
agent: str | None = None,
note: str | None = None,
measurement_kind: str | None = None,
source_provider: str | None = None,
since: datetime | None = None,
until: datetime | None = None,
include_superseded: bool = True,
unattributed: bool = False,
):
if task_id:
q = q.where(TokenEvent.task_id == task_id)
if workstream_id:
q = q.where(TokenEvent.workstream_id == workstream_id)
if repo_id:
q = q.where(TokenEvent.repo_id == repo_id)
if ref_type:
q = q.where(TokenEvent.ref_type == ref_type)
if ref_id:
q = q.where(TokenEvent.ref_id == ref_id)
if model:
q = q.where(TokenEvent.model == model)
if agent:
q = q.where(TokenEvent.agent == agent)
if note:
q = q.where(TokenEvent.note == note)
if measurement_kind:
q = q.where(TokenEvent.measurement_kind == measurement_kind)
if source_provider:
q = q.where(TokenEvent.source_provider == source_provider)
if since:
q = q.where(TokenEvent.created_at >= since)
if until:
q = q.where(TokenEvent.created_at < until)
if not include_superseded:
q = q.where(TokenEvent.measurement_kind != "superseded")
if unattributed:
q = q.where(
TokenEvent.repo_id.is_(None),
TokenEvent.workstream_id.is_(None),
TokenEvent.task_id.is_(None),
)
return q
@router.post("/", response_model=TokenEventRead, status_code=status.HTTP_201_CREATED)
async def create_token_event(
body: TokenEventCreate,
session: AsyncSession = Depends(get_session),
) -> TokenEvent:
data = body.model_dump(exclude_none=True)
return await _create_or_upsert_event(data, session)
@router.post("/upsert", response_model=TokenEventRead)
async def upsert_token_event(
body: TokenEventCreate,
session: AsyncSession = Depends(get_session),
) -> TokenEvent:
data = body.model_dump(exclude_none=True)
return await _create_or_upsert_event(data, session)
@router.get("/summary/", response_model=TokenSummary)
async def get_token_summary(
scope: str = Query(..., description="task|workstream|repo|commit|release|session"),
@@ -80,11 +258,16 @@ async def get_token_summary(
by_model: dict[str, int] = defaultdict(int)
by_agent: dict[str, int] = defaultdict(int)
by_measurement_kind: dict[str, int] = defaultdict(int)
by_source_provider: dict[str, int] = defaultdict(int)
for e in events:
total = _event_total(e)
if e.model:
by_model[e.model] += e.tokens_in + e.tokens_out
by_model[e.model] += total
if e.agent:
by_agent[e.agent] += e.tokens_in + e.tokens_out
by_agent[e.agent] += total
by_measurement_kind[e.measurement_kind] += total
by_source_provider[e.source_provider] += total
return TokenSummary(
scope=scope,
@@ -95,11 +278,18 @@ async def get_token_summary(
event_count=len(events),
by_model=dict(by_model),
by_agent=dict(by_agent),
by_measurement_kind=dict(by_measurement_kind),
by_source_provider=dict(by_source_provider),
)
@router.get("/by-repo/", response_model=list[RepoTokenSummary])
async def get_tokens_by_repo(
measurement_kind: str | None = None,
source_provider: str | None = None,
since: datetime | None = None,
until: datetime | None = None,
include_superseded: bool = Query(True),
session: AsyncSession = Depends(get_session),
) -> list[RepoTokenSummary]:
"""Aggregate token consumption per repo, resolving via the full graph.
@@ -112,7 +302,16 @@ async def get_tokens_by_repo(
Only events that resolve to a repo are included.
"""
# Fetch all events, workstreams, repos in three queries (avoids N+1)
events_result = await session.execute(select(TokenEvent))
events_result = await session.execute(
_filter_query(
select(TokenEvent),
measurement_kind=measurement_kind,
source_provider=source_provider,
since=since,
until=until,
include_superseded=include_superseded,
)
)
events = list(events_result.scalars().all())
ws_result = await session.execute(select(Workstream))
@@ -148,14 +347,19 @@ async def get_tokens_by_repo(
"event_count": 0,
"by_model": defaultdict(int),
"by_note": defaultdict(int),
"by_measurement_kind": defaultdict(int),
"by_source_provider": defaultdict(int),
}
g = groups[rid]
g["tokens_in"] += e.tokens_in
g["tokens_out"] += e.tokens_out
g["event_count"] += 1
total = _event_total(e)
if e.model:
g["by_model"][e.model] += e.tokens_in + e.tokens_out
g["by_note"][e.note or "unknown"] += e.tokens_in + e.tokens_out
g["by_model"][e.model] += total
g["by_note"][e.note or "unknown"] += total
g["by_measurement_kind"][e.measurement_kind] += total
g["by_source_provider"][e.source_provider] += total
return [
RepoTokenSummary(
@@ -166,6 +370,188 @@ async def get_tokens_by_repo(
]
@router.get("/aggregate/", response_model=TokenAggregateSummary)
async def get_token_aggregate(
measurement_kind: str | None = None,
source_provider: str | None = None,
since: datetime | None = None,
until: datetime | None = None,
include_superseded: bool = Query(False),
session: AsyncSession = Depends(get_session),
) -> TokenAggregateSummary:
events_result = await session.execute(
_filter_query(
select(TokenEvent),
measurement_kind=measurement_kind,
source_provider=source_provider,
since=since,
until=until,
include_superseded=include_superseded,
)
)
events = list(events_result.scalars().all())
ws_result = await session.execute(select(Workstream))
ws_map: dict[uuid.UUID, Workstream] = {w.id: w for w in ws_result.scalars().all()}
task_result = await session.execute(select(Task))
task_map: dict[uuid.UUID, Task] = {t.id: t for t in task_result.scalars().all()}
repo_result = await session.execute(select(ManagedRepo))
repo_map: dict[uuid.UUID, ManagedRepo] = {r.id: r for r in repo_result.scalars().all()}
def resolve_repo_id(e: TokenEvent) -> uuid.UUID | None:
if e.repo_id:
return e.repo_id
ws_id = e.workstream_id
if not ws_id and e.task_id and e.task_id in task_map:
ws_id = task_map[e.task_id].workstream_id
if ws_id and ws_id in ws_map:
return ws_map[ws_id].repo_id
return None
def add(groups: dict[str, dict[str, Any]], key: str | None, label: str | None, e: TokenEvent) -> None:
if not key:
return
if key not in groups:
groups[key] = {
"scope_id": key,
"label": label,
"tokens_in": 0,
"tokens_out": 0,
"event_count": 0,
"by_measurement_kind": defaultdict(int),
"by_source_provider": defaultdict(int),
}
row = groups[key]
total = _event_total(e)
row["tokens_in"] += e.tokens_in
row["tokens_out"] += e.tokens_out
row["event_count"] += 1
row["by_measurement_kind"][e.measurement_kind] += total
row["by_source_provider"][e.source_provider] += total
by_repo: dict[str, dict[str, Any]] = {}
by_workstream: dict[str, dict[str, Any]] = {}
by_task: dict[str, dict[str, Any]] = {}
by_model: dict[str, dict[str, Any]] = {}
by_measurement_kind: dict[str, int] = defaultdict(int)
by_source_provider: dict[str, int] = defaultdict(int)
first_event_at = last_event_at = last_ingested_at = None
tokens_in = tokens_out = 0
for e in events:
total = _event_total(e)
tokens_in += e.tokens_in
tokens_out += e.tokens_out
by_measurement_kind[e.measurement_kind] += total
by_source_provider[e.source_provider] += total
if first_event_at is None or e.created_at < first_event_at:
first_event_at = e.created_at
if last_event_at is None or e.created_at > last_event_at:
last_event_at = e.created_at
if last_ingested_at is None or e.ingested_at > last_ingested_at:
last_ingested_at = e.ingested_at
rid = resolve_repo_id(e)
repo = repo_map.get(rid) if rid else None
add(by_repo, str(rid) if rid else None, repo.slug if repo else None, e)
ws_id = e.workstream_id or (task_map[e.task_id].workstream_id if e.task_id in task_map else None)
ws = ws_map.get(ws_id) if ws_id else None
add(by_workstream, str(ws_id) if ws_id else None, ws.title if ws else None, e)
task = task_map.get(e.task_id) if e.task_id else None
add(by_task, str(e.task_id) if e.task_id else None, task.title if task else None, e)
add(by_model, e.model or "unknown", e.model or "unknown", e)
def rows(groups: dict[str, dict[str, Any]]) -> list[TokenAggregateRow]:
result = []
for row in groups.values():
result.append(
TokenAggregateRow(
**{k: (dict(v) if isinstance(v, defaultdict) else v) for k, v in row.items()},
tokens_total=row["tokens_in"] + row["tokens_out"],
)
)
return sorted(result, key=lambda item: -item.tokens_total)
return TokenAggregateSummary(
tokens_in=tokens_in,
tokens_out=tokens_out,
tokens_total=tokens_in + tokens_out,
event_count=len(events),
first_event_at=first_event_at,
last_event_at=last_event_at,
last_ingested_at=last_ingested_at,
by_repo=rows(by_repo),
by_workstream=rows(by_workstream),
by_task=rows(by_task),
by_model=rows(by_model),
by_measurement_kind=dict(by_measurement_kind),
by_source_provider=dict(by_source_provider),
)
@router.get("/quality/", response_model=TokenQualitySummary)
async def get_token_quality(
since: datetime | None = None,
until: datetime | None = None,
session: AsyncSession = Depends(get_session),
) -> TokenQualitySummary:
result = await session.execute(_filter_query(select(TokenEvent), since=since, until=until))
events = list(result.scalars().all())
by_measurement_kind: dict[str, int] = defaultdict(int)
by_source_provider: dict[str, int] = defaultdict(int)
source_counts: dict[tuple[str, str, str], int] = defaultdict(int)
last_codex_ingested_at = None
last_claude_ingested_at = None
fallback_count = 0
unattributed_measured_count = 0
missing_provenance_count = 0
for e in events:
by_measurement_kind[e.measurement_kind] += 1
by_source_provider[e.source_provider] += 1
if e.source_id:
source_counts[(e.measurement_kind, e.source_provider, e.source_id)] += 1
if e.source_provider == "task_fallback" or e.note == "heuristic":
fallback_count += 1
if e.measurement_kind == "measured" and not (e.repo_id or e.workstream_id or e.task_id):
unattributed_measured_count += 1
if e.measurement_kind == "measured" and not e.source_id:
missing_provenance_count += 1
if e.source_provider == "codex_session" and (
last_codex_ingested_at is None or e.ingested_at > last_codex_ingested_at
):
last_codex_ingested_at = e.ingested_at
if e.source_provider == "claude_transcript" and (
last_claude_ingested_at is None or e.ingested_at > last_claude_ingested_at
):
last_claude_ingested_at = e.ingested_at
duplicate_source_count = sum(1 for count in source_counts.values() if count > 1)
return TokenQualitySummary(
event_count=len(events),
measured_event_count=by_measurement_kind.get("measured", 0),
estimated_event_count=by_measurement_kind.get("estimated", 0),
allocated_event_count=by_measurement_kind.get("allocated", 0),
superseded_event_count=by_measurement_kind.get("superseded", 0),
fallback_event_count=fallback_count,
unattributed_measured_event_count=unattributed_measured_count,
missing_provenance_event_count=missing_provenance_count,
duplicate_source_count=duplicate_source_count,
last_codex_ingested_at=last_codex_ingested_at,
last_claude_ingested_at=last_claude_ingested_at,
last_reconciliation_at=None,
by_measurement_kind=dict(by_measurement_kind),
by_source_provider=dict(by_source_provider),
)
@router.patch("/{event_id}", response_model=TokenEventRead)
async def patch_token_event(
event_id: uuid.UUID,
@@ -175,7 +561,26 @@ async def patch_token_event(
event = await session.get(TokenEvent, event_id)
if event is None:
raise HTTPException(status_code=404, detail="Token event not found")
for field, value in body.model_dump(exclude_none=True).items():
data = body.model_dump(exclude_none=True)
if "note" in data or "measurement_kind" in data or "source_provider" in data:
merged = {
"tokens_in": data.get("tokens_in", event.tokens_in),
"tokens_out": data.get("tokens_out", event.tokens_out),
"note": data.get("note", event.note),
"agent": data.get("agent", event.agent),
"ref_id": data.get("ref_id", event.ref_id),
"session_id": data.get("session_id", event.session_id),
"measurement_kind": data.get("measurement_kind", event.measurement_kind),
"source_provider": data.get("source_provider", event.source_provider),
"source_id": data.get("source_id", event.source_id),
}
inferred = _apply_event_defaults({k: v for k, v in merged.items() if v is not None})
data.setdefault("measurement_kind", inferred["measurement_kind"])
data.setdefault("source_provider", inferred["source_provider"])
data.setdefault("confidence", inferred["confidence"])
if inferred.get("source_id"):
data.setdefault("source_id", inferred["source_id"])
for field, value in data.items():
setattr(event, field, value)
await session.commit()
await session.refresh(event)
@@ -203,26 +608,33 @@ async def list_token_events(
model: str | None = None,
agent: str | None = None,
note: str | None = None,
measurement_kind: str | None = None,
source_provider: str | None = None,
since: datetime | None = None,
until: datetime | None = None,
include_superseded: bool = Query(True),
unattributed: bool = False,
offset: int = Query(0, ge=0),
limit: int = Query(100, le=1000),
session: AsyncSession = Depends(get_session),
) -> list[TokenEvent]:
q = select(TokenEvent)
if task_id:
q = q.where(TokenEvent.task_id == task_id)
if workstream_id:
q = q.where(TokenEvent.workstream_id == workstream_id)
if repo_id:
q = q.where(TokenEvent.repo_id == repo_id)
if ref_type:
q = q.where(TokenEvent.ref_type == ref_type)
if ref_id:
q = q.where(TokenEvent.ref_id == ref_id)
if model:
q = q.where(TokenEvent.model == model)
if agent:
q = q.where(TokenEvent.agent == agent)
if note:
q = q.where(TokenEvent.note == note)
q = q.order_by(TokenEvent.created_at.desc()).limit(limit)
q = _filter_query(
select(TokenEvent),
task_id=task_id,
workstream_id=workstream_id,
repo_id=repo_id,
ref_type=ref_type,
ref_id=ref_id,
model=model,
agent=agent,
note=note,
measurement_kind=measurement_kind,
source_provider=source_provider,
since=since,
until=until,
include_superseded=include_superseded,
unattributed=unattributed,
)
q = q.order_by(TokenEvent.created_at.desc()).offset(offset).limit(limit)
result = await session.execute(q)
return list(result.scalars().all())