generated from coulomb/repo-seed
feat(statehub): add offline write buffer relay
This commit is contained in:
221
api/services/write_idempotency.py
Normal file
221
api/services/write_idempotency.py
Normal file
@@ -0,0 +1,221 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||||
|
||||
from api.database import async_session_factory
|
||||
from api.models.write_idempotency_key import WriteIdempotencyKey
|
||||
|
||||
IDEMPOTENCY_HEADER = b"idempotency-key"
|
||||
REPLAY_HEADER = "X-StateHub-Idempotency-Replay"
|
||||
CONFLICT_STATUS = 409
|
||||
DEFAULT_IDEMPOTENCY_TTL_DAYS = 14
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WriteRouteRule:
|
||||
method: str
|
||||
pattern: str
|
||||
route_class: str
|
||||
description: str
|
||||
|
||||
def matches(self, method: str, path: str) -> bool:
|
||||
normalized = path.rstrip("/") or "/"
|
||||
return self.method == method.upper() and re.fullmatch(self.pattern, normalized) is not None
|
||||
|
||||
|
||||
WRITE_ROUTE_RULES: tuple[WriteRouteRule, ...] = (
|
||||
WriteRouteRule("POST", r"/progress", "append", "append progress event"),
|
||||
WriteRouteRule("POST", r"/messages", "append", "send agent message"),
|
||||
WriteRouteRule("PATCH", r"/messages/[^/]+/read", "append", "mark known message read"),
|
||||
WriteRouteRule("POST", r"/token-events", "append", "record token event"),
|
||||
WriteRouteRule("POST", r"/token-events/upsert", "append", "upsert token event"),
|
||||
WriteRouteRule("POST", r"/decisions", "append", "record decision"),
|
||||
WriteRouteRule("PATCH", r"/tasks/[^/]+", "replace", "update task"),
|
||||
WriteRouteRule("POST", r"/tasks/bulk-status-sync", "replace", "bulk task status sync"),
|
||||
WriteRouteRule("PATCH", r"/decisions/[^/]+", "replace", "update decision"),
|
||||
WriteRouteRule("POST", r"/decisions/[^/]+/resolve", "replace", "resolve decision"),
|
||||
WriteRouteRule("PATCH", r"/workplans/[^/]+", "replace", "update workplan"),
|
||||
WriteRouteRule("PATCH", r"/workstreams/[^/]+", "replace", "update legacy workstream alias"),
|
||||
)
|
||||
|
||||
|
||||
def route_rule_for(method: str, path: str) -> WriteRouteRule | None:
|
||||
for rule in WRITE_ROUTE_RULES:
|
||||
if rule.matches(method, path):
|
||||
return rule
|
||||
return None
|
||||
|
||||
|
||||
def route_class_for(method: str, path: str) -> str | None:
|
||||
rule = route_rule_for(method, path)
|
||||
return rule.route_class if rule else None
|
||||
|
||||
|
||||
def canonical_request_hash(method: str, path: str, query_string: bytes, body: bytes) -> str:
|
||||
try:
|
||||
parsed: Any = json.loads(body.decode("utf-8")) if body else None
|
||||
body_repr = json.dumps(parsed, sort_keys=True, separators=(",", ":"))
|
||||
except (UnicodeDecodeError, json.JSONDecodeError):
|
||||
body_repr = body.hex()
|
||||
query = query_string.decode("utf-8", errors="replace")
|
||||
seed = f"{method.upper()}\n{path}\n{query}\n{body_repr}".encode("utf-8")
|
||||
return hashlib.sha256(seed).hexdigest()
|
||||
|
||||
|
||||
def _header_value(headers: list[tuple[bytes, bytes]], name: bytes) -> str | None:
|
||||
lname = name.lower()
|
||||
for key, value in headers:
|
||||
if key.lower() == lname:
|
||||
return value.decode("utf-8", errors="replace")
|
||||
return None
|
||||
|
||||
|
||||
async def _send_json_response(response: JSONResponse, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
await response(scope, receive, send)
|
||||
|
||||
|
||||
class WriteIdempotencyMiddleware:
|
||||
"""Replay exact duplicate write requests carrying Idempotency-Key.
|
||||
|
||||
The middleware is intentionally narrow: it only participates on the offline
|
||||
relay allowlist. Non-allowlisted routes keep their normal behavior even if a
|
||||
caller sends an Idempotency-Key header.
|
||||
"""
|
||||
|
||||
def __init__(self, app: ASGIApp, *, ttl_days: int = DEFAULT_IDEMPOTENCY_TTL_DAYS) -> None:
|
||||
self.app = app
|
||||
self.ttl_days = ttl_days
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] != "http":
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
method = str(scope.get("method", "")).upper()
|
||||
path = str(scope.get("path", ""))
|
||||
rule = route_rule_for(method, path)
|
||||
headers = list(scope.get("headers") or [])
|
||||
key = _header_value(headers, IDEMPOTENCY_HEADER)
|
||||
if rule is None or not key:
|
||||
await self.app(scope, receive, send)
|
||||
return
|
||||
|
||||
body = await self._read_body(receive)
|
||||
request_hash = canonical_request_hash(method, path, scope.get("query_string", b""), body)
|
||||
source_host = _header_value(headers, b"x-statehub-source-host")
|
||||
source_agent = _header_value(headers, b"x-statehub-source-agent")
|
||||
|
||||
async with async_session_factory() as session:
|
||||
existing = (await session.execute(
|
||||
select(WriteIdempotencyKey).where(WriteIdempotencyKey.key == key)
|
||||
)).scalar_one_or_none()
|
||||
if existing is not None:
|
||||
existing.last_seen_at = datetime.now(tz=timezone.utc)
|
||||
await session.commit()
|
||||
if existing.request_hash != request_hash:
|
||||
await _send_json_response(
|
||||
JSONResponse(
|
||||
status_code=CONFLICT_STATUS,
|
||||
content={
|
||||
"error": "Idempotency-Key was reused with a different request",
|
||||
"idempotency_key": key,
|
||||
},
|
||||
),
|
||||
scope,
|
||||
self._receive_from_body(body),
|
||||
send,
|
||||
)
|
||||
return
|
||||
await _send_json_response(
|
||||
JSONResponse(
|
||||
status_code=existing.response_status,
|
||||
content=existing.response_body,
|
||||
headers={REPLAY_HEADER: "true"},
|
||||
),
|
||||
scope,
|
||||
self._receive_from_body(body),
|
||||
send,
|
||||
)
|
||||
return
|
||||
|
||||
start_message: Message | None = None
|
||||
body_parts: list[bytes] = []
|
||||
|
||||
async def capture_send(message: Message) -> None:
|
||||
nonlocal start_message
|
||||
if message["type"] == "http.response.start":
|
||||
start_message = message
|
||||
elif message["type"] == "http.response.body":
|
||||
body_parts.append(message.get("body", b""))
|
||||
await send(message)
|
||||
|
||||
await self.app(scope, self._receive_from_body(body), capture_send)
|
||||
|
||||
if start_message is None:
|
||||
return
|
||||
status = int(start_message.get("status", 500))
|
||||
if status < 200 or status >= 300:
|
||||
return
|
||||
|
||||
response_body_bytes = b"".join(body_parts)
|
||||
try:
|
||||
response_body = json.loads(response_body_bytes.decode("utf-8")) if response_body_bytes else None
|
||||
except (UnicodeDecodeError, json.JSONDecodeError):
|
||||
return
|
||||
|
||||
async with async_session_factory() as session:
|
||||
existing = (await session.execute(
|
||||
select(WriteIdempotencyKey).where(WriteIdempotencyKey.key == key)
|
||||
)).scalar_one_or_none()
|
||||
if existing is not None:
|
||||
return
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
session.add(WriteIdempotencyKey(
|
||||
key=key,
|
||||
method=method,
|
||||
path=path,
|
||||
route_class=rule.route_class,
|
||||
request_hash=request_hash,
|
||||
response_status=status,
|
||||
response_body=response_body,
|
||||
source_host=source_host,
|
||||
source_agent=source_agent,
|
||||
first_seen_at=now,
|
||||
last_seen_at=now,
|
||||
expires_at=now + timedelta(days=self.ttl_days),
|
||||
))
|
||||
await session.commit()
|
||||
|
||||
@staticmethod
|
||||
async def _read_body(receive: Receive) -> bytes:
|
||||
chunks: list[bytes] = []
|
||||
while True:
|
||||
message = await receive()
|
||||
if message["type"] != "http.request":
|
||||
continue
|
||||
chunks.append(message.get("body", b""))
|
||||
if not message.get("more_body", False):
|
||||
break
|
||||
return b"".join(chunks)
|
||||
|
||||
@staticmethod
|
||||
def _receive_from_body(body: bytes) -> Receive:
|
||||
sent = False
|
||||
|
||||
async def receive() -> Message:
|
||||
nonlocal sent
|
||||
if sent:
|
||||
return {"type": "http.request", "body": b"", "more_body": False}
|
||||
sent = True
|
||||
return {"type": "http.request", "body": body, "more_body": False}
|
||||
|
||||
return receive
|
||||
Reference in New Issue
Block a user