generated from coulomb/repo-seed
Pass instruction depth config to llm-connect
This commit is contained in:
@@ -17,7 +17,12 @@ import httpx
|
||||
class DisabledLLMClient:
|
||||
"""LLM client used when no llm-connect endpoint is configured."""
|
||||
|
||||
def complete(self, prompt: str, model: str = "") -> str: # noqa: ARG002
|
||||
def complete(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str = "",
|
||||
config: dict[str, Any] | None = None,
|
||||
) -> str: # noqa: ARG002
|
||||
raise RuntimeError("LLM_CONNECT_URL is not configured")
|
||||
|
||||
|
||||
@@ -28,13 +33,19 @@ class LLMConnectClient:
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.timeout_seconds = timeout_seconds
|
||||
|
||||
def complete(self, prompt: str, model: str = "") -> str:
|
||||
def complete(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str = "",
|
||||
config: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
run_config = dict(config or {})
|
||||
if model and "model_name" not in run_config:
|
||||
run_config["model_name"] = model
|
||||
run_config.setdefault("timeout_seconds", int(self.timeout_seconds))
|
||||
payload: dict[str, Any] = {
|
||||
"prompt": prompt,
|
||||
"config": {
|
||||
"model_name": model,
|
||||
"timeout_seconds": int(self.timeout_seconds),
|
||||
},
|
||||
"config": run_config,
|
||||
}
|
||||
resp = httpx.post(
|
||||
f"{self.base_url}/execute",
|
||||
|
||||
@@ -109,6 +109,10 @@ class InstructionDef(BaseModel):
|
||||
description="Allowlist of event/context fields that may appear in the prompt template.",
|
||||
)
|
||||
model: str = Field(description="LLM model identifier, e.g. 'claude-sonnet-4-6'.")
|
||||
temperature: float | None = Field(default=None)
|
||||
max_tokens: int | None = Field(default=None)
|
||||
max_depth: int | None = Field(default=None)
|
||||
model_params: dict[str, Any] = Field(default_factory=dict)
|
||||
prompt: str = Field(description="Prompt template with {field.path} placeholders.")
|
||||
output_schema: str = Field(description="Path to JSON Schema file for output validation.")
|
||||
review_required: bool = Field(default=False)
|
||||
|
||||
@@ -144,15 +144,16 @@ def _execute(
|
||||
# Step 2 — render prompt (raises UntrustedFieldError on policy violation)
|
||||
rendered = _render_prompt(instr.prompt, instr.trusted_fields, event, context)
|
||||
prompt_hash = hashlib.sha256(rendered.encode()).hexdigest()
|
||||
llm_config = _llm_run_config(instr)
|
||||
|
||||
# Step 3 — call LLM
|
||||
raw_output = llm_client.complete(rendered, model=instr.model)
|
||||
raw_output = llm_client.complete(rendered, model=instr.model, config=llm_config)
|
||||
|
||||
# Step 4 — validate and optionally retry
|
||||
task_specs, report, error = _validate_output(raw_output, instr)
|
||||
if error:
|
||||
retry_prompt = rendered + f"\n\nPrevious output was invalid: {error}\nPlease fix."
|
||||
raw_output = llm_client.complete(retry_prompt, model=instr.model)
|
||||
raw_output = llm_client.complete(retry_prompt, model=instr.model, config=llm_config)
|
||||
task_specs, report, error = _validate_output(raw_output, instr)
|
||||
if error:
|
||||
logger.warning(
|
||||
@@ -172,6 +173,19 @@ def _execute(
|
||||
)
|
||||
|
||||
|
||||
def _llm_run_config(instr: Any) -> dict[str, Any]:
|
||||
"""Build the llm-connect RunConfig payload from instruction metadata."""
|
||||
config: dict[str, Any] = {"model_name": instr.model}
|
||||
for field in ("temperature", "max_tokens", "max_depth"):
|
||||
value = getattr(instr, field, None)
|
||||
if value is not None:
|
||||
config[field] = value
|
||||
model_params = getattr(instr, "model_params", None)
|
||||
if model_params:
|
||||
config["model_params"] = model_params
|
||||
return config
|
||||
|
||||
|
||||
def _empty_result(instr: Any, prompt_hash: str | None = None) -> InstructionResult:
|
||||
return InstructionResult(
|
||||
tasks=[],
|
||||
|
||||
Reference in New Issue
Block a user