generated from coulomb/repo-seed
feat(statehub): add offline write buffer relay
This commit is contained in:
1
api/edge/__init__.py
Normal file
1
api/edge/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""State Hub edge relay and durable outbox helpers."""
|
||||
358
api/edge/outbox.py
Normal file
358
api/edge/outbox.py
Normal file
@@ -0,0 +1,358 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
import stat
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from api.services.write_idempotency import route_class_for
|
||||
|
||||
DEFAULT_OUTBOX_PATH = Path(os.environ.get("STATEHUB_OUTBOX_PATH", "~/.statehub/edge-outbox.sqlite3")).expanduser()
|
||||
MAX_PAYLOAD_BYTES = 64 * 1024
|
||||
SECRET_FIELD_NAMES = {
|
||||
"authorization",
|
||||
"cookie",
|
||||
"set-cookie",
|
||||
"password",
|
||||
"passwd",
|
||||
"secret",
|
||||
"api_key",
|
||||
"apikey",
|
||||
"access_token",
|
||||
"refresh_token",
|
||||
"bearer_token",
|
||||
"client_secret",
|
||||
"private_key",
|
||||
"credential",
|
||||
"credentials",
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class OutboxEnvelope:
|
||||
id: str
|
||||
idempotency_key: str
|
||||
method: str
|
||||
path: str
|
||||
body: dict[str, Any] | list[Any] | None
|
||||
route_class: str
|
||||
source_agent: str | None
|
||||
source_host: str | None
|
||||
repo_slug: str | None
|
||||
session_id: str | None
|
||||
observed_revision: dict[str, Any] | None
|
||||
status: str
|
||||
attempt_count: int
|
||||
next_retry_at: str | None
|
||||
last_error: str | None
|
||||
response_status: int | None
|
||||
response_body: dict[str, Any] | list[Any] | str | None
|
||||
created_at: str
|
||||
updated_at: str
|
||||
acked_at: str | None
|
||||
|
||||
|
||||
class PayloadRejected(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
def utcnow() -> str:
|
||||
return datetime.now(tz=timezone.utc).isoformat()
|
||||
|
||||
|
||||
def default_outbox_path() -> Path:
|
||||
return DEFAULT_OUTBOX_PATH
|
||||
|
||||
|
||||
def scrub_payload(value: Any) -> Any:
|
||||
if isinstance(value, dict):
|
||||
scrubbed: dict[str, Any] = {}
|
||||
for key, item in value.items():
|
||||
normalized = str(key).lower().replace("-", "_")
|
||||
if normalized in SECRET_FIELD_NAMES:
|
||||
scrubbed[key] = "[redacted]"
|
||||
else:
|
||||
scrubbed[key] = scrub_payload(item)
|
||||
return scrubbed
|
||||
if isinstance(value, list):
|
||||
return [scrub_payload(item) for item in value]
|
||||
return value
|
||||
|
||||
|
||||
def _json_loads(raw: str | None) -> Any:
|
||||
if raw is None:
|
||||
return None
|
||||
return json.loads(raw)
|
||||
|
||||
|
||||
def _json_dumps(value: Any) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
return json.dumps(value, sort_keys=True, separators=(",", ":"))
|
||||
|
||||
|
||||
def _parse_dt(value: str | None) -> datetime | None:
|
||||
if not value:
|
||||
return None
|
||||
return datetime.fromisoformat(value)
|
||||
|
||||
|
||||
class OutboxStore:
|
||||
def __init__(self, path: str | Path | None = None) -> None:
|
||||
self.path = Path(path).expanduser() if path is not None else default_outbox_path()
|
||||
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._init_db()
|
||||
self._chmod_private()
|
||||
|
||||
def _connect(self) -> sqlite3.Connection:
|
||||
conn = sqlite3.connect(self.path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
|
||||
def _init_db(self) -> None:
|
||||
with self._connect() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS outbox_envelopes (
|
||||
id TEXT PRIMARY KEY,
|
||||
idempotency_key TEXT NOT NULL UNIQUE,
|
||||
method TEXT NOT NULL,
|
||||
path TEXT NOT NULL,
|
||||
body_json TEXT,
|
||||
route_class TEXT NOT NULL,
|
||||
source_agent TEXT,
|
||||
source_host TEXT,
|
||||
repo_slug TEXT,
|
||||
session_id TEXT,
|
||||
observed_revision_json TEXT,
|
||||
status TEXT NOT NULL,
|
||||
attempt_count INTEGER NOT NULL DEFAULT 0,
|
||||
next_retry_at TEXT,
|
||||
last_error TEXT,
|
||||
response_status INTEGER,
|
||||
response_body_json TEXT,
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL,
|
||||
acked_at TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS ix_outbox_status ON outbox_envelopes(status)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS ix_outbox_next_retry ON outbox_envelopes(next_retry_at)")
|
||||
conn.commit()
|
||||
|
||||
def _chmod_private(self) -> None:
|
||||
try:
|
||||
os.chmod(self.path, stat.S_IRUSR | stat.S_IWUSR)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def enqueue(
|
||||
self,
|
||||
*,
|
||||
method: str,
|
||||
path: str,
|
||||
body: Any,
|
||||
idempotency_key: str | None = None,
|
||||
source_agent: str | None = None,
|
||||
source_host: str | None = None,
|
||||
repo_slug: str | None = None,
|
||||
session_id: str | None = None,
|
||||
observed_revision: dict[str, Any] | None = None,
|
||||
) -> OutboxEnvelope:
|
||||
route_class = route_class_for(method, path)
|
||||
if route_class is None:
|
||||
raise PayloadRejected(f"{method.upper()} {path} is not queueable")
|
||||
scrubbed = scrub_payload(body)
|
||||
encoded = _json_dumps(scrubbed)
|
||||
if encoded is not None and len(encoded.encode("utf-8")) > MAX_PAYLOAD_BYTES:
|
||||
raise PayloadRejected("payload exceeds offline outbox size limit")
|
||||
now = utcnow()
|
||||
envelope_id = str(uuid.uuid4())
|
||||
key = idempotency_key or f"statehub-edge:{envelope_id}"
|
||||
method_upper = method.upper()
|
||||
with self._connect() as conn:
|
||||
if route_class == "replace":
|
||||
conn.execute(
|
||||
"""
|
||||
UPDATE outbox_envelopes
|
||||
SET status = 'cancelled', updated_at = ?, last_error = ?
|
||||
WHERE status = 'queued'
|
||||
AND route_class = 'replace'
|
||||
AND method = ?
|
||||
AND path = ?
|
||||
""",
|
||||
(now, f"superseded by {envelope_id}", method_upper, path),
|
||||
)
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO outbox_envelopes (
|
||||
id, idempotency_key, method, path, body_json, route_class,
|
||||
source_agent, source_host, repo_slug, session_id,
|
||||
observed_revision_json, status, created_at, updated_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, 'queued', ?, ?)
|
||||
""",
|
||||
(
|
||||
envelope_id,
|
||||
key,
|
||||
method_upper,
|
||||
path,
|
||||
encoded,
|
||||
route_class,
|
||||
source_agent,
|
||||
source_host,
|
||||
repo_slug,
|
||||
session_id,
|
||||
_json_dumps(observed_revision),
|
||||
now,
|
||||
now,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
return self.get(envelope_id)
|
||||
|
||||
def get(self, envelope_id: str) -> OutboxEnvelope:
|
||||
with self._connect() as conn:
|
||||
row = conn.execute("SELECT * FROM outbox_envelopes WHERE id = ?", (envelope_id,)).fetchone()
|
||||
if row is None:
|
||||
raise KeyError(envelope_id)
|
||||
return self._row_to_envelope(row)
|
||||
|
||||
def list(self, *, status: str | None = None, limit: int = 100) -> list[OutboxEnvelope]:
|
||||
with self._connect() as conn:
|
||||
if status:
|
||||
rows = conn.execute(
|
||||
"SELECT * FROM outbox_envelopes WHERE status = ? ORDER BY created_at LIMIT ?",
|
||||
(status, limit),
|
||||
).fetchall()
|
||||
else:
|
||||
rows = conn.execute(
|
||||
"SELECT * FROM outbox_envelopes ORDER BY created_at LIMIT ?",
|
||||
(limit,),
|
||||
).fetchall()
|
||||
return [self._row_to_envelope(row) for row in rows]
|
||||
|
||||
def due(self, *, limit: int = 50) -> list[OutboxEnvelope]:
|
||||
now = utcnow()
|
||||
with self._connect() as conn:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT * FROM outbox_envelopes
|
||||
WHERE status = 'queued' AND (next_retry_at IS NULL OR next_retry_at <= ?)
|
||||
ORDER BY created_at
|
||||
LIMIT ?
|
||||
""",
|
||||
(now, limit),
|
||||
).fetchall()
|
||||
return [self._row_to_envelope(row) for row in rows]
|
||||
|
||||
def summary(self) -> dict[str, Any]:
|
||||
with self._connect() as conn:
|
||||
rows = conn.execute(
|
||||
"SELECT status, COUNT(*) AS count, MIN(created_at) AS oldest FROM outbox_envelopes GROUP BY status"
|
||||
).fetchall()
|
||||
by_status = {row["status"]: row["count"] for row in rows}
|
||||
oldest_pending = None
|
||||
for row in rows:
|
||||
if row["status"] in {"queued", "sending", "conflict"} and row["oldest"]:
|
||||
oldest_pending = min(filter(None, [oldest_pending, row["oldest"]])) if oldest_pending else row["oldest"]
|
||||
return {
|
||||
"path": str(self.path),
|
||||
"by_status": by_status,
|
||||
"pending_count": sum(by_status.get(status, 0) for status in ("queued", "sending")),
|
||||
"conflict_count": by_status.get("conflict", 0),
|
||||
"oldest_pending_at": oldest_pending,
|
||||
}
|
||||
|
||||
def mark_sending(self, envelope_id: str) -> None:
|
||||
self._update(envelope_id, status="sending", updated_at=utcnow())
|
||||
|
||||
def mark_acked(self, envelope_id: str, *, response_status: int, response_body: Any) -> None:
|
||||
now = utcnow()
|
||||
self._update(
|
||||
envelope_id,
|
||||
status="acked",
|
||||
response_status=response_status,
|
||||
response_body_json=_json_dumps(response_body),
|
||||
updated_at=now,
|
||||
acked_at=now,
|
||||
last_error=None,
|
||||
next_retry_at=None,
|
||||
)
|
||||
|
||||
def mark_conflict(self, envelope_id: str, *, response_status: int, response_body: Any) -> None:
|
||||
self._update(
|
||||
envelope_id,
|
||||
status="conflict",
|
||||
response_status=response_status,
|
||||
response_body_json=_json_dumps(response_body),
|
||||
updated_at=utcnow(),
|
||||
last_error="conflict",
|
||||
)
|
||||
|
||||
def mark_dead(self, envelope_id: str, *, error: str, response_status: int | None = None, response_body: Any = None) -> None:
|
||||
self._update(
|
||||
envelope_id,
|
||||
status="dead",
|
||||
response_status=response_status,
|
||||
response_body_json=_json_dumps(response_body),
|
||||
updated_at=utcnow(),
|
||||
last_error=error,
|
||||
)
|
||||
|
||||
def mark_retry(self, envelope_id: str, *, error: str, attempt_count: int) -> None:
|
||||
delay_seconds = min(3600, 2 ** min(attempt_count, 10))
|
||||
next_retry = datetime.now(tz=timezone.utc) + timedelta(seconds=delay_seconds)
|
||||
self._update(
|
||||
envelope_id,
|
||||
status="queued",
|
||||
attempt_count=attempt_count,
|
||||
next_retry_at=next_retry.isoformat(),
|
||||
updated_at=utcnow(),
|
||||
last_error=error[:500],
|
||||
)
|
||||
|
||||
def retry(self, envelope_id: str) -> None:
|
||||
self._update(envelope_id, status="queued", next_retry_at=None, updated_at=utcnow())
|
||||
|
||||
def cancel(self, envelope_id: str) -> None:
|
||||
self._update(envelope_id, status="cancelled", updated_at=utcnow())
|
||||
|
||||
def export(self, *, status: str | None = None, limit: int = 1000) -> list[dict[str, Any]]:
|
||||
return [envelope.__dict__ for envelope in self.list(status=status, limit=limit)]
|
||||
|
||||
def _update(self, envelope_id: str, **values: Any) -> None:
|
||||
assignments = ", ".join(f"{key} = ?" for key in values)
|
||||
params = list(values.values()) + [envelope_id]
|
||||
with self._connect() as conn:
|
||||
conn.execute(f"UPDATE outbox_envelopes SET {assignments} WHERE id = ?", params)
|
||||
conn.commit()
|
||||
|
||||
def _row_to_envelope(self, row: sqlite3.Row) -> OutboxEnvelope:
|
||||
return OutboxEnvelope(
|
||||
id=row["id"],
|
||||
idempotency_key=row["idempotency_key"],
|
||||
method=row["method"],
|
||||
path=row["path"],
|
||||
body=_json_loads(row["body_json"]),
|
||||
route_class=row["route_class"],
|
||||
source_agent=row["source_agent"],
|
||||
source_host=row["source_host"],
|
||||
repo_slug=row["repo_slug"],
|
||||
session_id=row["session_id"],
|
||||
observed_revision=_json_loads(row["observed_revision_json"]),
|
||||
status=row["status"],
|
||||
attempt_count=row["attempt_count"],
|
||||
next_retry_at=row["next_retry_at"],
|
||||
last_error=row["last_error"],
|
||||
response_status=row["response_status"],
|
||||
response_body=_json_loads(row["response_body_json"]),
|
||||
created_at=row["created_at"],
|
||||
updated_at=row["updated_at"],
|
||||
acked_at=row["acked_at"],
|
||||
)
|
||||
206
api/edge/relay.py
Normal file
206
api/edge/relay.py
Normal file
@@ -0,0 +1,206 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import socket
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
|
||||
from api.edge.outbox import OutboxEnvelope, OutboxStore, PayloadRejected, default_outbox_path
|
||||
from api.services.write_idempotency import route_class_for
|
||||
|
||||
HOP_BY_HOP_HEADERS = {
|
||||
"connection",
|
||||
"keep-alive",
|
||||
"proxy-authenticate",
|
||||
"proxy-authorization",
|
||||
"te",
|
||||
"trailer",
|
||||
"transfer-encoding",
|
||||
"upgrade",
|
||||
"content-encoding",
|
||||
"content-length",
|
||||
}
|
||||
|
||||
|
||||
def _safe_response_headers(headers: httpx.Headers) -> dict[str, str]:
|
||||
return {key: value for key, value in headers.items() if key.lower() not in HOP_BY_HOP_HEADERS}
|
||||
|
||||
|
||||
def _body_summary(response: httpx.Response) -> Any:
|
||||
try:
|
||||
return response.json()
|
||||
except ValueError:
|
||||
return {"text": response.text[:500]}
|
||||
|
||||
|
||||
def queued_receipt(envelope: OutboxEnvelope, upstream_error: str) -> dict[str, Any]:
|
||||
return {
|
||||
"queued": True,
|
||||
"outbox_id": envelope.id,
|
||||
"idempotency_key": envelope.idempotency_key,
|
||||
"upstream": "unreachable",
|
||||
"upstream_error": upstream_error,
|
||||
"route_class": envelope.route_class,
|
||||
}
|
||||
|
||||
|
||||
async def replay_pending(
|
||||
store: OutboxStore,
|
||||
*,
|
||||
upstream_url: str,
|
||||
limit: int = 50,
|
||||
timeout: float = 10.0,
|
||||
) -> dict[str, int]:
|
||||
counts = {"sent": 0, "acked": 0, "conflict": 0, "retry": 0, "dead": 0}
|
||||
async with httpx.AsyncClient(base_url=upstream_url.rstrip("/"), timeout=timeout) as client:
|
||||
for envelope in store.due(limit=limit):
|
||||
counts["sent"] += 1
|
||||
store.mark_sending(envelope.id)
|
||||
try:
|
||||
response = await client.request(
|
||||
envelope.method,
|
||||
envelope.path,
|
||||
json=envelope.body,
|
||||
headers={
|
||||
"Idempotency-Key": envelope.idempotency_key,
|
||||
"X-StateHub-Source-Agent": envelope.source_agent or "statehub-edge",
|
||||
"X-StateHub-Source-Host": envelope.source_host or socket.gethostname(),
|
||||
},
|
||||
)
|
||||
except httpx.HTTPError as exc:
|
||||
counts["retry"] += 1
|
||||
store.mark_retry(envelope.id, error=str(exc), attempt_count=envelope.attempt_count + 1)
|
||||
continue
|
||||
|
||||
response_body = _body_summary(response)
|
||||
if response.status_code == 409:
|
||||
counts["conflict"] += 1
|
||||
store.mark_conflict(envelope.id, response_status=response.status_code, response_body=response_body)
|
||||
elif 200 <= response.status_code < 300:
|
||||
counts["acked"] += 1
|
||||
store.mark_acked(envelope.id, response_status=response.status_code, response_body=response_body)
|
||||
elif response.status_code >= 500:
|
||||
counts["retry"] += 1
|
||||
store.mark_retry(
|
||||
envelope.id,
|
||||
error=f"HTTP {response.status_code}: {response.text[:300]}",
|
||||
attempt_count=envelope.attempt_count + 1,
|
||||
)
|
||||
else:
|
||||
counts["dead"] += 1
|
||||
store.mark_dead(
|
||||
envelope.id,
|
||||
error=f"HTTP {response.status_code}: not retryable",
|
||||
response_status=response.status_code,
|
||||
response_body=response_body,
|
||||
)
|
||||
return counts
|
||||
|
||||
|
||||
def create_app(
|
||||
*,
|
||||
upstream_url: str | None = None,
|
||||
outbox_path: str | None = None,
|
||||
timeout: float = 10.0,
|
||||
) -> FastAPI:
|
||||
upstream = (upstream_url or os.environ.get("STATEHUB_UPSTREAM_URL") or os.environ.get("API_BASE") or "http://127.0.0.1:8000").rstrip("/")
|
||||
store_path = outbox_path or default_outbox_path()
|
||||
store_instance: OutboxStore | None = None
|
||||
|
||||
def get_store() -> OutboxStore:
|
||||
nonlocal store_instance
|
||||
if store_instance is None:
|
||||
store_instance = OutboxStore(store_path)
|
||||
return store_instance
|
||||
|
||||
app = FastAPI(title="State Hub Edge Relay", version="0.1.0")
|
||||
|
||||
@app.get("/edge/health")
|
||||
async def edge_health() -> dict[str, Any]:
|
||||
reachable = False
|
||||
error = None
|
||||
try:
|
||||
async with httpx.AsyncClient(base_url=upstream, timeout=2.0) as client:
|
||||
response = await client.get("/state/health")
|
||||
reachable = response.status_code < 500
|
||||
except httpx.HTTPError as exc:
|
||||
error = str(exc)
|
||||
return {
|
||||
"status": "ok",
|
||||
"upstream": upstream,
|
||||
"upstream_reachable": reachable,
|
||||
"upstream_error": error,
|
||||
"outbox": get_store().summary(),
|
||||
}
|
||||
|
||||
@app.post("/edge/replay")
|
||||
async def edge_replay(limit: int = 50) -> dict[str, int]:
|
||||
return await replay_pending(get_store(), upstream_url=upstream, limit=limit, timeout=timeout)
|
||||
|
||||
@app.api_route("/{path:path}", methods=["GET", "POST", "PATCH", "PUT", "DELETE"])
|
||||
async def proxy(path: str, request: Request) -> Response:
|
||||
api_path = "/" + path
|
||||
body: Any = None
|
||||
if request.method in {"POST", "PATCH", "PUT"}:
|
||||
try:
|
||||
body = await request.json()
|
||||
except ValueError:
|
||||
body = None
|
||||
|
||||
headers = {}
|
||||
if idempotency_key := request.headers.get("idempotency-key"):
|
||||
headers["Idempotency-Key"] = idempotency_key
|
||||
if request.headers.get("content-type"):
|
||||
headers["Content-Type"] = request.headers["content-type"]
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(base_url=upstream, timeout=timeout) as client:
|
||||
response = await client.request(
|
||||
request.method,
|
||||
api_path,
|
||||
params=request.query_params,
|
||||
json=body if body is not None else None,
|
||||
headers=headers,
|
||||
)
|
||||
return Response(
|
||||
content=response.content,
|
||||
status_code=response.status_code,
|
||||
headers=_safe_response_headers(response.headers),
|
||||
media_type=response.headers.get("content-type"),
|
||||
)
|
||||
except httpx.HTTPError as exc:
|
||||
route_class = route_class_for(request.method, api_path)
|
||||
if route_class is None or request.method not in {"POST", "PATCH"}:
|
||||
return JSONResponse(
|
||||
status_code=503,
|
||||
content={
|
||||
"error": "upstream unreachable and route is not queueable",
|
||||
"method": request.method,
|
||||
"path": api_path,
|
||||
"upstream": upstream,
|
||||
"detail": str(exc),
|
||||
},
|
||||
)
|
||||
try:
|
||||
envelope = get_store().enqueue(
|
||||
method=request.method,
|
||||
path=api_path,
|
||||
body=body,
|
||||
idempotency_key=request.headers.get("idempotency-key"),
|
||||
source_agent=request.headers.get("x-statehub-source-agent"),
|
||||
source_host=request.headers.get("x-statehub-source-host") or socket.gethostname(),
|
||||
repo_slug=request.headers.get("x-statehub-repo-slug"),
|
||||
session_id=request.headers.get("x-statehub-session-id"),
|
||||
observed_revision=None,
|
||||
)
|
||||
except PayloadRejected as reject:
|
||||
return JSONResponse(status_code=422, content={"error": str(reject)})
|
||||
return JSONResponse(status_code=202, content=queued_receipt(envelope, str(exc)))
|
||||
|
||||
return app
|
||||
|
||||
|
||||
app = create_app()
|
||||
@@ -11,6 +11,7 @@ from starlette.responses import Response as StarletteResponse
|
||||
|
||||
from api.database import engine
|
||||
from api.events import shutdown_publisher
|
||||
from api.services.write_idempotency import WriteIdempotencyMiddleware
|
||||
from api.routers import decisions, extension_points, progress, state, tasks, technical_debt, topics, workstreams, workstream_dependencies
|
||||
from api.routers import domains, repos, contributions, sbom, policy, domain_goals, repo_goals, messages, capability_requests, tpsc, services
|
||||
from api.routers import token_events
|
||||
@@ -91,13 +92,14 @@ _default_dashboard_origins = [
|
||||
_cors_env = os.getenv("CORS_ORIGINS", ",".join(_default_dashboard_origins))
|
||||
_cors_origins = [o.strip() for o in _cors_env.split(",") if o.strip()]
|
||||
|
||||
app.add_middleware(WriteIdempotencyMiddleware)
|
||||
app.add_middleware(ETagMiddleware)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=_cors_origins,
|
||||
allow_methods=["GET", "POST", "PATCH", "DELETE", "PUT"],
|
||||
allow_headers=["Content-Type", "If-None-Match"],
|
||||
expose_headers=["ETag", "X-StateHub-Elapsed-Ms", "X-StateHub-Response-Bytes", "X-StateHub-Cache"],
|
||||
allow_headers=["Content-Type", "If-None-Match", "Idempotency-Key", "X-StateHub-Source-Agent", "X-StateHub-Source-Host"],
|
||||
expose_headers=["ETag", "X-StateHub-Elapsed-Ms", "X-StateHub-Response-Bytes", "X-StateHub-Cache", "X-StateHub-Idempotency-Replay"],
|
||||
)
|
||||
|
||||
app.include_router(domains.router)
|
||||
|
||||
@@ -33,6 +33,7 @@ from api.models.interface_change import InterfaceChange
|
||||
from api.models.workplan_launch_request import WorkplanLaunchRequest
|
||||
from api.models.fabric_graph import FabricGraphImport, FabricGraphNode, FabricGraphEdge
|
||||
from api.models.legacy_meter import LegacyInterface, LegacyInterfaceUsageBucket
|
||||
from api.models.write_idempotency_key import WriteIdempotencyKey
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
@@ -65,4 +66,5 @@ __all__ = [
|
||||
"WorkplanLaunchRequest",
|
||||
"FabricGraphImport", "FabricGraphNode", "FabricGraphEdge",
|
||||
"LegacyInterface", "LegacyInterfaceUsageBucket",
|
||||
"WriteIdempotencyKey",
|
||||
]
|
||||
32
api/models/write_idempotency_key.py
Normal file
32
api/models/write_idempotency_key.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import DateTime, Integer, String, Text, UniqueConstraint
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from api.models.base import Base, new_uuid
|
||||
|
||||
|
||||
class WriteIdempotencyKey(Base):
|
||||
__tablename__ = "write_idempotency_keys"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("key", name="uq_write_idempotency_keys_key"),
|
||||
)
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=new_uuid)
|
||||
key: Mapped[str] = mapped_column(String(200), nullable=False, index=True)
|
||||
method: Mapped[str] = mapped_column(String(10), nullable=False)
|
||||
path: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
route_class: Mapped[str] = mapped_column(String(30), nullable=False)
|
||||
request_hash: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||
response_status: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
response_body: Mapped[Any] = mapped_column(JSONB, nullable=True)
|
||||
source_host: Mapped[str | None] = mapped_column(String(200), nullable=True)
|
||||
source_agent: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
first_seen_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
last_seen_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True, index=True)
|
||||
221
api/services/write_idempotency.py
Normal file
221
api/services/write_idempotency.py
Normal file
@@ -0,0 +1,221 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
|
||||
from api.database import async_session_factory
|
||||
from api.models.write_idempotency_key import WriteIdempotencyKey
|
||||
|
||||
IDEMPOTENCY_HEADER = b"idempotency-key"
|
||||
REPLAY_HEADER = "X-StateHub-Idempotency-Replay"
|
||||
CONFLICT_STATUS = 409
|
||||
DEFAULT_IDEMPOTENCY_TTL_DAYS = 14
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WriteRouteRule:
|
||||
method: str
|
||||
pattern: str
|
||||
route_class: str
|
||||
description: str
|
||||
|
||||
def matches(self, method: str, path: str) -> bool:
|
||||
normalized = path.rstrip("/") or "/"
|
||||
return self.method == method.upper() and re.fullmatch(self.pattern, normalized) is not None
|
||||
|
||||
|
||||
WRITE_ROUTE_RULES: tuple[WriteRouteRule, ...] = (
|
||||
WriteRouteRule("POST", r"/progress", "append", "append progress event"),
|
||||
WriteRouteRule("POST", r"/messages", "append", "send agent message"),
|
||||
WriteRouteRule("PATCH", r"/messages/[^/]+/read", "append", "mark known message read"),
|
||||
WriteRouteRule("POST", r"/token-events", "append", "record token event"),
|
||||
WriteRouteRule("POST", r"/token-events/upsert", "append", "upsert token event"),
|
||||
WriteRouteRule("POST", r"/decisions", "append", "record decision"),
|
||||
WriteRouteRule("PATCH", r"/tasks/[^/]+", "replace", "update task"),
|
||||
WriteRouteRule("POST", r"/tasks/bulk-status-sync", "replace", "bulk task status sync"),
|
||||
WriteRouteRule("PATCH", r"/decisions/[^/]+", "replace", "update decision"),
|
||||
WriteRouteRule("POST", r"/decisions/[^/]+/resolve", "replace", "resolve decision"),
|
||||
WriteRouteRule("PATCH", r"/workplans/[^/]+", "replace", "update workplan"),
|
||||
WriteRouteRule("PATCH", r"/workstreams/[^/]+", "replace", "update legacy workstream alias"),
|
||||
)
|
||||
|
||||
|
||||
def route_rule_for(method: str, path: str) -> WriteRouteRule | None:
|
||||
for rule in WRITE_ROUTE_RULES:
|
||||
if rule.matches(method, path):
|
||||
return rule
|
||||
return None
|
||||
|
||||
|
||||
def route_class_for(method: str, path: str) -> str | None:
|
||||
rule = route_rule_for(method, path)
|
||||
return rule.route_class if rule else None
|
||||
|
||||
|
||||
def canonical_request_hash(method: str, path: str, query_string: bytes, body: bytes) -> str:
|
||||
try:
|
||||
parsed: Any = json.loads(body.decode("utf-8")) if body else None
|
||||
body_repr = json.dumps(parsed, sort_keys=True, separators=(",", ":"))
|
||||
except (UnicodeDecodeError, json.JSONDecodeError):
|
||||
body_repr = body.hex()
|
||||
query = query_string.decode("utf-8", errors="replace")
|
||||
seed = f"{method.upper()}\n{path}\n{query}\n{body_repr}".encode("utf-8")
|
||||
return hashlib.sha256(seed).hexdigest()
|
||||
|
||||
|
||||
def _header_value(headers: list[tuple[bytes, bytes]], name: bytes) -> str | None:
|
||||
lname = name.lower()
|
||||
for key, value in headers:
|
||||
if key.lower() == lname:
|
||||
return value.decode("utf-8", errors="replace")
|
||||
return None
|
||||
|
||||
|
||||
async def _send_json_response(response: JSONResponse, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
await response(scope, receive, send)
|
||||
|
||||
|
||||
class WriteIdempotencyMiddleware:
|
||||
"""Replay exact duplicate write requests carrying Idempotency-Key.
|
||||
|
||||
The middleware is intentionally narrow: it only participates on the offline
|
||||
relay allowlist. Non-allowlisted routes keep their normal behavior even if a
|
||||
caller sends an Idempotency-Key header.
|
||||
"""
|
||||
|
||||
def __init__(self, app: ASGIApp, *, ttl_days: int = DEFAULT_IDEMPOTENCY_TTL_DAYS) -> None:
|
||||
self.app = app
|
||||
self.ttl_days = ttl_days
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
method = str(scope.get("method", "")).upper()
|
||||
path = str(scope.get("path", ""))
|
||||
rule = route_rule_for(method, path)
|
||||
headers = list(scope.get("headers") or [])
|
||||
key = _header_value(headers, IDEMPOTENCY_HEADER)
|
||||
if rule is None or not key:
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
body = await self._read_body(receive)
|
||||
request_hash = canonical_request_hash(method, path, scope.get("query_string", b""), body)
|
||||
source_host = _header_value(headers, b"x-statehub-source-host")
|
||||
source_agent = _header_value(headers, b"x-statehub-source-agent")
|
||||
|
||||
async with async_session_factory() as session:
|
||||
existing = (await session.execute(
|
||||
select(WriteIdempotencyKey).where(WriteIdempotencyKey.key == key)
|
||||
)).scalar_one_or_none()
|
||||
if existing is not None:
|
||||
existing.last_seen_at = datetime.now(tz=timezone.utc)
|
||||
await session.commit()
|
||||
if existing.request_hash != request_hash:
|
||||
await _send_json_response(
|
||||
JSONResponse(
|
||||
status_code=CONFLICT_STATUS,
|
||||
content={
|
||||
"error": "Idempotency-Key was reused with a different request",
|
||||
"idempotency_key": key,
|
||||
},
|
||||
),
|
||||
scope,
|
||||
self._receive_from_body(body),
|
||||
send,
|
||||
)
|
||||
return
|
||||
await _send_json_response(
|
||||
JSONResponse(
|
||||
status_code=existing.response_status,
|
||||
content=existing.response_body,
|
||||
headers={REPLAY_HEADER: "true"},
|
||||
),
|
||||
scope,
|
||||
self._receive_from_body(body),
|
||||
send,
|
||||
)
|
||||
return
|
||||
|
||||
start_message: Message | None = None
|
||||
body_parts: list[bytes] = []
|
||||
|
||||
async def capture_send(message: Message) -> None:
|
||||
nonlocal start_message
|
||||
if message["type"] == "http.response.start":
|
||||
start_message = message
|
||||
elif message["type"] == "http.response.body":
|
||||
body_parts.append(message.get("body", b""))
|
||||
await send(message)
|
||||
|
||||
await self.app(scope, self._receive_from_body(body), capture_send)
|
||||
|
||||
if start_message is None:
|
||||
return
|
||||
status = int(start_message.get("status", 500))
|
||||
if status < 200 or status >= 300:
|
||||
return
|
||||
|
||||
response_body_bytes = b"".join(body_parts)
|
||||
try:
|
||||
response_body = json.loads(response_body_bytes.decode("utf-8")) if response_body_bytes else None
|
||||
except (UnicodeDecodeError, json.JSONDecodeError):
|
||||
return
|
||||
|
||||
async with async_session_factory() as session:
|
||||
existing = (await session.execute(
|
||||
select(WriteIdempotencyKey).where(WriteIdempotencyKey.key == key)
|
||||
)).scalar_one_or_none()
|
||||
if existing is not None:
|
||||
return
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
session.add(WriteIdempotencyKey(
|
||||
key=key,
|
||||
method=method,
|
||||
path=path,
|
||||
route_class=rule.route_class,
|
||||
request_hash=request_hash,
|
||||
response_status=status,
|
||||
response_body=response_body,
|
||||
source_host=source_host,
|
||||
source_agent=source_agent,
|
||||
first_seen_at=now,
|
||||
last_seen_at=now,
|
||||
expires_at=now + timedelta(days=self.ttl_days),
|
||||
))
|
||||
await session.commit()
|
||||
|
||||
@staticmethod
|
||||
async def _read_body(receive: Receive) -> bytes:
|
||||
chunks: list[bytes] = []
|
||||
while True:
|
||||
message = await receive()
|
||||
if message["type"] != "http.request":
|
||||
continue
|
||||
chunks.append(message.get("body", b""))
|
||||
if not message.get("more_body", False):
|
||||
break
|
||||
return b"".join(chunks)
|
||||
|
||||
@staticmethod
|
||||
def _receive_from_body(body: bytes) -> Receive:
|
||||
sent = False
|
||||
|
||||
async def receive() -> Message:
|
||||
nonlocal sent
|
||||
if sent:
|
||||
return {"type": "http.request", "body": b"", "more_body": False}
|
||||
sent = True
|
||||
return {"type": "http.request", "body": body, "more_body": False}
|
||||
|
||||
return receive
|
||||
Reference in New Issue
Block a user