generated from coulomb/repo-seed
143 lines
5.1 KiB
Python
143 lines
5.1 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
import time
|
|
import urllib.error
|
|
import urllib.request
|
|
from dataclasses import dataclass
|
|
from typing import Any, Callable
|
|
|
|
from .errors import InfospaceError
|
|
from .workflow import AssistedGenerationRequest, AssistedGenerationResult
|
|
|
|
OPENROUTER_ENDPOINT = "https://openrouter.ai/api/v1/chat/completions"
|
|
Transport = Callable[[dict[str, Any], dict[str, str], str], dict[str, Any]]
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class OpenRouterAssistedGenerationAdapter:
|
|
model: str
|
|
api_key: str = ""
|
|
endpoint: str = OPENROUTER_ENDPOINT
|
|
transport: Transport | None = None
|
|
retry_limit: int = 2
|
|
timeout_seconds: float = 60.0
|
|
|
|
def __post_init__(self) -> None:
|
|
key = self.api_key or os.environ.get("OPENROUTER_API_KEY", "")
|
|
if not key:
|
|
raise InfospaceError(
|
|
"missing_openrouter_api_key",
|
|
"OPENROUTER_API_KEY is required for the OpenRouter provider",
|
|
{"env": "OPENROUTER_API_KEY"},
|
|
)
|
|
object.__setattr__(self, "api_key", key)
|
|
if not self.model:
|
|
raise InfospaceError(
|
|
"missing_openrouter_model",
|
|
"OpenRouter provider requires an explicit model",
|
|
{"option": "--model"},
|
|
)
|
|
|
|
def generate(
|
|
self,
|
|
request: AssistedGenerationRequest,
|
|
) -> AssistedGenerationResult:
|
|
payload = {
|
|
"model": self.model,
|
|
"messages": [
|
|
{
|
|
"role": "system",
|
|
"content": (
|
|
"Return concise, valid Markdown only. Preserve explicit "
|
|
"contracts requested in the user prompt."
|
|
),
|
|
},
|
|
{"role": "user", "content": request.prompt},
|
|
],
|
|
"metadata": {
|
|
"workflow_id": request.workflow_id,
|
|
"stage_id": request.stage_id,
|
|
"input_artifact_id": request.input_artifact_id,
|
|
},
|
|
}
|
|
headers = {
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
"Content-Type": "application/json",
|
|
"HTTP-Referer": "https://github.com/markitect/infospace-bench",
|
|
"X-Title": "infospace-bench",
|
|
}
|
|
started = time.monotonic()
|
|
retry_count = 0
|
|
last_error = ""
|
|
while True:
|
|
try:
|
|
response = (
|
|
self.transport(payload, headers, self.endpoint)
|
|
if self.transport is not None
|
|
else self._default_transport(payload, headers, self.endpoint)
|
|
)
|
|
choice = (response.get("choices") or [{}])[0]
|
|
message = choice.get("message") or {}
|
|
markdown = str(message.get("content") or "")
|
|
if not markdown:
|
|
raise InfospaceError(
|
|
"empty_openrouter_response",
|
|
"OpenRouter returned an empty assistant response",
|
|
{"model": self.model, "response_id": response.get("id")},
|
|
)
|
|
return AssistedGenerationResult(
|
|
markdown=markdown,
|
|
provider="openrouter",
|
|
metadata={
|
|
"model": self.model,
|
|
"request_id": str(response.get("id") or ""),
|
|
"usage": response.get("usage") or {},
|
|
"retry_count": retry_count,
|
|
"duration_seconds": round(time.monotonic() - started, 3),
|
|
},
|
|
)
|
|
except (urllib.error.HTTPError, urllib.error.URLError, TimeoutError) as exc:
|
|
last_error = str(exc)
|
|
except InfospaceError:
|
|
raise
|
|
except Exception as exc: # pragma: no cover - defensive provider boundary
|
|
last_error = str(exc)
|
|
|
|
if retry_count >= self.retry_limit:
|
|
raise InfospaceError(
|
|
"openrouter_request_failed",
|
|
"OpenRouter request failed after bounded retries",
|
|
{
|
|
"model": self.model,
|
|
"retry_count": retry_count,
|
|
"error": last_error,
|
|
},
|
|
)
|
|
retry_count += 1
|
|
time.sleep(min(2**retry_count, 8))
|
|
|
|
def _default_transport(
|
|
self,
|
|
payload: dict[str, Any],
|
|
headers: dict[str, str],
|
|
endpoint: str,
|
|
) -> dict[str, Any]:
|
|
request = urllib.request.Request(
|
|
endpoint,
|
|
data=json.dumps(payload).encode("utf-8"),
|
|
headers=headers,
|
|
method="POST",
|
|
)
|
|
with urllib.request.urlopen(request, timeout=self.timeout_seconds) as response:
|
|
data = response.read().decode("utf-8")
|
|
parsed = json.loads(data)
|
|
if not isinstance(parsed, dict):
|
|
raise InfospaceError(
|
|
"invalid_openrouter_response",
|
|
"OpenRouter returned a non-object JSON response",
|
|
{"model": self.model},
|
|
)
|
|
return parsed
|