662 lines
28 KiB
Python
662 lines
28 KiB
Python
"""Unit tests for Pipeline execution state persistence."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
from datetime import datetime, timezone
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from agentkit.orchestrator.pipeline_engine import PipelineEngine
|
|
from agentkit.orchestrator.pipeline_schema import Pipeline, PipelineStage, StageStatus
|
|
from agentkit.orchestrator.pipeline_state import (
|
|
PipelineStateMemory,
|
|
PipelineStatePG,
|
|
PipelineStateRedis,
|
|
PipelineStateManager,
|
|
)
|
|
|
|
|
|
# ═══════════════════════════════════════════════════════════════
|
|
# PipelineStateMemory
|
|
# ═══════════════════════════════════════════════════════════════
|
|
|
|
|
|
class TestPipelineStateMemory:
|
|
"""Tests for in-memory pipeline state storage."""
|
|
|
|
@pytest.fixture
|
|
def store(self) -> PipelineStateMemory:
|
|
return PipelineStateMemory()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_execution(self, store: PipelineStateMemory):
|
|
eid = await store.create_execution(
|
|
pipeline_name="test_pipeline",
|
|
steps=["step_a", "step_b"],
|
|
input_data={"key": "value"},
|
|
)
|
|
assert eid is not None
|
|
state = await store.get_execution(eid)
|
|
assert state is not None
|
|
assert state["pipeline_name"] == "test_pipeline"
|
|
assert state["status"] == "running"
|
|
assert state["current_step"] == "step_a"
|
|
assert state["completed_steps"] == []
|
|
assert state["input_data"] == {"key": "value"}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_step_completed(self, store: PipelineStateMemory):
|
|
eid = await store.create_execution("p", ["s1", "s2"])
|
|
await store.update_step(eid, "s1", "completed", output={"result": 42})
|
|
state = await store.get_execution(eid)
|
|
assert "s1" in state["completed_steps"]
|
|
assert state["step_results"]["s1"] == {"result": 42}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_step_failed(self, store: PipelineStateMemory):
|
|
eid = await store.create_execution("p", ["s1"])
|
|
await store.update_step(eid, "s1", "failed", error="boom")
|
|
state = await store.get_execution(eid)
|
|
assert state["error_message"] == "boom"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_complete_execution(self, store: PipelineStateMemory):
|
|
eid = await store.create_execution("p", ["s1"])
|
|
await store.complete_execution(eid, final_output={"done": True})
|
|
state = await store.get_execution(eid)
|
|
assert state["status"] == "completed"
|
|
assert state["final_output"] == {"done": True}
|
|
assert state["completed_at"] is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fail_execution(self, store: PipelineStateMemory):
|
|
eid = await store.create_execution("p", ["s1"])
|
|
await store.fail_execution(eid, "s1", "timeout")
|
|
state = await store.get_execution(eid)
|
|
assert state["status"] == "failed"
|
|
assert "s1" in state["error_message"]
|
|
assert "timeout" in state["error_message"]
|
|
assert state["completed_at"] is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_execution_not_found(self, store: PipelineStateMemory):
|
|
result = await store.get_execution("nonexistent")
|
|
assert result is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_executions(self, store: PipelineStateMemory):
|
|
eid1 = await store.create_execution("p1", ["s1"])
|
|
eid2 = await store.create_execution("p2", ["s2"])
|
|
await store.complete_execution(eid1)
|
|
# List all
|
|
all_execs = await store.list_executions()
|
|
assert len(all_execs) == 2
|
|
# Filter by status
|
|
completed = await store.list_executions(status="completed")
|
|
assert len(completed) == 1
|
|
assert completed[0]["id"] == eid1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_executions_pagination(self, store: PipelineStateMemory):
|
|
for i in range(5):
|
|
await store.create_execution(f"p{i}", ["s1"])
|
|
page1 = await store.list_executions(limit=2, offset=0)
|
|
page2 = await store.list_executions(limit=2, offset=2)
|
|
assert len(page1) == 2
|
|
assert len(page2) == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_step_history(self, store: PipelineStateMemory):
|
|
eid = await store.create_execution("p", ["s1", "s2"])
|
|
await store.update_step(eid, "s1", "completed", output={"r": 1})
|
|
await store.update_step(eid, "s2", "failed", error="err")
|
|
history = await store.get_step_history(eid)
|
|
assert len(history) == 2
|
|
assert history[0]["step_name"] == "s1"
|
|
assert history[0]["status"] == "completed"
|
|
assert history[1]["step_name"] == "s2"
|
|
assert history[1]["status"] == "failed"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_step_nonexistent_execution(self, store: PipelineStateMemory):
|
|
# Should not raise, just log warning
|
|
await store.update_step("nonexistent", "s1", "completed")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_execution_with_tenant(self, store: PipelineStateMemory):
|
|
eid = await store.create_execution("p", ["s1"], tenant_id="tenant_123")
|
|
state = await store.get_execution(eid)
|
|
assert state["tenant_id"] == "tenant_123"
|
|
|
|
|
|
# ═══════════════════════════════════════════════════════════════
|
|
# PipelineStateRedis
|
|
# ═══════════════════════════════════════════════════════════════
|
|
|
|
|
|
class TestPipelineStateRedis:
|
|
"""Tests for Redis-backed pipeline state storage (using mocks)."""
|
|
|
|
@pytest.fixture
|
|
def mock_redis(self):
|
|
"""Create a mock Redis client."""
|
|
redis = AsyncMock()
|
|
redis.get = AsyncMock(return_value=None)
|
|
redis.set = AsyncMock(return_value=True)
|
|
redis.zadd = AsyncMock(return_value=1)
|
|
redis.zrevrange = AsyncMock(return_value=[])
|
|
redis.mget = AsyncMock(return_value=[])
|
|
# Redis pipeline: set/zadd are synchronous (return self for chaining), execute is async
|
|
pipe = MagicMock()
|
|
pipe.set = MagicMock(return_value=pipe)
|
|
pipe.zadd = MagicMock(return_value=pipe)
|
|
pipe.execute = AsyncMock(return_value=[True, 1])
|
|
redis.pipeline = MagicMock(return_value=pipe)
|
|
return redis
|
|
|
|
@pytest.fixture
|
|
def store(self, mock_redis) -> PipelineStateRedis:
|
|
"""Create a PipelineStateRedis with mocked Redis."""
|
|
store = PipelineStateRedis(redis_url="redis://localhost:6379/0")
|
|
# Pre-inject the mock Redis client
|
|
store._redis = mock_redis
|
|
return store
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_execution_writes_to_redis(self, store: PipelineStateRedis, mock_redis):
|
|
eid = await store.create_execution("test_pipeline", ["s1", "s2"])
|
|
assert eid is not None
|
|
# Redis pipeline should have been used (pipe.set + pipe.zadd + pipe.execute)
|
|
pipe = mock_redis.pipeline.return_value
|
|
pipe.set.assert_called_once()
|
|
call_args = pipe.set.call_args
|
|
assert call_args[0][0].startswith("agentkit:pipeline:exec:")
|
|
# Verify the stored data is valid JSON
|
|
stored_data = json.loads(call_args[0][1])
|
|
assert stored_data["pipeline_name"] == "test_pipeline"
|
|
assert stored_data["status"] == "running"
|
|
# Verify TTL was set (7 days)
|
|
assert call_args[1].get("ex") == 7 * 24 * 3600
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_execution_adds_to_sorted_set(self, store: PipelineStateRedis, mock_redis):
|
|
await store.create_execution("p", ["s1"])
|
|
# ZADD should have been called via pipeline
|
|
pipe = mock_redis.pipeline.return_value
|
|
pipe.zadd.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_step_writes_to_redis(self, store: PipelineStateRedis, mock_redis):
|
|
eid = await store.create_execution("p", ["s1"])
|
|
mock_redis.set.reset_mock()
|
|
await store.update_step(eid, "s1", "completed", output={"r": 1})
|
|
mock_redis.set.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_complete_execution_writes_to_redis(self, store: PipelineStateRedis, mock_redis):
|
|
eid = await store.create_execution("p", ["s1"])
|
|
mock_redis.set.reset_mock()
|
|
await store.complete_execution(eid, final_output={"done": True})
|
|
mock_redis.set.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fail_execution_writes_to_redis(self, store: PipelineStateRedis, mock_redis):
|
|
eid = await store.create_execution("p", ["s1"])
|
|
mock_redis.set.reset_mock()
|
|
await store.fail_execution(eid, "s1", "error")
|
|
mock_redis.set.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_execution_from_redis(self, store: PipelineStateRedis, mock_redis):
|
|
eid = await store.create_execution("p", ["s1"])
|
|
# Simulate Redis returning data
|
|
state = await store._fallback.get_execution(eid)
|
|
mock_redis.get.return_value = json.dumps(state)
|
|
result = await store.get_execution(eid)
|
|
assert result is not None
|
|
assert result["pipeline_name"] == "p"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_execution_redis_miss_falls_back_to_memory(self, store: PipelineStateRedis, mock_redis):
|
|
eid = await store.create_execution("p", ["s1"])
|
|
# Redis returns None (miss)
|
|
mock_redis.get.return_value = None
|
|
# Should still find it in memory fallback
|
|
result = await store.get_execution(eid)
|
|
assert result is not None
|
|
assert result["pipeline_name"] == "p"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_executions_from_sorted_set(self, store: PipelineStateRedis, mock_redis):
|
|
eid = await store.create_execution("p", ["s1"])
|
|
state = await store._fallback.get_execution(eid)
|
|
mock_redis.zrevrange.return_value = [eid]
|
|
mock_redis.mget.return_value = [json.dumps(state)]
|
|
results = await store.list_executions()
|
|
assert len(results) == 1
|
|
assert results[0]["pipeline_name"] == "p"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fallback_on_redis_failure(self, mock_redis):
|
|
store = PipelineStateRedis(redis_url="redis://localhost:6379/0")
|
|
# Make Redis initialization fail
|
|
mock_redis.ping = AsyncMock(side_effect=Exception("connection refused"))
|
|
store._redis = mock_redis
|
|
# Force a Redis operation to fail
|
|
mock_redis.set = AsyncMock(side_effect=Exception("connection refused"))
|
|
mock_redis.pipeline = MagicMock(side_effect=Exception("connection refused"))
|
|
# Should fall back to memory
|
|
eid = await store.create_execution("p", ["s1"])
|
|
assert eid is not None
|
|
assert store.using_fallback is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_health_check(self, store: PipelineStateRedis, mock_redis):
|
|
mock_redis.ping = AsyncMock(return_value=True)
|
|
assert await store.health_check() is True
|
|
mock_redis.ping = AsyncMock(side_effect=Exception("fail"))
|
|
assert await store.health_check() is False
|
|
|
|
|
|
# ═══════════════════════════════════════════════════════════════
|
|
# PipelineStatePG
|
|
# ═══════════════════════════════════════════════════════════════
|
|
|
|
|
|
class TestPipelineStatePG:
|
|
"""Tests for PostgreSQL cold persistence (using mocks)."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_no_op_when_session_factory_is_none(self):
|
|
pg = PipelineStatePG(session_factory=None)
|
|
assert pg.enabled is False
|
|
# All methods should be no-op
|
|
await pg.persist_execution({"id": "1", "pipeline_name": "p", "status": "completed"})
|
|
await pg.persist_step_history("1", [])
|
|
result = await pg.query_executions()
|
|
assert result == []
|
|
result = await pg.get_execution("1")
|
|
assert result is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_persist_execution(self):
|
|
mock_session = AsyncMock()
|
|
mock_session.merge = AsyncMock()
|
|
mock_session.commit = AsyncMock()
|
|
|
|
mock_factory = MagicMock()
|
|
mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
|
mock_factory.return_value.__aexit__ = AsyncMock(return_value=False)
|
|
|
|
pg = PipelineStatePG(session_factory=mock_factory)
|
|
assert pg.enabled is True
|
|
|
|
state = {
|
|
"id": "test-id-123",
|
|
"pipeline_name": "test_pipeline",
|
|
"status": "completed",
|
|
"current_step": None,
|
|
"completed_steps": ["s1"],
|
|
"step_results": {"s1": {"r": 1}},
|
|
"input_data": {"key": "val"},
|
|
"final_output": {"done": True},
|
|
"error_message": None,
|
|
"tenant_id": None,
|
|
"created_at": datetime.now(timezone.utc).isoformat(),
|
|
"updated_at": datetime.now(timezone.utc).isoformat(),
|
|
"completed_at": datetime.now(timezone.utc).isoformat(),
|
|
}
|
|
await pg.persist_execution(state)
|
|
mock_session.merge.assert_called_once()
|
|
mock_session.commit.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_persist_step_history(self):
|
|
mock_session = AsyncMock()
|
|
mock_session.merge = AsyncMock()
|
|
mock_session.commit = AsyncMock()
|
|
|
|
mock_factory = MagicMock()
|
|
mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
|
mock_factory.return_value.__aexit__ = AsyncMock(return_value=False)
|
|
|
|
pg = PipelineStatePG(session_factory=mock_factory)
|
|
|
|
steps = [
|
|
{
|
|
"id": "step-id-1",
|
|
"step_name": "s1",
|
|
"status": "completed",
|
|
"output_data": {"r": 1},
|
|
"error_message": None,
|
|
"duration_ms": 100,
|
|
"started_at": datetime.now(timezone.utc).isoformat(),
|
|
"completed_at": datetime.now(timezone.utc).isoformat(),
|
|
}
|
|
]
|
|
await pg.persist_step_history("exec-1", steps)
|
|
mock_session.merge.assert_called_once()
|
|
mock_session.commit.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_persist_execution_handles_error(self):
|
|
mock_session = AsyncMock()
|
|
mock_session.merge = AsyncMock(side_effect=Exception("DB error"))
|
|
mock_session.commit = AsyncMock()
|
|
|
|
mock_factory = MagicMock()
|
|
mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
|
mock_factory.return_value.__aexit__ = AsyncMock(return_value=False)
|
|
|
|
pg = PipelineStatePG(session_factory=mock_factory)
|
|
# Should not raise
|
|
await pg.persist_execution({
|
|
"id": "1",
|
|
"pipeline_name": "p",
|
|
"status": "completed",
|
|
"created_at": datetime.now(timezone.utc).isoformat(),
|
|
"updated_at": datetime.now(timezone.utc).isoformat(),
|
|
})
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_query_executions(self):
|
|
from agentkit.orchestrator.pipeline_models import PipelineExecutionModel
|
|
|
|
# Create a mock model instance
|
|
model = PipelineExecutionModel(
|
|
id="test-id",
|
|
pipeline_name="test_pipeline",
|
|
status="completed",
|
|
created_at=datetime.now(timezone.utc),
|
|
updated_at=datetime.now(timezone.utc),
|
|
)
|
|
|
|
mock_result = MagicMock()
|
|
mock_result.scalars.return_value.all.return_value = [model]
|
|
|
|
mock_session = AsyncMock()
|
|
mock_session.execute = AsyncMock(return_value=mock_result)
|
|
|
|
mock_factory = MagicMock()
|
|
mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
|
mock_factory.return_value.__aexit__ = AsyncMock(return_value=False)
|
|
|
|
pg = PipelineStatePG(session_factory=mock_factory)
|
|
results = await pg.query_executions(pipeline_name="test_pipeline")
|
|
assert len(results) == 1
|
|
assert results[0]["pipeline_name"] == "test_pipeline"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_execution_found(self):
|
|
from agentkit.orchestrator.pipeline_models import PipelineExecutionModel
|
|
|
|
model = PipelineExecutionModel(
|
|
id="test-id",
|
|
pipeline_name="test_pipeline",
|
|
status="completed",
|
|
created_at=datetime.now(timezone.utc),
|
|
updated_at=datetime.now(timezone.utc),
|
|
)
|
|
|
|
mock_result = MagicMock()
|
|
mock_result.scalar_one_or_none.return_value = model
|
|
|
|
mock_session = AsyncMock()
|
|
mock_session.execute = AsyncMock(return_value=mock_result)
|
|
|
|
mock_factory = MagicMock()
|
|
mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
|
mock_factory.return_value.__aexit__ = AsyncMock(return_value=False)
|
|
|
|
pg = PipelineStatePG(session_factory=mock_factory)
|
|
result = await pg.get_execution("test-id")
|
|
assert result is not None
|
|
assert result["id"] == "test-id"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_execution_not_found(self):
|
|
mock_result = MagicMock()
|
|
mock_result.scalar_one_or_none.return_value = None
|
|
|
|
mock_session = AsyncMock()
|
|
mock_session.execute = AsyncMock(return_value=mock_result)
|
|
|
|
mock_factory = MagicMock()
|
|
mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
|
mock_factory.return_value.__aexit__ = AsyncMock(return_value=False)
|
|
|
|
pg = PipelineStatePG(session_factory=mock_factory)
|
|
result = await pg.get_execution("nonexistent")
|
|
assert result is None
|
|
|
|
|
|
# ═══════════════════════════════════════════════════════════════
|
|
# PipelineStateManager
|
|
# ═══════════════════════════════════════════════════════════════
|
|
|
|
|
|
class TestPipelineStateManager:
|
|
"""Tests for the unified state manager."""
|
|
|
|
@pytest.fixture
|
|
def manager(self) -> PipelineStateManager:
|
|
"""Create a manager with memory-only backend."""
|
|
return PipelineStateManager(redis_url=None, session_factory=None)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_and_get_execution(self, manager: PipelineStateManager):
|
|
eid = await manager.create_execution("p", ["s1"], input_data={"k": "v"})
|
|
state = await manager.get_execution(eid)
|
|
assert state is not None
|
|
assert state["pipeline_name"] == "p"
|
|
assert state["status"] == "running"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_step(self, manager: PipelineStateManager):
|
|
eid = await manager.create_execution("p", ["s1"])
|
|
await manager.update_step(eid, "s1", "completed", output={"r": 1})
|
|
state = await manager.get_execution(eid)
|
|
assert "s1" in state["completed_steps"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_complete_persists_to_cold(self):
|
|
"""Test that completing an execution triggers PG persist."""
|
|
mock_session = AsyncMock()
|
|
mock_session.merge = AsyncMock()
|
|
mock_session.commit = AsyncMock()
|
|
|
|
mock_factory = MagicMock()
|
|
mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
|
mock_factory.return_value.__aexit__ = AsyncMock(return_value=False)
|
|
|
|
manager = PipelineStateManager(redis_url=None, session_factory=mock_factory)
|
|
eid = await manager.create_execution("p", ["s1"])
|
|
await manager.update_step(eid, "s1", "completed", output={"r": 1})
|
|
await manager.complete_execution(eid, final_output={"done": True})
|
|
|
|
# PG persist should have been called
|
|
mock_session.merge.assert_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fail_persists_to_cold(self):
|
|
"""Test that failing an execution triggers PG persist."""
|
|
mock_session = AsyncMock()
|
|
mock_session.merge = AsyncMock()
|
|
mock_session.commit = AsyncMock()
|
|
|
|
mock_factory = MagicMock()
|
|
mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
|
mock_factory.return_value.__aexit__ = AsyncMock(return_value=False)
|
|
|
|
manager = PipelineStateManager(redis_url=None, session_factory=mock_factory)
|
|
eid = await manager.create_execution("p", ["s1"])
|
|
await manager.fail_execution(eid, "s1", "error")
|
|
|
|
mock_session.merge.assert_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_execution_pg_fallback(self):
|
|
"""Test Redis miss falls back to PG."""
|
|
from agentkit.orchestrator.pipeline_models import PipelineExecutionModel
|
|
|
|
model = PipelineExecutionModel(
|
|
id="pg-exec-id",
|
|
pipeline_name="pg_pipeline",
|
|
status="completed",
|
|
created_at=datetime.now(timezone.utc),
|
|
updated_at=datetime.now(timezone.utc),
|
|
)
|
|
|
|
mock_result = MagicMock()
|
|
mock_result.scalar_one_or_none.return_value = model
|
|
|
|
mock_session = AsyncMock()
|
|
mock_session.execute = AsyncMock(return_value=mock_result)
|
|
|
|
mock_factory = MagicMock()
|
|
mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
|
mock_factory.return_value.__aexit__ = AsyncMock(return_value=False)
|
|
|
|
manager = PipelineStateManager(redis_url=None, session_factory=mock_factory)
|
|
# This execution_id doesn't exist in hot store, should fall back to PG
|
|
result = await manager.get_execution("pg-exec-id")
|
|
assert result is not None
|
|
assert result["pipeline_name"] == "pg_pipeline"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_executions_hot_first(self, manager: PipelineStateManager):
|
|
eid = await manager.create_execution("p", ["s1"])
|
|
results = await manager.list_executions()
|
|
assert len(results) == 1
|
|
assert results[0]["id"] == eid
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_health_check_memory_only(self, manager: PipelineStateManager):
|
|
health = await manager.health_check()
|
|
assert health["hot"] is True
|
|
assert health["cold"] is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_health_check_with_pg(self):
|
|
mock_factory = MagicMock()
|
|
manager = PipelineStateManager(redis_url=None, session_factory=mock_factory)
|
|
health = await manager.health_check()
|
|
assert health["hot"] is True
|
|
assert health["cold"] is True
|
|
|
|
|
|
# ═══════════════════════════════════════════════════════════════
|
|
# PipelineEngine with state persistence
|
|
# ═══════════════════════════════════════════════════════════════
|
|
|
|
|
|
class TestPipelineEngineWithState:
|
|
"""Tests for PipelineEngine integration with state persistence."""
|
|
|
|
@pytest.fixture
|
|
def pipeline(self) -> Pipeline:
|
|
return Pipeline(
|
|
name="test_pipeline",
|
|
version="1.0",
|
|
description="Test pipeline",
|
|
stages=[
|
|
PipelineStage(name="step_a", agent="agent1", action="do_a"),
|
|
PipelineStage(name="step_b", agent="agent2", action="do_b", depends_on=["step_a"]),
|
|
],
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_engine_without_state_backward_compatible(self, pipeline: Pipeline):
|
|
"""Engine without state_manager should work as before."""
|
|
engine = PipelineEngine(dispatcher=None)
|
|
result = await engine.execute(pipeline)
|
|
assert result.status == StageStatus.COMPLETED
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_engine_with_state_creates_execution(self, pipeline: Pipeline):
|
|
"""Engine with state_manager should create execution state."""
|
|
state_manager = PipelineStateManager(redis_url=None, session_factory=None)
|
|
engine = PipelineEngine(dispatcher=None, state_manager=state_manager)
|
|
result = await engine.execute(pipeline)
|
|
assert result.status == StageStatus.COMPLETED
|
|
# Check that execution was created in state store
|
|
executions = await state_manager.list_executions()
|
|
assert len(executions) == 1
|
|
assert executions[0]["status"] == "completed"
|
|
assert executions[0]["pipeline_name"] == "test_pipeline"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_engine_with_state_updates_steps(self, pipeline: Pipeline):
|
|
"""Engine should update step state after each stage."""
|
|
state_manager = PipelineStateManager(redis_url=None, session_factory=None)
|
|
engine = PipelineEngine(dispatcher=None, state_manager=state_manager)
|
|
await engine.execute(pipeline)
|
|
executions = await state_manager.list_executions()
|
|
exec_state = executions[0]
|
|
# Both steps should be completed
|
|
assert "step_a" in exec_state["completed_steps"]
|
|
assert "step_b" in exec_state["completed_steps"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_engine_with_state_on_failure(self):
|
|
"""Engine should persist failure state when a stage fails."""
|
|
pipeline = Pipeline(
|
|
name="fail_pipeline",
|
|
version="1.0",
|
|
description="Pipeline that fails",
|
|
stages=[
|
|
PipelineStage(name="bad_step", agent="agent1", action="fail"),
|
|
],
|
|
)
|
|
|
|
# Create a dispatcher that raises
|
|
mock_dispatcher = AsyncMock()
|
|
mock_dispatcher.dispatch = AsyncMock(side_effect=Exception("boom"))
|
|
|
|
state_manager = PipelineStateManager(redis_url=None, session_factory=None)
|
|
engine = PipelineEngine(dispatcher=mock_dispatcher, state_manager=state_manager)
|
|
result = await engine.execute(pipeline)
|
|
assert result.status == StageStatus.FAILED
|
|
# Check state was persisted
|
|
executions = await state_manager.list_executions()
|
|
assert len(executions) == 1
|
|
assert executions[0]["status"] == "failed"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_engine_state_survives_check(self, pipeline: Pipeline):
|
|
"""Verify state can be retrieved after execution."""
|
|
state_manager = PipelineStateManager(redis_url=None, session_factory=None)
|
|
engine = PipelineEngine(dispatcher=None, state_manager=state_manager)
|
|
result = await engine.execute(pipeline, context={"brand": "acme"})
|
|
# Get execution by ID
|
|
executions = await state_manager.list_executions()
|
|
eid = executions[0]["id"]
|
|
state = await state_manager.get_execution(eid)
|
|
assert state is not None
|
|
assert state["pipeline_name"] == "test_pipeline"
|
|
assert state["status"] == "completed"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_engine_with_circular_dependency(self):
|
|
"""Engine should handle circular dependency gracefully."""
|
|
pipeline = Pipeline(
|
|
name="circular",
|
|
version="1.0",
|
|
description="Circular pipeline",
|
|
stages=[
|
|
PipelineStage(name="a", agent="agent1", action="do", depends_on=["b"]),
|
|
PipelineStage(name="b", agent="agent2", action="do", depends_on=["a"]),
|
|
],
|
|
)
|
|
state_manager = PipelineStateManager(redis_url=None, session_factory=None)
|
|
engine = PipelineEngine(dispatcher=None, state_manager=state_manager)
|
|
result = await engine.execute(pipeline)
|
|
assert result.status == StageStatus.FAILED
|
|
assert "Circular" in result.error_message
|
|
# No execution state should be created (topological sort fails before creation)
|
|
executions = await state_manager.list_executions()
|
|
assert len(executions) == 0
|