generated from coulomb/repo-seed
feat(memory): add permission aware retrieval
This commit is contained in:
@@ -136,6 +136,9 @@ from .services import (
|
||||
ContextEntityQueryResult,
|
||||
LexicalIndexRefreshResult,
|
||||
MemoryGraphImportSummary,
|
||||
MemoryQueryRequest,
|
||||
MemoryRetrievalItem,
|
||||
MemoryRetrievalResult,
|
||||
MemoryRuntimeService,
|
||||
RelationshipChangeResult,
|
||||
RelationshipQueryItem,
|
||||
@@ -261,6 +264,9 @@ __all__ = [
|
||||
"MemoryGraphRepository",
|
||||
"MemoryNodeRecord",
|
||||
"MemoryProfileRecord",
|
||||
"MemoryQueryRequest",
|
||||
"MemoryRetrievalItem",
|
||||
"MemoryRetrievalResult",
|
||||
"MemoryRuntimeService",
|
||||
"MemorySourceSpan",
|
||||
"NormalizedDocument",
|
||||
|
||||
@@ -7,7 +7,13 @@ from .asset_service import (
|
||||
)
|
||||
from .content_service import RepresentationContentResult, RepresentationContentStream, RepresentationContentService
|
||||
from .ingestion_service import AssetIngestionResult, AssetIngestionService
|
||||
from .memory_service import MemoryGraphImportSummary, MemoryRuntimeService
|
||||
from .memory_service import (
|
||||
MemoryGraphImportSummary,
|
||||
MemoryQueryRequest,
|
||||
MemoryRetrievalItem,
|
||||
MemoryRetrievalResult,
|
||||
MemoryRuntimeService,
|
||||
)
|
||||
from .retrieval_service import (
|
||||
AssetQueryItem,
|
||||
AssetQueryRequest,
|
||||
@@ -55,6 +61,9 @@ __all__ = [
|
||||
"ContextEntityQueryResult",
|
||||
"LexicalIndexRefreshResult",
|
||||
"MemoryGraphImportSummary",
|
||||
"MemoryQueryRequest",
|
||||
"MemoryRetrievalItem",
|
||||
"MemoryRetrievalResult",
|
||||
"MemoryRuntimeService",
|
||||
"RelationshipChangeResult",
|
||||
"RepresentationContentResult",
|
||||
|
||||
@@ -9,10 +9,13 @@ from kontextual_engine.core import (
|
||||
AuditEvent,
|
||||
AuditOutcome,
|
||||
MemoryGraphImportResult,
|
||||
MemoryEdgeRecord,
|
||||
MemoryNodeRecord,
|
||||
OperationContext,
|
||||
PolicyDecision,
|
||||
)
|
||||
from kontextual_engine.errors import ValidationError
|
||||
from kontextual_engine.ports import MemoryGraphRepository
|
||||
from kontextual_engine.errors import Diagnostic, ValidationError
|
||||
from kontextual_engine.ports import AllowAllPolicyGateway, MemoryGraphRepository, PolicyGateway
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -41,9 +44,118 @@ class MemoryGraphImportSummary:
|
||||
return data
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MemoryQueryRequest:
|
||||
graph_id: str | None = None
|
||||
node_ids: tuple[str, ...] = ()
|
||||
kinds: tuple[str, ...] = ()
|
||||
text_contains: str | None = None
|
||||
include_edges: bool = True
|
||||
limit: int = 50
|
||||
offset: int = 0
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
object.__setattr__(self, "node_ids", tuple(self.node_ids))
|
||||
object.__setattr__(self, "kinds", tuple(self.kinds))
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"graph_id": self.graph_id,
|
||||
"node_ids": list(self.node_ids),
|
||||
"kinds": list(self.kinds),
|
||||
"text_contains": self.text_contains,
|
||||
"include_edges": self.include_edges,
|
||||
"limit": self.limit,
|
||||
"offset": self.offset,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MemoryRetrievalItem:
|
||||
node: MemoryNodeRecord
|
||||
edges: tuple[MemoryEdgeRecord, ...] = ()
|
||||
policy_decision: PolicyDecision | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
object.__setattr__(self, "edges", tuple(self.edges))
|
||||
|
||||
def to_context_item(self) -> dict[str, Any]:
|
||||
return {
|
||||
"node_id": self.node.node_id,
|
||||
"contract_node_id": self.node.contract_node_id,
|
||||
"kind": self.node.kind,
|
||||
"text": self.node.text,
|
||||
"source_spans": [span.to_dict() for span in self.node.source_spans],
|
||||
"provenance": [dict(item) for item in self.node.provenance],
|
||||
"metadata": {
|
||||
"graph_id": self.node.graph_id,
|
||||
"contract_graph_id": self.node.contract_graph_id,
|
||||
"edges": [edge.to_dict() for edge in self.edges],
|
||||
},
|
||||
}
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"node": self.node.to_dict(),
|
||||
"edges": [edge.to_dict() for edge in self.edges],
|
||||
"context_item": self.to_context_item(),
|
||||
"policy_decision": self.policy_decision.to_dict() if self.policy_decision else None,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MemoryRetrievalResult:
|
||||
request: MemoryQueryRequest
|
||||
correlation_id: str
|
||||
total: int
|
||||
items: tuple[MemoryRetrievalItem, ...] = ()
|
||||
diagnostics: tuple[Diagnostic, ...] = ()
|
||||
audit_event: AuditEvent | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
success: bool = True
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
object.__setattr__(self, "items", tuple(self.items))
|
||||
object.__setattr__(self, "diagnostics", tuple(self.diagnostics))
|
||||
|
||||
@property
|
||||
def result_count(self) -> int:
|
||||
return len(self.items)
|
||||
|
||||
@property
|
||||
def next_offset(self) -> int | None:
|
||||
next_offset = self.request.offset + self.result_count
|
||||
return next_offset if next_offset < self.total else None
|
||||
|
||||
@property
|
||||
def context_items(self) -> tuple[dict[str, Any], ...]:
|
||||
return tuple(item.to_context_item() for item in self.items)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"query": self.request.to_dict(),
|
||||
"correlation_id": self.correlation_id,
|
||||
"success": self.success,
|
||||
"total": self.total,
|
||||
"result_count": self.result_count,
|
||||
"next_offset": self.next_offset,
|
||||
"metadata": dict(self.metadata),
|
||||
"results": [item.to_dict() for item in self.items],
|
||||
"context_items": list(self.context_items),
|
||||
"diagnostics": [diagnostic.to_dict() for diagnostic in self.diagnostics],
|
||||
"audit_event": self.audit_event.to_dict() if self.audit_event else None,
|
||||
}
|
||||
|
||||
|
||||
class MemoryRuntimeService:
|
||||
def __init__(self, repository: MemoryGraphRepository) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
repository: MemoryGraphRepository,
|
||||
*,
|
||||
policy_gateway: PolicyGateway | None = None,
|
||||
) -> None:
|
||||
self.repository = repository
|
||||
self.policy_gateway = policy_gateway or AllowAllPolicyGateway()
|
||||
|
||||
def import_markitect_graph(
|
||||
self,
|
||||
@@ -98,3 +210,190 @@ class MemoryRuntimeService:
|
||||
"intent": imported.intent,
|
||||
},
|
||||
)
|
||||
|
||||
def query_memory(
|
||||
self,
|
||||
request: MemoryQueryRequest,
|
||||
context: OperationContext,
|
||||
) -> MemoryRetrievalResult:
|
||||
diagnostics = _validate_query(request)
|
||||
if diagnostics:
|
||||
return MemoryRetrievalResult(
|
||||
request=request,
|
||||
correlation_id=context.correlation_id,
|
||||
total=0,
|
||||
diagnostics=tuple(diagnostics),
|
||||
success=False,
|
||||
)
|
||||
|
||||
scope_resource = f"memory-graph:{request.graph_id or '*'}"
|
||||
scope_decision = self._authorize(
|
||||
context,
|
||||
"memory.query",
|
||||
scope_resource,
|
||||
resource_metadata={"query": request.to_dict()},
|
||||
)
|
||||
if not scope_decision.allowed:
|
||||
audit_event = AuditEvent.from_context(
|
||||
"memory.query",
|
||||
scope_resource,
|
||||
AuditOutcome.DENIED,
|
||||
context,
|
||||
policy_decision=scope_decision,
|
||||
details={"query": request.to_dict()},
|
||||
)
|
||||
return MemoryRetrievalResult(
|
||||
request=request,
|
||||
correlation_id=context.correlation_id,
|
||||
total=0,
|
||||
diagnostics=(_permission_denied_diagnostic(scope_decision),),
|
||||
audit_event=audit_event,
|
||||
success=False,
|
||||
metadata={"policy_enforced": True, "permission_filtered_count": 0},
|
||||
)
|
||||
|
||||
nodes = self.repository.list_memory_nodes(graph_id=request.graph_id)
|
||||
if request.node_ids:
|
||||
wanted = set(request.node_ids)
|
||||
nodes = [node for node in nodes if node.node_id in wanted or node.contract_node_id in wanted]
|
||||
if request.kinds:
|
||||
wanted_kinds = set(request.kinds)
|
||||
nodes = [node for node in nodes if node.kind in wanted_kinds]
|
||||
if request.text_contains:
|
||||
needle = request.text_contains.casefold()
|
||||
nodes = [node for node in nodes if needle in node.text.casefold()]
|
||||
|
||||
denied_count = 0
|
||||
item_diagnostics: list[Diagnostic] = []
|
||||
allowed_nodes: list[tuple[MemoryNodeRecord, PolicyDecision]] = []
|
||||
for node in nodes:
|
||||
decision = self._authorize(
|
||||
context,
|
||||
"memory.node.retrieve",
|
||||
node.resource_id,
|
||||
resource_metadata=_node_policy_metadata(node),
|
||||
)
|
||||
if decision.allowed:
|
||||
allowed_nodes.append((node, decision))
|
||||
else:
|
||||
denied_count += 1
|
||||
item_diagnostics.append(_permission_denied_diagnostic(decision, node=node))
|
||||
|
||||
allowed_node_ids = {str(node.node_id) for node, _ in allowed_nodes}
|
||||
items: list[MemoryRetrievalItem] = []
|
||||
for node, decision in allowed_nodes:
|
||||
edges: tuple[MemoryEdgeRecord, ...] = ()
|
||||
if request.include_edges:
|
||||
edges = tuple(
|
||||
edge
|
||||
for edge in self.repository.list_memory_edges(graph_id=node.graph_id)
|
||||
if (
|
||||
edge.source_node_id == node.node_id
|
||||
or edge.target_node_id == node.node_id
|
||||
)
|
||||
and edge.source_node_id in allowed_node_ids
|
||||
and edge.target_node_id in allowed_node_ids
|
||||
)
|
||||
items.append(MemoryRetrievalItem(node=node, edges=edges, policy_decision=decision))
|
||||
|
||||
total = len(items)
|
||||
page = tuple(items[request.offset : request.offset + request.limit])
|
||||
outcome = AuditOutcome.PARTIAL if denied_count else AuditOutcome.SUCCESS
|
||||
audit_event = AuditEvent.from_context(
|
||||
"memory.query",
|
||||
scope_resource,
|
||||
outcome,
|
||||
context,
|
||||
policy_decision=scope_decision,
|
||||
details={
|
||||
"query": request.to_dict(),
|
||||
"matched_count": len(nodes),
|
||||
"permission_filtered_count": denied_count,
|
||||
"result_count": len(page),
|
||||
},
|
||||
)
|
||||
return MemoryRetrievalResult(
|
||||
request=request,
|
||||
correlation_id=context.correlation_id,
|
||||
total=total,
|
||||
items=page,
|
||||
diagnostics=tuple(item_diagnostics),
|
||||
audit_event=audit_event,
|
||||
metadata={
|
||||
"zero_result": total == 0,
|
||||
"policy_enforced": True,
|
||||
"permission_filtered_count": denied_count,
|
||||
"context_assembly": "memory-node-source-spans",
|
||||
},
|
||||
)
|
||||
|
||||
def _authorize(
|
||||
self,
|
||||
context: OperationContext,
|
||||
action: str,
|
||||
resource: str,
|
||||
*,
|
||||
resource_metadata: dict[str, Any] | None = None,
|
||||
) -> PolicyDecision:
|
||||
try:
|
||||
return self.policy_gateway.authorize(
|
||||
context,
|
||||
action,
|
||||
resource,
|
||||
resource_metadata=resource_metadata,
|
||||
)
|
||||
except Exception as exc:
|
||||
return PolicyDecision.fail_closed(
|
||||
context.actor.id,
|
||||
action,
|
||||
resource,
|
||||
reason=str(exc) or "Memory policy gateway failed",
|
||||
context={
|
||||
"gateway_error": type(exc).__name__,
|
||||
"resource_metadata": resource_metadata or {},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _validate_query(request: MemoryQueryRequest) -> list[Diagnostic]:
|
||||
diagnostics: list[Diagnostic] = []
|
||||
if request.limit < 1:
|
||||
diagnostics.append(Diagnostic("error", "memory.query.limit_invalid", "limit must be at least 1"))
|
||||
if request.offset < 0:
|
||||
diagnostics.append(Diagnostic("error", "memory.query.offset_invalid", "offset must not be negative"))
|
||||
return diagnostics
|
||||
|
||||
|
||||
def _node_policy_metadata(node: MemoryNodeRecord) -> dict[str, Any]:
|
||||
return {
|
||||
"node_id": node.node_id,
|
||||
"contract_node_id": node.contract_node_id,
|
||||
"graph_id": node.graph_id,
|
||||
"contract_graph_id": node.contract_graph_id,
|
||||
"kind": node.kind,
|
||||
"namespace": dict(node.namespace),
|
||||
"policy": dict(node.policy),
|
||||
"metadata": dict(node.metadata),
|
||||
"source_spans": [span.to_dict() for span in node.source_spans],
|
||||
}
|
||||
|
||||
|
||||
def _permission_denied_diagnostic(
|
||||
decision: PolicyDecision,
|
||||
*,
|
||||
node: MemoryNodeRecord | None = None,
|
||||
) -> Diagnostic:
|
||||
details: dict[str, Any] = {
|
||||
"policy_decision": decision.to_dict(),
|
||||
"resource": decision.resource,
|
||||
}
|
||||
if node:
|
||||
details["node_id"] = node.node_id
|
||||
details["contract_node_id"] = node.contract_node_id
|
||||
details["kind"] = node.kind
|
||||
return Diagnostic(
|
||||
"warning",
|
||||
"memory.permission_denied",
|
||||
decision.reason or "Memory retrieval denied by policy.",
|
||||
details=details,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user