generated from coulomb/repo-seed
409 lines
14 KiB
Python
409 lines
14 KiB
Python
"""Temporal activity definitions for activity-core.
|
|
|
|
Activities run inside a Worker bound to 'orchestrator-tq'.
|
|
Each function is decorated with @activity.defn and executed by
|
|
RunActivityWorkflow via workflow.execute_activity().
|
|
|
|
DB access pattern: worker.py calls init_session_factory(url) once before
|
|
starting workers, which sets the module-level _session_factory used by
|
|
activities that need DB access.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import uuid
|
|
from datetime import datetime, timezone
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
|
from temporalio import activity
|
|
from temporalio.exceptions import ApplicationError
|
|
|
|
from activity_core.db import make_engine
|
|
from activity_core.issue_sink import get_issue_sink
|
|
from activity_core.orm import ActivityDefinition as ActivityDefinitionRow
|
|
from activity_core.orm import ActivityRun, TaskInstance, TaskSpawnLog
|
|
from activity_core.rules import evaluate_condition
|
|
from activity_core.llm_client import get_llm_client
|
|
from activity_core.models import InstructionDef
|
|
from activity_core.report_sinks import persist_reports
|
|
from activity_core.rules.executor import execute_instruction_with_audit
|
|
|
|
|
|
_session_factory: async_sessionmaker[AsyncSession] | None = None
|
|
|
|
|
|
def init_session_factory(url: str) -> None:
|
|
"""Initialise the shared DB session factory.
|
|
|
|
Must be called once from worker.py before workers are started.
|
|
"""
|
|
global _session_factory
|
|
_session_factory = async_sessionmaker(make_engine(url), expire_on_commit=False)
|
|
|
|
|
|
def _get_session_factory() -> async_sessionmaker[AsyncSession]:
|
|
if _session_factory is None:
|
|
raise RuntimeError(
|
|
"DB session factory not initialised — call init_session_factory() first"
|
|
)
|
|
return _session_factory
|
|
|
|
|
|
# ── Activities ─────────────────────────────────────────────────────────────────
|
|
|
|
@activity.defn
|
|
async def load_activity_definition(activity_id: str) -> dict:
|
|
"""Load an ActivityDefinition row from Postgres by ID.
|
|
|
|
Returns a JSON-serialisable dict suitable for passing between
|
|
Temporal workflow steps.
|
|
|
|
Raises:
|
|
ApplicationError (non-retryable): if no row exists for activity_id.
|
|
"""
|
|
Session = _get_session_factory()
|
|
async with Session() as session:
|
|
row = await session.scalar(
|
|
select(ActivityDefinitionRow).where(
|
|
ActivityDefinitionRow.id == uuid.UUID(activity_id)
|
|
)
|
|
)
|
|
|
|
if row is None:
|
|
raise ApplicationError(
|
|
f"ActivityDefinition {activity_id!r} not found",
|
|
non_retryable=True,
|
|
)
|
|
|
|
return {
|
|
"id": str(row.id),
|
|
"name": row.name,
|
|
"enabled": row.enabled,
|
|
"trigger_type": row.trigger_type,
|
|
"trigger_config": row.trigger_config,
|
|
"context_sources": row.context_sources,
|
|
"task_templates": row.task_templates,
|
|
"rules": row.rules_json,
|
|
"instructions": row.instructions_json,
|
|
"dedupe_key_strategy": row.dedupe_key_strategy,
|
|
"version": row.version,
|
|
}
|
|
|
|
|
|
@activity.defn
|
|
async def resolve_context(
|
|
context_sources: list[dict],
|
|
event_envelope_json: str | None = None,
|
|
) -> dict:
|
|
"""Resolve each context source and merge into a snapshot dict.
|
|
|
|
Returns: {bind_key: resolved_value, ...}
|
|
|
|
Source types are dispatched via CONTEXT_RESOLVER_REGISTRY.
|
|
A resolver that raises logs a warning and binds {} — it does not abort the run.
|
|
The 'static' type is handled inline without a registry entry.
|
|
"""
|
|
import activity_core.context_resolvers # noqa: F401 — registers all adapters
|
|
from activity_core.context_resolvers.base import CONTEXT_RESOLVER_REGISTRY
|
|
|
|
snapshot: dict = {}
|
|
for source in context_sources:
|
|
source_type = source.get("type", "")
|
|
query = source.get("query", "")
|
|
params = source.get("params") or {}
|
|
raw_bind = source.get("bind_to") or source.get("name") or source_type
|
|
# Strip the 'context.' namespace prefix so evaluator can find the key.
|
|
bind_key = raw_bind.removeprefix("context.") if raw_bind.startswith("context.") else raw_bind
|
|
|
|
if source_type == "static":
|
|
snapshot[bind_key] = source.get("config", {}).get("value")
|
|
continue
|
|
|
|
resolver_cls = CONTEXT_RESOLVER_REGISTRY.get(source_type)
|
|
if resolver_cls is None:
|
|
activity.logger.warning(
|
|
"Unknown context source type %r — binding {}",
|
|
source_type,
|
|
)
|
|
snapshot[bind_key] = {}
|
|
continue
|
|
|
|
try:
|
|
snapshot[bind_key] = resolver_cls().resolve(query, None, params)
|
|
except Exception as exc:
|
|
activity.logger.warning(
|
|
"Context resolver %r failed — %s; binding {}",
|
|
source_type,
|
|
exc,
|
|
)
|
|
snapshot[bind_key] = {}
|
|
|
|
return snapshot
|
|
|
|
|
|
@activity.defn
|
|
async def log_run(run_payload: dict) -> str:
|
|
"""Persist an ActivityRun record to Postgres and return its run_id.
|
|
|
|
Idempotent: uses INSERT … ON CONFLICT (run_id) DO NOTHING so Temporal
|
|
activity retries do not produce duplicate rows.
|
|
|
|
Expected keys in run_payload:
|
|
run_id (str UUID — computed deterministically in workflow)
|
|
activity_id (str UUID)
|
|
scheduled_for (ISO-8601 str or None)
|
|
context_snapshot (dict)
|
|
tasks_spawned (int)
|
|
version_used (int)
|
|
|
|
Returns:
|
|
run_id as a str UUID.
|
|
"""
|
|
Session = _get_session_factory()
|
|
|
|
run_id = uuid.UUID(run_payload["run_id"])
|
|
|
|
scheduled_for: datetime | None = None
|
|
if run_payload.get("scheduled_for"):
|
|
scheduled_for = datetime.fromisoformat(run_payload["scheduled_for"])
|
|
|
|
stmt = (
|
|
pg_insert(ActivityRun)
|
|
.values(
|
|
run_id=run_id,
|
|
activity_id=uuid.UUID(run_payload["activity_id"]),
|
|
scheduled_for=scheduled_for,
|
|
fired_at=datetime.now(tz=timezone.utc),
|
|
context_snapshot=run_payload["context_snapshot"],
|
|
tasks_spawned=run_payload["tasks_spawned"],
|
|
version_used=run_payload["version_used"],
|
|
)
|
|
.on_conflict_do_nothing(index_elements=["run_id"])
|
|
)
|
|
|
|
async with Session() as session:
|
|
async with session.begin():
|
|
await session.execute(stmt)
|
|
|
|
return str(run_id)
|
|
|
|
|
|
@activity.defn
|
|
async def persist_task_instance(task_payload: dict) -> str:
|
|
"""Write a TaskInstance row and return its id.
|
|
|
|
Idempotent: uses INSERT … ON CONFLICT (id) DO NOTHING.
|
|
|
|
Expected keys in task_payload:
|
|
id (str UUID — deterministic, computed in TaskExecutorWorkflow)
|
|
run_id (str UUID)
|
|
type (str)
|
|
params (dict)
|
|
status (str, default "done" for stub)
|
|
|
|
Returns:
|
|
task instance id as a str UUID.
|
|
"""
|
|
Session = _get_session_factory()
|
|
task_id = uuid.UUID(task_payload["id"])
|
|
|
|
stmt = (
|
|
pg_insert(TaskInstance)
|
|
.values(
|
|
id=task_id,
|
|
run_id=uuid.UUID(task_payload["run_id"]),
|
|
type=task_payload["type"],
|
|
params=task_payload.get("params", {}),
|
|
status=task_payload.get("status", "done"),
|
|
)
|
|
.on_conflict_do_nothing(index_elements=["id"])
|
|
)
|
|
|
|
async with Session() as session:
|
|
async with session.begin():
|
|
await session.execute(stmt)
|
|
|
|
return str(task_id)
|
|
|
|
|
|
@activity.defn
|
|
async def evaluate_rules(payload: dict) -> list[dict]:
|
|
"""Evaluate each rule condition against the event and context.
|
|
|
|
Returns the list of matching rule dicts (those whose condition is True).
|
|
Rules that raise UnsafeExpression or any other error are skipped and logged.
|
|
|
|
Expected keys in payload:
|
|
rules list[dict] — RuleDef serialised dicts
|
|
event dict — EventEnvelope attributes (or empty for cron)
|
|
context dict — context snapshot from resolve_context
|
|
"""
|
|
from activity_core.rules.evaluator import UnsafeExpression
|
|
|
|
rules = payload.get("rules", [])
|
|
event_attrs = payload.get("event", {})
|
|
context = payload.get("context", {})
|
|
|
|
# Build a simple object whose attributes mirror event fields for the evaluator.
|
|
class _Env:
|
|
def __init__(self, attrs: dict) -> None:
|
|
self.attributes = _DictObj(attrs)
|
|
|
|
class _DictObj:
|
|
def __init__(self, d: dict) -> None:
|
|
self.__dict__.update(d)
|
|
|
|
event_obj = _Env(event_attrs)
|
|
|
|
matched: list[dict] = []
|
|
for rule in rules:
|
|
condition = rule.get("condition", "")
|
|
try:
|
|
if evaluate_condition(condition, event_obj, context):
|
|
matched.append(rule)
|
|
except UnsafeExpression as exc:
|
|
activity.logger.warning("rule %r unsafe expression — skipping: %s", rule.get("id"), exc)
|
|
except Exception as exc:
|
|
activity.logger.warning("rule %r eval error — skipping: %s", rule.get("id"), exc)
|
|
|
|
return matched
|
|
|
|
|
|
@activity.defn
|
|
async def evaluate_instructions(payload: dict) -> dict:
|
|
"""Evaluate instruction blocks and return task specs/reports with audit fields.
|
|
|
|
Expected keys in payload:
|
|
instructions list[dict] — InstructionDef serialised dicts
|
|
event dict — EventEnvelope attributes (or empty for cron)
|
|
context dict — context snapshot from resolve_context
|
|
"""
|
|
instructions = payload.get("instructions", [])
|
|
event_attrs = payload.get("event", {})
|
|
context = payload.get("context", {})
|
|
llm_client = get_llm_client()
|
|
|
|
class _Env:
|
|
def __init__(self, attrs: dict) -> None:
|
|
self.attributes = _DictObj(attrs)
|
|
|
|
class _DictObj:
|
|
def __init__(self, d: dict) -> None:
|
|
self.__dict__.update(d)
|
|
|
|
event_obj = _Env(event_attrs)
|
|
|
|
task_specs: list[dict] = []
|
|
reports: list[dict] = []
|
|
for raw_instruction in instructions:
|
|
try:
|
|
instruction = InstructionDef.model_validate(raw_instruction)
|
|
except Exception as exc:
|
|
activity.logger.warning("instruction definition invalid — %s", exc)
|
|
continue
|
|
|
|
result = execute_instruction_with_audit(
|
|
instruction,
|
|
event_obj,
|
|
context,
|
|
llm_client,
|
|
)
|
|
if result.report is not None:
|
|
reports.append({
|
|
"instruction_id": instruction.id,
|
|
"report": result.report,
|
|
"sinks": instruction.report_sinks,
|
|
"condition": result.condition_matched,
|
|
"prompt_hash": result.prompt_hash,
|
|
"model": result.model,
|
|
"output_validated": result.output_validated,
|
|
"review_required": result.review_required,
|
|
})
|
|
for spec in result.tasks:
|
|
task_specs.append({
|
|
"title": spec.title,
|
|
"description": spec.description,
|
|
"target_repo": spec.target_repo,
|
|
"priority": spec.priority,
|
|
"labels": spec.labels,
|
|
"due_in_days": spec.due_in_days,
|
|
"source_type": "instruction",
|
|
"source_id": instruction.id,
|
|
"condition": result.condition_matched,
|
|
"prompt_hash": result.prompt_hash,
|
|
"model": result.model,
|
|
"output_validated": result.output_validated,
|
|
"review_required": result.review_required,
|
|
})
|
|
|
|
return {"task_specs": task_specs, "reports": reports}
|
|
|
|
|
|
@activity.defn
|
|
async def persist_instruction_reports(payload: dict) -> list[dict]:
|
|
"""Persist report payloads to deterministic configured sinks."""
|
|
return persist_reports(payload)
|
|
|
|
|
|
@activity.defn
|
|
async def emit_tasks(payload: dict) -> list[str]:
|
|
"""Emit TaskSpecs to IssueSink and write task_spawn_log rows.
|
|
|
|
Returns list of external task ref IDs.
|
|
|
|
Expected keys in payload:
|
|
task_specs list[dict] — from evaluate_rules matched actions
|
|
activity_id str — UUID of the ActivityDefinition
|
|
triggering_event_id str — event ID or workflow ID for cron
|
|
run_id str — UUID of the ActivityRun
|
|
"""
|
|
from activity_core.rules.models import TaskSpec
|
|
|
|
task_specs_raw = payload.get("task_specs", [])
|
|
activity_id = payload.get("activity_id", "")
|
|
triggering_event_id = payload.get("triggering_event_id", "")
|
|
|
|
sink = get_issue_sink()
|
|
Session = _get_session_factory()
|
|
|
|
refs: list[str] = []
|
|
async with Session() as session:
|
|
async with session.begin():
|
|
for spec_dict in task_specs_raw:
|
|
spec = TaskSpec(
|
|
title=spec_dict.get("title", ""),
|
|
description=spec_dict.get("description", ""),
|
|
target_repo=spec_dict.get("target_repo"),
|
|
priority=spec_dict.get("priority", "medium"),
|
|
labels=spec_dict.get("labels", []),
|
|
due_in_days=spec_dict.get("due_in_days"),
|
|
source_type=spec_dict.get("source_type", "rule"),
|
|
source_id=spec_dict.get("source_id", ""),
|
|
triggering_event_id=triggering_event_id,
|
|
activity_definition_id=activity_id,
|
|
)
|
|
try:
|
|
ref = sink.emit(spec)
|
|
refs.append(ref.external_id)
|
|
|
|
log_row = TaskSpawnLog(
|
|
activity_def_id=uuid.UUID(activity_id),
|
|
source_type=spec.source_type,
|
|
source_id=spec.source_id,
|
|
source_version="1",
|
|
triggering_event_id=triggering_event_id,
|
|
task_ref=ref.external_id,
|
|
condition_matched=spec_dict.get("condition"),
|
|
prompt_hash=spec_dict.get("prompt_hash"),
|
|
model=spec_dict.get("model"),
|
|
output_validated=spec_dict.get("output_validated"),
|
|
review_required=spec_dict.get("review_required"),
|
|
)
|
|
session.add(log_row)
|
|
except Exception as exc:
|
|
activity.logger.warning("emit_tasks: sink.emit failed — %s", exc)
|
|
|
|
return refs
|