fischer-agentkit/tests/unit/test_pipeline_state.py

662 lines
28 KiB
Python

"""Unit tests for Pipeline execution state persistence."""
from __future__ import annotations
import asyncio
import json
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from agentkit.orchestrator.pipeline_engine import PipelineEngine
from agentkit.orchestrator.pipeline_schema import Pipeline, PipelineStage, StageStatus
from agentkit.orchestrator.pipeline_state import (
PipelineStateMemory,
PipelineStatePG,
PipelineStateRedis,
PipelineStateManager,
)
# ═══════════════════════════════════════════════════════════════
# PipelineStateMemory
# ═══════════════════════════════════════════════════════════════
class TestPipelineStateMemory:
"""Tests for in-memory pipeline state storage."""
@pytest.fixture
def store(self) -> PipelineStateMemory:
return PipelineStateMemory()
@pytest.mark.asyncio
async def test_create_execution(self, store: PipelineStateMemory):
eid = await store.create_execution(
pipeline_name="test_pipeline",
steps=["step_a", "step_b"],
input_data={"key": "value"},
)
assert eid is not None
state = await store.get_execution(eid)
assert state is not None
assert state["pipeline_name"] == "test_pipeline"
assert state["status"] == "running"
assert state["current_step"] == "step_a"
assert state["completed_steps"] == []
assert state["input_data"] == {"key": "value"}
@pytest.mark.asyncio
async def test_update_step_completed(self, store: PipelineStateMemory):
eid = await store.create_execution("p", ["s1", "s2"])
await store.update_step(eid, "s1", "completed", output={"result": 42})
state = await store.get_execution(eid)
assert "s1" in state["completed_steps"]
assert state["step_results"]["s1"] == {"result": 42}
@pytest.mark.asyncio
async def test_update_step_failed(self, store: PipelineStateMemory):
eid = await store.create_execution("p", ["s1"])
await store.update_step(eid, "s1", "failed", error="boom")
state = await store.get_execution(eid)
assert state["error_message"] == "boom"
@pytest.mark.asyncio
async def test_complete_execution(self, store: PipelineStateMemory):
eid = await store.create_execution("p", ["s1"])
await store.complete_execution(eid, final_output={"done": True})
state = await store.get_execution(eid)
assert state["status"] == "completed"
assert state["final_output"] == {"done": True}
assert state["completed_at"] is not None
@pytest.mark.asyncio
async def test_fail_execution(self, store: PipelineStateMemory):
eid = await store.create_execution("p", ["s1"])
await store.fail_execution(eid, "s1", "timeout")
state = await store.get_execution(eid)
assert state["status"] == "failed"
assert "s1" in state["error_message"]
assert "timeout" in state["error_message"]
assert state["completed_at"] is not None
@pytest.mark.asyncio
async def test_get_execution_not_found(self, store: PipelineStateMemory):
result = await store.get_execution("nonexistent")
assert result is None
@pytest.mark.asyncio
async def test_list_executions(self, store: PipelineStateMemory):
eid1 = await store.create_execution("p1", ["s1"])
eid2 = await store.create_execution("p2", ["s2"])
await store.complete_execution(eid1)
# List all
all_execs = await store.list_executions()
assert len(all_execs) == 2
# Filter by status
completed = await store.list_executions(status="completed")
assert len(completed) == 1
assert completed[0]["id"] == eid1
@pytest.mark.asyncio
async def test_list_executions_pagination(self, store: PipelineStateMemory):
for i in range(5):
await store.create_execution(f"p{i}", ["s1"])
page1 = await store.list_executions(limit=2, offset=0)
page2 = await store.list_executions(limit=2, offset=2)
assert len(page1) == 2
assert len(page2) == 2
@pytest.mark.asyncio
async def test_get_step_history(self, store: PipelineStateMemory):
eid = await store.create_execution("p", ["s1", "s2"])
await store.update_step(eid, "s1", "completed", output={"r": 1})
await store.update_step(eid, "s2", "failed", error="err")
history = await store.get_step_history(eid)
assert len(history) == 2
assert history[0]["step_name"] == "s1"
assert history[0]["status"] == "completed"
assert history[1]["step_name"] == "s2"
assert history[1]["status"] == "failed"
@pytest.mark.asyncio
async def test_update_step_nonexistent_execution(self, store: PipelineStateMemory):
# Should not raise, just log warning
await store.update_step("nonexistent", "s1", "completed")
@pytest.mark.asyncio
async def test_create_execution_with_tenant(self, store: PipelineStateMemory):
eid = await store.create_execution("p", ["s1"], tenant_id="tenant_123")
state = await store.get_execution(eid)
assert state["tenant_id"] == "tenant_123"
# ═══════════════════════════════════════════════════════════════
# PipelineStateRedis
# ═══════════════════════════════════════════════════════════════
class TestPipelineStateRedis:
"""Tests for Redis-backed pipeline state storage (using mocks)."""
@pytest.fixture
def mock_redis(self):
"""Create a mock Redis client."""
redis = AsyncMock()
redis.get = AsyncMock(return_value=None)
redis.set = AsyncMock(return_value=True)
redis.zadd = AsyncMock(return_value=1)
redis.zrevrange = AsyncMock(return_value=[])
redis.mget = AsyncMock(return_value=[])
# Redis pipeline: set/zadd are synchronous (return self for chaining), execute is async
pipe = MagicMock()
pipe.set = MagicMock(return_value=pipe)
pipe.zadd = MagicMock(return_value=pipe)
pipe.execute = AsyncMock(return_value=[True, 1])
redis.pipeline = MagicMock(return_value=pipe)
return redis
@pytest.fixture
def store(self, mock_redis) -> PipelineStateRedis:
"""Create a PipelineStateRedis with mocked Redis."""
store = PipelineStateRedis(redis_url="redis://localhost:6379/0")
# Pre-inject the mock Redis client
store._redis = mock_redis
return store
@pytest.mark.asyncio
async def test_create_execution_writes_to_redis(self, store: PipelineStateRedis, mock_redis):
eid = await store.create_execution("test_pipeline", ["s1", "s2"])
assert eid is not None
# Redis pipeline should have been used (pipe.set + pipe.zadd + pipe.execute)
pipe = mock_redis.pipeline.return_value
pipe.set.assert_called_once()
call_args = pipe.set.call_args
assert call_args[0][0].startswith("agentkit:pipeline:exec:")
# Verify the stored data is valid JSON
stored_data = json.loads(call_args[0][1])
assert stored_data["pipeline_name"] == "test_pipeline"
assert stored_data["status"] == "running"
# Verify TTL was set (7 days)
assert call_args[1].get("ex") == 7 * 24 * 3600
@pytest.mark.asyncio
async def test_create_execution_adds_to_sorted_set(self, store: PipelineStateRedis, mock_redis):
await store.create_execution("p", ["s1"])
# ZADD should have been called via pipeline
pipe = mock_redis.pipeline.return_value
pipe.zadd.assert_called_once()
@pytest.mark.asyncio
async def test_update_step_writes_to_redis(self, store: PipelineStateRedis, mock_redis):
eid = await store.create_execution("p", ["s1"])
mock_redis.set.reset_mock()
await store.update_step(eid, "s1", "completed", output={"r": 1})
mock_redis.set.assert_called_once()
@pytest.mark.asyncio
async def test_complete_execution_writes_to_redis(self, store: PipelineStateRedis, mock_redis):
eid = await store.create_execution("p", ["s1"])
mock_redis.set.reset_mock()
await store.complete_execution(eid, final_output={"done": True})
mock_redis.set.assert_called_once()
@pytest.mark.asyncio
async def test_fail_execution_writes_to_redis(self, store: PipelineStateRedis, mock_redis):
eid = await store.create_execution("p", ["s1"])
mock_redis.set.reset_mock()
await store.fail_execution(eid, "s1", "error")
mock_redis.set.assert_called_once()
@pytest.mark.asyncio
async def test_get_execution_from_redis(self, store: PipelineStateRedis, mock_redis):
eid = await store.create_execution("p", ["s1"])
# Simulate Redis returning data
state = await store._fallback.get_execution(eid)
mock_redis.get.return_value = json.dumps(state)
result = await store.get_execution(eid)
assert result is not None
assert result["pipeline_name"] == "p"
@pytest.mark.asyncio
async def test_get_execution_redis_miss_falls_back_to_memory(self, store: PipelineStateRedis, mock_redis):
eid = await store.create_execution("p", ["s1"])
# Redis returns None (miss)
mock_redis.get.return_value = None
# Should still find it in memory fallback
result = await store.get_execution(eid)
assert result is not None
assert result["pipeline_name"] == "p"
@pytest.mark.asyncio
async def test_list_executions_from_sorted_set(self, store: PipelineStateRedis, mock_redis):
eid = await store.create_execution("p", ["s1"])
state = await store._fallback.get_execution(eid)
mock_redis.zrevrange.return_value = [eid]
mock_redis.mget.return_value = [json.dumps(state)]
results = await store.list_executions()
assert len(results) == 1
assert results[0]["pipeline_name"] == "p"
@pytest.mark.asyncio
async def test_fallback_on_redis_failure(self, mock_redis):
store = PipelineStateRedis(redis_url="redis://localhost:6379/0")
# Make Redis initialization fail
mock_redis.ping = AsyncMock(side_effect=Exception("connection refused"))
store._redis = mock_redis
# Force a Redis operation to fail
mock_redis.set = AsyncMock(side_effect=Exception("connection refused"))
mock_redis.pipeline = MagicMock(side_effect=Exception("connection refused"))
# Should fall back to memory
eid = await store.create_execution("p", ["s1"])
assert eid is not None
assert store.using_fallback is True
@pytest.mark.asyncio
async def test_health_check(self, store: PipelineStateRedis, mock_redis):
mock_redis.ping = AsyncMock(return_value=True)
assert await store.health_check() is True
mock_redis.ping = AsyncMock(side_effect=Exception("fail"))
assert await store.health_check() is False
# ═══════════════════════════════════════════════════════════════
# PipelineStatePG
# ═══════════════════════════════════════════════════════════════
class TestPipelineStatePG:
"""Tests for PostgreSQL cold persistence (using mocks)."""
@pytest.mark.asyncio
async def test_no_op_when_session_factory_is_none(self):
pg = PipelineStatePG(session_factory=None)
assert pg.enabled is False
# All methods should be no-op
await pg.persist_execution({"id": "1", "pipeline_name": "p", "status": "completed"})
await pg.persist_step_history("1", [])
result = await pg.query_executions()
assert result == []
result = await pg.get_execution("1")
assert result is None
@pytest.mark.asyncio
async def test_persist_execution(self):
mock_session = AsyncMock()
mock_session.merge = AsyncMock()
mock_session.commit = AsyncMock()
mock_factory = MagicMock()
mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session)
mock_factory.return_value.__aexit__ = AsyncMock(return_value=False)
pg = PipelineStatePG(session_factory=mock_factory)
assert pg.enabled is True
state = {
"id": "test-id-123",
"pipeline_name": "test_pipeline",
"status": "completed",
"current_step": None,
"completed_steps": ["s1"],
"step_results": {"s1": {"r": 1}},
"input_data": {"key": "val"},
"final_output": {"done": True},
"error_message": None,
"tenant_id": None,
"created_at": datetime.now(timezone.utc).isoformat(),
"updated_at": datetime.now(timezone.utc).isoformat(),
"completed_at": datetime.now(timezone.utc).isoformat(),
}
await pg.persist_execution(state)
mock_session.merge.assert_called_once()
mock_session.commit.assert_called_once()
@pytest.mark.asyncio
async def test_persist_step_history(self):
mock_session = AsyncMock()
mock_session.merge = AsyncMock()
mock_session.commit = AsyncMock()
mock_factory = MagicMock()
mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session)
mock_factory.return_value.__aexit__ = AsyncMock(return_value=False)
pg = PipelineStatePG(session_factory=mock_factory)
steps = [
{
"id": "step-id-1",
"step_name": "s1",
"status": "completed",
"output_data": {"r": 1},
"error_message": None,
"duration_ms": 100,
"started_at": datetime.now(timezone.utc).isoformat(),
"completed_at": datetime.now(timezone.utc).isoformat(),
}
]
await pg.persist_step_history("exec-1", steps)
mock_session.merge.assert_called_once()
mock_session.commit.assert_called_once()
@pytest.mark.asyncio
async def test_persist_execution_handles_error(self):
mock_session = AsyncMock()
mock_session.merge = AsyncMock(side_effect=Exception("DB error"))
mock_session.commit = AsyncMock()
mock_factory = MagicMock()
mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session)
mock_factory.return_value.__aexit__ = AsyncMock(return_value=False)
pg = PipelineStatePG(session_factory=mock_factory)
# Should not raise
await pg.persist_execution({
"id": "1",
"pipeline_name": "p",
"status": "completed",
"created_at": datetime.now(timezone.utc).isoformat(),
"updated_at": datetime.now(timezone.utc).isoformat(),
})
@pytest.mark.asyncio
async def test_query_executions(self):
from agentkit.orchestrator.pipeline_models import PipelineExecutionModel
# Create a mock model instance
model = PipelineExecutionModel(
id="test-id",
pipeline_name="test_pipeline",
status="completed",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = [model]
mock_session = AsyncMock()
mock_session.execute = AsyncMock(return_value=mock_result)
mock_factory = MagicMock()
mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session)
mock_factory.return_value.__aexit__ = AsyncMock(return_value=False)
pg = PipelineStatePG(session_factory=mock_factory)
results = await pg.query_executions(pipeline_name="test_pipeline")
assert len(results) == 1
assert results[0]["pipeline_name"] == "test_pipeline"
@pytest.mark.asyncio
async def test_get_execution_found(self):
from agentkit.orchestrator.pipeline_models import PipelineExecutionModel
model = PipelineExecutionModel(
id="test-id",
pipeline_name="test_pipeline",
status="completed",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = model
mock_session = AsyncMock()
mock_session.execute = AsyncMock(return_value=mock_result)
mock_factory = MagicMock()
mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session)
mock_factory.return_value.__aexit__ = AsyncMock(return_value=False)
pg = PipelineStatePG(session_factory=mock_factory)
result = await pg.get_execution("test-id")
assert result is not None
assert result["id"] == "test-id"
@pytest.mark.asyncio
async def test_get_execution_not_found(self):
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
mock_session = AsyncMock()
mock_session.execute = AsyncMock(return_value=mock_result)
mock_factory = MagicMock()
mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session)
mock_factory.return_value.__aexit__ = AsyncMock(return_value=False)
pg = PipelineStatePG(session_factory=mock_factory)
result = await pg.get_execution("nonexistent")
assert result is None
# ═══════════════════════════════════════════════════════════════
# PipelineStateManager
# ═══════════════════════════════════════════════════════════════
class TestPipelineStateManager:
"""Tests for the unified state manager."""
@pytest.fixture
def manager(self) -> PipelineStateManager:
"""Create a manager with memory-only backend."""
return PipelineStateManager(redis_url=None, session_factory=None)
@pytest.mark.asyncio
async def test_create_and_get_execution(self, manager: PipelineStateManager):
eid = await manager.create_execution("p", ["s1"], input_data={"k": "v"})
state = await manager.get_execution(eid)
assert state is not None
assert state["pipeline_name"] == "p"
assert state["status"] == "running"
@pytest.mark.asyncio
async def test_update_step(self, manager: PipelineStateManager):
eid = await manager.create_execution("p", ["s1"])
await manager.update_step(eid, "s1", "completed", output={"r": 1})
state = await manager.get_execution(eid)
assert "s1" in state["completed_steps"]
@pytest.mark.asyncio
async def test_complete_persists_to_cold(self):
"""Test that completing an execution triggers PG persist."""
mock_session = AsyncMock()
mock_session.merge = AsyncMock()
mock_session.commit = AsyncMock()
mock_factory = MagicMock()
mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session)
mock_factory.return_value.__aexit__ = AsyncMock(return_value=False)
manager = PipelineStateManager(redis_url=None, session_factory=mock_factory)
eid = await manager.create_execution("p", ["s1"])
await manager.update_step(eid, "s1", "completed", output={"r": 1})
await manager.complete_execution(eid, final_output={"done": True})
# PG persist should have been called
mock_session.merge.assert_called()
@pytest.mark.asyncio
async def test_fail_persists_to_cold(self):
"""Test that failing an execution triggers PG persist."""
mock_session = AsyncMock()
mock_session.merge = AsyncMock()
mock_session.commit = AsyncMock()
mock_factory = MagicMock()
mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session)
mock_factory.return_value.__aexit__ = AsyncMock(return_value=False)
manager = PipelineStateManager(redis_url=None, session_factory=mock_factory)
eid = await manager.create_execution("p", ["s1"])
await manager.fail_execution(eid, "s1", "error")
mock_session.merge.assert_called()
@pytest.mark.asyncio
async def test_get_execution_pg_fallback(self):
"""Test Redis miss falls back to PG."""
from agentkit.orchestrator.pipeline_models import PipelineExecutionModel
model = PipelineExecutionModel(
id="pg-exec-id",
pipeline_name="pg_pipeline",
status="completed",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = model
mock_session = AsyncMock()
mock_session.execute = AsyncMock(return_value=mock_result)
mock_factory = MagicMock()
mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session)
mock_factory.return_value.__aexit__ = AsyncMock(return_value=False)
manager = PipelineStateManager(redis_url=None, session_factory=mock_factory)
# This execution_id doesn't exist in hot store, should fall back to PG
result = await manager.get_execution("pg-exec-id")
assert result is not None
assert result["pipeline_name"] == "pg_pipeline"
@pytest.mark.asyncio
async def test_list_executions_hot_first(self, manager: PipelineStateManager):
eid = await manager.create_execution("p", ["s1"])
results = await manager.list_executions()
assert len(results) == 1
assert results[0]["id"] == eid
@pytest.mark.asyncio
async def test_health_check_memory_only(self, manager: PipelineStateManager):
health = await manager.health_check()
assert health["hot"] is True
assert health["cold"] is False
@pytest.mark.asyncio
async def test_health_check_with_pg(self):
mock_factory = MagicMock()
manager = PipelineStateManager(redis_url=None, session_factory=mock_factory)
health = await manager.health_check()
assert health["hot"] is True
assert health["cold"] is True
# ═══════════════════════════════════════════════════════════════
# PipelineEngine with state persistence
# ═══════════════════════════════════════════════════════════════
class TestPipelineEngineWithState:
"""Tests for PipelineEngine integration with state persistence."""
@pytest.fixture
def pipeline(self) -> Pipeline:
return Pipeline(
name="test_pipeline",
version="1.0",
description="Test pipeline",
stages=[
PipelineStage(name="step_a", agent="agent1", action="do_a"),
PipelineStage(name="step_b", agent="agent2", action="do_b", depends_on=["step_a"]),
],
)
@pytest.mark.asyncio
async def test_engine_without_state_backward_compatible(self, pipeline: Pipeline):
"""Engine without state_manager should work as before."""
engine = PipelineEngine(dispatcher=None)
result = await engine.execute(pipeline)
assert result.status == StageStatus.COMPLETED
@pytest.mark.asyncio
async def test_engine_with_state_creates_execution(self, pipeline: Pipeline):
"""Engine with state_manager should create execution state."""
state_manager = PipelineStateManager(redis_url=None, session_factory=None)
engine = PipelineEngine(dispatcher=None, state_manager=state_manager)
result = await engine.execute(pipeline)
assert result.status == StageStatus.COMPLETED
# Check that execution was created in state store
executions = await state_manager.list_executions()
assert len(executions) == 1
assert executions[0]["status"] == "completed"
assert executions[0]["pipeline_name"] == "test_pipeline"
@pytest.mark.asyncio
async def test_engine_with_state_updates_steps(self, pipeline: Pipeline):
"""Engine should update step state after each stage."""
state_manager = PipelineStateManager(redis_url=None, session_factory=None)
engine = PipelineEngine(dispatcher=None, state_manager=state_manager)
await engine.execute(pipeline)
executions = await state_manager.list_executions()
exec_state = executions[0]
# Both steps should be completed
assert "step_a" in exec_state["completed_steps"]
assert "step_b" in exec_state["completed_steps"]
@pytest.mark.asyncio
async def test_engine_with_state_on_failure(self):
"""Engine should persist failure state when a stage fails."""
pipeline = Pipeline(
name="fail_pipeline",
version="1.0",
description="Pipeline that fails",
stages=[
PipelineStage(name="bad_step", agent="agent1", action="fail"),
],
)
# Create a dispatcher that raises
mock_dispatcher = AsyncMock()
mock_dispatcher.dispatch = AsyncMock(side_effect=Exception("boom"))
state_manager = PipelineStateManager(redis_url=None, session_factory=None)
engine = PipelineEngine(dispatcher=mock_dispatcher, state_manager=state_manager)
result = await engine.execute(pipeline)
assert result.status == StageStatus.FAILED
# Check state was persisted
executions = await state_manager.list_executions()
assert len(executions) == 1
assert executions[0]["status"] == "failed"
@pytest.mark.asyncio
async def test_engine_state_survives_check(self, pipeline: Pipeline):
"""Verify state can be retrieved after execution."""
state_manager = PipelineStateManager(redis_url=None, session_factory=None)
engine = PipelineEngine(dispatcher=None, state_manager=state_manager)
result = await engine.execute(pipeline, context={"brand": "acme"})
# Get execution by ID
executions = await state_manager.list_executions()
eid = executions[0]["id"]
state = await state_manager.get_execution(eid)
assert state is not None
assert state["pipeline_name"] == "test_pipeline"
assert state["status"] == "completed"
@pytest.mark.asyncio
async def test_engine_with_circular_dependency(self):
"""Engine should handle circular dependency gracefully."""
pipeline = Pipeline(
name="circular",
version="1.0",
description="Circular pipeline",
stages=[
PipelineStage(name="a", agent="agent1", action="do", depends_on=["b"]),
PipelineStage(name="b", agent="agent2", action="do", depends_on=["a"]),
],
)
state_manager = PipelineStateManager(redis_url=None, session_factory=None)
engine = PipelineEngine(dispatcher=None, state_manager=state_manager)
result = await engine.execute(pipeline)
assert result.status == StageStatus.FAILED
assert "Circular" in result.error_message
# No execution state should be created (topological sort fails before creation)
executions = await state_manager.list_executions()
assert len(executions) == 0