feat(llm): add embedding adapter with cache and similarity utils (S1.3)
Add OpenAI-compatible embedding support (works with both OpenAI and OpenRouter), file-based embedding cache with content-digest invalidation, and pure-Python cosine similarity utilities for downstream redundancy detection. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
125
markitect/llm/embedding_openai.py
Normal file
125
markitect/llm/embedding_openai.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""
|
||||
OpenAI-compatible embedding adapter.
|
||||
|
||||
Works with both OpenAI (``/v1/embeddings``) and OpenRouter
|
||||
(``/api/v1/embeddings``) since they share the same API format.
|
||||
The *provider* parameter determines the default base URL and
|
||||
API key environment variable.
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from markitect.llm.embedding_adapter import EmbeddingAdapter
|
||||
from markitect.llm.config import resolve_api_key, find_project_root
|
||||
from markitect.llm._http import post_json
|
||||
from markitect.llm.exceptions import (
|
||||
LLMConfigurationError,
|
||||
LLMAPIError,
|
||||
LLMRateLimitError,
|
||||
)
|
||||
|
||||
_DEFAULT_MODEL = "text-embedding-3-small"
|
||||
|
||||
_PROVIDER_DEFAULTS: Dict[str, Dict[str, str]] = {
|
||||
"openai": {
|
||||
"api_base": "https://api.openai.com/v1",
|
||||
"env_var": "OPENAI_API_KEY",
|
||||
},
|
||||
"openrouter": {
|
||||
"api_base": "https://openrouter.ai/api/v1",
|
||||
"env_var": "OPENROUTER_API_KEY",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class OpenAICompatibleEmbeddingAdapter(EmbeddingAdapter):
|
||||
"""Embedding adapter for OpenAI-compatible endpoints.
|
||||
|
||||
A single class handles both OpenAI and OpenRouter because they
|
||||
expose the same ``/embeddings`` endpoint format.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
provider: str = "openai",
|
||||
max_retries: int = 3,
|
||||
):
|
||||
if provider not in _PROVIDER_DEFAULTS:
|
||||
known = ", ".join(sorted(_PROVIDER_DEFAULTS))
|
||||
raise LLMConfigurationError(
|
||||
f"Unknown embedding provider {provider!r}. Choose from: {known}",
|
||||
context={"provider": provider},
|
||||
)
|
||||
|
||||
defaults = _PROVIDER_DEFAULTS[provider]
|
||||
self._model = model or _DEFAULT_MODEL
|
||||
self._api_base = (api_base or defaults["api_base"]).rstrip("/")
|
||||
self._max_retries = max_retries
|
||||
self._provider = provider
|
||||
|
||||
# Resolve API key
|
||||
env_var = defaults["env_var"]
|
||||
root = find_project_root()
|
||||
key_file_paths = [root / f"apikey-{provider}.txt"] if root else []
|
||||
self._api_key = resolve_api_key(
|
||||
explicit=api_key,
|
||||
env_var=env_var,
|
||||
key_file_paths=key_file_paths,
|
||||
)
|
||||
|
||||
def embed(self, texts: list[str]) -> list[list[float]]:
|
||||
"""Embed texts via the OpenAI-compatible ``/embeddings`` endpoint.
|
||||
|
||||
Raises:
|
||||
LLMConfigurationError: If no API key is configured.
|
||||
LLMAPIError: On HTTP errors after retries are exhausted.
|
||||
"""
|
||||
if not self._api_key:
|
||||
raise LLMConfigurationError(
|
||||
"No API key configured for embedding adapter",
|
||||
context={"provider": self._provider},
|
||||
)
|
||||
|
||||
url = f"{self._api_base}/embeddings"
|
||||
payload: Dict[str, Any] = {
|
||||
"model": self._model,
|
||||
"input": texts,
|
||||
}
|
||||
headers = {"Authorization": f"Bearer {self._api_key}"}
|
||||
|
||||
data = self._post_with_retries(url, payload, headers)
|
||||
|
||||
# Response: {"data": [{"embedding": [...], "index": 0}, ...]}
|
||||
# Sort by index to guarantee input order.
|
||||
items = sorted(data["data"], key=lambda d: d["index"])
|
||||
return [item["embedding"] for item in items]
|
||||
|
||||
def validate(self) -> bool:
|
||||
"""Return ``True`` if an API key is available."""
|
||||
return self._api_key is not None
|
||||
|
||||
def _post_with_retries(
|
||||
self,
|
||||
url: str,
|
||||
payload: Dict[str, Any],
|
||||
headers: Dict[str, str],
|
||||
) -> Dict[str, Any]:
|
||||
last_exc: Optional[Exception] = None
|
||||
for attempt in range(self._max_retries + 1):
|
||||
try:
|
||||
return post_json(url, payload, headers)
|
||||
except LLMRateLimitError as exc:
|
||||
last_exc = exc
|
||||
if attempt < self._max_retries:
|
||||
time.sleep(2 ** attempt)
|
||||
except LLMAPIError as exc:
|
||||
if exc.status_code >= 500 and attempt < self._max_retries:
|
||||
last_exc = exc
|
||||
time.sleep(2 ** attempt)
|
||||
else:
|
||||
raise
|
||||
raise last_exc # type: ignore[misc]
|
||||
Reference in New Issue
Block a user