generated from coulomb/repo-seed
155 lines
4.4 KiB
Python
155 lines
4.4 KiB
Python
"""Provider payload helpers for translating ``RunConfig.model_params``."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from typing import Any
|
|
|
|
from llm_connect._diagnostics import (
|
|
diagnostics_enabled,
|
|
json_safe,
|
|
record_adapter_transformation,
|
|
)
|
|
|
|
|
|
# OpenAI Chat Completions fields that map straight through from model_params.
|
|
# Anything not in this set is provider-specific and must be either translated
|
|
# or dropped. Blind merges are deliberately avoided because OpenAI-compatible
|
|
# providers commonly reject unknown top-level fields with HTTP 400.
|
|
OPENAI_CHAT_PASSTHROUGH_FIELDS = frozenset(
|
|
{
|
|
"top_p",
|
|
"n",
|
|
"stream",
|
|
"stop",
|
|
"presence_penalty",
|
|
"frequency_penalty",
|
|
"logit_bias",
|
|
"user",
|
|
"seed",
|
|
"tools",
|
|
"tool_choice",
|
|
"response_format",
|
|
"logprobs",
|
|
"top_logprobs",
|
|
"parallel_tool_calls",
|
|
}
|
|
)
|
|
|
|
|
|
DROPPED_NON_OPENAI_FIELDS = frozenset(
|
|
{
|
|
"reasoning_effort",
|
|
"max_depth",
|
|
"claude_cli_path",
|
|
"json_schema",
|
|
}
|
|
)
|
|
|
|
|
|
GEMINI_TOP_LEVEL_FIELDS = frozenset(
|
|
{
|
|
"safetySettings",
|
|
"tools",
|
|
"toolConfig",
|
|
"systemInstruction",
|
|
"cachedContent",
|
|
}
|
|
)
|
|
|
|
|
|
GEMINI_GENERATION_CONFIG_FIELDS = frozenset(
|
|
{
|
|
"candidateCount",
|
|
"stopSequences",
|
|
"maxOutputTokens",
|
|
"temperature",
|
|
"topP",
|
|
"topK",
|
|
"responseMimeType",
|
|
"responseSchema",
|
|
}
|
|
)
|
|
|
|
|
|
GEMINI_GENERATION_CONFIG_ALIASES = {
|
|
"candidate_count": "candidateCount",
|
|
"stop_sequences": "stopSequences",
|
|
"max_output_tokens": "maxOutputTokens",
|
|
"top_p": "topP",
|
|
"top_k": "topK",
|
|
"response_mime_type": "responseMimeType",
|
|
"response_schema": "responseSchema",
|
|
}
|
|
|
|
|
|
def merge_openai_chat_model_params(payload: dict[str, Any], model_params: dict[str, Any]) -> None:
|
|
"""Merge model_params into an OpenAI Chat Completions-style payload.
|
|
|
|
Translates ``json_schema`` to ``response_format``, passes known OpenAI
|
|
fields through, and drops Claude/llm-connect-only knobs.
|
|
"""
|
|
|
|
before = json_safe(payload) if diagnostics_enabled() else None
|
|
|
|
schema = _coerce_json_schema(model_params.get("json_schema"))
|
|
caller_response_format = model_params.get("response_format")
|
|
if schema is not None and caller_response_format is None and "response_format" not in payload:
|
|
payload["response_format"] = {
|
|
"type": "json_schema",
|
|
"json_schema": {
|
|
"name": "structured_output",
|
|
"schema": schema,
|
|
"strict": False,
|
|
},
|
|
}
|
|
|
|
for key, value in model_params.items():
|
|
if key in DROPPED_NON_OPENAI_FIELDS:
|
|
continue
|
|
if key in OPENAI_CHAT_PASSTHROUGH_FIELDS:
|
|
payload[key] = value
|
|
|
|
if before is not None:
|
|
record_adapter_transformation("merge_model_params.openai_chat", before, payload)
|
|
|
|
|
|
def merge_gemini_model_params(payload: dict[str, Any], model_params: dict[str, Any]) -> None:
|
|
"""Merge model_params into a Gemini ``generateContent`` payload."""
|
|
|
|
before = json_safe(payload) if diagnostics_enabled() else None
|
|
generation_config = payload.setdefault("generationConfig", {})
|
|
|
|
schema = _coerce_json_schema(model_params.get("json_schema"))
|
|
if schema is not None and "responseSchema" not in generation_config:
|
|
generation_config["responseMimeType"] = "application/json"
|
|
generation_config["responseSchema"] = schema
|
|
|
|
explicit_generation_config = model_params.get("generationConfig")
|
|
if isinstance(explicit_generation_config, dict):
|
|
generation_config.update(explicit_generation_config)
|
|
|
|
for key, value in model_params.items():
|
|
if key in {"json_schema", "generationConfig", "reasoning_effort", "max_depth"}:
|
|
continue
|
|
if key in GEMINI_TOP_LEVEL_FIELDS:
|
|
payload[key] = value
|
|
continue
|
|
gemini_key = GEMINI_GENERATION_CONFIG_ALIASES.get(key, key)
|
|
if gemini_key in GEMINI_GENERATION_CONFIG_FIELDS:
|
|
generation_config[gemini_key] = value
|
|
|
|
if before is not None:
|
|
record_adapter_transformation("merge_model_params.gemini", before, payload)
|
|
|
|
|
|
def _coerce_json_schema(schema: Any) -> dict[str, Any] | None:
|
|
if isinstance(schema, str):
|
|
try:
|
|
schema = json.loads(schema)
|
|
except (TypeError, ValueError):
|
|
return None
|
|
if isinstance(schema, dict):
|
|
return schema
|
|
return None
|