generated from coulomb/repo-seed
215 lines
7.6 KiB
Python
215 lines
7.6 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Protocol
|
|
|
|
from repo_registry.core.models import ContentChunk, Repository
|
|
|
|
|
|
class LLMExtractionError(ValueError):
|
|
pass
|
|
|
|
|
|
class LLMResponseLike(Protocol):
|
|
content: str
|
|
|
|
|
|
class LLMAdapterLike(Protocol):
|
|
def execute_prompt(self, prompt: str, config: Any) -> LLMResponseLike:
|
|
pass
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ExtractedEvidence:
|
|
type: str
|
|
reference: str
|
|
strength: str = "medium"
|
|
source_paths: list[str] = field(default_factory=list)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ExtractedFeature:
|
|
name: str
|
|
type: str
|
|
location: str = ""
|
|
source_paths: list[str] = field(default_factory=list)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ExtractedCapability:
|
|
name: str
|
|
description: str = ""
|
|
inputs: list[str] = field(default_factory=list)
|
|
outputs: list[str] = field(default_factory=list)
|
|
features: list[ExtractedFeature] = field(default_factory=list)
|
|
evidence: list[ExtractedEvidence] = field(default_factory=list)
|
|
source_paths: list[str] = field(default_factory=list)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ExtractedAbility:
|
|
name: str
|
|
description: str = ""
|
|
capabilities: list[ExtractedCapability] = field(default_factory=list)
|
|
source_paths: list[str] = field(default_factory=list)
|
|
|
|
|
|
class LLMCandidateExtractor:
|
|
"""Structured candidate extraction over llm-connect-style adapters."""
|
|
|
|
def __init__(self, adapter: LLMAdapterLike, run_config: Any | None = None) -> None:
|
|
self.adapter = adapter
|
|
self.run_config = run_config or self._default_run_config()
|
|
|
|
def extract(
|
|
self,
|
|
repository: Repository,
|
|
chunks: list[ContentChunk],
|
|
) -> list[ExtractedAbility]:
|
|
prompt = self.build_prompt(repository, chunks)
|
|
response = self.adapter.execute_prompt(prompt, self.run_config)
|
|
return self.parse_response(response.content)
|
|
|
|
def build_prompt(self, repository: Repository, chunks: list[ContentChunk]) -> str:
|
|
chunk_text = "\n\n".join(
|
|
(
|
|
f"Source: {chunk.path}:{chunk.start_line}-{chunk.end_line} "
|
|
f"({chunk.kind})\n{chunk.text}"
|
|
)
|
|
for chunk in chunks[:12]
|
|
)
|
|
return (
|
|
"Extract a conservative, source-linked repository ability map.\n"
|
|
"Return strict JSON only with this shape:\n"
|
|
"{\n"
|
|
' "abilities": [\n'
|
|
" {\n"
|
|
' "name": "...",\n'
|
|
' "description": "...",\n'
|
|
' "source_paths": ["README.md"],\n'
|
|
' "capabilities": [\n'
|
|
" {\n"
|
|
' "name": "...",\n'
|
|
' "description": "...",\n'
|
|
' "inputs": ["..."],\n'
|
|
' "outputs": ["..."],\n'
|
|
' "source_paths": ["..."],\n'
|
|
' "features": [{"name": "...", "type": "...", "location": "...", "source_paths": ["..."]}],\n'
|
|
' "evidence": [{"type": "documentation", "reference": "...", "strength": "medium", "source_paths": ["..."]}]\n'
|
|
" }\n"
|
|
" ]\n"
|
|
" }\n"
|
|
" ]\n"
|
|
"}\n"
|
|
"Do not invent unsupported claims. If sources are weak, keep names generic.\n\n"
|
|
f"Repository: {repository.name}\n"
|
|
f"Description: {repository.description or ''}\n\n"
|
|
f"{chunk_text}\n"
|
|
)
|
|
|
|
def parse_response(self, content: str) -> list[ExtractedAbility]:
|
|
try:
|
|
payload = json.loads(self._json_text(content))
|
|
except json.JSONDecodeError as exc:
|
|
raise LLMExtractionError(f"LLM response was not valid JSON: {exc}") from exc
|
|
abilities = payload.get("abilities")
|
|
if not isinstance(abilities, list):
|
|
raise LLMExtractionError("LLM response must contain an abilities list")
|
|
return [self._ability(item) for item in abilities]
|
|
|
|
def _ability(self, item: dict[str, Any]) -> ExtractedAbility:
|
|
return ExtractedAbility(
|
|
name=self._required_str(item, "name"),
|
|
description=self._optional_str(item, "description"),
|
|
source_paths=self._str_list(item.get("source_paths")),
|
|
capabilities=[
|
|
self._capability(capability)
|
|
for capability in item.get("capabilities", [])
|
|
if isinstance(capability, dict)
|
|
],
|
|
)
|
|
|
|
def _capability(self, item: dict[str, Any]) -> ExtractedCapability:
|
|
return ExtractedCapability(
|
|
name=self._required_str(item, "name"),
|
|
description=self._optional_str(item, "description"),
|
|
inputs=self._str_list(item.get("inputs")),
|
|
outputs=self._str_list(item.get("outputs")),
|
|
source_paths=self._str_list(item.get("source_paths")),
|
|
features=[
|
|
self._feature(feature)
|
|
for feature in item.get("features", [])
|
|
if isinstance(feature, dict)
|
|
],
|
|
evidence=[
|
|
self._evidence(evidence)
|
|
for evidence in item.get("evidence", [])
|
|
if isinstance(evidence, dict)
|
|
],
|
|
)
|
|
|
|
def _feature(self, item: dict[str, Any]) -> ExtractedFeature:
|
|
return ExtractedFeature(
|
|
name=self._required_str(item, "name"),
|
|
type=self._required_str(item, "type"),
|
|
location=self._optional_str(item, "location"),
|
|
source_paths=self._str_list(item.get("source_paths")),
|
|
)
|
|
|
|
def _evidence(self, item: dict[str, Any]) -> ExtractedEvidence:
|
|
return ExtractedEvidence(
|
|
type=self._required_str(item, "type"),
|
|
reference=self._required_str(item, "reference"),
|
|
strength=self._optional_str(item, "strength") or "medium",
|
|
source_paths=self._str_list(item.get("source_paths")),
|
|
)
|
|
|
|
def _json_text(self, content: str) -> str:
|
|
stripped = content.strip()
|
|
if stripped.startswith("```"):
|
|
lines = stripped.splitlines()
|
|
if lines and lines[0].startswith("```"):
|
|
lines = lines[1:]
|
|
if lines and lines[-1].startswith("```"):
|
|
lines = lines[:-1]
|
|
return "\n".join(lines).strip()
|
|
return stripped
|
|
|
|
def _required_str(self, item: dict[str, Any], key: str) -> str:
|
|
value = item.get(key)
|
|
if not isinstance(value, str) or not value.strip():
|
|
raise LLMExtractionError(f"Missing required string field: {key}")
|
|
return value.strip()
|
|
|
|
def _optional_str(self, item: dict[str, Any], key: str) -> str:
|
|
value = item.get(key, "")
|
|
return value.strip() if isinstance(value, str) else ""
|
|
|
|
def _str_list(self, value: Any) -> list[str]:
|
|
if not isinstance(value, list):
|
|
return []
|
|
return [item.strip() for item in value if isinstance(item, str) and item.strip()]
|
|
|
|
def _default_run_config(self) -> Any:
|
|
try:
|
|
from llm_connect import RunConfig
|
|
except ModuleNotFoundError:
|
|
return None
|
|
return RunConfig(temperature=0.1, max_tokens=2000)
|
|
|
|
|
|
def create_llm_connect_adapter(
|
|
provider: str,
|
|
model: str | None = None,
|
|
**kwargs: Any,
|
|
) -> LLMAdapterLike:
|
|
try:
|
|
from llm_connect import create_adapter
|
|
except ModuleNotFoundError as exc:
|
|
raise LLMExtractionError(
|
|
"llm-connect is not installed. Install the sibling project with "
|
|
"`python -m pip install -e ../llm-connect`."
|
|
) from exc
|
|
return create_adapter(provider, model=model, **kwargs)
|