generated from coulomb/repo-seed
feat: add postgres user engine store
This commit is contained in:
180
tests/test_postgres_store_adapter.py
Normal file
180
tests/test_postgres_store_adapter.py
Normal file
@@ -0,0 +1,180 @@
|
||||
import copy
|
||||
import json
|
||||
import unittest
|
||||
from typing import Any, Iterable
|
||||
|
||||
from user_engine.adapters.postgres import PostgresUserEngineStore
|
||||
from user_engine.migrations import LATEST_SCHEMA_VERSION
|
||||
from user_engine.store_records import StoreRecord
|
||||
from user_engine.testing.store_conformance import (
|
||||
assert_user_engine_store_conformance,
|
||||
)
|
||||
|
||||
|
||||
class PostgresStoreAdapterTests(unittest.TestCase):
|
||||
def test_fake_postgres_store_satisfies_store_conformance(self):
|
||||
assert_user_engine_store_conformance(
|
||||
self,
|
||||
lambda: PostgresUserEngineStore(_FakePostgresConnection()),
|
||||
)
|
||||
|
||||
def test_ready_is_false_before_migration(self):
|
||||
store = PostgresUserEngineStore(_FakePostgresConnection())
|
||||
|
||||
self.assertFalse(store.ready)
|
||||
self.assertIsNone(store.schema_version)
|
||||
|
||||
|
||||
class _FakePostgresConnection:
|
||||
def __init__(self) -> None:
|
||||
self.schema_versions: set[str] = set()
|
||||
self.records: dict[tuple[str, str], StoreRecord] = {}
|
||||
self.audit_payloads: list[dict[str, Any]] = []
|
||||
self.outbox_payloads: list[dict[str, Any]] = []
|
||||
self._snapshot: tuple[
|
||||
set[str],
|
||||
dict[tuple[str, str], StoreRecord],
|
||||
list[dict[str, Any]],
|
||||
list[dict[str, Any]],
|
||||
] | None = None
|
||||
|
||||
def cursor(self) -> "_FakePostgresCursor":
|
||||
return _FakePostgresCursor(self)
|
||||
|
||||
def begin(self) -> None:
|
||||
self._snapshot = (
|
||||
copy.deepcopy(self.schema_versions),
|
||||
copy.deepcopy(self.records),
|
||||
copy.deepcopy(self.audit_payloads),
|
||||
copy.deepcopy(self.outbox_payloads),
|
||||
)
|
||||
|
||||
def commit(self) -> None:
|
||||
self._snapshot = None
|
||||
|
||||
def rollback(self) -> None:
|
||||
if self._snapshot is None:
|
||||
return
|
||||
(
|
||||
self.schema_versions,
|
||||
self.records,
|
||||
self.audit_payloads,
|
||||
self.outbox_payloads,
|
||||
) = self._snapshot
|
||||
self._snapshot = None
|
||||
|
||||
|
||||
class _FakePostgresCursor:
|
||||
def __init__(self, connection: _FakePostgresConnection) -> None:
|
||||
self.connection = connection
|
||||
self._rows: list[Any] = []
|
||||
|
||||
def execute(self, sql: str, params: Iterable[Any] | None = None) -> None:
|
||||
normalized = " ".join(sql.lower().split())
|
||||
values = tuple(params or ())
|
||||
|
||||
if "insert into user_engine_schema_versions" in normalized:
|
||||
self.connection.schema_versions.add(LATEST_SCHEMA_VERSION)
|
||||
self._rows = []
|
||||
return
|
||||
if "from user_engine_schema_versions" in normalized:
|
||||
self._rows = [(1,)] if values[0] in self.connection.schema_versions else []
|
||||
return
|
||||
if normalized.startswith("insert into user_engine_records"):
|
||||
payload = json.loads(values[7])
|
||||
record = StoreRecord(
|
||||
record_type=values[0],
|
||||
record_key=values[1],
|
||||
tenant=values[2],
|
||||
user_id=values[3],
|
||||
application_id=values[4],
|
||||
scope_type=values[5],
|
||||
scope_id=values[6],
|
||||
payload=payload,
|
||||
)
|
||||
self.connection.records[(record.record_type, record.record_key)] = record
|
||||
self._rows = []
|
||||
return
|
||||
if "from user_engine_records" in normalized:
|
||||
self._select_records(normalized, values)
|
||||
return
|
||||
if normalized.startswith("insert into user_engine_audit_records"):
|
||||
self.connection.audit_payloads.append(json.loads(values[8]))
|
||||
self._rows = []
|
||||
return
|
||||
if normalized.startswith("insert into user_engine_outbox_events"):
|
||||
self.connection.outbox_payloads.append(json.loads(values[5]))
|
||||
self._rows = []
|
||||
return
|
||||
if "from user_engine_audit_records" in normalized:
|
||||
if "count(*)" in normalized:
|
||||
self._rows = [(len(self.connection.audit_payloads),)]
|
||||
else:
|
||||
self._rows = [
|
||||
(json.dumps(payload),) for payload in self.connection.audit_payloads
|
||||
]
|
||||
return
|
||||
if "from user_engine_outbox_events" in normalized:
|
||||
if "count(*)" in normalized:
|
||||
self._rows = [(len(self.connection.outbox_payloads),)]
|
||||
else:
|
||||
self._rows = [
|
||||
(json.dumps(payload),) for payload in self.connection.outbox_payloads
|
||||
]
|
||||
return
|
||||
self._rows = []
|
||||
|
||||
def fetchone(self) -> Any | None:
|
||||
return self._rows[0] if self._rows else None
|
||||
|
||||
def fetchall(self) -> list[Any]:
|
||||
return self._rows
|
||||
|
||||
def close(self) -> None:
|
||||
return None
|
||||
|
||||
def _select_records(self, normalized: str, values: tuple[Any, ...]) -> None:
|
||||
if "group by record_type" in normalized:
|
||||
counts: dict[str, int] = {}
|
||||
for record_type, _record_key in self.connection.records:
|
||||
counts[record_type] = counts.get(record_type, 0) + 1
|
||||
self._rows = sorted(counts.items())
|
||||
return
|
||||
|
||||
record_type = values[0]
|
||||
filter_columns = [
|
||||
column
|
||||
for column in (
|
||||
"record_key",
|
||||
"tenant",
|
||||
"user_id",
|
||||
"application_id",
|
||||
"scope_type",
|
||||
"scope_id",
|
||||
)
|
||||
if f"{column} = %s" in normalized
|
||||
]
|
||||
filters = dict(zip(filter_columns, values[1:]))
|
||||
rows = []
|
||||
for (stored_type, _key), record in self.connection.records.items():
|
||||
if stored_type != record_type:
|
||||
continue
|
||||
if any(getattr(record, column) != value for column, value in filters.items()):
|
||||
continue
|
||||
rows.append(
|
||||
(
|
||||
record.record_type,
|
||||
record.record_key,
|
||||
record.tenant,
|
||||
record.user_id,
|
||||
record.application_id,
|
||||
record.scope_type,
|
||||
record.scope_id,
|
||||
json.dumps(record.payload),
|
||||
)
|
||||
)
|
||||
self._rows = sorted(rows, key=lambda row: row[1])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user