"""Temporal activity definitions for activity-core. Activities run inside a Worker bound to 'orchestrator-tq'. Each function is decorated with @activity.defn and executed by RunActivityWorkflow via workflow.execute_activity(). DB access pattern: worker.py calls init_session_factory(url) once before starting workers, which sets the module-level _session_factory used by activities that need DB access. """ from __future__ import annotations import uuid from datetime import datetime, timezone from sqlalchemy import select from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from temporalio import activity from temporalio.exceptions import ApplicationError from activity_core.db import make_engine from activity_core.issue_sink import get_issue_sink from activity_core.orm import ActivityDefinition as ActivityDefinitionRow from activity_core.orm import ActivityRun, TaskInstance, TaskSpawnLog from activity_core.rules import evaluate_condition from activity_core.llm_client import get_llm_client from activity_core.models import InstructionDef from activity_core.report_sinks import persist_reports from activity_core.rules.executor import execute_instruction_with_audit _session_factory: async_sessionmaker[AsyncSession] | None = None def init_session_factory(url: str) -> None: """Initialise the shared DB session factory. Must be called once from worker.py before workers are started. """ global _session_factory _session_factory = async_sessionmaker(make_engine(url), expire_on_commit=False) def _get_session_factory() -> async_sessionmaker[AsyncSession]: if _session_factory is None: raise RuntimeError( "DB session factory not initialised — call init_session_factory() first" ) return _session_factory # ── Activities ───────────────────────────────────────────────────────────────── @activity.defn async def load_activity_definition(activity_id: str) -> dict: """Load an ActivityDefinition row from Postgres by ID. Returns a JSON-serialisable dict suitable for passing between Temporal workflow steps. Raises: ApplicationError (non-retryable): if no row exists for activity_id. """ Session = _get_session_factory() async with Session() as session: row = await session.scalar( select(ActivityDefinitionRow).where( ActivityDefinitionRow.id == uuid.UUID(activity_id) ) ) if row is None: raise ApplicationError( f"ActivityDefinition {activity_id!r} not found", non_retryable=True, ) return { "id": str(row.id), "name": row.name, "enabled": row.enabled, "trigger_type": row.trigger_type, "trigger_config": row.trigger_config, "context_sources": row.context_sources, "task_templates": row.task_templates, "rules": row.rules_json, "instructions": row.instructions_json, "dedupe_key_strategy": row.dedupe_key_strategy, "version": row.version, } @activity.defn async def resolve_context( context_sources: list[dict], event_envelope_json: str | None = None, ) -> dict: """Resolve each context source and merge into a snapshot dict. Returns: {bind_key: resolved_value, ...} Source types are dispatched via CONTEXT_RESOLVER_REGISTRY. A resolver that raises logs a warning and binds {} — it does not abort the run. The 'static' type is handled inline without a registry entry. """ import activity_core.context_resolvers # noqa: F401 — registers all adapters from activity_core.context_resolvers.base import CONTEXT_RESOLVER_REGISTRY snapshot: dict = {} for source in context_sources: source_type = source.get("type", "") query = source.get("query", "") params = source.get("params") or {} raw_bind = source.get("bind_to") or source.get("name") or source_type # Strip the 'context.' namespace prefix so evaluator can find the key. bind_key = raw_bind.removeprefix("context.") if raw_bind.startswith("context.") else raw_bind if source_type == "static": snapshot[bind_key] = source.get("config", {}).get("value") continue resolver_cls = CONTEXT_RESOLVER_REGISTRY.get(source_type) if resolver_cls is None: activity.logger.warning( "Unknown context source type %r — binding {}", source_type, ) snapshot[bind_key] = {} continue try: snapshot[bind_key] = resolver_cls().resolve(query, None, params) except Exception as exc: activity.logger.warning( "Context resolver %r failed — %s; binding {}", source_type, exc, ) snapshot[bind_key] = {} return snapshot @activity.defn async def log_run(run_payload: dict) -> str: """Persist an ActivityRun record to Postgres and return its run_id. Idempotent: uses INSERT … ON CONFLICT (run_id) DO NOTHING so Temporal activity retries do not produce duplicate rows. Expected keys in run_payload: run_id (str UUID — computed deterministically in workflow) activity_id (str UUID) scheduled_for (ISO-8601 str or None) context_snapshot (dict) tasks_spawned (int) version_used (int) Returns: run_id as a str UUID. """ Session = _get_session_factory() run_id = uuid.UUID(run_payload["run_id"]) scheduled_for: datetime | None = None if run_payload.get("scheduled_for"): scheduled_for = datetime.fromisoformat(run_payload["scheduled_for"]) stmt = ( pg_insert(ActivityRun) .values( run_id=run_id, activity_id=uuid.UUID(run_payload["activity_id"]), scheduled_for=scheduled_for, fired_at=datetime.now(tz=timezone.utc), context_snapshot=run_payload["context_snapshot"], tasks_spawned=run_payload["tasks_spawned"], version_used=run_payload["version_used"], ) .on_conflict_do_nothing(index_elements=["run_id"]) ) async with Session() as session: async with session.begin(): await session.execute(stmt) return str(run_id) @activity.defn async def persist_task_instance(task_payload: dict) -> str: """Write a TaskInstance row and return its id. Idempotent: uses INSERT … ON CONFLICT (id) DO NOTHING. Expected keys in task_payload: id (str UUID — deterministic, computed in TaskExecutorWorkflow) run_id (str UUID) type (str) params (dict) status (str, default "done" for stub) Returns: task instance id as a str UUID. """ Session = _get_session_factory() task_id = uuid.UUID(task_payload["id"]) stmt = ( pg_insert(TaskInstance) .values( id=task_id, run_id=uuid.UUID(task_payload["run_id"]), type=task_payload["type"], params=task_payload.get("params", {}), status=task_payload.get("status", "done"), ) .on_conflict_do_nothing(index_elements=["id"]) ) async with Session() as session: async with session.begin(): await session.execute(stmt) return str(task_id) @activity.defn async def evaluate_rules(payload: dict) -> list[dict]: """Evaluate each rule condition against the event and context. Returns the list of matching rule dicts (those whose condition is True). Rules that raise UnsafeExpression or any other error are skipped and logged. Expected keys in payload: rules list[dict] — RuleDef serialised dicts event dict — EventEnvelope attributes (or empty for cron) context dict — context snapshot from resolve_context """ from activity_core.rules.evaluator import UnsafeExpression rules = payload.get("rules", []) event_attrs = payload.get("event", {}) context = payload.get("context", {}) # Build a simple object whose attributes mirror event fields for the evaluator. class _Env: def __init__(self, attrs: dict) -> None: self.attributes = _DictObj(attrs) class _DictObj: def __init__(self, d: dict) -> None: self.__dict__.update(d) event_obj = _Env(event_attrs) matched: list[dict] = [] for rule in rules: condition = rule.get("condition", "") try: if evaluate_condition(condition, event_obj, context): matched.append(rule) except UnsafeExpression as exc: activity.logger.warning("rule %r unsafe expression — skipping: %s", rule.get("id"), exc) except Exception as exc: activity.logger.warning("rule %r eval error — skipping: %s", rule.get("id"), exc) return matched @activity.defn async def evaluate_instructions(payload: dict) -> dict: """Evaluate instruction blocks and return task specs/reports with audit fields. Expected keys in payload: instructions list[dict] — InstructionDef serialised dicts event dict — EventEnvelope attributes (or empty for cron) context dict — context snapshot from resolve_context """ instructions = payload.get("instructions", []) event_attrs = payload.get("event", {}) context = payload.get("context", {}) llm_client = get_llm_client() class _Env: def __init__(self, attrs: dict) -> None: self.attributes = _DictObj(attrs) class _DictObj: def __init__(self, d: dict) -> None: self.__dict__.update(d) event_obj = _Env(event_attrs) task_specs: list[dict] = [] reports: list[dict] = [] for raw_instruction in instructions: try: instruction = InstructionDef.model_validate(raw_instruction) except Exception as exc: activity.logger.warning("instruction definition invalid — %s", exc) continue result = execute_instruction_with_audit( instruction, event_obj, context, llm_client, ) if result.report is not None: reports.append({ "instruction_id": instruction.id, "report": result.report, "sinks": instruction.report_sinks, "condition": result.condition_matched, "prompt_hash": result.prompt_hash, "model": result.model, "output_validated": result.output_validated, "review_required": result.review_required, }) for spec in result.tasks: task_specs.append({ "title": spec.title, "description": spec.description, "target_repo": spec.target_repo, "priority": spec.priority, "labels": spec.labels, "due_in_days": spec.due_in_days, "source_type": "instruction", "source_id": instruction.id, "condition": result.condition_matched, "prompt_hash": result.prompt_hash, "model": result.model, "output_validated": result.output_validated, "review_required": result.review_required, }) return {"task_specs": task_specs, "reports": reports} @activity.defn async def persist_instruction_reports(payload: dict) -> list[dict]: """Persist report payloads to deterministic configured sinks.""" return persist_reports(payload) @activity.defn async def emit_tasks(payload: dict) -> list[str]: """Emit TaskSpecs to IssueSink and write task_spawn_log rows. Returns list of external task ref IDs. Expected keys in payload: task_specs list[dict] — from evaluate_rules matched actions activity_id str — UUID of the ActivityDefinition triggering_event_id str — event ID or workflow ID for cron run_id str — UUID of the ActivityRun """ from activity_core.rules.models import TaskSpec task_specs_raw = payload.get("task_specs", []) activity_id = payload.get("activity_id", "") triggering_event_id = payload.get("triggering_event_id", "") sink = get_issue_sink() Session = _get_session_factory() refs: list[str] = [] async with Session() as session: async with session.begin(): for spec_dict in task_specs_raw: spec = TaskSpec( title=spec_dict.get("title", ""), description=spec_dict.get("description", ""), target_repo=spec_dict.get("target_repo"), priority=spec_dict.get("priority", "medium"), labels=spec_dict.get("labels", []), due_in_days=spec_dict.get("due_in_days"), source_type=spec_dict.get("source_type", "rule"), source_id=spec_dict.get("source_id", ""), triggering_event_id=triggering_event_id, activity_definition_id=activity_id, ) try: ref = sink.emit(spec) refs.append(ref.external_id) log_row = TaskSpawnLog( activity_def_id=uuid.UUID(activity_id), source_type=spec.source_type, source_id=spec.source_id, source_version="1", triggering_event_id=triggering_event_id, task_ref=ref.external_id, condition_matched=spec_dict.get("condition"), prompt_hash=spec_dict.get("prompt_hash"), model=spec_dict.get("model"), output_validated=spec_dict.get("output_validated"), review_required=spec_dict.get("review_required"), ) session.add(log_row) except Exception as exc: activity.logger.warning("emit_tasks: sink.emit failed — %s", exc) return refs