Files
activity-core/src/activity_core/activities.py
2026-05-19 18:36:58 +02:00

409 lines
14 KiB
Python

"""Temporal activity definitions for activity-core.
Activities run inside a Worker bound to 'orchestrator-tq'.
Each function is decorated with @activity.defn and executed by
RunActivityWorkflow via workflow.execute_activity().
DB access pattern: worker.py calls init_session_factory(url) once before
starting workers, which sets the module-level _session_factory used by
activities that need DB access.
"""
from __future__ import annotations
import uuid
from datetime import datetime, timezone
from sqlalchemy import select
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from temporalio import activity
from temporalio.exceptions import ApplicationError
from activity_core.db import make_engine
from activity_core.issue_sink import get_issue_sink
from activity_core.orm import ActivityDefinition as ActivityDefinitionRow
from activity_core.orm import ActivityRun, TaskInstance, TaskSpawnLog
from activity_core.rules import evaluate_condition
from activity_core.llm_client import get_llm_client
from activity_core.models import InstructionDef
from activity_core.report_sinks import persist_reports
from activity_core.rules.executor import execute_instruction_with_audit
_session_factory: async_sessionmaker[AsyncSession] | None = None
def init_session_factory(url: str) -> None:
"""Initialise the shared DB session factory.
Must be called once from worker.py before workers are started.
"""
global _session_factory
_session_factory = async_sessionmaker(make_engine(url), expire_on_commit=False)
def _get_session_factory() -> async_sessionmaker[AsyncSession]:
if _session_factory is None:
raise RuntimeError(
"DB session factory not initialised — call init_session_factory() first"
)
return _session_factory
# ── Activities ─────────────────────────────────────────────────────────────────
@activity.defn
async def load_activity_definition(activity_id: str) -> dict:
"""Load an ActivityDefinition row from Postgres by ID.
Returns a JSON-serialisable dict suitable for passing between
Temporal workflow steps.
Raises:
ApplicationError (non-retryable): if no row exists for activity_id.
"""
Session = _get_session_factory()
async with Session() as session:
row = await session.scalar(
select(ActivityDefinitionRow).where(
ActivityDefinitionRow.id == uuid.UUID(activity_id)
)
)
if row is None:
raise ApplicationError(
f"ActivityDefinition {activity_id!r} not found",
non_retryable=True,
)
return {
"id": str(row.id),
"name": row.name,
"enabled": row.enabled,
"trigger_type": row.trigger_type,
"trigger_config": row.trigger_config,
"context_sources": row.context_sources,
"task_templates": row.task_templates,
"rules": row.rules_json,
"instructions": row.instructions_json,
"dedupe_key_strategy": row.dedupe_key_strategy,
"version": row.version,
}
@activity.defn
async def resolve_context(
context_sources: list[dict],
event_envelope_json: str | None = None,
) -> dict:
"""Resolve each context source and merge into a snapshot dict.
Returns: {bind_key: resolved_value, ...}
Source types are dispatched via CONTEXT_RESOLVER_REGISTRY.
A resolver that raises logs a warning and binds {} — it does not abort the run.
The 'static' type is handled inline without a registry entry.
"""
import activity_core.context_resolvers # noqa: F401 — registers all adapters
from activity_core.context_resolvers.base import CONTEXT_RESOLVER_REGISTRY
snapshot: dict = {}
for source in context_sources:
source_type = source.get("type", "")
query = source.get("query", "")
params = source.get("params") or {}
raw_bind = source.get("bind_to") or source.get("name") or source_type
# Strip the 'context.' namespace prefix so evaluator can find the key.
bind_key = raw_bind.removeprefix("context.") if raw_bind.startswith("context.") else raw_bind
if source_type == "static":
snapshot[bind_key] = source.get("config", {}).get("value")
continue
resolver_cls = CONTEXT_RESOLVER_REGISTRY.get(source_type)
if resolver_cls is None:
activity.logger.warning(
"Unknown context source type %r — binding {}",
source_type,
)
snapshot[bind_key] = {}
continue
try:
snapshot[bind_key] = resolver_cls().resolve(query, None, params)
except Exception as exc:
activity.logger.warning(
"Context resolver %r failed — %s; binding {}",
source_type,
exc,
)
snapshot[bind_key] = {}
return snapshot
@activity.defn
async def log_run(run_payload: dict) -> str:
"""Persist an ActivityRun record to Postgres and return its run_id.
Idempotent: uses INSERT … ON CONFLICT (run_id) DO NOTHING so Temporal
activity retries do not produce duplicate rows.
Expected keys in run_payload:
run_id (str UUID — computed deterministically in workflow)
activity_id (str UUID)
scheduled_for (ISO-8601 str or None)
context_snapshot (dict)
tasks_spawned (int)
version_used (int)
Returns:
run_id as a str UUID.
"""
Session = _get_session_factory()
run_id = uuid.UUID(run_payload["run_id"])
scheduled_for: datetime | None = None
if run_payload.get("scheduled_for"):
scheduled_for = datetime.fromisoformat(run_payload["scheduled_for"])
stmt = (
pg_insert(ActivityRun)
.values(
run_id=run_id,
activity_id=uuid.UUID(run_payload["activity_id"]),
scheduled_for=scheduled_for,
fired_at=datetime.now(tz=timezone.utc),
context_snapshot=run_payload["context_snapshot"],
tasks_spawned=run_payload["tasks_spawned"],
version_used=run_payload["version_used"],
)
.on_conflict_do_nothing(index_elements=["run_id"])
)
async with Session() as session:
async with session.begin():
await session.execute(stmt)
return str(run_id)
@activity.defn
async def persist_task_instance(task_payload: dict) -> str:
"""Write a TaskInstance row and return its id.
Idempotent: uses INSERT … ON CONFLICT (id) DO NOTHING.
Expected keys in task_payload:
id (str UUID — deterministic, computed in TaskExecutorWorkflow)
run_id (str UUID)
type (str)
params (dict)
status (str, default "done" for stub)
Returns:
task instance id as a str UUID.
"""
Session = _get_session_factory()
task_id = uuid.UUID(task_payload["id"])
stmt = (
pg_insert(TaskInstance)
.values(
id=task_id,
run_id=uuid.UUID(task_payload["run_id"]),
type=task_payload["type"],
params=task_payload.get("params", {}),
status=task_payload.get("status", "done"),
)
.on_conflict_do_nothing(index_elements=["id"])
)
async with Session() as session:
async with session.begin():
await session.execute(stmt)
return str(task_id)
@activity.defn
async def evaluate_rules(payload: dict) -> list[dict]:
"""Evaluate each rule condition against the event and context.
Returns the list of matching rule dicts (those whose condition is True).
Rules that raise UnsafeExpression or any other error are skipped and logged.
Expected keys in payload:
rules list[dict] — RuleDef serialised dicts
event dict — EventEnvelope attributes (or empty for cron)
context dict — context snapshot from resolve_context
"""
from activity_core.rules.evaluator import UnsafeExpression
rules = payload.get("rules", [])
event_attrs = payload.get("event", {})
context = payload.get("context", {})
# Build a simple object whose attributes mirror event fields for the evaluator.
class _Env:
def __init__(self, attrs: dict) -> None:
self.attributes = _DictObj(attrs)
class _DictObj:
def __init__(self, d: dict) -> None:
self.__dict__.update(d)
event_obj = _Env(event_attrs)
matched: list[dict] = []
for rule in rules:
condition = rule.get("condition", "")
try:
if evaluate_condition(condition, event_obj, context):
matched.append(rule)
except UnsafeExpression as exc:
activity.logger.warning("rule %r unsafe expression — skipping: %s", rule.get("id"), exc)
except Exception as exc:
activity.logger.warning("rule %r eval error — skipping: %s", rule.get("id"), exc)
return matched
@activity.defn
async def evaluate_instructions(payload: dict) -> dict:
"""Evaluate instruction blocks and return task specs/reports with audit fields.
Expected keys in payload:
instructions list[dict] — InstructionDef serialised dicts
event dict — EventEnvelope attributes (or empty for cron)
context dict — context snapshot from resolve_context
"""
instructions = payload.get("instructions", [])
event_attrs = payload.get("event", {})
context = payload.get("context", {})
llm_client = get_llm_client()
class _Env:
def __init__(self, attrs: dict) -> None:
self.attributes = _DictObj(attrs)
class _DictObj:
def __init__(self, d: dict) -> None:
self.__dict__.update(d)
event_obj = _Env(event_attrs)
task_specs: list[dict] = []
reports: list[dict] = []
for raw_instruction in instructions:
try:
instruction = InstructionDef.model_validate(raw_instruction)
except Exception as exc:
activity.logger.warning("instruction definition invalid — %s", exc)
continue
result = execute_instruction_with_audit(
instruction,
event_obj,
context,
llm_client,
)
if result.report is not None:
reports.append({
"instruction_id": instruction.id,
"report": result.report,
"sinks": instruction.report_sinks,
"condition": result.condition_matched,
"prompt_hash": result.prompt_hash,
"model": result.model,
"output_validated": result.output_validated,
"review_required": result.review_required,
})
for spec in result.tasks:
task_specs.append({
"title": spec.title,
"description": spec.description,
"target_repo": spec.target_repo,
"priority": spec.priority,
"labels": spec.labels,
"due_in_days": spec.due_in_days,
"source_type": "instruction",
"source_id": instruction.id,
"condition": result.condition_matched,
"prompt_hash": result.prompt_hash,
"model": result.model,
"output_validated": result.output_validated,
"review_required": result.review_required,
})
return {"task_specs": task_specs, "reports": reports}
@activity.defn
async def persist_instruction_reports(payload: dict) -> list[dict]:
"""Persist report payloads to deterministic configured sinks."""
return persist_reports(payload)
@activity.defn
async def emit_tasks(payload: dict) -> list[str]:
"""Emit TaskSpecs to IssueSink and write task_spawn_log rows.
Returns list of external task ref IDs.
Expected keys in payload:
task_specs list[dict] — from evaluate_rules matched actions
activity_id str — UUID of the ActivityDefinition
triggering_event_id str — event ID or workflow ID for cron
run_id str — UUID of the ActivityRun
"""
from activity_core.rules.models import TaskSpec
task_specs_raw = payload.get("task_specs", [])
activity_id = payload.get("activity_id", "")
triggering_event_id = payload.get("triggering_event_id", "")
sink = get_issue_sink()
Session = _get_session_factory()
refs: list[str] = []
async with Session() as session:
async with session.begin():
for spec_dict in task_specs_raw:
spec = TaskSpec(
title=spec_dict.get("title", ""),
description=spec_dict.get("description", ""),
target_repo=spec_dict.get("target_repo"),
priority=spec_dict.get("priority", "medium"),
labels=spec_dict.get("labels", []),
due_in_days=spec_dict.get("due_in_days"),
source_type=spec_dict.get("source_type", "rule"),
source_id=spec_dict.get("source_id", ""),
triggering_event_id=triggering_event_id,
activity_definition_id=activity_id,
)
try:
ref = sink.emit(spec)
refs.append(ref.external_id)
log_row = TaskSpawnLog(
activity_def_id=uuid.UUID(activity_id),
source_type=spec.source_type,
source_id=spec.source_id,
source_version="1",
triggering_event_id=triggering_event_id,
task_ref=ref.external_id,
condition_matched=spec_dict.get("condition"),
prompt_hash=spec_dict.get("prompt_hash"),
model=spec_dict.get("model"),
output_validated=spec_dict.get("output_validated"),
review_required=spec_dict.get("review_required"),
)
session.add(log_row)
except Exception as exc:
activity.logger.warning("emit_tasks: sink.emit failed — %s", exc)
return refs