"""Unit tests for Pipeline execution state persistence.""" from __future__ import annotations import asyncio import json from datetime import datetime, timezone from unittest.mock import AsyncMock, MagicMock, patch import pytest from agentkit.orchestrator.pipeline_engine import PipelineEngine from agentkit.orchestrator.pipeline_schema import Pipeline, PipelineStage, StageStatus from agentkit.orchestrator.pipeline_state import ( PipelineStateMemory, PipelineStatePG, PipelineStateRedis, PipelineStateManager, ) # ═══════════════════════════════════════════════════════════════ # PipelineStateMemory # ═══════════════════════════════════════════════════════════════ class TestPipelineStateMemory: """Tests for in-memory pipeline state storage.""" @pytest.fixture def store(self) -> PipelineStateMemory: return PipelineStateMemory() @pytest.mark.asyncio async def test_create_execution(self, store: PipelineStateMemory): eid = await store.create_execution( pipeline_name="test_pipeline", steps=["step_a", "step_b"], input_data={"key": "value"}, ) assert eid is not None state = await store.get_execution(eid) assert state is not None assert state["pipeline_name"] == "test_pipeline" assert state["status"] == "running" assert state["current_step"] == "step_a" assert state["completed_steps"] == [] assert state["input_data"] == {"key": "value"} @pytest.mark.asyncio async def test_update_step_completed(self, store: PipelineStateMemory): eid = await store.create_execution("p", ["s1", "s2"]) await store.update_step(eid, "s1", "completed", output={"result": 42}) state = await store.get_execution(eid) assert "s1" in state["completed_steps"] assert state["step_results"]["s1"] == {"result": 42} @pytest.mark.asyncio async def test_update_step_failed(self, store: PipelineStateMemory): eid = await store.create_execution("p", ["s1"]) await store.update_step(eid, "s1", "failed", error="boom") state = await store.get_execution(eid) assert state["error_message"] == "boom" @pytest.mark.asyncio async def test_complete_execution(self, store: PipelineStateMemory): eid = await store.create_execution("p", ["s1"]) await store.complete_execution(eid, final_output={"done": True}) state = await store.get_execution(eid) assert state["status"] == "completed" assert state["final_output"] == {"done": True} assert state["completed_at"] is not None @pytest.mark.asyncio async def test_fail_execution(self, store: PipelineStateMemory): eid = await store.create_execution("p", ["s1"]) await store.fail_execution(eid, "s1", "timeout") state = await store.get_execution(eid) assert state["status"] == "failed" assert "s1" in state["error_message"] assert "timeout" in state["error_message"] assert state["completed_at"] is not None @pytest.mark.asyncio async def test_get_execution_not_found(self, store: PipelineStateMemory): result = await store.get_execution("nonexistent") assert result is None @pytest.mark.asyncio async def test_list_executions(self, store: PipelineStateMemory): eid1 = await store.create_execution("p1", ["s1"]) eid2 = await store.create_execution("p2", ["s2"]) await store.complete_execution(eid1) # List all all_execs = await store.list_executions() assert len(all_execs) == 2 # Filter by status completed = await store.list_executions(status="completed") assert len(completed) == 1 assert completed[0]["id"] == eid1 @pytest.mark.asyncio async def test_list_executions_pagination(self, store: PipelineStateMemory): for i in range(5): await store.create_execution(f"p{i}", ["s1"]) page1 = await store.list_executions(limit=2, offset=0) page2 = await store.list_executions(limit=2, offset=2) assert len(page1) == 2 assert len(page2) == 2 @pytest.mark.asyncio async def test_get_step_history(self, store: PipelineStateMemory): eid = await store.create_execution("p", ["s1", "s2"]) await store.update_step(eid, "s1", "completed", output={"r": 1}) await store.update_step(eid, "s2", "failed", error="err") history = await store.get_step_history(eid) assert len(history) == 2 assert history[0]["step_name"] == "s1" assert history[0]["status"] == "completed" assert history[1]["step_name"] == "s2" assert history[1]["status"] == "failed" @pytest.mark.asyncio async def test_update_step_nonexistent_execution(self, store: PipelineStateMemory): # Should not raise, just log warning await store.update_step("nonexistent", "s1", "completed") @pytest.mark.asyncio async def test_create_execution_with_tenant(self, store: PipelineStateMemory): eid = await store.create_execution("p", ["s1"], tenant_id="tenant_123") state = await store.get_execution(eid) assert state["tenant_id"] == "tenant_123" # ═══════════════════════════════════════════════════════════════ # PipelineStateRedis # ═══════════════════════════════════════════════════════════════ class TestPipelineStateRedis: """Tests for Redis-backed pipeline state storage (using mocks).""" @pytest.fixture def mock_redis(self): """Create a mock Redis client.""" redis = AsyncMock() redis.get = AsyncMock(return_value=None) redis.set = AsyncMock(return_value=True) redis.zadd = AsyncMock(return_value=1) redis.zrevrange = AsyncMock(return_value=[]) redis.mget = AsyncMock(return_value=[]) # Redis pipeline: set/zadd are synchronous (return self for chaining), execute is async pipe = MagicMock() pipe.set = MagicMock(return_value=pipe) pipe.zadd = MagicMock(return_value=pipe) pipe.execute = AsyncMock(return_value=[True, 1]) redis.pipeline = MagicMock(return_value=pipe) return redis @pytest.fixture def store(self, mock_redis) -> PipelineStateRedis: """Create a PipelineStateRedis with mocked Redis.""" store = PipelineStateRedis(redis_url="redis://localhost:6379/0") # Pre-inject the mock Redis client store._redis = mock_redis return store @pytest.mark.asyncio async def test_create_execution_writes_to_redis(self, store: PipelineStateRedis, mock_redis): eid = await store.create_execution("test_pipeline", ["s1", "s2"]) assert eid is not None # Redis pipeline should have been used (pipe.set + pipe.zadd + pipe.execute) pipe = mock_redis.pipeline.return_value pipe.set.assert_called_once() call_args = pipe.set.call_args assert call_args[0][0].startswith("agentkit:pipeline:exec:") # Verify the stored data is valid JSON stored_data = json.loads(call_args[0][1]) assert stored_data["pipeline_name"] == "test_pipeline" assert stored_data["status"] == "running" # Verify TTL was set (7 days) assert call_args[1].get("ex") == 7 * 24 * 3600 @pytest.mark.asyncio async def test_create_execution_adds_to_sorted_set(self, store: PipelineStateRedis, mock_redis): await store.create_execution("p", ["s1"]) # ZADD should have been called via pipeline pipe = mock_redis.pipeline.return_value pipe.zadd.assert_called_once() @pytest.mark.asyncio async def test_update_step_writes_to_redis(self, store: PipelineStateRedis, mock_redis): eid = await store.create_execution("p", ["s1"]) mock_redis.set.reset_mock() await store.update_step(eid, "s1", "completed", output={"r": 1}) mock_redis.set.assert_called_once() @pytest.mark.asyncio async def test_complete_execution_writes_to_redis(self, store: PipelineStateRedis, mock_redis): eid = await store.create_execution("p", ["s1"]) mock_redis.set.reset_mock() await store.complete_execution(eid, final_output={"done": True}) mock_redis.set.assert_called_once() @pytest.mark.asyncio async def test_fail_execution_writes_to_redis(self, store: PipelineStateRedis, mock_redis): eid = await store.create_execution("p", ["s1"]) mock_redis.set.reset_mock() await store.fail_execution(eid, "s1", "error") mock_redis.set.assert_called_once() @pytest.mark.asyncio async def test_get_execution_from_redis(self, store: PipelineStateRedis, mock_redis): eid = await store.create_execution("p", ["s1"]) # Simulate Redis returning data state = await store._fallback.get_execution(eid) mock_redis.get.return_value = json.dumps(state) result = await store.get_execution(eid) assert result is not None assert result["pipeline_name"] == "p" @pytest.mark.asyncio async def test_get_execution_redis_miss_falls_back_to_memory(self, store: PipelineStateRedis, mock_redis): eid = await store.create_execution("p", ["s1"]) # Redis returns None (miss) mock_redis.get.return_value = None # Should still find it in memory fallback result = await store.get_execution(eid) assert result is not None assert result["pipeline_name"] == "p" @pytest.mark.asyncio async def test_list_executions_from_sorted_set(self, store: PipelineStateRedis, mock_redis): eid = await store.create_execution("p", ["s1"]) state = await store._fallback.get_execution(eid) mock_redis.zrevrange.return_value = [eid] mock_redis.mget.return_value = [json.dumps(state)] results = await store.list_executions() assert len(results) == 1 assert results[0]["pipeline_name"] == "p" @pytest.mark.asyncio async def test_fallback_on_redis_failure(self, mock_redis): store = PipelineStateRedis(redis_url="redis://localhost:6379/0") # Make Redis initialization fail mock_redis.ping = AsyncMock(side_effect=Exception("connection refused")) store._redis = mock_redis # Force a Redis operation to fail mock_redis.set = AsyncMock(side_effect=Exception("connection refused")) mock_redis.pipeline = MagicMock(side_effect=Exception("connection refused")) # Should fall back to memory eid = await store.create_execution("p", ["s1"]) assert eid is not None assert store.using_fallback is True @pytest.mark.asyncio async def test_health_check(self, store: PipelineStateRedis, mock_redis): mock_redis.ping = AsyncMock(return_value=True) assert await store.health_check() is True mock_redis.ping = AsyncMock(side_effect=Exception("fail")) assert await store.health_check() is False # ═══════════════════════════════════════════════════════════════ # PipelineStatePG # ═══════════════════════════════════════════════════════════════ class TestPipelineStatePG: """Tests for PostgreSQL cold persistence (using mocks).""" @pytest.mark.asyncio async def test_no_op_when_session_factory_is_none(self): pg = PipelineStatePG(session_factory=None) assert pg.enabled is False # All methods should be no-op await pg.persist_execution({"id": "1", "pipeline_name": "p", "status": "completed"}) await pg.persist_step_history("1", []) result = await pg.query_executions() assert result == [] result = await pg.get_execution("1") assert result is None @pytest.mark.asyncio async def test_persist_execution(self): mock_session = AsyncMock() mock_session.merge = AsyncMock() mock_session.commit = AsyncMock() mock_factory = MagicMock() mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session) mock_factory.return_value.__aexit__ = AsyncMock(return_value=False) pg = PipelineStatePG(session_factory=mock_factory) assert pg.enabled is True state = { "id": "test-id-123", "pipeline_name": "test_pipeline", "status": "completed", "current_step": None, "completed_steps": ["s1"], "step_results": {"s1": {"r": 1}}, "input_data": {"key": "val"}, "final_output": {"done": True}, "error_message": None, "tenant_id": None, "created_at": datetime.now(timezone.utc).isoformat(), "updated_at": datetime.now(timezone.utc).isoformat(), "completed_at": datetime.now(timezone.utc).isoformat(), } await pg.persist_execution(state) mock_session.merge.assert_called_once() mock_session.commit.assert_called_once() @pytest.mark.asyncio async def test_persist_step_history(self): mock_session = AsyncMock() mock_session.merge = AsyncMock() mock_session.commit = AsyncMock() mock_factory = MagicMock() mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session) mock_factory.return_value.__aexit__ = AsyncMock(return_value=False) pg = PipelineStatePG(session_factory=mock_factory) steps = [ { "id": "step-id-1", "step_name": "s1", "status": "completed", "output_data": {"r": 1}, "error_message": None, "duration_ms": 100, "started_at": datetime.now(timezone.utc).isoformat(), "completed_at": datetime.now(timezone.utc).isoformat(), } ] await pg.persist_step_history("exec-1", steps) mock_session.merge.assert_called_once() mock_session.commit.assert_called_once() @pytest.mark.asyncio async def test_persist_execution_handles_error(self): mock_session = AsyncMock() mock_session.merge = AsyncMock(side_effect=Exception("DB error")) mock_session.commit = AsyncMock() mock_factory = MagicMock() mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session) mock_factory.return_value.__aexit__ = AsyncMock(return_value=False) pg = PipelineStatePG(session_factory=mock_factory) # Should not raise await pg.persist_execution({ "id": "1", "pipeline_name": "p", "status": "completed", "created_at": datetime.now(timezone.utc).isoformat(), "updated_at": datetime.now(timezone.utc).isoformat(), }) @pytest.mark.asyncio async def test_query_executions(self): from agentkit.orchestrator.pipeline_models import PipelineExecutionModel # Create a mock model instance model = PipelineExecutionModel( id="test-id", pipeline_name="test_pipeline", status="completed", created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ) mock_result = MagicMock() mock_result.scalars.return_value.all.return_value = [model] mock_session = AsyncMock() mock_session.execute = AsyncMock(return_value=mock_result) mock_factory = MagicMock() mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session) mock_factory.return_value.__aexit__ = AsyncMock(return_value=False) pg = PipelineStatePG(session_factory=mock_factory) results = await pg.query_executions(pipeline_name="test_pipeline") assert len(results) == 1 assert results[0]["pipeline_name"] == "test_pipeline" @pytest.mark.asyncio async def test_get_execution_found(self): from agentkit.orchestrator.pipeline_models import PipelineExecutionModel model = PipelineExecutionModel( id="test-id", pipeline_name="test_pipeline", status="completed", created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ) mock_result = MagicMock() mock_result.scalar_one_or_none.return_value = model mock_session = AsyncMock() mock_session.execute = AsyncMock(return_value=mock_result) mock_factory = MagicMock() mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session) mock_factory.return_value.__aexit__ = AsyncMock(return_value=False) pg = PipelineStatePG(session_factory=mock_factory) result = await pg.get_execution("test-id") assert result is not None assert result["id"] == "test-id" @pytest.mark.asyncio async def test_get_execution_not_found(self): mock_result = MagicMock() mock_result.scalar_one_or_none.return_value = None mock_session = AsyncMock() mock_session.execute = AsyncMock(return_value=mock_result) mock_factory = MagicMock() mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session) mock_factory.return_value.__aexit__ = AsyncMock(return_value=False) pg = PipelineStatePG(session_factory=mock_factory) result = await pg.get_execution("nonexistent") assert result is None # ═══════════════════════════════════════════════════════════════ # PipelineStateManager # ═══════════════════════════════════════════════════════════════ class TestPipelineStateManager: """Tests for the unified state manager.""" @pytest.fixture def manager(self) -> PipelineStateManager: """Create a manager with memory-only backend.""" return PipelineStateManager(redis_url=None, session_factory=None) @pytest.mark.asyncio async def test_create_and_get_execution(self, manager: PipelineStateManager): eid = await manager.create_execution("p", ["s1"], input_data={"k": "v"}) state = await manager.get_execution(eid) assert state is not None assert state["pipeline_name"] == "p" assert state["status"] == "running" @pytest.mark.asyncio async def test_update_step(self, manager: PipelineStateManager): eid = await manager.create_execution("p", ["s1"]) await manager.update_step(eid, "s1", "completed", output={"r": 1}) state = await manager.get_execution(eid) assert "s1" in state["completed_steps"] @pytest.mark.asyncio async def test_complete_persists_to_cold(self): """Test that completing an execution triggers PG persist.""" mock_session = AsyncMock() mock_session.merge = AsyncMock() mock_session.commit = AsyncMock() mock_factory = MagicMock() mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session) mock_factory.return_value.__aexit__ = AsyncMock(return_value=False) manager = PipelineStateManager(redis_url=None, session_factory=mock_factory) eid = await manager.create_execution("p", ["s1"]) await manager.update_step(eid, "s1", "completed", output={"r": 1}) await manager.complete_execution(eid, final_output={"done": True}) # PG persist should have been called mock_session.merge.assert_called() @pytest.mark.asyncio async def test_fail_persists_to_cold(self): """Test that failing an execution triggers PG persist.""" mock_session = AsyncMock() mock_session.merge = AsyncMock() mock_session.commit = AsyncMock() mock_factory = MagicMock() mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session) mock_factory.return_value.__aexit__ = AsyncMock(return_value=False) manager = PipelineStateManager(redis_url=None, session_factory=mock_factory) eid = await manager.create_execution("p", ["s1"]) await manager.fail_execution(eid, "s1", "error") mock_session.merge.assert_called() @pytest.mark.asyncio async def test_get_execution_pg_fallback(self): """Test Redis miss falls back to PG.""" from agentkit.orchestrator.pipeline_models import PipelineExecutionModel model = PipelineExecutionModel( id="pg-exec-id", pipeline_name="pg_pipeline", status="completed", created_at=datetime.now(timezone.utc), updated_at=datetime.now(timezone.utc), ) mock_result = MagicMock() mock_result.scalar_one_or_none.return_value = model mock_session = AsyncMock() mock_session.execute = AsyncMock(return_value=mock_result) mock_factory = MagicMock() mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session) mock_factory.return_value.__aexit__ = AsyncMock(return_value=False) manager = PipelineStateManager(redis_url=None, session_factory=mock_factory) # This execution_id doesn't exist in hot store, should fall back to PG result = await manager.get_execution("pg-exec-id") assert result is not None assert result["pipeline_name"] == "pg_pipeline" @pytest.mark.asyncio async def test_list_executions_hot_first(self, manager: PipelineStateManager): eid = await manager.create_execution("p", ["s1"]) results = await manager.list_executions() assert len(results) == 1 assert results[0]["id"] == eid @pytest.mark.asyncio async def test_health_check_memory_only(self, manager: PipelineStateManager): health = await manager.health_check() assert health["hot"] is True assert health["cold"] is False @pytest.mark.asyncio async def test_health_check_with_pg(self): mock_factory = MagicMock() manager = PipelineStateManager(redis_url=None, session_factory=mock_factory) health = await manager.health_check() assert health["hot"] is True assert health["cold"] is True # ═══════════════════════════════════════════════════════════════ # PipelineEngine with state persistence # ═══════════════════════════════════════════════════════════════ class TestPipelineEngineWithState: """Tests for PipelineEngine integration with state persistence.""" @pytest.fixture def pipeline(self) -> Pipeline: return Pipeline( name="test_pipeline", version="1.0", description="Test pipeline", stages=[ PipelineStage(name="step_a", agent="agent1", action="do_a"), PipelineStage(name="step_b", agent="agent2", action="do_b", depends_on=["step_a"]), ], ) @pytest.mark.asyncio async def test_engine_without_state_backward_compatible(self, pipeline: Pipeline): """Engine without state_manager should work as before.""" engine = PipelineEngine(dispatcher=None) result = await engine.execute(pipeline) assert result.status == StageStatus.COMPLETED @pytest.mark.asyncio async def test_engine_with_state_creates_execution(self, pipeline: Pipeline): """Engine with state_manager should create execution state.""" state_manager = PipelineStateManager(redis_url=None, session_factory=None) engine = PipelineEngine(dispatcher=None, state_manager=state_manager) result = await engine.execute(pipeline) assert result.status == StageStatus.COMPLETED # Check that execution was created in state store executions = await state_manager.list_executions() assert len(executions) == 1 assert executions[0]["status"] == "completed" assert executions[0]["pipeline_name"] == "test_pipeline" @pytest.mark.asyncio async def test_engine_with_state_updates_steps(self, pipeline: Pipeline): """Engine should update step state after each stage.""" state_manager = PipelineStateManager(redis_url=None, session_factory=None) engine = PipelineEngine(dispatcher=None, state_manager=state_manager) await engine.execute(pipeline) executions = await state_manager.list_executions() exec_state = executions[0] # Both steps should be completed assert "step_a" in exec_state["completed_steps"] assert "step_b" in exec_state["completed_steps"] @pytest.mark.asyncio async def test_engine_with_state_on_failure(self): """Engine should persist failure state when a stage fails.""" pipeline = Pipeline( name="fail_pipeline", version="1.0", description="Pipeline that fails", stages=[ PipelineStage(name="bad_step", agent="agent1", action="fail"), ], ) # Create a dispatcher that raises mock_dispatcher = AsyncMock() mock_dispatcher.dispatch = AsyncMock(side_effect=Exception("boom")) state_manager = PipelineStateManager(redis_url=None, session_factory=None) engine = PipelineEngine(dispatcher=mock_dispatcher, state_manager=state_manager) result = await engine.execute(pipeline) assert result.status == StageStatus.FAILED # Check state was persisted executions = await state_manager.list_executions() assert len(executions) == 1 assert executions[0]["status"] == "failed" @pytest.mark.asyncio async def test_engine_state_survives_check(self, pipeline: Pipeline): """Verify state can be retrieved after execution.""" state_manager = PipelineStateManager(redis_url=None, session_factory=None) engine = PipelineEngine(dispatcher=None, state_manager=state_manager) result = await engine.execute(pipeline, context={"brand": "acme"}) # Get execution by ID executions = await state_manager.list_executions() eid = executions[0]["id"] state = await state_manager.get_execution(eid) assert state is not None assert state["pipeline_name"] == "test_pipeline" assert state["status"] == "completed" @pytest.mark.asyncio async def test_engine_with_circular_dependency(self): """Engine should handle circular dependency gracefully.""" pipeline = Pipeline( name="circular", version="1.0", description="Circular pipeline", stages=[ PipelineStage(name="a", agent="agent1", action="do", depends_on=["b"]), PipelineStage(name="b", agent="agent2", action="do", depends_on=["a"]), ], ) state_manager = PipelineStateManager(redis_url=None, session_factory=None) engine = PipelineEngine(dispatcher=None, state_manager=state_manager) result = await engine.execute(pipeline) assert result.status == StageStatus.FAILED assert "Circular" in result.error_message # No execution state should be created (topological sort fails before creation) executions = await state_manager.list_executions() assert len(executions) == 0