fischer-agentkit/src/agentkit/bitable/formula/parser.py

312 lines
10 KiB
Python

"""Formula parser — converts formula strings to safe Python AST.
KTD7 security: uses ``ast.parse`` then a restricted ``NodeVisitor`` that
only allows whitelist nodes. **Never** uses ``eval()`` / ``exec()``.
Formula syntax:
- Starts with ``=`` (stripped before parsing).
- Field references: ``{field_id}`` — converted to ``Name(field_id)`` nodes.
- Arithmetic: ``+``, ``-``, ``*``, ``/``, ``%``, ``**``.
- Comparison: ``==``, ``!=``, ``<``, ``>``, ``<=``, ``>=``.
- Boolean: ``and``, ``or``, ``not``.
- String concat: ``+`` on strings, or ``CONCAT(...)``.
- Function calls: ``SUM(...)``, ``AVG(...)``, etc. (registered functions only).
- Conditional: ``IF(cond, a, b)`` or Python ``a if cond else b``.
- Literals: numbers, strings (single/double quotes).
Examples::
=1+2*3 → 7
=SUM({f1}) → aggregate sum of column f1
={f1} + {f2} → row-level sum of fields f1 and f2
=CONCAT({f1}, "-") → string concatenation
"""
from __future__ import annotations
import ast
import re
from typing import Any
from agentkit.bitable.formula.functions import FUNCTION_REGISTRY
# ── Exceptions ────────────────────────────────────────────
class FormulaParseError(Exception):
"""Raised when a formula string cannot be parsed."""
class FormulaSecurityError(Exception):
"""Raised when a formula AST contains a disallowed node (KTD7)."""
class UnknownFunctionError(Exception):
"""Raised when a formula calls a function not in the registry."""
# ── Field reference substitution ──────────────────────────
# Match {field_id} — field IDs are UUIDs or alphanumeric.
_FIELD_REF_RE = re.compile(r"\{([a-zA-Z0-9_-]+)\}")
def _substitute_field_refs(formula: str) -> tuple[str, dict[str, str]]:
"""Replace ``{field_id}`` with ``_f_<safe_name>`` (a Python Name node).
Field IDs are UUIDs that may start with a digit, which is invalid in Python
identifiers. We prefix with ``_f_`` and replace hyphens with underscores.
A reverse mapping is returned so the engine can map back to real field IDs.
"""
mapping: dict[str, str] = {}
def _replace(match: re.Match[str]) -> str:
field_id = match.group(1)
# Convert UUID-style field_id to a valid Python identifier
safe_name = "_f_" + field_id.replace("-", "_")
mapping[safe_name] = field_id
return safe_name
result = _FIELD_REF_RE.sub(_replace, formula)
return result, mapping
# ── AST whitelist (KTD7) ──────────────────────────────────
_ALLOWED_NODES: frozenset[type[ast.AST]] = frozenset(
{
ast.Expression,
ast.BinOp,
ast.UnaryOp,
ast.BoolOp,
ast.Compare,
ast.Call,
ast.Name,
ast.Constant,
ast.IfExp,
ast.Load,
ast.Add,
ast.Sub,
ast.Mult,
ast.Div,
ast.Mod,
ast.Pow,
ast.USub,
ast.UAdd,
ast.Not,
ast.And,
ast.Or,
ast.Eq,
ast.NotEq,
ast.Lt,
ast.Gt,
ast.LtE,
ast.GtE,
}
)
class _SecurityVisitor(ast.NodeVisitor):
"""Visit AST nodes, reject any not in the whitelist (KTD7)."""
def __init__(self, allowed_functions: set[str]) -> None:
self._allowed_functions = allowed_functions
def generic_visit(self, node: ast.AST) -> None: # noqa: D401
node_type = type(node)
if node_type not in _ALLOWED_NODES:
raise FormulaSecurityError(
f"Disallowed AST node: {node_type.__name__}. Formula contains unsafe constructs."
)
super().generic_visit(node)
def visit_Call(self, node: ast.Call) -> None:
# Check that the function being called is a registered Name
if not isinstance(node.func, ast.Name):
raise FormulaSecurityError(
"Only direct function calls by name are allowed. "
"Method calls and attribute access are forbidden."
)
if node.func.id not in self._allowed_functions:
raise UnknownFunctionError(
f"Unknown function: '{node.func.id}'. Allowed: {sorted(self._allowed_functions)}"
)
self.generic_visit(node)
# ── Public API ────────────────────────────────────────────
def parse_formula(
formula: str, allowed_functions: set[str] | None = None
) -> tuple[ast.Expression, dict[str, str]]:
"""Parse a formula string into a safe AST.
Args:
formula: Formula string, optionally starting with ``=``.
allowed_functions: Set of registered function names. If None,
all functions are allowed (used for syntax-only validation).
Returns:
Tuple of (AST expression, field_ref_mapping) where
field_ref_mapping maps safe Python identifiers to original field IDs.
Raises:
FormulaParseError: Syntax error in formula.
FormulaSecurityError: Formula contains disallowed AST nodes.
UnknownFunctionError: Formula calls an unregistered function.
"""
expr = formula.strip()
if expr.startswith("="):
expr = expr[1:]
if not expr:
raise FormulaParseError("Empty formula")
# Substitute field references {field_id} → safe_name
substituted, field_mapping = _substitute_field_refs(expr)
try:
tree = ast.parse(substituted, mode="eval")
except SyntaxError as e:
raise FormulaParseError(f"Syntax error in formula: {e}") from e
# Security check
# When allowed_functions is None, use all registered functions (syntax-only validation)
if allowed_functions is None:
allowed = set(FUNCTION_REGISTRY.keys())
else:
allowed = allowed_functions
visitor = _SecurityVisitor(allowed)
visitor.visit(tree)
return tree, field_mapping # type: ignore[return-value]
def evaluate_ast(
tree: ast.Expression,
field_values: dict[str, Any],
functions: dict[str, Any],
) -> Any:
"""Evaluate a parsed formula AST against field values and functions.
This is NOT ``eval()`` — it's a manual AST walker that only processes
whitelist nodes. Field references (Name nodes) are resolved from
``field_values``; function calls from ``functions``.
Args:
tree: Parsed AST from :func:`parse_formula`.
field_values: Mapping of field safe-name → value (scalar or list for aggregates).
functions: Mapping of function name → callable.
Returns:
The computed value.
"""
return _eval_node(tree.body, field_values, functions)
def _eval_node(node: ast.AST, fields: dict[str, Any], functions: dict[str, Any]) -> Any:
"""Recursively evaluate an AST node."""
if isinstance(node, ast.Constant):
return node.value
if isinstance(node, ast.Name):
if node.id not in fields:
raise FormulaParseError(f"Unknown field reference: {node.id}")
return fields[node.id]
if isinstance(node, ast.BinOp):
left = _eval_node(node.left, fields, functions)
right = _eval_node(node.right, fields, functions)
return _apply_binop(node.op, left, right)
if isinstance(node, ast.UnaryOp):
operand = _eval_node(node.operand, fields, functions)
if isinstance(node.op, ast.USub):
return -operand
if isinstance(node.op, ast.UAdd):
return +operand
if isinstance(node.op, ast.Not):
return not operand
raise FormulaSecurityError(f"Disallowed unary op: {type(node.op).__name__}")
if isinstance(node, ast.BoolOp):
values = [_eval_node(v, fields, functions) for v in node.values]
if isinstance(node.op, ast.And):
result = True
for v in values:
result = result and v
if not result:
return result
return result
if isinstance(node.op, ast.Or):
result = False
for v in values:
result = result or v
if not result:
return result
return result
raise FormulaSecurityError(f"Disallowed bool op: {type(node.op).__name__}")
if isinstance(node, ast.Compare):
left = _eval_node(node.left, fields, functions)
for op, comparator in zip(node.ops, node.comparators):
right = _eval_node(comparator, fields, functions)
if not _apply_compare(op, left, right):
return False
left = right
return True
if isinstance(node, ast.IfExp):
test = _eval_node(node.test, fields, functions)
if test:
return _eval_node(node.body, fields, functions)
return _eval_node(node.orelse, fields, functions)
if isinstance(node, ast.Call):
if not isinstance(node.func, ast.Name):
raise FormulaSecurityError("Only named function calls allowed")
func_name = node.func.id
if func_name not in functions:
raise UnknownFunctionError(f"Unknown function: {func_name}")
args = [_eval_node(a, fields, functions) for a in node.args]
return functions[func_name](*args)
raise FormulaSecurityError(f"Disallowed node during evaluation: {type(node).__name__}")
def _apply_binop(op: ast.AST, left: Any, right: Any) -> Any:
"""Apply a binary operator."""
if isinstance(op, ast.Add):
# String concat or numeric addition
if isinstance(left, str) or isinstance(right, str):
return f"{left}{right}"
return left + right
if isinstance(op, ast.Sub):
return left - right
if isinstance(op, ast.Mult):
return left * right
if isinstance(op, ast.Div):
return left / right
if isinstance(op, ast.Mod):
return left % right
if isinstance(op, ast.Pow):
return left**right
raise FormulaSecurityError(f"Disallowed binary op: {type(op).__name__}")
def _apply_compare(op: ast.AST, left: Any, right: Any) -> bool:
"""Apply a comparison operator."""
if isinstance(op, ast.Eq):
return left == right
if isinstance(op, ast.NotEq):
return left != right
if isinstance(op, ast.Lt):
return left < right
if isinstance(op, ast.Gt):
return left > right
if isinstance(op, ast.LtE):
return left <= right
if isinstance(op, ast.GtE):
return left >= right
raise FormulaSecurityError(f"Disallowed compare op: {type(op).__name__}")