generated from coulomb/repo-seed
274 lines
9.3 KiB
Python
274 lines
9.3 KiB
Python
"""
|
|
Tests for LLMServer HTTP serve mode (FR-1).
|
|
"""
|
|
|
|
import threading
|
|
import time
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
import json
|
|
import urllib.error
|
|
import urllib.request
|
|
|
|
import pytest
|
|
|
|
from llm_connect._diagnostics import (
|
|
record_adapter_transformation,
|
|
record_provider_request,
|
|
record_provider_response,
|
|
)
|
|
from llm_connect.adapter import MockLLMAdapter, ErrorLLMAdapter
|
|
from llm_connect.models import LLMResponse, RunConfig
|
|
from llm_connect.server import LLMServer
|
|
|
|
|
|
@pytest.fixture()
|
|
def server():
|
|
"""Start a server on a free port; stop after each test."""
|
|
s = LLMServer(adapter=MockLLMAdapter(mock_response="hello world"), port=0)
|
|
s.start()
|
|
yield s
|
|
s.stop()
|
|
|
|
|
|
def _get(url: str) -> tuple[int, dict]:
|
|
try:
|
|
with urllib.request.urlopen(url) as resp:
|
|
return resp.status, json.loads(resp.read())
|
|
except urllib.error.HTTPError as exc:
|
|
return exc.code, json.loads(exc.read())
|
|
|
|
|
|
def _post(url: str, body: dict) -> tuple[int, dict]:
|
|
payload = json.dumps(body).encode()
|
|
req = urllib.request.Request(
|
|
url,
|
|
data=payload,
|
|
headers={"Content-Type": "application/json"},
|
|
method="POST",
|
|
)
|
|
try:
|
|
with urllib.request.urlopen(req) as resp:
|
|
return resp.status, json.loads(resp.read())
|
|
except urllib.error.HTTPError as exc:
|
|
return exc.code, json.loads(exc.read())
|
|
|
|
|
|
class DiagnosticLLMAdapter(MockLLMAdapter):
|
|
def execute_prompt(self, prompt: str, config: RunConfig) -> LLMResponse:
|
|
record_provider_request(
|
|
url="https://provider.example/v1/chat",
|
|
payload={"prompt": prompt, "model": config.model_name},
|
|
headers={"Authorization": "Bearer secret-token"},
|
|
)
|
|
response = super().execute_prompt(prompt, config)
|
|
response.metadata["provider"] = "diagnostic"
|
|
response.metadata["response_id"] = "diag-response"
|
|
record_provider_response(status=200, body={"id": "diag-response", "content": response.content})
|
|
record_adapter_transformation(
|
|
"diagnostic_transform",
|
|
{"before": prompt},
|
|
{"after": response.content},
|
|
)
|
|
return response
|
|
|
|
|
|
class BarrierLLMAdapter(MockLLMAdapter):
|
|
def __init__(self):
|
|
super().__init__(mock_response="parallel")
|
|
self._barrier = threading.Barrier(2)
|
|
|
|
def execute_prompt(self, prompt: str, config: RunConfig) -> LLMResponse:
|
|
self._barrier.wait(timeout=2.0)
|
|
return super().execute_prompt(prompt, config)
|
|
|
|
|
|
class TestHealth:
|
|
def test_health_returns_200(self, server):
|
|
status, body = _get(f"http://127.0.0.1:{server.port}/health")
|
|
assert status == 200
|
|
assert body["status"] == "ok"
|
|
|
|
def test_unknown_get_returns_404(self, server):
|
|
status, body = _get(f"http://127.0.0.1:{server.port}/nope")
|
|
assert status == 404
|
|
|
|
|
|
class TestExecute:
|
|
def test_post_execute_round_trip(self, server):
|
|
status, body = _post(
|
|
f"http://127.0.0.1:{server.port}/execute",
|
|
{"prompt": "say hello"},
|
|
)
|
|
assert status == 200
|
|
assert body["content"] == "hello world"
|
|
assert body["finish_reason"] == "stop"
|
|
assert "debug" not in body
|
|
|
|
def test_response_includes_usage(self, server):
|
|
status, body = _post(
|
|
f"http://127.0.0.1:{server.port}/execute",
|
|
{"prompt": "count tokens"},
|
|
)
|
|
assert status == 200
|
|
assert "usage" in body
|
|
assert body["usage"]["total_tokens"] > 0
|
|
|
|
def test_missing_prompt_returns_400(self, server):
|
|
status, body = _post(
|
|
f"http://127.0.0.1:{server.port}/execute",
|
|
{"config": {}},
|
|
)
|
|
assert status == 400
|
|
assert "prompt" in body["error"]
|
|
|
|
def test_invalid_json_returns_400(self, server):
|
|
req = urllib.request.Request(
|
|
f"http://127.0.0.1:{server.port}/execute",
|
|
data=b"not json",
|
|
headers={"Content-Type": "application/json"},
|
|
method="POST",
|
|
)
|
|
try:
|
|
with urllib.request.urlopen(req) as resp:
|
|
status, body = resp.status, json.loads(resp.read())
|
|
except urllib.error.HTTPError as exc:
|
|
status, body = exc.code, json.loads(exc.read())
|
|
assert status == 400
|
|
|
|
def test_unknown_post_path_returns_404(self, server):
|
|
status, body = _post(
|
|
f"http://127.0.0.1:{server.port}/wrong",
|
|
{"prompt": "hi"},
|
|
)
|
|
assert status == 404
|
|
|
|
def test_adapter_error_returns_500(self):
|
|
s = LLMServer(adapter=ErrorLLMAdapter("boom"), port=0)
|
|
s.start()
|
|
try:
|
|
status, body = _post(
|
|
f"http://127.0.0.1:{s.port}/execute",
|
|
{"prompt": "hello"},
|
|
)
|
|
assert status == 500
|
|
assert "boom" in body["error"]
|
|
finally:
|
|
s.stop()
|
|
|
|
def test_config_fields_forwarded(self):
|
|
"""Config fields in request body reach the adapter via RunConfig."""
|
|
adapter = MockLLMAdapter(mock_response="x")
|
|
s = LLMServer(adapter=adapter, port=0)
|
|
s.start()
|
|
try:
|
|
status, body = _post(
|
|
f"http://127.0.0.1:{s.port}/execute",
|
|
{
|
|
"prompt": "hi",
|
|
"config": {
|
|
"model_name": "gpt-3.5-turbo",
|
|
"max_tokens": 100,
|
|
"max_depth": 2,
|
|
"model_params": {"reasoning_effort": "medium"},
|
|
},
|
|
},
|
|
)
|
|
assert status == 200
|
|
assert adapter.last_config.model_name == "gpt-3.5-turbo"
|
|
assert adapter.last_config.max_tokens == 100
|
|
assert adapter.last_config.max_depth == 2
|
|
assert adapter.last_config.model_params == {"reasoning_effort": "medium"}
|
|
finally:
|
|
s.stop()
|
|
|
|
def test_config_must_be_object(self, server):
|
|
status, body = _post(
|
|
f"http://127.0.0.1:{server.port}/execute",
|
|
{"prompt": "hi", "config": "not an object"},
|
|
)
|
|
assert status == 400
|
|
assert "config" in body["error"]
|
|
|
|
def test_debug_query_returns_diagnostics(self):
|
|
s = LLMServer(adapter=DiagnosticLLMAdapter(mock_response="debug body"), port=0)
|
|
s.start()
|
|
try:
|
|
status, body = _post(
|
|
f"http://127.0.0.1:{s.port}/execute?debug=1",
|
|
{"prompt": "inspect", "config": {"model_name": "diagnostic-model"}},
|
|
)
|
|
finally:
|
|
s.stop()
|
|
|
|
assert status == 200
|
|
assert body["content"] == "debug body"
|
|
debug = body["debug"]
|
|
assert debug["provider_request"]["payload"] == {
|
|
"prompt": "inspect",
|
|
"model": "diagnostic-model",
|
|
}
|
|
assert debug["provider_request"]["headers_redacted"]["Authorization"] == "Bearer <redacted>"
|
|
assert debug["provider_response"]["status"] == 200
|
|
assert debug["adapter_transformations"][0]["step"] == "diagnostic_transform"
|
|
|
|
def test_debug_env_returns_diagnostics(self, monkeypatch):
|
|
monkeypatch.setenv("LLM_CONNECT_DEBUG", "1")
|
|
s = LLMServer(adapter=DiagnosticLLMAdapter(mock_response="debug body"), port=0)
|
|
s.start()
|
|
try:
|
|
status, body = _post(
|
|
f"http://127.0.0.1:{s.port}/execute",
|
|
{"prompt": "inspect"},
|
|
)
|
|
finally:
|
|
s.stop()
|
|
|
|
assert status == 200
|
|
assert "debug" in body
|
|
|
|
def test_audit_dir_records_replayable_call(self, monkeypatch, tmp_path):
|
|
monkeypatch.setenv("LLM_CONNECT_AUDIT_DIR", str(tmp_path))
|
|
s = LLMServer(adapter=DiagnosticLLMAdapter(mock_response="audit body"), port=0)
|
|
s.start()
|
|
try:
|
|
status, body = _post(
|
|
f"http://127.0.0.1:{s.port}/execute",
|
|
{"prompt": "audit me", "config": {"model_name": "audit-model"}},
|
|
)
|
|
finally:
|
|
s.stop()
|
|
|
|
assert status == 200
|
|
assert "debug" not in body
|
|
files = list(tmp_path.glob("*.json"))
|
|
assert len(files) == 1
|
|
record = json.loads(files[0].read_text(encoding="utf-8"))
|
|
assert record["prompt"] == "audit me"
|
|
assert record["config"]["model_name"] == "audit-model"
|
|
assert record["parsed_content"] == "audit body"
|
|
assert record["provider_request"]["headers_redacted"]["Authorization"] == "Bearer <redacted>"
|
|
assert record["provider_response"]["body"]["id"] == "diag-response"
|
|
assert record["latency_seconds"] >= 0
|
|
|
|
def test_execute_requests_run_concurrently(self):
|
|
s = LLMServer(adapter=BarrierLLMAdapter(), port=0)
|
|
s.start()
|
|
try:
|
|
start = time.monotonic()
|
|
with ThreadPoolExecutor(max_workers=2) as pool:
|
|
futures = [
|
|
pool.submit(
|
|
_post,
|
|
f"http://127.0.0.1:{s.port}/execute",
|
|
{"prompt": f"request {idx}"},
|
|
)
|
|
for idx in range(2)
|
|
]
|
|
results = [future.result(timeout=3.0) for future in futures]
|
|
elapsed = time.monotonic() - start
|
|
finally:
|
|
s.stop()
|
|
|
|
assert [status for status, _body in results] == [200, 200]
|
|
assert elapsed < 1.5
|