Add OpenAI-compatible embedding support (works with both OpenAI and OpenRouter), file-based embedding cache with content-digest invalidation, and pure-Python cosine similarity utilities for downstream redundancy detection. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
126 lines
4.1 KiB
Python
126 lines
4.1 KiB
Python
"""
|
|
OpenAI-compatible embedding adapter.
|
|
|
|
Works with both OpenAI (``/v1/embeddings``) and OpenRouter
|
|
(``/api/v1/embeddings``) since they share the same API format.
|
|
The *provider* parameter determines the default base URL and
|
|
API key environment variable.
|
|
"""
|
|
|
|
import time
|
|
from typing import Optional, Dict, Any
|
|
|
|
from markitect.llm.embedding_adapter import EmbeddingAdapter
|
|
from markitect.llm.config import resolve_api_key, find_project_root
|
|
from markitect.llm._http import post_json
|
|
from markitect.llm.exceptions import (
|
|
LLMConfigurationError,
|
|
LLMAPIError,
|
|
LLMRateLimitError,
|
|
)
|
|
|
|
_DEFAULT_MODEL = "text-embedding-3-small"
|
|
|
|
_PROVIDER_DEFAULTS: Dict[str, Dict[str, str]] = {
|
|
"openai": {
|
|
"api_base": "https://api.openai.com/v1",
|
|
"env_var": "OPENAI_API_KEY",
|
|
},
|
|
"openrouter": {
|
|
"api_base": "https://openrouter.ai/api/v1",
|
|
"env_var": "OPENROUTER_API_KEY",
|
|
},
|
|
}
|
|
|
|
|
|
class OpenAICompatibleEmbeddingAdapter(EmbeddingAdapter):
|
|
"""Embedding adapter for OpenAI-compatible endpoints.
|
|
|
|
A single class handles both OpenAI and OpenRouter because they
|
|
expose the same ``/embeddings`` endpoint format.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model: Optional[str] = None,
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
provider: str = "openai",
|
|
max_retries: int = 3,
|
|
):
|
|
if provider not in _PROVIDER_DEFAULTS:
|
|
known = ", ".join(sorted(_PROVIDER_DEFAULTS))
|
|
raise LLMConfigurationError(
|
|
f"Unknown embedding provider {provider!r}. Choose from: {known}",
|
|
context={"provider": provider},
|
|
)
|
|
|
|
defaults = _PROVIDER_DEFAULTS[provider]
|
|
self._model = model or _DEFAULT_MODEL
|
|
self._api_base = (api_base or defaults["api_base"]).rstrip("/")
|
|
self._max_retries = max_retries
|
|
self._provider = provider
|
|
|
|
# Resolve API key
|
|
env_var = defaults["env_var"]
|
|
root = find_project_root()
|
|
key_file_paths = [root / f"apikey-{provider}.txt"] if root else []
|
|
self._api_key = resolve_api_key(
|
|
explicit=api_key,
|
|
env_var=env_var,
|
|
key_file_paths=key_file_paths,
|
|
)
|
|
|
|
def embed(self, texts: list[str]) -> list[list[float]]:
|
|
"""Embed texts via the OpenAI-compatible ``/embeddings`` endpoint.
|
|
|
|
Raises:
|
|
LLMConfigurationError: If no API key is configured.
|
|
LLMAPIError: On HTTP errors after retries are exhausted.
|
|
"""
|
|
if not self._api_key:
|
|
raise LLMConfigurationError(
|
|
"No API key configured for embedding adapter",
|
|
context={"provider": self._provider},
|
|
)
|
|
|
|
url = f"{self._api_base}/embeddings"
|
|
payload: Dict[str, Any] = {
|
|
"model": self._model,
|
|
"input": texts,
|
|
}
|
|
headers = {"Authorization": f"Bearer {self._api_key}"}
|
|
|
|
data = self._post_with_retries(url, payload, headers)
|
|
|
|
# Response: {"data": [{"embedding": [...], "index": 0}, ...]}
|
|
# Sort by index to guarantee input order.
|
|
items = sorted(data["data"], key=lambda d: d["index"])
|
|
return [item["embedding"] for item in items]
|
|
|
|
def validate(self) -> bool:
|
|
"""Return ``True`` if an API key is available."""
|
|
return self._api_key is not None
|
|
|
|
def _post_with_retries(
|
|
self,
|
|
url: str,
|
|
payload: Dict[str, Any],
|
|
headers: Dict[str, str],
|
|
) -> Dict[str, Any]:
|
|
last_exc: Optional[Exception] = None
|
|
for attempt in range(self._max_retries + 1):
|
|
try:
|
|
return post_json(url, payload, headers)
|
|
except LLMRateLimitError as exc:
|
|
last_exc = exc
|
|
if attempt < self._max_retries:
|
|
time.sleep(2 ** attempt)
|
|
except LLMAPIError as exc:
|
|
if exc.status_code >= 500 and attempt < self._max_retries:
|
|
last_exc = exc
|
|
time.sleep(2 ** attempt)
|
|
else:
|
|
raise
|
|
raise last_exc # type: ignore[misc]
|