212 lines
7.3 KiB
Python
212 lines
7.3 KiB
Python
"""Tests for the formula engine — DAG, cycle detection, evaluation.
|
|
|
|
Covers: topological sort, circular reference detection, aggregate vs row
|
|
context, formula-to-formula dependencies, and the built-in function library.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import pytest
|
|
|
|
from agentkit.bitable.formula.engine import (
|
|
CircularReferenceError,
|
|
FormulaEngine,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Basic evaluation
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_engine_evaluate_simple_arithmetic() -> None:
|
|
"""=1+2*3 → 7"""
|
|
engine = FormulaEngine()
|
|
engine.add_formula("calc", "=1+2*3")
|
|
result = engine.evaluate("calc", row_values={})
|
|
assert result == 7
|
|
|
|
|
|
def test_engine_evaluate_row_reference() -> None:
|
|
"""={f1} + {f2} → row-level sum"""
|
|
engine = FormulaEngine()
|
|
engine.add_formula("sum", "={f1} + {f2}")
|
|
result = engine.evaluate("sum", row_values={"f1": 10, "f2": 20})
|
|
assert result == 30
|
|
|
|
|
|
def test_engine_evaluate_aggregate_sum() -> None:
|
|
"""=SUM({f1}) → aggregate sum of column"""
|
|
engine = FormulaEngine()
|
|
engine.add_formula("total", "=SUM({f1})")
|
|
result = engine.evaluate("total", row_values={}, column_values={"f1": [1, 2, 3]})
|
|
assert result == 6
|
|
|
|
|
|
def test_engine_evaluate_aggregate_avg() -> None:
|
|
"""=AVG({f1}) → average of column"""
|
|
engine = FormulaEngine()
|
|
engine.add_formula("avg", "=AVG({f1})")
|
|
result = engine.evaluate("avg", row_values={}, column_values={"f1": [10, 20, 30]})
|
|
assert result == 20.0
|
|
|
|
|
|
def test_engine_evaluate_aggregate_count() -> None:
|
|
"""=COUNT({f1}) → count of non-empty values"""
|
|
engine = FormulaEngine()
|
|
engine.add_formula("cnt", "=COUNT({f1})")
|
|
result = engine.evaluate("cnt", row_values={}, column_values={"f1": [1, None, 3, "", 5]})
|
|
assert result == 3 # None and "" are ignored
|
|
|
|
|
|
def test_engine_evaluate_mixed_aggregate_and_row() -> None:
|
|
"""={f1} + SUM({f2}) → row f1 + column f2 sum"""
|
|
engine = FormulaEngine()
|
|
engine.add_formula("mixed", "={f1} + SUM({f2})")
|
|
result = engine.evaluate("mixed", row_values={"f1": 10}, column_values={"f2": [1, 2, 3]})
|
|
assert result == 16 # 10 + 6
|
|
|
|
|
|
def test_engine_evaluate_concat() -> None:
|
|
"""=CONCAT({f1}, "-", {f2}) → string concat"""
|
|
engine = FormulaEngine()
|
|
engine.add_formula("label", '=CONCAT({f1}, "-", {f2})')
|
|
result = engine.evaluate("label", row_values={"f1": "a", "f2": "b"})
|
|
assert result == "a-b"
|
|
|
|
|
|
def test_engine_evaluate_if_function() -> None:
|
|
"""=IF({f1} > 5, "big", "small")"""
|
|
engine = FormulaEngine()
|
|
engine.add_formula("size", '=IF({f1} > 5, "big", "small")')
|
|
assert engine.evaluate("size", row_values={"f1": 10}) == "big"
|
|
assert engine.evaluate("size", row_values={"f1": 3}) == "small"
|
|
|
|
|
|
def test_engine_evaluate_min_max() -> None:
|
|
engine = FormulaEngine()
|
|
engine.add_formula("mn", "=MIN({f1})")
|
|
engine.add_formula("mx", "=MAX({f1})")
|
|
cols = {"f1": [3, 1, 4, 1, 5, 9, 2, 6]}
|
|
assert engine.evaluate("mn", {}, cols) == 1
|
|
assert engine.evaluate("mx", {}, cols) == 9
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# DAG: dependencies and dependents
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_engine_get_dependencies() -> None:
|
|
engine = FormulaEngine()
|
|
engine.add_formula("c", "={a} + {b}")
|
|
assert engine.get_dependencies("c") == {"a", "b"}
|
|
assert engine.get_dependents("a") == {"c"}
|
|
assert engine.get_dependents("b") == {"c"}
|
|
|
|
|
|
def test_engine_topological_order() -> None:
|
|
"""c depends on b, b depends on a → order: a, b, c"""
|
|
engine = FormulaEngine()
|
|
engine.add_formula("c", "={b} + 1")
|
|
engine.add_formula("b", "={a} + 1")
|
|
engine.add_formula("a", "=1")
|
|
order = engine.topological_order()
|
|
assert order.index("a") < order.index("b")
|
|
assert order.index("b") < order.index("c")
|
|
|
|
|
|
def test_engine_evaluate_all_for_record() -> None:
|
|
"""Formula-to-formula dependency: c = b + 1, b = a + 1, a = 5 → c = 7"""
|
|
engine = FormulaEngine()
|
|
engine.add_formula("a", "=5")
|
|
engine.add_formula("b", "={a} + 1")
|
|
engine.add_formula("c", "={b} + 1")
|
|
results = engine.evaluate_all_for_record(row_values={})
|
|
assert results["a"] == 5
|
|
assert results["b"] == 6
|
|
assert results["c"] == 7
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Circular reference detection
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_circular_reference_detected() -> None:
|
|
"""f1 = f2 + 1, f2 = f1 + 1 → CircularReferenceError"""
|
|
engine = FormulaEngine()
|
|
engine.add_formula("f1", "={f2} + 1")
|
|
with pytest.raises(CircularReferenceError):
|
|
engine.add_formula("f2", "={f1} + 1")
|
|
|
|
|
|
def test_circular_reference_rollback() -> None:
|
|
"""When cycle is detected, the formula is not added (rollback)."""
|
|
engine = FormulaEngine()
|
|
engine.add_formula("f1", "={f2} + 1")
|
|
with pytest.raises(CircularReferenceError):
|
|
engine.add_formula("f2", "={f1} + 1")
|
|
# f2 should not be in the engine
|
|
assert "f2" not in engine._formulas
|
|
assert "f2" not in engine._dag
|
|
|
|
|
|
def test_self_reference_detected() -> None:
|
|
"""f1 = f1 + 1 → CircularReferenceError"""
|
|
engine = FormulaEngine()
|
|
with pytest.raises(CircularReferenceError):
|
|
engine.add_formula("f1", "={f1} + 1")
|
|
|
|
|
|
def test_remove_formula_breaks_cycle() -> None:
|
|
"""Remove a formula, then the cycle can be broken."""
|
|
engine = FormulaEngine()
|
|
engine.add_formula("f1", "={f2} + 1")
|
|
# Can't add f2 = f1 + 1 (cycle)
|
|
with pytest.raises(CircularReferenceError):
|
|
engine.add_formula("f2", "={f1} + 1")
|
|
# Remove f1, now f2 can be added standalone
|
|
engine.remove_formula("f1")
|
|
engine.add_formula("f2", "=42")
|
|
assert engine.evaluate("f2", {}) == 42
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Edge cases
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_evaluate_missing_field_value_is_none() -> None:
|
|
"""Missing field values are None — arithmetic on None raises TypeError."""
|
|
engine = FormulaEngine()
|
|
engine.add_formula("calc", "={missing_field} + 1")
|
|
# The engine passes None for missing fields (row_values.get returns None)
|
|
with pytest.raises(TypeError):
|
|
engine.evaluate("calc", row_values={})
|
|
|
|
|
|
def test_aggregate_ignores_none_and_empty() -> None:
|
|
"""SUM ignores None and empty string values."""
|
|
engine = FormulaEngine()
|
|
engine.add_formula("total", "=SUM({f1})")
|
|
result = engine.evaluate("total", row_values={}, column_values={"f1": [1, None, 2, "", 3]})
|
|
assert result == 6
|
|
|
|
|
|
def test_division_by_zero_returns_error_in_evaluate_all() -> None:
|
|
"""Division by zero is caught in evaluate_all_for_record, returns error dict."""
|
|
engine = FormulaEngine()
|
|
engine.add_formula("calc", "={f1} / 0")
|
|
results = engine.evaluate_all_for_record(row_values={"f1": 10})
|
|
assert "__error" in results["calc"]
|
|
|
|
|
|
def test_engine_with_uuid_field_ids() -> None:
|
|
"""Field IDs with hyphens (UUIDs) work correctly."""
|
|
fid = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
|
|
engine = FormulaEngine()
|
|
engine.add_formula("calc", f"={{{fid}}} * 2")
|
|
result = engine.evaluate("calc", row_values={fid: 21})
|
|
assert result == 42
|