generated from coulomb/repo-seed
181 lines
6.1 KiB
Python
181 lines
6.1 KiB
Python
import copy
|
|
import json
|
|
import unittest
|
|
from typing import Any, Iterable
|
|
|
|
from user_engine.adapters.postgres import PostgresUserEngineStore
|
|
from user_engine.migrations import LATEST_SCHEMA_VERSION
|
|
from user_engine.store_records import StoreRecord
|
|
from user_engine.testing.store_conformance import (
|
|
assert_user_engine_store_conformance,
|
|
)
|
|
|
|
|
|
class PostgresStoreAdapterTests(unittest.TestCase):
|
|
def test_fake_postgres_store_satisfies_store_conformance(self):
|
|
assert_user_engine_store_conformance(
|
|
self,
|
|
lambda: PostgresUserEngineStore(_FakePostgresConnection()),
|
|
)
|
|
|
|
def test_ready_is_false_before_migration(self):
|
|
store = PostgresUserEngineStore(_FakePostgresConnection())
|
|
|
|
self.assertFalse(store.ready)
|
|
self.assertIsNone(store.schema_version)
|
|
|
|
|
|
class _FakePostgresConnection:
|
|
def __init__(self) -> None:
|
|
self.schema_versions: set[str] = set()
|
|
self.records: dict[tuple[str, str], StoreRecord] = {}
|
|
self.audit_payloads: list[dict[str, Any]] = []
|
|
self.outbox_payloads: list[dict[str, Any]] = []
|
|
self._snapshot: tuple[
|
|
set[str],
|
|
dict[tuple[str, str], StoreRecord],
|
|
list[dict[str, Any]],
|
|
list[dict[str, Any]],
|
|
] | None = None
|
|
|
|
def cursor(self) -> "_FakePostgresCursor":
|
|
return _FakePostgresCursor(self)
|
|
|
|
def begin(self) -> None:
|
|
self._snapshot = (
|
|
copy.deepcopy(self.schema_versions),
|
|
copy.deepcopy(self.records),
|
|
copy.deepcopy(self.audit_payloads),
|
|
copy.deepcopy(self.outbox_payloads),
|
|
)
|
|
|
|
def commit(self) -> None:
|
|
self._snapshot = None
|
|
|
|
def rollback(self) -> None:
|
|
if self._snapshot is None:
|
|
return
|
|
(
|
|
self.schema_versions,
|
|
self.records,
|
|
self.audit_payloads,
|
|
self.outbox_payloads,
|
|
) = self._snapshot
|
|
self._snapshot = None
|
|
|
|
|
|
class _FakePostgresCursor:
|
|
def __init__(self, connection: _FakePostgresConnection) -> None:
|
|
self.connection = connection
|
|
self._rows: list[Any] = []
|
|
|
|
def execute(self, sql: str, params: Iterable[Any] | None = None) -> None:
|
|
normalized = " ".join(sql.lower().split())
|
|
values = tuple(params or ())
|
|
|
|
if "insert into user_engine_schema_versions" in normalized:
|
|
self.connection.schema_versions.add(LATEST_SCHEMA_VERSION)
|
|
self._rows = []
|
|
return
|
|
if "from user_engine_schema_versions" in normalized:
|
|
self._rows = [(1,)] if values[0] in self.connection.schema_versions else []
|
|
return
|
|
if normalized.startswith("insert into user_engine_records"):
|
|
payload = json.loads(values[7])
|
|
record = StoreRecord(
|
|
record_type=values[0],
|
|
record_key=values[1],
|
|
tenant=values[2],
|
|
user_id=values[3],
|
|
application_id=values[4],
|
|
scope_type=values[5],
|
|
scope_id=values[6],
|
|
payload=payload,
|
|
)
|
|
self.connection.records[(record.record_type, record.record_key)] = record
|
|
self._rows = []
|
|
return
|
|
if "from user_engine_records" in normalized:
|
|
self._select_records(normalized, values)
|
|
return
|
|
if normalized.startswith("insert into user_engine_audit_records"):
|
|
self.connection.audit_payloads.append(json.loads(values[8]))
|
|
self._rows = []
|
|
return
|
|
if normalized.startswith("insert into user_engine_outbox_events"):
|
|
self.connection.outbox_payloads.append(json.loads(values[5]))
|
|
self._rows = []
|
|
return
|
|
if "from user_engine_audit_records" in normalized:
|
|
if "count(*)" in normalized:
|
|
self._rows = [(len(self.connection.audit_payloads),)]
|
|
else:
|
|
self._rows = [
|
|
(json.dumps(payload),) for payload in self.connection.audit_payloads
|
|
]
|
|
return
|
|
if "from user_engine_outbox_events" in normalized:
|
|
if "count(*)" in normalized:
|
|
self._rows = [(len(self.connection.outbox_payloads),)]
|
|
else:
|
|
self._rows = [
|
|
(json.dumps(payload),) for payload in self.connection.outbox_payloads
|
|
]
|
|
return
|
|
self._rows = []
|
|
|
|
def fetchone(self) -> Any | None:
|
|
return self._rows[0] if self._rows else None
|
|
|
|
def fetchall(self) -> list[Any]:
|
|
return self._rows
|
|
|
|
def close(self) -> None:
|
|
return None
|
|
|
|
def _select_records(self, normalized: str, values: tuple[Any, ...]) -> None:
|
|
if "group by record_type" in normalized:
|
|
counts: dict[str, int] = {}
|
|
for record_type, _record_key in self.connection.records:
|
|
counts[record_type] = counts.get(record_type, 0) + 1
|
|
self._rows = sorted(counts.items())
|
|
return
|
|
|
|
record_type = values[0]
|
|
filter_columns = [
|
|
column
|
|
for column in (
|
|
"record_key",
|
|
"tenant",
|
|
"user_id",
|
|
"application_id",
|
|
"scope_type",
|
|
"scope_id",
|
|
)
|
|
if f"{column} = %s" in normalized
|
|
]
|
|
filters = dict(zip(filter_columns, values[1:]))
|
|
rows = []
|
|
for (stored_type, _key), record in self.connection.records.items():
|
|
if stored_type != record_type:
|
|
continue
|
|
if any(getattr(record, column) != value for column, value in filters.items()):
|
|
continue
|
|
rows.append(
|
|
(
|
|
record.record_type,
|
|
record.record_key,
|
|
record.tenant,
|
|
record.user_id,
|
|
record.application_id,
|
|
record.scope_type,
|
|
record.scope_id,
|
|
json.dumps(record.payload),
|
|
)
|
|
)
|
|
self._rows = sorted(rows, key=lambda row: row[1])
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|