"""Formula parser — converts formula strings to safe Python AST. KTD7 security: uses ``ast.parse`` then a restricted ``NodeVisitor`` that only allows whitelist nodes. **Never** uses ``eval()`` / ``exec()``. Formula syntax: - Starts with ``=`` (stripped before parsing). - Field references: ``{field_id}`` — converted to ``Name(field_id)`` nodes. - Arithmetic: ``+``, ``-``, ``*``, ``/``, ``%``, ``**``. - Comparison: ``==``, ``!=``, ``<``, ``>``, ``<=``, ``>=``. - Boolean: ``and``, ``or``, ``not``. - String concat: ``+`` on strings, or ``CONCAT(...)``. - Function calls: ``SUM(...)``, ``AVG(...)``, etc. (registered functions only). - Conditional: ``IF(cond, a, b)`` or Python ``a if cond else b``. - Literals: numbers, strings (single/double quotes). Examples:: =1+2*3 → 7 =SUM({f1}) → aggregate sum of column f1 ={f1} + {f2} → row-level sum of fields f1 and f2 =CONCAT({f1}, "-") → string concatenation """ from __future__ import annotations import ast import re from typing import Any from agentkit.bitable.formula.functions import FUNCTION_REGISTRY # ── Exceptions ──────────────────────────────────────────── class FormulaParseError(Exception): """Raised when a formula string cannot be parsed.""" class FormulaSecurityError(Exception): """Raised when a formula AST contains a disallowed node (KTD7).""" class UnknownFunctionError(Exception): """Raised when a formula calls a function not in the registry.""" # ── Field reference substitution ────────────────────────── # Match {field_id} — field IDs are UUIDs or alphanumeric. _FIELD_REF_RE = re.compile(r"\{([a-zA-Z0-9_-]+)\}") def _substitute_field_refs(formula: str) -> tuple[str, dict[str, str]]: """Replace ``{field_id}`` with ``_f_`` (a Python Name node). Field IDs are UUIDs that may start with a digit, which is invalid in Python identifiers. We prefix with ``_f_`` and replace hyphens with underscores. A reverse mapping is returned so the engine can map back to real field IDs. """ mapping: dict[str, str] = {} def _replace(match: re.Match[str]) -> str: field_id = match.group(1) # Convert UUID-style field_id to a valid Python identifier safe_name = "_f_" + field_id.replace("-", "_") mapping[safe_name] = field_id return safe_name result = _FIELD_REF_RE.sub(_replace, formula) return result, mapping # ── AST whitelist (KTD7) ────────────────────────────────── _ALLOWED_NODES: frozenset[type[ast.AST]] = frozenset( { ast.Expression, ast.BinOp, ast.UnaryOp, ast.BoolOp, ast.Compare, ast.Call, ast.Name, ast.Constant, ast.IfExp, ast.Load, ast.Add, ast.Sub, ast.Mult, ast.Div, ast.Mod, ast.Pow, ast.USub, ast.UAdd, ast.Not, ast.And, ast.Or, ast.Eq, ast.NotEq, ast.Lt, ast.Gt, ast.LtE, ast.GtE, } ) class _SecurityVisitor(ast.NodeVisitor): """Visit AST nodes, reject any not in the whitelist (KTD7).""" def __init__(self, allowed_functions: set[str]) -> None: self._allowed_functions = allowed_functions def generic_visit(self, node: ast.AST) -> None: # noqa: D401 node_type = type(node) if node_type not in _ALLOWED_NODES: raise FormulaSecurityError( f"Disallowed AST node: {node_type.__name__}. Formula contains unsafe constructs." ) super().generic_visit(node) def visit_Call(self, node: ast.Call) -> None: # Check that the function being called is a registered Name if not isinstance(node.func, ast.Name): raise FormulaSecurityError( "Only direct function calls by name are allowed. " "Method calls and attribute access are forbidden." ) if node.func.id not in self._allowed_functions: raise UnknownFunctionError( f"Unknown function: '{node.func.id}'. Allowed: {sorted(self._allowed_functions)}" ) self.generic_visit(node) # ── Public API ──────────────────────────────────────────── def parse_formula( formula: str, allowed_functions: set[str] | None = None ) -> tuple[ast.Expression, dict[str, str]]: """Parse a formula string into a safe AST. Args: formula: Formula string, optionally starting with ``=``. allowed_functions: Set of registered function names. If None, all functions are allowed (used for syntax-only validation). Returns: Tuple of (AST expression, field_ref_mapping) where field_ref_mapping maps safe Python identifiers to original field IDs. Raises: FormulaParseError: Syntax error in formula. FormulaSecurityError: Formula contains disallowed AST nodes. UnknownFunctionError: Formula calls an unregistered function. """ expr = formula.strip() if expr.startswith("="): expr = expr[1:] if not expr: raise FormulaParseError("Empty formula") # Substitute field references {field_id} → safe_name substituted, field_mapping = _substitute_field_refs(expr) try: tree = ast.parse(substituted, mode="eval") except SyntaxError as e: raise FormulaParseError(f"Syntax error in formula: {e}") from e # Security check # When allowed_functions is None, use all registered functions (syntax-only validation) if allowed_functions is None: allowed = set(FUNCTION_REGISTRY.keys()) else: allowed = allowed_functions visitor = _SecurityVisitor(allowed) visitor.visit(tree) return tree, field_mapping # type: ignore[return-value] def evaluate_ast( tree: ast.Expression, field_values: dict[str, Any], functions: dict[str, Any], ) -> Any: """Evaluate a parsed formula AST against field values and functions. This is NOT ``eval()`` — it's a manual AST walker that only processes whitelist nodes. Field references (Name nodes) are resolved from ``field_values``; function calls from ``functions``. Args: tree: Parsed AST from :func:`parse_formula`. field_values: Mapping of field safe-name → value (scalar or list for aggregates). functions: Mapping of function name → callable. Returns: The computed value. """ return _eval_node(tree.body, field_values, functions) def _eval_node(node: ast.AST, fields: dict[str, Any], functions: dict[str, Any]) -> Any: """Recursively evaluate an AST node.""" if isinstance(node, ast.Constant): return node.value if isinstance(node, ast.Name): if node.id not in fields: raise FormulaParseError(f"Unknown field reference: {node.id}") return fields[node.id] if isinstance(node, ast.BinOp): left = _eval_node(node.left, fields, functions) right = _eval_node(node.right, fields, functions) return _apply_binop(node.op, left, right) if isinstance(node, ast.UnaryOp): operand = _eval_node(node.operand, fields, functions) if isinstance(node.op, ast.USub): return -operand if isinstance(node.op, ast.UAdd): return +operand if isinstance(node.op, ast.Not): return not operand raise FormulaSecurityError(f"Disallowed unary op: {type(node.op).__name__}") if isinstance(node, ast.BoolOp): values = [_eval_node(v, fields, functions) for v in node.values] if isinstance(node.op, ast.And): result = True for v in values: result = result and v if not result: return result return result if isinstance(node.op, ast.Or): result = False for v in values: result = result or v if not result: return result return result raise FormulaSecurityError(f"Disallowed bool op: {type(node.op).__name__}") if isinstance(node, ast.Compare): left = _eval_node(node.left, fields, functions) for op, comparator in zip(node.ops, node.comparators): right = _eval_node(comparator, fields, functions) if not _apply_compare(op, left, right): return False left = right return True if isinstance(node, ast.IfExp): test = _eval_node(node.test, fields, functions) if test: return _eval_node(node.body, fields, functions) return _eval_node(node.orelse, fields, functions) if isinstance(node, ast.Call): if not isinstance(node.func, ast.Name): raise FormulaSecurityError("Only named function calls allowed") func_name = node.func.id if func_name not in functions: raise UnknownFunctionError(f"Unknown function: {func_name}") args = [_eval_node(a, fields, functions) for a in node.args] return functions[func_name](*args) raise FormulaSecurityError(f"Disallowed node during evaluation: {type(node).__name__}") def _apply_binop(op: ast.AST, left: Any, right: Any) -> Any: """Apply a binary operator.""" if isinstance(op, ast.Add): # String concat or numeric addition if isinstance(left, str) or isinstance(right, str): return f"{left}{right}" return left + right if isinstance(op, ast.Sub): return left - right if isinstance(op, ast.Mult): return left * right if isinstance(op, ast.Div): return left / right if isinstance(op, ast.Mod): return left % right if isinstance(op, ast.Pow): return left**right raise FormulaSecurityError(f"Disallowed binary op: {type(op).__name__}") def _apply_compare(op: ast.AST, left: Any, right: Any) -> bool: """Apply a comparison operator.""" if isinstance(op, ast.Eq): return left == right if isinstance(op, ast.NotEq): return left != right if isinstance(op, ast.Lt): return left < right if isinstance(op, ast.Gt): return left > right if isinstance(op, ast.LtE): return left <= right if isinstance(op, ast.GtE): return left >= right raise FormulaSecurityError(f"Disallowed compare op: {type(op).__name__}")