generated from coulomb/repo-seed
Fixed and improved token tracking
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, ForeignKey, Integer, Text, func
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import DateTime, Float, ForeignKey, Integer, Text, UniqueConstraint, func
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from api.models.base import Base, new_uuid
|
||||
@@ -10,6 +12,14 @@ from api.models.base import Base, new_uuid
|
||||
|
||||
class TokenEvent(Base):
|
||||
__tablename__ = "token_events"
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"measurement_kind",
|
||||
"source_provider",
|
||||
"source_id",
|
||||
name="uq_token_events_source_identity",
|
||||
),
|
||||
)
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True), primary_key=True, default=new_uuid
|
||||
@@ -31,6 +41,35 @@ class TokenEvent(Base):
|
||||
ref_type: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
ref_id: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
note: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
measurement_kind: Mapped[str] = mapped_column(
|
||||
Text, nullable=False, default="estimated", server_default="estimated", index=True
|
||||
)
|
||||
source_provider: Mapped[str] = mapped_column(
|
||||
Text, nullable=False, default="manual", server_default="manual", index=True
|
||||
)
|
||||
source_id: Mapped[str | None] = mapped_column(Text, nullable=True, index=True)
|
||||
source_path: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
source_created_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True, index=True
|
||||
)
|
||||
ingested_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False, index=True
|
||||
)
|
||||
parser_version: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
confidence: Mapped[float] = mapped_column(
|
||||
Float, nullable=False, default=0.35, server_default="0.35"
|
||||
)
|
||||
cached_input_tokens: Mapped[int] = mapped_column(
|
||||
Integer, nullable=False, default=0, server_default="0"
|
||||
)
|
||||
reasoning_output_tokens: Mapped[int] = mapped_column(
|
||||
Integer, nullable=False, default=0, server_default="0"
|
||||
)
|
||||
raw_total_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
cost_estimated_usd: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
raw_metadata: Mapped[dict[str, Any]] = mapped_column(
|
||||
JSONB, nullable=False, default=dict, server_default="{}"
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False, index=True
|
||||
)
|
||||
|
||||
@@ -75,23 +75,47 @@ async def update_task(
|
||||
if task is None:
|
||||
raise HTTPException(status_code=404, detail="Task not found")
|
||||
|
||||
previous_status = task.status.value
|
||||
|
||||
# Separate token fields from task fields
|
||||
token_field_names = {"tokens_in", "tokens_out", "workplan_tokens_in", "workplan_tokens_out", "token_note", "model", "agent", "session_id"}
|
||||
token_field_names = {
|
||||
"tokens_in",
|
||||
"tokens_out",
|
||||
"workplan_tokens_in",
|
||||
"workplan_tokens_out",
|
||||
"token_note",
|
||||
"model",
|
||||
"agent",
|
||||
"session_id",
|
||||
"suppress_token_event",
|
||||
}
|
||||
update_data = body.model_dump(exclude_unset=True)
|
||||
token_data = {k: update_data.pop(k) for k in list(update_data.keys()) if k in token_field_names}
|
||||
suppress_token_event = bool(token_data.pop("suppress_token_event", False))
|
||||
|
||||
for field, value in update_data.items():
|
||||
setattr(task, field, value)
|
||||
await session.commit()
|
||||
await session.refresh(task)
|
||||
|
||||
# Token event — three-tier logic, only when marking done
|
||||
if update_data.get("status") == "done":
|
||||
# Token event — three-tier logic, only for an intentional transition to done.
|
||||
status_update = update_data.get("status")
|
||||
new_status = status_update.value if hasattr(status_update, "value") else status_update
|
||||
if (
|
||||
new_status == "done"
|
||||
and previous_status != "done"
|
||||
and not suppress_token_event
|
||||
):
|
||||
if "tokens_in" in token_data and "tokens_out" in token_data:
|
||||
# Tier 1: exact counts — default note "measured"; caller may override with token_note
|
||||
tin = token_data["tokens_in"]
|
||||
tout = token_data["tokens_out"]
|
||||
tnote = token_data.get("token_note") or "measured"
|
||||
measurement_kind = "measured"
|
||||
source_provider = "manual"
|
||||
confidence = 1.0
|
||||
source_id = f"task:{task_id}:manual"
|
||||
raw_metadata = {"input_source": "task_status_patch"}
|
||||
elif "workplan_tokens_in" in token_data and "workplan_tokens_out" in token_data:
|
||||
# Tier 2: prorate workplan total across task count
|
||||
count_result = await session.execute(
|
||||
@@ -101,9 +125,24 @@ async def update_task(
|
||||
tin = token_data["workplan_tokens_in"] // task_count
|
||||
tout = token_data["workplan_tokens_out"] // task_count
|
||||
tnote = "workplan"
|
||||
measurement_kind = "allocated"
|
||||
source_provider = "manual"
|
||||
confidence = 0.7
|
||||
source_id = f"task:{task_id}:workplan-allocation"
|
||||
raw_metadata = {
|
||||
"allocation_method": "workplan_prorated",
|
||||
"workplan_tokens_in": token_data["workplan_tokens_in"],
|
||||
"workplan_tokens_out": token_data["workplan_tokens_out"],
|
||||
"task_count": task_count,
|
||||
}
|
||||
else:
|
||||
# Tier 3: heuristic fallback
|
||||
tin, tout, tnote = 1000, 500, "heuristic"
|
||||
measurement_kind = "estimated"
|
||||
source_provider = "task_fallback"
|
||||
confidence = 0.35
|
||||
source_id = f"task:{task_id}:heuristic"
|
||||
raw_metadata = {"estimation_method": "fixed_task_done_fallback"}
|
||||
|
||||
# Resolve repo_id via workstream
|
||||
ws = await session.get(Workstream, task.workstream_id)
|
||||
@@ -121,6 +160,12 @@ async def update_task(
|
||||
ref_type="task",
|
||||
ref_id=str(task_id),
|
||||
note=tnote,
|
||||
measurement_kind=measurement_kind,
|
||||
source_provider=source_provider,
|
||||
source_id=source_id,
|
||||
confidence=confidence,
|
||||
raw_total_tokens=tin + tout,
|
||||
raw_metadata=raw_metadata,
|
||||
)
|
||||
session.add(event)
|
||||
await session.commit()
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy import select
|
||||
@@ -10,18 +12,95 @@ from api.models.managed_repo import ManagedRepo
|
||||
from api.models.task import Task
|
||||
from api.models.token_event import TokenEvent
|
||||
from api.models.workstream import Workstream
|
||||
from api.schemas.token_event import RepoTokenSummary, TokenEventCreate, TokenEventPatch, TokenEventRead, TokenSummary
|
||||
from api.schemas.token_event import (
|
||||
RepoTokenSummary,
|
||||
TokenAggregateRow,
|
||||
TokenAggregateSummary,
|
||||
TokenEventCreate,
|
||||
TokenEventPatch,
|
||||
TokenEventRead,
|
||||
TokenQualitySummary,
|
||||
TokenSummary,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/token-events", tags=["token-events"])
|
||||
|
||||
DEFAULT_CONFIDENCE = {
|
||||
"measured": 1.0,
|
||||
"allocated": 0.70,
|
||||
"estimated": 0.35,
|
||||
"superseded": 0.0,
|
||||
}
|
||||
|
||||
@router.post("/", response_model=TokenEventRead, status_code=status.HTTP_201_CREATED)
|
||||
async def create_token_event(
|
||||
body: TokenEventCreate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> TokenEvent:
|
||||
data = body.model_dump()
|
||||
SOURCE_PARSER_DEFAULTS = {
|
||||
"codex_session": "codex-desktop-v1",
|
||||
"claude_transcript": "claude-transcript-v1",
|
||||
"llm_connect": "llm-connect-v1",
|
||||
}
|
||||
|
||||
|
||||
def _event_total(event: TokenEvent) -> int:
|
||||
return event.tokens_in + event.tokens_out
|
||||
|
||||
|
||||
def _infer_measurement_kind(data: dict[str, Any]) -> str:
|
||||
if data.get("measurement_kind"):
|
||||
return str(data["measurement_kind"])
|
||||
note = data.get("note")
|
||||
if note == "heuristic_superseded_by_codex_backfill":
|
||||
return "superseded"
|
||||
if note == "workplan":
|
||||
return "allocated"
|
||||
if note == "heuristic":
|
||||
return "estimated"
|
||||
if note == "measured" or str(note or "").startswith("backfill:codex-session"):
|
||||
return "measured"
|
||||
provider = data.get("source_provider")
|
||||
if provider in {"codex_session", "claude_transcript", "llm_connect"}:
|
||||
return "measured"
|
||||
return "estimated"
|
||||
|
||||
|
||||
def _infer_source_provider(data: dict[str, Any], measurement_kind: str) -> str:
|
||||
if data.get("source_provider"):
|
||||
return str(data["source_provider"])
|
||||
note = data.get("note")
|
||||
ref_id = str(data.get("ref_id") or "")
|
||||
agent = str(data.get("agent") or "").lower()
|
||||
if note == "heuristic":
|
||||
return "task_fallback"
|
||||
if ref_id.startswith("codex:") or str(note or "").startswith("backfill:codex-session"):
|
||||
return "codex_session"
|
||||
if measurement_kind == "measured" and "claude" in agent:
|
||||
return "claude_transcript"
|
||||
return "manual"
|
||||
|
||||
|
||||
def _apply_event_defaults(data: dict[str, Any]) -> dict[str, Any]:
|
||||
measurement_kind = _infer_measurement_kind(data)
|
||||
source_provider = _infer_source_provider(data, measurement_kind)
|
||||
data["measurement_kind"] = measurement_kind
|
||||
data["source_provider"] = source_provider
|
||||
|
||||
if not data.get("source_id") and source_provider in {"codex_session", "claude_transcript", "llm_connect"}:
|
||||
source_id = data.get("ref_id") or data.get("session_id")
|
||||
if source_id:
|
||||
data["source_id"] = str(source_id)
|
||||
|
||||
if not data.get("source_created_at") and data.get("created_at") and data.get("source_id"):
|
||||
data["source_created_at"] = data["created_at"]
|
||||
|
||||
data.setdefault("confidence", DEFAULT_CONFIDENCE.get(measurement_kind, 0.35))
|
||||
data.setdefault("cached_input_tokens", 0)
|
||||
data.setdefault("reasoning_output_tokens", 0)
|
||||
data.setdefault("raw_total_tokens", (data.get("tokens_in") or 0) + (data.get("tokens_out") or 0))
|
||||
data.setdefault("raw_metadata", {})
|
||||
if source_provider in SOURCE_PARSER_DEFAULTS:
|
||||
data.setdefault("parser_version", SOURCE_PARSER_DEFAULTS[source_provider])
|
||||
return data
|
||||
|
||||
|
||||
async def _populate_relationship_defaults(data: dict[str, Any], session: AsyncSession) -> dict[str, Any]:
|
||||
# Auto-populate workstream_id from task if not provided
|
||||
if data.get("task_id") and not data.get("workstream_id"):
|
||||
task = await session.get(Task, data["task_id"])
|
||||
@@ -33,6 +112,34 @@ async def create_token_event(
|
||||
ws = await session.get(Workstream, data["workstream_id"])
|
||||
if ws and ws.repo_id:
|
||||
data["repo_id"] = ws.repo_id
|
||||
return data
|
||||
|
||||
|
||||
async def _find_source_event(data: dict[str, Any], session: AsyncSession) -> TokenEvent | None:
|
||||
source_id = data.get("source_id")
|
||||
if not source_id:
|
||||
return None
|
||||
result = await session.execute(
|
||||
select(TokenEvent).where(
|
||||
TokenEvent.measurement_kind == data["measurement_kind"],
|
||||
TokenEvent.source_provider == data["source_provider"],
|
||||
TokenEvent.source_id == source_id,
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def _create_or_upsert_event(data: dict[str, Any], session: AsyncSession) -> TokenEvent:
|
||||
data = _apply_event_defaults(data)
|
||||
data = await _populate_relationship_defaults(data, session)
|
||||
|
||||
existing = await _find_source_event(data, session)
|
||||
if existing is not None:
|
||||
for field, value in data.items():
|
||||
setattr(existing, field, value)
|
||||
await session.commit()
|
||||
await session.refresh(existing)
|
||||
return existing
|
||||
|
||||
event = TokenEvent(**data)
|
||||
session.add(event)
|
||||
@@ -41,6 +148,77 @@ async def create_token_event(
|
||||
return event
|
||||
|
||||
|
||||
def _filter_query(
|
||||
q,
|
||||
*,
|
||||
task_id: uuid.UUID | None = None,
|
||||
workstream_id: uuid.UUID | None = None,
|
||||
repo_id: uuid.UUID | None = None,
|
||||
ref_type: str | None = None,
|
||||
ref_id: str | None = None,
|
||||
model: str | None = None,
|
||||
agent: str | None = None,
|
||||
note: str | None = None,
|
||||
measurement_kind: str | None = None,
|
||||
source_provider: str | None = None,
|
||||
since: datetime | None = None,
|
||||
until: datetime | None = None,
|
||||
include_superseded: bool = True,
|
||||
unattributed: bool = False,
|
||||
):
|
||||
if task_id:
|
||||
q = q.where(TokenEvent.task_id == task_id)
|
||||
if workstream_id:
|
||||
q = q.where(TokenEvent.workstream_id == workstream_id)
|
||||
if repo_id:
|
||||
q = q.where(TokenEvent.repo_id == repo_id)
|
||||
if ref_type:
|
||||
q = q.where(TokenEvent.ref_type == ref_type)
|
||||
if ref_id:
|
||||
q = q.where(TokenEvent.ref_id == ref_id)
|
||||
if model:
|
||||
q = q.where(TokenEvent.model == model)
|
||||
if agent:
|
||||
q = q.where(TokenEvent.agent == agent)
|
||||
if note:
|
||||
q = q.where(TokenEvent.note == note)
|
||||
if measurement_kind:
|
||||
q = q.where(TokenEvent.measurement_kind == measurement_kind)
|
||||
if source_provider:
|
||||
q = q.where(TokenEvent.source_provider == source_provider)
|
||||
if since:
|
||||
q = q.where(TokenEvent.created_at >= since)
|
||||
if until:
|
||||
q = q.where(TokenEvent.created_at < until)
|
||||
if not include_superseded:
|
||||
q = q.where(TokenEvent.measurement_kind != "superseded")
|
||||
if unattributed:
|
||||
q = q.where(
|
||||
TokenEvent.repo_id.is_(None),
|
||||
TokenEvent.workstream_id.is_(None),
|
||||
TokenEvent.task_id.is_(None),
|
||||
)
|
||||
return q
|
||||
|
||||
|
||||
@router.post("/", response_model=TokenEventRead, status_code=status.HTTP_201_CREATED)
|
||||
async def create_token_event(
|
||||
body: TokenEventCreate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> TokenEvent:
|
||||
data = body.model_dump(exclude_none=True)
|
||||
return await _create_or_upsert_event(data, session)
|
||||
|
||||
|
||||
@router.post("/upsert", response_model=TokenEventRead)
|
||||
async def upsert_token_event(
|
||||
body: TokenEventCreate,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> TokenEvent:
|
||||
data = body.model_dump(exclude_none=True)
|
||||
return await _create_or_upsert_event(data, session)
|
||||
|
||||
|
||||
@router.get("/summary/", response_model=TokenSummary)
|
||||
async def get_token_summary(
|
||||
scope: str = Query(..., description="task|workstream|repo|commit|release|session"),
|
||||
@@ -80,11 +258,16 @@ async def get_token_summary(
|
||||
|
||||
by_model: dict[str, int] = defaultdict(int)
|
||||
by_agent: dict[str, int] = defaultdict(int)
|
||||
by_measurement_kind: dict[str, int] = defaultdict(int)
|
||||
by_source_provider: dict[str, int] = defaultdict(int)
|
||||
for e in events:
|
||||
total = _event_total(e)
|
||||
if e.model:
|
||||
by_model[e.model] += e.tokens_in + e.tokens_out
|
||||
by_model[e.model] += total
|
||||
if e.agent:
|
||||
by_agent[e.agent] += e.tokens_in + e.tokens_out
|
||||
by_agent[e.agent] += total
|
||||
by_measurement_kind[e.measurement_kind] += total
|
||||
by_source_provider[e.source_provider] += total
|
||||
|
||||
return TokenSummary(
|
||||
scope=scope,
|
||||
@@ -95,11 +278,18 @@ async def get_token_summary(
|
||||
event_count=len(events),
|
||||
by_model=dict(by_model),
|
||||
by_agent=dict(by_agent),
|
||||
by_measurement_kind=dict(by_measurement_kind),
|
||||
by_source_provider=dict(by_source_provider),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/by-repo/", response_model=list[RepoTokenSummary])
|
||||
async def get_tokens_by_repo(
|
||||
measurement_kind: str | None = None,
|
||||
source_provider: str | None = None,
|
||||
since: datetime | None = None,
|
||||
until: datetime | None = None,
|
||||
include_superseded: bool = Query(True),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> list[RepoTokenSummary]:
|
||||
"""Aggregate token consumption per repo, resolving via the full graph.
|
||||
@@ -112,7 +302,16 @@ async def get_tokens_by_repo(
|
||||
Only events that resolve to a repo are included.
|
||||
"""
|
||||
# Fetch all events, workstreams, repos in three queries (avoids N+1)
|
||||
events_result = await session.execute(select(TokenEvent))
|
||||
events_result = await session.execute(
|
||||
_filter_query(
|
||||
select(TokenEvent),
|
||||
measurement_kind=measurement_kind,
|
||||
source_provider=source_provider,
|
||||
since=since,
|
||||
until=until,
|
||||
include_superseded=include_superseded,
|
||||
)
|
||||
)
|
||||
events = list(events_result.scalars().all())
|
||||
|
||||
ws_result = await session.execute(select(Workstream))
|
||||
@@ -148,14 +347,19 @@ async def get_tokens_by_repo(
|
||||
"event_count": 0,
|
||||
"by_model": defaultdict(int),
|
||||
"by_note": defaultdict(int),
|
||||
"by_measurement_kind": defaultdict(int),
|
||||
"by_source_provider": defaultdict(int),
|
||||
}
|
||||
g = groups[rid]
|
||||
g["tokens_in"] += e.tokens_in
|
||||
g["tokens_out"] += e.tokens_out
|
||||
g["event_count"] += 1
|
||||
total = _event_total(e)
|
||||
if e.model:
|
||||
g["by_model"][e.model] += e.tokens_in + e.tokens_out
|
||||
g["by_note"][e.note or "unknown"] += e.tokens_in + e.tokens_out
|
||||
g["by_model"][e.model] += total
|
||||
g["by_note"][e.note or "unknown"] += total
|
||||
g["by_measurement_kind"][e.measurement_kind] += total
|
||||
g["by_source_provider"][e.source_provider] += total
|
||||
|
||||
return [
|
||||
RepoTokenSummary(
|
||||
@@ -166,6 +370,188 @@ async def get_tokens_by_repo(
|
||||
]
|
||||
|
||||
|
||||
@router.get("/aggregate/", response_model=TokenAggregateSummary)
|
||||
async def get_token_aggregate(
|
||||
measurement_kind: str | None = None,
|
||||
source_provider: str | None = None,
|
||||
since: datetime | None = None,
|
||||
until: datetime | None = None,
|
||||
include_superseded: bool = Query(False),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> TokenAggregateSummary:
|
||||
events_result = await session.execute(
|
||||
_filter_query(
|
||||
select(TokenEvent),
|
||||
measurement_kind=measurement_kind,
|
||||
source_provider=source_provider,
|
||||
since=since,
|
||||
until=until,
|
||||
include_superseded=include_superseded,
|
||||
)
|
||||
)
|
||||
events = list(events_result.scalars().all())
|
||||
|
||||
ws_result = await session.execute(select(Workstream))
|
||||
ws_map: dict[uuid.UUID, Workstream] = {w.id: w for w in ws_result.scalars().all()}
|
||||
|
||||
task_result = await session.execute(select(Task))
|
||||
task_map: dict[uuid.UUID, Task] = {t.id: t for t in task_result.scalars().all()}
|
||||
|
||||
repo_result = await session.execute(select(ManagedRepo))
|
||||
repo_map: dict[uuid.UUID, ManagedRepo] = {r.id: r for r in repo_result.scalars().all()}
|
||||
|
||||
def resolve_repo_id(e: TokenEvent) -> uuid.UUID | None:
|
||||
if e.repo_id:
|
||||
return e.repo_id
|
||||
ws_id = e.workstream_id
|
||||
if not ws_id and e.task_id and e.task_id in task_map:
|
||||
ws_id = task_map[e.task_id].workstream_id
|
||||
if ws_id and ws_id in ws_map:
|
||||
return ws_map[ws_id].repo_id
|
||||
return None
|
||||
|
||||
def add(groups: dict[str, dict[str, Any]], key: str | None, label: str | None, e: TokenEvent) -> None:
|
||||
if not key:
|
||||
return
|
||||
if key not in groups:
|
||||
groups[key] = {
|
||||
"scope_id": key,
|
||||
"label": label,
|
||||
"tokens_in": 0,
|
||||
"tokens_out": 0,
|
||||
"event_count": 0,
|
||||
"by_measurement_kind": defaultdict(int),
|
||||
"by_source_provider": defaultdict(int),
|
||||
}
|
||||
row = groups[key]
|
||||
total = _event_total(e)
|
||||
row["tokens_in"] += e.tokens_in
|
||||
row["tokens_out"] += e.tokens_out
|
||||
row["event_count"] += 1
|
||||
row["by_measurement_kind"][e.measurement_kind] += total
|
||||
row["by_source_provider"][e.source_provider] += total
|
||||
|
||||
by_repo: dict[str, dict[str, Any]] = {}
|
||||
by_workstream: dict[str, dict[str, Any]] = {}
|
||||
by_task: dict[str, dict[str, Any]] = {}
|
||||
by_model: dict[str, dict[str, Any]] = {}
|
||||
by_measurement_kind: dict[str, int] = defaultdict(int)
|
||||
by_source_provider: dict[str, int] = defaultdict(int)
|
||||
|
||||
first_event_at = last_event_at = last_ingested_at = None
|
||||
tokens_in = tokens_out = 0
|
||||
for e in events:
|
||||
total = _event_total(e)
|
||||
tokens_in += e.tokens_in
|
||||
tokens_out += e.tokens_out
|
||||
by_measurement_kind[e.measurement_kind] += total
|
||||
by_source_provider[e.source_provider] += total
|
||||
|
||||
if first_event_at is None or e.created_at < first_event_at:
|
||||
first_event_at = e.created_at
|
||||
if last_event_at is None or e.created_at > last_event_at:
|
||||
last_event_at = e.created_at
|
||||
if last_ingested_at is None or e.ingested_at > last_ingested_at:
|
||||
last_ingested_at = e.ingested_at
|
||||
|
||||
rid = resolve_repo_id(e)
|
||||
repo = repo_map.get(rid) if rid else None
|
||||
add(by_repo, str(rid) if rid else None, repo.slug if repo else None, e)
|
||||
|
||||
ws_id = e.workstream_id or (task_map[e.task_id].workstream_id if e.task_id in task_map else None)
|
||||
ws = ws_map.get(ws_id) if ws_id else None
|
||||
add(by_workstream, str(ws_id) if ws_id else None, ws.title if ws else None, e)
|
||||
|
||||
task = task_map.get(e.task_id) if e.task_id else None
|
||||
add(by_task, str(e.task_id) if e.task_id else None, task.title if task else None, e)
|
||||
|
||||
add(by_model, e.model or "unknown", e.model or "unknown", e)
|
||||
|
||||
def rows(groups: dict[str, dict[str, Any]]) -> list[TokenAggregateRow]:
|
||||
result = []
|
||||
for row in groups.values():
|
||||
result.append(
|
||||
TokenAggregateRow(
|
||||
**{k: (dict(v) if isinstance(v, defaultdict) else v) for k, v in row.items()},
|
||||
tokens_total=row["tokens_in"] + row["tokens_out"],
|
||||
)
|
||||
)
|
||||
return sorted(result, key=lambda item: -item.tokens_total)
|
||||
|
||||
return TokenAggregateSummary(
|
||||
tokens_in=tokens_in,
|
||||
tokens_out=tokens_out,
|
||||
tokens_total=tokens_in + tokens_out,
|
||||
event_count=len(events),
|
||||
first_event_at=first_event_at,
|
||||
last_event_at=last_event_at,
|
||||
last_ingested_at=last_ingested_at,
|
||||
by_repo=rows(by_repo),
|
||||
by_workstream=rows(by_workstream),
|
||||
by_task=rows(by_task),
|
||||
by_model=rows(by_model),
|
||||
by_measurement_kind=dict(by_measurement_kind),
|
||||
by_source_provider=dict(by_source_provider),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/quality/", response_model=TokenQualitySummary)
|
||||
async def get_token_quality(
|
||||
since: datetime | None = None,
|
||||
until: datetime | None = None,
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> TokenQualitySummary:
|
||||
result = await session.execute(_filter_query(select(TokenEvent), since=since, until=until))
|
||||
events = list(result.scalars().all())
|
||||
|
||||
by_measurement_kind: dict[str, int] = defaultdict(int)
|
||||
by_source_provider: dict[str, int] = defaultdict(int)
|
||||
source_counts: dict[tuple[str, str, str], int] = defaultdict(int)
|
||||
last_codex_ingested_at = None
|
||||
last_claude_ingested_at = None
|
||||
|
||||
fallback_count = 0
|
||||
unattributed_measured_count = 0
|
||||
missing_provenance_count = 0
|
||||
for e in events:
|
||||
by_measurement_kind[e.measurement_kind] += 1
|
||||
by_source_provider[e.source_provider] += 1
|
||||
if e.source_id:
|
||||
source_counts[(e.measurement_kind, e.source_provider, e.source_id)] += 1
|
||||
if e.source_provider == "task_fallback" or e.note == "heuristic":
|
||||
fallback_count += 1
|
||||
if e.measurement_kind == "measured" and not (e.repo_id or e.workstream_id or e.task_id):
|
||||
unattributed_measured_count += 1
|
||||
if e.measurement_kind == "measured" and not e.source_id:
|
||||
missing_provenance_count += 1
|
||||
if e.source_provider == "codex_session" and (
|
||||
last_codex_ingested_at is None or e.ingested_at > last_codex_ingested_at
|
||||
):
|
||||
last_codex_ingested_at = e.ingested_at
|
||||
if e.source_provider == "claude_transcript" and (
|
||||
last_claude_ingested_at is None or e.ingested_at > last_claude_ingested_at
|
||||
):
|
||||
last_claude_ingested_at = e.ingested_at
|
||||
|
||||
duplicate_source_count = sum(1 for count in source_counts.values() if count > 1)
|
||||
return TokenQualitySummary(
|
||||
event_count=len(events),
|
||||
measured_event_count=by_measurement_kind.get("measured", 0),
|
||||
estimated_event_count=by_measurement_kind.get("estimated", 0),
|
||||
allocated_event_count=by_measurement_kind.get("allocated", 0),
|
||||
superseded_event_count=by_measurement_kind.get("superseded", 0),
|
||||
fallback_event_count=fallback_count,
|
||||
unattributed_measured_event_count=unattributed_measured_count,
|
||||
missing_provenance_event_count=missing_provenance_count,
|
||||
duplicate_source_count=duplicate_source_count,
|
||||
last_codex_ingested_at=last_codex_ingested_at,
|
||||
last_claude_ingested_at=last_claude_ingested_at,
|
||||
last_reconciliation_at=None,
|
||||
by_measurement_kind=dict(by_measurement_kind),
|
||||
by_source_provider=dict(by_source_provider),
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/{event_id}", response_model=TokenEventRead)
|
||||
async def patch_token_event(
|
||||
event_id: uuid.UUID,
|
||||
@@ -175,7 +561,26 @@ async def patch_token_event(
|
||||
event = await session.get(TokenEvent, event_id)
|
||||
if event is None:
|
||||
raise HTTPException(status_code=404, detail="Token event not found")
|
||||
for field, value in body.model_dump(exclude_none=True).items():
|
||||
data = body.model_dump(exclude_none=True)
|
||||
if "note" in data or "measurement_kind" in data or "source_provider" in data:
|
||||
merged = {
|
||||
"tokens_in": data.get("tokens_in", event.tokens_in),
|
||||
"tokens_out": data.get("tokens_out", event.tokens_out),
|
||||
"note": data.get("note", event.note),
|
||||
"agent": data.get("agent", event.agent),
|
||||
"ref_id": data.get("ref_id", event.ref_id),
|
||||
"session_id": data.get("session_id", event.session_id),
|
||||
"measurement_kind": data.get("measurement_kind", event.measurement_kind),
|
||||
"source_provider": data.get("source_provider", event.source_provider),
|
||||
"source_id": data.get("source_id", event.source_id),
|
||||
}
|
||||
inferred = _apply_event_defaults({k: v for k, v in merged.items() if v is not None})
|
||||
data.setdefault("measurement_kind", inferred["measurement_kind"])
|
||||
data.setdefault("source_provider", inferred["source_provider"])
|
||||
data.setdefault("confidence", inferred["confidence"])
|
||||
if inferred.get("source_id"):
|
||||
data.setdefault("source_id", inferred["source_id"])
|
||||
for field, value in data.items():
|
||||
setattr(event, field, value)
|
||||
await session.commit()
|
||||
await session.refresh(event)
|
||||
@@ -203,26 +608,33 @@ async def list_token_events(
|
||||
model: str | None = None,
|
||||
agent: str | None = None,
|
||||
note: str | None = None,
|
||||
measurement_kind: str | None = None,
|
||||
source_provider: str | None = None,
|
||||
since: datetime | None = None,
|
||||
until: datetime | None = None,
|
||||
include_superseded: bool = Query(True),
|
||||
unattributed: bool = False,
|
||||
offset: int = Query(0, ge=0),
|
||||
limit: int = Query(100, le=1000),
|
||||
session: AsyncSession = Depends(get_session),
|
||||
) -> list[TokenEvent]:
|
||||
q = select(TokenEvent)
|
||||
if task_id:
|
||||
q = q.where(TokenEvent.task_id == task_id)
|
||||
if workstream_id:
|
||||
q = q.where(TokenEvent.workstream_id == workstream_id)
|
||||
if repo_id:
|
||||
q = q.where(TokenEvent.repo_id == repo_id)
|
||||
if ref_type:
|
||||
q = q.where(TokenEvent.ref_type == ref_type)
|
||||
if ref_id:
|
||||
q = q.where(TokenEvent.ref_id == ref_id)
|
||||
if model:
|
||||
q = q.where(TokenEvent.model == model)
|
||||
if agent:
|
||||
q = q.where(TokenEvent.agent == agent)
|
||||
if note:
|
||||
q = q.where(TokenEvent.note == note)
|
||||
q = q.order_by(TokenEvent.created_at.desc()).limit(limit)
|
||||
q = _filter_query(
|
||||
select(TokenEvent),
|
||||
task_id=task_id,
|
||||
workstream_id=workstream_id,
|
||||
repo_id=repo_id,
|
||||
ref_type=ref_type,
|
||||
ref_id=ref_id,
|
||||
model=model,
|
||||
agent=agent,
|
||||
note=note,
|
||||
measurement_kind=measurement_kind,
|
||||
source_provider=source_provider,
|
||||
since=since,
|
||||
until=until,
|
||||
include_superseded=include_superseded,
|
||||
unattributed=unattributed,
|
||||
)
|
||||
q = q.order_by(TokenEvent.created_at.desc()).offset(offset).limit(limit)
|
||||
result = await session.execute(q)
|
||||
return list(result.scalars().all())
|
||||
|
||||
@@ -43,6 +43,7 @@ class TaskUpdate(BaseModel):
|
||||
# 2. workplan_tokens_in + workplan_tokens_out → prorated across task count (note="workplan")
|
||||
# 3. neither provided, status=done → heuristic 1000/500 (note="heuristic")
|
||||
# token_note overrides the auto-assigned note for Tier 1 only (e.g. "userbased")
|
||||
# suppress_token_event lets file/cache sync update status without recording usage.
|
||||
tokens_in: int | None = None
|
||||
tokens_out: int | None = None
|
||||
workplan_tokens_in: int | None = None
|
||||
@@ -51,6 +52,7 @@ class TaskUpdate(BaseModel):
|
||||
model: str | None = None
|
||||
agent: str | None = None
|
||||
session_id: str | None = None
|
||||
suppress_token_event: bool | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def blocking_reason_required_when_blocked(self) -> Self:
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, computed_field
|
||||
from pydantic import BaseModel, ConfigDict, Field, computed_field
|
||||
|
||||
|
||||
class TokenEventCreate(BaseModel):
|
||||
@@ -16,6 +17,19 @@ class TokenEventCreate(BaseModel):
|
||||
ref_type: str | None = None
|
||||
ref_id: str | None = None
|
||||
note: str | None = None
|
||||
created_at: datetime | None = None
|
||||
measurement_kind: str | None = None
|
||||
source_provider: str | None = None
|
||||
source_id: str | None = None
|
||||
source_path: str | None = None
|
||||
source_created_at: datetime | None = None
|
||||
parser_version: str | None = None
|
||||
confidence: float | None = None
|
||||
cached_input_tokens: int | None = None
|
||||
reasoning_output_tokens: int | None = None
|
||||
raw_total_tokens: int | None = None
|
||||
cost_estimated_usd: float | None = None
|
||||
raw_metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class TokenEventRead(BaseModel):
|
||||
@@ -33,6 +47,19 @@ class TokenEventRead(BaseModel):
|
||||
ref_type: str | None = None
|
||||
ref_id: str | None = None
|
||||
note: str | None = None
|
||||
measurement_kind: str
|
||||
source_provider: str
|
||||
source_id: str | None = None
|
||||
source_path: str | None = None
|
||||
source_created_at: datetime | None = None
|
||||
ingested_at: datetime
|
||||
parser_version: str | None = None
|
||||
confidence: float
|
||||
cached_input_tokens: int
|
||||
reasoning_output_tokens: int
|
||||
raw_total_tokens: int | None = None
|
||||
cost_estimated_usd: float | None = None
|
||||
raw_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
created_at: datetime
|
||||
|
||||
@computed_field
|
||||
@@ -40,6 +67,11 @@ class TokenEventRead(BaseModel):
|
||||
def tokens_total(self) -> int:
|
||||
return self.tokens_in + self.tokens_out
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def token_evidence_total(self) -> int:
|
||||
return (self.raw_total_tokens or self.tokens_in + self.tokens_out)
|
||||
|
||||
|
||||
class TokenSummary(BaseModel):
|
||||
scope: str
|
||||
@@ -50,14 +82,36 @@ class TokenSummary(BaseModel):
|
||||
event_count: int
|
||||
by_model: dict[str, int]
|
||||
by_agent: dict[str, int]
|
||||
by_measurement_kind: dict[str, int] = Field(default_factory=dict)
|
||||
by_source_provider: dict[str, int] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class TokenEventPatch(BaseModel):
|
||||
tokens_in: int | None = None
|
||||
tokens_out: int | None = None
|
||||
task_id: uuid.UUID | None = None
|
||||
workstream_id: uuid.UUID | None = None
|
||||
repo_id: uuid.UUID | None = None
|
||||
session_id: str | None = None
|
||||
note: str | None = None
|
||||
model: str | None = None
|
||||
agent: str | None = None
|
||||
ref_type: str | None = None
|
||||
ref_id: str | None = None
|
||||
created_at: datetime | None = None
|
||||
measurement_kind: str | None = None
|
||||
source_provider: str | None = None
|
||||
source_id: str | None = None
|
||||
source_path: str | None = None
|
||||
source_created_at: datetime | None = None
|
||||
ingested_at: datetime | None = None
|
||||
parser_version: str | None = None
|
||||
confidence: float | None = None
|
||||
cached_input_tokens: int | None = None
|
||||
reasoning_output_tokens: int | None = None
|
||||
raw_total_tokens: int | None = None
|
||||
cost_estimated_usd: float | None = None
|
||||
raw_metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class RepoTokenSummary(BaseModel):
|
||||
@@ -69,3 +123,49 @@ class RepoTokenSummary(BaseModel):
|
||||
event_count: int
|
||||
by_model: dict[str, int]
|
||||
by_note: dict[str, int]
|
||||
by_measurement_kind: dict[str, int] = Field(default_factory=dict)
|
||||
by_source_provider: dict[str, int] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class TokenAggregateRow(BaseModel):
|
||||
scope_id: str
|
||||
label: str | None = None
|
||||
tokens_in: int
|
||||
tokens_out: int
|
||||
tokens_total: int
|
||||
event_count: int
|
||||
by_measurement_kind: dict[str, int] = Field(default_factory=dict)
|
||||
by_source_provider: dict[str, int] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class TokenAggregateSummary(BaseModel):
|
||||
tokens_in: int
|
||||
tokens_out: int
|
||||
tokens_total: int
|
||||
event_count: int
|
||||
first_event_at: datetime | None = None
|
||||
last_event_at: datetime | None = None
|
||||
last_ingested_at: datetime | None = None
|
||||
by_repo: list[TokenAggregateRow] = Field(default_factory=list)
|
||||
by_workstream: list[TokenAggregateRow] = Field(default_factory=list)
|
||||
by_task: list[TokenAggregateRow] = Field(default_factory=list)
|
||||
by_model: list[TokenAggregateRow] = Field(default_factory=list)
|
||||
by_measurement_kind: dict[str, int] = Field(default_factory=dict)
|
||||
by_source_provider: dict[str, int] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class TokenQualitySummary(BaseModel):
|
||||
event_count: int
|
||||
measured_event_count: int
|
||||
estimated_event_count: int
|
||||
allocated_event_count: int
|
||||
superseded_event_count: int
|
||||
fallback_event_count: int
|
||||
unattributed_measured_event_count: int
|
||||
missing_provenance_event_count: int
|
||||
duplicate_source_count: int
|
||||
last_codex_ingested_at: datetime | None = None
|
||||
last_claude_ingested_at: datetime | None = None
|
||||
last_reconciliation_at: datetime | None = None
|
||||
by_measurement_kind: dict[str, int] = Field(default_factory=dict)
|
||||
by_source_provider: dict[str, int] = Field(default_factory=dict)
|
||||
|
||||
16
api/services/token_sources/__init__.py
Normal file
16
api/services/token_sources/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Token source adapters for measured agent usage."""
|
||||
|
||||
from api.services.token_sources.base import TokenSourceRecord, parse_iso
|
||||
from api.services.token_sources.codex import collect_codex_sessions, iter_codex_session_files, parse_codex_session
|
||||
from api.services.token_sources.claude import collect_claude_transcripts, iter_claude_transcript_files, parse_claude_transcript
|
||||
|
||||
__all__ = [
|
||||
"TokenSourceRecord",
|
||||
"parse_iso",
|
||||
"collect_codex_sessions",
|
||||
"iter_codex_session_files",
|
||||
"parse_codex_session",
|
||||
"collect_claude_transcripts",
|
||||
"iter_claude_transcript_files",
|
||||
"parse_claude_transcript",
|
||||
]
|
||||
171
api/services/token_sources/attribution.py
Normal file
171
api/services/token_sources/attribution.py
Normal file
@@ -0,0 +1,171 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RepoRef:
|
||||
repo_id: str
|
||||
slug: str
|
||||
local_path: str | None = None
|
||||
host_paths: dict[str, Any] | None = None
|
||||
remote_url: str | None = None
|
||||
git_fingerprint: str | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RepoMatch:
|
||||
repo_id: str
|
||||
slug: str
|
||||
method: str
|
||||
confidence: float
|
||||
|
||||
|
||||
def normalise_cwd(raw: str | None) -> str | None:
|
||||
if not raw:
|
||||
return None
|
||||
value = raw.replace("\\", "/")
|
||||
prefixes = (
|
||||
"//wsl.localhost/Ubuntu-24.04",
|
||||
"//wsl$/Ubuntu-24.04",
|
||||
)
|
||||
for prefix in prefixes:
|
||||
if value.startswith(prefix):
|
||||
return value[len(prefix):] or "/"
|
||||
if len(value) >= 3 and value[1:3] == ":/":
|
||||
drive = value[0].lower()
|
||||
return f"/mnt/{drive}{value[2:]}"
|
||||
return value
|
||||
|
||||
|
||||
def normalise_remote_url(raw: str | None) -> str | None:
|
||||
if not raw:
|
||||
return None
|
||||
value = raw.strip()
|
||||
if value.endswith(".git"):
|
||||
value = value[:-4]
|
||||
if value.startswith("git@") and ":" in value:
|
||||
host, path = value[4:].split(":", 1)
|
||||
value = f"ssh://{host}/{path}"
|
||||
return value.lower().rstrip("/")
|
||||
|
||||
|
||||
def repo_refs_from_api(repos: list[dict[str, Any]]) -> list[RepoRef]:
|
||||
refs = []
|
||||
for repo in repos:
|
||||
repo_id = repo.get("id")
|
||||
slug = repo.get("slug")
|
||||
if not repo_id or not slug:
|
||||
continue
|
||||
refs.append(
|
||||
RepoRef(
|
||||
repo_id=str(repo_id),
|
||||
slug=str(slug),
|
||||
local_path=repo.get("local_path"),
|
||||
host_paths=repo.get("host_paths") if isinstance(repo.get("host_paths"), dict) else {},
|
||||
remote_url=repo.get("remote_url"),
|
||||
git_fingerprint=repo.get("git_fingerprint"),
|
||||
)
|
||||
)
|
||||
return refs
|
||||
|
||||
|
||||
def _git(cwd: str, *args: str) -> str | None:
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", *args],
|
||||
cwd=cwd,
|
||||
check=False,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
except (OSError, subprocess.SubprocessError):
|
||||
return None
|
||||
if result.returncode != 0:
|
||||
return None
|
||||
value = result.stdout.strip().splitlines()
|
||||
return value[0] if value else None
|
||||
|
||||
|
||||
def git_fingerprint_for_path(cwd: str | None) -> str | None:
|
||||
path = normalise_cwd(cwd)
|
||||
if not path or not Path(path).exists():
|
||||
return None
|
||||
root = _git(path, "rev-parse", "--show-toplevel")
|
||||
if not root:
|
||||
return None
|
||||
return _git(root, "rev-list", "--max-parents=0", "HEAD")
|
||||
|
||||
|
||||
def git_remote_for_path(cwd: str | None) -> str | None:
|
||||
path = normalise_cwd(cwd)
|
||||
if not path or not Path(path).exists():
|
||||
return None
|
||||
root = _git(path, "rev-parse", "--show-toplevel")
|
||||
if not root:
|
||||
return None
|
||||
return _git(root, "remote", "get-url", "origin")
|
||||
|
||||
|
||||
def _repo_paths(repo: RepoRef) -> list[str]:
|
||||
paths = [repo.local_path]
|
||||
if repo.host_paths:
|
||||
paths.extend(str(v) for v in repo.host_paths.values() if v)
|
||||
result = []
|
||||
for raw in paths:
|
||||
path = normalise_cwd(str(raw)) if raw and raw != "(unknown)" else None
|
||||
if path:
|
||||
result.append(path.rstrip("/"))
|
||||
return result
|
||||
|
||||
|
||||
def resolve_repo(cwd: str | None, repos: list[RepoRef]) -> RepoMatch | None:
|
||||
path = normalise_cwd(cwd)
|
||||
fingerprint = git_fingerprint_for_path(path)
|
||||
remote = normalise_remote_url(git_remote_for_path(path))
|
||||
|
||||
if fingerprint:
|
||||
candidates = [repo for repo in repos if repo.git_fingerprint == fingerprint]
|
||||
if len(candidates) == 1:
|
||||
repo = candidates[0]
|
||||
return RepoMatch(repo.repo_id, repo.slug, "git_fingerprint", 0.98)
|
||||
if remote:
|
||||
remote_candidates = [
|
||||
repo for repo in candidates
|
||||
if normalise_remote_url(repo.remote_url) == remote
|
||||
]
|
||||
if len(remote_candidates) == 1:
|
||||
repo = remote_candidates[0]
|
||||
return RepoMatch(repo.repo_id, repo.slug, "git_fingerprint_remote", 0.99)
|
||||
|
||||
if remote:
|
||||
candidates = [repo for repo in repos if normalise_remote_url(repo.remote_url) == remote]
|
||||
if len(candidates) == 1:
|
||||
repo = candidates[0]
|
||||
return RepoMatch(repo.repo_id, repo.slug, "remote_url", 0.90)
|
||||
|
||||
if not path:
|
||||
return None
|
||||
|
||||
path_matches: list[tuple[str, RepoRef]] = []
|
||||
for repo in repos:
|
||||
for repo_path in _repo_paths(repo):
|
||||
if path == repo_path or path.startswith(f"{repo_path}/"):
|
||||
path_matches.append((repo_path, repo))
|
||||
if not path_matches:
|
||||
return None
|
||||
path_matches.sort(key=lambda item: len(item[0]), reverse=True)
|
||||
exact = [item for item in path_matches if path == item[0]]
|
||||
if exact:
|
||||
basename = Path(path).name
|
||||
for _, repo in exact:
|
||||
if repo.slug == basename:
|
||||
return RepoMatch(repo.repo_id, repo.slug, "path_exact_slug", 0.85)
|
||||
repo = exact[0][1]
|
||||
return RepoMatch(repo.repo_id, repo.slug, "path_exact", 0.80)
|
||||
repo = path_matches[0][1]
|
||||
return RepoMatch(repo.repo_id, repo.slug, "path_prefix", 0.75)
|
||||
71
api/services/token_sources/base.py
Normal file
71
api/services/token_sources/base.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
def parse_iso(value: str) -> datetime:
|
||||
raw = value.strip()
|
||||
if raw.endswith("Z"):
|
||||
raw = raw[:-1] + "+00:00"
|
||||
if "T" not in raw:
|
||||
raw = f"{raw}T00:00:00+00:00"
|
||||
parsed = datetime.fromisoformat(raw)
|
||||
if parsed.tzinfo is None:
|
||||
parsed = parsed.replace(tzinfo=timezone.utc)
|
||||
return parsed.astimezone(timezone.utc)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenSourceRecord:
|
||||
source_provider: str
|
||||
source_id: str
|
||||
source_path: Path
|
||||
source_created_at: datetime | None
|
||||
session_id: str | None = None
|
||||
cwd: str | None = None
|
||||
model: str | None = None
|
||||
agent: str | None = None
|
||||
tokens_in: int = 0
|
||||
tokens_out: int = 0
|
||||
cached_input_tokens: int = 0
|
||||
reasoning_output_tokens: int = 0
|
||||
raw_total_tokens: int | None = None
|
||||
parser_version: str | None = None
|
||||
confidence: float = 1.0
|
||||
raw_metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@property
|
||||
def tokens_total(self) -> int:
|
||||
return self.tokens_in + self.tokens_out
|
||||
|
||||
def to_token_event_payload(self, repo_id: str | None = None) -> dict[str, Any]:
|
||||
raw_total = self.raw_total_tokens
|
||||
if raw_total is None:
|
||||
raw_total = self.tokens_in + self.tokens_out
|
||||
created_at = self.source_created_at.isoformat() if self.source_created_at else None
|
||||
return {
|
||||
"tokens_in": self.tokens_in,
|
||||
"tokens_out": self.tokens_out,
|
||||
"repo_id": repo_id,
|
||||
"session_id": self.session_id,
|
||||
"model": self.model,
|
||||
"agent": self.agent,
|
||||
"ref_type": "session",
|
||||
"ref_id": self.source_id,
|
||||
"note": f"measured:{self.source_provider}",
|
||||
"created_at": created_at,
|
||||
"measurement_kind": "measured",
|
||||
"source_provider": self.source_provider,
|
||||
"source_id": self.source_id,
|
||||
"source_path": str(self.source_path),
|
||||
"source_created_at": created_at,
|
||||
"parser_version": self.parser_version,
|
||||
"confidence": self.confidence,
|
||||
"cached_input_tokens": self.cached_input_tokens,
|
||||
"reasoning_output_tokens": self.reasoning_output_tokens,
|
||||
"raw_total_tokens": raw_total,
|
||||
"raw_metadata": self.raw_metadata,
|
||||
}
|
||||
120
api/services/token_sources/claude.py
Normal file
120
api/services/token_sources/claude.py
Normal file
@@ -0,0 +1,120 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from api.services.token_sources.base import TokenSourceRecord, parse_iso
|
||||
|
||||
PARSER_VERSION = "claude-transcript-v1"
|
||||
|
||||
|
||||
def iter_claude_transcript_files(claude_home: Path) -> list[Path]:
|
||||
projects = claude_home / "projects"
|
||||
if not projects.is_dir():
|
||||
return []
|
||||
return sorted(projects.glob("**/*.jsonl"))
|
||||
|
||||
|
||||
def _usage_from_entry(entry: dict[str, Any]) -> dict[str, Any]:
|
||||
message = entry.get("message")
|
||||
if isinstance(message, dict) and isinstance(message.get("usage"), dict):
|
||||
return message["usage"]
|
||||
usage = entry.get("usage")
|
||||
return usage if isinstance(usage, dict) else {}
|
||||
|
||||
|
||||
def parse_claude_transcript(path: Path, since: datetime) -> TokenSourceRecord | None:
|
||||
session_id = path.stem
|
||||
cwd: str | None = None
|
||||
model: str | None = None
|
||||
first_at: datetime | None = None
|
||||
last_at: datetime | None = None
|
||||
tokens_in = tokens_out = 0
|
||||
cached_input_tokens = 0
|
||||
raw_total_tokens = 0
|
||||
usage_records = 0
|
||||
malformed_lines = 0
|
||||
|
||||
try:
|
||||
handle = path.open("r", encoding="utf-8", errors="ignore")
|
||||
except OSError:
|
||||
return None
|
||||
|
||||
with handle:
|
||||
for line in handle:
|
||||
try:
|
||||
entry: dict[str, Any] = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
malformed_lines += 1
|
||||
continue
|
||||
|
||||
ts = entry.get("timestamp") or entry.get("created_at")
|
||||
parsed_ts = parse_iso(ts) if isinstance(ts, str) else None
|
||||
if parsed_ts:
|
||||
first_at = first_at or parsed_ts
|
||||
last_at = parsed_ts
|
||||
|
||||
session_id = str(entry.get("session_id") or entry.get("conversation_id") or session_id)
|
||||
cwd = entry.get("cwd") or entry.get("project_cwd") or cwd
|
||||
model = entry.get("model") or model
|
||||
message = entry.get("message")
|
||||
if isinstance(message, dict):
|
||||
model = message.get("model") or model
|
||||
|
||||
usage = _usage_from_entry(entry)
|
||||
if not usage:
|
||||
continue
|
||||
if parsed_ts is not None and parsed_ts < since:
|
||||
continue
|
||||
|
||||
input_tokens = int(usage.get("input_tokens") or 0)
|
||||
cache_creation = int(usage.get("cache_creation_input_tokens") or 0)
|
||||
cache_read = int(usage.get("cache_read_input_tokens") or 0)
|
||||
output_tokens = int(usage.get("output_tokens") or 0)
|
||||
if input_tokens == 0 and output_tokens == 0 and cache_creation == 0 and cache_read == 0:
|
||||
continue
|
||||
tokens_in += input_tokens
|
||||
tokens_out += output_tokens
|
||||
cached_input_tokens += cache_creation + cache_read
|
||||
raw_total_tokens += input_tokens + cache_creation + cache_read + output_tokens
|
||||
usage_records += 1
|
||||
|
||||
if usage_records == 0 or tokens_in + tokens_out + cached_input_tokens == 0:
|
||||
return None
|
||||
|
||||
return TokenSourceRecord(
|
||||
source_provider="claude_transcript",
|
||||
source_id=f"claude:{session_id}",
|
||||
source_path=path,
|
||||
source_created_at=last_at,
|
||||
session_id=session_id,
|
||||
cwd=cwd,
|
||||
model=model,
|
||||
agent="claude",
|
||||
tokens_in=tokens_in,
|
||||
tokens_out=tokens_out,
|
||||
cached_input_tokens=cached_input_tokens,
|
||||
raw_total_tokens=raw_total_tokens or None,
|
||||
parser_version=PARSER_VERSION,
|
||||
confidence=1.0,
|
||||
raw_metadata={
|
||||
"started_at": first_at.isoformat() if first_at else None,
|
||||
"usage_records": usage_records,
|
||||
"malformed_lines": malformed_lines,
|
||||
"source_file_name": path.name,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def collect_claude_transcripts(claude_home: Path, since: datetime) -> list[TokenSourceRecord]:
|
||||
by_id: dict[str, TokenSourceRecord] = {}
|
||||
for path in iter_claude_transcript_files(claude_home):
|
||||
parsed = parse_claude_transcript(path, since)
|
||||
if parsed is None:
|
||||
continue
|
||||
current = by_id.get(parsed.source_id)
|
||||
if current is None or parsed.tokens_total > current.tokens_total:
|
||||
by_id[parsed.source_id] = parsed
|
||||
return sorted(by_id.values(), key=lambda item: item.source_created_at or datetime.min.replace(tzinfo=since.tzinfo))
|
||||
124
api/services/token_sources/codex.py
Normal file
124
api/services/token_sources/codex.py
Normal file
@@ -0,0 +1,124 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from api.services.token_sources.base import TokenSourceRecord, parse_iso
|
||||
|
||||
PARSER_VERSION = "codex-desktop-v1"
|
||||
|
||||
|
||||
def iter_codex_session_files(codex_home: Path) -> list[Path]:
|
||||
files: list[Path] = []
|
||||
sessions = codex_home / "sessions"
|
||||
archived = codex_home / "archived_sessions"
|
||||
if sessions.is_dir():
|
||||
files.extend(sorted(sessions.glob("*/*/*/*.jsonl")))
|
||||
if archived.is_dir():
|
||||
files.extend(sorted(archived.glob("*.jsonl")))
|
||||
return files
|
||||
|
||||
|
||||
def parse_codex_session(path: Path, since: datetime) -> TokenSourceRecord | None:
|
||||
fallback_id = path.stem.removeprefix("rollout-")
|
||||
session_id = fallback_id
|
||||
started_at: datetime | None = None
|
||||
last_at: datetime | None = None
|
||||
cwd: str | None = None
|
||||
model: str | None = None
|
||||
tokens_in = tokens_out = 0
|
||||
cached_input_tokens = reasoning_output_tokens = 0
|
||||
raw_total_tokens = 0
|
||||
usage_records = 0
|
||||
malformed_lines = 0
|
||||
|
||||
try:
|
||||
handle = path.open("r", encoding="utf-8", errors="ignore")
|
||||
except OSError:
|
||||
return None
|
||||
|
||||
with handle:
|
||||
for line in handle:
|
||||
try:
|
||||
entry: dict[str, Any] = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
malformed_lines += 1
|
||||
continue
|
||||
|
||||
ts = entry.get("timestamp")
|
||||
parsed_ts = parse_iso(ts) if isinstance(ts, str) else None
|
||||
if parsed_ts:
|
||||
last_at = parsed_ts
|
||||
started_at = started_at or parsed_ts
|
||||
|
||||
payload = entry.get("payload") or {}
|
||||
if entry.get("type") == "session_meta":
|
||||
meta_id = payload.get("id")
|
||||
if meta_id:
|
||||
session_id = str(meta_id)
|
||||
cwd = payload.get("cwd") or cwd
|
||||
meta_ts = payload.get("timestamp")
|
||||
if isinstance(meta_ts, str):
|
||||
started_at = parse_iso(meta_ts)
|
||||
elif entry.get("type") == "turn_context":
|
||||
cwd = payload.get("cwd") or cwd
|
||||
model = payload.get("model") or model
|
||||
elif entry.get("type") == "event_msg" and payload.get("type") == "token_count":
|
||||
if parsed_ts is None or parsed_ts < since:
|
||||
continue
|
||||
info = payload.get("info") or {}
|
||||
last = info.get("last_token_usage") or {}
|
||||
if not isinstance(last, dict):
|
||||
continue
|
||||
input_tokens = int(last.get("input_tokens") or 0)
|
||||
output_tokens = int(last.get("output_tokens") or 0)
|
||||
if input_tokens == 0 and output_tokens == 0:
|
||||
continue
|
||||
tokens_in += input_tokens
|
||||
tokens_out += output_tokens
|
||||
cached_input_tokens += int(last.get("cached_input_tokens") or 0)
|
||||
reasoning_output_tokens += int(last.get("reasoning_output_tokens") or 0)
|
||||
raw_total_tokens += int(last.get("total_tokens") or input_tokens + output_tokens)
|
||||
usage_records += 1
|
||||
last_at = parsed_ts
|
||||
|
||||
if usage_records == 0 or tokens_in + tokens_out == 0:
|
||||
return None
|
||||
|
||||
return TokenSourceRecord(
|
||||
source_provider="codex_session",
|
||||
source_id=f"codex:{session_id}",
|
||||
source_path=path,
|
||||
source_created_at=last_at,
|
||||
session_id=session_id,
|
||||
cwd=cwd,
|
||||
model=model,
|
||||
agent="codex",
|
||||
tokens_in=tokens_in,
|
||||
tokens_out=tokens_out,
|
||||
cached_input_tokens=cached_input_tokens,
|
||||
reasoning_output_tokens=reasoning_output_tokens,
|
||||
raw_total_tokens=raw_total_tokens or None,
|
||||
parser_version=PARSER_VERSION,
|
||||
confidence=1.0,
|
||||
raw_metadata={
|
||||
"started_at": started_at.isoformat() if started_at else None,
|
||||
"usage_records": usage_records,
|
||||
"malformed_lines": malformed_lines,
|
||||
"source_file_name": path.name,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def collect_codex_sessions(codex_home: Path, since: datetime) -> list[TokenSourceRecord]:
|
||||
by_id: dict[str, TokenSourceRecord] = {}
|
||||
for path in iter_codex_session_files(codex_home):
|
||||
parsed = parse_codex_session(path, since)
|
||||
if parsed is None:
|
||||
continue
|
||||
current = by_id.get(parsed.source_id)
|
||||
if current is None or parsed.tokens_total > current.tokens_total:
|
||||
by_id[parsed.source_id] = parsed
|
||||
return sorted(by_id.values(), key=lambda item: item.source_created_at or datetime.min.replace(tzinfo=since.tzinfo))
|
||||
Reference in New Issue
Block a user