Files
llm-connect/llm_connect/grading.py
tegwick c4ad4bb9f2
Some checks failed
CI / test (3.10) (push) Has been cancelled
CI / test (3.11) (push) Has been cancelled
CI / test (3.12) (push) Has been cancelled
Add adaptive cost-quality routing primitives
2026-05-17 21:32:27 +02:00

240 lines
7.7 KiB
Python

"""Baseline grading primitives for adaptive routing.
Graders compare a candidate adapter response against a caller-chosen baseline.
They produce normalised quality scores that can be recorded in a
``QualityLedger`` and consumed later by adaptive routing policy.
"""
from __future__ import annotations
import json
import re
from dataclasses import dataclass, field, replace
from typing import Any, Protocol
from llm_connect.adapter import LLMAdapter
from llm_connect.embedding_adapter import EmbeddingAdapter
from llm_connect.models import LLMResponse, RunConfig
from llm_connect.similarity import cosine_similarity
def _validate_score(value: float) -> float:
if not isinstance(value, (int, float)):
raise ValueError("quality_score must be a number between 0 and 1")
score = float(value)
if not 0 <= score <= 1:
raise ValueError("quality_score must be between 0 and 1")
return score
def _normalise_text(text: str) -> str:
return " ".join(text.strip().split())
@dataclass(frozen=True)
class GradingResult:
"""Structured result from comparing candidate output to baseline output."""
quality_score: float
notes: str
grader_id: str
baseline_response: LLMResponse
candidate_response: LLMResponse
def __post_init__(self) -> None:
if not str(self.grader_id).strip():
raise ValueError("grader_id must be a non-empty string")
object.__setattr__(self, "quality_score", _validate_score(self.quality_score))
object.__setattr__(self, "notes", str(self.notes))
class Judge(Protocol):
"""Compare baseline and candidate responses."""
grader_id: str
def judge(
self,
baseline_response: LLMResponse,
candidate_response: LLMResponse,
*,
prompt: str,
run_config: RunConfig,
) -> GradingResult:
"""Return a quality score for candidate relative to baseline."""
class BaselineGrader(Protocol):
"""Run baseline and candidate adapters, then judge the paired responses."""
def grade(
self,
baseline_adapter: LLMAdapter,
candidate_adapter: LLMAdapter,
prompt: str,
run_config: RunConfig,
) -> GradingResult:
"""Return a structured grading result."""
@dataclass
class ExactMatchJudge:
"""Judge that scores 1.0 when response text matches exactly after normalisation."""
normalize_whitespace: bool = True
case_sensitive: bool = True
grader_id: str = "exact-match"
def judge(
self,
baseline_response: LLMResponse,
candidate_response: LLMResponse,
*,
prompt: str,
run_config: RunConfig,
) -> GradingResult:
baseline_text = baseline_response.content
candidate_text = candidate_response.content
if self.normalize_whitespace:
baseline_text = _normalise_text(baseline_text)
candidate_text = _normalise_text(candidate_text)
if not self.case_sensitive:
baseline_text = baseline_text.casefold()
candidate_text = candidate_text.casefold()
matched = baseline_text == candidate_text
return GradingResult(
quality_score=1.0 if matched else 0.0,
notes="exact match" if matched else "candidate content differs from baseline",
grader_id=self.grader_id,
baseline_response=baseline_response,
candidate_response=candidate_response,
)
@dataclass
class EmbeddingSimilarityJudge:
"""Judge that maps cosine similarity between response embeddings to 0..1."""
embedding_adapter: EmbeddingAdapter
grader_id: str = "embedding-similarity"
def judge(
self,
baseline_response: LLMResponse,
candidate_response: LLMResponse,
*,
prompt: str,
run_config: RunConfig,
) -> GradingResult:
embeddings = self.embedding_adapter.embed(
[baseline_response.content, candidate_response.content]
)
if len(embeddings) != 2:
raise ValueError("EmbeddingSimilarityJudge expected exactly two embeddings")
raw_similarity = cosine_similarity(embeddings[0], embeddings[1])
quality_score = max(0.0, min(1.0, raw_similarity))
return GradingResult(
quality_score=quality_score,
notes=f"cosine similarity {raw_similarity:.4f}",
grader_id=self.grader_id,
baseline_response=baseline_response,
candidate_response=candidate_response,
)
@dataclass
class LLMJudge:
"""LLM-as-judge wrapper using a fixed rubric prompt and JSON response."""
judge_adapter: LLMAdapter
rubric: str = (
"Compare the candidate response to the baseline response. "
"Return JSON only with keys quality_score and notes. "
"quality_score must be a number from 0 to 1."
)
grader_id: str = "llm-judge"
seed: int | None = 0
def judge(
self,
baseline_response: LLMResponse,
candidate_response: LLMResponse,
*,
prompt: str,
run_config: RunConfig,
) -> GradingResult:
judge_prompt = self._build_prompt(prompt, baseline_response, candidate_response)
judge_config = self._judge_config(run_config)
response = self.judge_adapter.execute_prompt(judge_prompt, judge_config)
parsed = self._parse_judge_response(response.content)
return GradingResult(
quality_score=parsed["quality_score"],
notes=parsed["notes"],
grader_id=self.grader_id,
baseline_response=baseline_response,
candidate_response=candidate_response,
)
def _judge_config(self, run_config: RunConfig) -> RunConfig:
params: dict[str, Any] = dict(run_config.model_params)
if self.seed is not None:
params.setdefault("seed", self.seed)
return replace(run_config, temperature=0.0, model_params=params, budget_tracker=None)
def _build_prompt(
self,
prompt: str,
baseline_response: LLMResponse,
candidate_response: LLMResponse,
) -> str:
return (
f"{self.rubric}\n\n"
f"Original prompt:\n{prompt}\n\n"
f"Baseline response:\n{baseline_response.content}\n\n"
f"Candidate response:\n{candidate_response.content}\n"
)
def _parse_judge_response(self, content: str) -> dict[str, Any]:
try:
data = json.loads(content)
except json.JSONDecodeError:
match = re.search(r"\{.*\}", content, flags=re.DOTALL)
if not match:
raise ValueError("LLMJudge response did not contain JSON") from None
try:
data = json.loads(match.group(0))
except json.JSONDecodeError as exc:
raise ValueError("LLMJudge response JSON could not be parsed") from exc
if not isinstance(data, dict):
raise ValueError("LLMJudge response JSON must be an object")
return {
"quality_score": _validate_score(data.get("quality_score")),
"notes": str(data.get("notes", "")),
}
@dataclass
class PairedGrader:
"""Baseline grader that runs both adapters and delegates comparison to a judge."""
judge: Judge = field(default_factory=ExactMatchJudge)
def grade(
self,
baseline_adapter: LLMAdapter,
candidate_adapter: LLMAdapter,
prompt: str,
run_config: RunConfig,
) -> GradingResult:
baseline_response = baseline_adapter.execute_prompt(prompt, run_config)
candidate_response = candidate_adapter.execute_prompt(prompt, run_config)
return self.judge.judge(
baseline_response,
candidate_response,
prompt=prompt,
run_config=run_config,
)