from __future__ import annotations import hashlib import json import re from dataclasses import dataclass from datetime import datetime, timedelta, timezone from typing import Any from sqlalchemy import select from starlette.responses import JSONResponse from starlette.types import ASGIApp, Message, Receive, Scope, Send from api.database import async_session_factory from api.models.write_idempotency_key import WriteIdempotencyKey IDEMPOTENCY_HEADER = b"idempotency-key" REPLAY_HEADER = "X-StateHub-Idempotency-Replay" CONFLICT_STATUS = 409 DEFAULT_IDEMPOTENCY_TTL_DAYS = 14 @dataclass(frozen=True) class WriteRouteRule: method: str pattern: str route_class: str description: str def matches(self, method: str, path: str) -> bool: normalized = path.rstrip("/") or "/" return self.method == method.upper() and re.fullmatch(self.pattern, normalized) is not None WRITE_ROUTE_RULES: tuple[WriteRouteRule, ...] = ( WriteRouteRule("POST", r"/progress", "append", "append progress event"), WriteRouteRule("POST", r"/messages", "append", "send agent message"), WriteRouteRule("PATCH", r"/messages/[^/]+/read", "append", "mark known message read"), WriteRouteRule("POST", r"/token-events", "append", "record token event"), WriteRouteRule("POST", r"/token-events/upsert", "append", "upsert token event"), WriteRouteRule("POST", r"/decisions", "append", "record decision"), WriteRouteRule("PATCH", r"/tasks/[^/]+", "replace", "update task"), WriteRouteRule("POST", r"/tasks/bulk-status-sync", "replace", "bulk task status sync"), WriteRouteRule("PATCH", r"/decisions/[^/]+", "replace", "update decision"), WriteRouteRule("POST", r"/decisions/[^/]+/resolve", "replace", "resolve decision"), WriteRouteRule("PATCH", r"/workplans/[^/]+", "replace", "update workplan"), WriteRouteRule("PATCH", r"/workstreams/[^/]+", "replace", "update legacy workstream alias"), ) def route_rule_for(method: str, path: str) -> WriteRouteRule | None: for rule in WRITE_ROUTE_RULES: if rule.matches(method, path): return rule return None def route_class_for(method: str, path: str) -> str | None: rule = route_rule_for(method, path) return rule.route_class if rule else None def canonical_request_hash(method: str, path: str, query_string: bytes, body: bytes) -> str: try: parsed: Any = json.loads(body.decode("utf-8")) if body else None body_repr = json.dumps(parsed, sort_keys=True, separators=(",", ":")) except (UnicodeDecodeError, json.JSONDecodeError): body_repr = body.hex() query = query_string.decode("utf-8", errors="replace") seed = f"{method.upper()}\n{path}\n{query}\n{body_repr}".encode("utf-8") return hashlib.sha256(seed).hexdigest() def _header_value(headers: list[tuple[bytes, bytes]], name: bytes) -> str | None: lname = name.lower() for key, value in headers: if key.lower() == lname: return value.decode("utf-8", errors="replace") return None async def _send_json_response(response: JSONResponse, scope: Scope, receive: Receive, send: Send) -> None: await response(scope, receive, send) class WriteIdempotencyMiddleware: """Replay exact duplicate write requests carrying Idempotency-Key. The middleware is intentionally narrow: it only participates on the offline relay allowlist. Non-allowlisted routes keep their normal behavior even if a caller sends an Idempotency-Key header. """ def __init__(self, app: ASGIApp, *, ttl_days: int = DEFAULT_IDEMPOTENCY_TTL_DAYS) -> None: self.app = app self.ttl_days = ttl_days async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] != "http": await self.app(scope, receive, send) return method = str(scope.get("method", "")).upper() path = str(scope.get("path", "")) rule = route_rule_for(method, path) headers = list(scope.get("headers") or []) key = _header_value(headers, IDEMPOTENCY_HEADER) if rule is None or not key: await self.app(scope, receive, send) return body = await self._read_body(receive) request_hash = canonical_request_hash(method, path, scope.get("query_string", b""), body) source_host = _header_value(headers, b"x-statehub-source-host") source_agent = _header_value(headers, b"x-statehub-source-agent") async with async_session_factory() as session: existing = (await session.execute( select(WriteIdempotencyKey).where(WriteIdempotencyKey.key == key) )).scalar_one_or_none() if existing is not None: existing.last_seen_at = datetime.now(tz=timezone.utc) await session.commit() if existing.request_hash != request_hash: await _send_json_response( JSONResponse( status_code=CONFLICT_STATUS, content={ "error": "Idempotency-Key was reused with a different request", "idempotency_key": key, }, ), scope, self._receive_from_body(body), send, ) return await _send_json_response( JSONResponse( status_code=existing.response_status, content=existing.response_body, headers={REPLAY_HEADER: "true"}, ), scope, self._receive_from_body(body), send, ) return start_message: Message | None = None body_parts: list[bytes] = [] async def capture_send(message: Message) -> None: nonlocal start_message if message["type"] == "http.response.start": start_message = message elif message["type"] == "http.response.body": body_parts.append(message.get("body", b"")) await send(message) await self.app(scope, self._receive_from_body(body), capture_send) if start_message is None: return status = int(start_message.get("status", 500)) if status < 200 or status >= 300: return response_body_bytes = b"".join(body_parts) try: response_body = json.loads(response_body_bytes.decode("utf-8")) if response_body_bytes else None except (UnicodeDecodeError, json.JSONDecodeError): return async with async_session_factory() as session: existing = (await session.execute( select(WriteIdempotencyKey).where(WriteIdempotencyKey.key == key) )).scalar_one_or_none() if existing is not None: return now = datetime.now(tz=timezone.utc) session.add(WriteIdempotencyKey( key=key, method=method, path=path, route_class=rule.route_class, request_hash=request_hash, response_status=status, response_body=response_body, source_host=source_host, source_agent=source_agent, first_seen_at=now, last_seen_at=now, expires_at=now + timedelta(days=self.ttl_days), )) await session.commit() @staticmethod async def _read_body(receive: Receive) -> bytes: chunks: list[bytes] = [] while True: message = await receive() if message["type"] != "http.request": continue chunks.append(message.get("body", b"")) if not message.get("more_body", False): break return b"".join(chunks) @staticmethod def _receive_from_body(body: bytes) -> Receive: sent = False async def receive() -> Message: nonlocal sent if sent: return {"type": "http.request", "body": b"", "more_body": False} sent = True return {"type": "http.request", "body": body, "more_body": False} return receive