""" Tests for the routing bridge that wraps an llm-connect RoutingPolicy as an infospace-bench AssistedGenerationAdapter (IB-WP-0018 T02/T05). All tests use mocked llm-connect ``LLMAdapter`` instances — no network. """ from __future__ import annotations from typing import Any import pytest from llm_connect.adapter import LLMAdapter from llm_connect.models import LLMResponse, RunConfig from llm_connect.routing import ( AdaptiveRoutingPolicy, RoutingPolicy, RoutingRule, ) from llm_connect.quality import QualityLedger, QualityObservation from infospace_bench.routing import ( STAGE_TO_TASK_TYPE_DEFAULT, RoutingAssistedGenerationAdapter, ) from infospace_bench.workflow import AssistedGenerationRequest class _MockAdapter(LLMAdapter): """Test double: returns a configured ``LLMResponse`` and records calls.""" def __init__(self, *, model: str, content: str = "ok", cost_per_call: float = 0.0) -> None: self.model = model self._content = content self._cost_per_call = cost_per_call self.calls: list[tuple[str, RunConfig]] = [] def execute_prompt(self, prompt: str, config: RunConfig) -> LLMResponse: self.calls.append((prompt, config)) return LLMResponse( content=self._content, model=self.model, usage={"prompt_tokens": 100, "completion_tokens": 50}, finish_reason="stop", metadata={"cost_per_call_usd": self._cost_per_call}, ) def validate_config(self, config: RunConfig) -> bool: # pragma: no cover - trivial return True def _request(stage_id: str, prompt: str = "Hello.") -> AssistedGenerationRequest: return AssistedGenerationRequest( stage_id=stage_id, workflow_id="generic-source-entities", input_artifact_id="source/test.md", prompt=prompt, ) def test_bridge_resolves_static_policy_per_stage() -> None: cheap = _MockAdapter(model="cheap-1", content="# Cheap") smart = _MockAdapter(model="smart-1", content="# Smart") policy = RoutingPolicy( rules=[ RoutingRule(task_type="summarize-source", prefer=cheap), RoutingRule(task_type="extract-entities", prefer=smart), ], default=cheap, ) bridge = RoutingAssistedGenerationAdapter(policy=policy) summary = bridge.generate(_request("summarize-source", "Source A")) entities = bridge.generate(_request("extract-entities", "Source A")) assert summary.markdown == "# Cheap" assert summary.metadata["task_type"] == "summarize-source" assert summary.metadata["model"] == "cheap-1" assert summary.metadata["usage"]["prompt_tokens"] == 100 assert entities.markdown == "# Smart" assert entities.metadata["model"] == "smart-1" assert len(cheap.calls) == 1 assert len(smart.calls) == 1 def test_bridge_honours_stage_to_task_type_overrides() -> None: extraction = _MockAdapter(model="extraction-1") policy = RoutingPolicy( rules=[RoutingRule(task_type="extraction", prefer=extraction)], ) bridge = RoutingAssistedGenerationAdapter( policy=policy, stage_to_task_type={ "extract-entities": "extraction", "extract-relations": "extraction", }, ) bridge.generate(_request("extract-entities")) bridge.generate(_request("extract-relations")) assert len(extraction.calls) == 2 def test_bridge_default_task_type_map_covers_all_known_stages() -> None: expected = { "summarize-source", "extract-entities", "extract-relations", "evaluate-entity", "synthesize-report", } assert set(STAGE_TO_TASK_TYPE_DEFAULT) == expected # Identity mapping by default for stage in expected: assert STAGE_TO_TASK_TYPE_DEFAULT[stage] == stage def test_bridge_falls_through_to_stage_id_when_no_known_mapping() -> None: custom_adapter = _MockAdapter(model="custom-1") policy = RoutingPolicy( rules=[RoutingRule(task_type="custom-stage", prefer=custom_adapter)], ) bridge = RoutingAssistedGenerationAdapter(policy=policy) result = bridge.generate(_request("custom-stage")) assert result.markdown == "# ok" or result.markdown == "ok" assert custom_adapter.calls, "custom stage_id should fall through to the same task_type" def test_bridge_uses_adaptive_path_when_quality_floor_set(tmp_path) -> None: cheap = _MockAdapter(model="cheap-1") smart = _MockAdapter(model="smart-1") ledger = QualityLedger(path=tmp_path / "quality.jsonl") # Cheap clears the floor; smart does too but at a higher cost. for _ in range(3): ledger.append( QualityObservation( task_type="extract-entities", adapter_id="cheap-1", model_id="cheap-1", cost_usd=0.001, quality_score=0.9, tokens_in=100, tokens_out=50, latency_ms=10, ) ) ledger.append( QualityObservation( task_type="extract-entities", adapter_id="smart-1", model_id="smart-1", cost_usd=0.01, quality_score=0.95, tokens_in=100, tokens_out=50, latency_ms=10, ) ) policy = AdaptiveRoutingPolicy( rules=[RoutingRule(task_type="extract-entities", prefer=smart)], default=cheap, ledger=ledger, adapters_by_id={"cheap-1": cheap, "smart-1": smart}, ) bridge = RoutingAssistedGenerationAdapter(policy=policy, quality_floor=0.8) bridge.generate(_request("extract-entities")) assert cheap.calls, "adaptive policy should pick the cheaper qualifying adapter" assert not smart.calls def test_bridge_falls_back_to_static_when_quality_floor_unset(tmp_path) -> None: cheap = _MockAdapter(model="cheap-1") smart = _MockAdapter(model="smart-1") ledger = QualityLedger(path=tmp_path / "quality.jsonl") policy = AdaptiveRoutingPolicy( rules=[RoutingRule(task_type="extract-entities", prefer=smart)], ledger=ledger, ) bridge = RoutingAssistedGenerationAdapter(policy=policy) # no quality_floor bridge.generate(_request("extract-entities")) assert smart.calls, "without a quality_floor the bridge must use static routing" assert not cheap.calls def test_bridge_preserves_response_metadata_and_provider_tag() -> None: adapter = _MockAdapter(model="cheap-1") adapter.execute_prompt = lambda prompt, config: LLMResponse( # type: ignore[assignment] content="# ok", model="cheap-1", usage={"prompt_tokens": 7, "completion_tokens": 3}, finish_reason="stop", metadata={"request_id": "req-42"}, ) policy = RoutingPolicy(rules=[RoutingRule(task_type="custom", prefer=adapter)]) bridge = RoutingAssistedGenerationAdapter(policy=policy) result = bridge.generate(_request("custom")) assert result.metadata["request_id"] == "req-42" assert result.metadata["usage"] == {"prompt_tokens": 7, "completion_tokens": 3} assert result.metadata["task_type"] == "custom" assert result.metadata["adapter_id"].endswith(":cheap-1") assert result.provider == "mock" def test_wrap_with_shadow_sampling_passes_candidate_through(tmp_path) -> None: from llm_connect.grading import ExactMatchJudge, PairedGrader from infospace_bench.routing import wrap_with_shadow_sampling candidate = _MockAdapter(model="cheap-1", content="match") baseline = _MockAdapter(model="baseline-1", content="match") ledger = QualityLedger(path=tmp_path / "quality.jsonl") grader = PairedGrader(judge=ExactMatchJudge()) shadow = wrap_with_shadow_sampling( candidate=candidate, baseline=baseline, grader=grader, ledger=ledger, task_type="extract-entities", shadow_rate=1.0, async_shadow=False, ) config = RunConfig(model_name="cheap-1") response = shadow.execute_prompt("Hello.", config) assert response.content == "match" # Baseline ran in the shadow path; ledger now has one observation. assert baseline.calls, "baseline must have been called when shadow_rate=1.0" observations = ledger.by_task_type("extract-entities") assert observations, "shadow path should append at least one observation" def test_wrap_with_shadow_sampling_isolates_baseline_failure(tmp_path) -> None: from llm_connect.grading import ExactMatchJudge, PairedGrader from infospace_bench.routing import wrap_with_shadow_sampling candidate = _MockAdapter(model="cheap-1", content="ok") class _AngryBaseline(LLMAdapter): def execute_prompt(self, prompt, config): raise RuntimeError("baseline outage") def validate_config(self, config): return True seen_errors: list[Exception] = [] shadow = wrap_with_shadow_sampling( candidate=candidate, baseline=_AngryBaseline(), grader=PairedGrader(judge=ExactMatchJudge()), ledger=QualityLedger(path=tmp_path / "quality.jsonl"), task_type="summarize-source", shadow_rate=1.0, async_shadow=False, on_shadow_error=seen_errors.append, ) response = shadow.execute_prompt("Hello.", RunConfig(model_name="cheap-1")) assert response.content == "ok", "candidate response must survive baseline outage" assert seen_errors and "baseline outage" in str(seen_errors[0]) def test_summarise_quality_ledger_rolls_up_by_task_and_adapter(tmp_path) -> None: from infospace_bench.routing import summarise_quality_ledger ledger_path = tmp_path / "quality.jsonl" ledger = QualityLedger(path=ledger_path) for quality in (0.9, 0.95, 0.85): ledger.append( QualityObservation( task_type="extract-entities", adapter_id="cheap-1", model_id="cheap-1", cost_usd=0.001, quality_score=quality, tokens_in=100, tokens_out=50, latency_ms=10, ) ) ledger.append( QualityObservation( task_type="summarize-source", adapter_id="cheaper-1", model_id="cheaper-1", cost_usd=0.0001, quality_score=0.7, tokens_in=80, tokens_out=20, latency_ms=5, ) ) rows = summarise_quality_ledger(ledger_path) by_key = {(row["task_type"], row["adapter_id"]): row for row in rows} extract = by_key[("extract-entities", "cheap-1")] assert extract["observations"] == 3 assert extract["mean_quality"] == round((0.9 + 0.95 + 0.85) / 3, 4) assert extract["mean_cost_usd"] == 0.001 summarize = by_key[("summarize-source", "cheaper-1")] assert summarize["observations"] == 1 def test_collect_adapter_choices_rolls_up_per_stage(tmp_path) -> None: """Unit test: report helper aggregates adapter choices from artifact provenance.""" from infospace_bench.generator import _collect_adapter_choices class _FakeArtifact: def __init__(self, kind: str, provenance: dict) -> None: self.kind = kind self.provenance = provenance artifacts = [ _FakeArtifact( kind="entity", provenance={ "stage_id": "extract-entities", "provider_metadata": { "adapter_id": "_MockAdapter:cheap-1", "task_type": "extract-entities", "usage": {"prompt_tokens": 120, "completion_tokens": 40}, }, }, ), _FakeArtifact( kind="entity", provenance={ "stage_id": "extract-entities", "provider_metadata": { "adapter_id": "_MockAdapter:cheap-1", "task_type": "extract-entities", "usage": {"prompt_tokens": 130, "completion_tokens": 50}, }, }, ), _FakeArtifact( kind="relation", provenance={ "stage_id": "extract-relations", "provider_metadata": { "adapter_id": "_MockAdapter:smart-1", "task_type": "extract-relations", "usage": {"prompt_tokens": 200, "completion_tokens": 80}, }, }, ), # Artifact without provider_metadata should be ignored. _FakeArtifact(kind="generated", provenance={"stage_id": "summarize-source"}), ] rows = _collect_adapter_choices(artifacts) by_key = {(row["stage_id"], row["adapter_id"]): row for row in rows} entities_row = by_key[("extract-entities", "_MockAdapter:cheap-1")] relations_row = by_key[("extract-relations", "_MockAdapter:smart-1")] assert entities_row["calls"] == 2 assert entities_row["prompt_tokens"] == 250 assert entities_row["completion_tokens"] == 90 assert relations_row["calls"] == 1 assert relations_row["task_type"] == "extract-relations" def test_routing_ledger_cli(tmp_path) -> None: import json as _json import subprocess as _sub import sys as _sys import os as _os ledger_path = tmp_path / "quality.jsonl" ledger = QualityLedger(path=ledger_path) ledger.append( QualityObservation( task_type="extract-entities", adapter_id="cheap-1", model_id="cheap-1", cost_usd=0.001, quality_score=0.9, tokens_in=100, tokens_out=50, latency_ms=10, ) ) env = _os.environ.copy() env["PYTHONPATH"] = "src:/home/worsch/markitect-tool/src:/home/worsch/llm-connect" result = _sub.run( [_sys.executable, "-m", "infospace_bench", "routing", "ledger", str(ledger_path)], check=False, env=env, text=True, capture_output=True, ) assert result.returncode == 0, result.stderr payload = _json.loads(result.stdout) assert payload["ledger_path"] == str(ledger_path) assert payload["rows"] and payload["rows"][0]["task_type"] == "extract-entities" def test_bridge_passes_estimated_cost_per_1k_through() -> None: captured: dict[str, Any] = {} class _PolicyProbe(RoutingPolicy): def resolve(self, task_type, estimated_cost_per_1k=None): # type: ignore[override] captured["task_type"] = task_type captured["estimated_cost_per_1k"] = estimated_cost_per_1k return _MockAdapter(model="x") bridge = RoutingAssistedGenerationAdapter( policy=_PolicyProbe(), estimated_cost_per_1k=0.5, ) bridge.generate(_request("summarize-source")) assert captured["task_type"] == "summarize-source" assert captured["estimated_cost_per_1k"] == 0.5