312 lines
10 KiB
Python
312 lines
10 KiB
Python
"""Formula parser — converts formula strings to safe Python AST.
|
|
|
|
KTD7 security: uses ``ast.parse`` then a restricted ``NodeVisitor`` that
|
|
only allows whitelist nodes. **Never** uses ``eval()`` / ``exec()``.
|
|
|
|
Formula syntax:
|
|
- Starts with ``=`` (stripped before parsing).
|
|
- Field references: ``{field_id}`` — converted to ``Name(field_id)`` nodes.
|
|
- Arithmetic: ``+``, ``-``, ``*``, ``/``, ``%``, ``**``.
|
|
- Comparison: ``==``, ``!=``, ``<``, ``>``, ``<=``, ``>=``.
|
|
- Boolean: ``and``, ``or``, ``not``.
|
|
- String concat: ``+`` on strings, or ``CONCAT(...)``.
|
|
- Function calls: ``SUM(...)``, ``AVG(...)``, etc. (registered functions only).
|
|
- Conditional: ``IF(cond, a, b)`` or Python ``a if cond else b``.
|
|
- Literals: numbers, strings (single/double quotes).
|
|
|
|
Examples::
|
|
=1+2*3 → 7
|
|
=SUM({f1}) → aggregate sum of column f1
|
|
={f1} + {f2} → row-level sum of fields f1 and f2
|
|
=CONCAT({f1}, "-") → string concatenation
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import ast
|
|
import re
|
|
from typing import Any
|
|
|
|
from agentkit.bitable.formula.functions import FUNCTION_REGISTRY
|
|
|
|
# ── Exceptions ────────────────────────────────────────────
|
|
|
|
|
|
class FormulaParseError(Exception):
|
|
"""Raised when a formula string cannot be parsed."""
|
|
|
|
|
|
class FormulaSecurityError(Exception):
|
|
"""Raised when a formula AST contains a disallowed node (KTD7)."""
|
|
|
|
|
|
class UnknownFunctionError(Exception):
|
|
"""Raised when a formula calls a function not in the registry."""
|
|
|
|
|
|
# ── Field reference substitution ──────────────────────────
|
|
|
|
# Match {field_id} — field IDs are UUIDs or alphanumeric.
|
|
_FIELD_REF_RE = re.compile(r"\{([a-zA-Z0-9_-]+)\}")
|
|
|
|
|
|
def _substitute_field_refs(formula: str) -> tuple[str, dict[str, str]]:
|
|
"""Replace ``{field_id}`` with ``_f_<safe_name>`` (a Python Name node).
|
|
|
|
Field IDs are UUIDs that may start with a digit, which is invalid in Python
|
|
identifiers. We prefix with ``_f_`` and replace hyphens with underscores.
|
|
A reverse mapping is returned so the engine can map back to real field IDs.
|
|
"""
|
|
mapping: dict[str, str] = {}
|
|
|
|
def _replace(match: re.Match[str]) -> str:
|
|
field_id = match.group(1)
|
|
# Convert UUID-style field_id to a valid Python identifier
|
|
safe_name = "_f_" + field_id.replace("-", "_")
|
|
mapping[safe_name] = field_id
|
|
return safe_name
|
|
|
|
result = _FIELD_REF_RE.sub(_replace, formula)
|
|
return result, mapping
|
|
|
|
|
|
# ── AST whitelist (KTD7) ──────────────────────────────────
|
|
|
|
_ALLOWED_NODES: frozenset[type[ast.AST]] = frozenset(
|
|
{
|
|
ast.Expression,
|
|
ast.BinOp,
|
|
ast.UnaryOp,
|
|
ast.BoolOp,
|
|
ast.Compare,
|
|
ast.Call,
|
|
ast.Name,
|
|
ast.Constant,
|
|
ast.IfExp,
|
|
ast.Load,
|
|
ast.Add,
|
|
ast.Sub,
|
|
ast.Mult,
|
|
ast.Div,
|
|
ast.Mod,
|
|
ast.Pow,
|
|
ast.USub,
|
|
ast.UAdd,
|
|
ast.Not,
|
|
ast.And,
|
|
ast.Or,
|
|
ast.Eq,
|
|
ast.NotEq,
|
|
ast.Lt,
|
|
ast.Gt,
|
|
ast.LtE,
|
|
ast.GtE,
|
|
}
|
|
)
|
|
|
|
|
|
class _SecurityVisitor(ast.NodeVisitor):
|
|
"""Visit AST nodes, reject any not in the whitelist (KTD7)."""
|
|
|
|
def __init__(self, allowed_functions: set[str]) -> None:
|
|
self._allowed_functions = allowed_functions
|
|
|
|
def generic_visit(self, node: ast.AST) -> None: # noqa: D401
|
|
node_type = type(node)
|
|
if node_type not in _ALLOWED_NODES:
|
|
raise FormulaSecurityError(
|
|
f"Disallowed AST node: {node_type.__name__}. Formula contains unsafe constructs."
|
|
)
|
|
super().generic_visit(node)
|
|
|
|
def visit_Call(self, node: ast.Call) -> None:
|
|
# Check that the function being called is a registered Name
|
|
if not isinstance(node.func, ast.Name):
|
|
raise FormulaSecurityError(
|
|
"Only direct function calls by name are allowed. "
|
|
"Method calls and attribute access are forbidden."
|
|
)
|
|
if node.func.id not in self._allowed_functions:
|
|
raise UnknownFunctionError(
|
|
f"Unknown function: '{node.func.id}'. Allowed: {sorted(self._allowed_functions)}"
|
|
)
|
|
self.generic_visit(node)
|
|
|
|
|
|
# ── Public API ────────────────────────────────────────────
|
|
|
|
|
|
def parse_formula(
|
|
formula: str, allowed_functions: set[str] | None = None
|
|
) -> tuple[ast.Expression, dict[str, str]]:
|
|
"""Parse a formula string into a safe AST.
|
|
|
|
Args:
|
|
formula: Formula string, optionally starting with ``=``.
|
|
allowed_functions: Set of registered function names. If None,
|
|
all functions are allowed (used for syntax-only validation).
|
|
|
|
Returns:
|
|
Tuple of (AST expression, field_ref_mapping) where
|
|
field_ref_mapping maps safe Python identifiers to original field IDs.
|
|
|
|
Raises:
|
|
FormulaParseError: Syntax error in formula.
|
|
FormulaSecurityError: Formula contains disallowed AST nodes.
|
|
UnknownFunctionError: Formula calls an unregistered function.
|
|
"""
|
|
expr = formula.strip()
|
|
if expr.startswith("="):
|
|
expr = expr[1:]
|
|
|
|
if not expr:
|
|
raise FormulaParseError("Empty formula")
|
|
|
|
# Substitute field references {field_id} → safe_name
|
|
substituted, field_mapping = _substitute_field_refs(expr)
|
|
|
|
try:
|
|
tree = ast.parse(substituted, mode="eval")
|
|
except SyntaxError as e:
|
|
raise FormulaParseError(f"Syntax error in formula: {e}") from e
|
|
|
|
# Security check
|
|
# When allowed_functions is None, use all registered functions (syntax-only validation)
|
|
if allowed_functions is None:
|
|
allowed = set(FUNCTION_REGISTRY.keys())
|
|
else:
|
|
allowed = allowed_functions
|
|
visitor = _SecurityVisitor(allowed)
|
|
visitor.visit(tree)
|
|
|
|
return tree, field_mapping # type: ignore[return-value]
|
|
|
|
|
|
def evaluate_ast(
|
|
tree: ast.Expression,
|
|
field_values: dict[str, Any],
|
|
functions: dict[str, Any],
|
|
) -> Any:
|
|
"""Evaluate a parsed formula AST against field values and functions.
|
|
|
|
This is NOT ``eval()`` — it's a manual AST walker that only processes
|
|
whitelist nodes. Field references (Name nodes) are resolved from
|
|
``field_values``; function calls from ``functions``.
|
|
|
|
Args:
|
|
tree: Parsed AST from :func:`parse_formula`.
|
|
field_values: Mapping of field safe-name → value (scalar or list for aggregates).
|
|
functions: Mapping of function name → callable.
|
|
|
|
Returns:
|
|
The computed value.
|
|
"""
|
|
return _eval_node(tree.body, field_values, functions)
|
|
|
|
|
|
def _eval_node(node: ast.AST, fields: dict[str, Any], functions: dict[str, Any]) -> Any:
|
|
"""Recursively evaluate an AST node."""
|
|
if isinstance(node, ast.Constant):
|
|
return node.value
|
|
|
|
if isinstance(node, ast.Name):
|
|
if node.id not in fields:
|
|
raise FormulaParseError(f"Unknown field reference: {node.id}")
|
|
return fields[node.id]
|
|
|
|
if isinstance(node, ast.BinOp):
|
|
left = _eval_node(node.left, fields, functions)
|
|
right = _eval_node(node.right, fields, functions)
|
|
return _apply_binop(node.op, left, right)
|
|
|
|
if isinstance(node, ast.UnaryOp):
|
|
operand = _eval_node(node.operand, fields, functions)
|
|
if isinstance(node.op, ast.USub):
|
|
return -operand
|
|
if isinstance(node.op, ast.UAdd):
|
|
return +operand
|
|
if isinstance(node.op, ast.Not):
|
|
return not operand
|
|
raise FormulaSecurityError(f"Disallowed unary op: {type(node.op).__name__}")
|
|
|
|
if isinstance(node, ast.BoolOp):
|
|
values = [_eval_node(v, fields, functions) for v in node.values]
|
|
if isinstance(node.op, ast.And):
|
|
result = True
|
|
for v in values:
|
|
result = result and v
|
|
if not result:
|
|
return result
|
|
return result
|
|
if isinstance(node.op, ast.Or):
|
|
result = False
|
|
for v in values:
|
|
result = result or v
|
|
if not result:
|
|
return result
|
|
return result
|
|
raise FormulaSecurityError(f"Disallowed bool op: {type(node.op).__name__}")
|
|
|
|
if isinstance(node, ast.Compare):
|
|
left = _eval_node(node.left, fields, functions)
|
|
for op, comparator in zip(node.ops, node.comparators):
|
|
right = _eval_node(comparator, fields, functions)
|
|
if not _apply_compare(op, left, right):
|
|
return False
|
|
left = right
|
|
return True
|
|
|
|
if isinstance(node, ast.IfExp):
|
|
test = _eval_node(node.test, fields, functions)
|
|
if test:
|
|
return _eval_node(node.body, fields, functions)
|
|
return _eval_node(node.orelse, fields, functions)
|
|
|
|
if isinstance(node, ast.Call):
|
|
if not isinstance(node.func, ast.Name):
|
|
raise FormulaSecurityError("Only named function calls allowed")
|
|
func_name = node.func.id
|
|
if func_name not in functions:
|
|
raise UnknownFunctionError(f"Unknown function: {func_name}")
|
|
args = [_eval_node(a, fields, functions) for a in node.args]
|
|
return functions[func_name](*args)
|
|
|
|
raise FormulaSecurityError(f"Disallowed node during evaluation: {type(node).__name__}")
|
|
|
|
|
|
def _apply_binop(op: ast.AST, left: Any, right: Any) -> Any:
|
|
"""Apply a binary operator."""
|
|
if isinstance(op, ast.Add):
|
|
# String concat or numeric addition
|
|
if isinstance(left, str) or isinstance(right, str):
|
|
return f"{left}{right}"
|
|
return left + right
|
|
if isinstance(op, ast.Sub):
|
|
return left - right
|
|
if isinstance(op, ast.Mult):
|
|
return left * right
|
|
if isinstance(op, ast.Div):
|
|
return left / right
|
|
if isinstance(op, ast.Mod):
|
|
return left % right
|
|
if isinstance(op, ast.Pow):
|
|
return left**right
|
|
raise FormulaSecurityError(f"Disallowed binary op: {type(op).__name__}")
|
|
|
|
|
|
def _apply_compare(op: ast.AST, left: Any, right: Any) -> bool:
|
|
"""Apply a comparison operator."""
|
|
if isinstance(op, ast.Eq):
|
|
return left == right
|
|
if isinstance(op, ast.NotEq):
|
|
return left != right
|
|
if isinstance(op, ast.Lt):
|
|
return left < right
|
|
if isinstance(op, ast.Gt):
|
|
return left > right
|
|
if isinstance(op, ast.LtE):
|
|
return left <= right
|
|
if isinstance(op, ast.GtE):
|
|
return left >= right
|
|
raise FormulaSecurityError(f"Disallowed compare op: {type(op).__name__}")
|