200 lines
6.9 KiB
Python
200 lines
6.9 KiB
Python
"""Tests for the formula parser (KTD7 security + parsing).
|
|
|
|
Test-first per U3 execution note: parser, security constraints, and cycle
|
|
detection tests are written before the engine/recalc worker.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import pytest
|
|
|
|
from agentkit.bitable.formula.parser import (
|
|
FormulaParseError,
|
|
FormulaSecurityError,
|
|
UnknownFunctionError,
|
|
evaluate_ast,
|
|
parse_formula,
|
|
)
|
|
|
|
ALLOWED = {"SUM", "AVG", "COUNT", "MIN", "MAX", "ABS", "ROUND", "IF", "LEN", "CONCAT"}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Parsing happy paths
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_parse_simple_arithmetic() -> None:
|
|
tree, mapping = parse_formula("=1+2*3", ALLOWED)
|
|
assert mapping == {}
|
|
result = evaluate_ast(tree, {}, {})
|
|
assert result == 7
|
|
|
|
|
|
def test_parse_strips_equals_prefix() -> None:
|
|
tree1, _ = parse_formula("=1+1", ALLOWED)
|
|
tree2, _ = parse_formula("1+1", ALLOWED)
|
|
assert evaluate_ast(tree1, {}, {}) == evaluate_ast(tree2, {}, {}) == 2
|
|
|
|
|
|
def test_parse_field_reference() -> None:
|
|
tree, mapping = parse_formula("={field_abc} + 1", ALLOWED)
|
|
assert "field_abc" in mapping.values()
|
|
# Safe name is prefixed with _f_
|
|
safe_name = next(k for k, v in mapping.items() if v == "field_abc")
|
|
result = evaluate_ast(tree, {safe_name: 41}, {})
|
|
assert result == 42
|
|
|
|
|
|
def test_parse_uuid_field_reference() -> None:
|
|
"""Field IDs are UUIDs with hyphens — must be substituted to safe names."""
|
|
fid = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
|
|
tree, mapping = parse_formula(f"={{{fid}}} * 2", ALLOWED)
|
|
# The mapping should have a safe name → original UUID
|
|
assert fid in mapping.values()
|
|
# Evaluate using the safe name (prefixed with _f_)
|
|
safe_name = next(k for k, v in mapping.items() if v == fid)
|
|
assert safe_name.startswith("_f_")
|
|
result = evaluate_ast(tree, {safe_name: 21}, {})
|
|
assert result == 42
|
|
|
|
|
|
def test_parse_string_concatenation() -> None:
|
|
tree, _ = parse_formula('="hello" + " " + "world"', ALLOWED)
|
|
assert evaluate_ast(tree, {}, {}) == "hello world"
|
|
|
|
|
|
def test_parse_conditional_ifexp() -> None:
|
|
tree, _ = parse_formula("=1 if True else 2", ALLOWED)
|
|
assert evaluate_ast(tree, {}, {}) == 1
|
|
|
|
|
|
def test_parse_comparison() -> None:
|
|
tree, mapping = parse_formula("={f} > 5", ALLOWED)
|
|
safe_name = next(k for k, v in mapping.items() if v == "f")
|
|
assert evaluate_ast(tree, {safe_name: 10}, {}) is True
|
|
assert evaluate_ast(tree, {safe_name: 3}, {}) is False
|
|
|
|
|
|
def test_parse_boolean_ops() -> None:
|
|
tree, _ = parse_formula("=True and False", ALLOWED)
|
|
assert evaluate_ast(tree, {}, {}) is False
|
|
tree2, _ = parse_formula("=True or False", ALLOWED)
|
|
assert evaluate_ast(tree2, {}, {}) is True
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Function calls
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_parse_function_call_sum() -> None:
|
|
tree, mapping = parse_formula("=SUM({f1})", ALLOWED)
|
|
safe_name = next(k for k, v in mapping.items() if v == "f1")
|
|
result = evaluate_ast(tree, {safe_name: [1, 2, 3]}, {"SUM": sum})
|
|
assert result == 6
|
|
|
|
|
|
def test_parse_function_call_concat() -> None:
|
|
tree, mapping = parse_formula('=CONCAT({f1}, "-", {f2})', ALLOWED)
|
|
safe_f1 = next(k for k, v in mapping.items() if v == "f1")
|
|
safe_f2 = next(k for k, v in mapping.items() if v == "f2")
|
|
result = evaluate_ast(
|
|
tree, {safe_f1: "a", safe_f2: "b"}, {"CONCAT": lambda *a: "".join(str(x) for x in a)}
|
|
)
|
|
assert result == "a-b"
|
|
|
|
|
|
def test_parse_nested_function_calls() -> None:
|
|
tree, _ = parse_formula("=ABS(-5) + ROUND(3.7, 0)", ALLOWED)
|
|
funcs = {"ABS": abs, "ROUND": round}
|
|
result = evaluate_ast(tree, {}, funcs)
|
|
assert result == 9 # 5 + 4
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# KTD7 Security — disallowed nodes
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_security_rejects_attribute_access() -> None:
|
|
"""__import__('os') is rejected — it's a Call to an unregistered function.
|
|
(Attribute access like os.system would be caught by the Attribute node check,
|
|
but __import__ is caught earlier as an unknown function.)"""
|
|
with pytest.raises((FormulaSecurityError, UnknownFunctionError)):
|
|
parse_formula("=__import__('os')", ALLOWED)
|
|
|
|
|
|
def test_security_rejects_attribute_chain() -> None:
|
|
"""Attribute access like ''.join([]) is rejected by the Attribute node check."""
|
|
with pytest.raises(FormulaSecurityError):
|
|
parse_formula("=''.join([])", ALLOWED)
|
|
|
|
|
|
def test_security_rejects_lambda() -> None:
|
|
with pytest.raises(FormulaSecurityError):
|
|
parse_formula("=(lambda: 1)()", ALLOWED)
|
|
|
|
|
|
def test_security_rejects_subscript() -> None:
|
|
with pytest.raises(FormulaSecurityError):
|
|
parse_formula("=[1,2,3][0]", ALLOWED)
|
|
|
|
|
|
def test_security_rejects_assignment() -> None:
|
|
"""Assignment is a statement, not an expression — rejected at parse stage."""
|
|
with pytest.raises((FormulaSecurityError, FormulaParseError)):
|
|
parse_formula("=x = 1", ALLOWED)
|
|
|
|
|
|
def test_unknown_function_rejected() -> None:
|
|
with pytest.raises(UnknownFunctionError):
|
|
parse_formula("=UNKNOWN(1)", ALLOWED)
|
|
|
|
|
|
def test_eval_function_rejected_if_not_registered() -> None:
|
|
"""eval is not in the registry → UnknownFunctionError."""
|
|
with pytest.raises(UnknownFunctionError):
|
|
parse_formula("=eval('1+1')", ALLOWED)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Error paths
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_parse_error_unbalanced_parens() -> None:
|
|
with pytest.raises(FormulaParseError):
|
|
parse_formula("=(1+2", ALLOWED)
|
|
|
|
|
|
def test_parse_error_empty_formula() -> None:
|
|
with pytest.raises(FormulaParseError):
|
|
parse_formula("=", ALLOWED)
|
|
|
|
|
|
def test_parse_error_empty_string() -> None:
|
|
with pytest.raises(FormulaParseError):
|
|
parse_formula("", ALLOWED)
|
|
|
|
|
|
def test_evaluate_unknown_field_ref_raises() -> None:
|
|
tree, _ = parse_formula("={nonexistent} + 1", ALLOWED)
|
|
with pytest.raises(FormulaParseError, match="Unknown field reference"):
|
|
evaluate_ast(tree, {}, {})
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Mixed aggregate + row context
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_mixed_aggregate_and_row_context() -> None:
|
|
"""={f1} + SUM({f2}) — row f1 + column f2 sum."""
|
|
tree, mapping = parse_formula("={f1} + SUM({f2})", ALLOWED)
|
|
safe_f1 = next(k for k, v in mapping.items() if v == "f1")
|
|
safe_f2 = next(k for k, v in mapping.items() if v == "f2")
|
|
# f1 is a row value (scalar), f2 is a column value (list)
|
|
result = evaluate_ast(tree, {safe_f1: 10, safe_f2: [1, 2, 3]}, {"SUM": sum})
|
|
assert result == 16 # 10 + 6
|