feat(statehub): add offline write buffer relay

This commit is contained in:
2026-06-25 13:44:27 +02:00
parent 63f0398304
commit b536741539
21 changed files with 1963 additions and 25 deletions

1
api/edge/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""State Hub edge relay and durable outbox helpers."""

358
api/edge/outbox.py Normal file
View File

@@ -0,0 +1,358 @@
from __future__ import annotations
import json
import os
import sqlite3
import stat
import uuid
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Any
from api.services.write_idempotency import route_class_for
DEFAULT_OUTBOX_PATH = Path(os.environ.get("STATEHUB_OUTBOX_PATH", "~/.statehub/edge-outbox.sqlite3")).expanduser()
MAX_PAYLOAD_BYTES = 64 * 1024
SECRET_FIELD_NAMES = {
"authorization",
"cookie",
"set-cookie",
"password",
"passwd",
"secret",
"api_key",
"apikey",
"access_token",
"refresh_token",
"bearer_token",
"client_secret",
"private_key",
"credential",
"credentials",
}
@dataclass(frozen=True)
class OutboxEnvelope:
id: str
idempotency_key: str
method: str
path: str
body: dict[str, Any] | list[Any] | None
route_class: str
source_agent: str | None
source_host: str | None
repo_slug: str | None
session_id: str | None
observed_revision: dict[str, Any] | None
status: str
attempt_count: int
next_retry_at: str | None
last_error: str | None
response_status: int | None
response_body: dict[str, Any] | list[Any] | str | None
created_at: str
updated_at: str
acked_at: str | None
class PayloadRejected(ValueError):
pass
def utcnow() -> str:
return datetime.now(tz=timezone.utc).isoformat()
def default_outbox_path() -> Path:
return DEFAULT_OUTBOX_PATH
def scrub_payload(value: Any) -> Any:
if isinstance(value, dict):
scrubbed: dict[str, Any] = {}
for key, item in value.items():
normalized = str(key).lower().replace("-", "_")
if normalized in SECRET_FIELD_NAMES:
scrubbed[key] = "[redacted]"
else:
scrubbed[key] = scrub_payload(item)
return scrubbed
if isinstance(value, list):
return [scrub_payload(item) for item in value]
return value
def _json_loads(raw: str | None) -> Any:
if raw is None:
return None
return json.loads(raw)
def _json_dumps(value: Any) -> str | None:
if value is None:
return None
return json.dumps(value, sort_keys=True, separators=(",", ":"))
def _parse_dt(value: str | None) -> datetime | None:
if not value:
return None
return datetime.fromisoformat(value)
class OutboxStore:
def __init__(self, path: str | Path | None = None) -> None:
self.path = Path(path).expanduser() if path is not None else default_outbox_path()
self.path.parent.mkdir(parents=True, exist_ok=True)
self._init_db()
self._chmod_private()
def _connect(self) -> sqlite3.Connection:
conn = sqlite3.connect(self.path)
conn.row_factory = sqlite3.Row
return conn
def _init_db(self) -> None:
with self._connect() as conn:
conn.execute(
"""
CREATE TABLE IF NOT EXISTS outbox_envelopes (
id TEXT PRIMARY KEY,
idempotency_key TEXT NOT NULL UNIQUE,
method TEXT NOT NULL,
path TEXT NOT NULL,
body_json TEXT,
route_class TEXT NOT NULL,
source_agent TEXT,
source_host TEXT,
repo_slug TEXT,
session_id TEXT,
observed_revision_json TEXT,
status TEXT NOT NULL,
attempt_count INTEGER NOT NULL DEFAULT 0,
next_retry_at TEXT,
last_error TEXT,
response_status INTEGER,
response_body_json TEXT,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL,
acked_at TEXT
)
"""
)
conn.execute("CREATE INDEX IF NOT EXISTS ix_outbox_status ON outbox_envelopes(status)")
conn.execute("CREATE INDEX IF NOT EXISTS ix_outbox_next_retry ON outbox_envelopes(next_retry_at)")
conn.commit()
def _chmod_private(self) -> None:
try:
os.chmod(self.path, stat.S_IRUSR | stat.S_IWUSR)
except OSError:
pass
def enqueue(
self,
*,
method: str,
path: str,
body: Any,
idempotency_key: str | None = None,
source_agent: str | None = None,
source_host: str | None = None,
repo_slug: str | None = None,
session_id: str | None = None,
observed_revision: dict[str, Any] | None = None,
) -> OutboxEnvelope:
route_class = route_class_for(method, path)
if route_class is None:
raise PayloadRejected(f"{method.upper()} {path} is not queueable")
scrubbed = scrub_payload(body)
encoded = _json_dumps(scrubbed)
if encoded is not None and len(encoded.encode("utf-8")) > MAX_PAYLOAD_BYTES:
raise PayloadRejected("payload exceeds offline outbox size limit")
now = utcnow()
envelope_id = str(uuid.uuid4())
key = idempotency_key or f"statehub-edge:{envelope_id}"
method_upper = method.upper()
with self._connect() as conn:
if route_class == "replace":
conn.execute(
"""
UPDATE outbox_envelopes
SET status = 'cancelled', updated_at = ?, last_error = ?
WHERE status = 'queued'
AND route_class = 'replace'
AND method = ?
AND path = ?
""",
(now, f"superseded by {envelope_id}", method_upper, path),
)
conn.execute(
"""
INSERT INTO outbox_envelopes (
id, idempotency_key, method, path, body_json, route_class,
source_agent, source_host, repo_slug, session_id,
observed_revision_json, status, created_at, updated_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, 'queued', ?, ?)
""",
(
envelope_id,
key,
method_upper,
path,
encoded,
route_class,
source_agent,
source_host,
repo_slug,
session_id,
_json_dumps(observed_revision),
now,
now,
),
)
conn.commit()
return self.get(envelope_id)
def get(self, envelope_id: str) -> OutboxEnvelope:
with self._connect() as conn:
row = conn.execute("SELECT * FROM outbox_envelopes WHERE id = ?", (envelope_id,)).fetchone()
if row is None:
raise KeyError(envelope_id)
return self._row_to_envelope(row)
def list(self, *, status: str | None = None, limit: int = 100) -> list[OutboxEnvelope]:
with self._connect() as conn:
if status:
rows = conn.execute(
"SELECT * FROM outbox_envelopes WHERE status = ? ORDER BY created_at LIMIT ?",
(status, limit),
).fetchall()
else:
rows = conn.execute(
"SELECT * FROM outbox_envelopes ORDER BY created_at LIMIT ?",
(limit,),
).fetchall()
return [self._row_to_envelope(row) for row in rows]
def due(self, *, limit: int = 50) -> list[OutboxEnvelope]:
now = utcnow()
with self._connect() as conn:
rows = conn.execute(
"""
SELECT * FROM outbox_envelopes
WHERE status = 'queued' AND (next_retry_at IS NULL OR next_retry_at <= ?)
ORDER BY created_at
LIMIT ?
""",
(now, limit),
).fetchall()
return [self._row_to_envelope(row) for row in rows]
def summary(self) -> dict[str, Any]:
with self._connect() as conn:
rows = conn.execute(
"SELECT status, COUNT(*) AS count, MIN(created_at) AS oldest FROM outbox_envelopes GROUP BY status"
).fetchall()
by_status = {row["status"]: row["count"] for row in rows}
oldest_pending = None
for row in rows:
if row["status"] in {"queued", "sending", "conflict"} and row["oldest"]:
oldest_pending = min(filter(None, [oldest_pending, row["oldest"]])) if oldest_pending else row["oldest"]
return {
"path": str(self.path),
"by_status": by_status,
"pending_count": sum(by_status.get(status, 0) for status in ("queued", "sending")),
"conflict_count": by_status.get("conflict", 0),
"oldest_pending_at": oldest_pending,
}
def mark_sending(self, envelope_id: str) -> None:
self._update(envelope_id, status="sending", updated_at=utcnow())
def mark_acked(self, envelope_id: str, *, response_status: int, response_body: Any) -> None:
now = utcnow()
self._update(
envelope_id,
status="acked",
response_status=response_status,
response_body_json=_json_dumps(response_body),
updated_at=now,
acked_at=now,
last_error=None,
next_retry_at=None,
)
def mark_conflict(self, envelope_id: str, *, response_status: int, response_body: Any) -> None:
self._update(
envelope_id,
status="conflict",
response_status=response_status,
response_body_json=_json_dumps(response_body),
updated_at=utcnow(),
last_error="conflict",
)
def mark_dead(self, envelope_id: str, *, error: str, response_status: int | None = None, response_body: Any = None) -> None:
self._update(
envelope_id,
status="dead",
response_status=response_status,
response_body_json=_json_dumps(response_body),
updated_at=utcnow(),
last_error=error,
)
def mark_retry(self, envelope_id: str, *, error: str, attempt_count: int) -> None:
delay_seconds = min(3600, 2 ** min(attempt_count, 10))
next_retry = datetime.now(tz=timezone.utc) + timedelta(seconds=delay_seconds)
self._update(
envelope_id,
status="queued",
attempt_count=attempt_count,
next_retry_at=next_retry.isoformat(),
updated_at=utcnow(),
last_error=error[:500],
)
def retry(self, envelope_id: str) -> None:
self._update(envelope_id, status="queued", next_retry_at=None, updated_at=utcnow())
def cancel(self, envelope_id: str) -> None:
self._update(envelope_id, status="cancelled", updated_at=utcnow())
def export(self, *, status: str | None = None, limit: int = 1000) -> list[dict[str, Any]]:
return [envelope.__dict__ for envelope in self.list(status=status, limit=limit)]
def _update(self, envelope_id: str, **values: Any) -> None:
assignments = ", ".join(f"{key} = ?" for key in values)
params = list(values.values()) + [envelope_id]
with self._connect() as conn:
conn.execute(f"UPDATE outbox_envelopes SET {assignments} WHERE id = ?", params)
conn.commit()
def _row_to_envelope(self, row: sqlite3.Row) -> OutboxEnvelope:
return OutboxEnvelope(
id=row["id"],
idempotency_key=row["idempotency_key"],
method=row["method"],
path=row["path"],
body=_json_loads(row["body_json"]),
route_class=row["route_class"],
source_agent=row["source_agent"],
source_host=row["source_host"],
repo_slug=row["repo_slug"],
session_id=row["session_id"],
observed_revision=_json_loads(row["observed_revision_json"]),
status=row["status"],
attempt_count=row["attempt_count"],
next_retry_at=row["next_retry_at"],
last_error=row["last_error"],
response_status=row["response_status"],
response_body=_json_loads(row["response_body_json"]),
created_at=row["created_at"],
updated_at=row["updated_at"],
acked_at=row["acked_at"],
)

206
api/edge/relay.py Normal file
View File

@@ -0,0 +1,206 @@
from __future__ import annotations
import os
import socket
from typing import Any
import httpx
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response
from api.edge.outbox import OutboxEnvelope, OutboxStore, PayloadRejected, default_outbox_path
from api.services.write_idempotency import route_class_for
HOP_BY_HOP_HEADERS = {
"connection",
"keep-alive",
"proxy-authenticate",
"proxy-authorization",
"te",
"trailer",
"transfer-encoding",
"upgrade",
"content-encoding",
"content-length",
}
def _safe_response_headers(headers: httpx.Headers) -> dict[str, str]:
return {key: value for key, value in headers.items() if key.lower() not in HOP_BY_HOP_HEADERS}
def _body_summary(response: httpx.Response) -> Any:
try:
return response.json()
except ValueError:
return {"text": response.text[:500]}
def queued_receipt(envelope: OutboxEnvelope, upstream_error: str) -> dict[str, Any]:
return {
"queued": True,
"outbox_id": envelope.id,
"idempotency_key": envelope.idempotency_key,
"upstream": "unreachable",
"upstream_error": upstream_error,
"route_class": envelope.route_class,
}
async def replay_pending(
store: OutboxStore,
*,
upstream_url: str,
limit: int = 50,
timeout: float = 10.0,
) -> dict[str, int]:
counts = {"sent": 0, "acked": 0, "conflict": 0, "retry": 0, "dead": 0}
async with httpx.AsyncClient(base_url=upstream_url.rstrip("/"), timeout=timeout) as client:
for envelope in store.due(limit=limit):
counts["sent"] += 1
store.mark_sending(envelope.id)
try:
response = await client.request(
envelope.method,
envelope.path,
json=envelope.body,
headers={
"Idempotency-Key": envelope.idempotency_key,
"X-StateHub-Source-Agent": envelope.source_agent or "statehub-edge",
"X-StateHub-Source-Host": envelope.source_host or socket.gethostname(),
},
)
except httpx.HTTPError as exc:
counts["retry"] += 1
store.mark_retry(envelope.id, error=str(exc), attempt_count=envelope.attempt_count + 1)
continue
response_body = _body_summary(response)
if response.status_code == 409:
counts["conflict"] += 1
store.mark_conflict(envelope.id, response_status=response.status_code, response_body=response_body)
elif 200 <= response.status_code < 300:
counts["acked"] += 1
store.mark_acked(envelope.id, response_status=response.status_code, response_body=response_body)
elif response.status_code >= 500:
counts["retry"] += 1
store.mark_retry(
envelope.id,
error=f"HTTP {response.status_code}: {response.text[:300]}",
attempt_count=envelope.attempt_count + 1,
)
else:
counts["dead"] += 1
store.mark_dead(
envelope.id,
error=f"HTTP {response.status_code}: not retryable",
response_status=response.status_code,
response_body=response_body,
)
return counts
def create_app(
*,
upstream_url: str | None = None,
outbox_path: str | None = None,
timeout: float = 10.0,
) -> FastAPI:
upstream = (upstream_url or os.environ.get("STATEHUB_UPSTREAM_URL") or os.environ.get("API_BASE") or "http://127.0.0.1:8000").rstrip("/")
store_path = outbox_path or default_outbox_path()
store_instance: OutboxStore | None = None
def get_store() -> OutboxStore:
nonlocal store_instance
if store_instance is None:
store_instance = OutboxStore(store_path)
return store_instance
app = FastAPI(title="State Hub Edge Relay", version="0.1.0")
@app.get("/edge/health")
async def edge_health() -> dict[str, Any]:
reachable = False
error = None
try:
async with httpx.AsyncClient(base_url=upstream, timeout=2.0) as client:
response = await client.get("/state/health")
reachable = response.status_code < 500
except httpx.HTTPError as exc:
error = str(exc)
return {
"status": "ok",
"upstream": upstream,
"upstream_reachable": reachable,
"upstream_error": error,
"outbox": get_store().summary(),
}
@app.post("/edge/replay")
async def edge_replay(limit: int = 50) -> dict[str, int]:
return await replay_pending(get_store(), upstream_url=upstream, limit=limit, timeout=timeout)
@app.api_route("/{path:path}", methods=["GET", "POST", "PATCH", "PUT", "DELETE"])
async def proxy(path: str, request: Request) -> Response:
api_path = "/" + path
body: Any = None
if request.method in {"POST", "PATCH", "PUT"}:
try:
body = await request.json()
except ValueError:
body = None
headers = {}
if idempotency_key := request.headers.get("idempotency-key"):
headers["Idempotency-Key"] = idempotency_key
if request.headers.get("content-type"):
headers["Content-Type"] = request.headers["content-type"]
try:
async with httpx.AsyncClient(base_url=upstream, timeout=timeout) as client:
response = await client.request(
request.method,
api_path,
params=request.query_params,
json=body if body is not None else None,
headers=headers,
)
return Response(
content=response.content,
status_code=response.status_code,
headers=_safe_response_headers(response.headers),
media_type=response.headers.get("content-type"),
)
except httpx.HTTPError as exc:
route_class = route_class_for(request.method, api_path)
if route_class is None or request.method not in {"POST", "PATCH"}:
return JSONResponse(
status_code=503,
content={
"error": "upstream unreachable and route is not queueable",
"method": request.method,
"path": api_path,
"upstream": upstream,
"detail": str(exc),
},
)
try:
envelope = get_store().enqueue(
method=request.method,
path=api_path,
body=body,
idempotency_key=request.headers.get("idempotency-key"),
source_agent=request.headers.get("x-statehub-source-agent"),
source_host=request.headers.get("x-statehub-source-host") or socket.gethostname(),
repo_slug=request.headers.get("x-statehub-repo-slug"),
session_id=request.headers.get("x-statehub-session-id"),
observed_revision=None,
)
except PayloadRejected as reject:
return JSONResponse(status_code=422, content={"error": str(reject)})
return JSONResponse(status_code=202, content=queued_receipt(envelope, str(exc)))
return app
app = create_app()

View File

@@ -11,6 +11,7 @@ from starlette.responses import Response as StarletteResponse
from api.database import engine
from api.events import shutdown_publisher
from api.services.write_idempotency import WriteIdempotencyMiddleware
from api.routers import decisions, extension_points, progress, state, tasks, technical_debt, topics, workstreams, workstream_dependencies
from api.routers import domains, repos, contributions, sbom, policy, domain_goals, repo_goals, messages, capability_requests, tpsc, services
from api.routers import token_events
@@ -91,13 +92,14 @@ _default_dashboard_origins = [
_cors_env = os.getenv("CORS_ORIGINS", ",".join(_default_dashboard_origins))
_cors_origins = [o.strip() for o in _cors_env.split(",") if o.strip()]
app.add_middleware(WriteIdempotencyMiddleware)
app.add_middleware(ETagMiddleware)
app.add_middleware(
CORSMiddleware,
allow_origins=_cors_origins,
allow_methods=["GET", "POST", "PATCH", "DELETE", "PUT"],
allow_headers=["Content-Type", "If-None-Match"],
expose_headers=["ETag", "X-StateHub-Elapsed-Ms", "X-StateHub-Response-Bytes", "X-StateHub-Cache"],
allow_headers=["Content-Type", "If-None-Match", "Idempotency-Key", "X-StateHub-Source-Agent", "X-StateHub-Source-Host"],
expose_headers=["ETag", "X-StateHub-Elapsed-Ms", "X-StateHub-Response-Bytes", "X-StateHub-Cache", "X-StateHub-Idempotency-Replay"],
)
app.include_router(domains.router)

View File

@@ -33,6 +33,7 @@ from api.models.interface_change import InterfaceChange
from api.models.workplan_launch_request import WorkplanLaunchRequest
from api.models.fabric_graph import FabricGraphImport, FabricGraphNode, FabricGraphEdge
from api.models.legacy_meter import LegacyInterface, LegacyInterfaceUsageBucket
from api.models.write_idempotency_key import WriteIdempotencyKey
__all__ = [
"Base",
@@ -65,4 +66,5 @@ __all__ = [
"WorkplanLaunchRequest",
"FabricGraphImport", "FabricGraphNode", "FabricGraphEdge",
"LegacyInterface", "LegacyInterfaceUsageBucket",
"WriteIdempotencyKey",
]

View File

@@ -0,0 +1,32 @@
from __future__ import annotations
import uuid
from datetime import datetime
from typing import Any
from sqlalchemy import DateTime, Integer, String, Text, UniqueConstraint
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.orm import Mapped, mapped_column
from api.models.base import Base, new_uuid
class WriteIdempotencyKey(Base):
__tablename__ = "write_idempotency_keys"
__table_args__ = (
UniqueConstraint("key", name="uq_write_idempotency_keys_key"),
)
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=new_uuid)
key: Mapped[str] = mapped_column(String(200), nullable=False, index=True)
method: Mapped[str] = mapped_column(String(10), nullable=False)
path: Mapped[str] = mapped_column(Text, nullable=False)
route_class: Mapped[str] = mapped_column(String(30), nullable=False)
request_hash: Mapped[str] = mapped_column(String(64), nullable=False)
response_status: Mapped[int] = mapped_column(Integer, nullable=False)
response_body: Mapped[Any] = mapped_column(JSONB, nullable=True)
source_host: Mapped[str | None] = mapped_column(String(200), nullable=True)
source_agent: Mapped[str | None] = mapped_column(String(100), nullable=True)
first_seen_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
last_seen_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True, index=True)

View File

@@ -0,0 +1,221 @@
from __future__ import annotations
import hashlib
import json
import re
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Any
from sqlalchemy import select
from starlette.responses import JSONResponse
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from api.database import async_session_factory
from api.models.write_idempotency_key import WriteIdempotencyKey
IDEMPOTENCY_HEADER = b"idempotency-key"
REPLAY_HEADER = "X-StateHub-Idempotency-Replay"
CONFLICT_STATUS = 409
DEFAULT_IDEMPOTENCY_TTL_DAYS = 14
@dataclass(frozen=True)
class WriteRouteRule:
method: str
pattern: str
route_class: str
description: str
def matches(self, method: str, path: str) -> bool:
normalized = path.rstrip("/") or "/"
return self.method == method.upper() and re.fullmatch(self.pattern, normalized) is not None
WRITE_ROUTE_RULES: tuple[WriteRouteRule, ...] = (
WriteRouteRule("POST", r"/progress", "append", "append progress event"),
WriteRouteRule("POST", r"/messages", "append", "send agent message"),
WriteRouteRule("PATCH", r"/messages/[^/]+/read", "append", "mark known message read"),
WriteRouteRule("POST", r"/token-events", "append", "record token event"),
WriteRouteRule("POST", r"/token-events/upsert", "append", "upsert token event"),
WriteRouteRule("POST", r"/decisions", "append", "record decision"),
WriteRouteRule("PATCH", r"/tasks/[^/]+", "replace", "update task"),
WriteRouteRule("POST", r"/tasks/bulk-status-sync", "replace", "bulk task status sync"),
WriteRouteRule("PATCH", r"/decisions/[^/]+", "replace", "update decision"),
WriteRouteRule("POST", r"/decisions/[^/]+/resolve", "replace", "resolve decision"),
WriteRouteRule("PATCH", r"/workplans/[^/]+", "replace", "update workplan"),
WriteRouteRule("PATCH", r"/workstreams/[^/]+", "replace", "update legacy workstream alias"),
)
def route_rule_for(method: str, path: str) -> WriteRouteRule | None:
for rule in WRITE_ROUTE_RULES:
if rule.matches(method, path):
return rule
return None
def route_class_for(method: str, path: str) -> str | None:
rule = route_rule_for(method, path)
return rule.route_class if rule else None
def canonical_request_hash(method: str, path: str, query_string: bytes, body: bytes) -> str:
try:
parsed: Any = json.loads(body.decode("utf-8")) if body else None
body_repr = json.dumps(parsed, sort_keys=True, separators=(",", ":"))
except (UnicodeDecodeError, json.JSONDecodeError):
body_repr = body.hex()
query = query_string.decode("utf-8", errors="replace")
seed = f"{method.upper()}\n{path}\n{query}\n{body_repr}".encode("utf-8")
return hashlib.sha256(seed).hexdigest()
def _header_value(headers: list[tuple[bytes, bytes]], name: bytes) -> str | None:
lname = name.lower()
for key, value in headers:
if key.lower() == lname:
return value.decode("utf-8", errors="replace")
return None
async def _send_json_response(response: JSONResponse, scope: Scope, receive: Receive, send: Send) -> None:
await response(scope, receive, send)
class WriteIdempotencyMiddleware:
"""Replay exact duplicate write requests carrying Idempotency-Key.
The middleware is intentionally narrow: it only participates on the offline
relay allowlist. Non-allowlisted routes keep their normal behavior even if a
caller sends an Idempotency-Key header.
"""
def __init__(self, app: ASGIApp, *, ttl_days: int = DEFAULT_IDEMPOTENCY_TTL_DAYS) -> None:
self.app = app
self.ttl_days = ttl_days
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return
method = str(scope.get("method", "")).upper()
path = str(scope.get("path", ""))
rule = route_rule_for(method, path)
headers = list(scope.get("headers") or [])
key = _header_value(headers, IDEMPOTENCY_HEADER)
if rule is None or not key:
await self.app(scope, receive, send)
return
body = await self._read_body(receive)
request_hash = canonical_request_hash(method, path, scope.get("query_string", b""), body)
source_host = _header_value(headers, b"x-statehub-source-host")
source_agent = _header_value(headers, b"x-statehub-source-agent")
async with async_session_factory() as session:
existing = (await session.execute(
select(WriteIdempotencyKey).where(WriteIdempotencyKey.key == key)
)).scalar_one_or_none()
if existing is not None:
existing.last_seen_at = datetime.now(tz=timezone.utc)
await session.commit()
if existing.request_hash != request_hash:
await _send_json_response(
JSONResponse(
status_code=CONFLICT_STATUS,
content={
"error": "Idempotency-Key was reused with a different request",
"idempotency_key": key,
},
),
scope,
self._receive_from_body(body),
send,
)
return
await _send_json_response(
JSONResponse(
status_code=existing.response_status,
content=existing.response_body,
headers={REPLAY_HEADER: "true"},
),
scope,
self._receive_from_body(body),
send,
)
return
start_message: Message | None = None
body_parts: list[bytes] = []
async def capture_send(message: Message) -> None:
nonlocal start_message
if message["type"] == "http.response.start":
start_message = message
elif message["type"] == "http.response.body":
body_parts.append(message.get("body", b""))
await send(message)
await self.app(scope, self._receive_from_body(body), capture_send)
if start_message is None:
return
status = int(start_message.get("status", 500))
if status < 200 or status >= 300:
return
response_body_bytes = b"".join(body_parts)
try:
response_body = json.loads(response_body_bytes.decode("utf-8")) if response_body_bytes else None
except (UnicodeDecodeError, json.JSONDecodeError):
return
async with async_session_factory() as session:
existing = (await session.execute(
select(WriteIdempotencyKey).where(WriteIdempotencyKey.key == key)
)).scalar_one_or_none()
if existing is not None:
return
now = datetime.now(tz=timezone.utc)
session.add(WriteIdempotencyKey(
key=key,
method=method,
path=path,
route_class=rule.route_class,
request_hash=request_hash,
response_status=status,
response_body=response_body,
source_host=source_host,
source_agent=source_agent,
first_seen_at=now,
last_seen_at=now,
expires_at=now + timedelta(days=self.ttl_days),
))
await session.commit()
@staticmethod
async def _read_body(receive: Receive) -> bytes:
chunks: list[bytes] = []
while True:
message = await receive()
if message["type"] != "http.request":
continue
chunks.append(message.get("body", b""))
if not message.get("more_body", False):
break
return b"".join(chunks)
@staticmethod
def _receive_from_body(body: bytes) -> Receive:
sent = False
async def receive() -> Message:
nonlocal sent
if sent:
return {"type": "http.request", "body": b"", "more_body": False}
sent = True
return {"type": "http.request", "body": body, "more_body": False}
return receive