""" Tests for BudgetTracker (FR-4) and LLMBudgetExceededError. """ import threading import pytest from llm_connect.models import BudgetTracker, RunConfig from llm_connect.adapter import MockLLMAdapter from llm_connect.exceptions import LLMBudgetExceededError, LLMError class TestBudgetTracker: def test_initial_state(self): t = BudgetTracker(total=1000) assert t.total == 1000 assert t.spent == 0 assert t.remaining() == 1000 def test_consume_updates_spent(self): t = BudgetTracker(total=1000) t.consume(300) assert t.spent == 300 assert t.remaining() == 700 def test_consume_multiple_times(self): t = BudgetTracker(total=1000) t.consume(400) t.consume(400) assert t.spent == 800 assert t.remaining() == 200 def test_consume_exact_budget(self): t = BudgetTracker(total=100) t.consume(100) assert t.spent == 100 assert t.remaining() == 0 def test_consume_exceeds_budget_raises(self): t = BudgetTracker(total=100) t.consume(60) with pytest.raises(LLMBudgetExceededError): t.consume(50) def test_exceeded_error_carries_details(self): t = BudgetTracker(total=100) t.consume(80) with pytest.raises(LLMBudgetExceededError) as exc_info: t.consume(30) err = exc_info.value assert err.total == 100 assert err.spent == 80 assert err.requested == 30 def test_exceeded_error_is_subclass_of_llm_error(self): with pytest.raises(LLMError): t = BudgetTracker(total=10) t.consume(20) def test_remaining_never_negative(self): t = BudgetTracker(total=100) t.consume(100) assert t.remaining() == 0 def test_invalid_total_raises(self): with pytest.raises(ValueError): BudgetTracker(total=0) with pytest.raises(ValueError): BudgetTracker(total=-1) def test_repr(self): t = BudgetTracker(total=500) t.consume(100) r = repr(t) assert "500" in r assert "100" in r def test_thread_safety(self): """Concurrent consume() calls must not corrupt state or allow overspend.""" total = 1000 t = BudgetTracker(total=total) errors = [] def consume_100(): try: t.consume(100) except LLMBudgetExceededError: errors.append(1) threads = [threading.Thread(target=consume_100) for _ in range(15)] for th in threads: th.start() for th in threads: th.join() # At most 10 consumes of 100 can succeed within a budget of 1000 assert t.spent <= total assert len(errors) == 5 # 15 attempts, 10 succeed, 5 fail class TestBudgetEnforcementInAdapter: def test_single_call_consumes_budget(self): tracker = BudgetTracker(total=10000) config = RunConfig(budget_tracker=tracker) adapter = MockLLMAdapter(mock_response="hello world") adapter.execute_prompt("test prompt", config) assert tracker.spent > 0 def test_exhausted_budget_raises_before_call(self): tracker = BudgetTracker(total=1) tracker.consume(1) # exhaust it config = RunConfig(budget_tracker=tracker) adapter = MockLLMAdapter() with pytest.raises(LLMBudgetExceededError): adapter.execute_prompt("any prompt", config) # Adapter should not have been called assert adapter.call_count == 0 def test_delegation_chain_shared_tracker(self): """A → B → C sharing the same tracker enforces the cap across all calls.""" tracker = BudgetTracker(total=10000) config = RunConfig(budget_tracker=tracker) adapter = MockLLMAdapter(mock_response="response") adapter.execute_prompt("call A", config) adapter.execute_prompt("call B", config) adapter.execute_prompt("call C", config) assert adapter.call_count == 3 assert tracker.spent > 0 def test_budget_exceeded_mid_chain(self): """Chain stops when budget is exhausted between calls.""" # MockLLMAdapter uses word count for tokens — "x" * 200 = 200 token prompt # mock_response "r" * 100 = 25 tokens; total ~75 per call adapter = MockLLMAdapter(mock_response="r " * 50) # ~50 completion tokens tracker = BudgetTracker(total=200) config = RunConfig(budget_tracker=tracker) # First call succeeds adapter.execute_prompt("p " * 100, config) # Eventually exhausts the budget with pytest.raises(LLMBudgetExceededError): for _ in range(10): adapter.execute_prompt("p " * 100, config) def test_no_tracker_has_no_effect(self): """Adapters work normally when no budget_tracker is set.""" config = RunConfig() # no budget_tracker adapter = MockLLMAdapter() response = adapter.execute_prompt("hello", config) assert response.content == "Mock LLM response"