generated from coulomb/repo-seed
T03 — wrap_with_shadow_sampling() helper in routing.py: builds a llm-connect ShadowingAdapter around any candidate LLMAdapter with a caller-supplied baseline, grader, and QualityLedger. async_shadow=True by default so production load is not doubled; on_shadow_error escape hatch keeps caller logs informed when a baseline outage swallows the shadow path. The returned adapter is still an LLMAdapter so it slots into a RoutingPolicy rule without further code change. T04 — generation report enrichment plus a small CLI helper: - _collect_adapter_choices walks artifact provenance, groups by (stage_id, adapter_id), and surfaces calls + prompt/completion tokens per (stage, adapter) pair in a new ## Per-stage adapter choices section. Runs that did not go through the bridge have no provider_metadata.adapter_id and emit an empty list, so fixture-only reports stay terse. - summarise_quality_ledger() rolls a llm-connect QualityLedger up by (task_type, adapter_id) with mean quality, mean cost, observations, and cumulative tokens. - infospace-bench routing ledger <path> CLI prints the rollup as JSON. Five new tests cover shadow happy-path, shadow failure isolation, ledger rollup, the routing CLI, and the report's adapter-choice aggregation. Closes IB-WP-0018: T01-T05 are all done and the workplan status flips from blocked to done now that LLM-WP-0004's primitives have shipped. 144 tests pass, 1 skipped (the OpenRouter live smoke, gated as before). Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
427 lines
15 KiB
Python
427 lines
15 KiB
Python
"""
|
|
Tests for the routing bridge that wraps an llm-connect RoutingPolicy as
|
|
an infospace-bench AssistedGenerationAdapter (IB-WP-0018 T02/T05).
|
|
|
|
All tests use mocked llm-connect ``LLMAdapter`` instances — no network.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
|
|
import pytest
|
|
|
|
from llm_connect.adapter import LLMAdapter
|
|
from llm_connect.models import LLMResponse, RunConfig
|
|
from llm_connect.routing import (
|
|
AdaptiveRoutingPolicy,
|
|
RoutingPolicy,
|
|
RoutingRule,
|
|
)
|
|
from llm_connect.quality import QualityLedger, QualityObservation
|
|
|
|
from infospace_bench.routing import (
|
|
STAGE_TO_TASK_TYPE_DEFAULT,
|
|
RoutingAssistedGenerationAdapter,
|
|
)
|
|
from infospace_bench.workflow import AssistedGenerationRequest
|
|
|
|
|
|
class _MockAdapter(LLMAdapter):
|
|
"""Test double: returns a configured ``LLMResponse`` and records calls."""
|
|
|
|
def __init__(self, *, model: str, content: str = "ok", cost_per_call: float = 0.0) -> None:
|
|
self.model = model
|
|
self._content = content
|
|
self._cost_per_call = cost_per_call
|
|
self.calls: list[tuple[str, RunConfig]] = []
|
|
|
|
def execute_prompt(self, prompt: str, config: RunConfig) -> LLMResponse:
|
|
self.calls.append((prompt, config))
|
|
return LLMResponse(
|
|
content=self._content,
|
|
model=self.model,
|
|
usage={"prompt_tokens": 100, "completion_tokens": 50},
|
|
finish_reason="stop",
|
|
metadata={"cost_per_call_usd": self._cost_per_call},
|
|
)
|
|
|
|
def validate_config(self, config: RunConfig) -> bool: # pragma: no cover - trivial
|
|
return True
|
|
|
|
|
|
def _request(stage_id: str, prompt: str = "Hello.") -> AssistedGenerationRequest:
|
|
return AssistedGenerationRequest(
|
|
stage_id=stage_id,
|
|
workflow_id="generic-source-entities",
|
|
input_artifact_id="source/test.md",
|
|
prompt=prompt,
|
|
)
|
|
|
|
|
|
def test_bridge_resolves_static_policy_per_stage() -> None:
|
|
cheap = _MockAdapter(model="cheap-1", content="# Cheap")
|
|
smart = _MockAdapter(model="smart-1", content="# Smart")
|
|
policy = RoutingPolicy(
|
|
rules=[
|
|
RoutingRule(task_type="summarize-source", prefer=cheap),
|
|
RoutingRule(task_type="extract-entities", prefer=smart),
|
|
],
|
|
default=cheap,
|
|
)
|
|
bridge = RoutingAssistedGenerationAdapter(policy=policy)
|
|
|
|
summary = bridge.generate(_request("summarize-source", "Source A"))
|
|
entities = bridge.generate(_request("extract-entities", "Source A"))
|
|
|
|
assert summary.markdown == "# Cheap"
|
|
assert summary.metadata["task_type"] == "summarize-source"
|
|
assert summary.metadata["model"] == "cheap-1"
|
|
assert summary.metadata["usage"]["prompt_tokens"] == 100
|
|
assert entities.markdown == "# Smart"
|
|
assert entities.metadata["model"] == "smart-1"
|
|
assert len(cheap.calls) == 1
|
|
assert len(smart.calls) == 1
|
|
|
|
|
|
def test_bridge_honours_stage_to_task_type_overrides() -> None:
|
|
extraction = _MockAdapter(model="extraction-1")
|
|
policy = RoutingPolicy(
|
|
rules=[RoutingRule(task_type="extraction", prefer=extraction)],
|
|
)
|
|
bridge = RoutingAssistedGenerationAdapter(
|
|
policy=policy,
|
|
stage_to_task_type={
|
|
"extract-entities": "extraction",
|
|
"extract-relations": "extraction",
|
|
},
|
|
)
|
|
|
|
bridge.generate(_request("extract-entities"))
|
|
bridge.generate(_request("extract-relations"))
|
|
|
|
assert len(extraction.calls) == 2
|
|
|
|
|
|
def test_bridge_default_task_type_map_covers_all_known_stages() -> None:
|
|
expected = {
|
|
"summarize-source",
|
|
"extract-entities",
|
|
"extract-relations",
|
|
"evaluate-entity",
|
|
"synthesize-report",
|
|
}
|
|
assert set(STAGE_TO_TASK_TYPE_DEFAULT) == expected
|
|
# Identity mapping by default
|
|
for stage in expected:
|
|
assert STAGE_TO_TASK_TYPE_DEFAULT[stage] == stage
|
|
|
|
|
|
def test_bridge_falls_through_to_stage_id_when_no_known_mapping() -> None:
|
|
custom_adapter = _MockAdapter(model="custom-1")
|
|
policy = RoutingPolicy(
|
|
rules=[RoutingRule(task_type="custom-stage", prefer=custom_adapter)],
|
|
)
|
|
bridge = RoutingAssistedGenerationAdapter(policy=policy)
|
|
|
|
result = bridge.generate(_request("custom-stage"))
|
|
|
|
assert result.markdown == "# ok" or result.markdown == "ok"
|
|
assert custom_adapter.calls, "custom stage_id should fall through to the same task_type"
|
|
|
|
|
|
def test_bridge_uses_adaptive_path_when_quality_floor_set(tmp_path) -> None:
|
|
cheap = _MockAdapter(model="cheap-1")
|
|
smart = _MockAdapter(model="smart-1")
|
|
ledger = QualityLedger(path=tmp_path / "quality.jsonl")
|
|
# Cheap clears the floor; smart does too but at a higher cost.
|
|
for _ in range(3):
|
|
ledger.append(
|
|
QualityObservation(
|
|
task_type="extract-entities",
|
|
adapter_id="cheap-1",
|
|
model_id="cheap-1",
|
|
cost_usd=0.001,
|
|
quality_score=0.9,
|
|
tokens_in=100,
|
|
tokens_out=50,
|
|
latency_ms=10,
|
|
)
|
|
)
|
|
ledger.append(
|
|
QualityObservation(
|
|
task_type="extract-entities",
|
|
adapter_id="smart-1",
|
|
model_id="smart-1",
|
|
cost_usd=0.01,
|
|
quality_score=0.95,
|
|
tokens_in=100,
|
|
tokens_out=50,
|
|
latency_ms=10,
|
|
)
|
|
)
|
|
|
|
policy = AdaptiveRoutingPolicy(
|
|
rules=[RoutingRule(task_type="extract-entities", prefer=smart)],
|
|
default=cheap,
|
|
ledger=ledger,
|
|
adapters_by_id={"cheap-1": cheap, "smart-1": smart},
|
|
)
|
|
bridge = RoutingAssistedGenerationAdapter(policy=policy, quality_floor=0.8)
|
|
|
|
bridge.generate(_request("extract-entities"))
|
|
|
|
assert cheap.calls, "adaptive policy should pick the cheaper qualifying adapter"
|
|
assert not smart.calls
|
|
|
|
|
|
def test_bridge_falls_back_to_static_when_quality_floor_unset(tmp_path) -> None:
|
|
cheap = _MockAdapter(model="cheap-1")
|
|
smart = _MockAdapter(model="smart-1")
|
|
ledger = QualityLedger(path=tmp_path / "quality.jsonl")
|
|
|
|
policy = AdaptiveRoutingPolicy(
|
|
rules=[RoutingRule(task_type="extract-entities", prefer=smart)],
|
|
ledger=ledger,
|
|
)
|
|
bridge = RoutingAssistedGenerationAdapter(policy=policy) # no quality_floor
|
|
|
|
bridge.generate(_request("extract-entities"))
|
|
|
|
assert smart.calls, "without a quality_floor the bridge must use static routing"
|
|
assert not cheap.calls
|
|
|
|
|
|
def test_bridge_preserves_response_metadata_and_provider_tag() -> None:
|
|
adapter = _MockAdapter(model="cheap-1")
|
|
adapter.execute_prompt = lambda prompt, config: LLMResponse( # type: ignore[assignment]
|
|
content="# ok",
|
|
model="cheap-1",
|
|
usage={"prompt_tokens": 7, "completion_tokens": 3},
|
|
finish_reason="stop",
|
|
metadata={"request_id": "req-42"},
|
|
)
|
|
policy = RoutingPolicy(rules=[RoutingRule(task_type="custom", prefer=adapter)])
|
|
bridge = RoutingAssistedGenerationAdapter(policy=policy)
|
|
|
|
result = bridge.generate(_request("custom"))
|
|
|
|
assert result.metadata["request_id"] == "req-42"
|
|
assert result.metadata["usage"] == {"prompt_tokens": 7, "completion_tokens": 3}
|
|
assert result.metadata["task_type"] == "custom"
|
|
assert result.metadata["adapter_id"].endswith(":cheap-1")
|
|
assert result.provider == "mock"
|
|
|
|
|
|
def test_wrap_with_shadow_sampling_passes_candidate_through(tmp_path) -> None:
|
|
from llm_connect.grading import ExactMatchJudge, PairedGrader
|
|
from infospace_bench.routing import wrap_with_shadow_sampling
|
|
|
|
candidate = _MockAdapter(model="cheap-1", content="match")
|
|
baseline = _MockAdapter(model="baseline-1", content="match")
|
|
ledger = QualityLedger(path=tmp_path / "quality.jsonl")
|
|
grader = PairedGrader(judge=ExactMatchJudge())
|
|
|
|
shadow = wrap_with_shadow_sampling(
|
|
candidate=candidate,
|
|
baseline=baseline,
|
|
grader=grader,
|
|
ledger=ledger,
|
|
task_type="extract-entities",
|
|
shadow_rate=1.0,
|
|
async_shadow=False,
|
|
)
|
|
|
|
config = RunConfig(model_name="cheap-1")
|
|
response = shadow.execute_prompt("Hello.", config)
|
|
|
|
assert response.content == "match"
|
|
# Baseline ran in the shadow path; ledger now has one observation.
|
|
assert baseline.calls, "baseline must have been called when shadow_rate=1.0"
|
|
observations = ledger.by_task_type("extract-entities")
|
|
assert observations, "shadow path should append at least one observation"
|
|
|
|
|
|
def test_wrap_with_shadow_sampling_isolates_baseline_failure(tmp_path) -> None:
|
|
from llm_connect.grading import ExactMatchJudge, PairedGrader
|
|
from infospace_bench.routing import wrap_with_shadow_sampling
|
|
|
|
candidate = _MockAdapter(model="cheap-1", content="ok")
|
|
|
|
class _AngryBaseline(LLMAdapter):
|
|
def execute_prompt(self, prompt, config):
|
|
raise RuntimeError("baseline outage")
|
|
|
|
def validate_config(self, config):
|
|
return True
|
|
|
|
seen_errors: list[Exception] = []
|
|
shadow = wrap_with_shadow_sampling(
|
|
candidate=candidate,
|
|
baseline=_AngryBaseline(),
|
|
grader=PairedGrader(judge=ExactMatchJudge()),
|
|
ledger=QualityLedger(path=tmp_path / "quality.jsonl"),
|
|
task_type="summarize-source",
|
|
shadow_rate=1.0,
|
|
async_shadow=False,
|
|
on_shadow_error=seen_errors.append,
|
|
)
|
|
response = shadow.execute_prompt("Hello.", RunConfig(model_name="cheap-1"))
|
|
|
|
assert response.content == "ok", "candidate response must survive baseline outage"
|
|
assert seen_errors and "baseline outage" in str(seen_errors[0])
|
|
|
|
|
|
def test_summarise_quality_ledger_rolls_up_by_task_and_adapter(tmp_path) -> None:
|
|
from infospace_bench.routing import summarise_quality_ledger
|
|
|
|
ledger_path = tmp_path / "quality.jsonl"
|
|
ledger = QualityLedger(path=ledger_path)
|
|
for quality in (0.9, 0.95, 0.85):
|
|
ledger.append(
|
|
QualityObservation(
|
|
task_type="extract-entities",
|
|
adapter_id="cheap-1",
|
|
model_id="cheap-1",
|
|
cost_usd=0.001,
|
|
quality_score=quality,
|
|
tokens_in=100,
|
|
tokens_out=50,
|
|
latency_ms=10,
|
|
)
|
|
)
|
|
ledger.append(
|
|
QualityObservation(
|
|
task_type="summarize-source",
|
|
adapter_id="cheaper-1",
|
|
model_id="cheaper-1",
|
|
cost_usd=0.0001,
|
|
quality_score=0.7,
|
|
tokens_in=80,
|
|
tokens_out=20,
|
|
latency_ms=5,
|
|
)
|
|
)
|
|
|
|
rows = summarise_quality_ledger(ledger_path)
|
|
|
|
by_key = {(row["task_type"], row["adapter_id"]): row for row in rows}
|
|
extract = by_key[("extract-entities", "cheap-1")]
|
|
assert extract["observations"] == 3
|
|
assert extract["mean_quality"] == round((0.9 + 0.95 + 0.85) / 3, 4)
|
|
assert extract["mean_cost_usd"] == 0.001
|
|
summarize = by_key[("summarize-source", "cheaper-1")]
|
|
assert summarize["observations"] == 1
|
|
|
|
|
|
def test_collect_adapter_choices_rolls_up_per_stage(tmp_path) -> None:
|
|
"""Unit test: report helper aggregates adapter choices from artifact provenance."""
|
|
from infospace_bench.generator import _collect_adapter_choices
|
|
|
|
class _FakeArtifact:
|
|
def __init__(self, kind: str, provenance: dict) -> None:
|
|
self.kind = kind
|
|
self.provenance = provenance
|
|
|
|
artifacts = [
|
|
_FakeArtifact(
|
|
kind="entity",
|
|
provenance={
|
|
"stage_id": "extract-entities",
|
|
"provider_metadata": {
|
|
"adapter_id": "_MockAdapter:cheap-1",
|
|
"task_type": "extract-entities",
|
|
"usage": {"prompt_tokens": 120, "completion_tokens": 40},
|
|
},
|
|
},
|
|
),
|
|
_FakeArtifact(
|
|
kind="entity",
|
|
provenance={
|
|
"stage_id": "extract-entities",
|
|
"provider_metadata": {
|
|
"adapter_id": "_MockAdapter:cheap-1",
|
|
"task_type": "extract-entities",
|
|
"usage": {"prompt_tokens": 130, "completion_tokens": 50},
|
|
},
|
|
},
|
|
),
|
|
_FakeArtifact(
|
|
kind="relation",
|
|
provenance={
|
|
"stage_id": "extract-relations",
|
|
"provider_metadata": {
|
|
"adapter_id": "_MockAdapter:smart-1",
|
|
"task_type": "extract-relations",
|
|
"usage": {"prompt_tokens": 200, "completion_tokens": 80},
|
|
},
|
|
},
|
|
),
|
|
# Artifact without provider_metadata should be ignored.
|
|
_FakeArtifact(kind="generated", provenance={"stage_id": "summarize-source"}),
|
|
]
|
|
|
|
rows = _collect_adapter_choices(artifacts)
|
|
|
|
by_key = {(row["stage_id"], row["adapter_id"]): row for row in rows}
|
|
entities_row = by_key[("extract-entities", "_MockAdapter:cheap-1")]
|
|
relations_row = by_key[("extract-relations", "_MockAdapter:smart-1")]
|
|
assert entities_row["calls"] == 2
|
|
assert entities_row["prompt_tokens"] == 250
|
|
assert entities_row["completion_tokens"] == 90
|
|
assert relations_row["calls"] == 1
|
|
assert relations_row["task_type"] == "extract-relations"
|
|
|
|
|
|
def test_routing_ledger_cli(tmp_path) -> None:
|
|
import json as _json
|
|
import subprocess as _sub
|
|
import sys as _sys
|
|
import os as _os
|
|
|
|
ledger_path = tmp_path / "quality.jsonl"
|
|
ledger = QualityLedger(path=ledger_path)
|
|
ledger.append(
|
|
QualityObservation(
|
|
task_type="extract-entities",
|
|
adapter_id="cheap-1",
|
|
model_id="cheap-1",
|
|
cost_usd=0.001,
|
|
quality_score=0.9,
|
|
tokens_in=100,
|
|
tokens_out=50,
|
|
latency_ms=10,
|
|
)
|
|
)
|
|
|
|
env = _os.environ.copy()
|
|
env["PYTHONPATH"] = "src:/home/worsch/markitect-tool/src:/home/worsch/llm-connect"
|
|
result = _sub.run(
|
|
[_sys.executable, "-m", "infospace_bench", "routing", "ledger", str(ledger_path)],
|
|
check=False, env=env, text=True, capture_output=True,
|
|
)
|
|
|
|
assert result.returncode == 0, result.stderr
|
|
payload = _json.loads(result.stdout)
|
|
assert payload["ledger_path"] == str(ledger_path)
|
|
assert payload["rows"] and payload["rows"][0]["task_type"] == "extract-entities"
|
|
|
|
|
|
def test_bridge_passes_estimated_cost_per_1k_through() -> None:
|
|
captured: dict[str, Any] = {}
|
|
|
|
class _PolicyProbe(RoutingPolicy):
|
|
def resolve(self, task_type, estimated_cost_per_1k=None): # type: ignore[override]
|
|
captured["task_type"] = task_type
|
|
captured["estimated_cost_per_1k"] = estimated_cost_per_1k
|
|
return _MockAdapter(model="x")
|
|
|
|
bridge = RoutingAssistedGenerationAdapter(
|
|
policy=_PolicyProbe(),
|
|
estimated_cost_per_1k=0.5,
|
|
)
|
|
bridge.generate(_request("summarize-source"))
|
|
|
|
assert captured["task_type"] == "summarize-source"
|
|
assert captured["estimated_cost_per_1k"] == 0.5
|