Files
state-hub/api/services/write_idempotency.py

222 lines
8.3 KiB
Python

from __future__ import annotations
import hashlib
import json
import re
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Any
from sqlalchemy import select
from starlette.responses import JSONResponse
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from api.database import async_session_factory
from api.models.write_idempotency_key import WriteIdempotencyKey
IDEMPOTENCY_HEADER = b"idempotency-key"
REPLAY_HEADER = "X-StateHub-Idempotency-Replay"
CONFLICT_STATUS = 409
DEFAULT_IDEMPOTENCY_TTL_DAYS = 14
@dataclass(frozen=True)
class WriteRouteRule:
method: str
pattern: str
route_class: str
description: str
def matches(self, method: str, path: str) -> bool:
normalized = path.rstrip("/") or "/"
return self.method == method.upper() and re.fullmatch(self.pattern, normalized) is not None
WRITE_ROUTE_RULES: tuple[WriteRouteRule, ...] = (
WriteRouteRule("POST", r"/progress", "append", "append progress event"),
WriteRouteRule("POST", r"/messages", "append", "send agent message"),
WriteRouteRule("PATCH", r"/messages/[^/]+/read", "append", "mark known message read"),
WriteRouteRule("POST", r"/token-events", "append", "record token event"),
WriteRouteRule("POST", r"/token-events/upsert", "append", "upsert token event"),
WriteRouteRule("POST", r"/decisions", "append", "record decision"),
WriteRouteRule("PATCH", r"/tasks/[^/]+", "replace", "update task"),
WriteRouteRule("POST", r"/tasks/bulk-status-sync", "replace", "bulk task status sync"),
WriteRouteRule("PATCH", r"/decisions/[^/]+", "replace", "update decision"),
WriteRouteRule("POST", r"/decisions/[^/]+/resolve", "replace", "resolve decision"),
WriteRouteRule("PATCH", r"/workplans/[^/]+", "replace", "update workplan"),
WriteRouteRule("PATCH", r"/workstreams/[^/]+", "replace", "update legacy workstream alias"),
)
def route_rule_for(method: str, path: str) -> WriteRouteRule | None:
for rule in WRITE_ROUTE_RULES:
if rule.matches(method, path):
return rule
return None
def route_class_for(method: str, path: str) -> str | None:
rule = route_rule_for(method, path)
return rule.route_class if rule else None
def canonical_request_hash(method: str, path: str, query_string: bytes, body: bytes) -> str:
try:
parsed: Any = json.loads(body.decode("utf-8")) if body else None
body_repr = json.dumps(parsed, sort_keys=True, separators=(",", ":"))
except (UnicodeDecodeError, json.JSONDecodeError):
body_repr = body.hex()
query = query_string.decode("utf-8", errors="replace")
seed = f"{method.upper()}\n{path}\n{query}\n{body_repr}".encode("utf-8")
return hashlib.sha256(seed).hexdigest()
def _header_value(headers: list[tuple[bytes, bytes]], name: bytes) -> str | None:
lname = name.lower()
for key, value in headers:
if key.lower() == lname:
return value.decode("utf-8", errors="replace")
return None
async def _send_json_response(response: JSONResponse, scope: Scope, receive: Receive, send: Send) -> None:
await response(scope, receive, send)
class WriteIdempotencyMiddleware:
"""Replay exact duplicate write requests carrying Idempotency-Key.
The middleware is intentionally narrow: it only participates on the offline
relay allowlist. Non-allowlisted routes keep their normal behavior even if a
caller sends an Idempotency-Key header.
"""
def __init__(self, app: ASGIApp, *, ttl_days: int = DEFAULT_IDEMPOTENCY_TTL_DAYS) -> None:
self.app = app
self.ttl_days = ttl_days
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
await self.app(scope, receive, send)
return
method = str(scope.get("method", "")).upper()
path = str(scope.get("path", ""))
rule = route_rule_for(method, path)
headers = list(scope.get("headers") or [])
key = _header_value(headers, IDEMPOTENCY_HEADER)
if rule is None or not key:
await self.app(scope, receive, send)
return
body = await self._read_body(receive)
request_hash = canonical_request_hash(method, path, scope.get("query_string", b""), body)
source_host = _header_value(headers, b"x-statehub-source-host")
source_agent = _header_value(headers, b"x-statehub-source-agent")
async with async_session_factory() as session:
existing = (await session.execute(
select(WriteIdempotencyKey).where(WriteIdempotencyKey.key == key)
)).scalar_one_or_none()
if existing is not None:
existing.last_seen_at = datetime.now(tz=timezone.utc)
await session.commit()
if existing.request_hash != request_hash:
await _send_json_response(
JSONResponse(
status_code=CONFLICT_STATUS,
content={
"error": "Idempotency-Key was reused with a different request",
"idempotency_key": key,
},
),
scope,
self._receive_from_body(body),
send,
)
return
await _send_json_response(
JSONResponse(
status_code=existing.response_status,
content=existing.response_body,
headers={REPLAY_HEADER: "true"},
),
scope,
self._receive_from_body(body),
send,
)
return
start_message: Message | None = None
body_parts: list[bytes] = []
async def capture_send(message: Message) -> None:
nonlocal start_message
if message["type"] == "http.response.start":
start_message = message
elif message["type"] == "http.response.body":
body_parts.append(message.get("body", b""))
await send(message)
await self.app(scope, self._receive_from_body(body), capture_send)
if start_message is None:
return
status = int(start_message.get("status", 500))
if status < 200 or status >= 300:
return
response_body_bytes = b"".join(body_parts)
try:
response_body = json.loads(response_body_bytes.decode("utf-8")) if response_body_bytes else None
except (UnicodeDecodeError, json.JSONDecodeError):
return
async with async_session_factory() as session:
existing = (await session.execute(
select(WriteIdempotencyKey).where(WriteIdempotencyKey.key == key)
)).scalar_one_or_none()
if existing is not None:
return
now = datetime.now(tz=timezone.utc)
session.add(WriteIdempotencyKey(
key=key,
method=method,
path=path,
route_class=rule.route_class,
request_hash=request_hash,
response_status=status,
response_body=response_body,
source_host=source_host,
source_agent=source_agent,
first_seen_at=now,
last_seen_at=now,
expires_at=now + timedelta(days=self.ttl_days),
))
await session.commit()
@staticmethod
async def _read_body(receive: Receive) -> bytes:
chunks: list[bytes] = []
while True:
message = await receive()
if message["type"] != "http.request":
continue
chunks.append(message.get("body", b""))
if not message.get("more_body", False):
break
return b"".join(chunks)
@staticmethod
def _receive_from_body(body: bytes) -> Receive:
sent = False
async def receive() -> Message:
nonlocal sent
if sent:
return {"type": "http.request", "body": b"", "more_body": False}
sent = True
return {"type": "http.request", "body": body, "more_body": False}
return receive