fischer-agentkit/tests/unit/bitable/test_formula_parser.py

200 lines
6.9 KiB
Python

"""Tests for the formula parser (KTD7 security + parsing).
Test-first per U3 execution note: parser, security constraints, and cycle
detection tests are written before the engine/recalc worker.
"""
from __future__ import annotations
import pytest
from agentkit.bitable.formula.parser import (
FormulaParseError,
FormulaSecurityError,
UnknownFunctionError,
evaluate_ast,
parse_formula,
)
ALLOWED = {"SUM", "AVG", "COUNT", "MIN", "MAX", "ABS", "ROUND", "IF", "LEN", "CONCAT"}
# ---------------------------------------------------------------------------
# Parsing happy paths
# ---------------------------------------------------------------------------
def test_parse_simple_arithmetic() -> None:
tree, mapping = parse_formula("=1+2*3", ALLOWED)
assert mapping == {}
result = evaluate_ast(tree, {}, {})
assert result == 7
def test_parse_strips_equals_prefix() -> None:
tree1, _ = parse_formula("=1+1", ALLOWED)
tree2, _ = parse_formula("1+1", ALLOWED)
assert evaluate_ast(tree1, {}, {}) == evaluate_ast(tree2, {}, {}) == 2
def test_parse_field_reference() -> None:
tree, mapping = parse_formula("={field_abc} + 1", ALLOWED)
assert "field_abc" in mapping.values()
# Safe name is prefixed with _f_
safe_name = next(k for k, v in mapping.items() if v == "field_abc")
result = evaluate_ast(tree, {safe_name: 41}, {})
assert result == 42
def test_parse_uuid_field_reference() -> None:
"""Field IDs are UUIDs with hyphens — must be substituted to safe names."""
fid = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
tree, mapping = parse_formula(f"={{{fid}}} * 2", ALLOWED)
# The mapping should have a safe name → original UUID
assert fid in mapping.values()
# Evaluate using the safe name (prefixed with _f_)
safe_name = next(k for k, v in mapping.items() if v == fid)
assert safe_name.startswith("_f_")
result = evaluate_ast(tree, {safe_name: 21}, {})
assert result == 42
def test_parse_string_concatenation() -> None:
tree, _ = parse_formula('="hello" + " " + "world"', ALLOWED)
assert evaluate_ast(tree, {}, {}) == "hello world"
def test_parse_conditional_ifexp() -> None:
tree, _ = parse_formula("=1 if True else 2", ALLOWED)
assert evaluate_ast(tree, {}, {}) == 1
def test_parse_comparison() -> None:
tree, mapping = parse_formula("={f} > 5", ALLOWED)
safe_name = next(k for k, v in mapping.items() if v == "f")
assert evaluate_ast(tree, {safe_name: 10}, {}) is True
assert evaluate_ast(tree, {safe_name: 3}, {}) is False
def test_parse_boolean_ops() -> None:
tree, _ = parse_formula("=True and False", ALLOWED)
assert evaluate_ast(tree, {}, {}) is False
tree2, _ = parse_formula("=True or False", ALLOWED)
assert evaluate_ast(tree2, {}, {}) is True
# ---------------------------------------------------------------------------
# Function calls
# ---------------------------------------------------------------------------
def test_parse_function_call_sum() -> None:
tree, mapping = parse_formula("=SUM({f1})", ALLOWED)
safe_name = next(k for k, v in mapping.items() if v == "f1")
result = evaluate_ast(tree, {safe_name: [1, 2, 3]}, {"SUM": sum})
assert result == 6
def test_parse_function_call_concat() -> None:
tree, mapping = parse_formula('=CONCAT({f1}, "-", {f2})', ALLOWED)
safe_f1 = next(k for k, v in mapping.items() if v == "f1")
safe_f2 = next(k for k, v in mapping.items() if v == "f2")
result = evaluate_ast(
tree, {safe_f1: "a", safe_f2: "b"}, {"CONCAT": lambda *a: "".join(str(x) for x in a)}
)
assert result == "a-b"
def test_parse_nested_function_calls() -> None:
tree, _ = parse_formula("=ABS(-5) + ROUND(3.7, 0)", ALLOWED)
funcs = {"ABS": abs, "ROUND": round}
result = evaluate_ast(tree, {}, funcs)
assert result == 9 # 5 + 4
# ---------------------------------------------------------------------------
# KTD7 Security — disallowed nodes
# ---------------------------------------------------------------------------
def test_security_rejects_attribute_access() -> None:
"""__import__('os') is rejected — it's a Call to an unregistered function.
(Attribute access like os.system would be caught by the Attribute node check,
but __import__ is caught earlier as an unknown function.)"""
with pytest.raises((FormulaSecurityError, UnknownFunctionError)):
parse_formula("=__import__('os')", ALLOWED)
def test_security_rejects_attribute_chain() -> None:
"""Attribute access like ''.join([]) is rejected by the Attribute node check."""
with pytest.raises(FormulaSecurityError):
parse_formula("=''.join([])", ALLOWED)
def test_security_rejects_lambda() -> None:
with pytest.raises(FormulaSecurityError):
parse_formula("=(lambda: 1)()", ALLOWED)
def test_security_rejects_subscript() -> None:
with pytest.raises(FormulaSecurityError):
parse_formula("=[1,2,3][0]", ALLOWED)
def test_security_rejects_assignment() -> None:
"""Assignment is a statement, not an expression — rejected at parse stage."""
with pytest.raises((FormulaSecurityError, FormulaParseError)):
parse_formula("=x = 1", ALLOWED)
def test_unknown_function_rejected() -> None:
with pytest.raises(UnknownFunctionError):
parse_formula("=UNKNOWN(1)", ALLOWED)
def test_eval_function_rejected_if_not_registered() -> None:
"""eval is not in the registry → UnknownFunctionError."""
with pytest.raises(UnknownFunctionError):
parse_formula("=eval('1+1')", ALLOWED)
# ---------------------------------------------------------------------------
# Error paths
# ---------------------------------------------------------------------------
def test_parse_error_unbalanced_parens() -> None:
with pytest.raises(FormulaParseError):
parse_formula("=(1+2", ALLOWED)
def test_parse_error_empty_formula() -> None:
with pytest.raises(FormulaParseError):
parse_formula("=", ALLOWED)
def test_parse_error_empty_string() -> None:
with pytest.raises(FormulaParseError):
parse_formula("", ALLOWED)
def test_evaluate_unknown_field_ref_raises() -> None:
tree, _ = parse_formula("={nonexistent} + 1", ALLOWED)
with pytest.raises(FormulaParseError, match="Unknown field reference"):
evaluate_ast(tree, {}, {})
# ---------------------------------------------------------------------------
# Mixed aggregate + row context
# ---------------------------------------------------------------------------
def test_mixed_aggregate_and_row_context() -> None:
"""={f1} + SUM({f2}) — row f1 + column f2 sum."""
tree, mapping = parse_formula("={f1} + SUM({f2})", ALLOWED)
safe_f1 = next(k for k, v in mapping.items() if v == "f1")
safe_f2 = next(k for k, v in mapping.items() if v == "f2")
# f1 is a row value (scalar), f2 is a column value (list)
result = evaluate_ast(tree, {safe_f1: 10, safe_f2: [1, 2, 3]}, {"SUM": sum})
assert result == 16 # 10 + 6