Files
user-engine/tests/test_postgres_store_adapter.py

181 lines
6.1 KiB
Python

import copy
import json
import unittest
from typing import Any, Iterable
from user_engine.adapters.postgres import PostgresUserEngineStore
from user_engine.migrations import LATEST_SCHEMA_VERSION
from user_engine.store_records import StoreRecord
from user_engine.testing.store_conformance import (
assert_user_engine_store_conformance,
)
class PostgresStoreAdapterTests(unittest.TestCase):
def test_fake_postgres_store_satisfies_store_conformance(self):
assert_user_engine_store_conformance(
self,
lambda: PostgresUserEngineStore(_FakePostgresConnection()),
)
def test_ready_is_false_before_migration(self):
store = PostgresUserEngineStore(_FakePostgresConnection())
self.assertFalse(store.ready)
self.assertIsNone(store.schema_version)
class _FakePostgresConnection:
def __init__(self) -> None:
self.schema_versions: set[str] = set()
self.records: dict[tuple[str, str], StoreRecord] = {}
self.audit_payloads: list[dict[str, Any]] = []
self.outbox_payloads: list[dict[str, Any]] = []
self._snapshot: tuple[
set[str],
dict[tuple[str, str], StoreRecord],
list[dict[str, Any]],
list[dict[str, Any]],
] | None = None
def cursor(self) -> "_FakePostgresCursor":
return _FakePostgresCursor(self)
def begin(self) -> None:
self._snapshot = (
copy.deepcopy(self.schema_versions),
copy.deepcopy(self.records),
copy.deepcopy(self.audit_payloads),
copy.deepcopy(self.outbox_payloads),
)
def commit(self) -> None:
self._snapshot = None
def rollback(self) -> None:
if self._snapshot is None:
return
(
self.schema_versions,
self.records,
self.audit_payloads,
self.outbox_payloads,
) = self._snapshot
self._snapshot = None
class _FakePostgresCursor:
def __init__(self, connection: _FakePostgresConnection) -> None:
self.connection = connection
self._rows: list[Any] = []
def execute(self, sql: str, params: Iterable[Any] | None = None) -> None:
normalized = " ".join(sql.lower().split())
values = tuple(params or ())
if "insert into user_engine_schema_versions" in normalized:
self.connection.schema_versions.add(LATEST_SCHEMA_VERSION)
self._rows = []
return
if "from user_engine_schema_versions" in normalized:
self._rows = [(1,)] if values[0] in self.connection.schema_versions else []
return
if normalized.startswith("insert into user_engine_records"):
payload = json.loads(values[7])
record = StoreRecord(
record_type=values[0],
record_key=values[1],
tenant=values[2],
user_id=values[3],
application_id=values[4],
scope_type=values[5],
scope_id=values[6],
payload=payload,
)
self.connection.records[(record.record_type, record.record_key)] = record
self._rows = []
return
if "from user_engine_records" in normalized:
self._select_records(normalized, values)
return
if normalized.startswith("insert into user_engine_audit_records"):
self.connection.audit_payloads.append(json.loads(values[8]))
self._rows = []
return
if normalized.startswith("insert into user_engine_outbox_events"):
self.connection.outbox_payloads.append(json.loads(values[5]))
self._rows = []
return
if "from user_engine_audit_records" in normalized:
if "count(*)" in normalized:
self._rows = [(len(self.connection.audit_payloads),)]
else:
self._rows = [
(json.dumps(payload),) for payload in self.connection.audit_payloads
]
return
if "from user_engine_outbox_events" in normalized:
if "count(*)" in normalized:
self._rows = [(len(self.connection.outbox_payloads),)]
else:
self._rows = [
(json.dumps(payload),) for payload in self.connection.outbox_payloads
]
return
self._rows = []
def fetchone(self) -> Any | None:
return self._rows[0] if self._rows else None
def fetchall(self) -> list[Any]:
return self._rows
def close(self) -> None:
return None
def _select_records(self, normalized: str, values: tuple[Any, ...]) -> None:
if "group by record_type" in normalized:
counts: dict[str, int] = {}
for record_type, _record_key in self.connection.records:
counts[record_type] = counts.get(record_type, 0) + 1
self._rows = sorted(counts.items())
return
record_type = values[0]
filter_columns = [
column
for column in (
"record_key",
"tenant",
"user_id",
"application_id",
"scope_type",
"scope_id",
)
if f"{column} = %s" in normalized
]
filters = dict(zip(filter_columns, values[1:]))
rows = []
for (stored_type, _key), record in self.connection.records.items():
if stored_type != record_type:
continue
if any(getattr(record, column) != value for column, value in filters.items()):
continue
rows.append(
(
record.record_type,
record.record_key,
record.tenant,
record.user_id,
record.application_id,
record.scope_type,
record.scope_id,
json.dumps(record.payload),
)
)
self._rows = sorted(rows, key=lambda row: row[1])
if __name__ == "__main__":
unittest.main()