generated from coulomb/repo-seed
222 lines
8.3 KiB
Python
222 lines
8.3 KiB
Python
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import json
|
|
import re
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Any
|
|
|
|
from sqlalchemy import select
|
|
from starlette.responses import JSONResponse
|
|
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
|
|
|
from api.database import async_session_factory
|
|
from api.models.write_idempotency_key import WriteIdempotencyKey
|
|
|
|
IDEMPOTENCY_HEADER = b"idempotency-key"
|
|
REPLAY_HEADER = "X-StateHub-Idempotency-Replay"
|
|
CONFLICT_STATUS = 409
|
|
DEFAULT_IDEMPOTENCY_TTL_DAYS = 14
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class WriteRouteRule:
|
|
method: str
|
|
pattern: str
|
|
route_class: str
|
|
description: str
|
|
|
|
def matches(self, method: str, path: str) -> bool:
|
|
normalized = path.rstrip("/") or "/"
|
|
return self.method == method.upper() and re.fullmatch(self.pattern, normalized) is not None
|
|
|
|
|
|
WRITE_ROUTE_RULES: tuple[WriteRouteRule, ...] = (
|
|
WriteRouteRule("POST", r"/progress", "append", "append progress event"),
|
|
WriteRouteRule("POST", r"/messages", "append", "send agent message"),
|
|
WriteRouteRule("PATCH", r"/messages/[^/]+/read", "append", "mark known message read"),
|
|
WriteRouteRule("POST", r"/token-events", "append", "record token event"),
|
|
WriteRouteRule("POST", r"/token-events/upsert", "append", "upsert token event"),
|
|
WriteRouteRule("POST", r"/decisions", "append", "record decision"),
|
|
WriteRouteRule("PATCH", r"/tasks/[^/]+", "replace", "update task"),
|
|
WriteRouteRule("POST", r"/tasks/bulk-status-sync", "replace", "bulk task status sync"),
|
|
WriteRouteRule("PATCH", r"/decisions/[^/]+", "replace", "update decision"),
|
|
WriteRouteRule("POST", r"/decisions/[^/]+/resolve", "replace", "resolve decision"),
|
|
WriteRouteRule("PATCH", r"/workplans/[^/]+", "replace", "update workplan"),
|
|
WriteRouteRule("PATCH", r"/workstreams/[^/]+", "replace", "update legacy workstream alias"),
|
|
)
|
|
|
|
|
|
def route_rule_for(method: str, path: str) -> WriteRouteRule | None:
|
|
for rule in WRITE_ROUTE_RULES:
|
|
if rule.matches(method, path):
|
|
return rule
|
|
return None
|
|
|
|
|
|
def route_class_for(method: str, path: str) -> str | None:
|
|
rule = route_rule_for(method, path)
|
|
return rule.route_class if rule else None
|
|
|
|
|
|
def canonical_request_hash(method: str, path: str, query_string: bytes, body: bytes) -> str:
|
|
try:
|
|
parsed: Any = json.loads(body.decode("utf-8")) if body else None
|
|
body_repr = json.dumps(parsed, sort_keys=True, separators=(",", ":"))
|
|
except (UnicodeDecodeError, json.JSONDecodeError):
|
|
body_repr = body.hex()
|
|
query = query_string.decode("utf-8", errors="replace")
|
|
seed = f"{method.upper()}\n{path}\n{query}\n{body_repr}".encode("utf-8")
|
|
return hashlib.sha256(seed).hexdigest()
|
|
|
|
|
|
def _header_value(headers: list[tuple[bytes, bytes]], name: bytes) -> str | None:
|
|
lname = name.lower()
|
|
for key, value in headers:
|
|
if key.lower() == lname:
|
|
return value.decode("utf-8", errors="replace")
|
|
return None
|
|
|
|
|
|
async def _send_json_response(response: JSONResponse, scope: Scope, receive: Receive, send: Send) -> None:
|
|
await response(scope, receive, send)
|
|
|
|
|
|
class WriteIdempotencyMiddleware:
|
|
"""Replay exact duplicate write requests carrying Idempotency-Key.
|
|
|
|
The middleware is intentionally narrow: it only participates on the offline
|
|
relay allowlist. Non-allowlisted routes keep their normal behavior even if a
|
|
caller sends an Idempotency-Key header.
|
|
"""
|
|
|
|
def __init__(self, app: ASGIApp, *, ttl_days: int = DEFAULT_IDEMPOTENCY_TTL_DAYS) -> None:
|
|
self.app = app
|
|
self.ttl_days = ttl_days
|
|
|
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
|
if scope["type"] != "http":
|
|
await self.app(scope, receive, send)
|
|
return
|
|
|
|
method = str(scope.get("method", "")).upper()
|
|
path = str(scope.get("path", ""))
|
|
rule = route_rule_for(method, path)
|
|
headers = list(scope.get("headers") or [])
|
|
key = _header_value(headers, IDEMPOTENCY_HEADER)
|
|
if rule is None or not key:
|
|
await self.app(scope, receive, send)
|
|
return
|
|
|
|
body = await self._read_body(receive)
|
|
request_hash = canonical_request_hash(method, path, scope.get("query_string", b""), body)
|
|
source_host = _header_value(headers, b"x-statehub-source-host")
|
|
source_agent = _header_value(headers, b"x-statehub-source-agent")
|
|
|
|
async with async_session_factory() as session:
|
|
existing = (await session.execute(
|
|
select(WriteIdempotencyKey).where(WriteIdempotencyKey.key == key)
|
|
)).scalar_one_or_none()
|
|
if existing is not None:
|
|
existing.last_seen_at = datetime.now(tz=timezone.utc)
|
|
await session.commit()
|
|
if existing.request_hash != request_hash:
|
|
await _send_json_response(
|
|
JSONResponse(
|
|
status_code=CONFLICT_STATUS,
|
|
content={
|
|
"error": "Idempotency-Key was reused with a different request",
|
|
"idempotency_key": key,
|
|
},
|
|
),
|
|
scope,
|
|
self._receive_from_body(body),
|
|
send,
|
|
)
|
|
return
|
|
await _send_json_response(
|
|
JSONResponse(
|
|
status_code=existing.response_status,
|
|
content=existing.response_body,
|
|
headers={REPLAY_HEADER: "true"},
|
|
),
|
|
scope,
|
|
self._receive_from_body(body),
|
|
send,
|
|
)
|
|
return
|
|
|
|
start_message: Message | None = None
|
|
body_parts: list[bytes] = []
|
|
|
|
async def capture_send(message: Message) -> None:
|
|
nonlocal start_message
|
|
if message["type"] == "http.response.start":
|
|
start_message = message
|
|
elif message["type"] == "http.response.body":
|
|
body_parts.append(message.get("body", b""))
|
|
await send(message)
|
|
|
|
await self.app(scope, self._receive_from_body(body), capture_send)
|
|
|
|
if start_message is None:
|
|
return
|
|
status = int(start_message.get("status", 500))
|
|
if status < 200 or status >= 300:
|
|
return
|
|
|
|
response_body_bytes = b"".join(body_parts)
|
|
try:
|
|
response_body = json.loads(response_body_bytes.decode("utf-8")) if response_body_bytes else None
|
|
except (UnicodeDecodeError, json.JSONDecodeError):
|
|
return
|
|
|
|
async with async_session_factory() as session:
|
|
existing = (await session.execute(
|
|
select(WriteIdempotencyKey).where(WriteIdempotencyKey.key == key)
|
|
)).scalar_one_or_none()
|
|
if existing is not None:
|
|
return
|
|
now = datetime.now(tz=timezone.utc)
|
|
session.add(WriteIdempotencyKey(
|
|
key=key,
|
|
method=method,
|
|
path=path,
|
|
route_class=rule.route_class,
|
|
request_hash=request_hash,
|
|
response_status=status,
|
|
response_body=response_body,
|
|
source_host=source_host,
|
|
source_agent=source_agent,
|
|
first_seen_at=now,
|
|
last_seen_at=now,
|
|
expires_at=now + timedelta(days=self.ttl_days),
|
|
))
|
|
await session.commit()
|
|
|
|
@staticmethod
|
|
async def _read_body(receive: Receive) -> bytes:
|
|
chunks: list[bytes] = []
|
|
while True:
|
|
message = await receive()
|
|
if message["type"] != "http.request":
|
|
continue
|
|
chunks.append(message.get("body", b""))
|
|
if not message.get("more_body", False):
|
|
break
|
|
return b"".join(chunks)
|
|
|
|
@staticmethod
|
|
def _receive_from_body(body: bytes) -> Receive:
|
|
sent = False
|
|
|
|
async def receive() -> Message:
|
|
nonlocal sent
|
|
if sent:
|
|
return {"type": "http.request", "body": b"", "more_body": False}
|
|
sent = True
|
|
return {"type": "http.request", "body": body, "more_body": False}
|
|
|
|
return receive
|