feat(pipeline): U5 state persistence with Redis hot + PG cold dual-write

Add PipelineStateMemory/Redis/PG backends, PipelineStateManager with
Redis Sorted Set hot state + PostgreSQL JSONB cold persistence.
Integrated into PipelineEngine with state persistence calls at each
step transition.
This commit is contained in:
chiguyong 2026-06-07 17:25:52 +08:00
parent 2e547e345a
commit 4db637cd4f
5 changed files with 1421 additions and 6 deletions

View File

@ -5,6 +5,18 @@ from agentkit.orchestrator.pipeline_engine import PipelineEngine
from agentkit.orchestrator.pipeline_loader import PipelineLoader from agentkit.orchestrator.pipeline_loader import PipelineLoader
from agentkit.orchestrator.handoff import HandoffManager from agentkit.orchestrator.handoff import HandoffManager
from agentkit.orchestrator.dynamic_pipeline import DynamicPipeline from agentkit.orchestrator.dynamic_pipeline import DynamicPipeline
from agentkit.orchestrator.pipeline_state import (
PipelineStateMemory,
PipelineStateRedis,
PipelineStatePG,
PipelineStateManager,
)
from agentkit.orchestrator.retry import StepRetryPolicy, execute_with_retry
from agentkit.orchestrator.compensation import (
CompletedStep,
CompensationResult,
SagaOrchestrator,
)
__all__ = [ __all__ = [
"Pipeline", "Pipeline",
@ -14,4 +26,13 @@ __all__ = [
"PipelineLoader", "PipelineLoader",
"HandoffManager", "HandoffManager",
"DynamicPipeline", "DynamicPipeline",
"PipelineStateMemory",
"PipelineStateRedis",
"PipelineStatePG",
"PipelineStateManager",
"StepRetryPolicy",
"execute_with_retry",
"CompletedStep",
"CompensationResult",
"SagaOrchestrator",
] ]

View File

@ -1,4 +1,4 @@
"""Pipeline Engine - DAG + 并行执行""" """Pipeline Engine - DAG + 并行执行 + 步骤重试 + Saga 补偿"""
import asyncio import asyncio
import logging import logging
@ -6,6 +6,7 @@ from collections import defaultdict
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any from typing import Any
from agentkit.orchestrator.compensation import SagaOrchestrator
from agentkit.orchestrator.pipeline_schema import ( from agentkit.orchestrator.pipeline_schema import (
Pipeline, Pipeline,
PipelineResult, PipelineResult,
@ -13,6 +14,7 @@ from agentkit.orchestrator.pipeline_schema import (
StageResult, StageResult,
StageStatus, StageStatus,
) )
from agentkit.orchestrator.retry import StepRetryPolicy, execute_with_retry
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -25,11 +27,14 @@ class PipelineEngine:
- 同层并行执行asyncio.gather - 同层并行执行asyncio.gather
- 变量解析 - 变量解析
- 条件执行 - 条件执行
- 重试 - 步骤级指数退避重试StepRetryPolicy
- Saga 补偿LIFO 回滚已完成步骤
- 状态持久化可选
""" """
def __init__(self, dispatcher: Any = None): def __init__(self, dispatcher: Any = None, state_manager: Any = None):
self._dispatcher = dispatcher self._dispatcher = dispatcher
self._state_manager = state_manager
async def execute( async def execute(
self, self,
@ -48,6 +53,22 @@ class PipelineEngine:
result.error_message = str(e) result.error_message = str(e)
return result return result
# Create execution state if state_manager is configured
execution_id: str | None = None
if self._state_manager is not None:
try:
step_names = [s.name for s in pipeline.stages]
execution_id = await self._state_manager.create_execution(
pipeline_name=pipeline.name,
steps=step_names,
input_data=context,
)
except Exception as exc:
logger.warning(f"Failed to create execution state: {exc}")
# Create Saga orchestrator for compensation tracking
saga = SagaOrchestrator()
# 逐层执行 # 逐层执行
for level, stages in enumerate(level_groups): for level, stages in enumerate(level_groups):
logger.info(f"Pipeline '{pipeline.name}' executing level {level} with {len(stages)} stage(s)") logger.info(f"Pipeline '{pipeline.name}' executing level {level} with {len(stages)} stage(s)")
@ -55,7 +76,7 @@ class PipelineEngine:
# 并行执行同层 stages # 并行执行同层 stages
tasks = [] tasks = []
for stage in stages: for stage in stages:
tasks.append(self._execute_stage(stage, result)) tasks.append(self._execute_stage(stage, result, saga))
stage_results = await asyncio.gather(*tasks, return_exceptions=True) stage_results = await asyncio.gather(*tasks, return_exceptions=True)
@ -69,6 +90,22 @@ class PipelineEngine:
) )
result.stage_results[stage.name] = sr result.stage_results[stage.name] = sr
# Update step state
if self._state_manager is not None and execution_id is not None:
try:
step_status = "completed" if sr.status == StageStatus.COMPLETED else sr.status.value
step_output = sr.output_data if hasattr(sr, 'output_data') else None
step_error = sr.error_message if hasattr(sr, 'error_message') else None
await self._state_manager.update_step(
execution_id=execution_id,
step_name=stage.name,
status=step_status,
output=step_output,
error=step_error,
)
except Exception as exc:
logger.warning(f"Failed to update step state: {exc}")
# 收集输出变量 # 收集输出变量
if sr.output_data and isinstance(sr, dict): if sr.output_data and isinstance(sr, dict):
pass pass
@ -80,17 +117,56 @@ class PipelineEngine:
# 检查是否需要中止 # 检查是否需要中止
if hasattr(sr, 'status') and sr.status == StageStatus.FAILED: if hasattr(sr, 'status') and sr.status == StageStatus.FAILED:
if not stage.continue_on_failure: if not stage.continue_on_failure:
# Execute Saga compensation for completed steps
compensation_results = await saga.compensate()
if compensation_results:
failed_compensations = [
cr for cr in compensation_results if not cr.success and cr.error != "no_compensation_needed"
]
if failed_compensations:
logger.warning(
f"Compensation had {len(failed_compensations)} failures: "
f"{[c.step_name for c in failed_compensations]}"
)
result.status = StageStatus.FAILED result.status = StageStatus.FAILED
result.error_message = f"Stage '{stage.name}' failed" result.error_message = f"Stage '{stage.name}' failed"
# Fail execution state
if self._state_manager is not None and execution_id is not None:
try:
await self._state_manager.fail_execution(
execution_id=execution_id,
step_name=stage.name,
error=result.error_message,
)
except Exception as exc:
logger.warning(f"Failed to persist failure state: {exc}")
return result return result
result.status = StageStatus.COMPLETED result.status = StageStatus.COMPLETED
# Complete execution state
if self._state_manager is not None and execution_id is not None:
try:
final_output = {
name: sr.output_data
for name, sr in result.stage_results.items()
if sr.output_data is not None
}
await self._state_manager.complete_execution(
execution_id=execution_id,
final_output=final_output,
)
except Exception as exc:
logger.warning(f"Failed to persist completion state: {exc}")
return result return result
async def _execute_stage( async def _execute_stage(
self, self,
stage: PipelineStage, stage: PipelineStage,
pipeline_result: PipelineResult, pipeline_result: PipelineResult,
saga: SagaOrchestrator,
) -> StageResult: ) -> StageResult:
"""执行单个 stage""" """执行单个 stage"""
started_at = datetime.now(timezone.utc).isoformat() started_at = datetime.now(timezone.utc).isoformat()
@ -110,13 +186,20 @@ class PipelineEngine:
# 执行 # 执行
if self._dispatcher is None: if self._dispatcher is None:
# Dry-run 模式 # Dry-run 模式
return StageResult( result = StageResult(
stage_name=stage.name, stage_name=stage.name,
status=StageStatus.COMPLETED, status=StageStatus.COMPLETED,
output_data={"dry_run": True, "inputs": resolved_inputs}, output_data={"dry_run": True, "inputs": resolved_inputs},
started_at=started_at, started_at=started_at,
completed_at=datetime.now(timezone.utc).isoformat(), completed_at=datetime.now(timezone.utc).isoformat(),
) )
# Record completed step for Saga compensation
saga.record_completed(
step_name=stage.name,
result=result.output_data,
compensate_action=stage.compensate,
)
return result
# 通过 Dispatcher 分发任务 # 通过 Dispatcher 分发任务
from agentkit.core.protocol import TaskMessage from agentkit.core.protocol import TaskMessage
@ -133,7 +216,8 @@ class PipelineEngine:
timeout_seconds=stage.timeout_seconds, timeout_seconds=stage.timeout_seconds,
) )
try: async def _dispatch_and_wait() -> StageResult:
"""Dispatch task and wait for result"""
await self._dispatcher.dispatch(task) await self._dispatcher.dispatch(task)
# 等待结果 # 等待结果
@ -158,6 +242,24 @@ class PipelineEngine:
completed_at=datetime.now(timezone.utc).isoformat(), completed_at=datetime.now(timezone.utc).isoformat(),
) )
try:
# Execute with retry if retry_policy is configured
sr = await execute_with_retry(
func=_dispatch_and_wait,
retry_policy=stage.retry_policy,
step_name=stage.name,
)
# Record completed step for Saga compensation on success
if sr.status == StageStatus.COMPLETED:
saga.record_completed(
step_name=stage.name,
result=sr.output_data,
compensate_action=stage.compensate,
)
return sr
except Exception as e: except Exception as e:
return StageResult( return StageResult(
stage_name=stage.name, stage_name=stage.name,

View File

@ -0,0 +1,59 @@
"""Pipeline execution ORM models for PostgreSQL persistence."""
import uuid
from datetime import datetime, timezone
from sqlalchemy import Column, DateTime, Index, Integer, String, Text
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import DeclarativeBase
class Base(DeclarativeBase):
pass
class PipelineExecutionModel(Base):
"""Pipeline execution record — persisted final state."""
__tablename__ = "pipeline_executions"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
pipeline_name = Column(String(128), nullable=False, index=True)
status = Column(String(32), nullable=False, index=True)
current_step = Column(String(128))
completed_steps = Column(JSONB, default=list)
step_results = Column(JSONB, default=dict)
input_data = Column(JSONB)
final_output = Column(JSONB)
error_message = Column(Text)
tenant_id = Column(String(64), index=True)
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
updated_at = Column(
DateTime,
default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc),
)
completed_at = Column(DateTime)
__table_args__ = (
Index("ix_pipeline_status_created", "status", "created_at"),
)
class PipelineStepHistoryModel(Base):
"""Step execution history — audit trail."""
__tablename__ = "pipeline_step_history"
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
execution_id = Column(String(36), nullable=False, index=True)
step_name = Column(String(128), nullable=False)
step_index = Column(Integer, nullable=False)
status = Column(String(32), nullable=False)
input_data = Column(JSONB)
output_data = Column(JSONB)
error_message = Column(Text)
duration_ms = Column(Integer)
retry_attempt = Column(Integer, default=0)
started_at = Column(DateTime)
completed_at = Column(DateTime)

View File

@ -0,0 +1,572 @@
"""Pipeline execution state persistence — Redis hot state + PostgreSQL cold storage.
Architecture:
PipelineStateMemory In-memory fallback (always available, for testing)
PipelineStateRedis Redis hot state (low-latency reads/writes)
PipelineStatePG PostgreSQL cold persistence (durable audit trail)
PipelineStateManager Unified manager (Redis + PG dual write, fallback chain)
"""
from __future__ import annotations
import json
import logging
import uuid
from datetime import datetime, timezone
from typing import Any, Callable, Coroutine
from agentkit.orchestrator.pipeline_models import (
PipelineExecutionModel,
PipelineStepHistoryModel,
)
logger = logging.getLogger(__name__)
# Redis key patterns
_EXEC_KEY_PREFIX = "agentkit:pipeline:exec:"
_INDEX_KEY = "agentkit:pipeline:index"
_TTL_SECONDS = 7 * 24 * 3600 # 7 days
class PipelineStateMemory:
"""In-memory pipeline state storage (testing / fallback)."""
def __init__(self) -> None:
self._executions: dict[str, dict[str, Any]] = {}
self._step_history: dict[str, list[dict[str, Any]]] = {}
async def create_execution(
self,
pipeline_name: str,
steps: list[str],
input_data: dict[str, Any] | None = None,
tenant_id: str | None = None,
) -> str:
execution_id = str(uuid.uuid4())
now = datetime.now(timezone.utc).isoformat()
self._executions[execution_id] = {
"id": execution_id,
"pipeline_name": pipeline_name,
"status": "running",
"current_step": steps[0] if steps else None,
"completed_steps": [],
"step_results": {},
"input_data": input_data,
"final_output": None,
"error_message": None,
"tenant_id": tenant_id,
"created_at": now,
"updated_at": now,
"completed_at": None,
}
self._step_history[execution_id] = []
return execution_id
async def update_step(
self,
execution_id: str,
step_name: str,
status: str,
output: dict[str, Any] | None = None,
error: str | None = None,
duration_ms: int | None = None,
) -> None:
exec_state = self._executions.get(execution_id)
if exec_state is None:
logger.warning(f"Execution '{execution_id}' not found for step update")
return
exec_state["current_step"] = step_name
exec_state["updated_at"] = datetime.now(timezone.utc).isoformat()
if status == "completed":
if step_name not in exec_state["completed_steps"]:
exec_state["completed_steps"].append(step_name)
if output is not None:
exec_state["step_results"][step_name] = output
elif status == "failed":
exec_state["error_message"] = error
# Record step history event
step_event: dict[str, Any] = {
"id": str(uuid.uuid4()),
"execution_id": execution_id,
"step_name": step_name,
"status": status,
"output_data": output,
"error_message": error,
"duration_ms": duration_ms,
"started_at": datetime.now(timezone.utc).isoformat(),
"completed_at": datetime.now(timezone.utc).isoformat() if status in ("completed", "failed") else None,
}
self._step_history[execution_id].append(step_event)
async def complete_execution(
self,
execution_id: str,
final_output: dict[str, Any] | None = None,
) -> None:
exec_state = self._executions.get(execution_id)
if exec_state is None:
return
now = datetime.now(timezone.utc).isoformat()
exec_state["status"] = "completed"
exec_state["final_output"] = final_output
exec_state["updated_at"] = now
exec_state["completed_at"] = now
async def fail_execution(
self,
execution_id: str,
step_name: str,
error: str,
) -> None:
exec_state = self._executions.get(execution_id)
if exec_state is None:
return
now = datetime.now(timezone.utc).isoformat()
exec_state["status"] = "failed"
exec_state["error_message"] = f"Step '{step_name}' failed: {error}"
exec_state["updated_at"] = now
exec_state["completed_at"] = now
async def get_execution(self, execution_id: str) -> dict[str, Any] | None:
return self._executions.get(execution_id)
async def list_executions(
self,
status: str | None = None,
limit: int = 50,
offset: int = 0,
) -> list[dict[str, Any]]:
results = list(self._executions.values())
if status:
results = [e for e in results if e.get("status") == status]
results.sort(key=lambda e: e.get("created_at", ""), reverse=True)
return results[offset : offset + limit]
async def get_step_history(self, execution_id: str) -> list[dict[str, Any]]:
return self._step_history.get(execution_id, [])
class PipelineStateRedis:
"""Redis-backed pipeline state storage (hot state).
Uses Redis Hash for execution state and Sorted Set for indexing.
Falls back to PipelineStateMemory if Redis is unavailable.
"""
def __init__(self, redis_url: str = "redis://localhost:6379/0") -> None:
self._redis_url = redis_url
self._redis: Any = None
self._fallback = PipelineStateMemory()
self._use_fallback = False
async def _get_redis(self):
if self._redis is None:
import redis.asyncio as aioredis
self._redis = aioredis.from_url(
self._redis_url,
decode_responses=True,
)
return self._redis
async def _safe_redis_call(
self, fn: Callable[..., Coroutine[Any, Any, Any]], *args: Any, **kwargs: Any
) -> Any:
"""Execute a Redis call, falling back to memory on failure."""
if self._use_fallback:
return None
try:
redis = await self._get_redis()
return await fn(redis, *args, **kwargs)
except Exception as exc:
logger.warning(f"Redis operation failed, switching to memory fallback: {exc}")
self._use_fallback = True
self._redis = None
return None
def _key(self, execution_id: str) -> str:
return f"{_EXEC_KEY_PREFIX}{execution_id}"
async def create_execution(
self,
pipeline_name: str,
steps: list[str],
input_data: dict[str, Any] | None = None,
tenant_id: str | None = None,
) -> str:
# Always write to fallback first for consistency
execution_id = await self._fallback.create_execution(
pipeline_name, steps, input_data, tenant_id
)
# Try Redis
async def _redis_create(redis: Any) -> None:
state = self._fallback._executions[execution_id]
score = datetime.now(timezone.utc).timestamp()
pipe = redis.pipeline()
pipe.set(self._key(execution_id), json.dumps(state), ex=_TTL_SECONDS)
pipe.zadd(_INDEX_KEY, {execution_id: score})
await pipe.execute()
await self._safe_redis_call(_redis_create)
return execution_id
async def update_step(
self,
execution_id: str,
step_name: str,
status: str,
output: dict[str, Any] | None = None,
error: str | None = None,
duration_ms: int | None = None,
) -> None:
await self._fallback.update_step(execution_id, step_name, status, output, error, duration_ms)
async def _redis_update(redis: Any) -> None:
state = self._fallback._executions.get(execution_id)
if state is None:
return
await redis.set(self._key(execution_id), json.dumps(state), ex=_TTL_SECONDS)
await self._safe_redis_call(_redis_update)
async def complete_execution(
self,
execution_id: str,
final_output: dict[str, Any] | None = None,
) -> None:
await self._fallback.complete_execution(execution_id, final_output)
async def _redis_complete(redis: Any) -> None:
state = self._fallback._executions.get(execution_id)
if state is None:
return
await redis.set(self._key(execution_id), json.dumps(state), ex=_TTL_SECONDS)
await self._safe_redis_call(_redis_complete)
async def fail_execution(
self,
execution_id: str,
step_name: str,
error: str,
) -> None:
await self._fallback.fail_execution(execution_id, step_name, error)
async def _redis_fail(redis: Any) -> None:
state = self._fallback._executions.get(execution_id)
if state is None:
return
await redis.set(self._key(execution_id), json.dumps(state), ex=_TTL_SECONDS)
await self._safe_redis_call(_redis_fail)
async def get_execution(self, execution_id: str) -> dict[str, Any] | None:
# Try Redis first
if not self._use_fallback:
try:
redis = await self._get_redis()
raw = await redis.get(self._key(execution_id))
if raw is not None:
return json.loads(raw)
except Exception:
pass
# Fallback to memory
return await self._fallback.get_execution(execution_id)
async def list_executions(
self,
status: str | None = None,
limit: int = 50,
offset: int = 0,
) -> list[dict[str, Any]]:
# Try Redis sorted set for efficient listing
if not self._use_fallback:
try:
redis = await self._get_redis()
# Get recent execution IDs from sorted set (newest first)
ids = await redis.zrevrange(_INDEX_KEY, offset, offset + limit - 1)
if ids:
keys = [self._key(eid) for eid in ids]
values = await redis.mget(keys)
results = []
for raw in values:
if raw is None:
continue
state = json.loads(raw)
if status is None or state.get("status") == status:
results.append(state)
return results
except Exception:
pass
return await self._fallback.list_executions(status, limit, offset)
async def get_step_history(self, execution_id: str) -> list[dict[str, Any]]:
return await self._fallback.get_step_history(execution_id)
async def health_check(self) -> bool:
if self._use_fallback:
return False
try:
redis = await self._get_redis()
return await redis.ping()
except Exception:
return False
@property
def using_fallback(self) -> bool:
return self._use_fallback
class PipelineStatePG:
"""PostgreSQL cold persistence for pipeline execution records.
If session_factory is None, all methods are no-op.
"""
def __init__(self, session_factory: Any = None) -> None:
self._session_factory = session_factory
@property
def enabled(self) -> bool:
return self._session_factory is not None
async def persist_execution(self, state: dict[str, Any]) -> None:
"""Write a completed/failed execution to PostgreSQL."""
if not self.enabled:
return
try:
from sqlalchemy.ext.asyncio import AsyncSession
async with self._session_factory() as session:
model = PipelineExecutionModel(
id=state["id"],
pipeline_name=state["pipeline_name"],
status=state["status"],
current_step=state.get("current_step"),
completed_steps=state.get("completed_steps", []),
step_results=state.get("step_results", {}),
input_data=state.get("input_data"),
final_output=state.get("final_output"),
error_message=state.get("error_message"),
tenant_id=state.get("tenant_id"),
created_at=datetime.fromisoformat(state["created_at"]) if state.get("created_at") else None,
updated_at=datetime.fromisoformat(state["updated_at"]) if state.get("updated_at") else None,
completed_at=datetime.fromisoformat(state["completed_at"]) if state.get("completed_at") else None,
)
await session.merge(model)
await session.commit()
except Exception as exc:
logger.error(f"Failed to persist execution to PG: {exc}")
async def persist_step_history(
self, execution_id: str, steps: list[dict[str, Any]]
) -> None:
"""Write step history to PostgreSQL."""
if not self.enabled:
return
try:
async with self._session_factory() as session:
for idx, step in enumerate(steps):
model = PipelineStepHistoryModel(
id=step.get("id", str(uuid.uuid4())),
execution_id=execution_id,
step_name=step["step_name"],
step_index=idx,
status=step["status"],
input_data=step.get("input_data"),
output_data=step.get("output_data"),
error_message=step.get("error_message"),
duration_ms=step.get("duration_ms"),
retry_attempt=step.get("retry_attempt", 0),
started_at=datetime.fromisoformat(step["started_at"]) if step.get("started_at") else None,
completed_at=datetime.fromisoformat(step["completed_at"]) if step.get("completed_at") else None,
)
await session.merge(model)
await session.commit()
except Exception as exc:
logger.error(f"Failed to persist step history to PG: {exc}")
async def query_executions(
self,
pipeline_name: str | None = None,
status: str | None = None,
limit: int = 50,
offset: int = 0,
) -> list[dict[str, Any]]:
"""Query historical executions from PostgreSQL."""
if not self.enabled:
return []
try:
from sqlalchemy import select
async with self._session_factory() as session:
stmt = select(PipelineExecutionModel).order_by(
PipelineExecutionModel.created_at.desc()
)
if pipeline_name:
stmt = stmt.where(
PipelineExecutionModel.pipeline_name == pipeline_name
)
if status:
stmt = stmt.where(PipelineExecutionModel.status == status)
stmt = stmt.offset(offset).limit(limit)
result = await session.execute(stmt)
rows = result.scalars().all()
return [self._model_to_dict(row) for row in rows]
except Exception as exc:
logger.error(f"Failed to query executions from PG: {exc}")
return []
async def get_execution(self, execution_id: str) -> dict[str, Any] | None:
"""Get a single execution from PostgreSQL (for Redis miss fallback)."""
if not self.enabled:
return None
try:
from sqlalchemy import select
async with self._session_factory() as session:
stmt = select(PipelineExecutionModel).where(
PipelineExecutionModel.id == execution_id
)
result = await session.execute(stmt)
row = result.scalar_one_or_none()
if row is None:
return None
return self._model_to_dict(row)
except Exception as exc:
logger.error(f"Failed to get execution from PG: {exc}")
return None
@staticmethod
def _model_to_dict(model: PipelineExecutionModel) -> dict[str, Any]:
return {
"id": model.id,
"pipeline_name": model.pipeline_name,
"status": model.status,
"current_step": model.current_step,
"completed_steps": model.completed_steps or [],
"step_results": model.step_results or {},
"input_data": model.input_data,
"final_output": model.final_output,
"error_message": model.error_message,
"tenant_id": model.tenant_id,
"created_at": model.created_at.isoformat() if model.created_at else None,
"updated_at": model.updated_at.isoformat() if model.updated_at else None,
"completed_at": model.completed_at.isoformat() if model.completed_at else None,
}
class PipelineStateManager:
"""Unified pipeline state manager — Redis hot + PG cold.
- create / update Redis (with in-memory fallback)
- complete / fail Redis + async persist to PG
- get Redis first, PG fallback
- list Redis for recent, PG for historical
"""
def __init__(
self,
redis_url: str | None = None,
session_factory: Any = None,
) -> None:
if redis_url:
self._hot = PipelineStateRedis(redis_url=redis_url)
else:
self._hot = PipelineStateMemory()
self._cold = PipelineStatePG(session_factory=session_factory)
@property
def hot_store(self) -> PipelineStateMemory | PipelineStateRedis:
return self._hot
@property
def cold_store(self) -> PipelineStatePG:
return self._cold
async def create_execution(
self,
pipeline_name: str,
steps: list[str],
input_data: dict[str, Any] | None = None,
tenant_id: str | None = None,
) -> str:
return await self._hot.create_execution(pipeline_name, steps, input_data, tenant_id)
async def update_step(
self,
execution_id: str,
step_name: str,
status: str,
output: dict[str, Any] | None = None,
error: str | None = None,
duration_ms: int | None = None,
) -> None:
await self._hot.update_step(execution_id, step_name, status, output, error, duration_ms)
async def complete_execution(
self,
execution_id: str,
final_output: dict[str, Any] | None = None,
) -> None:
await self._hot.complete_execution(execution_id, final_output)
# Persist to PG
state = await self._hot.get_execution(execution_id)
if state:
await self._cold.persist_execution(state)
step_history = await self._hot.get_step_history(execution_id)
if step_history:
await self._cold.persist_step_history(execution_id, step_history)
async def fail_execution(
self,
execution_id: str,
step_name: str,
error: str,
) -> None:
await self._hot.fail_execution(execution_id, step_name, error)
# Persist to PG
state = await self._hot.get_execution(execution_id)
if state:
await self._cold.persist_execution(state)
step_history = await self._hot.get_step_history(execution_id)
if step_history:
await self._cold.persist_step_history(execution_id, step_history)
async def get_execution(self, execution_id: str) -> dict[str, Any] | None:
# Redis / memory first
state = await self._hot.get_execution(execution_id)
if state is not None:
return state
# PG fallback
return await self._cold.get_execution(execution_id)
async def list_executions(
self,
status: str | None = None,
limit: int = 50,
offset: int = 0,
) -> list[dict[str, Any]]:
# Hot store for recent executions
results = await self._hot.list_executions(status, limit, offset)
if results:
return results
# Cold store for historical queries
return await self._cold.query_executions(status=status, limit=limit, offset=offset)
async def get_step_history(self, execution_id: str) -> list[dict[str, Any]]:
return await self._hot.get_step_history(execution_id)
async def health_check(self) -> dict[str, bool]:
"""Check health of both stores."""
hot_ok = True
if isinstance(self._hot, PipelineStateRedis):
hot_ok = await self._hot.health_check()
cold_ok = self._cold.enabled
return {"hot": hot_ok, "cold": cold_ok}

View File

@ -0,0 +1,661 @@
"""Unit tests for Pipeline execution state persistence."""
from __future__ import annotations
import asyncio
import json
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from agentkit.orchestrator.pipeline_engine import PipelineEngine
from agentkit.orchestrator.pipeline_schema import Pipeline, PipelineStage, StageStatus
from agentkit.orchestrator.pipeline_state import (
PipelineStateMemory,
PipelineStatePG,
PipelineStateRedis,
PipelineStateManager,
)
# ═══════════════════════════════════════════════════════════════
# PipelineStateMemory
# ═══════════════════════════════════════════════════════════════
class TestPipelineStateMemory:
"""Tests for in-memory pipeline state storage."""
@pytest.fixture
def store(self) -> PipelineStateMemory:
return PipelineStateMemory()
@pytest.mark.asyncio
async def test_create_execution(self, store: PipelineStateMemory):
eid = await store.create_execution(
pipeline_name="test_pipeline",
steps=["step_a", "step_b"],
input_data={"key": "value"},
)
assert eid is not None
state = await store.get_execution(eid)
assert state is not None
assert state["pipeline_name"] == "test_pipeline"
assert state["status"] == "running"
assert state["current_step"] == "step_a"
assert state["completed_steps"] == []
assert state["input_data"] == {"key": "value"}
@pytest.mark.asyncio
async def test_update_step_completed(self, store: PipelineStateMemory):
eid = await store.create_execution("p", ["s1", "s2"])
await store.update_step(eid, "s1", "completed", output={"result": 42})
state = await store.get_execution(eid)
assert "s1" in state["completed_steps"]
assert state["step_results"]["s1"] == {"result": 42}
@pytest.mark.asyncio
async def test_update_step_failed(self, store: PipelineStateMemory):
eid = await store.create_execution("p", ["s1"])
await store.update_step(eid, "s1", "failed", error="boom")
state = await store.get_execution(eid)
assert state["error_message"] == "boom"
@pytest.mark.asyncio
async def test_complete_execution(self, store: PipelineStateMemory):
eid = await store.create_execution("p", ["s1"])
await store.complete_execution(eid, final_output={"done": True})
state = await store.get_execution(eid)
assert state["status"] == "completed"
assert state["final_output"] == {"done": True}
assert state["completed_at"] is not None
@pytest.mark.asyncio
async def test_fail_execution(self, store: PipelineStateMemory):
eid = await store.create_execution("p", ["s1"])
await store.fail_execution(eid, "s1", "timeout")
state = await store.get_execution(eid)
assert state["status"] == "failed"
assert "s1" in state["error_message"]
assert "timeout" in state["error_message"]
assert state["completed_at"] is not None
@pytest.mark.asyncio
async def test_get_execution_not_found(self, store: PipelineStateMemory):
result = await store.get_execution("nonexistent")
assert result is None
@pytest.mark.asyncio
async def test_list_executions(self, store: PipelineStateMemory):
eid1 = await store.create_execution("p1", ["s1"])
eid2 = await store.create_execution("p2", ["s2"])
await store.complete_execution(eid1)
# List all
all_execs = await store.list_executions()
assert len(all_execs) == 2
# Filter by status
completed = await store.list_executions(status="completed")
assert len(completed) == 1
assert completed[0]["id"] == eid1
@pytest.mark.asyncio
async def test_list_executions_pagination(self, store: PipelineStateMemory):
for i in range(5):
await store.create_execution(f"p{i}", ["s1"])
page1 = await store.list_executions(limit=2, offset=0)
page2 = await store.list_executions(limit=2, offset=2)
assert len(page1) == 2
assert len(page2) == 2
@pytest.mark.asyncio
async def test_get_step_history(self, store: PipelineStateMemory):
eid = await store.create_execution("p", ["s1", "s2"])
await store.update_step(eid, "s1", "completed", output={"r": 1})
await store.update_step(eid, "s2", "failed", error="err")
history = await store.get_step_history(eid)
assert len(history) == 2
assert history[0]["step_name"] == "s1"
assert history[0]["status"] == "completed"
assert history[1]["step_name"] == "s2"
assert history[1]["status"] == "failed"
@pytest.mark.asyncio
async def test_update_step_nonexistent_execution(self, store: PipelineStateMemory):
# Should not raise, just log warning
await store.update_step("nonexistent", "s1", "completed")
@pytest.mark.asyncio
async def test_create_execution_with_tenant(self, store: PipelineStateMemory):
eid = await store.create_execution("p", ["s1"], tenant_id="tenant_123")
state = await store.get_execution(eid)
assert state["tenant_id"] == "tenant_123"
# ═══════════════════════════════════════════════════════════════
# PipelineStateRedis
# ═══════════════════════════════════════════════════════════════
class TestPipelineStateRedis:
"""Tests for Redis-backed pipeline state storage (using mocks)."""
@pytest.fixture
def mock_redis(self):
"""Create a mock Redis client."""
redis = AsyncMock()
redis.get = AsyncMock(return_value=None)
redis.set = AsyncMock(return_value=True)
redis.zadd = AsyncMock(return_value=1)
redis.zrevrange = AsyncMock(return_value=[])
redis.mget = AsyncMock(return_value=[])
# Redis pipeline: set/zadd are synchronous (return self for chaining), execute is async
pipe = MagicMock()
pipe.set = MagicMock(return_value=pipe)
pipe.zadd = MagicMock(return_value=pipe)
pipe.execute = AsyncMock(return_value=[True, 1])
redis.pipeline = MagicMock(return_value=pipe)
return redis
@pytest.fixture
def store(self, mock_redis) -> PipelineStateRedis:
"""Create a PipelineStateRedis with mocked Redis."""
store = PipelineStateRedis(redis_url="redis://localhost:6379/0")
# Pre-inject the mock Redis client
store._redis = mock_redis
return store
@pytest.mark.asyncio
async def test_create_execution_writes_to_redis(self, store: PipelineStateRedis, mock_redis):
eid = await store.create_execution("test_pipeline", ["s1", "s2"])
assert eid is not None
# Redis pipeline should have been used (pipe.set + pipe.zadd + pipe.execute)
pipe = mock_redis.pipeline.return_value
pipe.set.assert_called_once()
call_args = pipe.set.call_args
assert call_args[0][0].startswith("agentkit:pipeline:exec:")
# Verify the stored data is valid JSON
stored_data = json.loads(call_args[0][1])
assert stored_data["pipeline_name"] == "test_pipeline"
assert stored_data["status"] == "running"
# Verify TTL was set (7 days)
assert call_args[1].get("ex") == 7 * 24 * 3600
@pytest.mark.asyncio
async def test_create_execution_adds_to_sorted_set(self, store: PipelineStateRedis, mock_redis):
await store.create_execution("p", ["s1"])
# ZADD should have been called via pipeline
pipe = mock_redis.pipeline.return_value
pipe.zadd.assert_called_once()
@pytest.mark.asyncio
async def test_update_step_writes_to_redis(self, store: PipelineStateRedis, mock_redis):
eid = await store.create_execution("p", ["s1"])
mock_redis.set.reset_mock()
await store.update_step(eid, "s1", "completed", output={"r": 1})
mock_redis.set.assert_called_once()
@pytest.mark.asyncio
async def test_complete_execution_writes_to_redis(self, store: PipelineStateRedis, mock_redis):
eid = await store.create_execution("p", ["s1"])
mock_redis.set.reset_mock()
await store.complete_execution(eid, final_output={"done": True})
mock_redis.set.assert_called_once()
@pytest.mark.asyncio
async def test_fail_execution_writes_to_redis(self, store: PipelineStateRedis, mock_redis):
eid = await store.create_execution("p", ["s1"])
mock_redis.set.reset_mock()
await store.fail_execution(eid, "s1", "error")
mock_redis.set.assert_called_once()
@pytest.mark.asyncio
async def test_get_execution_from_redis(self, store: PipelineStateRedis, mock_redis):
eid = await store.create_execution("p", ["s1"])
# Simulate Redis returning data
state = await store._fallback.get_execution(eid)
mock_redis.get.return_value = json.dumps(state)
result = await store.get_execution(eid)
assert result is not None
assert result["pipeline_name"] == "p"
@pytest.mark.asyncio
async def test_get_execution_redis_miss_falls_back_to_memory(self, store: PipelineStateRedis, mock_redis):
eid = await store.create_execution("p", ["s1"])
# Redis returns None (miss)
mock_redis.get.return_value = None
# Should still find it in memory fallback
result = await store.get_execution(eid)
assert result is not None
assert result["pipeline_name"] == "p"
@pytest.mark.asyncio
async def test_list_executions_from_sorted_set(self, store: PipelineStateRedis, mock_redis):
eid = await store.create_execution("p", ["s1"])
state = await store._fallback.get_execution(eid)
mock_redis.zrevrange.return_value = [eid]
mock_redis.mget.return_value = [json.dumps(state)]
results = await store.list_executions()
assert len(results) == 1
assert results[0]["pipeline_name"] == "p"
@pytest.mark.asyncio
async def test_fallback_on_redis_failure(self, mock_redis):
store = PipelineStateRedis(redis_url="redis://localhost:6379/0")
# Make Redis initialization fail
mock_redis.ping = AsyncMock(side_effect=Exception("connection refused"))
store._redis = mock_redis
# Force a Redis operation to fail
mock_redis.set = AsyncMock(side_effect=Exception("connection refused"))
mock_redis.pipeline = MagicMock(side_effect=Exception("connection refused"))
# Should fall back to memory
eid = await store.create_execution("p", ["s1"])
assert eid is not None
assert store.using_fallback is True
@pytest.mark.asyncio
async def test_health_check(self, store: PipelineStateRedis, mock_redis):
mock_redis.ping = AsyncMock(return_value=True)
assert await store.health_check() is True
mock_redis.ping = AsyncMock(side_effect=Exception("fail"))
assert await store.health_check() is False
# ═══════════════════════════════════════════════════════════════
# PipelineStatePG
# ═══════════════════════════════════════════════════════════════
class TestPipelineStatePG:
"""Tests for PostgreSQL cold persistence (using mocks)."""
@pytest.mark.asyncio
async def test_no_op_when_session_factory_is_none(self):
pg = PipelineStatePG(session_factory=None)
assert pg.enabled is False
# All methods should be no-op
await pg.persist_execution({"id": "1", "pipeline_name": "p", "status": "completed"})
await pg.persist_step_history("1", [])
result = await pg.query_executions()
assert result == []
result = await pg.get_execution("1")
assert result is None
@pytest.mark.asyncio
async def test_persist_execution(self):
mock_session = AsyncMock()
mock_session.merge = AsyncMock()
mock_session.commit = AsyncMock()
mock_factory = MagicMock()
mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session)
mock_factory.return_value.__aexit__ = AsyncMock(return_value=False)
pg = PipelineStatePG(session_factory=mock_factory)
assert pg.enabled is True
state = {
"id": "test-id-123",
"pipeline_name": "test_pipeline",
"status": "completed",
"current_step": None,
"completed_steps": ["s1"],
"step_results": {"s1": {"r": 1}},
"input_data": {"key": "val"},
"final_output": {"done": True},
"error_message": None,
"tenant_id": None,
"created_at": datetime.now(timezone.utc).isoformat(),
"updated_at": datetime.now(timezone.utc).isoformat(),
"completed_at": datetime.now(timezone.utc).isoformat(),
}
await pg.persist_execution(state)
mock_session.merge.assert_called_once()
mock_session.commit.assert_called_once()
@pytest.mark.asyncio
async def test_persist_step_history(self):
mock_session = AsyncMock()
mock_session.merge = AsyncMock()
mock_session.commit = AsyncMock()
mock_factory = MagicMock()
mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session)
mock_factory.return_value.__aexit__ = AsyncMock(return_value=False)
pg = PipelineStatePG(session_factory=mock_factory)
steps = [
{
"id": "step-id-1",
"step_name": "s1",
"status": "completed",
"output_data": {"r": 1},
"error_message": None,
"duration_ms": 100,
"started_at": datetime.now(timezone.utc).isoformat(),
"completed_at": datetime.now(timezone.utc).isoformat(),
}
]
await pg.persist_step_history("exec-1", steps)
mock_session.merge.assert_called_once()
mock_session.commit.assert_called_once()
@pytest.mark.asyncio
async def test_persist_execution_handles_error(self):
mock_session = AsyncMock()
mock_session.merge = AsyncMock(side_effect=Exception("DB error"))
mock_session.commit = AsyncMock()
mock_factory = MagicMock()
mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session)
mock_factory.return_value.__aexit__ = AsyncMock(return_value=False)
pg = PipelineStatePG(session_factory=mock_factory)
# Should not raise
await pg.persist_execution({
"id": "1",
"pipeline_name": "p",
"status": "completed",
"created_at": datetime.now(timezone.utc).isoformat(),
"updated_at": datetime.now(timezone.utc).isoformat(),
})
@pytest.mark.asyncio
async def test_query_executions(self):
from agentkit.orchestrator.pipeline_models import PipelineExecutionModel
# Create a mock model instance
model = PipelineExecutionModel(
id="test-id",
pipeline_name="test_pipeline",
status="completed",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = [model]
mock_session = AsyncMock()
mock_session.execute = AsyncMock(return_value=mock_result)
mock_factory = MagicMock()
mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session)
mock_factory.return_value.__aexit__ = AsyncMock(return_value=False)
pg = PipelineStatePG(session_factory=mock_factory)
results = await pg.query_executions(pipeline_name="test_pipeline")
assert len(results) == 1
assert results[0]["pipeline_name"] == "test_pipeline"
@pytest.mark.asyncio
async def test_get_execution_found(self):
from agentkit.orchestrator.pipeline_models import PipelineExecutionModel
model = PipelineExecutionModel(
id="test-id",
pipeline_name="test_pipeline",
status="completed",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = model
mock_session = AsyncMock()
mock_session.execute = AsyncMock(return_value=mock_result)
mock_factory = MagicMock()
mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session)
mock_factory.return_value.__aexit__ = AsyncMock(return_value=False)
pg = PipelineStatePG(session_factory=mock_factory)
result = await pg.get_execution("test-id")
assert result is not None
assert result["id"] == "test-id"
@pytest.mark.asyncio
async def test_get_execution_not_found(self):
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
mock_session = AsyncMock()
mock_session.execute = AsyncMock(return_value=mock_result)
mock_factory = MagicMock()
mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session)
mock_factory.return_value.__aexit__ = AsyncMock(return_value=False)
pg = PipelineStatePG(session_factory=mock_factory)
result = await pg.get_execution("nonexistent")
assert result is None
# ═══════════════════════════════════════════════════════════════
# PipelineStateManager
# ═══════════════════════════════════════════════════════════════
class TestPipelineStateManager:
"""Tests for the unified state manager."""
@pytest.fixture
def manager(self) -> PipelineStateManager:
"""Create a manager with memory-only backend."""
return PipelineStateManager(redis_url=None, session_factory=None)
@pytest.mark.asyncio
async def test_create_and_get_execution(self, manager: PipelineStateManager):
eid = await manager.create_execution("p", ["s1"], input_data={"k": "v"})
state = await manager.get_execution(eid)
assert state is not None
assert state["pipeline_name"] == "p"
assert state["status"] == "running"
@pytest.mark.asyncio
async def test_update_step(self, manager: PipelineStateManager):
eid = await manager.create_execution("p", ["s1"])
await manager.update_step(eid, "s1", "completed", output={"r": 1})
state = await manager.get_execution(eid)
assert "s1" in state["completed_steps"]
@pytest.mark.asyncio
async def test_complete_persists_to_cold(self):
"""Test that completing an execution triggers PG persist."""
mock_session = AsyncMock()
mock_session.merge = AsyncMock()
mock_session.commit = AsyncMock()
mock_factory = MagicMock()
mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session)
mock_factory.return_value.__aexit__ = AsyncMock(return_value=False)
manager = PipelineStateManager(redis_url=None, session_factory=mock_factory)
eid = await manager.create_execution("p", ["s1"])
await manager.update_step(eid, "s1", "completed", output={"r": 1})
await manager.complete_execution(eid, final_output={"done": True})
# PG persist should have been called
mock_session.merge.assert_called()
@pytest.mark.asyncio
async def test_fail_persists_to_cold(self):
"""Test that failing an execution triggers PG persist."""
mock_session = AsyncMock()
mock_session.merge = AsyncMock()
mock_session.commit = AsyncMock()
mock_factory = MagicMock()
mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session)
mock_factory.return_value.__aexit__ = AsyncMock(return_value=False)
manager = PipelineStateManager(redis_url=None, session_factory=mock_factory)
eid = await manager.create_execution("p", ["s1"])
await manager.fail_execution(eid, "s1", "error")
mock_session.merge.assert_called()
@pytest.mark.asyncio
async def test_get_execution_pg_fallback(self):
"""Test Redis miss falls back to PG."""
from agentkit.orchestrator.pipeline_models import PipelineExecutionModel
model = PipelineExecutionModel(
id="pg-exec-id",
pipeline_name="pg_pipeline",
status="completed",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = model
mock_session = AsyncMock()
mock_session.execute = AsyncMock(return_value=mock_result)
mock_factory = MagicMock()
mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session)
mock_factory.return_value.__aexit__ = AsyncMock(return_value=False)
manager = PipelineStateManager(redis_url=None, session_factory=mock_factory)
# This execution_id doesn't exist in hot store, should fall back to PG
result = await manager.get_execution("pg-exec-id")
assert result is not None
assert result["pipeline_name"] == "pg_pipeline"
@pytest.mark.asyncio
async def test_list_executions_hot_first(self, manager: PipelineStateManager):
eid = await manager.create_execution("p", ["s1"])
results = await manager.list_executions()
assert len(results) == 1
assert results[0]["id"] == eid
@pytest.mark.asyncio
async def test_health_check_memory_only(self, manager: PipelineStateManager):
health = await manager.health_check()
assert health["hot"] is True
assert health["cold"] is False
@pytest.mark.asyncio
async def test_health_check_with_pg(self):
mock_factory = MagicMock()
manager = PipelineStateManager(redis_url=None, session_factory=mock_factory)
health = await manager.health_check()
assert health["hot"] is True
assert health["cold"] is True
# ═══════════════════════════════════════════════════════════════
# PipelineEngine with state persistence
# ═══════════════════════════════════════════════════════════════
class TestPipelineEngineWithState:
"""Tests for PipelineEngine integration with state persistence."""
@pytest.fixture
def pipeline(self) -> Pipeline:
return Pipeline(
name="test_pipeline",
version="1.0",
description="Test pipeline",
stages=[
PipelineStage(name="step_a", agent="agent1", action="do_a"),
PipelineStage(name="step_b", agent="agent2", action="do_b", depends_on=["step_a"]),
],
)
@pytest.mark.asyncio
async def test_engine_without_state_backward_compatible(self, pipeline: Pipeline):
"""Engine without state_manager should work as before."""
engine = PipelineEngine(dispatcher=None)
result = await engine.execute(pipeline)
assert result.status == StageStatus.COMPLETED
@pytest.mark.asyncio
async def test_engine_with_state_creates_execution(self, pipeline: Pipeline):
"""Engine with state_manager should create execution state."""
state_manager = PipelineStateManager(redis_url=None, session_factory=None)
engine = PipelineEngine(dispatcher=None, state_manager=state_manager)
result = await engine.execute(pipeline)
assert result.status == StageStatus.COMPLETED
# Check that execution was created in state store
executions = await state_manager.list_executions()
assert len(executions) == 1
assert executions[0]["status"] == "completed"
assert executions[0]["pipeline_name"] == "test_pipeline"
@pytest.mark.asyncio
async def test_engine_with_state_updates_steps(self, pipeline: Pipeline):
"""Engine should update step state after each stage."""
state_manager = PipelineStateManager(redis_url=None, session_factory=None)
engine = PipelineEngine(dispatcher=None, state_manager=state_manager)
await engine.execute(pipeline)
executions = await state_manager.list_executions()
exec_state = executions[0]
# Both steps should be completed
assert "step_a" in exec_state["completed_steps"]
assert "step_b" in exec_state["completed_steps"]
@pytest.mark.asyncio
async def test_engine_with_state_on_failure(self):
"""Engine should persist failure state when a stage fails."""
pipeline = Pipeline(
name="fail_pipeline",
version="1.0",
description="Pipeline that fails",
stages=[
PipelineStage(name="bad_step", agent="agent1", action="fail"),
],
)
# Create a dispatcher that raises
mock_dispatcher = AsyncMock()
mock_dispatcher.dispatch = AsyncMock(side_effect=Exception("boom"))
state_manager = PipelineStateManager(redis_url=None, session_factory=None)
engine = PipelineEngine(dispatcher=mock_dispatcher, state_manager=state_manager)
result = await engine.execute(pipeline)
assert result.status == StageStatus.FAILED
# Check state was persisted
executions = await state_manager.list_executions()
assert len(executions) == 1
assert executions[0]["status"] == "failed"
@pytest.mark.asyncio
async def test_engine_state_survives_check(self, pipeline: Pipeline):
"""Verify state can be retrieved after execution."""
state_manager = PipelineStateManager(redis_url=None, session_factory=None)
engine = PipelineEngine(dispatcher=None, state_manager=state_manager)
result = await engine.execute(pipeline, context={"brand": "acme"})
# Get execution by ID
executions = await state_manager.list_executions()
eid = executions[0]["id"]
state = await state_manager.get_execution(eid)
assert state is not None
assert state["pipeline_name"] == "test_pipeline"
assert state["status"] == "completed"
@pytest.mark.asyncio
async def test_engine_with_circular_dependency(self):
"""Engine should handle circular dependency gracefully."""
pipeline = Pipeline(
name="circular",
version="1.0",
description="Circular pipeline",
stages=[
PipelineStage(name="a", agent="agent1", action="do", depends_on=["b"]),
PipelineStage(name="b", agent="agent2", action="do", depends_on=["a"]),
],
)
state_manager = PipelineStateManager(redis_url=None, session_factory=None)
engine = PipelineEngine(dispatcher=None, state_manager=state_manager)
result = await engine.execute(pipeline)
assert result.status == StageStatus.FAILED
assert "Circular" in result.error_message
# No execution state should be created (topological sort fails before creation)
executions = await state_manager.list_executions()
assert len(executions) == 0