Files
railiance-fabric/railiance_fabric/llm_extraction.py

704 lines
26 KiB
Python

from __future__ import annotations
import json
import re
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Iterable
from jsonschema import ValidationError
from .canon import edge_canon_mapping, evidence_state_for, node_canon_mapping
from .discovery import (
attribute_stable_key,
discovery_stable_key,
relationship_stable_key,
replacement_scope_id,
short_fingerprint,
source_fingerprint,
)
from .schema_validation import draft202012_validator
PROMPT_VERSION = "repo-evidence-v1"
EXTRACTOR_ID = "llm-connect-repo-evidence"
EXTRACTOR_VERSION = "0.1.0"
@dataclass(frozen=True)
class LLMExtractionConfig:
provider: str = "mock"
model: str = "mock"
temperature: float = 0.0
max_tokens: int = 1500
min_confidence: float = 0.6
max_evidence_items: int = 14
api_key: str | None = None
@dataclass(frozen=True)
class LocalRunConfig:
model_name: str
temperature: float
max_tokens: int
model_params: dict[str, object]
class LLMExtractionError(RuntimeError):
pass
def augment_snapshot_with_llm(
snapshot: dict[str, Any],
*,
config: LLMExtractionConfig | None = None,
adapter: object | None = None,
) -> dict[str, Any]:
"""Return a copy of ``snapshot`` enriched with schema-gated LLM candidates."""
config = config or LLMExtractionConfig()
augmented = _copy_json(snapshot)
artifacts: list[dict[str, object]] = list(augmented.get("review_artifacts", []))
bundle = build_evidence_bundle(snapshot, max_items=config.max_evidence_items)
bundle_hash = short_fingerprint(bundle, length=16)
prompt = build_llm_prompt(bundle)
try:
llm_adapter = adapter or create_llm_adapter(config)
run_config = create_run_config(config)
response = llm_adapter.execute_prompt(prompt, run_config)
raw_output = _response_content(response)
except Exception as exc:
augmented["review_artifacts"] = [
*artifacts,
review_artifact(
artifact_type="llm_execution_error",
message=f"LLM extraction failed: {exc}",
payload={"provider": config.provider, "model": config.model},
),
]
_mark_llm_scan_metadata(augmented, config)
return augmented
try:
parsed = parse_llm_json(raw_output)
except Exception as exc:
augmented["review_artifacts"] = [
*artifacts,
review_artifact(
artifact_type="llm_output_invalid",
message=f"LLM output is not a valid structured extraction: {exc}",
payload={"raw_output": raw_output},
),
]
_mark_llm_scan_metadata(augmented, config)
return augmented
response_model = str(getattr(response, "model", config.model) or config.model)
usage = getattr(response, "usage", {}) if isinstance(getattr(response, "usage", {}), dict) else {}
metadata = getattr(response, "metadata", {}) if isinstance(getattr(response, "metadata", {}), dict) else {}
candidates, rejected = project_llm_output(
parsed,
snapshot,
bundle,
config=config,
model=response_model,
usage=usage,
metadata=metadata,
bundle_hash=bundle_hash,
)
artifacts.extend(rejected)
candidate_snapshot = _copy_json(augmented)
_merge_candidates(candidate_snapshot, candidates)
if artifacts:
candidate_snapshot["review_artifacts"] = artifacts
_mark_llm_scan_metadata(candidate_snapshot, config)
try:
draft202012_validator(Path("schemas") / "discovery-snapshot.schema.yaml").validate(candidate_snapshot)
except ValidationError as exc:
augmented["review_artifacts"] = [
*artifacts,
review_artifact(
artifact_type="llm_output_invalid",
message=f"LLM candidates did not validate against discovery schema: {exc.message}",
payload={"parsed_output": parsed},
),
]
_mark_llm_scan_metadata(augmented, config)
return augmented
return candidate_snapshot
def create_llm_adapter(config: LLMExtractionConfig) -> object:
try:
from llm_connect import create_adapter
except ModuleNotFoundError as exc:
raise LLMExtractionError("llm-connect is not importable") from exc
return create_adapter(config.provider, model=config.model, api_key=config.api_key)
def create_run_config(config: LLMExtractionConfig) -> object:
try:
from llm_connect import RunConfig
except ModuleNotFoundError:
return LocalRunConfig(
model_name=config.model,
temperature=config.temperature,
max_tokens=config.max_tokens,
model_params={
"response_format": "json_object",
"prompt_version": PROMPT_VERSION,
},
)
return RunConfig(
model_name=config.model,
temperature=config.temperature,
max_tokens=config.max_tokens,
model_params={
"response_format": "json_object",
"prompt_version": PROMPT_VERSION,
},
)
def build_evidence_bundle(snapshot: dict[str, Any], *, max_items: int = 14) -> dict[str, object]:
candidates = snapshot.get("candidates") if isinstance(snapshot.get("candidates"), dict) else {}
nodes = candidates.get("nodes") if isinstance(candidates.get("nodes"), list) else []
attributes = candidates.get("attributes") if isinstance(candidates.get("attributes"), list) else []
scored: list[tuple[int, dict[str, object]]] = []
for node in nodes:
if not isinstance(node, dict):
continue
scored.append((_node_evidence_score(node), _bundle_node(node)))
scored.sort(key=lambda item: (-item[0], str(item[1].get("id", ""))))
text_attributes = [
_bundle_attribute(attribute)
for attribute in attributes
if isinstance(attribute, dict) and str(attribute.get("name", "")).endswith(("_title", "_present"))
]
return {
"repo": snapshot.get("source", {}),
"scan": {
"run_id": (snapshot.get("scan") or {}).get("run_id", ""),
"profile": (snapshot.get("scan") or {}).get("profile", ""),
},
"evidence": [item for _, item in scored[:max_items]],
"attributes": text_attributes[:max_items],
}
def build_llm_prompt(bundle: dict[str, object]) -> str:
return "\n".join(
[
"You are enriching a Railiance Fabric discovery snapshot.",
"Use only the JSON evidence bundle below. Do not invent facts.",
"Return strict JSON with this shape:",
'{"nodes":[],"edges":[],"attributes":[]}',
"Node fields: kind, label, confidence, evidence_refs, rationale, aliases, attributes.",
"Edge fields: edge_type, source_label or source_key, target_label or target_key, confidence, evidence_refs, rationale.",
"Attribute fields: entity_label or entity_key, name, value, confidence, evidence_refs, rationale.",
"Use confidence from 0 to 1. Low confidence or uncertainty is acceptable; it will be reviewed.",
"Evidence bundle:",
json.dumps(bundle, indent=2, sort_keys=True),
]
)
def parse_llm_json(content: str) -> dict[str, object]:
text = _strip_code_fence(content.strip())
try:
parsed = json.loads(text)
except json.JSONDecodeError as exc:
raise LLMExtractionError(f"LLM output is not valid JSON: {exc}") from exc
if not isinstance(parsed, dict):
raise LLMExtractionError("LLM output must be a JSON object")
return parsed
def project_llm_output(
output: dict[str, object],
snapshot: dict[str, Any],
bundle: dict[str, object],
*,
config: LLMExtractionConfig,
model: str,
usage: dict[str, object],
metadata: dict[str, object],
bundle_hash: str,
) -> tuple[dict[str, list[dict[str, object]]], list[dict[str, object]]]:
repo_slug = str((snapshot.get("source") or {}).get("repo_slug") or "repo")
run_id = str((snapshot.get("scan") or {}).get("run_id") or "")
scope = _llm_scope(repo_slug, bundle_hash)
llm_anchor = _llm_anchor(run_id, bundle_hash)
provenance_base = {
"extractor_id": EXTRACTOR_ID,
"extractor_version": EXTRACTOR_VERSION,
"method": "llm",
"origin": "llm",
"prompt_version": PROMPT_VERSION,
"provider": config.provider,
"model": model,
"usage": usage,
}
if metadata:
provenance_base["rationale"] = f"metadata={json.dumps(metadata, sort_keys=True, default=str)}"
evidence_index = _evidence_index(bundle)
entity_index = _entity_index(snapshot)
candidates = {"nodes": [], "edges": [], "attributes": []}
artifacts: list[dict[str, object]] = []
for raw_node in _object_list(output.get("nodes")):
confidence = _confidence(raw_node.get("confidence"))
if confidence < config.min_confidence:
artifacts.append(_low_confidence_artifact(raw_node, confidence))
continue
label = str(raw_node.get("label") or "").strip()
kind = str(raw_node.get("kind") or "DiscoveredEntity").strip()
if not label:
artifacts.append(_invalid_candidate_artifact("LLM node is missing label", raw_node))
continue
stable_key = discovery_stable_key(repo_slug, kind, label)
entity_index[_entity_lookup_key(label, kind)] = stable_key
entity_index[_entity_lookup_key(label, "")] = stable_key
source_anchors = _anchors_for_refs(raw_node.get("evidence_refs"), evidence_index, llm_anchor)
provenance = {**provenance_base}
rationale = str(raw_node.get("rationale") or "").strip()
if rationale:
provenance["rationale"] = rationale
canon_mapping = node_canon_mapping(kind)
candidates["nodes"].append(
{
"stable_key": stable_key,
"kind": kind,
"label": label,
"repo": repo_slug,
"canon_category": canon_mapping.category,
"canon_anchor": canon_mapping.canon_anchor,
"mapping_fit": canon_mapping.fit,
"evidence_state": evidence_state_for(
origin="llm",
source_kind="llm",
review_state="needs_review",
confidence=confidence,
),
"aliases": _strings(raw_node.get("aliases")) + [label],
"attributes": _json_object(raw_node.get("attributes")) if isinstance(raw_node.get("attributes"), dict) else {},
"origin": "llm",
"review_state": "needs_review",
"status": "active",
"confidence": confidence,
"replacement_scope": scope["id"],
"provenance": [provenance],
"source_anchors": source_anchors,
}
)
for raw_edge in _object_list(output.get("edges")):
confidence = _confidence(raw_edge.get("confidence"))
if confidence < config.min_confidence:
artifacts.append(_low_confidence_artifact(raw_edge, confidence))
continue
edge_type = str(raw_edge.get("edge_type") or "").strip()
source_key = _resolve_entity_key(raw_edge, "source", entity_index)
target_key = _resolve_entity_key(raw_edge, "target", entity_index)
if not edge_type or not source_key or not target_key:
artifacts.append(_unresolved_candidate_artifact("LLM edge endpoint could not be resolved", raw_edge))
continue
source_anchors = _anchors_for_refs(raw_edge.get("evidence_refs"), evidence_index, llm_anchor)
provenance = {**provenance_base}
rationale = str(raw_edge.get("rationale") or "").strip()
if rationale:
provenance["rationale"] = rationale
canon_mapping = edge_canon_mapping(edge_type)
candidates["edges"].append(
{
"stable_key": relationship_stable_key(source_key, edge_type, target_key, evidence_scope=scope["id"]),
"edge_type": edge_type,
"canonical_type": canon_mapping.canonical_type,
"canon_anchor": canon_mapping.canon_anchor,
"mapping_fit": canon_mapping.fit,
"display_only": canon_mapping.display_only,
"evidence_state": evidence_state_for(
origin="llm",
source_kind="llm",
review_state="needs_review",
confidence=confidence,
),
"source_key": source_key,
"target_key": target_key,
"attributes": _json_object(raw_edge.get("attributes")) if isinstance(raw_edge.get("attributes"), dict) else {},
"origin": "llm",
"review_state": "needs_review",
"status": "active",
"confidence": confidence,
"replacement_scope": scope["id"],
"provenance": [provenance],
"source_anchors": source_anchors,
}
)
for raw_attribute in _object_list(output.get("attributes")):
confidence = _confidence(raw_attribute.get("confidence"))
if confidence < config.min_confidence:
artifacts.append(_low_confidence_artifact(raw_attribute, confidence))
continue
entity_key = _resolve_entity_key(raw_attribute, "entity", entity_index)
name = str(raw_attribute.get("name") or "").strip()
if not entity_key or not name:
artifacts.append(_unresolved_candidate_artifact("LLM attribute target could not be resolved", raw_attribute))
continue
source_anchors = _anchors_for_refs(raw_attribute.get("evidence_refs"), evidence_index, llm_anchor)
provenance = {**provenance_base}
rationale = str(raw_attribute.get("rationale") or "").strip()
if rationale:
provenance["rationale"] = rationale
candidates["attributes"].append(
{
"stable_key": attribute_stable_key(entity_key, name),
"entity_key": entity_key,
"name": name,
"value": _json_value(raw_attribute.get("value")),
"origin": "llm",
"review_state": "needs_review",
"confidence": confidence,
"replacement_scope": scope["id"],
"provenance": [provenance],
"source_anchors": source_anchors,
}
)
candidates["replacement_scopes"] = [scope]
return candidates, artifacts
def review_artifact(
*,
artifact_type: str,
message: str,
payload: dict[str, object] | None = None,
evidence_refs: Iterable[str] = (),
) -> dict[str, object]:
now = _utc_now()
body = {
"artifact_type": artifact_type,
"message": message,
"payload": payload or {},
"evidence_refs": list(evidence_refs),
"created_at": now,
}
return {
"id": f"review:{short_fingerprint(body, length=20)}",
"origin": "llm",
**body,
}
def _mark_llm_scan_metadata(snapshot: dict[str, Any], config: LLMExtractionConfig) -> None:
scan = snapshot.setdefault("scan", {})
scan["llm_enabled"] = True
scan["deterministic_only"] = False
scan["llm_budget"] = {
"provider": config.provider,
"model": config.model,
"max_tokens": config.max_tokens,
"min_confidence": config.min_confidence,
"prompt_version": PROMPT_VERSION,
}
def _merge_candidates(snapshot: dict[str, Any], candidates: dict[str, list[dict[str, object]]]) -> None:
existing_scopes = {
str(scope.get("id")): scope
for scope in snapshot.setdefault("replacement_scopes", [])
if isinstance(scope, dict)
}
for scope in candidates.get("replacement_scopes", []):
existing_scopes[str(scope["id"])] = scope
snapshot["replacement_scopes"] = [existing_scopes[key] for key in sorted(existing_scopes)]
snapshot_candidates = snapshot.setdefault("candidates", {"nodes": [], "edges": [], "attributes": []})
for collection in ("nodes", "edges", "attributes"):
existing = {
str(item.get("stable_key")): item
for item in snapshot_candidates.setdefault(collection, [])
if isinstance(item, dict)
}
for incoming in candidates.get(collection, []):
key = str(incoming.get("stable_key"))
existing[key] = _merge_candidate(existing.get(key), incoming)
snapshot_candidates[collection] = [existing[key] for key in sorted(existing)]
def _merge_candidate(existing: dict[str, object] | None, incoming: dict[str, object]) -> dict[str, object]:
if existing is None:
return incoming
merged = {**existing}
for field in ("aliases", "provenance", "source_anchors"):
values = [*list(existing.get(field, [])), *list(incoming.get(field, []))]
if values:
merged[field] = _unique_json(values) if field != "aliases" else _unique_strings(values)
if isinstance(existing.get("attributes"), dict) or isinstance(incoming.get("attributes"), dict):
merged["attributes"] = {
**(existing.get("attributes") if isinstance(existing.get("attributes"), dict) else {}),
**(incoming.get("attributes") if isinstance(incoming.get("attributes"), dict) else {}),
}
if isinstance(existing.get("confidence"), (int, float)) and isinstance(incoming.get("confidence"), (int, float)):
merged["confidence"] = max(float(existing["confidence"]), float(incoming["confidence"]))
return merged
def _llm_scope(repo_slug: str, bundle_hash: str) -> dict[str, object]:
return {
"id": replacement_scope_id(repo_slug, EXTRACTOR_ID, "llm", source_path=bundle_hash),
"extractor_id": EXTRACTOR_ID,
"source_kind": "llm",
"source_path": bundle_hash,
"mode": "additive",
"description": "LLM-assisted extraction over deterministic evidence bundle.",
}
def _llm_anchor(run_id: str, bundle_hash: str) -> dict[str, object]:
anchor = {
"source_kind": "llm",
"ref": f"{PROMPT_VERSION}:{run_id}:{bundle_hash}",
}
anchor["fingerprint"] = source_fingerprint(anchor)
return anchor
def _evidence_index(bundle: dict[str, object]) -> dict[str, list[dict[str, object]]]:
index: dict[str, list[dict[str, object]]] = {}
for item in list(bundle.get("evidence", [])) + list(bundle.get("attributes", [])):
if not isinstance(item, dict):
continue
item_id = str(item.get("id") or "")
anchors = item.get("source_anchors")
if item_id and isinstance(anchors, list):
index[item_id] = [anchor for anchor in anchors if isinstance(anchor, dict)]
return index
def _entity_index(snapshot: dict[str, Any]) -> dict[str, str]:
index: dict[str, str] = {}
candidates = snapshot.get("candidates") if isinstance(snapshot.get("candidates"), dict) else {}
for node in candidates.get("nodes", []):
if not isinstance(node, dict):
continue
stable_key = str(node.get("stable_key") or "")
kind = str(node.get("kind") or "")
label = str(node.get("label") or "")
if stable_key:
index[stable_key] = stable_key
if label and stable_key:
index[_entity_lookup_key(label, kind)] = stable_key
index[_entity_lookup_key(label, "")] = stable_key
for alias in _strings(node.get("aliases")):
index[_entity_lookup_key(alias, kind)] = stable_key
index[_entity_lookup_key(alias, "")] = stable_key
graph_id = str(node.get("graph_id") or "")
if graph_id and stable_key:
index[_entity_lookup_key(graph_id, kind)] = stable_key
index[_entity_lookup_key(graph_id, "")] = stable_key
return index
def _entity_lookup_key(label: str, kind: str) -> str:
return f"{kind.strip().lower()}::{label.strip().lower()}"
def _resolve_entity_key(raw: dict[str, object], role: str, entity_index: dict[str, str]) -> str:
explicit = str(raw.get(f"{role}_key") or "").strip()
if explicit:
return entity_index.get(explicit, explicit if explicit.startswith("discovery:") else "")
label = str(raw.get(f"{role}_label") or "").strip()
kind = str(raw.get(f"{role}_kind") or "").strip()
if not label:
return ""
return entity_index.get(_entity_lookup_key(label, kind), entity_index.get(_entity_lookup_key(label, ""), ""))
def _anchors_for_refs(
refs: object,
evidence_index: dict[str, list[dict[str, object]]],
fallback: dict[str, object],
) -> list[dict[str, object]]:
anchors: list[dict[str, object]] = []
for ref in _strings(refs):
anchors.extend(evidence_index.get(ref, []))
anchors.append(fallback)
return _unique_json(anchors)
def _node_evidence_score(node: dict[str, object]) -> int:
kind = str(node.get("kind") or "")
score = 1
if node.get("origin") == "repo_declaration":
score += 10
if kind in {"ServiceDeclaration", "CapabilityDeclaration", "InterfaceDeclaration", "Library"}:
score += 8
if kind in {"DeploymentService", "ContainerBuild", "ScoreWorkload"} or kind.startswith("Kubernetes"):
score += 5
if kind in {"Repository", "ExternalLibrary", "Lockfile", "ServiceConfig"}:
score += 2
return score
def _bundle_node(node: dict[str, object]) -> dict[str, object]:
return {
"id": str(node.get("stable_key") or ""),
"kind": node.get("kind") or "",
"label": node.get("label") or "",
"graph_id": node.get("graph_id") or "",
"origin": node.get("origin") or "",
"review_state": node.get("review_state") or "",
"attributes": _compact_attributes(node.get("attributes")),
"source_anchors": node.get("source_anchors") if isinstance(node.get("source_anchors"), list) else [],
}
def _bundle_attribute(attribute: dict[str, object]) -> dict[str, object]:
return {
"id": str(attribute.get("stable_key") or ""),
"entity_key": attribute.get("entity_key") or "",
"name": attribute.get("name") or "",
"value": attribute.get("value"),
"source_anchors": attribute.get("source_anchors") if isinstance(attribute.get("source_anchors"), list) else [],
}
def _compact_attributes(value: object) -> dict[str, object]:
if not isinstance(value, dict):
return {}
compact: dict[str, object] = {}
for key, item in value.items():
if key in {"metadata", "spec"}:
continue
compact[str(key)] = _json_value(item)
return compact
def _low_confidence_artifact(raw: dict[str, object], confidence: float) -> dict[str, object]:
return review_artifact(
artifact_type="llm_low_confidence",
message=f"LLM candidate below confidence threshold: {confidence:.2f}",
payload={"candidate": raw, "confidence": confidence},
evidence_refs=_strings(raw.get("evidence_refs")),
)
def _invalid_candidate_artifact(message: str, raw: dict[str, object]) -> dict[str, object]:
return review_artifact(
artifact_type="llm_output_invalid",
message=message,
payload={"candidate": raw},
evidence_refs=_strings(raw.get("evidence_refs")),
)
def _unresolved_candidate_artifact(message: str, raw: dict[str, object]) -> dict[str, object]:
return review_artifact(
artifact_type="llm_candidate_unresolved",
message=message,
payload={"candidate": raw},
evidence_refs=_strings(raw.get("evidence_refs")),
)
def _object_list(value: object) -> list[dict[str, object]]:
if not isinstance(value, list):
return []
return [item for item in value if isinstance(item, dict)]
def _confidence(value: object) -> float:
if isinstance(value, (int, float)):
return max(0.0, min(1.0, float(value)))
return 0.0
def _strings(value: object) -> list[str]:
if isinstance(value, str):
values = [value]
elif isinstance(value, list):
values = value
else:
values = []
result: list[str] = []
seen: set[str] = set()
for item in values:
text = str(item or "").strip()
if not text or text in seen:
continue
seen.add(text)
result.append(text)
return result
def _json_object(value: object) -> dict[str, object]:
if not isinstance(value, dict):
return {}
return {str(key): _json_value(item) for key, item in value.items()}
def _json_value(value: object) -> object:
if value is None or isinstance(value, (str, int, float, bool)):
return value
if isinstance(value, list):
return [_json_value(item) for item in value]
if isinstance(value, tuple):
return [_json_value(item) for item in value]
if isinstance(value, dict):
return {str(key): _json_value(item) for key, item in value.items()}
return str(value)
def _unique_strings(values: Iterable[object]) -> list[str]:
seen: set[str] = set()
result: list[str] = []
for value in values:
text = str(value or "").strip()
if not text or text in seen:
continue
seen.add(text)
result.append(text)
return result
def _unique_json(values: Iterable[object]) -> list[object]:
seen: set[str] = set()
result: list[object] = []
for value in values:
key = json.dumps(value, sort_keys=True, default=str)
if key in seen:
continue
seen.add(key)
result.append(value)
return result
def _response_content(response: object) -> str:
content = getattr(response, "content", "")
if not isinstance(content, str):
raise LLMExtractionError("LLM response content must be text")
return content
def _strip_code_fence(text: str) -> str:
match = re.fullmatch(r"```(?:json)?\s*(.*?)\s*```", text, flags=re.DOTALL)
return match.group(1) if match else text
def _copy_json(value: dict[str, Any]) -> dict[str, Any]:
return json.loads(json.dumps(value, default=str))
def _utc_now() -> str:
return datetime.now(timezone.utc).replace(microsecond=0).isoformat().replace("+00:00", "Z")