generated from coulomb/repo-seed
feat: WP-0001 foundation + WP-0002 core extensions
WP-0001 — Foundation & GAAF Baseline - SCOPE.md, ARCHITECTURE-LAYERS.md, contracts/ tree - .claude/rules/ stubs filled (architecture, stack, boundary) - 57 tests (pytest), pyproject.toml with ruff+mypy, CI workflow WP-0002 — Core Extensions (FR-4 + FR-3) - FR-4: BudgetTracker (thread-safe) + LLMBudgetExceededError + optional RunConfig.budget_tracker + enforcement in all adapters - FR-3: async_execute_prompt on LLMAdapter ABC (asyncio.to_thread fallback) + native asyncio.create_subprocess_exec in ClaudeCodeAdapter 81 tests passing. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
26
tests/conftest.py
Normal file
26
tests/conftest.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""
|
||||
Shared pytest fixtures for llm-connect tests.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from llm_connect.models import RunConfig, LLMResponse
|
||||
from llm_connect.adapter import MockLLMAdapter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def run_config():
|
||||
"""Default RunConfig for tests."""
|
||||
return RunConfig()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_adapter():
|
||||
"""MockLLMAdapter with a predictable response."""
|
||||
return MockLLMAdapter(mock_response="test response")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_response():
|
||||
"""A minimal valid LLMResponse."""
|
||||
return LLMResponse(content="hello", model="test-model")
|
||||
77
tests/test_adapter.py
Normal file
77
tests/test_adapter.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""
|
||||
Tests for MockLLMAdapter and ErrorLLMAdapter (Core adapter utilities).
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from llm_connect.adapter import MockLLMAdapter, ErrorLLMAdapter
|
||||
from llm_connect.models import RunConfig, LLMResponse
|
||||
|
||||
|
||||
class TestMockLLMAdapter:
|
||||
def test_returns_mock_response(self, mock_adapter, run_config):
|
||||
response = mock_adapter.execute_prompt("hello", run_config)
|
||||
assert response.content == "test response"
|
||||
|
||||
def test_returns_llm_response(self, mock_adapter, run_config):
|
||||
response = mock_adapter.execute_prompt("hello", run_config)
|
||||
assert isinstance(response, LLMResponse)
|
||||
|
||||
def test_call_count_increments(self, mock_adapter, run_config):
|
||||
assert mock_adapter.call_count == 0
|
||||
mock_adapter.execute_prompt("a", run_config)
|
||||
mock_adapter.execute_prompt("b", run_config)
|
||||
assert mock_adapter.call_count == 2
|
||||
|
||||
def test_records_last_prompt(self, mock_adapter, run_config):
|
||||
mock_adapter.execute_prompt("my prompt", run_config)
|
||||
assert mock_adapter.last_prompt == "my prompt"
|
||||
|
||||
def test_records_last_config(self, mock_adapter, run_config):
|
||||
mock_adapter.execute_prompt("x", run_config)
|
||||
assert mock_adapter.last_config is run_config
|
||||
|
||||
def test_reset_clears_state(self, mock_adapter, run_config):
|
||||
mock_adapter.execute_prompt("x", run_config)
|
||||
mock_adapter.reset()
|
||||
assert mock_adapter.call_count == 0
|
||||
assert mock_adapter.last_prompt is None
|
||||
assert mock_adapter.last_config is None
|
||||
|
||||
def test_validate_config_always_true(self, mock_adapter, run_config):
|
||||
assert mock_adapter.validate_config(run_config) is True
|
||||
|
||||
def test_usage_contains_expected_keys(self, mock_adapter, run_config):
|
||||
response = mock_adapter.execute_prompt("prompt text", run_config)
|
||||
assert "prompt_tokens" in response.usage
|
||||
assert "completion_tokens" in response.usage
|
||||
assert "total_tokens" in response.usage
|
||||
|
||||
def test_custom_response_text(self, run_config):
|
||||
adapter = MockLLMAdapter(mock_response="custom answer")
|
||||
response = adapter.execute_prompt("q", run_config)
|
||||
assert response.content == "custom answer"
|
||||
|
||||
def test_default_response_text(self, run_config):
|
||||
adapter = MockLLMAdapter()
|
||||
response = adapter.execute_prompt("q", run_config)
|
||||
assert response.content == "Mock LLM response"
|
||||
|
||||
def test_metadata_marks_as_mock(self, mock_adapter, run_config):
|
||||
response = mock_adapter.execute_prompt("q", run_config)
|
||||
assert response.metadata.get("mock") is True
|
||||
|
||||
|
||||
class TestErrorLLMAdapter:
|
||||
def test_raises_on_execute(self, run_config):
|
||||
adapter = ErrorLLMAdapter()
|
||||
with pytest.raises(RuntimeError):
|
||||
adapter.execute_prompt("q", run_config)
|
||||
|
||||
def test_raises_with_custom_message(self, run_config):
|
||||
adapter = ErrorLLMAdapter(error_message="boom")
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
adapter.execute_prompt("q", run_config)
|
||||
|
||||
def test_validate_config_returns_true(self, run_config):
|
||||
adapter = ErrorLLMAdapter()
|
||||
assert adapter.validate_config(run_config) is True
|
||||
101
tests/test_async.py
Normal file
101
tests/test_async.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
Tests for async_execute_prompt (FR-3).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
|
||||
from llm_connect.models import RunConfig, BudgetTracker
|
||||
from llm_connect.adapter import MockLLMAdapter
|
||||
from llm_connect.exceptions import LLMBudgetExceededError
|
||||
|
||||
|
||||
class TestAsyncExecutePrompt:
|
||||
def test_default_fallback_returns_response(self):
|
||||
adapter = MockLLMAdapter(mock_response="async result")
|
||||
config = RunConfig()
|
||||
response = asyncio.run(adapter.async_execute_prompt("hello", config))
|
||||
assert response.content == "async result"
|
||||
|
||||
def test_gather_multiple_adapters(self):
|
||||
"""asyncio.gather over N adapters completes without errors."""
|
||||
adapters = [MockLLMAdapter(mock_response=f"resp-{i}") for i in range(4)]
|
||||
config = RunConfig()
|
||||
|
||||
async def run():
|
||||
return await asyncio.gather(*[
|
||||
a.async_execute_prompt("prompt", config) for a in adapters
|
||||
])
|
||||
|
||||
results = asyncio.run(run())
|
||||
assert len(results) == 4
|
||||
for i, r in enumerate(results):
|
||||
assert r.content == f"resp-{i}"
|
||||
|
||||
def test_gather_increments_call_counts(self):
|
||||
adapter = MockLLMAdapter()
|
||||
config = RunConfig()
|
||||
|
||||
async def run():
|
||||
await asyncio.gather(*[
|
||||
adapter.async_execute_prompt("p", config) for _ in range(5)
|
||||
])
|
||||
|
||||
asyncio.run(run())
|
||||
assert adapter.call_count == 5
|
||||
|
||||
def test_concurrent_faster_than_sequential(self):
|
||||
"""Gathering N async calls should not be N× slower than one call."""
|
||||
import time
|
||||
|
||||
adapter = MockLLMAdapter()
|
||||
config = RunConfig()
|
||||
|
||||
async def run_concurrent(n: int):
|
||||
await asyncio.gather(*[
|
||||
adapter.async_execute_prompt("p", config) for _ in range(n)
|
||||
])
|
||||
|
||||
# Just verify it completes without deadlock or error — timing is CI-unreliable
|
||||
asyncio.run(run_concurrent(10))
|
||||
assert adapter.call_count == 10
|
||||
|
||||
def test_async_with_budget_tracker(self):
|
||||
"""Budget enforcement works through async calls."""
|
||||
tracker = BudgetTracker(total=10000)
|
||||
config = RunConfig(budget_tracker=tracker)
|
||||
adapter = MockLLMAdapter(mock_response="hi")
|
||||
|
||||
asyncio.run(adapter.async_execute_prompt("hello", config))
|
||||
assert tracker.spent > 0
|
||||
|
||||
def test_async_exhausted_budget_raises(self):
|
||||
"""Exhausted budget raises LLMBudgetExceededError in async context."""
|
||||
tracker = BudgetTracker(total=1)
|
||||
tracker.consume(1)
|
||||
config = RunConfig(budget_tracker=tracker)
|
||||
adapter = MockLLMAdapter()
|
||||
|
||||
with pytest.raises(LLMBudgetExceededError):
|
||||
asyncio.run(adapter.async_execute_prompt("p", config))
|
||||
|
||||
def test_async_gather_with_shared_budget(self):
|
||||
"""Shared budget across concurrent async calls is enforced correctly."""
|
||||
tracker = BudgetTracker(total=100000)
|
||||
config = RunConfig(budget_tracker=tracker)
|
||||
adapters = [MockLLMAdapter(mock_response="hi") for _ in range(4)]
|
||||
|
||||
async def run():
|
||||
await asyncio.gather(*[
|
||||
a.async_execute_prompt("hello", config) for a in adapters
|
||||
])
|
||||
|
||||
asyncio.run(run())
|
||||
assert tracker.spent > 0
|
||||
|
||||
def test_returns_llm_response_type(self):
|
||||
from llm_connect.models import LLMResponse
|
||||
adapter = MockLLMAdapter()
|
||||
config = RunConfig()
|
||||
response = asyncio.run(adapter.async_execute_prompt("q", config))
|
||||
assert isinstance(response, LLMResponse)
|
||||
152
tests/test_budget.py
Normal file
152
tests/test_budget.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""
|
||||
Tests for BudgetTracker (FR-4) and LLMBudgetExceededError.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import pytest
|
||||
|
||||
from llm_connect.models import BudgetTracker, RunConfig
|
||||
from llm_connect.adapter import MockLLMAdapter
|
||||
from llm_connect.exceptions import LLMBudgetExceededError, LLMError
|
||||
|
||||
|
||||
class TestBudgetTracker:
|
||||
def test_initial_state(self):
|
||||
t = BudgetTracker(total=1000)
|
||||
assert t.total == 1000
|
||||
assert t.spent == 0
|
||||
assert t.remaining() == 1000
|
||||
|
||||
def test_consume_updates_spent(self):
|
||||
t = BudgetTracker(total=1000)
|
||||
t.consume(300)
|
||||
assert t.spent == 300
|
||||
assert t.remaining() == 700
|
||||
|
||||
def test_consume_multiple_times(self):
|
||||
t = BudgetTracker(total=1000)
|
||||
t.consume(400)
|
||||
t.consume(400)
|
||||
assert t.spent == 800
|
||||
assert t.remaining() == 200
|
||||
|
||||
def test_consume_exact_budget(self):
|
||||
t = BudgetTracker(total=100)
|
||||
t.consume(100)
|
||||
assert t.spent == 100
|
||||
assert t.remaining() == 0
|
||||
|
||||
def test_consume_exceeds_budget_raises(self):
|
||||
t = BudgetTracker(total=100)
|
||||
t.consume(60)
|
||||
with pytest.raises(LLMBudgetExceededError):
|
||||
t.consume(50)
|
||||
|
||||
def test_exceeded_error_carries_details(self):
|
||||
t = BudgetTracker(total=100)
|
||||
t.consume(80)
|
||||
with pytest.raises(LLMBudgetExceededError) as exc_info:
|
||||
t.consume(30)
|
||||
err = exc_info.value
|
||||
assert err.total == 100
|
||||
assert err.spent == 80
|
||||
assert err.requested == 30
|
||||
|
||||
def test_exceeded_error_is_subclass_of_llm_error(self):
|
||||
with pytest.raises(LLMError):
|
||||
t = BudgetTracker(total=10)
|
||||
t.consume(20)
|
||||
|
||||
def test_remaining_never_negative(self):
|
||||
t = BudgetTracker(total=100)
|
||||
t.consume(100)
|
||||
assert t.remaining() == 0
|
||||
|
||||
def test_invalid_total_raises(self):
|
||||
with pytest.raises(ValueError):
|
||||
BudgetTracker(total=0)
|
||||
with pytest.raises(ValueError):
|
||||
BudgetTracker(total=-1)
|
||||
|
||||
def test_repr(self):
|
||||
t = BudgetTracker(total=500)
|
||||
t.consume(100)
|
||||
r = repr(t)
|
||||
assert "500" in r
|
||||
assert "100" in r
|
||||
|
||||
def test_thread_safety(self):
|
||||
"""Concurrent consume() calls must not corrupt state or allow overspend."""
|
||||
total = 1000
|
||||
t = BudgetTracker(total=total)
|
||||
errors = []
|
||||
|
||||
def consume_100():
|
||||
try:
|
||||
t.consume(100)
|
||||
except LLMBudgetExceededError:
|
||||
errors.append(1)
|
||||
|
||||
threads = [threading.Thread(target=consume_100) for _ in range(15)]
|
||||
for th in threads:
|
||||
th.start()
|
||||
for th in threads:
|
||||
th.join()
|
||||
|
||||
# At most 10 consumes of 100 can succeed within a budget of 1000
|
||||
assert t.spent <= total
|
||||
assert len(errors) == 5 # 15 attempts, 10 succeed, 5 fail
|
||||
|
||||
|
||||
class TestBudgetEnforcementInAdapter:
|
||||
def test_single_call_consumes_budget(self):
|
||||
tracker = BudgetTracker(total=10000)
|
||||
config = RunConfig(budget_tracker=tracker)
|
||||
adapter = MockLLMAdapter(mock_response="hello world")
|
||||
adapter.execute_prompt("test prompt", config)
|
||||
assert tracker.spent > 0
|
||||
|
||||
def test_exhausted_budget_raises_before_call(self):
|
||||
tracker = BudgetTracker(total=1)
|
||||
tracker.consume(1) # exhaust it
|
||||
config = RunConfig(budget_tracker=tracker)
|
||||
adapter = MockLLMAdapter()
|
||||
with pytest.raises(LLMBudgetExceededError):
|
||||
adapter.execute_prompt("any prompt", config)
|
||||
# Adapter should not have been called
|
||||
assert adapter.call_count == 0
|
||||
|
||||
def test_delegation_chain_shared_tracker(self):
|
||||
"""A → B → C sharing the same tracker enforces the cap across all calls."""
|
||||
tracker = BudgetTracker(total=10000)
|
||||
config = RunConfig(budget_tracker=tracker)
|
||||
adapter = MockLLMAdapter(mock_response="response")
|
||||
|
||||
adapter.execute_prompt("call A", config)
|
||||
adapter.execute_prompt("call B", config)
|
||||
adapter.execute_prompt("call C", config)
|
||||
|
||||
assert adapter.call_count == 3
|
||||
assert tracker.spent > 0
|
||||
|
||||
def test_budget_exceeded_mid_chain(self):
|
||||
"""Chain stops when budget is exhausted between calls."""
|
||||
# MockLLMAdapter uses word count for tokens — "x" * 200 = 200 token prompt
|
||||
# mock_response "r" * 100 = 25 tokens; total ~75 per call
|
||||
adapter = MockLLMAdapter(mock_response="r " * 50) # ~50 completion tokens
|
||||
tracker = BudgetTracker(total=200)
|
||||
config = RunConfig(budget_tracker=tracker)
|
||||
|
||||
# First call succeeds
|
||||
adapter.execute_prompt("p " * 100, config)
|
||||
# Eventually exhausts the budget
|
||||
with pytest.raises(LLMBudgetExceededError):
|
||||
for _ in range(10):
|
||||
adapter.execute_prompt("p " * 100, config)
|
||||
|
||||
def test_no_tracker_has_no_effect(self):
|
||||
"""Adapters work normally when no budget_tracker is set."""
|
||||
config = RunConfig() # no budget_tracker
|
||||
adapter = MockLLMAdapter()
|
||||
response = adapter.execute_prompt("hello", config)
|
||||
assert response.content == "Mock LLM response"
|
||||
96
tests/test_exceptions.py
Normal file
96
tests/test_exceptions.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""
|
||||
Tests for the LLMError exception hierarchy (Core).
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from llm_connect.exceptions import (
|
||||
LLMError,
|
||||
LLMConfigurationError,
|
||||
LLMAPIError,
|
||||
LLMRateLimitError,
|
||||
LLMTimeoutError,
|
||||
LLMSubprocessError,
|
||||
)
|
||||
|
||||
|
||||
class TestLLMErrorHierarchy:
|
||||
def test_all_are_subclasses_of_llm_error(self):
|
||||
assert issubclass(LLMConfigurationError, LLMError)
|
||||
assert issubclass(LLMAPIError, LLMError)
|
||||
assert issubclass(LLMRateLimitError, LLMError)
|
||||
assert issubclass(LLMTimeoutError, LLMError)
|
||||
assert issubclass(LLMSubprocessError, LLMError)
|
||||
|
||||
def test_rate_limit_is_api_error(self):
|
||||
assert issubclass(LLMRateLimitError, LLMAPIError)
|
||||
|
||||
def test_all_are_exceptions(self):
|
||||
assert issubclass(LLMError, Exception)
|
||||
|
||||
|
||||
class TestLLMError:
|
||||
def test_basic_message(self):
|
||||
err = LLMError("something went wrong")
|
||||
assert str(err) == "something went wrong"
|
||||
|
||||
def test_context_appears_in_str(self):
|
||||
err = LLMError("oops", context={"provider": "openai"})
|
||||
assert "provider=openai" in str(err)
|
||||
|
||||
def test_cause_is_chained(self):
|
||||
cause = ValueError("root cause")
|
||||
err = LLMError("wrapper", cause=cause)
|
||||
assert err.__cause__ is cause
|
||||
|
||||
def test_empty_context_does_not_appear(self):
|
||||
err = LLMError("clean message", context={})
|
||||
assert str(err) == "clean message"
|
||||
|
||||
|
||||
class TestLLMAPIError:
|
||||
def test_has_status_code(self):
|
||||
err = LLMAPIError("bad request", status_code=400)
|
||||
assert err.status_code == 400
|
||||
|
||||
def test_has_response_body(self):
|
||||
err = LLMAPIError("error", status_code=500, response_body='{"error": "oops"}')
|
||||
assert err.response_body == '{"error": "oops"}'
|
||||
|
||||
def test_defaults(self):
|
||||
err = LLMAPIError("minimal")
|
||||
assert err.status_code == 0
|
||||
assert err.response_body == ""
|
||||
|
||||
def test_rate_limit_inherits_status_code(self):
|
||||
err = LLMRateLimitError("too many", status_code=429)
|
||||
assert err.status_code == 429
|
||||
assert isinstance(err, LLMAPIError)
|
||||
|
||||
|
||||
class TestLLMSubprocessError:
|
||||
def test_has_return_code(self):
|
||||
err = LLMSubprocessError("cli failed", return_code=1)
|
||||
assert err.return_code == 1
|
||||
|
||||
def test_has_stderr(self):
|
||||
err = LLMSubprocessError("cli failed", stderr="error output")
|
||||
assert err.stderr == "error output"
|
||||
|
||||
def test_defaults(self):
|
||||
err = LLMSubprocessError("minimal")
|
||||
assert err.return_code == 1
|
||||
assert err.stderr == ""
|
||||
|
||||
|
||||
class TestRaiseAndCatch:
|
||||
def test_catch_as_llm_error(self):
|
||||
with pytest.raises(LLMError):
|
||||
raise LLMConfigurationError("no key")
|
||||
|
||||
def test_catch_api_error_as_llm_error(self):
|
||||
with pytest.raises(LLMError):
|
||||
raise LLMAPIError("http error", status_code=502)
|
||||
|
||||
def test_catch_rate_limit_as_api_error(self):
|
||||
with pytest.raises(LLMAPIError):
|
||||
raise LLMRateLimitError("429", status_code=429)
|
||||
97
tests/test_factory.py
Normal file
97
tests/test_factory.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
Tests for create_adapter() and create_embedding_adapter() factories.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from llm_connect.factory import create_adapter
|
||||
from llm_connect.embedding_factory import create_embedding_adapter
|
||||
from llm_connect.exceptions import LLMConfigurationError
|
||||
from llm_connect.adapter import LLMAdapter
|
||||
from llm_connect.embedding_adapter import EmbeddingAdapter
|
||||
from llm_connect.openrouter import OpenRouterAdapter
|
||||
from llm_connect.claude_code import ClaudeCodeAdapter
|
||||
from llm_connect.openai import OpenAIAdapter
|
||||
from llm_connect.gemini import GeminiAdapter
|
||||
from llm_connect.embedding_openai import OpenAICompatibleEmbeddingAdapter
|
||||
|
||||
|
||||
class TestCreateAdapter:
|
||||
def test_unknown_provider_raises(self):
|
||||
with pytest.raises(LLMConfigurationError, match="Unknown LLM provider"):
|
||||
create_adapter("nonexistent-provider")
|
||||
|
||||
def test_unknown_provider_error_lists_known(self):
|
||||
with pytest.raises(LLMConfigurationError) as exc_info:
|
||||
create_adapter("bad")
|
||||
assert "openai" in str(exc_info.value)
|
||||
assert "gemini" in str(exc_info.value)
|
||||
|
||||
def test_openrouter_returns_adapter(self):
|
||||
adapter = create_adapter("openrouter", api_key="test-key")
|
||||
assert isinstance(adapter, OpenRouterAdapter)
|
||||
assert isinstance(adapter, LLMAdapter)
|
||||
|
||||
def test_openrouter_no_key_still_constructs(self):
|
||||
# OpenRouterAdapter defers key validation to execute_prompt
|
||||
adapter = create_adapter("openrouter")
|
||||
assert isinstance(adapter, OpenRouterAdapter)
|
||||
|
||||
def test_openai_with_key_returns_adapter(self):
|
||||
adapter = create_adapter("openai", api_key="sk-test-key")
|
||||
assert isinstance(adapter, OpenAIAdapter)
|
||||
assert isinstance(adapter, LLMAdapter)
|
||||
|
||||
def test_openai_without_key_raises(self, monkeypatch):
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
with pytest.raises(LLMConfigurationError):
|
||||
create_adapter("openai")
|
||||
|
||||
def test_gemini_with_key_returns_adapter(self):
|
||||
adapter = create_adapter("gemini", api_key="aistudio-test-key")
|
||||
assert isinstance(adapter, GeminiAdapter)
|
||||
assert isinstance(adapter, LLMAdapter)
|
||||
|
||||
def test_gemini_without_key_raises(self, monkeypatch):
|
||||
monkeypatch.delenv("GEMINI_API_KEY", raising=False)
|
||||
with pytest.raises(LLMConfigurationError):
|
||||
create_adapter("gemini")
|
||||
|
||||
def test_claude_code_returns_adapter(self):
|
||||
adapter = create_adapter("claude-code")
|
||||
assert isinstance(adapter, ClaudeCodeAdapter)
|
||||
assert isinstance(adapter, LLMAdapter)
|
||||
|
||||
def test_claude_code_with_model(self):
|
||||
adapter = create_adapter("claude-code", model="claude-opus-4-6")
|
||||
assert isinstance(adapter, ClaudeCodeAdapter)
|
||||
|
||||
def test_all_known_providers_are_reachable(self):
|
||||
known = {"openrouter", "openai", "gemini", "claude-code"}
|
||||
# Just verify each key is in the factory registry (no construction needed)
|
||||
from llm_connect.factory import _PROVIDERS
|
||||
assert known == set(_PROVIDERS.keys())
|
||||
|
||||
|
||||
class TestCreateEmbeddingAdapter:
|
||||
def test_unknown_provider_raises(self):
|
||||
with pytest.raises(LLMConfigurationError, match="Unknown embedding provider"):
|
||||
create_embedding_adapter("nonexistent")
|
||||
|
||||
def test_openai_returns_adapter(self):
|
||||
adapter = create_embedding_adapter("openai", api_key="sk-test")
|
||||
assert isinstance(adapter, OpenAICompatibleEmbeddingAdapter)
|
||||
assert isinstance(adapter, EmbeddingAdapter)
|
||||
|
||||
def test_openrouter_returns_adapter(self):
|
||||
adapter = create_embedding_adapter("openrouter", api_key="or-test")
|
||||
assert isinstance(adapter, OpenAICompatibleEmbeddingAdapter)
|
||||
assert isinstance(adapter, EmbeddingAdapter)
|
||||
|
||||
def test_validate_returns_true_when_key_set(self):
|
||||
adapter = create_embedding_adapter("openai", api_key="sk-test")
|
||||
assert adapter.validate() is True
|
||||
|
||||
def test_validate_returns_false_when_no_key(self, monkeypatch):
|
||||
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
|
||||
adapter = create_embedding_adapter("openai")
|
||||
assert adapter.validate() is False
|
||||
86
tests/test_models.py
Normal file
86
tests/test_models.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""
|
||||
Tests for RunConfig and LLMResponse (Core models).
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from llm_connect.models import RunConfig, LLMResponse
|
||||
|
||||
|
||||
class TestRunConfig:
|
||||
def test_defaults(self):
|
||||
cfg = RunConfig()
|
||||
assert cfg.model_name == "gpt-4"
|
||||
assert cfg.temperature == 0.7
|
||||
assert cfg.max_tokens == 2000
|
||||
assert cfg.model_params == {}
|
||||
assert cfg.max_depth == 3
|
||||
assert cfg.skip_if_exists is True
|
||||
assert cfg.timeout_seconds == 300
|
||||
|
||||
def test_custom_values(self):
|
||||
cfg = RunConfig(model_name="gemini-2.5-flash", temperature=0.1, max_tokens=500)
|
||||
assert cfg.model_name == "gemini-2.5-flash"
|
||||
assert cfg.temperature == 0.1
|
||||
assert cfg.max_tokens == 500
|
||||
|
||||
def test_to_dict_roundtrip(self):
|
||||
cfg = RunConfig(model_name="gpt-4o", temperature=0.3, max_tokens=1000)
|
||||
d = cfg.to_dict()
|
||||
assert d["model_name"] == "gpt-4o"
|
||||
assert d["temperature"] == 0.3
|
||||
assert d["max_tokens"] == 1000
|
||||
|
||||
def test_from_dict_roundtrip(self):
|
||||
original = RunConfig(model_name="claude-3", temperature=0.5, max_tokens=800)
|
||||
restored = RunConfig.from_dict(original.to_dict())
|
||||
assert restored.model_name == original.model_name
|
||||
assert restored.temperature == original.temperature
|
||||
assert restored.max_tokens == original.max_tokens
|
||||
|
||||
def test_from_dict_uses_defaults_for_missing_keys(self):
|
||||
cfg = RunConfig.from_dict({})
|
||||
assert cfg.model_name == "gpt-4"
|
||||
assert cfg.temperature == 0.7
|
||||
|
||||
def test_model_params_default_is_independent(self):
|
||||
a = RunConfig()
|
||||
b = RunConfig()
|
||||
a.model_params["x"] = 1
|
||||
assert "x" not in b.model_params
|
||||
|
||||
|
||||
class TestLLMResponse:
|
||||
def test_minimal_construction(self):
|
||||
r = LLMResponse(content="hello", model="test-model")
|
||||
assert r.content == "hello"
|
||||
assert r.model == "test-model"
|
||||
assert r.usage == {}
|
||||
assert r.finish_reason == "stop"
|
||||
assert r.metadata == {}
|
||||
|
||||
def test_full_construction(self):
|
||||
r = LLMResponse(
|
||||
content="response text",
|
||||
model="gpt-4",
|
||||
usage={"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
|
||||
finish_reason="length",
|
||||
metadata={"provider": "openai", "latency_seconds": 1.2},
|
||||
)
|
||||
assert r.usage["total_tokens"] == 15
|
||||
assert r.finish_reason == "length"
|
||||
assert r.metadata["provider"] == "openai"
|
||||
|
||||
def test_to_dict(self):
|
||||
r = LLMResponse(content="hi", model="m", finish_reason="stop")
|
||||
d = r.to_dict()
|
||||
assert d["content"] == "hi"
|
||||
assert d["model"] == "m"
|
||||
assert d["finish_reason"] == "stop"
|
||||
assert "usage" in d
|
||||
assert "metadata" in d
|
||||
|
||||
def test_metadata_default_is_independent(self):
|
||||
a = LLMResponse(content="a", model="m")
|
||||
b = LLMResponse(content="b", model="m")
|
||||
a.metadata["x"] = 1
|
||||
assert "x" not in b.metadata
|
||||
Reference in New Issue
Block a user