"""Revision-gated cache for ``GET /state/summary`` with stale-while-revalidate.""" from __future__ import annotations import asyncio import logging from dataclasses import dataclass from datetime import datetime, timezone from collections.abc import Awaitable, Callable from typing import Literal from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import noload from api.models.capability_request import CapabilityRequest from api.models.contribution import Contribution from api.models.decision import Decision from api.models.domain import Domain from api.models.extension_point import ExtensionPoint from api.models.managed_repo import ManagedRepo from api.models.progress_event import ProgressEvent from api.models.sbom_snapshot import SBOMSnapshot from api.models.task import Task from api.models.technical_debt import TechnicalDebt from api.models.topic import Topic from api.models.workplan import Workplan from api.models.workplan_dependency import WorkplanDependency from api.schemas.progress_event import ProgressEventRead from api.schemas.state import StateSummary logger = logging.getLogger(__name__) _EPOCH = datetime(1970, 1, 1, tzinfo=timezone.utc) _MAX_STALE_AGE_SECONDS = 300.0 InvalidateScope = Literal["all", "core", "progress"] CacheStatus = Literal["hit-revision", "stale", "miss", "progress-section"] BuildSummaryFn = Callable[[AsyncSession], Awaitable[StateSummary]] # Tables feeding the stable (non-progress) summary core. _CORE_TABLES: tuple[tuple[str, type], ...] = ( ("topics", Topic), ("workplans", Workplan), ("tasks", Task), ("decisions", Decision), ("workplan_dependencies", WorkplanDependency), ("managed_repos", ManagedRepo), ("contributions", Contribution), ("capability_requests", CapabilityRequest), ("domains", Domain), ("extension_points", ExtensionPoint), ("technical_debt", TechnicalDebt), ) @dataclass(frozen=True) class SummaryRevision: """Cheap fingerprints of hub data that affect ``/state/summary``.""" core: datetime progress: datetime | None sbom: datetime | None def core_fingerprint(self) -> str: return _fingerprint(self.core, self.sbom) def progress_fingerprint(self) -> str: return self.progress.isoformat() if self.progress else "" def combined_fingerprint(self) -> str: return f"{self.core_fingerprint()}|{self.progress_fingerprint()}" def _fingerprint(*parts: datetime | None) -> str: normalized = [ (part or _EPOCH).astimezone(timezone.utc).isoformat() for part in parts ] return "|".join(normalized) async def fetch_summary_revision(session: AsyncSession) -> SummaryRevision: """Return per-section revision watermarks (indexed MAX scans).""" core_parts: list[datetime] = [] for _name, model in _CORE_TABLES: value = ( await session.execute(select(func.max(model.updated_at))) ).scalar_one_or_none() if value is not None: core_parts.append(value) sbom_at = ( await session.execute(select(func.max(SBOMSnapshot.snapshot_at))) ).scalar_one_or_none() progress_at = ( await session.execute(select(func.max(ProgressEvent.created_at))) ).scalar_one_or_none() core = max(core_parts, default=_EPOCH) if sbom_at is not None and sbom_at > core: core = sbom_at return SummaryRevision(core=core, progress=progress_at, sbom=sbom_at) async def fetch_recent_progress(session: AsyncSession, *, limit: int = 20) -> list[ProgressEventRead]: rows = await session.execute( select(ProgressEvent) .options(noload("*")) .order_by(ProgressEvent.created_at.desc()) .limit(limit) ) return [ProgressEventRead.model_validate(event) for event in rows.scalars().all()] def merge_summary(core: StateSummary, recent_progress: list[ProgressEventRead]) -> StateSummary: return core.model_copy(update={"recent_progress": recent_progress}) @dataclass class _CacheEntry: summary: StateSummary core_revision: str progress_revision: str built_at: float class SummaryCache: def __init__(self) -> None: self._entry: _CacheEntry | None = None self._refresh_task: asyncio.Task | None = None self.last_error: str | None = None self._build_fn: BuildSummaryFn | None = None def configure(self, build_fn: BuildSummaryFn) -> None: self._build_fn = build_fn def reset(self) -> None: self._entry = None self.last_error = None if self._refresh_task is not None and not self._refresh_task.done(): self._refresh_task.cancel() self._refresh_task = None def invalidate(self, scope: InvalidateScope = "all") -> None: if scope == "all" or self._entry is None: self.reset() return if scope == "core": self.reset() elif scope == "progress": self._entry.progress_revision = "__invalid__" def store(self, summary: StateSummary, revision: SummaryRevision) -> None: import time self._entry = _CacheEntry( summary=summary, core_revision=revision.core_fingerprint(), progress_revision=revision.progress_fingerprint(), built_at=time.monotonic(), ) self.last_error = None def _entry_age(self) -> float | None: import time if self._entry is None: return None return time.monotonic() - self._entry.built_at def _entry_matches(self, revision: SummaryRevision) -> tuple[bool, bool]: if self._entry is None: return False, False core_match = self._entry.core_revision == revision.core_fingerprint() progress_match = self._entry.progress_revision == revision.progress_fingerprint() return core_match, progress_match def resolve( self, revision: SummaryRevision, *, force_refresh: bool, ) -> tuple[CacheStatus, StateSummary | None]: import time if force_refresh: return "miss", None if self._entry is None: return "miss", None age = self._entry_age() if age is not None and age > _MAX_STALE_AGE_SECONDS: return "miss", None core_match, progress_match = self._entry_matches(revision) if core_match and progress_match: return "hit-revision", self._entry.summary if core_match and not progress_match: return "progress-section", self._entry.summary # Core changed — serve stale full snapshot while refreshing. return "stale", self._entry.summary def schedule_refresh(self, revision: SummaryRevision) -> None: if self._build_fn is None: return if self._refresh_task is not None and not self._refresh_task.done(): return self._refresh_task = asyncio.create_task( self._refresh_background(revision), name="summary-cache-refresh", ) async def _refresh_background(self, revision: SummaryRevision) -> None: from api.database import async_session_factory if self._build_fn is None: return try: async with async_session_factory() as session: current = await fetch_summary_revision(session) summary = await self._build_fn(session) self.store(summary, current) except Exception as exc: self.last_error = str(exc) logger.exception("summary cache background refresh failed") _summary_cache = SummaryCache() def get_summary_cache() -> SummaryCache: return _summary_cache def invalidate_summary_cache(scope: InvalidateScope = "all") -> None: _summary_cache.invalidate(scope) def reset_summary_cache_for_tests() -> None: _summary_cache.reset() _INVALIDATION_REGISTERED = False def register_summary_cache_invalidation() -> None: """Clear summary cache when ORM rows that affect summary are written.""" global _INVALIDATION_REGISTERED if _INVALIDATION_REGISTERED: return _INVALIDATION_REGISTERED = True from sqlalchemy import event def _invalidate_core(*_args: object, **_kwargs: object) -> None: invalidate_summary_cache("core") def _invalidate_progress(*_args: object, **_kwargs: object) -> None: invalidate_summary_cache("progress") for _name, model in _CORE_TABLES: event.listen(model, "after_insert", _invalidate_core) event.listen(model, "after_update", _invalidate_core) event.listen(model, "after_delete", _invalidate_core) event.listen(SBOMSnapshot, "after_insert", _invalidate_core) event.listen(SBOMSnapshot, "after_delete", _invalidate_core) event.listen(ProgressEvent, "after_insert", _invalidate_progress) async def apply_progress_section( session: AsyncSession, summary: StateSummary, revision: SummaryRevision, ) -> StateSummary: recent = await fetch_recent_progress(session) merged = merge_summary(summary, recent) cache = get_summary_cache() if cache._entry is not None and cache._entry.core_revision == revision.core_fingerprint(): cache._entry.summary = merged cache._entry.progress_revision = revision.progress_fingerprint() else: cache.store(merged, revision) return merged