""" Tests for the routing bridge that wraps an llm-connect RoutingPolicy as an infospace-bench AssistedGenerationAdapter (IB-WP-0018 T02/T05). All tests use mocked llm-connect ``LLMAdapter`` instances — no network. """ from __future__ import annotations from typing import Any import pytest from llm_connect.adapter import LLMAdapter from llm_connect.models import LLMResponse, RunConfig from llm_connect.routing import ( AdaptiveRoutingPolicy, RoutingPolicy, RoutingRule, ) from llm_connect.quality import QualityLedger, QualityObservation from infospace_bench.routing import ( STAGE_TO_TASK_TYPE_DEFAULT, RoutingAssistedGenerationAdapter, ) from infospace_bench.workflow import AssistedGenerationRequest class _MockAdapter(LLMAdapter): """Test double: returns a configured ``LLMResponse`` and records calls.""" def __init__(self, *, model: str, content: str = "ok", cost_per_call: float = 0.0) -> None: self.model = model self._content = content self._cost_per_call = cost_per_call self.calls: list[tuple[str, RunConfig]] = [] def execute_prompt(self, prompt: str, config: RunConfig) -> LLMResponse: self.calls.append((prompt, config)) return LLMResponse( content=self._content, model=self.model, usage={"prompt_tokens": 100, "completion_tokens": 50}, finish_reason="stop", metadata={"cost_per_call_usd": self._cost_per_call}, ) def validate_config(self, config: RunConfig) -> bool: # pragma: no cover - trivial return True def _request(stage_id: str, prompt: str = "Hello.") -> AssistedGenerationRequest: return AssistedGenerationRequest( stage_id=stage_id, workflow_id="generic-source-entities", input_artifact_id="source/test.md", prompt=prompt, ) def test_bridge_resolves_static_policy_per_stage() -> None: cheap = _MockAdapter(model="cheap-1", content="# Cheap") smart = _MockAdapter(model="smart-1", content="# Smart") policy = RoutingPolicy( rules=[ RoutingRule(task_type="summarize-source", prefer=cheap), RoutingRule(task_type="extract-entities", prefer=smart), ], default=cheap, ) bridge = RoutingAssistedGenerationAdapter(policy=policy) summary = bridge.generate(_request("summarize-source", "Source A")) entities = bridge.generate(_request("extract-entities", "Source A")) assert summary.markdown == "# Cheap" assert summary.metadata["task_type"] == "summarize-source" assert summary.metadata["model"] == "cheap-1" assert summary.metadata["usage"]["prompt_tokens"] == 100 assert entities.markdown == "# Smart" assert entities.metadata["model"] == "smart-1" assert len(cheap.calls) == 1 assert len(smart.calls) == 1 def test_bridge_honours_stage_to_task_type_overrides() -> None: extraction = _MockAdapter(model="extraction-1") policy = RoutingPolicy( rules=[RoutingRule(task_type="extraction", prefer=extraction)], ) bridge = RoutingAssistedGenerationAdapter( policy=policy, stage_to_task_type={ "extract-entities": "extraction", "extract-relations": "extraction", }, ) bridge.generate(_request("extract-entities")) bridge.generate(_request("extract-relations")) assert len(extraction.calls) == 2 def test_bridge_default_task_type_map_covers_all_known_stages() -> None: expected = { "summarize-source", "extract-entities", "extract-relations", "evaluate-entity", "synthesize-report", } assert set(STAGE_TO_TASK_TYPE_DEFAULT) == expected # Identity mapping by default for stage in expected: assert STAGE_TO_TASK_TYPE_DEFAULT[stage] == stage def test_bridge_falls_through_to_stage_id_when_no_known_mapping() -> None: custom_adapter = _MockAdapter(model="custom-1") policy = RoutingPolicy( rules=[RoutingRule(task_type="custom-stage", prefer=custom_adapter)], ) bridge = RoutingAssistedGenerationAdapter(policy=policy) result = bridge.generate(_request("custom-stage")) assert result.markdown == "# ok" or result.markdown == "ok" assert custom_adapter.calls, "custom stage_id should fall through to the same task_type" def test_bridge_uses_adaptive_path_when_quality_floor_set(tmp_path) -> None: cheap = _MockAdapter(model="cheap-1") smart = _MockAdapter(model="smart-1") ledger = QualityLedger(path=tmp_path / "quality.jsonl") # Cheap clears the floor; smart does too but at a higher cost. for _ in range(3): ledger.append( QualityObservation( task_type="extract-entities", adapter_id="cheap-1", model_id="cheap-1", cost_usd=0.001, quality_score=0.9, tokens_in=100, tokens_out=50, latency_ms=10, ) ) ledger.append( QualityObservation( task_type="extract-entities", adapter_id="smart-1", model_id="smart-1", cost_usd=0.01, quality_score=0.95, tokens_in=100, tokens_out=50, latency_ms=10, ) ) policy = AdaptiveRoutingPolicy( rules=[RoutingRule(task_type="extract-entities", prefer=smart)], default=cheap, ledger=ledger, adapters_by_id={"cheap-1": cheap, "smart-1": smart}, ) bridge = RoutingAssistedGenerationAdapter(policy=policy, quality_floor=0.8) bridge.generate(_request("extract-entities")) assert cheap.calls, "adaptive policy should pick the cheaper qualifying adapter" assert not smart.calls def test_bridge_falls_back_to_static_when_quality_floor_unset(tmp_path) -> None: cheap = _MockAdapter(model="cheap-1") smart = _MockAdapter(model="smart-1") ledger = QualityLedger(path=tmp_path / "quality.jsonl") policy = AdaptiveRoutingPolicy( rules=[RoutingRule(task_type="extract-entities", prefer=smart)], ledger=ledger, ) bridge = RoutingAssistedGenerationAdapter(policy=policy) # no quality_floor bridge.generate(_request("extract-entities")) assert smart.calls, "without a quality_floor the bridge must use static routing" assert not cheap.calls def test_bridge_preserves_response_metadata_and_provider_tag() -> None: adapter = _MockAdapter(model="cheap-1") adapter.execute_prompt = lambda prompt, config: LLMResponse( # type: ignore[assignment] content="# ok", model="cheap-1", usage={"prompt_tokens": 7, "completion_tokens": 3}, finish_reason="stop", metadata={"request_id": "req-42"}, ) policy = RoutingPolicy(rules=[RoutingRule(task_type="custom", prefer=adapter)]) bridge = RoutingAssistedGenerationAdapter(policy=policy) result = bridge.generate(_request("custom")) assert result.metadata["request_id"] == "req-42" assert result.metadata["usage"] == {"prompt_tokens": 7, "completion_tokens": 3} assert result.metadata["task_type"] == "custom" assert result.metadata["adapter_id"].endswith(":cheap-1") assert result.provider == "mock" def test_bridge_passes_estimated_cost_per_1k_through() -> None: captured: dict[str, Any] = {} class _PolicyProbe(RoutingPolicy): def resolve(self, task_type, estimated_cost_per_1k=None): # type: ignore[override] captured["task_type"] = task_type captured["estimated_cost_per_1k"] = estimated_cost_per_1k return _MockAdapter(model="x") bridge = RoutingAssistedGenerationAdapter( policy=_PolicyProbe(), estimated_cost_per_1k=0.5, ) bridge.generate(_request("summarize-source")) assert captured["task_type"] == "summarize-source" assert captured["estimated_cost_per_1k"] == 0.5