generated from coulomb/repo-seed
Improvements and perspective from architecture review
This commit is contained in:
@@ -278,11 +278,9 @@ class DocumentFunctionRegistry:
|
||||
_validate_arguments(descriptor, args, kwargs)
|
||||
if descriptor.id == "data.get":
|
||||
output = context.variables.get(str(args[0]), kwargs.get("default", ""))
|
||||
raise _FunctionOutputReady(output)
|
||||
assert descriptor.implementation is not None
|
||||
output = descriptor.implementation(*args, **kwargs)
|
||||
except _FunctionOutputReady as ready:
|
||||
output = ready.output
|
||||
else:
|
||||
assert descriptor.implementation is not None
|
||||
output = descriptor.implementation(*args, **kwargs)
|
||||
except Exception as exc:
|
||||
return _call_error(call, "function.evaluation_failed", str(exc), context)
|
||||
|
||||
@@ -513,22 +511,34 @@ def validate_document_functions(
|
||||
diagnostics: list[Diagnostic] = []
|
||||
runs: list[DocumentFunctionRun] = []
|
||||
for call in parse_document_function_calls(text):
|
||||
if allowed_set and call.function_id not in allowed_set:
|
||||
diagnostics.append(_diagnostic(call, "function.not_allowed", f"Function `{call.function_id}` is not allowed."))
|
||||
if call.function_id in forbidden_set:
|
||||
diagnostics.append(_diagnostic(call, "function.forbidden", f"Function `{call.function_id}` is forbidden."))
|
||||
try:
|
||||
descriptor = registry.get(call.function_id)
|
||||
if descriptor.execution != "deterministic":
|
||||
for index, current in enumerate([call, *call.pipeline]):
|
||||
if allowed_set and current.function_id not in allowed_set:
|
||||
diagnostics.append(
|
||||
_diagnostic(
|
||||
call,
|
||||
"function.unstable",
|
||||
f"Function `{call.function_id}` is `{descriptor.execution}` and cannot run in deterministic contexts.",
|
||||
current,
|
||||
"function.not_allowed",
|
||||
f"Function `{current.function_id}` is not allowed.",
|
||||
)
|
||||
)
|
||||
except DocumentFunctionError as exc:
|
||||
diagnostics.append(_diagnostic(call, "function.unknown", str(exc)))
|
||||
if current.function_id in forbidden_set:
|
||||
diagnostics.append(
|
||||
_diagnostic(current, "function.forbidden", f"Function `{current.function_id}` is forbidden.")
|
||||
)
|
||||
try:
|
||||
descriptor = registry.get(current.function_id)
|
||||
if descriptor.execution != "deterministic":
|
||||
diagnostics.append(
|
||||
_diagnostic(
|
||||
current,
|
||||
"function.unstable",
|
||||
f"Function `{current.function_id}` is `{descriptor.execution}` and cannot run in deterministic contexts.",
|
||||
)
|
||||
)
|
||||
args = current.args if index == 0 else ["<pipeline-output>", *current.args]
|
||||
_validate_arguments(descriptor, args, current.kwargs)
|
||||
except DocumentFunctionError as exc:
|
||||
code = "function.unknown" if str(exc).startswith("Unknown document function") else "function.arguments"
|
||||
diagnostics.append(_diagnostic(current, code, str(exc)))
|
||||
runs.append(DocumentFunctionRun(call=call))
|
||||
return DocumentFunctionEvaluationResult(content=text, calls=runs, diagnostics=diagnostics)
|
||||
|
||||
@@ -541,7 +551,7 @@ def _parse_call_expression(
|
||||
line: int | None,
|
||||
body: str | None = None,
|
||||
) -> DocumentFunctionCall:
|
||||
pipeline_parts = [part.strip() for part in expression.split("|") if part.strip()]
|
||||
pipeline_parts = _split_pipeline_expression(expression)
|
||||
if not pipeline_parts:
|
||||
raise DocumentFunctionError("Document function call is empty.")
|
||||
first = _parse_single_call(pipeline_parts[0], raw=raw, inline=inline, line=line, body=body)
|
||||
@@ -627,6 +637,18 @@ def _validate_arguments(
|
||||
required = [parameter for parameter in descriptor.parameters if parameter.required and not parameter.variadic]
|
||||
positional = [parameter for parameter in descriptor.parameters if not parameter.variadic]
|
||||
variadic = next((parameter for parameter in descriptor.parameters if parameter.variadic), None)
|
||||
parameter_names = {parameter.name for parameter in descriptor.parameters}
|
||||
unknown = sorted(set(kwargs) - parameter_names)
|
||||
if unknown:
|
||||
raise DocumentFunctionError(
|
||||
f"Function `{descriptor.id}` received unknown named argument `{unknown[0]}`."
|
||||
)
|
||||
if variadic is None:
|
||||
for index, parameter in enumerate(positional[: len(args)]):
|
||||
if parameter.name in kwargs:
|
||||
raise DocumentFunctionError(
|
||||
f"Function `{descriptor.id}` received `{parameter.name}` both positionally and by name."
|
||||
)
|
||||
if len(args) > len(positional) and variadic is None:
|
||||
raise DocumentFunctionError(f"Function `{descriptor.id}` received too many positional arguments.")
|
||||
for index, parameter in enumerate(required):
|
||||
@@ -635,6 +657,42 @@ def _validate_arguments(
|
||||
raise DocumentFunctionError(f"Function `{descriptor.id}` requires `{parameter.name}`.")
|
||||
|
||||
|
||||
def _split_pipeline_expression(expression: str) -> list[str]:
|
||||
parts: list[str] = []
|
||||
current: list[str] = []
|
||||
quote: str | None = None
|
||||
escaped = False
|
||||
for char in expression:
|
||||
if escaped:
|
||||
current.append(char)
|
||||
escaped = False
|
||||
continue
|
||||
if char == "\\":
|
||||
current.append(char)
|
||||
escaped = True
|
||||
continue
|
||||
if char in {"'", '"'}:
|
||||
if quote == char:
|
||||
quote = None
|
||||
elif quote is None:
|
||||
quote = char
|
||||
current.append(char)
|
||||
continue
|
||||
if char == "|" and quote is None:
|
||||
part = "".join(current).strip()
|
||||
if part:
|
||||
parts.append(part)
|
||||
current = []
|
||||
continue
|
||||
current.append(char)
|
||||
if quote is not None:
|
||||
raise DocumentFunctionError("Invalid function pipeline: unterminated quote.")
|
||||
part = "".join(current).strip()
|
||||
if part:
|
||||
parts.append(part)
|
||||
return parts
|
||||
|
||||
|
||||
def _blocked_capabilities(
|
||||
descriptor: DocumentFunctionDescriptor,
|
||||
context: ProcessingContext,
|
||||
@@ -778,11 +836,6 @@ def _data_get(key: Any, default: Any = "", *, body: Any = None) -> Any:
|
||||
return body if body is not None else default if str(key).startswith("$") else key
|
||||
|
||||
|
||||
class _FunctionOutputReady(Exception):
|
||||
def __init__(self, output: Any) -> None:
|
||||
self.output = output
|
||||
|
||||
|
||||
def _drop_empty(data: dict[str, Any]) -> dict[str, Any]:
|
||||
return {
|
||||
key: value
|
||||
|
||||
Reference in New Issue
Block a user