Files
markitect-main/markitect/llm/embedding_openai.py
tegwick 267368eb60 feat(llm): add embedding adapter with cache and similarity utils (S1.3)
Add OpenAI-compatible embedding support (works with both OpenAI and
OpenRouter), file-based embedding cache with content-digest invalidation,
and pure-Python cosine similarity utilities for downstream redundancy
detection.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-19 01:22:21 +01:00

126 lines
4.1 KiB
Python

"""
OpenAI-compatible embedding adapter.
Works with both OpenAI (``/v1/embeddings``) and OpenRouter
(``/api/v1/embeddings``) since they share the same API format.
The *provider* parameter determines the default base URL and
API key environment variable.
"""
import time
from typing import Optional, Dict, Any
from markitect.llm.embedding_adapter import EmbeddingAdapter
from markitect.llm.config import resolve_api_key, find_project_root
from markitect.llm._http import post_json
from markitect.llm.exceptions import (
LLMConfigurationError,
LLMAPIError,
LLMRateLimitError,
)
_DEFAULT_MODEL = "text-embedding-3-small"
_PROVIDER_DEFAULTS: Dict[str, Dict[str, str]] = {
"openai": {
"api_base": "https://api.openai.com/v1",
"env_var": "OPENAI_API_KEY",
},
"openrouter": {
"api_base": "https://openrouter.ai/api/v1",
"env_var": "OPENROUTER_API_KEY",
},
}
class OpenAICompatibleEmbeddingAdapter(EmbeddingAdapter):
"""Embedding adapter for OpenAI-compatible endpoints.
A single class handles both OpenAI and OpenRouter because they
expose the same ``/embeddings`` endpoint format.
"""
def __init__(
self,
model: Optional[str] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
provider: str = "openai",
max_retries: int = 3,
):
if provider not in _PROVIDER_DEFAULTS:
known = ", ".join(sorted(_PROVIDER_DEFAULTS))
raise LLMConfigurationError(
f"Unknown embedding provider {provider!r}. Choose from: {known}",
context={"provider": provider},
)
defaults = _PROVIDER_DEFAULTS[provider]
self._model = model or _DEFAULT_MODEL
self._api_base = (api_base or defaults["api_base"]).rstrip("/")
self._max_retries = max_retries
self._provider = provider
# Resolve API key
env_var = defaults["env_var"]
root = find_project_root()
key_file_paths = [root / f"apikey-{provider}.txt"] if root else []
self._api_key = resolve_api_key(
explicit=api_key,
env_var=env_var,
key_file_paths=key_file_paths,
)
def embed(self, texts: list[str]) -> list[list[float]]:
"""Embed texts via the OpenAI-compatible ``/embeddings`` endpoint.
Raises:
LLMConfigurationError: If no API key is configured.
LLMAPIError: On HTTP errors after retries are exhausted.
"""
if not self._api_key:
raise LLMConfigurationError(
"No API key configured for embedding adapter",
context={"provider": self._provider},
)
url = f"{self._api_base}/embeddings"
payload: Dict[str, Any] = {
"model": self._model,
"input": texts,
}
headers = {"Authorization": f"Bearer {self._api_key}"}
data = self._post_with_retries(url, payload, headers)
# Response: {"data": [{"embedding": [...], "index": 0}, ...]}
# Sort by index to guarantee input order.
items = sorted(data["data"], key=lambda d: d["index"])
return [item["embedding"] for item in items]
def validate(self) -> bool:
"""Return ``True`` if an API key is available."""
return self._api_key is not None
def _post_with_retries(
self,
url: str,
payload: Dict[str, Any],
headers: Dict[str, str],
) -> Dict[str, Any]:
last_exc: Optional[Exception] = None
for attempt in range(self._max_retries + 1):
try:
return post_json(url, payload, headers)
except LLMRateLimitError as exc:
last_exc = exc
if attempt < self._max_retries:
time.sleep(2 ** attempt)
except LLMAPIError as exc:
if exc.status_code >= 500 and attempt < self._max_retries:
last_exc = exc
time.sleep(2 ** attempt)
else:
raise
raise last_exc # type: ignore[misc]