"""Temporal activity definitions for activity-core. Activities run inside a Worker bound to 'orchestrator-tq'. Each function is decorated with @activity.defn and executed by RunActivityWorkflow via workflow.execute_activity(). DB access pattern: worker.py calls init_session_factory(url) once before starting workers, which sets the module-level _session_factory used by activities that need DB access. """ from __future__ import annotations import json import uuid from datetime import datetime, timezone from typing import Any from sqlalchemy import select from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from temporalio import activity from temporalio.exceptions import ApplicationError from activity_core.db import make_engine from activity_core.issue_sink import get_issue_sink from activity_core.orm import ActivityDefinition as ActivityDefinitionRow from activity_core.orm import ActivityRun, TaskInstance, TaskSpawnLog from activity_core.llm_client import get_llm_client from activity_core.models import InstructionDef from activity_core.ops_evidence_sinks import persist_ops_inventory_evidence from activity_core.report_sinks import persist_reports from activity_core.rules.actions import expand_rule_actions from activity_core.rules.executor import execute_instruction_with_audit _session_factory: async_sessionmaker[AsyncSession] | None = None def init_session_factory(url: str) -> None: """Initialise the shared DB session factory. Must be called once from worker.py before workers are started. """ global _session_factory _session_factory = async_sessionmaker(make_engine(url), expire_on_commit=False) def _get_session_factory() -> async_sessionmaker[AsyncSession]: if _session_factory is None: raise RuntimeError( "DB session factory not initialised — call init_session_factory() first" ) return _session_factory def _bind_resolver_result(bind_key: str, result: Any) -> Any: """Unwrap single-key resolver payloads when the key matches bind_key. Resolvers such as ``discover_kaizen_projects`` return ``{"projects": [...]}`` while definitions bind to ``context.projects`` and iterate ``for_each: context.projects``. Multi-key summaries (e.g. repo SBOM bulk) stay intact. """ if isinstance(result, dict) and len(result) == 1 and bind_key in result: return result[bind_key] return result def _parse_event_envelope(event_envelope_json: str | None) -> dict[str, Any] | None: """Parse an event envelope JSON string for context resolvers.""" if not event_envelope_json: return None try: payload = json.loads(event_envelope_json) except (TypeError, json.JSONDecodeError) as exc: activity.logger.warning("Invalid event envelope JSON - %s", exc) return None if not isinstance(payload, dict): activity.logger.warning( "Invalid event envelope JSON - expected object, got %s", type(payload).__name__, ) return None return payload # ── Activities ───────────────────────────────────────────────────────────────── @activity.defn async def load_activity_definition(activity_id: str) -> dict: """Load an ActivityDefinition row from Postgres by ID. Returns a JSON-serialisable dict suitable for passing between Temporal workflow steps. Raises: ApplicationError (non-retryable): if no row exists for activity_id. """ Session = _get_session_factory() async with Session() as session: row = await session.scalar( select(ActivityDefinitionRow).where( ActivityDefinitionRow.id == uuid.UUID(activity_id) ) ) if row is None: raise ApplicationError( f"ActivityDefinition {activity_id!r} not found", non_retryable=True, ) return { "id": str(row.id), "name": row.name, "enabled": row.enabled, "trigger_type": row.trigger_type, "trigger_config": row.trigger_config, "context_sources": row.context_sources, "task_templates": row.task_templates, "rules": row.rules_json, "instructions": row.instructions_json, "dedupe_key_strategy": row.dedupe_key_strategy, "version": row.version, } @activity.defn async def resolve_context( context_sources: list[dict], event_envelope_json: str | None = None, ) -> dict: """Resolve each context source and merge into a snapshot dict. Returns: {bind_key: resolved_value, ...} Source types are dispatched via CONTEXT_RESOLVER_REGISTRY. A resolver that raises logs a warning and binds {} unless the context source is marked required, in which case the activity fails visibly. The 'static' type is handled inline without a registry entry. """ import activity_core.context_resolvers # noqa: F401 — registers all adapters from activity_core.context_resolvers.base import CONTEXT_RESOLVER_REGISTRY snapshot: dict = {} event_envelope = _parse_event_envelope(event_envelope_json) for source in context_sources: source_type = source.get("type", "") query = source.get("query", "") params = source.get("params") or {} required = bool(source.get("required") or params.get("required", False)) resolver_params = dict(params) resolver_params["required"] = required raw_bind = source.get("bind_to") or source.get("name") or source_type # Strip the 'context.' namespace prefix so evaluator can find the key. bind_key = raw_bind.removeprefix("context.") if raw_bind.startswith("context.") else raw_bind if source_type == "static": snapshot[bind_key] = source.get("config", {}).get("value") continue resolver_cls = CONTEXT_RESOLVER_REGISTRY.get(source_type) if resolver_cls is None: if required: raise ApplicationError( f"Required context source type {source_type!r} is not registered", non_retryable=True, ) activity.logger.warning( "Unknown context source type %r — binding {}", source_type, ) snapshot[bind_key] = {} continue try: resolved = resolver_cls().resolve(query, event_envelope, resolver_params) snapshot[bind_key] = _bind_resolver_result(bind_key, resolved) except Exception as exc: if required: raise ApplicationError( f"Required context resolver {source_type!r}/{query!r} failed: {exc}" ) from exc activity.logger.warning( "Context resolver %r failed — %s; binding {}", source_type, exc, ) snapshot[bind_key] = {} return snapshot @activity.defn async def log_run(run_payload: dict) -> str: """Persist an ActivityRun record to Postgres and return its run_id. Idempotent: uses INSERT … ON CONFLICT (run_id) DO NOTHING so Temporal activity retries do not produce duplicate rows. Expected keys in run_payload: run_id (str UUID — computed deterministically in workflow) activity_id (str UUID) scheduled_for (ISO-8601 str or None) context_snapshot (dict) tasks_spawned (int) version_used (int) Returns: run_id as a str UUID. """ Session = _get_session_factory() run_id = uuid.UUID(run_payload["run_id"]) scheduled_for: datetime | None = None if run_payload.get("scheduled_for"): scheduled_for = datetime.fromisoformat(run_payload["scheduled_for"]) stmt = ( pg_insert(ActivityRun) .values( run_id=run_id, activity_id=uuid.UUID(run_payload["activity_id"]), scheduled_for=scheduled_for, fired_at=datetime.now(tz=timezone.utc), context_snapshot=run_payload["context_snapshot"], tasks_spawned=run_payload["tasks_spawned"], version_used=run_payload["version_used"], ) .on_conflict_do_nothing(index_elements=["run_id"]) ) async with Session() as session: async with session.begin(): await session.execute(stmt) return str(run_id) @activity.defn async def persist_task_instance(task_payload: dict) -> str: """Write a TaskInstance row and return its id. Idempotent: uses INSERT … ON CONFLICT (id) DO NOTHING. Expected keys in task_payload: id (str UUID — deterministic, computed in TaskExecutorWorkflow) run_id (str UUID) type (str) params (dict) status (str, default "done" for stub) Returns: task instance id as a str UUID. """ Session = _get_session_factory() task_id = uuid.UUID(task_payload["id"]) stmt = ( pg_insert(TaskInstance) .values( id=task_id, run_id=uuid.UUID(task_payload["run_id"]), type=task_payload["type"], params=task_payload.get("params", {}), status=task_payload.get("status", "done"), ) .on_conflict_do_nothing(index_elements=["id"]) ) async with Session() as session: async with session.begin(): await session.execute(stmt) return str(task_id) @activity.defn async def evaluate_rules(payload: dict) -> list[dict]: """Evaluate rules and render matching actions as task specs. Rules that raise UnsafeExpression or any other error are skipped and logged. Expected keys in payload: rules list[dict] — RuleDef serialised dicts event dict — EventEnvelope attributes (or empty for cron) context dict — context snapshot from resolve_context """ from activity_core.rules.evaluator import UnsafeExpression rules = payload.get("rules", []) event_attrs = payload.get("event", {}) context = payload.get("context", {}) # Build a simple object whose attributes mirror event fields for the evaluator. class _Env: def __init__(self, attrs: dict) -> None: self.attributes = _DictObj(attrs) class _DictObj: def __init__(self, d: dict) -> None: self.__dict__.update(d) event_obj = _Env(event_attrs) task_specs: list[dict] = [] for rule in rules: try: task_specs.extend(expand_rule_actions([rule], event_obj, context)) except UnsafeExpression as exc: activity.logger.warning("rule %r unsafe expression — skipping: %s", rule.get("id"), exc) except Exception as exc: activity.logger.warning("rule %r eval error — skipping: %s", rule.get("id"), exc) return task_specs @activity.defn async def evaluate_instructions(payload: dict) -> dict: """Evaluate instruction blocks and return task specs/reports with audit fields. Expected keys in payload: instructions list[dict] — InstructionDef serialised dicts event dict — EventEnvelope attributes (or empty for cron) context dict — context snapshot from resolve_context """ instructions = payload.get("instructions", []) event_attrs = payload.get("event", {}) context = payload.get("context", {}) llm_client = get_llm_client() class _Env: def __init__(self, attrs: dict) -> None: self.attributes = _DictObj(attrs) class _DictObj: def __init__(self, d: dict) -> None: self.__dict__.update(d) event_obj = _Env(event_attrs) task_specs: list[dict] = [] reports: list[dict] = [] for raw_instruction in instructions: try: instruction = InstructionDef.model_validate(raw_instruction) except Exception as exc: activity.logger.warning("instruction definition invalid — %s", exc) continue result = execute_instruction_with_audit( instruction, event_obj, context, llm_client, ) if result.report is not None: reports.append({ "instruction_id": instruction.id, "report": result.report, "sinks": instruction.report_sinks, "condition": result.condition_matched, "prompt_hash": result.prompt_hash, "model": result.model, "output_validated": result.output_validated, "review_required": result.review_required, "validation_error": result.validation_error, }) for spec in result.tasks: task_specs.append({ "title": spec.title, "description": spec.description, "target_repo": spec.target_repo, "priority": spec.priority, "labels": spec.labels, "due_in_days": spec.due_in_days, "source_type": "instruction", "source_id": instruction.id, "condition": result.condition_matched, "prompt_hash": result.prompt_hash, "model": result.model, "output_validated": result.output_validated, "review_required": result.review_required, }) return {"task_specs": task_specs, "reports": reports} @activity.defn async def persist_instruction_reports(payload: dict) -> list[dict]: """Persist report payloads to deterministic configured sinks.""" return persist_reports(payload) @activity.defn async def persist_ops_evidence(payload: dict) -> list[dict]: """Persist compact deterministic ops inventory evidence.""" return persist_ops_inventory_evidence(payload) @activity.defn async def emit_tasks(payload: dict) -> list[str]: """Emit TaskSpecs to IssueSink and write task_spawn_log rows. Returns list of external task ref IDs. Expected keys in payload: task_specs list[dict] — from evaluate_rules matched actions activity_id str — UUID of the ActivityDefinition triggering_event_id str — event ID or workflow ID for cron run_id str — UUID of the ActivityRun """ from activity_core.rules.models import TaskSpec task_specs_raw = payload.get("task_specs", []) activity_id = payload.get("activity_id", "") triggering_event_id = payload.get("triggering_event_id", "") sink = get_issue_sink() Session = _get_session_factory() refs: list[str] = [] errors: list[str] = [] async with Session() as session: async with session.begin(): for spec_dict in task_specs_raw: spec = TaskSpec( title=spec_dict.get("title", ""), description=spec_dict.get("description", ""), target_repo=spec_dict.get("target_repo"), priority=spec_dict.get("priority", "medium"), labels=spec_dict.get("labels", []), due_in_days=spec_dict.get("due_in_days"), source_type=spec_dict.get("source_type", "rule"), source_id=spec_dict.get("source_id", ""), triggering_event_id=triggering_event_id, activity_definition_id=activity_id, ) try: ref = sink.emit(spec) refs.append(ref.external_id) log_row = TaskSpawnLog( activity_def_id=uuid.UUID(activity_id), source_type=spec.source_type, source_id=spec.source_id, source_version="1", triggering_event_id=triggering_event_id, task_ref=ref.external_id, condition_matched=spec_dict.get("condition"), prompt_hash=spec_dict.get("prompt_hash"), model=spec_dict.get("model"), output_validated=spec_dict.get("output_validated"), review_required=spec_dict.get("review_required"), ) session.add(log_row) except Exception as exc: message = f"{spec.source_type}:{spec.source_id}: {exc}" errors.append(message) activity.logger.warning("emit_tasks: sink.emit failed — %s", exc) if errors: raise RuntimeError(f"task emission sink failure: {errors!r}") return refs