"""Tests for the formula parser (KTD7 security + parsing). Test-first per U3 execution note: parser, security constraints, and cycle detection tests are written before the engine/recalc worker. """ from __future__ import annotations import pytest from agentkit.bitable.formula.parser import ( FormulaParseError, FormulaSecurityError, UnknownFunctionError, evaluate_ast, parse_formula, ) ALLOWED = {"SUM", "AVG", "COUNT", "MIN", "MAX", "ABS", "ROUND", "IF", "LEN", "CONCAT"} # --------------------------------------------------------------------------- # Parsing happy paths # --------------------------------------------------------------------------- def test_parse_simple_arithmetic() -> None: tree, mapping = parse_formula("=1+2*3", ALLOWED) assert mapping == {} result = evaluate_ast(tree, {}, {}) assert result == 7 def test_parse_strips_equals_prefix() -> None: tree1, _ = parse_formula("=1+1", ALLOWED) tree2, _ = parse_formula("1+1", ALLOWED) assert evaluate_ast(tree1, {}, {}) == evaluate_ast(tree2, {}, {}) == 2 def test_parse_field_reference() -> None: tree, mapping = parse_formula("={field_abc} + 1", ALLOWED) assert "field_abc" in mapping.values() # Safe name is prefixed with _f_ safe_name = next(k for k, v in mapping.items() if v == "field_abc") result = evaluate_ast(tree, {safe_name: 41}, {}) assert result == 42 def test_parse_uuid_field_reference() -> None: """Field IDs are UUIDs with hyphens — must be substituted to safe names.""" fid = "a1b2c3d4-e5f6-7890-abcd-ef1234567890" tree, mapping = parse_formula(f"={{{fid}}} * 2", ALLOWED) # The mapping should have a safe name → original UUID assert fid in mapping.values() # Evaluate using the safe name (prefixed with _f_) safe_name = next(k for k, v in mapping.items() if v == fid) assert safe_name.startswith("_f_") result = evaluate_ast(tree, {safe_name: 21}, {}) assert result == 42 def test_parse_string_concatenation() -> None: tree, _ = parse_formula('="hello" + " " + "world"', ALLOWED) assert evaluate_ast(tree, {}, {}) == "hello world" def test_parse_conditional_ifexp() -> None: tree, _ = parse_formula("=1 if True else 2", ALLOWED) assert evaluate_ast(tree, {}, {}) == 1 def test_parse_comparison() -> None: tree, mapping = parse_formula("={f} > 5", ALLOWED) safe_name = next(k for k, v in mapping.items() if v == "f") assert evaluate_ast(tree, {safe_name: 10}, {}) is True assert evaluate_ast(tree, {safe_name: 3}, {}) is False def test_parse_boolean_ops() -> None: tree, _ = parse_formula("=True and False", ALLOWED) assert evaluate_ast(tree, {}, {}) is False tree2, _ = parse_formula("=True or False", ALLOWED) assert evaluate_ast(tree2, {}, {}) is True # --------------------------------------------------------------------------- # Function calls # --------------------------------------------------------------------------- def test_parse_function_call_sum() -> None: tree, mapping = parse_formula("=SUM({f1})", ALLOWED) safe_name = next(k for k, v in mapping.items() if v == "f1") result = evaluate_ast(tree, {safe_name: [1, 2, 3]}, {"SUM": sum}) assert result == 6 def test_parse_function_call_concat() -> None: tree, mapping = parse_formula('=CONCAT({f1}, "-", {f2})', ALLOWED) safe_f1 = next(k for k, v in mapping.items() if v == "f1") safe_f2 = next(k for k, v in mapping.items() if v == "f2") result = evaluate_ast( tree, {safe_f1: "a", safe_f2: "b"}, {"CONCAT": lambda *a: "".join(str(x) for x in a)} ) assert result == "a-b" def test_parse_nested_function_calls() -> None: tree, _ = parse_formula("=ABS(-5) + ROUND(3.7, 0)", ALLOWED) funcs = {"ABS": abs, "ROUND": round} result = evaluate_ast(tree, {}, funcs) assert result == 9 # 5 + 4 # --------------------------------------------------------------------------- # KTD7 Security — disallowed nodes # --------------------------------------------------------------------------- def test_security_rejects_attribute_access() -> None: """__import__('os') is rejected — it's a Call to an unregistered function. (Attribute access like os.system would be caught by the Attribute node check, but __import__ is caught earlier as an unknown function.)""" with pytest.raises((FormulaSecurityError, UnknownFunctionError)): parse_formula("=__import__('os')", ALLOWED) def test_security_rejects_attribute_chain() -> None: """Attribute access like ''.join([]) is rejected by the Attribute node check.""" with pytest.raises(FormulaSecurityError): parse_formula("=''.join([])", ALLOWED) def test_security_rejects_lambda() -> None: with pytest.raises(FormulaSecurityError): parse_formula("=(lambda: 1)()", ALLOWED) def test_security_rejects_subscript() -> None: with pytest.raises(FormulaSecurityError): parse_formula("=[1,2,3][0]", ALLOWED) def test_security_rejects_assignment() -> None: """Assignment is a statement, not an expression — rejected at parse stage.""" with pytest.raises((FormulaSecurityError, FormulaParseError)): parse_formula("=x = 1", ALLOWED) def test_unknown_function_rejected() -> None: with pytest.raises(UnknownFunctionError): parse_formula("=UNKNOWN(1)", ALLOWED) def test_eval_function_rejected_if_not_registered() -> None: """eval is not in the registry → UnknownFunctionError.""" with pytest.raises(UnknownFunctionError): parse_formula("=eval('1+1')", ALLOWED) # --------------------------------------------------------------------------- # Error paths # --------------------------------------------------------------------------- def test_parse_error_unbalanced_parens() -> None: with pytest.raises(FormulaParseError): parse_formula("=(1+2", ALLOWED) def test_parse_error_empty_formula() -> None: with pytest.raises(FormulaParseError): parse_formula("=", ALLOWED) def test_parse_error_empty_string() -> None: with pytest.raises(FormulaParseError): parse_formula("", ALLOWED) def test_evaluate_unknown_field_ref_raises() -> None: tree, _ = parse_formula("={nonexistent} + 1", ALLOWED) with pytest.raises(FormulaParseError, match="Unknown field reference"): evaluate_ast(tree, {}, {}) # --------------------------------------------------------------------------- # Mixed aggregate + row context # --------------------------------------------------------------------------- def test_mixed_aggregate_and_row_context() -> None: """={f1} + SUM({f2}) — row f1 + column f2 sum.""" tree, mapping = parse_formula("={f1} + SUM({f2})", ALLOWED) safe_f1 = next(k for k, v in mapping.items() if v == "f1") safe_f2 = next(k for k, v in mapping.items() if v == "f2") # f1 is a row value (scalar), f2 is a column value (list) result = evaluate_ast(tree, {safe_f1: 10, safe_f2: [1, 2, 3]}, {"SUM": sum}) assert result == 16 # 10 + 6