Fixed and improved token tracking

This commit is contained in:
2026-05-23 13:59:05 +02:00
parent dd3279ea1a
commit c12091c2eb
29 changed files with 3549 additions and 278 deletions

View File

@@ -1,8 +1,10 @@
import uuid
from datetime import datetime
from sqlalchemy import DateTime, ForeignKey, Integer, Text, func
from sqlalchemy.dialects.postgresql import UUID
from typing import Any
from sqlalchemy import DateTime, Float, ForeignKey, Integer, Text, UniqueConstraint, func
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from api.models.base import Base, new_uuid
@@ -10,6 +12,14 @@ from api.models.base import Base, new_uuid
class TokenEvent(Base):
__tablename__ = "token_events"
__table_args__ = (
UniqueConstraint(
"measurement_kind",
"source_provider",
"source_id",
name="uq_token_events_source_identity",
),
)
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True), primary_key=True, default=new_uuid
@@ -31,6 +41,35 @@ class TokenEvent(Base):
ref_type: Mapped[str | None] = mapped_column(Text, nullable=True)
ref_id: Mapped[str | None] = mapped_column(Text, nullable=True)
note: Mapped[str | None] = mapped_column(Text, nullable=True)
measurement_kind: Mapped[str] = mapped_column(
Text, nullable=False, default="estimated", server_default="estimated", index=True
)
source_provider: Mapped[str] = mapped_column(
Text, nullable=False, default="manual", server_default="manual", index=True
)
source_id: Mapped[str | None] = mapped_column(Text, nullable=True, index=True)
source_path: Mapped[str | None] = mapped_column(Text, nullable=True)
source_created_at: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True, index=True
)
ingested_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False, index=True
)
parser_version: Mapped[str | None] = mapped_column(Text, nullable=True)
confidence: Mapped[float] = mapped_column(
Float, nullable=False, default=0.35, server_default="0.35"
)
cached_input_tokens: Mapped[int] = mapped_column(
Integer, nullable=False, default=0, server_default="0"
)
reasoning_output_tokens: Mapped[int] = mapped_column(
Integer, nullable=False, default=0, server_default="0"
)
raw_total_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True)
cost_estimated_usd: Mapped[float | None] = mapped_column(Float, nullable=True)
raw_metadata: Mapped[dict[str, Any]] = mapped_column(
JSONB, nullable=False, default=dict, server_default="{}"
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False, index=True
)

View File

@@ -75,23 +75,47 @@ async def update_task(
if task is None:
raise HTTPException(status_code=404, detail="Task not found")
previous_status = task.status.value
# Separate token fields from task fields
token_field_names = {"tokens_in", "tokens_out", "workplan_tokens_in", "workplan_tokens_out", "token_note", "model", "agent", "session_id"}
token_field_names = {
"tokens_in",
"tokens_out",
"workplan_tokens_in",
"workplan_tokens_out",
"token_note",
"model",
"agent",
"session_id",
"suppress_token_event",
}
update_data = body.model_dump(exclude_unset=True)
token_data = {k: update_data.pop(k) for k in list(update_data.keys()) if k in token_field_names}
suppress_token_event = bool(token_data.pop("suppress_token_event", False))
for field, value in update_data.items():
setattr(task, field, value)
await session.commit()
await session.refresh(task)
# Token event — three-tier logic, only when marking done
if update_data.get("status") == "done":
# Token event — three-tier logic, only for an intentional transition to done.
status_update = update_data.get("status")
new_status = status_update.value if hasattr(status_update, "value") else status_update
if (
new_status == "done"
and previous_status != "done"
and not suppress_token_event
):
if "tokens_in" in token_data and "tokens_out" in token_data:
# Tier 1: exact counts — default note "measured"; caller may override with token_note
tin = token_data["tokens_in"]
tout = token_data["tokens_out"]
tnote = token_data.get("token_note") or "measured"
measurement_kind = "measured"
source_provider = "manual"
confidence = 1.0
source_id = f"task:{task_id}:manual"
raw_metadata = {"input_source": "task_status_patch"}
elif "workplan_tokens_in" in token_data and "workplan_tokens_out" in token_data:
# Tier 2: prorate workplan total across task count
count_result = await session.execute(
@@ -101,9 +125,24 @@ async def update_task(
tin = token_data["workplan_tokens_in"] // task_count
tout = token_data["workplan_tokens_out"] // task_count
tnote = "workplan"
measurement_kind = "allocated"
source_provider = "manual"
confidence = 0.7
source_id = f"task:{task_id}:workplan-allocation"
raw_metadata = {
"allocation_method": "workplan_prorated",
"workplan_tokens_in": token_data["workplan_tokens_in"],
"workplan_tokens_out": token_data["workplan_tokens_out"],
"task_count": task_count,
}
else:
# Tier 3: heuristic fallback
tin, tout, tnote = 1000, 500, "heuristic"
measurement_kind = "estimated"
source_provider = "task_fallback"
confidence = 0.35
source_id = f"task:{task_id}:heuristic"
raw_metadata = {"estimation_method": "fixed_task_done_fallback"}
# Resolve repo_id via workstream
ws = await session.get(Workstream, task.workstream_id)
@@ -121,6 +160,12 @@ async def update_task(
ref_type="task",
ref_id=str(task_id),
note=tnote,
measurement_kind=measurement_kind,
source_provider=source_provider,
source_id=source_id,
confidence=confidence,
raw_total_tokens=tin + tout,
raw_metadata=raw_metadata,
)
session.add(event)
await session.commit()

View File

@@ -1,5 +1,7 @@
import uuid
from collections import defaultdict
from datetime import datetime
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import select
@@ -10,18 +12,95 @@ from api.models.managed_repo import ManagedRepo
from api.models.task import Task
from api.models.token_event import TokenEvent
from api.models.workstream import Workstream
from api.schemas.token_event import RepoTokenSummary, TokenEventCreate, TokenEventPatch, TokenEventRead, TokenSummary
from api.schemas.token_event import (
RepoTokenSummary,
TokenAggregateRow,
TokenAggregateSummary,
TokenEventCreate,
TokenEventPatch,
TokenEventRead,
TokenQualitySummary,
TokenSummary,
)
router = APIRouter(prefix="/token-events", tags=["token-events"])
DEFAULT_CONFIDENCE = {
"measured": 1.0,
"allocated": 0.70,
"estimated": 0.35,
"superseded": 0.0,
}
@router.post("/", response_model=TokenEventRead, status_code=status.HTTP_201_CREATED)
async def create_token_event(
body: TokenEventCreate,
session: AsyncSession = Depends(get_session),
) -> TokenEvent:
data = body.model_dump()
SOURCE_PARSER_DEFAULTS = {
"codex_session": "codex-desktop-v1",
"claude_transcript": "claude-transcript-v1",
"llm_connect": "llm-connect-v1",
}
def _event_total(event: TokenEvent) -> int:
return event.tokens_in + event.tokens_out
def _infer_measurement_kind(data: dict[str, Any]) -> str:
if data.get("measurement_kind"):
return str(data["measurement_kind"])
note = data.get("note")
if note == "heuristic_superseded_by_codex_backfill":
return "superseded"
if note == "workplan":
return "allocated"
if note == "heuristic":
return "estimated"
if note == "measured" or str(note or "").startswith("backfill:codex-session"):
return "measured"
provider = data.get("source_provider")
if provider in {"codex_session", "claude_transcript", "llm_connect"}:
return "measured"
return "estimated"
def _infer_source_provider(data: dict[str, Any], measurement_kind: str) -> str:
if data.get("source_provider"):
return str(data["source_provider"])
note = data.get("note")
ref_id = str(data.get("ref_id") or "")
agent = str(data.get("agent") or "").lower()
if note == "heuristic":
return "task_fallback"
if ref_id.startswith("codex:") or str(note or "").startswith("backfill:codex-session"):
return "codex_session"
if measurement_kind == "measured" and "claude" in agent:
return "claude_transcript"
return "manual"
def _apply_event_defaults(data: dict[str, Any]) -> dict[str, Any]:
measurement_kind = _infer_measurement_kind(data)
source_provider = _infer_source_provider(data, measurement_kind)
data["measurement_kind"] = measurement_kind
data["source_provider"] = source_provider
if not data.get("source_id") and source_provider in {"codex_session", "claude_transcript", "llm_connect"}:
source_id = data.get("ref_id") or data.get("session_id")
if source_id:
data["source_id"] = str(source_id)
if not data.get("source_created_at") and data.get("created_at") and data.get("source_id"):
data["source_created_at"] = data["created_at"]
data.setdefault("confidence", DEFAULT_CONFIDENCE.get(measurement_kind, 0.35))
data.setdefault("cached_input_tokens", 0)
data.setdefault("reasoning_output_tokens", 0)
data.setdefault("raw_total_tokens", (data.get("tokens_in") or 0) + (data.get("tokens_out") or 0))
data.setdefault("raw_metadata", {})
if source_provider in SOURCE_PARSER_DEFAULTS:
data.setdefault("parser_version", SOURCE_PARSER_DEFAULTS[source_provider])
return data
async def _populate_relationship_defaults(data: dict[str, Any], session: AsyncSession) -> dict[str, Any]:
# Auto-populate workstream_id from task if not provided
if data.get("task_id") and not data.get("workstream_id"):
task = await session.get(Task, data["task_id"])
@@ -33,6 +112,34 @@ async def create_token_event(
ws = await session.get(Workstream, data["workstream_id"])
if ws and ws.repo_id:
data["repo_id"] = ws.repo_id
return data
async def _find_source_event(data: dict[str, Any], session: AsyncSession) -> TokenEvent | None:
source_id = data.get("source_id")
if not source_id:
return None
result = await session.execute(
select(TokenEvent).where(
TokenEvent.measurement_kind == data["measurement_kind"],
TokenEvent.source_provider == data["source_provider"],
TokenEvent.source_id == source_id,
)
)
return result.scalar_one_or_none()
async def _create_or_upsert_event(data: dict[str, Any], session: AsyncSession) -> TokenEvent:
data = _apply_event_defaults(data)
data = await _populate_relationship_defaults(data, session)
existing = await _find_source_event(data, session)
if existing is not None:
for field, value in data.items():
setattr(existing, field, value)
await session.commit()
await session.refresh(existing)
return existing
event = TokenEvent(**data)
session.add(event)
@@ -41,6 +148,77 @@ async def create_token_event(
return event
def _filter_query(
q,
*,
task_id: uuid.UUID | None = None,
workstream_id: uuid.UUID | None = None,
repo_id: uuid.UUID | None = None,
ref_type: str | None = None,
ref_id: str | None = None,
model: str | None = None,
agent: str | None = None,
note: str | None = None,
measurement_kind: str | None = None,
source_provider: str | None = None,
since: datetime | None = None,
until: datetime | None = None,
include_superseded: bool = True,
unattributed: bool = False,
):
if task_id:
q = q.where(TokenEvent.task_id == task_id)
if workstream_id:
q = q.where(TokenEvent.workstream_id == workstream_id)
if repo_id:
q = q.where(TokenEvent.repo_id == repo_id)
if ref_type:
q = q.where(TokenEvent.ref_type == ref_type)
if ref_id:
q = q.where(TokenEvent.ref_id == ref_id)
if model:
q = q.where(TokenEvent.model == model)
if agent:
q = q.where(TokenEvent.agent == agent)
if note:
q = q.where(TokenEvent.note == note)
if measurement_kind:
q = q.where(TokenEvent.measurement_kind == measurement_kind)
if source_provider:
q = q.where(TokenEvent.source_provider == source_provider)
if since:
q = q.where(TokenEvent.created_at >= since)
if until:
q = q.where(TokenEvent.created_at < until)
if not include_superseded:
q = q.where(TokenEvent.measurement_kind != "superseded")
if unattributed:
q = q.where(
TokenEvent.repo_id.is_(None),
TokenEvent.workstream_id.is_(None),
TokenEvent.task_id.is_(None),
)
return q
@router.post("/", response_model=TokenEventRead, status_code=status.HTTP_201_CREATED)
async def create_token_event(
body: TokenEventCreate,
session: AsyncSession = Depends(get_session),
) -> TokenEvent:
data = body.model_dump(exclude_none=True)
return await _create_or_upsert_event(data, session)
@router.post("/upsert", response_model=TokenEventRead)
async def upsert_token_event(
body: TokenEventCreate,
session: AsyncSession = Depends(get_session),
) -> TokenEvent:
data = body.model_dump(exclude_none=True)
return await _create_or_upsert_event(data, session)
@router.get("/summary/", response_model=TokenSummary)
async def get_token_summary(
scope: str = Query(..., description="task|workstream|repo|commit|release|session"),
@@ -80,11 +258,16 @@ async def get_token_summary(
by_model: dict[str, int] = defaultdict(int)
by_agent: dict[str, int] = defaultdict(int)
by_measurement_kind: dict[str, int] = defaultdict(int)
by_source_provider: dict[str, int] = defaultdict(int)
for e in events:
total = _event_total(e)
if e.model:
by_model[e.model] += e.tokens_in + e.tokens_out
by_model[e.model] += total
if e.agent:
by_agent[e.agent] += e.tokens_in + e.tokens_out
by_agent[e.agent] += total
by_measurement_kind[e.measurement_kind] += total
by_source_provider[e.source_provider] += total
return TokenSummary(
scope=scope,
@@ -95,11 +278,18 @@ async def get_token_summary(
event_count=len(events),
by_model=dict(by_model),
by_agent=dict(by_agent),
by_measurement_kind=dict(by_measurement_kind),
by_source_provider=dict(by_source_provider),
)
@router.get("/by-repo/", response_model=list[RepoTokenSummary])
async def get_tokens_by_repo(
measurement_kind: str | None = None,
source_provider: str | None = None,
since: datetime | None = None,
until: datetime | None = None,
include_superseded: bool = Query(True),
session: AsyncSession = Depends(get_session),
) -> list[RepoTokenSummary]:
"""Aggregate token consumption per repo, resolving via the full graph.
@@ -112,7 +302,16 @@ async def get_tokens_by_repo(
Only events that resolve to a repo are included.
"""
# Fetch all events, workstreams, repos in three queries (avoids N+1)
events_result = await session.execute(select(TokenEvent))
events_result = await session.execute(
_filter_query(
select(TokenEvent),
measurement_kind=measurement_kind,
source_provider=source_provider,
since=since,
until=until,
include_superseded=include_superseded,
)
)
events = list(events_result.scalars().all())
ws_result = await session.execute(select(Workstream))
@@ -148,14 +347,19 @@ async def get_tokens_by_repo(
"event_count": 0,
"by_model": defaultdict(int),
"by_note": defaultdict(int),
"by_measurement_kind": defaultdict(int),
"by_source_provider": defaultdict(int),
}
g = groups[rid]
g["tokens_in"] += e.tokens_in
g["tokens_out"] += e.tokens_out
g["event_count"] += 1
total = _event_total(e)
if e.model:
g["by_model"][e.model] += e.tokens_in + e.tokens_out
g["by_note"][e.note or "unknown"] += e.tokens_in + e.tokens_out
g["by_model"][e.model] += total
g["by_note"][e.note or "unknown"] += total
g["by_measurement_kind"][e.measurement_kind] += total
g["by_source_provider"][e.source_provider] += total
return [
RepoTokenSummary(
@@ -166,6 +370,188 @@ async def get_tokens_by_repo(
]
@router.get("/aggregate/", response_model=TokenAggregateSummary)
async def get_token_aggregate(
measurement_kind: str | None = None,
source_provider: str | None = None,
since: datetime | None = None,
until: datetime | None = None,
include_superseded: bool = Query(False),
session: AsyncSession = Depends(get_session),
) -> TokenAggregateSummary:
events_result = await session.execute(
_filter_query(
select(TokenEvent),
measurement_kind=measurement_kind,
source_provider=source_provider,
since=since,
until=until,
include_superseded=include_superseded,
)
)
events = list(events_result.scalars().all())
ws_result = await session.execute(select(Workstream))
ws_map: dict[uuid.UUID, Workstream] = {w.id: w for w in ws_result.scalars().all()}
task_result = await session.execute(select(Task))
task_map: dict[uuid.UUID, Task] = {t.id: t for t in task_result.scalars().all()}
repo_result = await session.execute(select(ManagedRepo))
repo_map: dict[uuid.UUID, ManagedRepo] = {r.id: r for r in repo_result.scalars().all()}
def resolve_repo_id(e: TokenEvent) -> uuid.UUID | None:
if e.repo_id:
return e.repo_id
ws_id = e.workstream_id
if not ws_id and e.task_id and e.task_id in task_map:
ws_id = task_map[e.task_id].workstream_id
if ws_id and ws_id in ws_map:
return ws_map[ws_id].repo_id
return None
def add(groups: dict[str, dict[str, Any]], key: str | None, label: str | None, e: TokenEvent) -> None:
if not key:
return
if key not in groups:
groups[key] = {
"scope_id": key,
"label": label,
"tokens_in": 0,
"tokens_out": 0,
"event_count": 0,
"by_measurement_kind": defaultdict(int),
"by_source_provider": defaultdict(int),
}
row = groups[key]
total = _event_total(e)
row["tokens_in"] += e.tokens_in
row["tokens_out"] += e.tokens_out
row["event_count"] += 1
row["by_measurement_kind"][e.measurement_kind] += total
row["by_source_provider"][e.source_provider] += total
by_repo: dict[str, dict[str, Any]] = {}
by_workstream: dict[str, dict[str, Any]] = {}
by_task: dict[str, dict[str, Any]] = {}
by_model: dict[str, dict[str, Any]] = {}
by_measurement_kind: dict[str, int] = defaultdict(int)
by_source_provider: dict[str, int] = defaultdict(int)
first_event_at = last_event_at = last_ingested_at = None
tokens_in = tokens_out = 0
for e in events:
total = _event_total(e)
tokens_in += e.tokens_in
tokens_out += e.tokens_out
by_measurement_kind[e.measurement_kind] += total
by_source_provider[e.source_provider] += total
if first_event_at is None or e.created_at < first_event_at:
first_event_at = e.created_at
if last_event_at is None or e.created_at > last_event_at:
last_event_at = e.created_at
if last_ingested_at is None or e.ingested_at > last_ingested_at:
last_ingested_at = e.ingested_at
rid = resolve_repo_id(e)
repo = repo_map.get(rid) if rid else None
add(by_repo, str(rid) if rid else None, repo.slug if repo else None, e)
ws_id = e.workstream_id or (task_map[e.task_id].workstream_id if e.task_id in task_map else None)
ws = ws_map.get(ws_id) if ws_id else None
add(by_workstream, str(ws_id) if ws_id else None, ws.title if ws else None, e)
task = task_map.get(e.task_id) if e.task_id else None
add(by_task, str(e.task_id) if e.task_id else None, task.title if task else None, e)
add(by_model, e.model or "unknown", e.model or "unknown", e)
def rows(groups: dict[str, dict[str, Any]]) -> list[TokenAggregateRow]:
result = []
for row in groups.values():
result.append(
TokenAggregateRow(
**{k: (dict(v) if isinstance(v, defaultdict) else v) for k, v in row.items()},
tokens_total=row["tokens_in"] + row["tokens_out"],
)
)
return sorted(result, key=lambda item: -item.tokens_total)
return TokenAggregateSummary(
tokens_in=tokens_in,
tokens_out=tokens_out,
tokens_total=tokens_in + tokens_out,
event_count=len(events),
first_event_at=first_event_at,
last_event_at=last_event_at,
last_ingested_at=last_ingested_at,
by_repo=rows(by_repo),
by_workstream=rows(by_workstream),
by_task=rows(by_task),
by_model=rows(by_model),
by_measurement_kind=dict(by_measurement_kind),
by_source_provider=dict(by_source_provider),
)
@router.get("/quality/", response_model=TokenQualitySummary)
async def get_token_quality(
since: datetime | None = None,
until: datetime | None = None,
session: AsyncSession = Depends(get_session),
) -> TokenQualitySummary:
result = await session.execute(_filter_query(select(TokenEvent), since=since, until=until))
events = list(result.scalars().all())
by_measurement_kind: dict[str, int] = defaultdict(int)
by_source_provider: dict[str, int] = defaultdict(int)
source_counts: dict[tuple[str, str, str], int] = defaultdict(int)
last_codex_ingested_at = None
last_claude_ingested_at = None
fallback_count = 0
unattributed_measured_count = 0
missing_provenance_count = 0
for e in events:
by_measurement_kind[e.measurement_kind] += 1
by_source_provider[e.source_provider] += 1
if e.source_id:
source_counts[(e.measurement_kind, e.source_provider, e.source_id)] += 1
if e.source_provider == "task_fallback" or e.note == "heuristic":
fallback_count += 1
if e.measurement_kind == "measured" and not (e.repo_id or e.workstream_id or e.task_id):
unattributed_measured_count += 1
if e.measurement_kind == "measured" and not e.source_id:
missing_provenance_count += 1
if e.source_provider == "codex_session" and (
last_codex_ingested_at is None or e.ingested_at > last_codex_ingested_at
):
last_codex_ingested_at = e.ingested_at
if e.source_provider == "claude_transcript" and (
last_claude_ingested_at is None or e.ingested_at > last_claude_ingested_at
):
last_claude_ingested_at = e.ingested_at
duplicate_source_count = sum(1 for count in source_counts.values() if count > 1)
return TokenQualitySummary(
event_count=len(events),
measured_event_count=by_measurement_kind.get("measured", 0),
estimated_event_count=by_measurement_kind.get("estimated", 0),
allocated_event_count=by_measurement_kind.get("allocated", 0),
superseded_event_count=by_measurement_kind.get("superseded", 0),
fallback_event_count=fallback_count,
unattributed_measured_event_count=unattributed_measured_count,
missing_provenance_event_count=missing_provenance_count,
duplicate_source_count=duplicate_source_count,
last_codex_ingested_at=last_codex_ingested_at,
last_claude_ingested_at=last_claude_ingested_at,
last_reconciliation_at=None,
by_measurement_kind=dict(by_measurement_kind),
by_source_provider=dict(by_source_provider),
)
@router.patch("/{event_id}", response_model=TokenEventRead)
async def patch_token_event(
event_id: uuid.UUID,
@@ -175,7 +561,26 @@ async def patch_token_event(
event = await session.get(TokenEvent, event_id)
if event is None:
raise HTTPException(status_code=404, detail="Token event not found")
for field, value in body.model_dump(exclude_none=True).items():
data = body.model_dump(exclude_none=True)
if "note" in data or "measurement_kind" in data or "source_provider" in data:
merged = {
"tokens_in": data.get("tokens_in", event.tokens_in),
"tokens_out": data.get("tokens_out", event.tokens_out),
"note": data.get("note", event.note),
"agent": data.get("agent", event.agent),
"ref_id": data.get("ref_id", event.ref_id),
"session_id": data.get("session_id", event.session_id),
"measurement_kind": data.get("measurement_kind", event.measurement_kind),
"source_provider": data.get("source_provider", event.source_provider),
"source_id": data.get("source_id", event.source_id),
}
inferred = _apply_event_defaults({k: v for k, v in merged.items() if v is not None})
data.setdefault("measurement_kind", inferred["measurement_kind"])
data.setdefault("source_provider", inferred["source_provider"])
data.setdefault("confidence", inferred["confidence"])
if inferred.get("source_id"):
data.setdefault("source_id", inferred["source_id"])
for field, value in data.items():
setattr(event, field, value)
await session.commit()
await session.refresh(event)
@@ -203,26 +608,33 @@ async def list_token_events(
model: str | None = None,
agent: str | None = None,
note: str | None = None,
measurement_kind: str | None = None,
source_provider: str | None = None,
since: datetime | None = None,
until: datetime | None = None,
include_superseded: bool = Query(True),
unattributed: bool = False,
offset: int = Query(0, ge=0),
limit: int = Query(100, le=1000),
session: AsyncSession = Depends(get_session),
) -> list[TokenEvent]:
q = select(TokenEvent)
if task_id:
q = q.where(TokenEvent.task_id == task_id)
if workstream_id:
q = q.where(TokenEvent.workstream_id == workstream_id)
if repo_id:
q = q.where(TokenEvent.repo_id == repo_id)
if ref_type:
q = q.where(TokenEvent.ref_type == ref_type)
if ref_id:
q = q.where(TokenEvent.ref_id == ref_id)
if model:
q = q.where(TokenEvent.model == model)
if agent:
q = q.where(TokenEvent.agent == agent)
if note:
q = q.where(TokenEvent.note == note)
q = q.order_by(TokenEvent.created_at.desc()).limit(limit)
q = _filter_query(
select(TokenEvent),
task_id=task_id,
workstream_id=workstream_id,
repo_id=repo_id,
ref_type=ref_type,
ref_id=ref_id,
model=model,
agent=agent,
note=note,
measurement_kind=measurement_kind,
source_provider=source_provider,
since=since,
until=until,
include_superseded=include_superseded,
unattributed=unattributed,
)
q = q.order_by(TokenEvent.created_at.desc()).offset(offset).limit(limit)
result = await session.execute(q)
return list(result.scalars().all())

View File

@@ -43,6 +43,7 @@ class TaskUpdate(BaseModel):
# 2. workplan_tokens_in + workplan_tokens_out → prorated across task count (note="workplan")
# 3. neither provided, status=done → heuristic 1000/500 (note="heuristic")
# token_note overrides the auto-assigned note for Tier 1 only (e.g. "userbased")
# suppress_token_event lets file/cache sync update status without recording usage.
tokens_in: int | None = None
tokens_out: int | None = None
workplan_tokens_in: int | None = None
@@ -51,6 +52,7 @@ class TaskUpdate(BaseModel):
model: str | None = None
agent: str | None = None
session_id: str | None = None
suppress_token_event: bool | None = None
@model_validator(mode="after")
def blocking_reason_required_when_blocked(self) -> Self:

View File

@@ -1,7 +1,8 @@
import uuid
from datetime import datetime
from typing import Any
from pydantic import BaseModel, ConfigDict, computed_field
from pydantic import BaseModel, ConfigDict, Field, computed_field
class TokenEventCreate(BaseModel):
@@ -16,6 +17,19 @@ class TokenEventCreate(BaseModel):
ref_type: str | None = None
ref_id: str | None = None
note: str | None = None
created_at: datetime | None = None
measurement_kind: str | None = None
source_provider: str | None = None
source_id: str | None = None
source_path: str | None = None
source_created_at: datetime | None = None
parser_version: str | None = None
confidence: float | None = None
cached_input_tokens: int | None = None
reasoning_output_tokens: int | None = None
raw_total_tokens: int | None = None
cost_estimated_usd: float | None = None
raw_metadata: dict[str, Any] | None = None
class TokenEventRead(BaseModel):
@@ -33,6 +47,19 @@ class TokenEventRead(BaseModel):
ref_type: str | None = None
ref_id: str | None = None
note: str | None = None
measurement_kind: str
source_provider: str
source_id: str | None = None
source_path: str | None = None
source_created_at: datetime | None = None
ingested_at: datetime
parser_version: str | None = None
confidence: float
cached_input_tokens: int
reasoning_output_tokens: int
raw_total_tokens: int | None = None
cost_estimated_usd: float | None = None
raw_metadata: dict[str, Any] = Field(default_factory=dict)
created_at: datetime
@computed_field
@@ -40,6 +67,11 @@ class TokenEventRead(BaseModel):
def tokens_total(self) -> int:
return self.tokens_in + self.tokens_out
@computed_field
@property
def token_evidence_total(self) -> int:
return (self.raw_total_tokens or self.tokens_in + self.tokens_out)
class TokenSummary(BaseModel):
scope: str
@@ -50,14 +82,36 @@ class TokenSummary(BaseModel):
event_count: int
by_model: dict[str, int]
by_agent: dict[str, int]
by_measurement_kind: dict[str, int] = Field(default_factory=dict)
by_source_provider: dict[str, int] = Field(default_factory=dict)
class TokenEventPatch(BaseModel):
tokens_in: int | None = None
tokens_out: int | None = None
task_id: uuid.UUID | None = None
workstream_id: uuid.UUID | None = None
repo_id: uuid.UUID | None = None
session_id: str | None = None
note: str | None = None
model: str | None = None
agent: str | None = None
ref_type: str | None = None
ref_id: str | None = None
created_at: datetime | None = None
measurement_kind: str | None = None
source_provider: str | None = None
source_id: str | None = None
source_path: str | None = None
source_created_at: datetime | None = None
ingested_at: datetime | None = None
parser_version: str | None = None
confidence: float | None = None
cached_input_tokens: int | None = None
reasoning_output_tokens: int | None = None
raw_total_tokens: int | None = None
cost_estimated_usd: float | None = None
raw_metadata: dict[str, Any] | None = None
class RepoTokenSummary(BaseModel):
@@ -69,3 +123,49 @@ class RepoTokenSummary(BaseModel):
event_count: int
by_model: dict[str, int]
by_note: dict[str, int]
by_measurement_kind: dict[str, int] = Field(default_factory=dict)
by_source_provider: dict[str, int] = Field(default_factory=dict)
class TokenAggregateRow(BaseModel):
scope_id: str
label: str | None = None
tokens_in: int
tokens_out: int
tokens_total: int
event_count: int
by_measurement_kind: dict[str, int] = Field(default_factory=dict)
by_source_provider: dict[str, int] = Field(default_factory=dict)
class TokenAggregateSummary(BaseModel):
tokens_in: int
tokens_out: int
tokens_total: int
event_count: int
first_event_at: datetime | None = None
last_event_at: datetime | None = None
last_ingested_at: datetime | None = None
by_repo: list[TokenAggregateRow] = Field(default_factory=list)
by_workstream: list[TokenAggregateRow] = Field(default_factory=list)
by_task: list[TokenAggregateRow] = Field(default_factory=list)
by_model: list[TokenAggregateRow] = Field(default_factory=list)
by_measurement_kind: dict[str, int] = Field(default_factory=dict)
by_source_provider: dict[str, int] = Field(default_factory=dict)
class TokenQualitySummary(BaseModel):
event_count: int
measured_event_count: int
estimated_event_count: int
allocated_event_count: int
superseded_event_count: int
fallback_event_count: int
unattributed_measured_event_count: int
missing_provenance_event_count: int
duplicate_source_count: int
last_codex_ingested_at: datetime | None = None
last_claude_ingested_at: datetime | None = None
last_reconciliation_at: datetime | None = None
by_measurement_kind: dict[str, int] = Field(default_factory=dict)
by_source_provider: dict[str, int] = Field(default_factory=dict)

View File

@@ -0,0 +1,16 @@
"""Token source adapters for measured agent usage."""
from api.services.token_sources.base import TokenSourceRecord, parse_iso
from api.services.token_sources.codex import collect_codex_sessions, iter_codex_session_files, parse_codex_session
from api.services.token_sources.claude import collect_claude_transcripts, iter_claude_transcript_files, parse_claude_transcript
__all__ = [
"TokenSourceRecord",
"parse_iso",
"collect_codex_sessions",
"iter_codex_session_files",
"parse_codex_session",
"collect_claude_transcripts",
"iter_claude_transcript_files",
"parse_claude_transcript",
]

View File

@@ -0,0 +1,171 @@
from __future__ import annotations
import subprocess
from dataclasses import dataclass
from pathlib import Path
from typing import Any
@dataclass(frozen=True)
class RepoRef:
repo_id: str
slug: str
local_path: str | None = None
host_paths: dict[str, Any] | None = None
remote_url: str | None = None
git_fingerprint: str | None = None
@dataclass(frozen=True)
class RepoMatch:
repo_id: str
slug: str
method: str
confidence: float
def normalise_cwd(raw: str | None) -> str | None:
if not raw:
return None
value = raw.replace("\\", "/")
prefixes = (
"//wsl.localhost/Ubuntu-24.04",
"//wsl$/Ubuntu-24.04",
)
for prefix in prefixes:
if value.startswith(prefix):
return value[len(prefix):] or "/"
if len(value) >= 3 and value[1:3] == ":/":
drive = value[0].lower()
return f"/mnt/{drive}{value[2:]}"
return value
def normalise_remote_url(raw: str | None) -> str | None:
if not raw:
return None
value = raw.strip()
if value.endswith(".git"):
value = value[:-4]
if value.startswith("git@") and ":" in value:
host, path = value[4:].split(":", 1)
value = f"ssh://{host}/{path}"
return value.lower().rstrip("/")
def repo_refs_from_api(repos: list[dict[str, Any]]) -> list[RepoRef]:
refs = []
for repo in repos:
repo_id = repo.get("id")
slug = repo.get("slug")
if not repo_id or not slug:
continue
refs.append(
RepoRef(
repo_id=str(repo_id),
slug=str(slug),
local_path=repo.get("local_path"),
host_paths=repo.get("host_paths") if isinstance(repo.get("host_paths"), dict) else {},
remote_url=repo.get("remote_url"),
git_fingerprint=repo.get("git_fingerprint"),
)
)
return refs
def _git(cwd: str, *args: str) -> str | None:
try:
result = subprocess.run(
["git", *args],
cwd=cwd,
check=False,
capture_output=True,
text=True,
timeout=5,
)
except (OSError, subprocess.SubprocessError):
return None
if result.returncode != 0:
return None
value = result.stdout.strip().splitlines()
return value[0] if value else None
def git_fingerprint_for_path(cwd: str | None) -> str | None:
path = normalise_cwd(cwd)
if not path or not Path(path).exists():
return None
root = _git(path, "rev-parse", "--show-toplevel")
if not root:
return None
return _git(root, "rev-list", "--max-parents=0", "HEAD")
def git_remote_for_path(cwd: str | None) -> str | None:
path = normalise_cwd(cwd)
if not path or not Path(path).exists():
return None
root = _git(path, "rev-parse", "--show-toplevel")
if not root:
return None
return _git(root, "remote", "get-url", "origin")
def _repo_paths(repo: RepoRef) -> list[str]:
paths = [repo.local_path]
if repo.host_paths:
paths.extend(str(v) for v in repo.host_paths.values() if v)
result = []
for raw in paths:
path = normalise_cwd(str(raw)) if raw and raw != "(unknown)" else None
if path:
result.append(path.rstrip("/"))
return result
def resolve_repo(cwd: str | None, repos: list[RepoRef]) -> RepoMatch | None:
path = normalise_cwd(cwd)
fingerprint = git_fingerprint_for_path(path)
remote = normalise_remote_url(git_remote_for_path(path))
if fingerprint:
candidates = [repo for repo in repos if repo.git_fingerprint == fingerprint]
if len(candidates) == 1:
repo = candidates[0]
return RepoMatch(repo.repo_id, repo.slug, "git_fingerprint", 0.98)
if remote:
remote_candidates = [
repo for repo in candidates
if normalise_remote_url(repo.remote_url) == remote
]
if len(remote_candidates) == 1:
repo = remote_candidates[0]
return RepoMatch(repo.repo_id, repo.slug, "git_fingerprint_remote", 0.99)
if remote:
candidates = [repo for repo in repos if normalise_remote_url(repo.remote_url) == remote]
if len(candidates) == 1:
repo = candidates[0]
return RepoMatch(repo.repo_id, repo.slug, "remote_url", 0.90)
if not path:
return None
path_matches: list[tuple[str, RepoRef]] = []
for repo in repos:
for repo_path in _repo_paths(repo):
if path == repo_path or path.startswith(f"{repo_path}/"):
path_matches.append((repo_path, repo))
if not path_matches:
return None
path_matches.sort(key=lambda item: len(item[0]), reverse=True)
exact = [item for item in path_matches if path == item[0]]
if exact:
basename = Path(path).name
for _, repo in exact:
if repo.slug == basename:
return RepoMatch(repo.repo_id, repo.slug, "path_exact_slug", 0.85)
repo = exact[0][1]
return RepoMatch(repo.repo_id, repo.slug, "path_exact", 0.80)
repo = path_matches[0][1]
return RepoMatch(repo.repo_id, repo.slug, "path_prefix", 0.75)

View File

@@ -0,0 +1,71 @@
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
def parse_iso(value: str) -> datetime:
raw = value.strip()
if raw.endswith("Z"):
raw = raw[:-1] + "+00:00"
if "T" not in raw:
raw = f"{raw}T00:00:00+00:00"
parsed = datetime.fromisoformat(raw)
if parsed.tzinfo is None:
parsed = parsed.replace(tzinfo=timezone.utc)
return parsed.astimezone(timezone.utc)
@dataclass
class TokenSourceRecord:
source_provider: str
source_id: str
source_path: Path
source_created_at: datetime | None
session_id: str | None = None
cwd: str | None = None
model: str | None = None
agent: str | None = None
tokens_in: int = 0
tokens_out: int = 0
cached_input_tokens: int = 0
reasoning_output_tokens: int = 0
raw_total_tokens: int | None = None
parser_version: str | None = None
confidence: float = 1.0
raw_metadata: dict[str, Any] = field(default_factory=dict)
@property
def tokens_total(self) -> int:
return self.tokens_in + self.tokens_out
def to_token_event_payload(self, repo_id: str | None = None) -> dict[str, Any]:
raw_total = self.raw_total_tokens
if raw_total is None:
raw_total = self.tokens_in + self.tokens_out
created_at = self.source_created_at.isoformat() if self.source_created_at else None
return {
"tokens_in": self.tokens_in,
"tokens_out": self.tokens_out,
"repo_id": repo_id,
"session_id": self.session_id,
"model": self.model,
"agent": self.agent,
"ref_type": "session",
"ref_id": self.source_id,
"note": f"measured:{self.source_provider}",
"created_at": created_at,
"measurement_kind": "measured",
"source_provider": self.source_provider,
"source_id": self.source_id,
"source_path": str(self.source_path),
"source_created_at": created_at,
"parser_version": self.parser_version,
"confidence": self.confidence,
"cached_input_tokens": self.cached_input_tokens,
"reasoning_output_tokens": self.reasoning_output_tokens,
"raw_total_tokens": raw_total,
"raw_metadata": self.raw_metadata,
}

View File

@@ -0,0 +1,120 @@
from __future__ import annotations
import json
from datetime import datetime
from pathlib import Path
from typing import Any
from api.services.token_sources.base import TokenSourceRecord, parse_iso
PARSER_VERSION = "claude-transcript-v1"
def iter_claude_transcript_files(claude_home: Path) -> list[Path]:
projects = claude_home / "projects"
if not projects.is_dir():
return []
return sorted(projects.glob("**/*.jsonl"))
def _usage_from_entry(entry: dict[str, Any]) -> dict[str, Any]:
message = entry.get("message")
if isinstance(message, dict) and isinstance(message.get("usage"), dict):
return message["usage"]
usage = entry.get("usage")
return usage if isinstance(usage, dict) else {}
def parse_claude_transcript(path: Path, since: datetime) -> TokenSourceRecord | None:
session_id = path.stem
cwd: str | None = None
model: str | None = None
first_at: datetime | None = None
last_at: datetime | None = None
tokens_in = tokens_out = 0
cached_input_tokens = 0
raw_total_tokens = 0
usage_records = 0
malformed_lines = 0
try:
handle = path.open("r", encoding="utf-8", errors="ignore")
except OSError:
return None
with handle:
for line in handle:
try:
entry: dict[str, Any] = json.loads(line)
except json.JSONDecodeError:
malformed_lines += 1
continue
ts = entry.get("timestamp") or entry.get("created_at")
parsed_ts = parse_iso(ts) if isinstance(ts, str) else None
if parsed_ts:
first_at = first_at or parsed_ts
last_at = parsed_ts
session_id = str(entry.get("session_id") or entry.get("conversation_id") or session_id)
cwd = entry.get("cwd") or entry.get("project_cwd") or cwd
model = entry.get("model") or model
message = entry.get("message")
if isinstance(message, dict):
model = message.get("model") or model
usage = _usage_from_entry(entry)
if not usage:
continue
if parsed_ts is not None and parsed_ts < since:
continue
input_tokens = int(usage.get("input_tokens") or 0)
cache_creation = int(usage.get("cache_creation_input_tokens") or 0)
cache_read = int(usage.get("cache_read_input_tokens") or 0)
output_tokens = int(usage.get("output_tokens") or 0)
if input_tokens == 0 and output_tokens == 0 and cache_creation == 0 and cache_read == 0:
continue
tokens_in += input_tokens
tokens_out += output_tokens
cached_input_tokens += cache_creation + cache_read
raw_total_tokens += input_tokens + cache_creation + cache_read + output_tokens
usage_records += 1
if usage_records == 0 or tokens_in + tokens_out + cached_input_tokens == 0:
return None
return TokenSourceRecord(
source_provider="claude_transcript",
source_id=f"claude:{session_id}",
source_path=path,
source_created_at=last_at,
session_id=session_id,
cwd=cwd,
model=model,
agent="claude",
tokens_in=tokens_in,
tokens_out=tokens_out,
cached_input_tokens=cached_input_tokens,
raw_total_tokens=raw_total_tokens or None,
parser_version=PARSER_VERSION,
confidence=1.0,
raw_metadata={
"started_at": first_at.isoformat() if first_at else None,
"usage_records": usage_records,
"malformed_lines": malformed_lines,
"source_file_name": path.name,
},
)
def collect_claude_transcripts(claude_home: Path, since: datetime) -> list[TokenSourceRecord]:
by_id: dict[str, TokenSourceRecord] = {}
for path in iter_claude_transcript_files(claude_home):
parsed = parse_claude_transcript(path, since)
if parsed is None:
continue
current = by_id.get(parsed.source_id)
if current is None or parsed.tokens_total > current.tokens_total:
by_id[parsed.source_id] = parsed
return sorted(by_id.values(), key=lambda item: item.source_created_at or datetime.min.replace(tzinfo=since.tzinfo))

View File

@@ -0,0 +1,124 @@
from __future__ import annotations
import json
from datetime import datetime
from pathlib import Path
from typing import Any
from api.services.token_sources.base import TokenSourceRecord, parse_iso
PARSER_VERSION = "codex-desktop-v1"
def iter_codex_session_files(codex_home: Path) -> list[Path]:
files: list[Path] = []
sessions = codex_home / "sessions"
archived = codex_home / "archived_sessions"
if sessions.is_dir():
files.extend(sorted(sessions.glob("*/*/*/*.jsonl")))
if archived.is_dir():
files.extend(sorted(archived.glob("*.jsonl")))
return files
def parse_codex_session(path: Path, since: datetime) -> TokenSourceRecord | None:
fallback_id = path.stem.removeprefix("rollout-")
session_id = fallback_id
started_at: datetime | None = None
last_at: datetime | None = None
cwd: str | None = None
model: str | None = None
tokens_in = tokens_out = 0
cached_input_tokens = reasoning_output_tokens = 0
raw_total_tokens = 0
usage_records = 0
malformed_lines = 0
try:
handle = path.open("r", encoding="utf-8", errors="ignore")
except OSError:
return None
with handle:
for line in handle:
try:
entry: dict[str, Any] = json.loads(line)
except json.JSONDecodeError:
malformed_lines += 1
continue
ts = entry.get("timestamp")
parsed_ts = parse_iso(ts) if isinstance(ts, str) else None
if parsed_ts:
last_at = parsed_ts
started_at = started_at or parsed_ts
payload = entry.get("payload") or {}
if entry.get("type") == "session_meta":
meta_id = payload.get("id")
if meta_id:
session_id = str(meta_id)
cwd = payload.get("cwd") or cwd
meta_ts = payload.get("timestamp")
if isinstance(meta_ts, str):
started_at = parse_iso(meta_ts)
elif entry.get("type") == "turn_context":
cwd = payload.get("cwd") or cwd
model = payload.get("model") or model
elif entry.get("type") == "event_msg" and payload.get("type") == "token_count":
if parsed_ts is None or parsed_ts < since:
continue
info = payload.get("info") or {}
last = info.get("last_token_usage") or {}
if not isinstance(last, dict):
continue
input_tokens = int(last.get("input_tokens") or 0)
output_tokens = int(last.get("output_tokens") or 0)
if input_tokens == 0 and output_tokens == 0:
continue
tokens_in += input_tokens
tokens_out += output_tokens
cached_input_tokens += int(last.get("cached_input_tokens") or 0)
reasoning_output_tokens += int(last.get("reasoning_output_tokens") or 0)
raw_total_tokens += int(last.get("total_tokens") or input_tokens + output_tokens)
usage_records += 1
last_at = parsed_ts
if usage_records == 0 or tokens_in + tokens_out == 0:
return None
return TokenSourceRecord(
source_provider="codex_session",
source_id=f"codex:{session_id}",
source_path=path,
source_created_at=last_at,
session_id=session_id,
cwd=cwd,
model=model,
agent="codex",
tokens_in=tokens_in,
tokens_out=tokens_out,
cached_input_tokens=cached_input_tokens,
reasoning_output_tokens=reasoning_output_tokens,
raw_total_tokens=raw_total_tokens or None,
parser_version=PARSER_VERSION,
confidence=1.0,
raw_metadata={
"started_at": started_at.isoformat() if started_at else None,
"usage_records": usage_records,
"malformed_lines": malformed_lines,
"source_file_name": path.name,
},
)
def collect_codex_sessions(codex_home: Path, since: datetime) -> list[TokenSourceRecord]:
by_id: dict[str, TokenSourceRecord] = {}
for path in iter_codex_session_files(codex_home):
parsed = parse_codex_session(path, since)
if parsed is None:
continue
current = by_id.get(parsed.source_id)
if current is None or parsed.tokens_total > current.tokens_total:
by_id[parsed.source_id] = parsed
return sorted(by_id.values(), key=lambda item: item.source_created_at or datetime.min.replace(tzinfo=since.tzinfo))