generated from coulomb/repo-seed
154 lines
4.3 KiB
Python
154 lines
4.3 KiB
Python
"""Per-call diagnostics capture for server debug and audit modes."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import copy
|
|
import json
|
|
from contextlib import contextmanager
|
|
from contextvars import ContextVar
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Iterator, Mapping
|
|
from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit
|
|
|
|
|
|
_SECRET_QUERY_KEYS = {"key", "api_key", "apikey", "access_token", "token"}
|
|
_SECRET_HEADER_TOKENS = ("authorization", "api-key", "apikey", "token", "secret", "key")
|
|
|
|
|
|
@dataclass
|
|
class Diagnostics:
|
|
"""Captured provider request/response details for one logical LLM call."""
|
|
|
|
provider_request: dict[str, Any] | None = None
|
|
provider_response: dict[str, Any] | None = None
|
|
adapter_transformations: list[dict[str, Any]] = field(default_factory=list)
|
|
|
|
def to_dict(self) -> dict[str, Any]:
|
|
return {
|
|
"provider_request": self.provider_request,
|
|
"provider_response": self.provider_response,
|
|
"adapter_transformations": self.adapter_transformations,
|
|
}
|
|
|
|
|
|
_CURRENT: ContextVar[Diagnostics | None] = ContextVar(
|
|
"llm_connect_diagnostics",
|
|
default=None,
|
|
)
|
|
|
|
|
|
@contextmanager
|
|
def capture_diagnostics(enabled: bool = True) -> Iterator[Diagnostics | None]:
|
|
"""Capture diagnostics within this context when *enabled* is true."""
|
|
|
|
if not enabled:
|
|
yield None
|
|
return
|
|
|
|
diagnostics = Diagnostics()
|
|
token = _CURRENT.set(diagnostics)
|
|
try:
|
|
yield diagnostics
|
|
finally:
|
|
_CURRENT.reset(token)
|
|
|
|
|
|
def diagnostics_enabled() -> bool:
|
|
return _CURRENT.get() is not None
|
|
|
|
|
|
def current_diagnostics() -> Diagnostics | None:
|
|
return _CURRENT.get()
|
|
|
|
|
|
def record_provider_request(
|
|
*,
|
|
url: str | None = None,
|
|
payload: Any | None = None,
|
|
headers: Mapping[str, Any] | None = None,
|
|
command: list[str] | None = None,
|
|
) -> None:
|
|
diagnostics = _CURRENT.get()
|
|
if diagnostics is None:
|
|
return
|
|
|
|
request: dict[str, Any] = {}
|
|
if url is not None:
|
|
request["url"] = redact_url(url)
|
|
if payload is not None:
|
|
request["payload"] = json_safe(payload)
|
|
if headers is not None:
|
|
request["headers_redacted"] = redact_headers(headers)
|
|
if command is not None:
|
|
request["command"] = list(command)
|
|
diagnostics.provider_request = request
|
|
|
|
|
|
def record_provider_response(*, status: int | None = None, body: Any | None = None) -> None:
|
|
diagnostics = _CURRENT.get()
|
|
if diagnostics is None:
|
|
return
|
|
|
|
response: dict[str, Any] = {}
|
|
if status is not None:
|
|
response["status"] = status
|
|
if body is not None:
|
|
response["body"] = json_safe(body)
|
|
diagnostics.provider_response = response
|
|
|
|
|
|
def record_adapter_transformation(step: str, before: Any, after: Any) -> None:
|
|
diagnostics = _CURRENT.get()
|
|
if diagnostics is None:
|
|
return
|
|
|
|
diagnostics.adapter_transformations.append(
|
|
{
|
|
"step": step,
|
|
"before": json_safe(before),
|
|
"after": json_safe(after),
|
|
}
|
|
)
|
|
|
|
|
|
def json_safe(value: Any) -> Any:
|
|
"""Return a JSON-serializable snapshot of *value* without mutating it."""
|
|
|
|
try:
|
|
return json.loads(json.dumps(value))
|
|
except (TypeError, ValueError):
|
|
try:
|
|
return copy.deepcopy(value)
|
|
except Exception:
|
|
return repr(value)
|
|
|
|
|
|
def redact_headers(headers: Mapping[str, Any]) -> dict[str, Any]:
|
|
redacted: dict[str, Any] = {}
|
|
for key, value in headers.items():
|
|
lowered = str(key).lower()
|
|
if any(token in lowered for token in _SECRET_HEADER_TOKENS):
|
|
redacted[str(key)] = _redact_header_value(value)
|
|
else:
|
|
redacted[str(key)] = json_safe(value)
|
|
return redacted
|
|
|
|
|
|
def redact_url(url: str) -> str:
|
|
parts = urlsplit(url)
|
|
query = []
|
|
for key, value in parse_qsl(parts.query, keep_blank_values=True):
|
|
if key.lower() in _SECRET_QUERY_KEYS:
|
|
query.append((key, "<redacted>"))
|
|
else:
|
|
query.append((key, value))
|
|
return urlunsplit((parts.scheme, parts.netloc, parts.path, urlencode(query), parts.fragment))
|
|
|
|
|
|
def _redact_header_value(value: Any) -> str:
|
|
text = str(value)
|
|
if " " in text:
|
|
scheme = text.split(" ", 1)[0]
|
|
return f"{scheme} <redacted>"
|
|
return "<redacted>"
|