feat(llm): add embedding adapter with cache and similarity utils (S1.3)
Add OpenAI-compatible embedding support (works with both OpenAI and OpenRouter), file-based embedding cache with content-digest invalidation, and pure-Python cosine similarity utilities for downstream redundancy detection. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -26,6 +26,15 @@ from markitect.llm.exceptions import (
|
||||
LLMTimeoutError,
|
||||
LLMSubprocessError,
|
||||
)
|
||||
from markitect.llm.embedding_adapter import EmbeddingAdapter
|
||||
from markitect.llm.embedding_openai import OpenAICompatibleEmbeddingAdapter
|
||||
from markitect.llm.embedding_cache import EmbeddingCache
|
||||
from markitect.llm.embedding_factory import create_embedding_adapter
|
||||
from markitect.llm.similarity import (
|
||||
cosine_similarity,
|
||||
similarity_matrix,
|
||||
find_similar_pairs,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"create_adapter",
|
||||
@@ -41,4 +50,11 @@ __all__ = [
|
||||
"LLMRateLimitError",
|
||||
"LLMTimeoutError",
|
||||
"LLMSubprocessError",
|
||||
"EmbeddingAdapter",
|
||||
"OpenAICompatibleEmbeddingAdapter",
|
||||
"EmbeddingCache",
|
||||
"create_embedding_adapter",
|
||||
"cosine_similarity",
|
||||
"similarity_matrix",
|
||||
"find_similar_pairs",
|
||||
]
|
||||
|
||||
34
markitect/llm/embedding_adapter.py
Normal file
34
markitect/llm/embedding_adapter.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""
|
||||
Abstract base class for embedding adapters.
|
||||
|
||||
Embedding adapters convert text into float vectors. This is a separate
|
||||
hierarchy from :class:`LLMAdapter` (text generation) because the API
|
||||
contract is fundamentally different: text in, float vectors out.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class EmbeddingAdapter(ABC):
|
||||
"""Base class for all embedding adapters."""
|
||||
|
||||
@abstractmethod
|
||||
def embed(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Embed a batch of texts into vectors.
|
||||
|
||||
Args:
|
||||
texts: One or more strings to embed.
|
||||
|
||||
Returns:
|
||||
A list of embedding vectors, one per input text,
|
||||
in the same order as *texts*.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def validate(self) -> bool:
|
||||
"""Check that the adapter is configured correctly.
|
||||
|
||||
Returns:
|
||||
``True`` if the adapter has a valid configuration
|
||||
(e.g. API key present), ``False`` otherwise.
|
||||
"""
|
||||
64
markitect/llm/embedding_cache.py
Normal file
64
markitect/llm/embedding_cache.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""
|
||||
File-based embedding cache.
|
||||
|
||||
Stores embedding vectors in a single JSON file keyed by entity slug.
|
||||
Each entry includes a content digest so stale embeddings are
|
||||
automatically invalidated when entity content changes.
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class EmbeddingCache:
|
||||
"""Persistent cache for embedding vectors.
|
||||
|
||||
Structure on disk (``embeddings.json``)::
|
||||
|
||||
{
|
||||
"division-of-labour": {"digest": "abc123", "vector": [0.1, ...]},
|
||||
...
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, cache_dir: Path):
|
||||
self._path = cache_dir / "embeddings.json"
|
||||
self._data: dict[str, dict] = {}
|
||||
self._hits = 0
|
||||
self._misses = 0
|
||||
self._load()
|
||||
|
||||
def get(self, slug: str, content_digest: str) -> Optional[list[float]]:
|
||||
"""Return the cached vector if *content_digest* matches, else ``None``."""
|
||||
entry = self._data.get(slug)
|
||||
if entry is not None and entry.get("digest") == content_digest:
|
||||
self._hits += 1
|
||||
return entry["vector"]
|
||||
self._misses += 1
|
||||
return None
|
||||
|
||||
def put(self, slug: str, content_digest: str, vector: list[float]) -> None:
|
||||
"""Store or overwrite the embedding for *slug*."""
|
||||
self._data[slug] = {"digest": content_digest, "vector": vector}
|
||||
|
||||
def save(self) -> None:
|
||||
"""Write cache to disk."""
|
||||
self._path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._path.write_text(json.dumps(self._data, separators=(",", ":")))
|
||||
|
||||
def stats(self) -> dict:
|
||||
"""Return cache statistics."""
|
||||
return {
|
||||
"entries": len(self._data),
|
||||
"hits": self._hits,
|
||||
"misses": self._misses,
|
||||
}
|
||||
|
||||
def _load(self) -> None:
|
||||
"""Read cache from disk if it exists."""
|
||||
if self._path.is_file():
|
||||
try:
|
||||
self._data = json.loads(self._path.read_text())
|
||||
except (json.JSONDecodeError, OSError):
|
||||
self._data = {}
|
||||
50
markitect/llm/embedding_factory.py
Normal file
50
markitect/llm/embedding_factory.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""
|
||||
Factory for creating embedding adapters by provider name.
|
||||
"""
|
||||
|
||||
from typing import Optional, Any
|
||||
|
||||
from markitect.llm.embedding_adapter import EmbeddingAdapter
|
||||
from markitect.llm.exceptions import LLMConfigurationError
|
||||
|
||||
_EMBEDDING_PROVIDERS = {
|
||||
"openai": "markitect.llm.embedding_openai.OpenAICompatibleEmbeddingAdapter",
|
||||
"openrouter": "markitect.llm.embedding_openai.OpenAICompatibleEmbeddingAdapter",
|
||||
}
|
||||
|
||||
|
||||
def create_embedding_adapter(
|
||||
provider: str = "openai",
|
||||
model: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> EmbeddingAdapter:
|
||||
"""Instantiate an :class:`EmbeddingAdapter` for the given *provider*.
|
||||
|
||||
Args:
|
||||
provider: ``"openai"`` or ``"openrouter"``.
|
||||
model: Embedding model name (e.g. ``"text-embedding-3-small"``).
|
||||
api_key: Explicit API key.
|
||||
**kwargs: Extra keyword arguments forwarded to the adapter.
|
||||
|
||||
Returns:
|
||||
A ready-to-use :class:`EmbeddingAdapter` instance.
|
||||
|
||||
Raises:
|
||||
LLMConfigurationError: If *provider* is not recognised.
|
||||
"""
|
||||
if provider not in _EMBEDDING_PROVIDERS:
|
||||
known = ", ".join(sorted(_EMBEDDING_PROVIDERS))
|
||||
raise LLMConfigurationError(
|
||||
f"Unknown embedding provider {provider!r}. Choose from: {known}",
|
||||
context={"provider": provider},
|
||||
)
|
||||
|
||||
# Lazy import
|
||||
fqn = _EMBEDDING_PROVIDERS[provider]
|
||||
module_path, class_name = fqn.rsplit(".", 1)
|
||||
import importlib
|
||||
mod = importlib.import_module(module_path)
|
||||
cls = getattr(mod, class_name)
|
||||
|
||||
return cls(model=model, api_key=api_key, provider=provider, **kwargs)
|
||||
125
markitect/llm/embedding_openai.py
Normal file
125
markitect/llm/embedding_openai.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""
|
||||
OpenAI-compatible embedding adapter.
|
||||
|
||||
Works with both OpenAI (``/v1/embeddings``) and OpenRouter
|
||||
(``/api/v1/embeddings``) since they share the same API format.
|
||||
The *provider* parameter determines the default base URL and
|
||||
API key environment variable.
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from markitect.llm.embedding_adapter import EmbeddingAdapter
|
||||
from markitect.llm.config import resolve_api_key, find_project_root
|
||||
from markitect.llm._http import post_json
|
||||
from markitect.llm.exceptions import (
|
||||
LLMConfigurationError,
|
||||
LLMAPIError,
|
||||
LLMRateLimitError,
|
||||
)
|
||||
|
||||
_DEFAULT_MODEL = "text-embedding-3-small"
|
||||
|
||||
_PROVIDER_DEFAULTS: Dict[str, Dict[str, str]] = {
|
||||
"openai": {
|
||||
"api_base": "https://api.openai.com/v1",
|
||||
"env_var": "OPENAI_API_KEY",
|
||||
},
|
||||
"openrouter": {
|
||||
"api_base": "https://openrouter.ai/api/v1",
|
||||
"env_var": "OPENROUTER_API_KEY",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class OpenAICompatibleEmbeddingAdapter(EmbeddingAdapter):
|
||||
"""Embedding adapter for OpenAI-compatible endpoints.
|
||||
|
||||
A single class handles both OpenAI and OpenRouter because they
|
||||
expose the same ``/embeddings`` endpoint format.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
provider: str = "openai",
|
||||
max_retries: int = 3,
|
||||
):
|
||||
if provider not in _PROVIDER_DEFAULTS:
|
||||
known = ", ".join(sorted(_PROVIDER_DEFAULTS))
|
||||
raise LLMConfigurationError(
|
||||
f"Unknown embedding provider {provider!r}. Choose from: {known}",
|
||||
context={"provider": provider},
|
||||
)
|
||||
|
||||
defaults = _PROVIDER_DEFAULTS[provider]
|
||||
self._model = model or _DEFAULT_MODEL
|
||||
self._api_base = (api_base or defaults["api_base"]).rstrip("/")
|
||||
self._max_retries = max_retries
|
||||
self._provider = provider
|
||||
|
||||
# Resolve API key
|
||||
env_var = defaults["env_var"]
|
||||
root = find_project_root()
|
||||
key_file_paths = [root / f"apikey-{provider}.txt"] if root else []
|
||||
self._api_key = resolve_api_key(
|
||||
explicit=api_key,
|
||||
env_var=env_var,
|
||||
key_file_paths=key_file_paths,
|
||||
)
|
||||
|
||||
def embed(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Embed texts via the OpenAI-compatible ``/embeddings`` endpoint.
|
||||
|
||||
Raises:
|
||||
LLMConfigurationError: If no API key is configured.
|
||||
LLMAPIError: On HTTP errors after retries are exhausted.
|
||||
"""
|
||||
if not self._api_key:
|
||||
raise LLMConfigurationError(
|
||||
"No API key configured for embedding adapter",
|
||||
context={"provider": self._provider},
|
||||
)
|
||||
|
||||
url = f"{self._api_base}/embeddings"
|
||||
payload: Dict[str, Any] = {
|
||||
"model": self._model,
|
||||
"input": texts,
|
||||
}
|
||||
headers = {"Authorization": f"Bearer {self._api_key}"}
|
||||
|
||||
data = self._post_with_retries(url, payload, headers)
|
||||
|
||||
# Response: {"data": [{"embedding": [...], "index": 0}, ...]}
|
||||
# Sort by index to guarantee input order.
|
||||
items = sorted(data["data"], key=lambda d: d["index"])
|
||||
return [item["embedding"] for item in items]
|
||||
|
||||
def validate(self) -> bool:
|
||||
"""Return ``True`` if an API key is available."""
|
||||
return self._api_key is not None
|
||||
|
||||
def _post_with_retries(
|
||||
self,
|
||||
url: str,
|
||||
payload: Dict[str, Any],
|
||||
headers: Dict[str, str],
|
||||
) -> Dict[str, Any]:
|
||||
last_exc: Optional[Exception] = None
|
||||
for attempt in range(self._max_retries + 1):
|
||||
try:
|
||||
return post_json(url, payload, headers)
|
||||
except LLMRateLimitError as exc:
|
||||
last_exc = exc
|
||||
if attempt < self._max_retries:
|
||||
time.sleep(2 ** attempt)
|
||||
except LLMAPIError as exc:
|
||||
if exc.status_code >= 500 and attempt < self._max_retries:
|
||||
last_exc = exc
|
||||
time.sleep(2 ** attempt)
|
||||
else:
|
||||
raise
|
||||
raise last_exc # type: ignore[misc]
|
||||
64
markitect/llm/similarity.py
Normal file
64
markitect/llm/similarity.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""
|
||||
Pure-Python vector similarity utilities.
|
||||
|
||||
No external dependencies — uses :mod:`math` only. Sufficient for the
|
||||
current entity scale (~100s). numpy can be substituted later if needed.
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
|
||||
def cosine_similarity(a: list[float], b: list[float]) -> float:
|
||||
"""Cosine similarity between two vectors.
|
||||
|
||||
Returns a float in [-1, 1]. Returns 0.0 if either vector has
|
||||
zero magnitude (to avoid division by zero).
|
||||
"""
|
||||
dot = sum(x * y for x, y in zip(a, b))
|
||||
mag_a = math.sqrt(sum(x * x for x in a))
|
||||
mag_b = math.sqrt(sum(x * x for x in b))
|
||||
if mag_a == 0.0 or mag_b == 0.0:
|
||||
return 0.0
|
||||
return dot / (mag_a * mag_b)
|
||||
|
||||
|
||||
def similarity_matrix(embeddings: list[list[float]]) -> list[list[float]]:
|
||||
"""Build an NxN cosine similarity matrix.
|
||||
|
||||
``matrix[i][j]`` is the cosine similarity between
|
||||
``embeddings[i]`` and ``embeddings[j]``.
|
||||
"""
|
||||
n = len(embeddings)
|
||||
mat: list[list[float]] = [[0.0] * n for _ in range(n)]
|
||||
for i in range(n):
|
||||
mat[i][i] = 1.0
|
||||
for j in range(i + 1, n):
|
||||
sim = cosine_similarity(embeddings[i], embeddings[j])
|
||||
mat[i][j] = sim
|
||||
mat[j][i] = sim
|
||||
return mat
|
||||
|
||||
|
||||
def find_similar_pairs(
|
||||
embeddings: dict[str, list[float]],
|
||||
threshold: float = 0.80,
|
||||
) -> list[tuple[str, str, float]]:
|
||||
"""Find all pairs with cosine similarity >= *threshold*.
|
||||
|
||||
Args:
|
||||
embeddings: Mapping of slug → embedding vector.
|
||||
threshold: Minimum similarity to include (default 0.80).
|
||||
|
||||
Returns:
|
||||
List of ``(slug_a, slug_b, similarity)`` tuples sorted by
|
||||
similarity descending.
|
||||
"""
|
||||
slugs = sorted(embeddings)
|
||||
pairs: list[tuple[str, str, float]] = []
|
||||
for i, slug_a in enumerate(slugs):
|
||||
for slug_b in slugs[i + 1:]:
|
||||
sim = cosine_similarity(embeddings[slug_a], embeddings[slug_b])
|
||||
if sim >= threshold:
|
||||
pairs.append((slug_a, slug_b, sim))
|
||||
pairs.sort(key=lambda t: t[2], reverse=True)
|
||||
return pairs
|
||||
235
tests/unit/llm/test_embeddings.py
Normal file
235
tests/unit/llm/test_embeddings.py
Normal file
@@ -0,0 +1,235 @@
|
||||
"""Tests for embedding adapter, cache, similarity, and factory."""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from markitect.llm.similarity import (
|
||||
cosine_similarity,
|
||||
similarity_matrix,
|
||||
find_similar_pairs,
|
||||
)
|
||||
from markitect.llm.embedding_cache import EmbeddingCache
|
||||
from markitect.llm.embedding_openai import OpenAICompatibleEmbeddingAdapter
|
||||
from markitect.llm.embedding_factory import create_embedding_adapter
|
||||
from markitect.llm.exceptions import LLMConfigurationError, LLMRateLimitError
|
||||
|
||||
|
||||
# ── Similarity math ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestCosineSimilarity:
|
||||
def test_identical_vectors(self):
|
||||
v = [1.0, 2.0, 3.0]
|
||||
assert cosine_similarity(v, v) == pytest.approx(1.0)
|
||||
|
||||
def test_orthogonal_vectors(self):
|
||||
a = [1.0, 0.0, 0.0]
|
||||
b = [0.0, 1.0, 0.0]
|
||||
assert cosine_similarity(a, b) == pytest.approx(0.0)
|
||||
|
||||
def test_opposite_vectors(self):
|
||||
a = [1.0, 0.0]
|
||||
b = [-1.0, 0.0]
|
||||
assert cosine_similarity(a, b) == pytest.approx(-1.0)
|
||||
|
||||
def test_zero_vector(self):
|
||||
assert cosine_similarity([0.0, 0.0], [1.0, 2.0]) == 0.0
|
||||
|
||||
|
||||
class TestSimilarityMatrix:
|
||||
def test_diagonal_is_one(self):
|
||||
vecs = [[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]]
|
||||
mat = similarity_matrix(vecs)
|
||||
for i in range(len(vecs)):
|
||||
assert mat[i][i] == pytest.approx(1.0)
|
||||
|
||||
def test_symmetric(self):
|
||||
vecs = [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]
|
||||
mat = similarity_matrix(vecs)
|
||||
n = len(vecs)
|
||||
for i in range(n):
|
||||
for j in range(n):
|
||||
assert mat[i][j] == pytest.approx(mat[j][i])
|
||||
|
||||
|
||||
class TestFindSimilarPairs:
|
||||
def test_threshold_filters(self):
|
||||
emb = {
|
||||
"a": [1.0, 0.0],
|
||||
"b": [0.0, 1.0],
|
||||
"c": [1.0, 0.01], # very similar to "a"
|
||||
}
|
||||
pairs = find_similar_pairs(emb, threshold=0.90)
|
||||
slugs_in_pairs = {(s1, s2) for s1, s2, _ in pairs}
|
||||
assert ("a", "c") in slugs_in_pairs
|
||||
# a-b are orthogonal, should not appear
|
||||
assert ("a", "b") not in slugs_in_pairs
|
||||
|
||||
def test_sorted_descending(self):
|
||||
emb = {
|
||||
"x": [1.0, 0.0, 0.0],
|
||||
"y": [0.9, 0.1, 0.0],
|
||||
"z": [0.95, 0.05, 0.0],
|
||||
}
|
||||
pairs = find_similar_pairs(emb, threshold=0.0)
|
||||
sims = [s for _, _, s in pairs]
|
||||
assert sims == sorted(sims, reverse=True)
|
||||
|
||||
def test_empty_embeddings(self):
|
||||
assert find_similar_pairs({}) == []
|
||||
|
||||
def test_single_embedding(self):
|
||||
assert find_similar_pairs({"only": [1.0, 0.0]}) == []
|
||||
|
||||
|
||||
# ── Embedding cache ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestEmbeddingCache:
|
||||
def test_put_get_roundtrip(self, tmp_path: Path):
|
||||
cache = EmbeddingCache(tmp_path)
|
||||
cache.put("division-of-labour", "abc123", [0.1, 0.2, 0.3])
|
||||
assert cache.get("division-of-labour", "abc123") == [0.1, 0.2, 0.3]
|
||||
|
||||
def test_wrong_digest_returns_none(self, tmp_path: Path):
|
||||
cache = EmbeddingCache(tmp_path)
|
||||
cache.put("slug", "digest-v1", [1.0])
|
||||
assert cache.get("slug", "digest-v2") is None
|
||||
|
||||
def test_missing_slug_returns_none(self, tmp_path: Path):
|
||||
cache = EmbeddingCache(tmp_path)
|
||||
assert cache.get("nonexistent", "any") is None
|
||||
|
||||
def test_save_load_persists(self, tmp_path: Path):
|
||||
cache = EmbeddingCache(tmp_path)
|
||||
cache.put("slug-a", "d1", [0.5, 0.6])
|
||||
cache.save()
|
||||
|
||||
cache2 = EmbeddingCache(tmp_path)
|
||||
assert cache2.get("slug-a", "d1") == [0.5, 0.6]
|
||||
|
||||
def test_stats_tracks_hits_and_misses(self, tmp_path: Path):
|
||||
cache = EmbeddingCache(tmp_path)
|
||||
cache.put("s", "d", [1.0])
|
||||
cache.get("s", "d") # hit
|
||||
cache.get("s", "wrong") # miss
|
||||
cache.get("missing", "x") # miss
|
||||
s = cache.stats()
|
||||
assert s["entries"] == 1
|
||||
assert s["hits"] == 1
|
||||
assert s["misses"] == 2
|
||||
|
||||
|
||||
# ── Adapter (mocked HTTP) ──────────────────────────────────────────
|
||||
|
||||
|
||||
def _make_embedding_response(vectors):
|
||||
"""Build a mock API response for the /embeddings endpoint."""
|
||||
return {
|
||||
"data": [
|
||||
{"embedding": vec, "index": i}
|
||||
for i, vec in enumerate(vectors)
|
||||
],
|
||||
"usage": {"prompt_tokens": 5, "total_tokens": 5},
|
||||
}
|
||||
|
||||
|
||||
class TestOpenAICompatibleEmbeddingAdapter:
|
||||
def _adapter(self, **kwargs):
|
||||
defaults = {"api_key": "sk-test", "provider": "openai"}
|
||||
defaults.update(kwargs)
|
||||
return OpenAICompatibleEmbeddingAdapter(**defaults)
|
||||
|
||||
@mock.patch("markitect.llm.embedding_openai.post_json")
|
||||
def test_embed_returns_vectors_in_order(self, mock_post):
|
||||
# Return indices out of order to verify sorting
|
||||
mock_post.return_value = {
|
||||
"data": [
|
||||
{"embedding": [0.2, 0.3], "index": 1},
|
||||
{"embedding": [0.1, 0.2], "index": 0},
|
||||
],
|
||||
"usage": {},
|
||||
}
|
||||
adapter = self._adapter()
|
||||
result = adapter.embed(["text1", "text2"])
|
||||
assert result == [[0.1, 0.2], [0.2, 0.3]]
|
||||
|
||||
@mock.patch("markitect.llm.embedding_openai.post_json")
|
||||
def test_embed_payload_structure(self, mock_post):
|
||||
mock_post.return_value = _make_embedding_response([[0.1]])
|
||||
adapter = self._adapter(model="text-embedding-3-large")
|
||||
adapter.embed(["hello"])
|
||||
|
||||
call_args = mock_post.call_args
|
||||
url = call_args[0][0]
|
||||
payload = call_args[0][1]
|
||||
assert url == "https://api.openai.com/v1/embeddings"
|
||||
assert payload["model"] == "text-embedding-3-large"
|
||||
assert payload["input"] == ["hello"]
|
||||
|
||||
def test_embed_raises_without_api_key(self):
|
||||
adapter = OpenAICompatibleEmbeddingAdapter(api_key=None, provider="openai")
|
||||
adapter._api_key = None
|
||||
with pytest.raises(LLMConfigurationError):
|
||||
adapter.embed(["test"])
|
||||
|
||||
def test_validate_true_with_key(self):
|
||||
adapter = self._adapter()
|
||||
assert adapter.validate() is True
|
||||
|
||||
def test_validate_false_without_key(self):
|
||||
adapter = OpenAICompatibleEmbeddingAdapter(api_key=None, provider="openai")
|
||||
adapter._api_key = None
|
||||
assert adapter.validate() is False
|
||||
|
||||
@mock.patch("markitect.llm.embedding_openai.post_json")
|
||||
@mock.patch("markitect.llm.embedding_openai.time.sleep")
|
||||
def test_retry_on_429(self, mock_sleep, mock_post):
|
||||
mock_post.side_effect = [
|
||||
LLMRateLimitError("rate limited", status_code=429),
|
||||
_make_embedding_response([[0.1, 0.2]]),
|
||||
]
|
||||
adapter = self._adapter(max_retries=2)
|
||||
result = adapter.embed(["test"])
|
||||
assert result == [[0.1, 0.2]]
|
||||
assert mock_sleep.call_count == 1
|
||||
|
||||
def test_openai_provider_base_url(self):
|
||||
adapter = self._adapter(provider="openai")
|
||||
assert adapter._api_base == "https://api.openai.com/v1"
|
||||
|
||||
def test_openrouter_provider_base_url(self):
|
||||
adapter = self._adapter(provider="openrouter")
|
||||
assert adapter._api_base == "https://openrouter.ai/api/v1"
|
||||
|
||||
def test_unknown_provider_raises(self):
|
||||
with pytest.raises(LLMConfigurationError):
|
||||
OpenAICompatibleEmbeddingAdapter(api_key="sk-test", provider="unknown")
|
||||
|
||||
|
||||
# ── Factory ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestCreateEmbeddingAdapter:
|
||||
def test_openai_provider(self):
|
||||
adapter = create_embedding_adapter("openai", api_key="sk-test")
|
||||
assert isinstance(adapter, OpenAICompatibleEmbeddingAdapter)
|
||||
assert adapter._provider == "openai"
|
||||
|
||||
def test_openrouter_provider(self):
|
||||
adapter = create_embedding_adapter("openrouter", api_key="sk-test")
|
||||
assert isinstance(adapter, OpenAICompatibleEmbeddingAdapter)
|
||||
assert adapter._provider == "openrouter"
|
||||
|
||||
def test_unknown_provider_raises(self):
|
||||
with pytest.raises(LLMConfigurationError) as exc_info:
|
||||
create_embedding_adapter("unknown")
|
||||
assert "unknown" in str(exc_info.value)
|
||||
|
||||
def test_model_passed_through(self):
|
||||
adapter = create_embedding_adapter(
|
||||
"openai", model="text-embedding-3-large", api_key="sk-test"
|
||||
)
|
||||
assert adapter._model == "text-embedding-3-large"
|
||||
Reference in New Issue
Block a user