generated from coulomb/repo-seed
294 lines
11 KiB
Python
294 lines
11 KiB
Python
"""Named runtime profiles for server-mode adapter dispatch."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
import threading
|
|
from dataclasses import dataclass, field, replace
|
|
from pathlib import Path
|
|
from typing import Any, Callable, Mapping
|
|
|
|
from llm_connect.adapter import LLMAdapter
|
|
from llm_connect.exceptions import LLMConfigurationError
|
|
from llm_connect.factory import create_adapter
|
|
from llm_connect.models import LLMResponse, RunConfig
|
|
|
|
CUSTODIAN_TRIAGE_BALANCED = "custodian-triage-balanced"
|
|
DEFAULT_CUSTODIAN_TRIAGE_PROVIDER = "openrouter"
|
|
DEFAULT_CUSTODIAN_TRIAGE_MODEL = "anthropic/claude-sonnet-4"
|
|
_RUN_CONFIG_DEFAULTS = RunConfig()
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class RuntimeProfile:
|
|
"""Provider/model routing and default call config for a named profile."""
|
|
|
|
name: str
|
|
provider: str
|
|
model: str
|
|
config: RunConfig = field(default_factory=RunConfig)
|
|
|
|
def resolve_config(self, request_config: RunConfig) -> RunConfig:
|
|
"""Merge profile defaults with request overrides.
|
|
|
|
`RunConfig` has value defaults rather than optional fields, so the
|
|
merge is intentionally conservative: provider/model identity comes from
|
|
the profile, scalar generation fields come from the request, and
|
|
`model_params` are shallow-merged with request keys winning.
|
|
"""
|
|
|
|
merged_params = {
|
|
**(self.config.model_params or {}),
|
|
**(request_config.model_params or {}),
|
|
}
|
|
return replace(
|
|
request_config,
|
|
model_name=self.model,
|
|
temperature=_profile_default_if_unchanged(
|
|
request_config.temperature,
|
|
_RUN_CONFIG_DEFAULTS.temperature,
|
|
self.config.temperature,
|
|
),
|
|
max_tokens=_profile_default_if_unchanged(
|
|
request_config.max_tokens,
|
|
_RUN_CONFIG_DEFAULTS.max_tokens,
|
|
self.config.max_tokens,
|
|
),
|
|
max_depth=_profile_default_if_unchanged(
|
|
request_config.max_depth,
|
|
_RUN_CONFIG_DEFAULTS.max_depth,
|
|
self.config.max_depth,
|
|
),
|
|
timeout_seconds=_profile_default_if_unchanged(
|
|
request_config.timeout_seconds,
|
|
_RUN_CONFIG_DEFAULTS.timeout_seconds,
|
|
self.config.timeout_seconds,
|
|
),
|
|
model_params=merged_params,
|
|
)
|
|
|
|
|
|
class ProfiledLLMAdapter(LLMAdapter):
|
|
"""Adapter wrapper that dispatches named profile requests to adapters."""
|
|
|
|
def __init__(
|
|
self,
|
|
default_adapter: LLMAdapter,
|
|
profiles: Mapping[str, RuntimeProfile],
|
|
*,
|
|
adapter_factory: Callable[[str, str], LLMAdapter] | None = None,
|
|
strict_profiles: bool = False,
|
|
profile_prefixes: tuple[str, ...] = ("custodian-",),
|
|
) -> None:
|
|
self.default_adapter = default_adapter
|
|
self.profiles = dict(profiles)
|
|
self.adapter_factory = adapter_factory or _default_adapter_factory
|
|
self.strict_profiles = strict_profiles
|
|
self.profile_prefixes = profile_prefixes
|
|
self._adapters: dict[tuple[str, str], LLMAdapter] = {}
|
|
self._lock = threading.Lock()
|
|
|
|
def execute_prompt(self, prompt: str, config: RunConfig) -> LLMResponse:
|
|
profile = self._resolve_profile(config.model_name)
|
|
if profile is None:
|
|
return self.default_adapter.execute_prompt(prompt, config)
|
|
|
|
adapter = self._adapter_for(profile)
|
|
resolved_config = profile.resolve_config(config)
|
|
response = adapter.execute_prompt(prompt, resolved_config)
|
|
response.metadata.setdefault("profile", profile.name)
|
|
response.metadata.setdefault("profile_provider", profile.provider)
|
|
response.metadata.setdefault("profile_model", profile.model)
|
|
return response
|
|
|
|
async def async_execute_prompt(self, prompt: str, config: RunConfig) -> LLMResponse:
|
|
profile = self._resolve_profile(config.model_name)
|
|
if profile is None:
|
|
return await self.default_adapter.async_execute_prompt(prompt, config)
|
|
|
|
adapter = self._adapter_for(profile)
|
|
resolved_config = profile.resolve_config(config)
|
|
response = await adapter.async_execute_prompt(prompt, resolved_config)
|
|
response.metadata.setdefault("profile", profile.name)
|
|
response.metadata.setdefault("profile_provider", profile.provider)
|
|
response.metadata.setdefault("profile_model", profile.model)
|
|
return response
|
|
|
|
def validate_config(self, config: RunConfig) -> bool:
|
|
profile = self._resolve_profile(config.model_name)
|
|
if profile is None:
|
|
return self.default_adapter.validate_config(config)
|
|
return self._adapter_for(profile).validate_config(profile.resolve_config(config))
|
|
|
|
def _resolve_profile(self, model_name: str) -> RuntimeProfile | None:
|
|
profile = self.profiles.get(model_name)
|
|
if profile is not None:
|
|
return profile
|
|
|
|
if self.strict_profiles or model_name.startswith(self.profile_prefixes):
|
|
known = ", ".join(sorted(self.profiles)) or "(none configured)"
|
|
raise LLMConfigurationError(
|
|
f"Unknown LLM runtime profile {model_name!r}. Known profiles: {known}",
|
|
context={"profile": model_name},
|
|
)
|
|
return None
|
|
|
|
def _adapter_for(self, profile: RuntimeProfile) -> LLMAdapter:
|
|
key = (profile.provider, profile.model)
|
|
with self._lock:
|
|
adapter = self._adapters.get(key)
|
|
if adapter is None:
|
|
adapter = self.adapter_factory(profile.provider, profile.model)
|
|
self._adapters[key] = adapter
|
|
return adapter
|
|
|
|
|
|
def default_runtime_profiles(
|
|
*,
|
|
provider: str | None = None,
|
|
model: str | None = None,
|
|
) -> dict[str, RuntimeProfile]:
|
|
"""Return built-in runtime profiles, with env/config overrides applied."""
|
|
|
|
triage_provider = (
|
|
os.environ.get("LLM_CONNECT_CUSTODIAN_TRIAGE_PROVIDER")
|
|
or provider
|
|
or DEFAULT_CUSTODIAN_TRIAGE_PROVIDER
|
|
)
|
|
triage_model = (
|
|
os.environ.get("LLM_CONNECT_CUSTODIAN_TRIAGE_MODEL")
|
|
or model
|
|
or DEFAULT_CUSTODIAN_TRIAGE_MODEL
|
|
)
|
|
profiles = {
|
|
CUSTODIAN_TRIAGE_BALANCED: RuntimeProfile(
|
|
name=CUSTODIAN_TRIAGE_BALANCED,
|
|
provider=triage_provider,
|
|
model=triage_model,
|
|
config=RunConfig(
|
|
model_name=triage_model,
|
|
temperature=_float_env("LLM_CONNECT_CUSTODIAN_TRIAGE_TEMPERATURE", 0.2),
|
|
max_tokens=_int_env("LLM_CONNECT_CUSTODIAN_TRIAGE_MAX_TOKENS", 1800),
|
|
max_depth=_int_env("LLM_CONNECT_CUSTODIAN_TRIAGE_MAX_DEPTH", 2),
|
|
timeout_seconds=_int_env("LLM_CONNECT_CUSTODIAN_TRIAGE_TIMEOUT_SECONDS", 300),
|
|
model_params={
|
|
"reasoning_effort": os.environ.get(
|
|
"LLM_CONNECT_CUSTODIAN_TRIAGE_REASONING_EFFORT",
|
|
"medium",
|
|
),
|
|
},
|
|
),
|
|
)
|
|
}
|
|
profiles.update(load_runtime_profiles_from_env())
|
|
return profiles
|
|
|
|
|
|
def load_runtime_profiles_from_env() -> dict[str, RuntimeProfile]:
|
|
"""Load optional profile overrides from JSON env/file config."""
|
|
|
|
raw = os.environ.get("LLM_CONNECT_PROFILES_JSON")
|
|
path = os.environ.get("LLM_CONNECT_PROFILE_FILE")
|
|
if raw and path:
|
|
raise LLMConfigurationError(
|
|
"Set only one of LLM_CONNECT_PROFILES_JSON or LLM_CONNECT_PROFILE_FILE",
|
|
context={"config": "runtime_profiles"},
|
|
)
|
|
if path:
|
|
try:
|
|
raw = Path(path).read_text(encoding="utf-8")
|
|
except OSError as exc:
|
|
raise LLMConfigurationError(
|
|
f"Could not read LLM runtime profile file {path!r}",
|
|
cause=exc,
|
|
context={"config": "runtime_profiles"},
|
|
) from exc
|
|
if not raw:
|
|
return {}
|
|
|
|
try:
|
|
data = json.loads(raw)
|
|
except json.JSONDecodeError as exc:
|
|
raise LLMConfigurationError(
|
|
"LLM runtime profile config must be valid JSON",
|
|
cause=exc,
|
|
context={"config": "runtime_profiles"},
|
|
) from exc
|
|
|
|
profiles_data = data.get("profiles", data) if isinstance(data, dict) else None
|
|
if not isinstance(profiles_data, dict):
|
|
raise LLMConfigurationError(
|
|
"LLM runtime profile config must be an object keyed by profile name",
|
|
context={"config": "runtime_profiles"},
|
|
)
|
|
|
|
return {
|
|
name: _profile_from_mapping(name, value)
|
|
for name, value in profiles_data.items()
|
|
}
|
|
|
|
|
|
def _profile_from_mapping(name: str, value: Any) -> RuntimeProfile:
|
|
if not isinstance(value, dict):
|
|
raise LLMConfigurationError(
|
|
f"Runtime profile {name!r} must be an object",
|
|
context={"profile": name},
|
|
)
|
|
provider = value.get("provider")
|
|
model = value.get("model")
|
|
if not isinstance(provider, str) or not provider:
|
|
raise LLMConfigurationError(
|
|
f"Runtime profile {name!r} requires a provider",
|
|
context={"profile": name},
|
|
)
|
|
if not isinstance(model, str) or not model:
|
|
raise LLMConfigurationError(
|
|
f"Runtime profile {name!r} requires a model",
|
|
context={"profile": name},
|
|
)
|
|
config_data = value.get("config", {})
|
|
if not isinstance(config_data, dict):
|
|
raise LLMConfigurationError(
|
|
f"Runtime profile {name!r} config must be an object",
|
|
context={"profile": name},
|
|
)
|
|
config = RunConfig.from_dict({"model_name": model, **config_data})
|
|
return RuntimeProfile(name=name, provider=provider, model=model, config=config)
|
|
|
|
|
|
def _default_adapter_factory(provider: str, model: str) -> LLMAdapter:
|
|
return create_adapter(provider, model=model)
|
|
|
|
|
|
def _profile_default_if_unchanged(value: Any, default: Any, profile_value: Any) -> Any:
|
|
return profile_value if value == default else value
|
|
|
|
|
|
def _int_env(name: str, default: int) -> int:
|
|
value = os.environ.get(name)
|
|
if value is None or value == "":
|
|
return default
|
|
try:
|
|
return int(value)
|
|
except ValueError as exc:
|
|
raise LLMConfigurationError(
|
|
f"{name} must be an integer",
|
|
cause=exc,
|
|
context={"env": name},
|
|
) from exc
|
|
|
|
|
|
def _float_env(name: str, default: float) -> float:
|
|
value = os.environ.get(name)
|
|
if value is None or value == "":
|
|
return default
|
|
try:
|
|
return float(value)
|
|
except ValueError as exc:
|
|
raise LLMConfigurationError(
|
|
f"{name} must be a number",
|
|
cause=exc,
|
|
context={"env": name},
|
|
) from exc
|