import copy import json import unittest from typing import Any, Iterable from user_engine.adapters.postgres import PostgresUserEngineStore from user_engine.migrations import LATEST_SCHEMA_VERSION from user_engine.store_records import StoreRecord from user_engine.testing.store_conformance import ( assert_user_engine_store_conformance, ) class PostgresStoreAdapterTests(unittest.TestCase): def test_fake_postgres_store_satisfies_store_conformance(self): assert_user_engine_store_conformance( self, lambda: PostgresUserEngineStore(_FakePostgresConnection()), ) def test_ready_is_false_before_migration(self): store = PostgresUserEngineStore(_FakePostgresConnection()) self.assertFalse(store.ready) self.assertIsNone(store.schema_version) class _FakePostgresConnection: def __init__(self) -> None: self.schema_versions: set[str] = set() self.records: dict[tuple[str, str], StoreRecord] = {} self.audit_payloads: list[dict[str, Any]] = [] self.outbox_payloads: list[dict[str, Any]] = [] self._snapshot: tuple[ set[str], dict[tuple[str, str], StoreRecord], list[dict[str, Any]], list[dict[str, Any]], ] | None = None def cursor(self) -> "_FakePostgresCursor": return _FakePostgresCursor(self) def begin(self) -> None: self._snapshot = ( copy.deepcopy(self.schema_versions), copy.deepcopy(self.records), copy.deepcopy(self.audit_payloads), copy.deepcopy(self.outbox_payloads), ) def commit(self) -> None: self._snapshot = None def rollback(self) -> None: if self._snapshot is None: return ( self.schema_versions, self.records, self.audit_payloads, self.outbox_payloads, ) = self._snapshot self._snapshot = None class _FakePostgresCursor: def __init__(self, connection: _FakePostgresConnection) -> None: self.connection = connection self._rows: list[Any] = [] def execute(self, sql: str, params: Iterable[Any] | None = None) -> None: normalized = " ".join(sql.lower().split()) values = tuple(params or ()) if "insert into user_engine_schema_versions" in normalized: self.connection.schema_versions.add(LATEST_SCHEMA_VERSION) self._rows = [] return if "from user_engine_schema_versions" in normalized: self._rows = [(1,)] if values[0] in self.connection.schema_versions else [] return if normalized.startswith("insert into user_engine_records"): payload = json.loads(values[7]) record = StoreRecord( record_type=values[0], record_key=values[1], tenant=values[2], user_id=values[3], application_id=values[4], scope_type=values[5], scope_id=values[6], payload=payload, ) self.connection.records[(record.record_type, record.record_key)] = record self._rows = [] return if "from user_engine_records" in normalized: self._select_records(normalized, values) return if normalized.startswith("insert into user_engine_audit_records"): self.connection.audit_payloads.append(json.loads(values[8])) self._rows = [] return if normalized.startswith("insert into user_engine_outbox_events"): self.connection.outbox_payloads.append(json.loads(values[5])) self._rows = [] return if "from user_engine_audit_records" in normalized: if "count(*)" in normalized: self._rows = [(len(self.connection.audit_payloads),)] else: self._rows = [ (json.dumps(payload),) for payload in self.connection.audit_payloads ] return if "from user_engine_outbox_events" in normalized: if "count(*)" in normalized: self._rows = [(len(self.connection.outbox_payloads),)] else: self._rows = [ (json.dumps(payload),) for payload in self.connection.outbox_payloads ] return self._rows = [] def fetchone(self) -> Any | None: return self._rows[0] if self._rows else None def fetchall(self) -> list[Any]: return self._rows def close(self) -> None: return None def _select_records(self, normalized: str, values: tuple[Any, ...]) -> None: if "group by record_type" in normalized: counts: dict[str, int] = {} for record_type, _record_key in self.connection.records: counts[record_type] = counts.get(record_type, 0) + 1 self._rows = sorted(counts.items()) return record_type = values[0] filter_columns = [ column for column in ( "record_key", "tenant", "user_id", "application_id", "scope_type", "scope_id", ) if f"{column} = %s" in normalized ] filters = dict(zip(filter_columns, values[1:])) rows = [] for (stored_type, _key), record in self.connection.records.items(): if stored_type != record_type: continue if any(getattr(record, column) != value for column, value in filters.items()): continue rows.append( ( record.record_type, record.record_key, record.tenant, record.user_id, record.application_id, record.scope_type, record.scope_id, json.dumps(record.payload), ) ) self._rows = sorted(rows, key=lambda row: row[1]) if __name__ == "__main__": unittest.main()