generated from coulomb/repo-seed
641 lines
24 KiB
Python
641 lines
24 KiB
Python
import uuid
|
|
from collections import defaultdict
|
|
from datetime import datetime
|
|
from typing import Any
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from api.database import get_session
|
|
from api.models.managed_repo import ManagedRepo
|
|
from api.models.task import Task
|
|
from api.models.token_event import TokenEvent
|
|
from api.models.workstream import Workstream
|
|
from api.schemas.token_event import (
|
|
RepoTokenSummary,
|
|
TokenAggregateRow,
|
|
TokenAggregateSummary,
|
|
TokenEventCreate,
|
|
TokenEventPatch,
|
|
TokenEventRead,
|
|
TokenQualitySummary,
|
|
TokenSummary,
|
|
)
|
|
|
|
router = APIRouter(prefix="/token-events", tags=["token-events"])
|
|
|
|
DEFAULT_CONFIDENCE = {
|
|
"measured": 1.0,
|
|
"allocated": 0.70,
|
|
"estimated": 0.35,
|
|
"superseded": 0.0,
|
|
}
|
|
|
|
SOURCE_PARSER_DEFAULTS = {
|
|
"codex_session": "codex-desktop-v1",
|
|
"claude_transcript": "claude-transcript-v1",
|
|
"llm_connect": "llm-connect-v1",
|
|
}
|
|
|
|
|
|
def _event_total(event: TokenEvent) -> int:
|
|
return event.tokens_in + event.tokens_out
|
|
|
|
|
|
def _infer_measurement_kind(data: dict[str, Any]) -> str:
|
|
if data.get("measurement_kind"):
|
|
return str(data["measurement_kind"])
|
|
note = data.get("note")
|
|
if note == "heuristic_superseded_by_codex_backfill":
|
|
return "superseded"
|
|
if note == "workplan":
|
|
return "allocated"
|
|
if note == "heuristic":
|
|
return "estimated"
|
|
if note == "measured" or str(note or "").startswith("backfill:codex-session"):
|
|
return "measured"
|
|
provider = data.get("source_provider")
|
|
if provider in {"codex_session", "claude_transcript", "llm_connect"}:
|
|
return "measured"
|
|
return "estimated"
|
|
|
|
|
|
def _infer_source_provider(data: dict[str, Any], measurement_kind: str) -> str:
|
|
if data.get("source_provider"):
|
|
return str(data["source_provider"])
|
|
note = data.get("note")
|
|
ref_id = str(data.get("ref_id") or "")
|
|
agent = str(data.get("agent") or "").lower()
|
|
if note == "heuristic":
|
|
return "task_fallback"
|
|
if ref_id.startswith("codex:") or str(note or "").startswith("backfill:codex-session"):
|
|
return "codex_session"
|
|
if measurement_kind == "measured" and "claude" in agent:
|
|
return "claude_transcript"
|
|
return "manual"
|
|
|
|
|
|
def _apply_event_defaults(data: dict[str, Any]) -> dict[str, Any]:
|
|
measurement_kind = _infer_measurement_kind(data)
|
|
source_provider = _infer_source_provider(data, measurement_kind)
|
|
data["measurement_kind"] = measurement_kind
|
|
data["source_provider"] = source_provider
|
|
|
|
if not data.get("source_id") and source_provider in {"codex_session", "claude_transcript", "llm_connect"}:
|
|
source_id = data.get("ref_id") or data.get("session_id")
|
|
if source_id:
|
|
data["source_id"] = str(source_id)
|
|
|
|
if not data.get("source_created_at") and data.get("created_at") and data.get("source_id"):
|
|
data["source_created_at"] = data["created_at"]
|
|
|
|
data.setdefault("confidence", DEFAULT_CONFIDENCE.get(measurement_kind, 0.35))
|
|
data.setdefault("cached_input_tokens", 0)
|
|
data.setdefault("reasoning_output_tokens", 0)
|
|
data.setdefault("raw_total_tokens", (data.get("tokens_in") or 0) + (data.get("tokens_out") or 0))
|
|
data.setdefault("raw_metadata", {})
|
|
if source_provider in SOURCE_PARSER_DEFAULTS:
|
|
data.setdefault("parser_version", SOURCE_PARSER_DEFAULTS[source_provider])
|
|
return data
|
|
|
|
|
|
async def _populate_relationship_defaults(data: dict[str, Any], session: AsyncSession) -> dict[str, Any]:
|
|
# Auto-populate workstream_id from task if not provided
|
|
if data.get("task_id") and not data.get("workstream_id"):
|
|
task = await session.get(Task, data["task_id"])
|
|
if task:
|
|
data["workstream_id"] = task.workstream_id
|
|
|
|
# Auto-populate repo_id from workstream if not provided
|
|
if data.get("workstream_id") and not data.get("repo_id"):
|
|
ws = await session.get(Workstream, data["workstream_id"])
|
|
if ws and ws.repo_id:
|
|
data["repo_id"] = ws.repo_id
|
|
return data
|
|
|
|
|
|
async def _find_source_event(data: dict[str, Any], session: AsyncSession) -> TokenEvent | None:
|
|
source_id = data.get("source_id")
|
|
if not source_id:
|
|
return None
|
|
result = await session.execute(
|
|
select(TokenEvent).where(
|
|
TokenEvent.measurement_kind == data["measurement_kind"],
|
|
TokenEvent.source_provider == data["source_provider"],
|
|
TokenEvent.source_id == source_id,
|
|
)
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
async def _create_or_upsert_event(data: dict[str, Any], session: AsyncSession) -> TokenEvent:
|
|
data = _apply_event_defaults(data)
|
|
data = await _populate_relationship_defaults(data, session)
|
|
|
|
existing = await _find_source_event(data, session)
|
|
if existing is not None:
|
|
for field, value in data.items():
|
|
setattr(existing, field, value)
|
|
await session.commit()
|
|
await session.refresh(existing)
|
|
return existing
|
|
|
|
event = TokenEvent(**data)
|
|
session.add(event)
|
|
await session.commit()
|
|
await session.refresh(event)
|
|
return event
|
|
|
|
|
|
def _filter_query(
|
|
q,
|
|
*,
|
|
task_id: uuid.UUID | None = None,
|
|
workstream_id: uuid.UUID | None = None,
|
|
repo_id: uuid.UUID | None = None,
|
|
ref_type: str | None = None,
|
|
ref_id: str | None = None,
|
|
model: str | None = None,
|
|
agent: str | None = None,
|
|
note: str | None = None,
|
|
measurement_kind: str | None = None,
|
|
source_provider: str | None = None,
|
|
since: datetime | None = None,
|
|
until: datetime | None = None,
|
|
include_superseded: bool = True,
|
|
unattributed: bool = False,
|
|
):
|
|
if task_id:
|
|
q = q.where(TokenEvent.task_id == task_id)
|
|
if workstream_id:
|
|
q = q.where(TokenEvent.workstream_id == workstream_id)
|
|
if repo_id:
|
|
q = q.where(TokenEvent.repo_id == repo_id)
|
|
if ref_type:
|
|
q = q.where(TokenEvent.ref_type == ref_type)
|
|
if ref_id:
|
|
q = q.where(TokenEvent.ref_id == ref_id)
|
|
if model:
|
|
q = q.where(TokenEvent.model == model)
|
|
if agent:
|
|
q = q.where(TokenEvent.agent == agent)
|
|
if note:
|
|
q = q.where(TokenEvent.note == note)
|
|
if measurement_kind:
|
|
q = q.where(TokenEvent.measurement_kind == measurement_kind)
|
|
if source_provider:
|
|
q = q.where(TokenEvent.source_provider == source_provider)
|
|
if since:
|
|
q = q.where(TokenEvent.created_at >= since)
|
|
if until:
|
|
q = q.where(TokenEvent.created_at < until)
|
|
if not include_superseded:
|
|
q = q.where(TokenEvent.measurement_kind != "superseded")
|
|
if unattributed:
|
|
q = q.where(
|
|
TokenEvent.repo_id.is_(None),
|
|
TokenEvent.workstream_id.is_(None),
|
|
TokenEvent.task_id.is_(None),
|
|
)
|
|
return q
|
|
|
|
|
|
@router.post("/", response_model=TokenEventRead, status_code=status.HTTP_201_CREATED)
|
|
async def create_token_event(
|
|
body: TokenEventCreate,
|
|
session: AsyncSession = Depends(get_session),
|
|
) -> TokenEvent:
|
|
data = body.model_dump(exclude_none=True)
|
|
return await _create_or_upsert_event(data, session)
|
|
|
|
|
|
@router.post("/upsert", response_model=TokenEventRead)
|
|
async def upsert_token_event(
|
|
body: TokenEventCreate,
|
|
session: AsyncSession = Depends(get_session),
|
|
) -> TokenEvent:
|
|
data = body.model_dump(exclude_none=True)
|
|
return await _create_or_upsert_event(data, session)
|
|
|
|
|
|
@router.get("/summary/", response_model=TokenSummary)
|
|
async def get_token_summary(
|
|
scope: str = Query(..., description="task|workstream|repo|commit|release|session"),
|
|
id: str = Query(..., description="FK value or ref_id depending on scope"),
|
|
session: AsyncSession = Depends(get_session),
|
|
) -> TokenSummary:
|
|
q = select(TokenEvent)
|
|
|
|
if scope == "task":
|
|
try:
|
|
uid = uuid.UUID(id)
|
|
except ValueError:
|
|
raise HTTPException(status_code=422, detail="id must be a valid UUID for scope=task")
|
|
q = q.where(TokenEvent.task_id == uid)
|
|
elif scope == "workstream":
|
|
try:
|
|
uid = uuid.UUID(id)
|
|
except ValueError:
|
|
raise HTTPException(status_code=422, detail="id must be a valid UUID for scope=workstream")
|
|
q = q.where(TokenEvent.workstream_id == uid)
|
|
elif scope == "repo":
|
|
try:
|
|
uid = uuid.UUID(id)
|
|
except ValueError:
|
|
raise HTTPException(status_code=422, detail="id must be a valid UUID for scope=repo")
|
|
q = q.where(TokenEvent.repo_id == uid)
|
|
elif scope in ("commit", "release", "session"):
|
|
q = q.where(TokenEvent.ref_type == scope, TokenEvent.ref_id == id)
|
|
else:
|
|
raise HTTPException(status_code=422, detail=f"Unknown scope: {scope!r}")
|
|
|
|
result = await session.execute(q)
|
|
events = list(result.scalars().all())
|
|
|
|
tokens_in = sum(e.tokens_in for e in events)
|
|
tokens_out = sum(e.tokens_out for e in events)
|
|
|
|
by_model: dict[str, int] = defaultdict(int)
|
|
by_agent: dict[str, int] = defaultdict(int)
|
|
by_measurement_kind: dict[str, int] = defaultdict(int)
|
|
by_source_provider: dict[str, int] = defaultdict(int)
|
|
for e in events:
|
|
total = _event_total(e)
|
|
if e.model:
|
|
by_model[e.model] += total
|
|
if e.agent:
|
|
by_agent[e.agent] += total
|
|
by_measurement_kind[e.measurement_kind] += total
|
|
by_source_provider[e.source_provider] += total
|
|
|
|
return TokenSummary(
|
|
scope=scope,
|
|
scope_id=id,
|
|
tokens_in=tokens_in,
|
|
tokens_out=tokens_out,
|
|
tokens_total=tokens_in + tokens_out,
|
|
event_count=len(events),
|
|
by_model=dict(by_model),
|
|
by_agent=dict(by_agent),
|
|
by_measurement_kind=dict(by_measurement_kind),
|
|
by_source_provider=dict(by_source_provider),
|
|
)
|
|
|
|
|
|
@router.get("/by-repo/", response_model=list[RepoTokenSummary])
|
|
async def get_tokens_by_repo(
|
|
measurement_kind: str | None = None,
|
|
source_provider: str | None = None,
|
|
since: datetime | None = None,
|
|
until: datetime | None = None,
|
|
include_superseded: bool = Query(True),
|
|
session: AsyncSession = Depends(get_session),
|
|
) -> list[RepoTokenSummary]:
|
|
"""Aggregate token consumption per repo, resolving via the full graph.
|
|
|
|
Resolution order for each event:
|
|
1. token_events.repo_id (direct)
|
|
2. → workstreams.repo_id (via workstream_id)
|
|
3. → task.workstream_id → workstreams.repo_id (via task_id)
|
|
|
|
Only events that resolve to a repo are included.
|
|
"""
|
|
# Fetch all events, workstreams, repos in three queries (avoids N+1)
|
|
events_result = await session.execute(
|
|
_filter_query(
|
|
select(TokenEvent),
|
|
measurement_kind=measurement_kind,
|
|
source_provider=source_provider,
|
|
since=since,
|
|
until=until,
|
|
include_superseded=include_superseded,
|
|
)
|
|
)
|
|
events = list(events_result.scalars().all())
|
|
|
|
ws_result = await session.execute(select(Workstream))
|
|
ws_map: dict[uuid.UUID, Workstream] = {w.id: w for w in ws_result.scalars().all()}
|
|
|
|
task_result = await session.execute(select(Task))
|
|
task_map: dict[uuid.UUID, Task] = {t.id: t for t in task_result.scalars().all()}
|
|
|
|
repo_result = await session.execute(select(ManagedRepo))
|
|
repo_map: dict[uuid.UUID, ManagedRepo] = {r.id: r for r in repo_result.scalars().all()}
|
|
|
|
def resolve_repo_id(e: TokenEvent) -> uuid.UUID | None:
|
|
if e.repo_id:
|
|
return e.repo_id
|
|
ws_id = e.workstream_id
|
|
if not ws_id and e.task_id and e.task_id in task_map:
|
|
ws_id = task_map[e.task_id].workstream_id
|
|
if ws_id and ws_id in ws_map:
|
|
return ws_map[ws_id].repo_id
|
|
return None
|
|
|
|
groups: dict[uuid.UUID, dict] = {}
|
|
for e in events:
|
|
rid = resolve_repo_id(e)
|
|
if not rid or rid not in repo_map:
|
|
continue
|
|
if rid not in groups:
|
|
groups[rid] = {
|
|
"repo_id": rid,
|
|
"repo_slug": repo_map[rid].slug,
|
|
"tokens_in": 0,
|
|
"tokens_out": 0,
|
|
"event_count": 0,
|
|
"by_model": defaultdict(int),
|
|
"by_note": defaultdict(int),
|
|
"by_measurement_kind": defaultdict(int),
|
|
"by_source_provider": defaultdict(int),
|
|
}
|
|
g = groups[rid]
|
|
g["tokens_in"] += e.tokens_in
|
|
g["tokens_out"] += e.tokens_out
|
|
g["event_count"] += 1
|
|
total = _event_total(e)
|
|
if e.model:
|
|
g["by_model"][e.model] += total
|
|
g["by_note"][e.note or "unknown"] += total
|
|
g["by_measurement_kind"][e.measurement_kind] += total
|
|
g["by_source_provider"][e.source_provider] += total
|
|
|
|
return [
|
|
RepoTokenSummary(
|
|
**{k: (dict(v) if isinstance(v, defaultdict) else v) for k, v in g.items()},
|
|
tokens_total=g["tokens_in"] + g["tokens_out"],
|
|
)
|
|
for g in sorted(groups.values(), key=lambda x: -(x["tokens_in"] + x["tokens_out"]))
|
|
]
|
|
|
|
|
|
@router.get("/aggregate/", response_model=TokenAggregateSummary)
|
|
async def get_token_aggregate(
|
|
measurement_kind: str | None = None,
|
|
source_provider: str | None = None,
|
|
since: datetime | None = None,
|
|
until: datetime | None = None,
|
|
include_superseded: bool = Query(False),
|
|
session: AsyncSession = Depends(get_session),
|
|
) -> TokenAggregateSummary:
|
|
events_result = await session.execute(
|
|
_filter_query(
|
|
select(TokenEvent),
|
|
measurement_kind=measurement_kind,
|
|
source_provider=source_provider,
|
|
since=since,
|
|
until=until,
|
|
include_superseded=include_superseded,
|
|
)
|
|
)
|
|
events = list(events_result.scalars().all())
|
|
|
|
ws_result = await session.execute(select(Workstream))
|
|
ws_map: dict[uuid.UUID, Workstream] = {w.id: w for w in ws_result.scalars().all()}
|
|
|
|
task_result = await session.execute(select(Task))
|
|
task_map: dict[uuid.UUID, Task] = {t.id: t for t in task_result.scalars().all()}
|
|
|
|
repo_result = await session.execute(select(ManagedRepo))
|
|
repo_map: dict[uuid.UUID, ManagedRepo] = {r.id: r for r in repo_result.scalars().all()}
|
|
|
|
def resolve_repo_id(e: TokenEvent) -> uuid.UUID | None:
|
|
if e.repo_id:
|
|
return e.repo_id
|
|
ws_id = e.workstream_id
|
|
if not ws_id and e.task_id and e.task_id in task_map:
|
|
ws_id = task_map[e.task_id].workstream_id
|
|
if ws_id and ws_id in ws_map:
|
|
return ws_map[ws_id].repo_id
|
|
return None
|
|
|
|
def add(groups: dict[str, dict[str, Any]], key: str | None, label: str | None, e: TokenEvent) -> None:
|
|
if not key:
|
|
return
|
|
if key not in groups:
|
|
groups[key] = {
|
|
"scope_id": key,
|
|
"label": label,
|
|
"tokens_in": 0,
|
|
"tokens_out": 0,
|
|
"event_count": 0,
|
|
"by_measurement_kind": defaultdict(int),
|
|
"by_source_provider": defaultdict(int),
|
|
}
|
|
row = groups[key]
|
|
total = _event_total(e)
|
|
row["tokens_in"] += e.tokens_in
|
|
row["tokens_out"] += e.tokens_out
|
|
row["event_count"] += 1
|
|
row["by_measurement_kind"][e.measurement_kind] += total
|
|
row["by_source_provider"][e.source_provider] += total
|
|
|
|
by_repo: dict[str, dict[str, Any]] = {}
|
|
by_workstream: dict[str, dict[str, Any]] = {}
|
|
by_task: dict[str, dict[str, Any]] = {}
|
|
by_model: dict[str, dict[str, Any]] = {}
|
|
by_measurement_kind: dict[str, int] = defaultdict(int)
|
|
by_source_provider: dict[str, int] = defaultdict(int)
|
|
|
|
first_event_at = last_event_at = last_ingested_at = None
|
|
tokens_in = tokens_out = 0
|
|
for e in events:
|
|
total = _event_total(e)
|
|
tokens_in += e.tokens_in
|
|
tokens_out += e.tokens_out
|
|
by_measurement_kind[e.measurement_kind] += total
|
|
by_source_provider[e.source_provider] += total
|
|
|
|
if first_event_at is None or e.created_at < first_event_at:
|
|
first_event_at = e.created_at
|
|
if last_event_at is None or e.created_at > last_event_at:
|
|
last_event_at = e.created_at
|
|
if last_ingested_at is None or e.ingested_at > last_ingested_at:
|
|
last_ingested_at = e.ingested_at
|
|
|
|
rid = resolve_repo_id(e)
|
|
repo = repo_map.get(rid) if rid else None
|
|
add(by_repo, str(rid) if rid else None, repo.slug if repo else None, e)
|
|
|
|
ws_id = e.workstream_id or (task_map[e.task_id].workstream_id if e.task_id in task_map else None)
|
|
ws = ws_map.get(ws_id) if ws_id else None
|
|
add(by_workstream, str(ws_id) if ws_id else None, ws.title if ws else None, e)
|
|
|
|
task = task_map.get(e.task_id) if e.task_id else None
|
|
add(by_task, str(e.task_id) if e.task_id else None, task.title if task else None, e)
|
|
|
|
add(by_model, e.model or "unknown", e.model or "unknown", e)
|
|
|
|
def rows(groups: dict[str, dict[str, Any]]) -> list[TokenAggregateRow]:
|
|
result = []
|
|
for row in groups.values():
|
|
result.append(
|
|
TokenAggregateRow(
|
|
**{k: (dict(v) if isinstance(v, defaultdict) else v) for k, v in row.items()},
|
|
tokens_total=row["tokens_in"] + row["tokens_out"],
|
|
)
|
|
)
|
|
return sorted(result, key=lambda item: -item.tokens_total)
|
|
|
|
return TokenAggregateSummary(
|
|
tokens_in=tokens_in,
|
|
tokens_out=tokens_out,
|
|
tokens_total=tokens_in + tokens_out,
|
|
event_count=len(events),
|
|
first_event_at=first_event_at,
|
|
last_event_at=last_event_at,
|
|
last_ingested_at=last_ingested_at,
|
|
by_repo=rows(by_repo),
|
|
by_workstream=rows(by_workstream),
|
|
by_task=rows(by_task),
|
|
by_model=rows(by_model),
|
|
by_measurement_kind=dict(by_measurement_kind),
|
|
by_source_provider=dict(by_source_provider),
|
|
)
|
|
|
|
|
|
@router.get("/quality/", response_model=TokenQualitySummary)
|
|
async def get_token_quality(
|
|
since: datetime | None = None,
|
|
until: datetime | None = None,
|
|
session: AsyncSession = Depends(get_session),
|
|
) -> TokenQualitySummary:
|
|
result = await session.execute(_filter_query(select(TokenEvent), since=since, until=until))
|
|
events = list(result.scalars().all())
|
|
|
|
by_measurement_kind: dict[str, int] = defaultdict(int)
|
|
by_source_provider: dict[str, int] = defaultdict(int)
|
|
source_counts: dict[tuple[str, str, str], int] = defaultdict(int)
|
|
last_codex_ingested_at = None
|
|
last_claude_ingested_at = None
|
|
|
|
fallback_count = 0
|
|
unattributed_measured_count = 0
|
|
missing_provenance_count = 0
|
|
for e in events:
|
|
by_measurement_kind[e.measurement_kind] += 1
|
|
by_source_provider[e.source_provider] += 1
|
|
if e.source_id:
|
|
source_counts[(e.measurement_kind, e.source_provider, e.source_id)] += 1
|
|
if e.source_provider == "task_fallback" or e.note == "heuristic":
|
|
fallback_count += 1
|
|
if e.measurement_kind == "measured" and not (e.repo_id or e.workstream_id or e.task_id):
|
|
unattributed_measured_count += 1
|
|
if e.measurement_kind == "measured" and not e.source_id:
|
|
missing_provenance_count += 1
|
|
if e.source_provider == "codex_session" and (
|
|
last_codex_ingested_at is None or e.ingested_at > last_codex_ingested_at
|
|
):
|
|
last_codex_ingested_at = e.ingested_at
|
|
if e.source_provider == "claude_transcript" and (
|
|
last_claude_ingested_at is None or e.ingested_at > last_claude_ingested_at
|
|
):
|
|
last_claude_ingested_at = e.ingested_at
|
|
|
|
duplicate_source_count = sum(1 for count in source_counts.values() if count > 1)
|
|
return TokenQualitySummary(
|
|
event_count=len(events),
|
|
measured_event_count=by_measurement_kind.get("measured", 0),
|
|
estimated_event_count=by_measurement_kind.get("estimated", 0),
|
|
allocated_event_count=by_measurement_kind.get("allocated", 0),
|
|
superseded_event_count=by_measurement_kind.get("superseded", 0),
|
|
fallback_event_count=fallback_count,
|
|
unattributed_measured_event_count=unattributed_measured_count,
|
|
missing_provenance_event_count=missing_provenance_count,
|
|
duplicate_source_count=duplicate_source_count,
|
|
last_codex_ingested_at=last_codex_ingested_at,
|
|
last_claude_ingested_at=last_claude_ingested_at,
|
|
last_reconciliation_at=None,
|
|
by_measurement_kind=dict(by_measurement_kind),
|
|
by_source_provider=dict(by_source_provider),
|
|
)
|
|
|
|
|
|
@router.patch("/{event_id}", response_model=TokenEventRead)
|
|
async def patch_token_event(
|
|
event_id: uuid.UUID,
|
|
body: TokenEventPatch,
|
|
session: AsyncSession = Depends(get_session),
|
|
) -> TokenEvent:
|
|
event = await session.get(TokenEvent, event_id)
|
|
if event is None:
|
|
raise HTTPException(status_code=404, detail="Token event not found")
|
|
data = body.model_dump(exclude_none=True)
|
|
if "note" in data or "measurement_kind" in data or "source_provider" in data:
|
|
merged = {
|
|
"tokens_in": data.get("tokens_in", event.tokens_in),
|
|
"tokens_out": data.get("tokens_out", event.tokens_out),
|
|
"note": data.get("note", event.note),
|
|
"agent": data.get("agent", event.agent),
|
|
"ref_id": data.get("ref_id", event.ref_id),
|
|
"session_id": data.get("session_id", event.session_id),
|
|
"measurement_kind": data.get("measurement_kind", event.measurement_kind),
|
|
"source_provider": data.get("source_provider", event.source_provider),
|
|
"source_id": data.get("source_id", event.source_id),
|
|
}
|
|
inferred = _apply_event_defaults({k: v for k, v in merged.items() if v is not None})
|
|
data.setdefault("measurement_kind", inferred["measurement_kind"])
|
|
data.setdefault("source_provider", inferred["source_provider"])
|
|
data.setdefault("confidence", inferred["confidence"])
|
|
if inferred.get("source_id"):
|
|
data.setdefault("source_id", inferred["source_id"])
|
|
for field, value in data.items():
|
|
setattr(event, field, value)
|
|
await session.commit()
|
|
await session.refresh(event)
|
|
return event
|
|
|
|
|
|
@router.get("/{event_id}", response_model=TokenEventRead)
|
|
async def get_token_event(
|
|
event_id: uuid.UUID,
|
|
session: AsyncSession = Depends(get_session),
|
|
) -> TokenEvent:
|
|
event = await session.get(TokenEvent, event_id)
|
|
if event is None:
|
|
raise HTTPException(status_code=404, detail="Token event not found")
|
|
return event
|
|
|
|
|
|
@router.get("/", response_model=list[TokenEventRead])
|
|
async def list_token_events(
|
|
task_id: uuid.UUID | None = None,
|
|
workstream_id: uuid.UUID | None = None,
|
|
repo_id: uuid.UUID | None = None,
|
|
ref_type: str | None = None,
|
|
ref_id: str | None = None,
|
|
model: str | None = None,
|
|
agent: str | None = None,
|
|
note: str | None = None,
|
|
measurement_kind: str | None = None,
|
|
source_provider: str | None = None,
|
|
since: datetime | None = None,
|
|
until: datetime | None = None,
|
|
include_superseded: bool = Query(True),
|
|
unattributed: bool = False,
|
|
offset: int = Query(0, ge=0),
|
|
limit: int = Query(100, le=1000),
|
|
session: AsyncSession = Depends(get_session),
|
|
) -> list[TokenEvent]:
|
|
q = _filter_query(
|
|
select(TokenEvent),
|
|
task_id=task_id,
|
|
workstream_id=workstream_id,
|
|
repo_id=repo_id,
|
|
ref_type=ref_type,
|
|
ref_id=ref_id,
|
|
model=model,
|
|
agent=agent,
|
|
note=note,
|
|
measurement_kind=measurement_kind,
|
|
source_provider=source_provider,
|
|
since=since,
|
|
until=until,
|
|
include_superseded=include_superseded,
|
|
unattributed=unattributed,
|
|
)
|
|
q = q.order_by(TokenEvent.created_at.desc()).offset(offset).limit(limit)
|
|
result = await session.execute(q)
|
|
return list(result.scalars().all())
|