"""Temporal activity definitions for activity-core. Activities run inside a Worker bound to 'orchestrator-tq'. Each function is decorated with @activity.defn and executed by RunActivityWorkflow via workflow.execute_activity(). DB access pattern: worker.py calls init_session_factory(url) once before starting workers, which sets the module-level _session_factory used by activities that need DB access. """ from __future__ import annotations import uuid from datetime import datetime, timezone from sqlalchemy import select from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from temporalio import activity from temporalio.exceptions import ApplicationError from activity_core.db import make_engine from activity_core.orm import ActivityDefinition as ActivityDefinitionRow from activity_core.orm import ActivityRun, TaskInstance _session_factory: async_sessionmaker[AsyncSession] | None = None def init_session_factory(url: str) -> None: """Initialise the shared DB session factory. Must be called once from worker.py before workers are started. """ global _session_factory _session_factory = async_sessionmaker(make_engine(url), expire_on_commit=False) def _get_session_factory() -> async_sessionmaker[AsyncSession]: if _session_factory is None: raise RuntimeError( "DB session factory not initialised — call init_session_factory() first" ) return _session_factory # ── Activities ───────────────────────────────────────────────────────────────── @activity.defn async def load_activity_definition(activity_id: str) -> dict: """Load an ActivityDefinition row from Postgres by ID. Returns a JSON-serialisable dict suitable for passing between Temporal workflow steps. Raises: ApplicationError (non-retryable): if no row exists for activity_id. """ Session = _get_session_factory() async with Session() as session: row = await session.scalar( select(ActivityDefinitionRow).where( ActivityDefinitionRow.id == uuid.UUID(activity_id) ) ) if row is None: raise ApplicationError( f"ActivityDefinition {activity_id!r} not found", non_retryable=True, ) return { "id": str(row.id), "name": row.name, "enabled": row.enabled, "trigger_type": row.trigger_type, "trigger_config": row.trigger_config, "context_sources": row.context_sources, "task_templates": row.task_templates, "dedupe_key_strategy": row.dedupe_key_strategy, "version": row.version, } @activity.defn async def resolve_context(context_sources: list[dict]) -> dict: """Resolve each context source and merge into a snapshot dict. Returns: {source.name: resolved_value, ...} Supported source types: static — returns config["value"] directly http_get — not yet implemented db_query — not yet implemented """ snapshot: dict = {} for source in context_sources: name = source["name"] source_type = source["type"] config = source.get("config", {}) if source_type == "static": snapshot[name] = config.get("value") elif source_type in ("http_get", "db_query"): raise ApplicationError( f"Context source type {source_type!r} is not yet implemented", non_retryable=True, ) else: raise ApplicationError( f"Unknown context source type {source_type!r}", non_retryable=True, ) return snapshot @activity.defn async def log_run(run_payload: dict) -> str: """Persist an ActivityRun record to Postgres and return its run_id. Idempotent: uses INSERT … ON CONFLICT (run_id) DO NOTHING so Temporal activity retries do not produce duplicate rows. Expected keys in run_payload: run_id (str UUID — computed deterministically in workflow) activity_id (str UUID) scheduled_for (ISO-8601 str or None) context_snapshot (dict) tasks_spawned (int) version_used (int) Returns: run_id as a str UUID. """ Session = _get_session_factory() run_id = uuid.UUID(run_payload["run_id"]) scheduled_for: datetime | None = None if run_payload.get("scheduled_for"): scheduled_for = datetime.fromisoformat(run_payload["scheduled_for"]) stmt = ( pg_insert(ActivityRun) .values( run_id=run_id, activity_id=uuid.UUID(run_payload["activity_id"]), scheduled_for=scheduled_for, fired_at=datetime.now(tz=timezone.utc), context_snapshot=run_payload["context_snapshot"], tasks_spawned=run_payload["tasks_spawned"], version_used=run_payload["version_used"], ) .on_conflict_do_nothing(index_elements=["run_id"]) ) async with Session() as session: async with session.begin(): await session.execute(stmt) return str(run_id) @activity.defn async def persist_task_instance(task_payload: dict) -> str: """Write a TaskInstance row and return its id. Idempotent: uses INSERT … ON CONFLICT (id) DO NOTHING. Expected keys in task_payload: id (str UUID — deterministic, computed in TaskExecutorWorkflow) run_id (str UUID) type (str) params (dict) status (str, default "done" for stub) Returns: task instance id as a str UUID. """ Session = _get_session_factory() task_id = uuid.UUID(task_payload["id"]) stmt = ( pg_insert(TaskInstance) .values( id=task_id, run_id=uuid.UUID(task_payload["run_id"]), type=task_payload["type"], params=task_payload.get("params", {}), status=task_payload.get("status", "done"), ) .on_conflict_do_nothing(index_elements=["id"]) ) async with Session() as session: async with session.begin(): await session.execute(stmt) return str(task_id)