From 5c4f96e7aa63fc9cfafb4ed03242947f05b6db40 Mon Sep 17 00:00:00 2001 From: tegwick Date: Tue, 19 May 2026 20:55:35 +0200 Subject: [PATCH] Pass instruction depth config to llm-connect --- src/activity_core/llm_client.py | 23 ++++++--- src/activity_core/models.py | 4 ++ src/activity_core/rules/executor.py | 18 ++++++- tests/rules/test_executor.py | 72 ++++++++++++++++++++++++++-- tests/test_instruction_evaluation.py | 55 +++++++++++++++++++-- tests/test_llm_client.py | 52 ++++++++++++++++++++ 6 files changed, 208 insertions(+), 16 deletions(-) create mode 100644 tests/test_llm_client.py diff --git a/src/activity_core/llm_client.py b/src/activity_core/llm_client.py index 076c2c6..cc9d8cb 100644 --- a/src/activity_core/llm_client.py +++ b/src/activity_core/llm_client.py @@ -17,7 +17,12 @@ import httpx class DisabledLLMClient: """LLM client used when no llm-connect endpoint is configured.""" - def complete(self, prompt: str, model: str = "") -> str: # noqa: ARG002 + def complete( + self, + prompt: str, + model: str = "", + config: dict[str, Any] | None = None, + ) -> str: # noqa: ARG002 raise RuntimeError("LLM_CONNECT_URL is not configured") @@ -28,13 +33,19 @@ class LLMConnectClient: self.base_url = base_url.rstrip("/") self.timeout_seconds = timeout_seconds - def complete(self, prompt: str, model: str = "") -> str: + def complete( + self, + prompt: str, + model: str = "", + config: dict[str, Any] | None = None, + ) -> str: + run_config = dict(config or {}) + if model and "model_name" not in run_config: + run_config["model_name"] = model + run_config.setdefault("timeout_seconds", int(self.timeout_seconds)) payload: dict[str, Any] = { "prompt": prompt, - "config": { - "model_name": model, - "timeout_seconds": int(self.timeout_seconds), - }, + "config": run_config, } resp = httpx.post( f"{self.base_url}/execute", diff --git a/src/activity_core/models.py b/src/activity_core/models.py index 89c2525..a5d2700 100644 --- a/src/activity_core/models.py +++ b/src/activity_core/models.py @@ -109,6 +109,10 @@ class InstructionDef(BaseModel): description="Allowlist of event/context fields that may appear in the prompt template.", ) model: str = Field(description="LLM model identifier, e.g. 'claude-sonnet-4-6'.") + temperature: float | None = Field(default=None) + max_tokens: int | None = Field(default=None) + max_depth: int | None = Field(default=None) + model_params: dict[str, Any] = Field(default_factory=dict) prompt: str = Field(description="Prompt template with {field.path} placeholders.") output_schema: str = Field(description="Path to JSON Schema file for output validation.") review_required: bool = Field(default=False) diff --git a/src/activity_core/rules/executor.py b/src/activity_core/rules/executor.py index 7561601..7c6894c 100644 --- a/src/activity_core/rules/executor.py +++ b/src/activity_core/rules/executor.py @@ -144,15 +144,16 @@ def _execute( # Step 2 — render prompt (raises UntrustedFieldError on policy violation) rendered = _render_prompt(instr.prompt, instr.trusted_fields, event, context) prompt_hash = hashlib.sha256(rendered.encode()).hexdigest() + llm_config = _llm_run_config(instr) # Step 3 — call LLM - raw_output = llm_client.complete(rendered, model=instr.model) + raw_output = llm_client.complete(rendered, model=instr.model, config=llm_config) # Step 4 — validate and optionally retry task_specs, report, error = _validate_output(raw_output, instr) if error: retry_prompt = rendered + f"\n\nPrevious output was invalid: {error}\nPlease fix." - raw_output = llm_client.complete(retry_prompt, model=instr.model) + raw_output = llm_client.complete(retry_prompt, model=instr.model, config=llm_config) task_specs, report, error = _validate_output(raw_output, instr) if error: logger.warning( @@ -172,6 +173,19 @@ def _execute( ) +def _llm_run_config(instr: Any) -> dict[str, Any]: + """Build the llm-connect RunConfig payload from instruction metadata.""" + config: dict[str, Any] = {"model_name": instr.model} + for field in ("temperature", "max_tokens", "max_depth"): + value = getattr(instr, field, None) + if value is not None: + config[field] = value + model_params = getattr(instr, "model_params", None) + if model_params: + config["model_params"] = model_params + return config + + def _empty_result(instr: Any, prompt_hash: str | None = None) -> InstructionResult: return InstructionResult( tasks=[], diff --git a/tests/rules/test_executor.py b/tests/rules/test_executor.py index e21637a..558a68a 100644 --- a/tests/rules/test_executor.py +++ b/tests/rules/test_executor.py @@ -30,14 +30,24 @@ from activity_core.rules.executor import ( class _NullLLM: """Always returns an empty task list.""" - def complete(self, prompt: str, model: str = "") -> str: + def complete( + self, + prompt: str, + model: str = "", + config: dict | None = None, + ) -> str: return "[]" class _BadLLM: """Returns invalid JSON on every call.""" - def complete(self, prompt: str, model: str = "") -> str: + def complete( + self, + prompt: str, + model: str = "", + config: dict | None = None, + ) -> str: return "not valid json {" @@ -47,9 +57,16 @@ class _CountingLLM: def __init__(self, responses: list[str]) -> None: self._responses = list(responses) self.call_count = 0 + self.calls: list[dict | None] = [] - def complete(self, prompt: str, model: str = "") -> str: + def complete( + self, + prompt: str, + model: str = "", + config: dict | None = None, + ) -> str: self.call_count += 1 + self.calls.append(config) if self._responses: return self._responses.pop(0) return "[]" @@ -77,6 +94,10 @@ def _instr( model: str = "claude-sonnet-4-6", output_schema: str = "", review_required: bool = False, + temperature: float | None = None, + max_tokens: int | None = None, + max_depth: int | None = None, + model_params: dict[str, Any] | None = None, ) -> SimpleNamespace: return SimpleNamespace( id=id, @@ -84,6 +105,10 @@ def _instr( trusted_fields=trusted_fields or [], prompt=prompt, model=model, + temperature=temperature, + max_tokens=max_tokens, + max_depth=max_depth, + model_params=model_params or {}, output_schema=output_schema, review_required=review_required, ) @@ -225,6 +250,31 @@ def test_execute_instruction_with_audit_returns_metadata(): assert result.review_required is True +def test_execute_instruction_forwards_llm_connect_run_config(): + llm = _CountingLLM(["[]"]) + instr = _instr( + prompt="Check State Hub.", + trusted_fields=[], + model="custodian-triage-balanced", + temperature=0.2, + max_tokens=1200, + max_depth=2, + model_params={"reasoning_effort": "medium"}, + ) + + execute_instruction_with_audit(instr, _Event(), {}, llm) + + assert llm.calls == [ + { + "model_name": "custodian-triage-balanced", + "temperature": 0.2, + "max_tokens": 1200, + "max_depth": 2, + "model_params": {"reasoning_effort": "medium"}, + } + ] + + def test_execute_instruction_with_audit_accepts_report_payload(): report_data = { "summary": "State Hub has loose ends.", @@ -312,6 +362,22 @@ def test_review_required_field_on_instruction_def(): assert defn.review_required is True +def test_instruction_def_accepts_llm_connect_depth_config(): + defn = InstructionDef( + id="test", + trusted_fields=[], + model="custodian-triage-balanced", + temperature=0.2, + max_tokens=1200, + max_depth=2, + model_params={"reasoning_effort": "medium"}, + prompt="p", + output_schema="schema.json", + ) + assert defn.max_depth == 2 + assert defn.model_params == {"reasoning_effort": "medium"} + + def test_review_required_defaults_to_false(): defn = InstructionDef( id="test", diff --git a/tests/test_instruction_evaluation.py b/tests/test_instruction_evaluation.py index 5afb713..41c3e9e 100644 --- a/tests/test_instruction_evaluation.py +++ b/tests/test_instruction_evaluation.py @@ -10,10 +10,15 @@ from activity_core import activities class FakeLLMClient: def __init__(self, response: str) -> None: self.response = response - self.calls: list[tuple[str, str]] = [] + self.calls: list[tuple[str, str, dict | None]] = [] - def complete(self, prompt: str, model: str = "") -> str: - self.calls.append((prompt, model)) + def complete( + self, + prompt: str, + model: str = "", + config: dict | None = None, + ) -> str: + self.calls.append((prompt, model, config)) return self.response @@ -56,7 +61,9 @@ async def test_evaluate_instructions_returns_task_specs_with_audit(monkeypatch) assert spec["prompt_hash"] is not None assert len(spec["prompt_hash"]) == 64 assert result["reports"] == [] - assert llm.calls == [("Open tasks: 3", "test-model")] + assert llm.calls == [ + ("Open tasks: 3", "test-model", {"model_name": "test-model"}) + ] @pytest.mark.asyncio @@ -94,7 +101,12 @@ async def test_evaluate_instructions_returns_report_payload(monkeypatch) -> None @pytest.mark.asyncio async def test_evaluate_instructions_without_llm_client_returns_no_tasks(monkeypatch) -> None: class RaisingClient: - def complete(self, prompt: str, model: str = "") -> str: # noqa: ARG002 + def complete( + self, + prompt: str, + model: str = "", + config: dict | None = None, + ) -> str: # noqa: ARG002 raise RuntimeError("not configured") monkeypatch.setattr(activities, "get_llm_client", lambda: RaisingClient()) @@ -114,3 +126,36 @@ async def test_evaluate_instructions_without_llm_client_returns_no_tasks(monkeyp }) assert result == {"task_specs": [], "reports": []} + + +@pytest.mark.asyncio +async def test_evaluate_instructions_forwards_llm_connect_depth_config(monkeypatch) -> None: + llm = FakeLLMClient(json.dumps({"summary": "ok", "recommendations": []})) + monkeypatch.setattr(activities, "get_llm_client", lambda: llm) + + await activities.evaluate_instructions({ + "instructions": [ + { + "id": "daily-triage-report", + "trusted_fields": [], + "model": "custodian-triage-balanced", + "temperature": 0.2, + "max_tokens": 1200, + "max_depth": 2, + "model_params": {"reasoning_effort": "medium"}, + "prompt": "Run report.", + "output_schema": "schemas/daily-triage-report.json", + "review_required": False, + } + ], + "event": {}, + "context": {}, + }) + + assert llm.calls[0][2] == { + "model_name": "custodian-triage-balanced", + "temperature": 0.2, + "max_tokens": 1200, + "max_depth": 2, + "model_params": {"reasoning_effort": "medium"}, + } diff --git a/tests/test_llm_client.py b/tests/test_llm_client.py new file mode 100644 index 0000000..81a40bc --- /dev/null +++ b/tests/test_llm_client.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +import httpx + +from activity_core.llm_client import LLMConnectClient + + +def test_llm_connect_client_forwards_run_config(monkeypatch) -> None: + captured: dict = {} + + class Response: + def raise_for_status(self) -> None: + pass + + def json(self) -> dict: + return {"content": '{"summary":"ok","recommendations":[]}'} + + def fake_post(url: str, json: dict, timeout: float) -> Response: + captured["url"] = url + captured["json"] = json + captured["timeout"] = timeout + return Response() + + monkeypatch.setattr(httpx, "post", fake_post) + + client = LLMConnectClient("http://llm-connect.local/", timeout_seconds=42) + result = client.complete( + "Prompt", + model="fallback-model", + config={ + "model_name": "custodian-triage-balanced", + "temperature": 0.2, + "max_tokens": 1200, + "max_depth": 2, + "model_params": {"reasoning_effort": "medium"}, + }, + ) + + assert result == '{"summary":"ok","recommendations":[]}' + assert captured["url"] == "http://llm-connect.local/execute" + assert captured["timeout"] == 42 + assert captured["json"] == { + "prompt": "Prompt", + "config": { + "model_name": "custodian-triage-balanced", + "temperature": 0.2, + "max_tokens": 1200, + "max_depth": 2, + "model_params": {"reasoning_effort": "medium"}, + "timeout_seconds": 42, + }, + }