feat: add postgres user engine store

This commit is contained in:
2026-06-16 07:14:37 +02:00
parent c494511a2e
commit 0d50ad294d
8 changed files with 939 additions and 3 deletions

View File

@@ -4,8 +4,10 @@ from user_engine.adapters.local import (
InMemoryUserEngineStore,
LocalAuthorizationCheckPort,
)
from user_engine.adapters.postgres import PostgresUserEngineStore
__all__ = [
"InMemoryUserEngineStore",
"LocalAuthorizationCheckPort",
"PostgresUserEngineStore",
]

View File

@@ -0,0 +1,626 @@
"""Postgres-backed store adapter.
The adapter is dependency-free: callers provide a DB-API or psycopg-like
connection object. Provider repositories remain responsible for creating,
pooling, securing, and observing those connections.
"""
from __future__ import annotations
import json
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Iterable, Iterator, Mapping, Protocol, cast
from user_engine.domain import (
Account,
AccessProfile,
ActiveAccessContext,
Application,
ApplicationBinding,
AuditRecord,
Catalog,
ExternalIdentity,
FamilyInvitation,
IdentityFactor,
Membership,
OnboardingJourney,
OutboxEvent,
PreparedAccount,
ProfileValue,
RegistrationSession,
TenantAccount,
User,
WelcomeProtocol,
)
from user_engine.migrations import LATEST_SCHEMA_VERSION, USER_ENGINE_RECORD_COUNT_KEYS
from user_engine.store_records import (
StoreRecord,
composite_record_key,
domain_record_from_store_record,
store_record_for,
)
_RECORD_COLUMNS = (
"record_type, record_key, tenant, user_id, application_id, "
"scope_type, scope_id, payload"
)
_RECORD_COUNT_KEY_BY_TYPE = {
"application_bindings": "bindings",
}
class PostgresCursor(Protocol):
def execute(self, sql: str, params: Iterable[Any] | None = None) -> Any:
"""Execute a SQL statement."""
def fetchone(self) -> Any | None:
"""Fetch one row from the previous query."""
def fetchall(self) -> Iterable[Any]:
"""Fetch all rows from the previous query."""
def close(self) -> Any:
"""Close the cursor."""
class PostgresConnection(Protocol):
def cursor(self) -> PostgresCursor:
"""Return a DB-API-like cursor."""
def commit(self) -> Any:
"""Commit the current transaction."""
def rollback(self) -> Any:
"""Roll back the current transaction."""
class PostgresUserEngineStore:
"""Postgres implementation of the `UserEngineStore` protocol."""
def __init__(self, connection: PostgresConnection) -> None:
self.connection = connection
@property
def schema_version(self) -> str | None:
return LATEST_SCHEMA_VERSION if self._has_latest_schema() else None
@property
def ready(self) -> bool:
return self.schema_version == LATEST_SCHEMA_VERSION
def migrate(self) -> None:
sql = _load_bootstrap_sql()
with self._cursor() as cursor:
cursor.execute(sql)
self.connection.commit()
@contextmanager
def transaction(self) -> Iterator[None]:
begin = getattr(self.connection, "begin", None)
if callable(begin):
begin()
try:
yield
except Exception:
self.connection.rollback()
raise
else:
self.connection.commit()
def save_user(self, user: User) -> None:
self._upsert_record(user)
def user(self, user_id: str) -> User | None:
return cast(User | None, self._get_record("users", user_id))
def save_account(self, account: Account) -> None:
self._upsert_record(account)
def user_account(self, user_id: str) -> Account | None:
return cast(Account | None, self._get_record("accounts", user_id))
def save_identity(self, identity: ExternalIdentity) -> None:
self._upsert_record(identity)
def find_identity(self, issuer: str, subject: str) -> ExternalIdentity | None:
key = composite_record_key(issuer, subject)
return cast(ExternalIdentity | None, self._get_record("external_identities", key))
def identities_for_user(self, user_id: str) -> tuple[ExternalIdentity, ...]:
return cast(
tuple[ExternalIdentity, ...],
self._query_records("external_identities", user_id=user_id),
)
def save_tenant_account(self, account: TenantAccount) -> None:
self._upsert_record(account)
def tenant_account(self, tenant: str, user_id: str) -> TenantAccount | None:
key = composite_record_key(tenant, user_id)
return cast(TenantAccount | None, self._get_record("tenant_accounts", key))
def save_membership(self, membership: Membership) -> None:
self._upsert_record(membership)
def memberships_for_user(
self, user_id: str, *, tenant: str | None = None
) -> tuple[Membership, ...]:
return cast(
tuple[Membership, ...],
self._query_records("memberships", user_id=user_id, tenant=tenant),
)
def memberships_for_tenant(self, tenant: str) -> tuple[Membership, ...]:
return cast(
tuple[Membership, ...],
self._query_records("memberships", tenant=tenant),
)
def save_application(self, application: Application) -> None:
self._upsert_record(application)
def application(self, application_id: str) -> Application | None:
return cast(Application | None, self._get_record("applications", application_id))
def save_binding(self, binding: ApplicationBinding) -> None:
self._upsert_record(binding)
def binding(self, application_id: str) -> ApplicationBinding | None:
return cast(
ApplicationBinding | None,
self._get_record("application_bindings", application_id),
)
def save_catalog(self, catalog: Catalog) -> None:
self._upsert_record(catalog)
def catalog(self, catalog_id: str) -> Catalog | None:
return cast(Catalog | None, self._get_record("catalogs", catalog_id))
def all_catalogs(self) -> tuple[Catalog, ...]:
return cast(tuple[Catalog, ...], self._query_records("catalogs"))
def save_family_invitation(self, invitation: FamilyInvitation) -> None:
self._upsert_record(invitation)
def family_invitation(self, invitation_id: str) -> FamilyInvitation | None:
return cast(
FamilyInvitation | None,
self._get_record("family_invitations", invitation_id),
)
def family_invitations_for_user(
self, user_id: str
) -> tuple[FamilyInvitation, ...]:
return cast(
tuple[FamilyInvitation, ...],
self._query_records("family_invitations", user_id=user_id),
)
def save_registration_session(self, session: RegistrationSession) -> None:
self._upsert_record(session)
def registration_session(
self, registration_id: str
) -> RegistrationSession | None:
return cast(
RegistrationSession | None,
self._get_record("registration_sessions", registration_id),
)
def all_registration_sessions(self) -> tuple[RegistrationSession, ...]:
return cast(
tuple[RegistrationSession, ...],
self._query_records("registration_sessions"),
)
def save_identity_factor(self, factor: IdentityFactor) -> None:
self._upsert_record(factor)
def identity_factor(self, factor_id: str) -> IdentityFactor | None:
return cast(
IdentityFactor | None,
self._get_record("identity_factors", factor_id),
)
def factors_for_registration(
self, registration_id: str
) -> tuple[IdentityFactor, ...]:
return cast(
tuple[IdentityFactor, ...],
self._query_records(
"identity_factors",
scope_type="registration",
scope_id=registration_id,
),
)
def factors_for_user(self, user_id: str) -> tuple[IdentityFactor, ...]:
return cast(
tuple[IdentityFactor, ...],
self._query_records("identity_factors", user_id=user_id),
)
def save_prepared_account(self, account: PreparedAccount) -> None:
self._upsert_record(account)
def prepared_account(self, prepared_account_id: str) -> PreparedAccount | None:
return cast(
PreparedAccount | None,
self._get_record("prepared_accounts", prepared_account_id),
)
def prepared_accounts_for_tenant(
self, tenant: str
) -> tuple[PreparedAccount, ...]:
return cast(
tuple[PreparedAccount, ...],
self._query_records("prepared_accounts", tenant=tenant),
)
def save_access_profile(self, profile: AccessProfile) -> None:
self._upsert_record(profile)
def access_profile(self, access_profile_id: str) -> AccessProfile | None:
return cast(
AccessProfile | None,
self._get_record("access_profiles", access_profile_id),
)
def access_profiles_for_tenant(self, tenant: str) -> tuple[AccessProfile, ...]:
return cast(
tuple[AccessProfile, ...],
self._query_records("access_profiles", tenant=tenant),
)
def save_active_access_context(self, context: ActiveAccessContext) -> None:
self._upsert_record(context)
def active_access_context(
self, user_id: str, tenant: str
) -> ActiveAccessContext | None:
key = composite_record_key(user_id, tenant)
return cast(
ActiveAccessContext | None,
self._get_record("active_access_contexts", key),
)
def active_access_contexts_for_tenant(
self, tenant: str
) -> tuple[ActiveAccessContext, ...]:
return cast(
tuple[ActiveAccessContext, ...],
self._query_records("active_access_contexts", tenant=tenant),
)
def save_welcome_protocol(self, protocol: WelcomeProtocol) -> None:
self._upsert_record(protocol)
def welcome_protocol(self, protocol_id: str) -> WelcomeProtocol | None:
return cast(
WelcomeProtocol | None,
self._get_record("welcome_protocols", protocol_id),
)
def welcome_protocols_for_tenant(
self, tenant: str
) -> tuple[WelcomeProtocol, ...]:
return cast(
tuple[WelcomeProtocol, ...],
self._query_records("welcome_protocols", tenant=tenant),
)
def save_onboarding_journey(self, journey: OnboardingJourney) -> None:
self._upsert_record(journey)
def onboarding_journey(self, journey_id: str) -> OnboardingJourney | None:
return cast(
OnboardingJourney | None,
self._get_record("onboarding_journeys", journey_id),
)
def onboarding_journeys_for_user(
self, user_id: str, *, tenant: str | None = None
) -> tuple[OnboardingJourney, ...]:
return cast(
tuple[OnboardingJourney, ...],
self._query_records("onboarding_journeys", user_id=user_id, tenant=tenant),
)
def onboarding_journeys_for_tenant(
self, tenant: str
) -> tuple[OnboardingJourney, ...]:
return cast(
tuple[OnboardingJourney, ...],
self._query_records("onboarding_journeys", tenant=tenant),
)
def save_profile_value(self, value: ProfileValue) -> None:
self._upsert_record(value)
def values_for_user(self, user_id: str) -> tuple[ProfileValue, ...]:
return cast(
tuple[ProfileValue, ...],
self._query_records("profile_values", user_id=user_id),
)
def append_audit(self, record: AuditRecord) -> None:
store_record = store_record_for(record)
with self._cursor() as cursor:
cursor.execute(
"""
INSERT INTO user_engine_audit_records (
audit_id, tenant, actor_issuer, actor_subject, action,
subject, correlation_id, summary, payload
)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s::jsonb)
ON CONFLICT (audit_id) DO NOTHING
""",
(
record.audit_id,
record.tenant,
record.actor.issuer,
record.actor.subject,
record.action,
record.subject,
record.correlation_id,
record.summary,
json.dumps(store_record.payload),
),
)
def audit_log(self) -> tuple[AuditRecord, ...]:
with self._cursor() as cursor:
cursor.execute(
"""
SELECT payload
FROM user_engine_audit_records
ORDER BY recorded_at, audit_id
"""
)
return tuple(
cast(AuditRecord, self._decode_payload_row("audit_records", row))
for row in cursor.fetchall()
)
def append_outbox(self, event: OutboxEvent) -> None:
store_record = store_record_for(event)
with self._cursor() as cursor:
cursor.execute(
"""
INSERT INTO user_engine_outbox_events (
event_id, tenant, event_type, aggregate_id, correlation_id,
payload, occurred_at
)
VALUES (%s, %s, %s, %s, %s, %s::jsonb, %s)
ON CONFLICT (event_id) DO NOTHING
""",
(
event.event_id,
event.tenant,
event.event_type,
event.aggregate_id,
event.correlation_id,
json.dumps(store_record.payload),
event.occurred_at,
),
)
def pending_outbox(self) -> tuple[OutboxEvent, ...]:
with self._cursor() as cursor:
cursor.execute(
"""
SELECT payload
FROM user_engine_outbox_events
WHERE claimed_at IS NULL AND delivered_at IS NULL
ORDER BY occurred_at, event_id
"""
)
return tuple(
cast(OutboxEvent, self._decode_payload_row("outbox_events", row))
for row in cursor.fetchall()
)
def record_counts(self) -> Mapping[str, int]:
counts = {key: 0 for key in USER_ENGINE_RECORD_COUNT_KEYS}
with self._cursor() as cursor:
cursor.execute(
"""
SELECT record_type, COUNT(*)
FROM user_engine_records
GROUP BY record_type
"""
)
for record_type, count in cursor.fetchall():
key = _RECORD_COUNT_KEY_BY_TYPE.get(record_type, record_type)
if key in counts:
counts[key] = int(count)
cursor.execute("SELECT COUNT(*) FROM user_engine_audit_records")
counts["audit_records"] = int(_first_column(cursor.fetchone()) or 0)
cursor.execute(
"""
SELECT COUNT(*)
FROM user_engine_outbox_events
WHERE claimed_at IS NULL AND delivered_at IS NULL
"""
)
counts["pending_outbox_events"] = int(
_first_column(cursor.fetchone()) or 0
)
return counts
def _upsert_record(self, value: Any) -> None:
record = store_record_for(value)
with self._cursor() as cursor:
cursor.execute(
f"""
INSERT INTO user_engine_records (
record_type, record_key, tenant, user_id, application_id,
scope_type, scope_id, payload
)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s::jsonb)
ON CONFLICT (record_type, record_key) DO UPDATE SET
tenant = EXCLUDED.tenant,
user_id = EXCLUDED.user_id,
application_id = EXCLUDED.application_id,
scope_type = EXCLUDED.scope_type,
scope_id = EXCLUDED.scope_id,
payload = EXCLUDED.payload,
updated_at = now()
""",
(
record.record_type,
record.record_key,
record.tenant,
record.user_id,
record.application_id,
record.scope_type,
record.scope_id,
json.dumps(record.payload),
),
)
def _get_record(self, record_type: str, record_key: str) -> Any | None:
with self._cursor() as cursor:
cursor.execute(
f"""
SELECT {_RECORD_COLUMNS}
FROM user_engine_records
WHERE record_type = %s AND record_key = %s
""",
(record_type, record_key),
)
row = cursor.fetchone()
if row is None:
return None
return domain_record_from_store_record(_store_record_from_row(row))
def _query_records(
self,
record_type: str,
*,
tenant: str | None = None,
user_id: str | None = None,
application_id: str | None = None,
scope_type: str | None = None,
scope_id: str | None = None,
) -> tuple[Any, ...]:
filters = {
"tenant": tenant,
"user_id": user_id,
"application_id": application_id,
"scope_type": scope_type,
"scope_id": scope_id,
}
clauses = ["record_type = %s"]
params: list[Any] = [record_type]
for column, value in filters.items():
if value is not None:
clauses.append(f"{column} = %s")
params.append(value)
where_clause = " AND ".join(clauses)
with self._cursor() as cursor:
cursor.execute(
f"""
SELECT {_RECORD_COLUMNS}
FROM user_engine_records
WHERE {where_clause}
ORDER BY record_key
""",
tuple(params),
)
rows = cursor.fetchall()
return tuple(
domain_record_from_store_record(_store_record_from_row(row))
for row in rows
)
def _decode_payload_row(self, record_type: str, row: Any) -> Any:
payload = _first_column(row)
if isinstance(payload, str):
payload = json.loads(payload)
record_key = str(cast(Mapping[str, Any], payload).get("event_id") or "")
if record_type == "audit_records":
record_key = str(cast(Mapping[str, Any], payload).get("audit_id") or "")
return domain_record_from_store_record(
StoreRecord(record_type=record_type, record_key=record_key, payload=payload)
)
def _has_latest_schema(self) -> bool:
with self._cursor() as cursor:
cursor.execute(
"""
SELECT 1
FROM user_engine_schema_versions
WHERE version = %s
""",
(LATEST_SCHEMA_VERSION,),
)
return cursor.fetchone() is not None
@contextmanager
def _cursor(self) -> Iterator[PostgresCursor]:
cursor = self.connection.cursor()
try:
yield cursor
finally:
close = getattr(cursor, "close", None)
if callable(close):
close()
def _store_record_from_row(row: Any) -> StoreRecord:
if isinstance(row, Mapping):
payload = row["payload"]
if isinstance(payload, str):
payload = json.loads(payload)
return StoreRecord(
record_type=str(row["record_type"]),
record_key=str(row["record_key"]),
tenant=cast(str | None, row.get("tenant")),
user_id=cast(str | None, row.get("user_id")),
application_id=cast(str | None, row.get("application_id")),
scope_type=cast(str | None, row.get("scope_type")),
scope_id=cast(str | None, row.get("scope_id")),
payload=payload,
)
(
record_type,
record_key,
tenant,
user_id,
application_id,
scope_type,
scope_id,
payload,
) = row
if isinstance(payload, str):
payload = json.loads(payload)
return StoreRecord(
record_type=str(record_type),
record_key=str(record_key),
tenant=tenant,
user_id=user_id,
application_id=application_id,
scope_type=scope_type,
scope_id=scope_id,
payload=payload,
)
def _first_column(row: Any) -> Any:
if row is None:
return None
if isinstance(row, Mapping):
return next(iter(row.values()))
return row[0]
def _load_bootstrap_sql() -> str:
repo_root = Path(__file__).resolve().parents[3]
return (repo_root / "migrations/postgres/0001_user_engine_store.sql").read_text(
encoding="utf-8"
)

View File

@@ -77,6 +77,11 @@ def store_record_for(value: Any) -> StoreRecord:
)
def composite_record_key(*parts: str | None) -> str:
"""Return the deterministic composite key used by durable store records."""
return _composite_key(*parts)
def domain_record_from_store_record(record: StoreRecord) -> Any:
"""Decode a durable-store record payload into its domain dataclass."""
codec = _CODECS_BY_RECORD_TYPE.get(record.record_type)
@@ -271,7 +276,11 @@ _CODECS = (
"identity_factors",
IdentityFactor,
lambda value: _single_key(value.factor_id),
lambda value: {"user_id": value.user_id},
lambda value: {
"user_id": value.user_id,
"scope_type": "registration" if value.registration_id else None,
"scope_id": value.registration_id,
},
),
StoreRecordCodec(
"prepared_accounts",