Pass instruction depth config to llm-connect

This commit is contained in:
2026-05-19 20:55:35 +02:00
parent 1ff8b14d1b
commit 5c4f96e7aa
6 changed files with 208 additions and 16 deletions

View File

@@ -17,7 +17,12 @@ import httpx
class DisabledLLMClient:
"""LLM client used when no llm-connect endpoint is configured."""
def complete(self, prompt: str, model: str = "") -> str: # noqa: ARG002
def complete(
self,
prompt: str,
model: str = "",
config: dict[str, Any] | None = None,
) -> str: # noqa: ARG002
raise RuntimeError("LLM_CONNECT_URL is not configured")
@@ -28,13 +33,19 @@ class LLMConnectClient:
self.base_url = base_url.rstrip("/")
self.timeout_seconds = timeout_seconds
def complete(self, prompt: str, model: str = "") -> str:
def complete(
self,
prompt: str,
model: str = "",
config: dict[str, Any] | None = None,
) -> str:
run_config = dict(config or {})
if model and "model_name" not in run_config:
run_config["model_name"] = model
run_config.setdefault("timeout_seconds", int(self.timeout_seconds))
payload: dict[str, Any] = {
"prompt": prompt,
"config": {
"model_name": model,
"timeout_seconds": int(self.timeout_seconds),
},
"config": run_config,
}
resp = httpx.post(
f"{self.base_url}/execute",

View File

@@ -109,6 +109,10 @@ class InstructionDef(BaseModel):
description="Allowlist of event/context fields that may appear in the prompt template.",
)
model: str = Field(description="LLM model identifier, e.g. 'claude-sonnet-4-6'.")
temperature: float | None = Field(default=None)
max_tokens: int | None = Field(default=None)
max_depth: int | None = Field(default=None)
model_params: dict[str, Any] = Field(default_factory=dict)
prompt: str = Field(description="Prompt template with {field.path} placeholders.")
output_schema: str = Field(description="Path to JSON Schema file for output validation.")
review_required: bool = Field(default=False)

View File

@@ -144,15 +144,16 @@ def _execute(
# Step 2 — render prompt (raises UntrustedFieldError on policy violation)
rendered = _render_prompt(instr.prompt, instr.trusted_fields, event, context)
prompt_hash = hashlib.sha256(rendered.encode()).hexdigest()
llm_config = _llm_run_config(instr)
# Step 3 — call LLM
raw_output = llm_client.complete(rendered, model=instr.model)
raw_output = llm_client.complete(rendered, model=instr.model, config=llm_config)
# Step 4 — validate and optionally retry
task_specs, report, error = _validate_output(raw_output, instr)
if error:
retry_prompt = rendered + f"\n\nPrevious output was invalid: {error}\nPlease fix."
raw_output = llm_client.complete(retry_prompt, model=instr.model)
raw_output = llm_client.complete(retry_prompt, model=instr.model, config=llm_config)
task_specs, report, error = _validate_output(raw_output, instr)
if error:
logger.warning(
@@ -172,6 +173,19 @@ def _execute(
)
def _llm_run_config(instr: Any) -> dict[str, Any]:
"""Build the llm-connect RunConfig payload from instruction metadata."""
config: dict[str, Any] = {"model_name": instr.model}
for field in ("temperature", "max_tokens", "max_depth"):
value = getattr(instr, field, None)
if value is not None:
config[field] = value
model_params = getattr(instr, "model_params", None)
if model_params:
config["model_params"] = model_params
return config
def _empty_result(instr: Any, prompt_hash: str | None = None) -> InstructionResult:
return InstructionResult(
tasks=[],

View File

@@ -30,14 +30,24 @@ from activity_core.rules.executor import (
class _NullLLM:
"""Always returns an empty task list."""
def complete(self, prompt: str, model: str = "") -> str:
def complete(
self,
prompt: str,
model: str = "",
config: dict | None = None,
) -> str:
return "[]"
class _BadLLM:
"""Returns invalid JSON on every call."""
def complete(self, prompt: str, model: str = "") -> str:
def complete(
self,
prompt: str,
model: str = "",
config: dict | None = None,
) -> str:
return "not valid json {"
@@ -47,9 +57,16 @@ class _CountingLLM:
def __init__(self, responses: list[str]) -> None:
self._responses = list(responses)
self.call_count = 0
self.calls: list[dict | None] = []
def complete(self, prompt: str, model: str = "") -> str:
def complete(
self,
prompt: str,
model: str = "",
config: dict | None = None,
) -> str:
self.call_count += 1
self.calls.append(config)
if self._responses:
return self._responses.pop(0)
return "[]"
@@ -77,6 +94,10 @@ def _instr(
model: str = "claude-sonnet-4-6",
output_schema: str = "",
review_required: bool = False,
temperature: float | None = None,
max_tokens: int | None = None,
max_depth: int | None = None,
model_params: dict[str, Any] | None = None,
) -> SimpleNamespace:
return SimpleNamespace(
id=id,
@@ -84,6 +105,10 @@ def _instr(
trusted_fields=trusted_fields or [],
prompt=prompt,
model=model,
temperature=temperature,
max_tokens=max_tokens,
max_depth=max_depth,
model_params=model_params or {},
output_schema=output_schema,
review_required=review_required,
)
@@ -225,6 +250,31 @@ def test_execute_instruction_with_audit_returns_metadata():
assert result.review_required is True
def test_execute_instruction_forwards_llm_connect_run_config():
llm = _CountingLLM(["[]"])
instr = _instr(
prompt="Check State Hub.",
trusted_fields=[],
model="custodian-triage-balanced",
temperature=0.2,
max_tokens=1200,
max_depth=2,
model_params={"reasoning_effort": "medium"},
)
execute_instruction_with_audit(instr, _Event(), {}, llm)
assert llm.calls == [
{
"model_name": "custodian-triage-balanced",
"temperature": 0.2,
"max_tokens": 1200,
"max_depth": 2,
"model_params": {"reasoning_effort": "medium"},
}
]
def test_execute_instruction_with_audit_accepts_report_payload():
report_data = {
"summary": "State Hub has loose ends.",
@@ -312,6 +362,22 @@ def test_review_required_field_on_instruction_def():
assert defn.review_required is True
def test_instruction_def_accepts_llm_connect_depth_config():
defn = InstructionDef(
id="test",
trusted_fields=[],
model="custodian-triage-balanced",
temperature=0.2,
max_tokens=1200,
max_depth=2,
model_params={"reasoning_effort": "medium"},
prompt="p",
output_schema="schema.json",
)
assert defn.max_depth == 2
assert defn.model_params == {"reasoning_effort": "medium"}
def test_review_required_defaults_to_false():
defn = InstructionDef(
id="test",

View File

@@ -10,10 +10,15 @@ from activity_core import activities
class FakeLLMClient:
def __init__(self, response: str) -> None:
self.response = response
self.calls: list[tuple[str, str]] = []
self.calls: list[tuple[str, str, dict | None]] = []
def complete(self, prompt: str, model: str = "") -> str:
self.calls.append((prompt, model))
def complete(
self,
prompt: str,
model: str = "",
config: dict | None = None,
) -> str:
self.calls.append((prompt, model, config))
return self.response
@@ -56,7 +61,9 @@ async def test_evaluate_instructions_returns_task_specs_with_audit(monkeypatch)
assert spec["prompt_hash"] is not None
assert len(spec["prompt_hash"]) == 64
assert result["reports"] == []
assert llm.calls == [("Open tasks: 3", "test-model")]
assert llm.calls == [
("Open tasks: 3", "test-model", {"model_name": "test-model"})
]
@pytest.mark.asyncio
@@ -94,7 +101,12 @@ async def test_evaluate_instructions_returns_report_payload(monkeypatch) -> None
@pytest.mark.asyncio
async def test_evaluate_instructions_without_llm_client_returns_no_tasks(monkeypatch) -> None:
class RaisingClient:
def complete(self, prompt: str, model: str = "") -> str: # noqa: ARG002
def complete(
self,
prompt: str,
model: str = "",
config: dict | None = None,
) -> str: # noqa: ARG002
raise RuntimeError("not configured")
monkeypatch.setattr(activities, "get_llm_client", lambda: RaisingClient())
@@ -114,3 +126,36 @@ async def test_evaluate_instructions_without_llm_client_returns_no_tasks(monkeyp
})
assert result == {"task_specs": [], "reports": []}
@pytest.mark.asyncio
async def test_evaluate_instructions_forwards_llm_connect_depth_config(monkeypatch) -> None:
llm = FakeLLMClient(json.dumps({"summary": "ok", "recommendations": []}))
monkeypatch.setattr(activities, "get_llm_client", lambda: llm)
await activities.evaluate_instructions({
"instructions": [
{
"id": "daily-triage-report",
"trusted_fields": [],
"model": "custodian-triage-balanced",
"temperature": 0.2,
"max_tokens": 1200,
"max_depth": 2,
"model_params": {"reasoning_effort": "medium"},
"prompt": "Run report.",
"output_schema": "schemas/daily-triage-report.json",
"review_required": False,
}
],
"event": {},
"context": {},
})
assert llm.calls[0][2] == {
"model_name": "custodian-triage-balanced",
"temperature": 0.2,
"max_tokens": 1200,
"max_depth": 2,
"model_params": {"reasoning_effort": "medium"},
}

52
tests/test_llm_client.py Normal file
View File

@@ -0,0 +1,52 @@
from __future__ import annotations
import httpx
from activity_core.llm_client import LLMConnectClient
def test_llm_connect_client_forwards_run_config(monkeypatch) -> None:
captured: dict = {}
class Response:
def raise_for_status(self) -> None:
pass
def json(self) -> dict:
return {"content": '{"summary":"ok","recommendations":[]}'}
def fake_post(url: str, json: dict, timeout: float) -> Response:
captured["url"] = url
captured["json"] = json
captured["timeout"] = timeout
return Response()
monkeypatch.setattr(httpx, "post", fake_post)
client = LLMConnectClient("http://llm-connect.local/", timeout_seconds=42)
result = client.complete(
"Prompt",
model="fallback-model",
config={
"model_name": "custodian-triage-balanced",
"temperature": 0.2,
"max_tokens": 1200,
"max_depth": 2,
"model_params": {"reasoning_effort": "medium"},
},
)
assert result == '{"summary":"ok","recommendations":[]}'
assert captured["url"] == "http://llm-connect.local/execute"
assert captured["timeout"] == 42
assert captured["json"] == {
"prompt": "Prompt",
"config": {
"model_name": "custodian-triage-balanced",
"temperature": 0.2,
"max_tokens": 1200,
"max_depth": 2,
"model_params": {"reasoning_effort": "medium"},
"timeout_seconds": 42,
},
}