Files
activity-core/src/activity_core/activities.py

431 lines
15 KiB
Python

"""Temporal activity definitions for activity-core.
Activities run inside a Worker bound to 'orchestrator-tq'.
Each function is decorated with @activity.defn and executed by
RunActivityWorkflow via workflow.execute_activity().
DB access pattern: worker.py calls init_session_factory(url) once before
starting workers, which sets the module-level _session_factory used by
activities that need DB access.
"""
from __future__ import annotations
import uuid
from datetime import datetime, timezone
from sqlalchemy import select
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from temporalio import activity
from temporalio.exceptions import ApplicationError
from activity_core.db import make_engine
from activity_core.issue_sink import get_issue_sink
from activity_core.orm import ActivityDefinition as ActivityDefinitionRow
from activity_core.orm import ActivityRun, TaskInstance, TaskSpawnLog
from activity_core.llm_client import get_llm_client
from activity_core.models import InstructionDef
from activity_core.ops_evidence_sinks import persist_ops_inventory_evidence
from activity_core.report_sinks import persist_reports
from activity_core.rules.actions import expand_rule_actions
from activity_core.rules.executor import execute_instruction_with_audit
_session_factory: async_sessionmaker[AsyncSession] | None = None
def init_session_factory(url: str) -> None:
"""Initialise the shared DB session factory.
Must be called once from worker.py before workers are started.
"""
global _session_factory
_session_factory = async_sessionmaker(make_engine(url), expire_on_commit=False)
def _get_session_factory() -> async_sessionmaker[AsyncSession]:
if _session_factory is None:
raise RuntimeError(
"DB session factory not initialised — call init_session_factory() first"
)
return _session_factory
# ── Activities ─────────────────────────────────────────────────────────────────
@activity.defn
async def load_activity_definition(activity_id: str) -> dict:
"""Load an ActivityDefinition row from Postgres by ID.
Returns a JSON-serialisable dict suitable for passing between
Temporal workflow steps.
Raises:
ApplicationError (non-retryable): if no row exists for activity_id.
"""
Session = _get_session_factory()
async with Session() as session:
row = await session.scalar(
select(ActivityDefinitionRow).where(
ActivityDefinitionRow.id == uuid.UUID(activity_id)
)
)
if row is None:
raise ApplicationError(
f"ActivityDefinition {activity_id!r} not found",
non_retryable=True,
)
return {
"id": str(row.id),
"name": row.name,
"enabled": row.enabled,
"trigger_type": row.trigger_type,
"trigger_config": row.trigger_config,
"context_sources": row.context_sources,
"task_templates": row.task_templates,
"rules": row.rules_json,
"instructions": row.instructions_json,
"dedupe_key_strategy": row.dedupe_key_strategy,
"version": row.version,
}
@activity.defn
async def resolve_context(
context_sources: list[dict],
event_envelope_json: str | None = None,
) -> dict:
"""Resolve each context source and merge into a snapshot dict.
Returns: {bind_key: resolved_value, ...}
Source types are dispatched via CONTEXT_RESOLVER_REGISTRY.
A resolver that raises logs a warning and binds {} unless the context source
is marked required, in which case the activity fails visibly.
The 'static' type is handled inline without a registry entry.
"""
import activity_core.context_resolvers # noqa: F401 — registers all adapters
from activity_core.context_resolvers.base import CONTEXT_RESOLVER_REGISTRY
snapshot: dict = {}
for source in context_sources:
source_type = source.get("type", "")
query = source.get("query", "")
params = source.get("params") or {}
required = bool(source.get("required") or params.get("required", False))
raw_bind = source.get("bind_to") or source.get("name") or source_type
# Strip the 'context.' namespace prefix so evaluator can find the key.
bind_key = raw_bind.removeprefix("context.") if raw_bind.startswith("context.") else raw_bind
if source_type == "static":
snapshot[bind_key] = source.get("config", {}).get("value")
continue
resolver_cls = CONTEXT_RESOLVER_REGISTRY.get(source_type)
if resolver_cls is None:
if required:
raise ApplicationError(
f"Required context source type {source_type!r} is not registered",
non_retryable=True,
)
activity.logger.warning(
"Unknown context source type %r — binding {}",
source_type,
)
snapshot[bind_key] = {}
continue
try:
snapshot[bind_key] = resolver_cls().resolve(query, None, params)
except Exception as exc:
if required:
raise ApplicationError(
f"Required context resolver {source_type!r}/{query!r} failed: {exc}"
) from exc
activity.logger.warning(
"Context resolver %r failed — %s; binding {}",
source_type,
exc,
)
snapshot[bind_key] = {}
return snapshot
@activity.defn
async def log_run(run_payload: dict) -> str:
"""Persist an ActivityRun record to Postgres and return its run_id.
Idempotent: uses INSERT … ON CONFLICT (run_id) DO NOTHING so Temporal
activity retries do not produce duplicate rows.
Expected keys in run_payload:
run_id (str UUID — computed deterministically in workflow)
activity_id (str UUID)
scheduled_for (ISO-8601 str or None)
context_snapshot (dict)
tasks_spawned (int)
version_used (int)
Returns:
run_id as a str UUID.
"""
Session = _get_session_factory()
run_id = uuid.UUID(run_payload["run_id"])
scheduled_for: datetime | None = None
if run_payload.get("scheduled_for"):
scheduled_for = datetime.fromisoformat(run_payload["scheduled_for"])
stmt = (
pg_insert(ActivityRun)
.values(
run_id=run_id,
activity_id=uuid.UUID(run_payload["activity_id"]),
scheduled_for=scheduled_for,
fired_at=datetime.now(tz=timezone.utc),
context_snapshot=run_payload["context_snapshot"],
tasks_spawned=run_payload["tasks_spawned"],
version_used=run_payload["version_used"],
)
.on_conflict_do_nothing(index_elements=["run_id"])
)
async with Session() as session:
async with session.begin():
await session.execute(stmt)
return str(run_id)
@activity.defn
async def persist_task_instance(task_payload: dict) -> str:
"""Write a TaskInstance row and return its id.
Idempotent: uses INSERT … ON CONFLICT (id) DO NOTHING.
Expected keys in task_payload:
id (str UUID — deterministic, computed in TaskExecutorWorkflow)
run_id (str UUID)
type (str)
params (dict)
status (str, default "done" for stub)
Returns:
task instance id as a str UUID.
"""
Session = _get_session_factory()
task_id = uuid.UUID(task_payload["id"])
stmt = (
pg_insert(TaskInstance)
.values(
id=task_id,
run_id=uuid.UUID(task_payload["run_id"]),
type=task_payload["type"],
params=task_payload.get("params", {}),
status=task_payload.get("status", "done"),
)
.on_conflict_do_nothing(index_elements=["id"])
)
async with Session() as session:
async with session.begin():
await session.execute(stmt)
return str(task_id)
@activity.defn
async def evaluate_rules(payload: dict) -> list[dict]:
"""Evaluate rules and render matching actions as task specs.
Rules that raise UnsafeExpression or any other error are skipped and logged.
Expected keys in payload:
rules list[dict] — RuleDef serialised dicts
event dict — EventEnvelope attributes (or empty for cron)
context dict — context snapshot from resolve_context
"""
from activity_core.rules.evaluator import UnsafeExpression
rules = payload.get("rules", [])
event_attrs = payload.get("event", {})
context = payload.get("context", {})
# Build a simple object whose attributes mirror event fields for the evaluator.
class _Env:
def __init__(self, attrs: dict) -> None:
self.attributes = _DictObj(attrs)
class _DictObj:
def __init__(self, d: dict) -> None:
self.__dict__.update(d)
event_obj = _Env(event_attrs)
task_specs: list[dict] = []
for rule in rules:
try:
task_specs.extend(expand_rule_actions([rule], event_obj, context))
except UnsafeExpression as exc:
activity.logger.warning("rule %r unsafe expression — skipping: %s", rule.get("id"), exc)
except Exception as exc:
activity.logger.warning("rule %r eval error — skipping: %s", rule.get("id"), exc)
return task_specs
@activity.defn
async def evaluate_instructions(payload: dict) -> dict:
"""Evaluate instruction blocks and return task specs/reports with audit fields.
Expected keys in payload:
instructions list[dict] — InstructionDef serialised dicts
event dict — EventEnvelope attributes (or empty for cron)
context dict — context snapshot from resolve_context
"""
instructions = payload.get("instructions", [])
event_attrs = payload.get("event", {})
context = payload.get("context", {})
llm_client = get_llm_client()
class _Env:
def __init__(self, attrs: dict) -> None:
self.attributes = _DictObj(attrs)
class _DictObj:
def __init__(self, d: dict) -> None:
self.__dict__.update(d)
event_obj = _Env(event_attrs)
task_specs: list[dict] = []
reports: list[dict] = []
for raw_instruction in instructions:
try:
instruction = InstructionDef.model_validate(raw_instruction)
except Exception as exc:
activity.logger.warning("instruction definition invalid — %s", exc)
continue
result = execute_instruction_with_audit(
instruction,
event_obj,
context,
llm_client,
)
if result.report is not None:
reports.append({
"instruction_id": instruction.id,
"report": result.report,
"sinks": instruction.report_sinks,
"condition": result.condition_matched,
"prompt_hash": result.prompt_hash,
"model": result.model,
"output_validated": result.output_validated,
"review_required": result.review_required,
"validation_error": result.validation_error,
})
for spec in result.tasks:
task_specs.append({
"title": spec.title,
"description": spec.description,
"target_repo": spec.target_repo,
"priority": spec.priority,
"labels": spec.labels,
"due_in_days": spec.due_in_days,
"source_type": "instruction",
"source_id": instruction.id,
"condition": result.condition_matched,
"prompt_hash": result.prompt_hash,
"model": result.model,
"output_validated": result.output_validated,
"review_required": result.review_required,
})
return {"task_specs": task_specs, "reports": reports}
@activity.defn
async def persist_instruction_reports(payload: dict) -> list[dict]:
"""Persist report payloads to deterministic configured sinks."""
return persist_reports(payload)
@activity.defn
async def persist_ops_evidence(payload: dict) -> list[dict]:
"""Persist compact deterministic ops inventory evidence."""
return persist_ops_inventory_evidence(payload)
@activity.defn
async def emit_tasks(payload: dict) -> list[str]:
"""Emit TaskSpecs to IssueSink and write task_spawn_log rows.
Returns list of external task ref IDs.
Expected keys in payload:
task_specs list[dict] — from evaluate_rules matched actions
activity_id str — UUID of the ActivityDefinition
triggering_event_id str — event ID or workflow ID for cron
run_id str — UUID of the ActivityRun
"""
from activity_core.rules.models import TaskSpec
task_specs_raw = payload.get("task_specs", [])
activity_id = payload.get("activity_id", "")
triggering_event_id = payload.get("triggering_event_id", "")
sink = get_issue_sink()
Session = _get_session_factory()
refs: list[str] = []
errors: list[str] = []
async with Session() as session:
async with session.begin():
for spec_dict in task_specs_raw:
spec = TaskSpec(
title=spec_dict.get("title", ""),
description=spec_dict.get("description", ""),
target_repo=spec_dict.get("target_repo"),
priority=spec_dict.get("priority", "medium"),
labels=spec_dict.get("labels", []),
due_in_days=spec_dict.get("due_in_days"),
source_type=spec_dict.get("source_type", "rule"),
source_id=spec_dict.get("source_id", ""),
triggering_event_id=triggering_event_id,
activity_definition_id=activity_id,
)
try:
ref = sink.emit(spec)
refs.append(ref.external_id)
log_row = TaskSpawnLog(
activity_def_id=uuid.UUID(activity_id),
source_type=spec.source_type,
source_id=spec.source_id,
source_version="1",
triggering_event_id=triggering_event_id,
task_ref=ref.external_id,
condition_matched=spec_dict.get("condition"),
prompt_hash=spec_dict.get("prompt_hash"),
model=spec_dict.get("model"),
output_validated=spec_dict.get("output_validated"),
review_required=spec_dict.get("review_required"),
)
session.add(log_row)
except Exception as exc:
message = f"{spec.source_type}:{spec.source_id}: {exc}"
errors.append(message)
activity.logger.warning("emit_tasks: sink.emit failed — %s", exc)
if errors:
raise RuntimeError(f"task emission sink failure: {errors!r}")
return refs