diff --git a/src/agentkit/orchestrator/__init__.py b/src/agentkit/orchestrator/__init__.py index 0907993..3658902 100644 --- a/src/agentkit/orchestrator/__init__.py +++ b/src/agentkit/orchestrator/__init__.py @@ -5,6 +5,18 @@ from agentkit.orchestrator.pipeline_engine import PipelineEngine from agentkit.orchestrator.pipeline_loader import PipelineLoader from agentkit.orchestrator.handoff import HandoffManager from agentkit.orchestrator.dynamic_pipeline import DynamicPipeline +from agentkit.orchestrator.pipeline_state import ( + PipelineStateMemory, + PipelineStateRedis, + PipelineStatePG, + PipelineStateManager, +) +from agentkit.orchestrator.retry import StepRetryPolicy, execute_with_retry +from agentkit.orchestrator.compensation import ( + CompletedStep, + CompensationResult, + SagaOrchestrator, +) __all__ = [ "Pipeline", @@ -14,4 +26,13 @@ __all__ = [ "PipelineLoader", "HandoffManager", "DynamicPipeline", + "PipelineStateMemory", + "PipelineStateRedis", + "PipelineStatePG", + "PipelineStateManager", + "StepRetryPolicy", + "execute_with_retry", + "CompletedStep", + "CompensationResult", + "SagaOrchestrator", ] diff --git a/src/agentkit/orchestrator/pipeline_engine.py b/src/agentkit/orchestrator/pipeline_engine.py index 26bca97..3262fe9 100644 --- a/src/agentkit/orchestrator/pipeline_engine.py +++ b/src/agentkit/orchestrator/pipeline_engine.py @@ -1,4 +1,4 @@ -"""Pipeline Engine - DAG + 并行执行""" +"""Pipeline Engine - DAG + 并行执行 + 步骤重试 + Saga 补偿""" import asyncio import logging @@ -6,6 +6,7 @@ from collections import defaultdict from datetime import datetime, timezone from typing import Any +from agentkit.orchestrator.compensation import SagaOrchestrator from agentkit.orchestrator.pipeline_schema import ( Pipeline, PipelineResult, @@ -13,6 +14,7 @@ from agentkit.orchestrator.pipeline_schema import ( StageResult, StageStatus, ) +from agentkit.orchestrator.retry import StepRetryPolicy, execute_with_retry logger = logging.getLogger(__name__) @@ -25,11 +27,14 @@ class PipelineEngine: - 同层并行执行(asyncio.gather) - 变量解析 - 条件执行 - - 重试 + - 步骤级指数退避重试(StepRetryPolicy) + - Saga 补偿(LIFO 回滚已完成步骤) + - 状态持久化(可选) """ - def __init__(self, dispatcher: Any = None): + def __init__(self, dispatcher: Any = None, state_manager: Any = None): self._dispatcher = dispatcher + self._state_manager = state_manager async def execute( self, @@ -48,6 +53,22 @@ class PipelineEngine: result.error_message = str(e) return result + # Create execution state if state_manager is configured + execution_id: str | None = None + if self._state_manager is not None: + try: + step_names = [s.name for s in pipeline.stages] + execution_id = await self._state_manager.create_execution( + pipeline_name=pipeline.name, + steps=step_names, + input_data=context, + ) + except Exception as exc: + logger.warning(f"Failed to create execution state: {exc}") + + # Create Saga orchestrator for compensation tracking + saga = SagaOrchestrator() + # 逐层执行 for level, stages in enumerate(level_groups): logger.info(f"Pipeline '{pipeline.name}' executing level {level} with {len(stages)} stage(s)") @@ -55,7 +76,7 @@ class PipelineEngine: # 并行执行同层 stages tasks = [] for stage in stages: - tasks.append(self._execute_stage(stage, result)) + tasks.append(self._execute_stage(stage, result, saga)) stage_results = await asyncio.gather(*tasks, return_exceptions=True) @@ -69,6 +90,22 @@ class PipelineEngine: ) result.stage_results[stage.name] = sr + # Update step state + if self._state_manager is not None and execution_id is not None: + try: + step_status = "completed" if sr.status == StageStatus.COMPLETED else sr.status.value + step_output = sr.output_data if hasattr(sr, 'output_data') else None + step_error = sr.error_message if hasattr(sr, 'error_message') else None + await self._state_manager.update_step( + execution_id=execution_id, + step_name=stage.name, + status=step_status, + output=step_output, + error=step_error, + ) + except Exception as exc: + logger.warning(f"Failed to update step state: {exc}") + # 收集输出变量 if sr.output_data and isinstance(sr, dict): pass @@ -80,17 +117,56 @@ class PipelineEngine: # 检查是否需要中止 if hasattr(sr, 'status') and sr.status == StageStatus.FAILED: if not stage.continue_on_failure: + # Execute Saga compensation for completed steps + compensation_results = await saga.compensate() + if compensation_results: + failed_compensations = [ + cr for cr in compensation_results if not cr.success and cr.error != "no_compensation_needed" + ] + if failed_compensations: + logger.warning( + f"Compensation had {len(failed_compensations)} failures: " + f"{[c.step_name for c in failed_compensations]}" + ) + result.status = StageStatus.FAILED result.error_message = f"Stage '{stage.name}' failed" + # Fail execution state + if self._state_manager is not None and execution_id is not None: + try: + await self._state_manager.fail_execution( + execution_id=execution_id, + step_name=stage.name, + error=result.error_message, + ) + except Exception as exc: + logger.warning(f"Failed to persist failure state: {exc}") return result result.status = StageStatus.COMPLETED + + # Complete execution state + if self._state_manager is not None and execution_id is not None: + try: + final_output = { + name: sr.output_data + for name, sr in result.stage_results.items() + if sr.output_data is not None + } + await self._state_manager.complete_execution( + execution_id=execution_id, + final_output=final_output, + ) + except Exception as exc: + logger.warning(f"Failed to persist completion state: {exc}") + return result async def _execute_stage( self, stage: PipelineStage, pipeline_result: PipelineResult, + saga: SagaOrchestrator, ) -> StageResult: """执行单个 stage""" started_at = datetime.now(timezone.utc).isoformat() @@ -110,13 +186,20 @@ class PipelineEngine: # 执行 if self._dispatcher is None: # Dry-run 模式 - return StageResult( + result = StageResult( stage_name=stage.name, status=StageStatus.COMPLETED, output_data={"dry_run": True, "inputs": resolved_inputs}, started_at=started_at, completed_at=datetime.now(timezone.utc).isoformat(), ) + # Record completed step for Saga compensation + saga.record_completed( + step_name=stage.name, + result=result.output_data, + compensate_action=stage.compensate, + ) + return result # 通过 Dispatcher 分发任务 from agentkit.core.protocol import TaskMessage @@ -133,7 +216,8 @@ class PipelineEngine: timeout_seconds=stage.timeout_seconds, ) - try: + async def _dispatch_and_wait() -> StageResult: + """Dispatch task and wait for result""" await self._dispatcher.dispatch(task) # 等待结果 @@ -158,6 +242,24 @@ class PipelineEngine: completed_at=datetime.now(timezone.utc).isoformat(), ) + try: + # Execute with retry if retry_policy is configured + sr = await execute_with_retry( + func=_dispatch_and_wait, + retry_policy=stage.retry_policy, + step_name=stage.name, + ) + + # Record completed step for Saga compensation on success + if sr.status == StageStatus.COMPLETED: + saga.record_completed( + step_name=stage.name, + result=sr.output_data, + compensate_action=stage.compensate, + ) + + return sr + except Exception as e: return StageResult( stage_name=stage.name, diff --git a/src/agentkit/orchestrator/pipeline_models.py b/src/agentkit/orchestrator/pipeline_models.py new file mode 100644 index 0000000..3fa1208 --- /dev/null +++ b/src/agentkit/orchestrator/pipeline_models.py @@ -0,0 +1,59 @@ +"""Pipeline execution ORM models for PostgreSQL persistence.""" + +import uuid +from datetime import datetime, timezone + +from sqlalchemy import Column, DateTime, Index, Integer, String, Text +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.orm import DeclarativeBase + + +class Base(DeclarativeBase): + pass + + +class PipelineExecutionModel(Base): + """Pipeline execution record — persisted final state.""" + + __tablename__ = "pipeline_executions" + + id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + pipeline_name = Column(String(128), nullable=False, index=True) + status = Column(String(32), nullable=False, index=True) + current_step = Column(String(128)) + completed_steps = Column(JSONB, default=list) + step_results = Column(JSONB, default=dict) + input_data = Column(JSONB) + final_output = Column(JSONB) + error_message = Column(Text) + tenant_id = Column(String(64), index=True) + created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc)) + updated_at = Column( + DateTime, + default=lambda: datetime.now(timezone.utc), + onupdate=lambda: datetime.now(timezone.utc), + ) + completed_at = Column(DateTime) + + __table_args__ = ( + Index("ix_pipeline_status_created", "status", "created_at"), + ) + + +class PipelineStepHistoryModel(Base): + """Step execution history — audit trail.""" + + __tablename__ = "pipeline_step_history" + + id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + execution_id = Column(String(36), nullable=False, index=True) + step_name = Column(String(128), nullable=False) + step_index = Column(Integer, nullable=False) + status = Column(String(32), nullable=False) + input_data = Column(JSONB) + output_data = Column(JSONB) + error_message = Column(Text) + duration_ms = Column(Integer) + retry_attempt = Column(Integer, default=0) + started_at = Column(DateTime) + completed_at = Column(DateTime) diff --git a/src/agentkit/orchestrator/pipeline_state.py b/src/agentkit/orchestrator/pipeline_state.py new file mode 100644 index 0000000..a266803 --- /dev/null +++ b/src/agentkit/orchestrator/pipeline_state.py @@ -0,0 +1,572 @@ +"""Pipeline execution state persistence — Redis hot state + PostgreSQL cold storage. + +Architecture: + PipelineStateMemory — In-memory fallback (always available, for testing) + PipelineStateRedis — Redis hot state (low-latency reads/writes) + PipelineStatePG — PostgreSQL cold persistence (durable audit trail) + PipelineStateManager — Unified manager (Redis + PG dual write, fallback chain) +""" + +from __future__ import annotations + +import json +import logging +import uuid +from datetime import datetime, timezone +from typing import Any, Callable, Coroutine + +from agentkit.orchestrator.pipeline_models import ( + PipelineExecutionModel, + PipelineStepHistoryModel, +) + +logger = logging.getLogger(__name__) + +# Redis key patterns +_EXEC_KEY_PREFIX = "agentkit:pipeline:exec:" +_INDEX_KEY = "agentkit:pipeline:index" +_TTL_SECONDS = 7 * 24 * 3600 # 7 days + + +class PipelineStateMemory: + """In-memory pipeline state storage (testing / fallback).""" + + def __init__(self) -> None: + self._executions: dict[str, dict[str, Any]] = {} + self._step_history: dict[str, list[dict[str, Any]]] = {} + + async def create_execution( + self, + pipeline_name: str, + steps: list[str], + input_data: dict[str, Any] | None = None, + tenant_id: str | None = None, + ) -> str: + execution_id = str(uuid.uuid4()) + now = datetime.now(timezone.utc).isoformat() + self._executions[execution_id] = { + "id": execution_id, + "pipeline_name": pipeline_name, + "status": "running", + "current_step": steps[0] if steps else None, + "completed_steps": [], + "step_results": {}, + "input_data": input_data, + "final_output": None, + "error_message": None, + "tenant_id": tenant_id, + "created_at": now, + "updated_at": now, + "completed_at": None, + } + self._step_history[execution_id] = [] + return execution_id + + async def update_step( + self, + execution_id: str, + step_name: str, + status: str, + output: dict[str, Any] | None = None, + error: str | None = None, + duration_ms: int | None = None, + ) -> None: + exec_state = self._executions.get(execution_id) + if exec_state is None: + logger.warning(f"Execution '{execution_id}' not found for step update") + return + + exec_state["current_step"] = step_name + exec_state["updated_at"] = datetime.now(timezone.utc).isoformat() + + if status == "completed": + if step_name not in exec_state["completed_steps"]: + exec_state["completed_steps"].append(step_name) + if output is not None: + exec_state["step_results"][step_name] = output + elif status == "failed": + exec_state["error_message"] = error + + # Record step history event + step_event: dict[str, Any] = { + "id": str(uuid.uuid4()), + "execution_id": execution_id, + "step_name": step_name, + "status": status, + "output_data": output, + "error_message": error, + "duration_ms": duration_ms, + "started_at": datetime.now(timezone.utc).isoformat(), + "completed_at": datetime.now(timezone.utc).isoformat() if status in ("completed", "failed") else None, + } + self._step_history[execution_id].append(step_event) + + async def complete_execution( + self, + execution_id: str, + final_output: dict[str, Any] | None = None, + ) -> None: + exec_state = self._executions.get(execution_id) + if exec_state is None: + return + now = datetime.now(timezone.utc).isoformat() + exec_state["status"] = "completed" + exec_state["final_output"] = final_output + exec_state["updated_at"] = now + exec_state["completed_at"] = now + + async def fail_execution( + self, + execution_id: str, + step_name: str, + error: str, + ) -> None: + exec_state = self._executions.get(execution_id) + if exec_state is None: + return + now = datetime.now(timezone.utc).isoformat() + exec_state["status"] = "failed" + exec_state["error_message"] = f"Step '{step_name}' failed: {error}" + exec_state["updated_at"] = now + exec_state["completed_at"] = now + + async def get_execution(self, execution_id: str) -> dict[str, Any] | None: + return self._executions.get(execution_id) + + async def list_executions( + self, + status: str | None = None, + limit: int = 50, + offset: int = 0, + ) -> list[dict[str, Any]]: + results = list(self._executions.values()) + if status: + results = [e for e in results if e.get("status") == status] + results.sort(key=lambda e: e.get("created_at", ""), reverse=True) + return results[offset : offset + limit] + + async def get_step_history(self, execution_id: str) -> list[dict[str, Any]]: + return self._step_history.get(execution_id, []) + + +class PipelineStateRedis: + """Redis-backed pipeline state storage (hot state). + + Uses Redis Hash for execution state and Sorted Set for indexing. + Falls back to PipelineStateMemory if Redis is unavailable. + """ + + def __init__(self, redis_url: str = "redis://localhost:6379/0") -> None: + self._redis_url = redis_url + self._redis: Any = None + self._fallback = PipelineStateMemory() + self._use_fallback = False + + async def _get_redis(self): + if self._redis is None: + import redis.asyncio as aioredis + + self._redis = aioredis.from_url( + self._redis_url, + decode_responses=True, + ) + return self._redis + + async def _safe_redis_call( + self, fn: Callable[..., Coroutine[Any, Any, Any]], *args: Any, **kwargs: Any + ) -> Any: + """Execute a Redis call, falling back to memory on failure.""" + if self._use_fallback: + return None + try: + redis = await self._get_redis() + return await fn(redis, *args, **kwargs) + except Exception as exc: + logger.warning(f"Redis operation failed, switching to memory fallback: {exc}") + self._use_fallback = True + self._redis = None + return None + + def _key(self, execution_id: str) -> str: + return f"{_EXEC_KEY_PREFIX}{execution_id}" + + async def create_execution( + self, + pipeline_name: str, + steps: list[str], + input_data: dict[str, Any] | None = None, + tenant_id: str | None = None, + ) -> str: + # Always write to fallback first for consistency + execution_id = await self._fallback.create_execution( + pipeline_name, steps, input_data, tenant_id + ) + + # Try Redis + async def _redis_create(redis: Any) -> None: + state = self._fallback._executions[execution_id] + score = datetime.now(timezone.utc).timestamp() + pipe = redis.pipeline() + pipe.set(self._key(execution_id), json.dumps(state), ex=_TTL_SECONDS) + pipe.zadd(_INDEX_KEY, {execution_id: score}) + await pipe.execute() + + await self._safe_redis_call(_redis_create) + return execution_id + + async def update_step( + self, + execution_id: str, + step_name: str, + status: str, + output: dict[str, Any] | None = None, + error: str | None = None, + duration_ms: int | None = None, + ) -> None: + await self._fallback.update_step(execution_id, step_name, status, output, error, duration_ms) + + async def _redis_update(redis: Any) -> None: + state = self._fallback._executions.get(execution_id) + if state is None: + return + await redis.set(self._key(execution_id), json.dumps(state), ex=_TTL_SECONDS) + + await self._safe_redis_call(_redis_update) + + async def complete_execution( + self, + execution_id: str, + final_output: dict[str, Any] | None = None, + ) -> None: + await self._fallback.complete_execution(execution_id, final_output) + + async def _redis_complete(redis: Any) -> None: + state = self._fallback._executions.get(execution_id) + if state is None: + return + await redis.set(self._key(execution_id), json.dumps(state), ex=_TTL_SECONDS) + + await self._safe_redis_call(_redis_complete) + + async def fail_execution( + self, + execution_id: str, + step_name: str, + error: str, + ) -> None: + await self._fallback.fail_execution(execution_id, step_name, error) + + async def _redis_fail(redis: Any) -> None: + state = self._fallback._executions.get(execution_id) + if state is None: + return + await redis.set(self._key(execution_id), json.dumps(state), ex=_TTL_SECONDS) + + await self._safe_redis_call(_redis_fail) + + async def get_execution(self, execution_id: str) -> dict[str, Any] | None: + # Try Redis first + if not self._use_fallback: + try: + redis = await self._get_redis() + raw = await redis.get(self._key(execution_id)) + if raw is not None: + return json.loads(raw) + except Exception: + pass + + # Fallback to memory + return await self._fallback.get_execution(execution_id) + + async def list_executions( + self, + status: str | None = None, + limit: int = 50, + offset: int = 0, + ) -> list[dict[str, Any]]: + # Try Redis sorted set for efficient listing + if not self._use_fallback: + try: + redis = await self._get_redis() + # Get recent execution IDs from sorted set (newest first) + ids = await redis.zrevrange(_INDEX_KEY, offset, offset + limit - 1) + if ids: + keys = [self._key(eid) for eid in ids] + values = await redis.mget(keys) + results = [] + for raw in values: + if raw is None: + continue + state = json.loads(raw) + if status is None or state.get("status") == status: + results.append(state) + return results + except Exception: + pass + + return await self._fallback.list_executions(status, limit, offset) + + async def get_step_history(self, execution_id: str) -> list[dict[str, Any]]: + return await self._fallback.get_step_history(execution_id) + + async def health_check(self) -> bool: + if self._use_fallback: + return False + try: + redis = await self._get_redis() + return await redis.ping() + except Exception: + return False + + @property + def using_fallback(self) -> bool: + return self._use_fallback + + +class PipelineStatePG: + """PostgreSQL cold persistence for pipeline execution records. + + If session_factory is None, all methods are no-op. + """ + + def __init__(self, session_factory: Any = None) -> None: + self._session_factory = session_factory + + @property + def enabled(self) -> bool: + return self._session_factory is not None + + async def persist_execution(self, state: dict[str, Any]) -> None: + """Write a completed/failed execution to PostgreSQL.""" + if not self.enabled: + return + try: + from sqlalchemy.ext.asyncio import AsyncSession + + async with self._session_factory() as session: + model = PipelineExecutionModel( + id=state["id"], + pipeline_name=state["pipeline_name"], + status=state["status"], + current_step=state.get("current_step"), + completed_steps=state.get("completed_steps", []), + step_results=state.get("step_results", {}), + input_data=state.get("input_data"), + final_output=state.get("final_output"), + error_message=state.get("error_message"), + tenant_id=state.get("tenant_id"), + created_at=datetime.fromisoformat(state["created_at"]) if state.get("created_at") else None, + updated_at=datetime.fromisoformat(state["updated_at"]) if state.get("updated_at") else None, + completed_at=datetime.fromisoformat(state["completed_at"]) if state.get("completed_at") else None, + ) + await session.merge(model) + await session.commit() + except Exception as exc: + logger.error(f"Failed to persist execution to PG: {exc}") + + async def persist_step_history( + self, execution_id: str, steps: list[dict[str, Any]] + ) -> None: + """Write step history to PostgreSQL.""" + if not self.enabled: + return + try: + async with self._session_factory() as session: + for idx, step in enumerate(steps): + model = PipelineStepHistoryModel( + id=step.get("id", str(uuid.uuid4())), + execution_id=execution_id, + step_name=step["step_name"], + step_index=idx, + status=step["status"], + input_data=step.get("input_data"), + output_data=step.get("output_data"), + error_message=step.get("error_message"), + duration_ms=step.get("duration_ms"), + retry_attempt=step.get("retry_attempt", 0), + started_at=datetime.fromisoformat(step["started_at"]) if step.get("started_at") else None, + completed_at=datetime.fromisoformat(step["completed_at"]) if step.get("completed_at") else None, + ) + await session.merge(model) + await session.commit() + except Exception as exc: + logger.error(f"Failed to persist step history to PG: {exc}") + + async def query_executions( + self, + pipeline_name: str | None = None, + status: str | None = None, + limit: int = 50, + offset: int = 0, + ) -> list[dict[str, Any]]: + """Query historical executions from PostgreSQL.""" + if not self.enabled: + return [] + try: + from sqlalchemy import select + + async with self._session_factory() as session: + stmt = select(PipelineExecutionModel).order_by( + PipelineExecutionModel.created_at.desc() + ) + if pipeline_name: + stmt = stmt.where( + PipelineExecutionModel.pipeline_name == pipeline_name + ) + if status: + stmt = stmt.where(PipelineExecutionModel.status == status) + stmt = stmt.offset(offset).limit(limit) + result = await session.execute(stmt) + rows = result.scalars().all() + return [self._model_to_dict(row) for row in rows] + except Exception as exc: + logger.error(f"Failed to query executions from PG: {exc}") + return [] + + async def get_execution(self, execution_id: str) -> dict[str, Any] | None: + """Get a single execution from PostgreSQL (for Redis miss fallback).""" + if not self.enabled: + return None + try: + from sqlalchemy import select + + async with self._session_factory() as session: + stmt = select(PipelineExecutionModel).where( + PipelineExecutionModel.id == execution_id + ) + result = await session.execute(stmt) + row = result.scalar_one_or_none() + if row is None: + return None + return self._model_to_dict(row) + except Exception as exc: + logger.error(f"Failed to get execution from PG: {exc}") + return None + + @staticmethod + def _model_to_dict(model: PipelineExecutionModel) -> dict[str, Any]: + return { + "id": model.id, + "pipeline_name": model.pipeline_name, + "status": model.status, + "current_step": model.current_step, + "completed_steps": model.completed_steps or [], + "step_results": model.step_results or {}, + "input_data": model.input_data, + "final_output": model.final_output, + "error_message": model.error_message, + "tenant_id": model.tenant_id, + "created_at": model.created_at.isoformat() if model.created_at else None, + "updated_at": model.updated_at.isoformat() if model.updated_at else None, + "completed_at": model.completed_at.isoformat() if model.completed_at else None, + } + + +class PipelineStateManager: + """Unified pipeline state manager — Redis hot + PG cold. + + - create / update → Redis (with in-memory fallback) + - complete / fail → Redis + async persist to PG + - get → Redis first, PG fallback + - list → Redis for recent, PG for historical + """ + + def __init__( + self, + redis_url: str | None = None, + session_factory: Any = None, + ) -> None: + if redis_url: + self._hot = PipelineStateRedis(redis_url=redis_url) + else: + self._hot = PipelineStateMemory() + self._cold = PipelineStatePG(session_factory=session_factory) + + @property + def hot_store(self) -> PipelineStateMemory | PipelineStateRedis: + return self._hot + + @property + def cold_store(self) -> PipelineStatePG: + return self._cold + + async def create_execution( + self, + pipeline_name: str, + steps: list[str], + input_data: dict[str, Any] | None = None, + tenant_id: str | None = None, + ) -> str: + return await self._hot.create_execution(pipeline_name, steps, input_data, tenant_id) + + async def update_step( + self, + execution_id: str, + step_name: str, + status: str, + output: dict[str, Any] | None = None, + error: str | None = None, + duration_ms: int | None = None, + ) -> None: + await self._hot.update_step(execution_id, step_name, status, output, error, duration_ms) + + async def complete_execution( + self, + execution_id: str, + final_output: dict[str, Any] | None = None, + ) -> None: + await self._hot.complete_execution(execution_id, final_output) + # Persist to PG + state = await self._hot.get_execution(execution_id) + if state: + await self._cold.persist_execution(state) + step_history = await self._hot.get_step_history(execution_id) + if step_history: + await self._cold.persist_step_history(execution_id, step_history) + + async def fail_execution( + self, + execution_id: str, + step_name: str, + error: str, + ) -> None: + await self._hot.fail_execution(execution_id, step_name, error) + # Persist to PG + state = await self._hot.get_execution(execution_id) + if state: + await self._cold.persist_execution(state) + step_history = await self._hot.get_step_history(execution_id) + if step_history: + await self._cold.persist_step_history(execution_id, step_history) + + async def get_execution(self, execution_id: str) -> dict[str, Any] | None: + # Redis / memory first + state = await self._hot.get_execution(execution_id) + if state is not None: + return state + # PG fallback + return await self._cold.get_execution(execution_id) + + async def list_executions( + self, + status: str | None = None, + limit: int = 50, + offset: int = 0, + ) -> list[dict[str, Any]]: + # Hot store for recent executions + results = await self._hot.list_executions(status, limit, offset) + if results: + return results + # Cold store for historical queries + return await self._cold.query_executions(status=status, limit=limit, offset=offset) + + async def get_step_history(self, execution_id: str) -> list[dict[str, Any]]: + return await self._hot.get_step_history(execution_id) + + async def health_check(self) -> dict[str, bool]: + """Check health of both stores.""" + hot_ok = True + if isinstance(self._hot, PipelineStateRedis): + hot_ok = await self._hot.health_check() + cold_ok = self._cold.enabled + return {"hot": hot_ok, "cold": cold_ok} diff --git a/tests/unit/test_pipeline_state.py b/tests/unit/test_pipeline_state.py new file mode 100644 index 0000000..55ad5d8 --- /dev/null +++ b/tests/unit/test_pipeline_state.py @@ -0,0 +1,661 @@ +"""Unit tests for Pipeline execution state persistence.""" + +from __future__ import annotations + +import asyncio +import json +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from agentkit.orchestrator.pipeline_engine import PipelineEngine +from agentkit.orchestrator.pipeline_schema import Pipeline, PipelineStage, StageStatus +from agentkit.orchestrator.pipeline_state import ( + PipelineStateMemory, + PipelineStatePG, + PipelineStateRedis, + PipelineStateManager, +) + + +# ═══════════════════════════════════════════════════════════════ +# PipelineStateMemory +# ═══════════════════════════════════════════════════════════════ + + +class TestPipelineStateMemory: + """Tests for in-memory pipeline state storage.""" + + @pytest.fixture + def store(self) -> PipelineStateMemory: + return PipelineStateMemory() + + @pytest.mark.asyncio + async def test_create_execution(self, store: PipelineStateMemory): + eid = await store.create_execution( + pipeline_name="test_pipeline", + steps=["step_a", "step_b"], + input_data={"key": "value"}, + ) + assert eid is not None + state = await store.get_execution(eid) + assert state is not None + assert state["pipeline_name"] == "test_pipeline" + assert state["status"] == "running" + assert state["current_step"] == "step_a" + assert state["completed_steps"] == [] + assert state["input_data"] == {"key": "value"} + + @pytest.mark.asyncio + async def test_update_step_completed(self, store: PipelineStateMemory): + eid = await store.create_execution("p", ["s1", "s2"]) + await store.update_step(eid, "s1", "completed", output={"result": 42}) + state = await store.get_execution(eid) + assert "s1" in state["completed_steps"] + assert state["step_results"]["s1"] == {"result": 42} + + @pytest.mark.asyncio + async def test_update_step_failed(self, store: PipelineStateMemory): + eid = await store.create_execution("p", ["s1"]) + await store.update_step(eid, "s1", "failed", error="boom") + state = await store.get_execution(eid) + assert state["error_message"] == "boom" + + @pytest.mark.asyncio + async def test_complete_execution(self, store: PipelineStateMemory): + eid = await store.create_execution("p", ["s1"]) + await store.complete_execution(eid, final_output={"done": True}) + state = await store.get_execution(eid) + assert state["status"] == "completed" + assert state["final_output"] == {"done": True} + assert state["completed_at"] is not None + + @pytest.mark.asyncio + async def test_fail_execution(self, store: PipelineStateMemory): + eid = await store.create_execution("p", ["s1"]) + await store.fail_execution(eid, "s1", "timeout") + state = await store.get_execution(eid) + assert state["status"] == "failed" + assert "s1" in state["error_message"] + assert "timeout" in state["error_message"] + assert state["completed_at"] is not None + + @pytest.mark.asyncio + async def test_get_execution_not_found(self, store: PipelineStateMemory): + result = await store.get_execution("nonexistent") + assert result is None + + @pytest.mark.asyncio + async def test_list_executions(self, store: PipelineStateMemory): + eid1 = await store.create_execution("p1", ["s1"]) + eid2 = await store.create_execution("p2", ["s2"]) + await store.complete_execution(eid1) + # List all + all_execs = await store.list_executions() + assert len(all_execs) == 2 + # Filter by status + completed = await store.list_executions(status="completed") + assert len(completed) == 1 + assert completed[0]["id"] == eid1 + + @pytest.mark.asyncio + async def test_list_executions_pagination(self, store: PipelineStateMemory): + for i in range(5): + await store.create_execution(f"p{i}", ["s1"]) + page1 = await store.list_executions(limit=2, offset=0) + page2 = await store.list_executions(limit=2, offset=2) + assert len(page1) == 2 + assert len(page2) == 2 + + @pytest.mark.asyncio + async def test_get_step_history(self, store: PipelineStateMemory): + eid = await store.create_execution("p", ["s1", "s2"]) + await store.update_step(eid, "s1", "completed", output={"r": 1}) + await store.update_step(eid, "s2", "failed", error="err") + history = await store.get_step_history(eid) + assert len(history) == 2 + assert history[0]["step_name"] == "s1" + assert history[0]["status"] == "completed" + assert history[1]["step_name"] == "s2" + assert history[1]["status"] == "failed" + + @pytest.mark.asyncio + async def test_update_step_nonexistent_execution(self, store: PipelineStateMemory): + # Should not raise, just log warning + await store.update_step("nonexistent", "s1", "completed") + + @pytest.mark.asyncio + async def test_create_execution_with_tenant(self, store: PipelineStateMemory): + eid = await store.create_execution("p", ["s1"], tenant_id="tenant_123") + state = await store.get_execution(eid) + assert state["tenant_id"] == "tenant_123" + + +# ═══════════════════════════════════════════════════════════════ +# PipelineStateRedis +# ═══════════════════════════════════════════════════════════════ + + +class TestPipelineStateRedis: + """Tests for Redis-backed pipeline state storage (using mocks).""" + + @pytest.fixture + def mock_redis(self): + """Create a mock Redis client.""" + redis = AsyncMock() + redis.get = AsyncMock(return_value=None) + redis.set = AsyncMock(return_value=True) + redis.zadd = AsyncMock(return_value=1) + redis.zrevrange = AsyncMock(return_value=[]) + redis.mget = AsyncMock(return_value=[]) + # Redis pipeline: set/zadd are synchronous (return self for chaining), execute is async + pipe = MagicMock() + pipe.set = MagicMock(return_value=pipe) + pipe.zadd = MagicMock(return_value=pipe) + pipe.execute = AsyncMock(return_value=[True, 1]) + redis.pipeline = MagicMock(return_value=pipe) + return redis + + @pytest.fixture + def store(self, mock_redis) -> PipelineStateRedis: + """Create a PipelineStateRedis with mocked Redis.""" + store = PipelineStateRedis(redis_url="redis://localhost:6379/0") + # Pre-inject the mock Redis client + store._redis = mock_redis + return store + + @pytest.mark.asyncio + async def test_create_execution_writes_to_redis(self, store: PipelineStateRedis, mock_redis): + eid = await store.create_execution("test_pipeline", ["s1", "s2"]) + assert eid is not None + # Redis pipeline should have been used (pipe.set + pipe.zadd + pipe.execute) + pipe = mock_redis.pipeline.return_value + pipe.set.assert_called_once() + call_args = pipe.set.call_args + assert call_args[0][0].startswith("agentkit:pipeline:exec:") + # Verify the stored data is valid JSON + stored_data = json.loads(call_args[0][1]) + assert stored_data["pipeline_name"] == "test_pipeline" + assert stored_data["status"] == "running" + # Verify TTL was set (7 days) + assert call_args[1].get("ex") == 7 * 24 * 3600 + + @pytest.mark.asyncio + async def test_create_execution_adds_to_sorted_set(self, store: PipelineStateRedis, mock_redis): + await store.create_execution("p", ["s1"]) + # ZADD should have been called via pipeline + pipe = mock_redis.pipeline.return_value + pipe.zadd.assert_called_once() + + @pytest.mark.asyncio + async def test_update_step_writes_to_redis(self, store: PipelineStateRedis, mock_redis): + eid = await store.create_execution("p", ["s1"]) + mock_redis.set.reset_mock() + await store.update_step(eid, "s1", "completed", output={"r": 1}) + mock_redis.set.assert_called_once() + + @pytest.mark.asyncio + async def test_complete_execution_writes_to_redis(self, store: PipelineStateRedis, mock_redis): + eid = await store.create_execution("p", ["s1"]) + mock_redis.set.reset_mock() + await store.complete_execution(eid, final_output={"done": True}) + mock_redis.set.assert_called_once() + + @pytest.mark.asyncio + async def test_fail_execution_writes_to_redis(self, store: PipelineStateRedis, mock_redis): + eid = await store.create_execution("p", ["s1"]) + mock_redis.set.reset_mock() + await store.fail_execution(eid, "s1", "error") + mock_redis.set.assert_called_once() + + @pytest.mark.asyncio + async def test_get_execution_from_redis(self, store: PipelineStateRedis, mock_redis): + eid = await store.create_execution("p", ["s1"]) + # Simulate Redis returning data + state = await store._fallback.get_execution(eid) + mock_redis.get.return_value = json.dumps(state) + result = await store.get_execution(eid) + assert result is not None + assert result["pipeline_name"] == "p" + + @pytest.mark.asyncio + async def test_get_execution_redis_miss_falls_back_to_memory(self, store: PipelineStateRedis, mock_redis): + eid = await store.create_execution("p", ["s1"]) + # Redis returns None (miss) + mock_redis.get.return_value = None + # Should still find it in memory fallback + result = await store.get_execution(eid) + assert result is not None + assert result["pipeline_name"] == "p" + + @pytest.mark.asyncio + async def test_list_executions_from_sorted_set(self, store: PipelineStateRedis, mock_redis): + eid = await store.create_execution("p", ["s1"]) + state = await store._fallback.get_execution(eid) + mock_redis.zrevrange.return_value = [eid] + mock_redis.mget.return_value = [json.dumps(state)] + results = await store.list_executions() + assert len(results) == 1 + assert results[0]["pipeline_name"] == "p" + + @pytest.mark.asyncio + async def test_fallback_on_redis_failure(self, mock_redis): + store = PipelineStateRedis(redis_url="redis://localhost:6379/0") + # Make Redis initialization fail + mock_redis.ping = AsyncMock(side_effect=Exception("connection refused")) + store._redis = mock_redis + # Force a Redis operation to fail + mock_redis.set = AsyncMock(side_effect=Exception("connection refused")) + mock_redis.pipeline = MagicMock(side_effect=Exception("connection refused")) + # Should fall back to memory + eid = await store.create_execution("p", ["s1"]) + assert eid is not None + assert store.using_fallback is True + + @pytest.mark.asyncio + async def test_health_check(self, store: PipelineStateRedis, mock_redis): + mock_redis.ping = AsyncMock(return_value=True) + assert await store.health_check() is True + mock_redis.ping = AsyncMock(side_effect=Exception("fail")) + assert await store.health_check() is False + + +# ═══════════════════════════════════════════════════════════════ +# PipelineStatePG +# ═══════════════════════════════════════════════════════════════ + + +class TestPipelineStatePG: + """Tests for PostgreSQL cold persistence (using mocks).""" + + @pytest.mark.asyncio + async def test_no_op_when_session_factory_is_none(self): + pg = PipelineStatePG(session_factory=None) + assert pg.enabled is False + # All methods should be no-op + await pg.persist_execution({"id": "1", "pipeline_name": "p", "status": "completed"}) + await pg.persist_step_history("1", []) + result = await pg.query_executions() + assert result == [] + result = await pg.get_execution("1") + assert result is None + + @pytest.mark.asyncio + async def test_persist_execution(self): + mock_session = AsyncMock() + mock_session.merge = AsyncMock() + mock_session.commit = AsyncMock() + + mock_factory = MagicMock() + mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session) + mock_factory.return_value.__aexit__ = AsyncMock(return_value=False) + + pg = PipelineStatePG(session_factory=mock_factory) + assert pg.enabled is True + + state = { + "id": "test-id-123", + "pipeline_name": "test_pipeline", + "status": "completed", + "current_step": None, + "completed_steps": ["s1"], + "step_results": {"s1": {"r": 1}}, + "input_data": {"key": "val"}, + "final_output": {"done": True}, + "error_message": None, + "tenant_id": None, + "created_at": datetime.now(timezone.utc).isoformat(), + "updated_at": datetime.now(timezone.utc).isoformat(), + "completed_at": datetime.now(timezone.utc).isoformat(), + } + await pg.persist_execution(state) + mock_session.merge.assert_called_once() + mock_session.commit.assert_called_once() + + @pytest.mark.asyncio + async def test_persist_step_history(self): + mock_session = AsyncMock() + mock_session.merge = AsyncMock() + mock_session.commit = AsyncMock() + + mock_factory = MagicMock() + mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session) + mock_factory.return_value.__aexit__ = AsyncMock(return_value=False) + + pg = PipelineStatePG(session_factory=mock_factory) + + steps = [ + { + "id": "step-id-1", + "step_name": "s1", + "status": "completed", + "output_data": {"r": 1}, + "error_message": None, + "duration_ms": 100, + "started_at": datetime.now(timezone.utc).isoformat(), + "completed_at": datetime.now(timezone.utc).isoformat(), + } + ] + await pg.persist_step_history("exec-1", steps) + mock_session.merge.assert_called_once() + mock_session.commit.assert_called_once() + + @pytest.mark.asyncio + async def test_persist_execution_handles_error(self): + mock_session = AsyncMock() + mock_session.merge = AsyncMock(side_effect=Exception("DB error")) + mock_session.commit = AsyncMock() + + mock_factory = MagicMock() + mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session) + mock_factory.return_value.__aexit__ = AsyncMock(return_value=False) + + pg = PipelineStatePG(session_factory=mock_factory) + # Should not raise + await pg.persist_execution({ + "id": "1", + "pipeline_name": "p", + "status": "completed", + "created_at": datetime.now(timezone.utc).isoformat(), + "updated_at": datetime.now(timezone.utc).isoformat(), + }) + + @pytest.mark.asyncio + async def test_query_executions(self): + from agentkit.orchestrator.pipeline_models import PipelineExecutionModel + + # Create a mock model instance + model = PipelineExecutionModel( + id="test-id", + pipeline_name="test_pipeline", + status="completed", + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [model] + + mock_session = AsyncMock() + mock_session.execute = AsyncMock(return_value=mock_result) + + mock_factory = MagicMock() + mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session) + mock_factory.return_value.__aexit__ = AsyncMock(return_value=False) + + pg = PipelineStatePG(session_factory=mock_factory) + results = await pg.query_executions(pipeline_name="test_pipeline") + assert len(results) == 1 + assert results[0]["pipeline_name"] == "test_pipeline" + + @pytest.mark.asyncio + async def test_get_execution_found(self): + from agentkit.orchestrator.pipeline_models import PipelineExecutionModel + + model = PipelineExecutionModel( + id="test-id", + pipeline_name="test_pipeline", + status="completed", + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = model + + mock_session = AsyncMock() + mock_session.execute = AsyncMock(return_value=mock_result) + + mock_factory = MagicMock() + mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session) + mock_factory.return_value.__aexit__ = AsyncMock(return_value=False) + + pg = PipelineStatePG(session_factory=mock_factory) + result = await pg.get_execution("test-id") + assert result is not None + assert result["id"] == "test-id" + + @pytest.mark.asyncio + async def test_get_execution_not_found(self): + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = None + + mock_session = AsyncMock() + mock_session.execute = AsyncMock(return_value=mock_result) + + mock_factory = MagicMock() + mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session) + mock_factory.return_value.__aexit__ = AsyncMock(return_value=False) + + pg = PipelineStatePG(session_factory=mock_factory) + result = await pg.get_execution("nonexistent") + assert result is None + + +# ═══════════════════════════════════════════════════════════════ +# PipelineStateManager +# ═══════════════════════════════════════════════════════════════ + + +class TestPipelineStateManager: + """Tests for the unified state manager.""" + + @pytest.fixture + def manager(self) -> PipelineStateManager: + """Create a manager with memory-only backend.""" + return PipelineStateManager(redis_url=None, session_factory=None) + + @pytest.mark.asyncio + async def test_create_and_get_execution(self, manager: PipelineStateManager): + eid = await manager.create_execution("p", ["s1"], input_data={"k": "v"}) + state = await manager.get_execution(eid) + assert state is not None + assert state["pipeline_name"] == "p" + assert state["status"] == "running" + + @pytest.mark.asyncio + async def test_update_step(self, manager: PipelineStateManager): + eid = await manager.create_execution("p", ["s1"]) + await manager.update_step(eid, "s1", "completed", output={"r": 1}) + state = await manager.get_execution(eid) + assert "s1" in state["completed_steps"] + + @pytest.mark.asyncio + async def test_complete_persists_to_cold(self): + """Test that completing an execution triggers PG persist.""" + mock_session = AsyncMock() + mock_session.merge = AsyncMock() + mock_session.commit = AsyncMock() + + mock_factory = MagicMock() + mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session) + mock_factory.return_value.__aexit__ = AsyncMock(return_value=False) + + manager = PipelineStateManager(redis_url=None, session_factory=mock_factory) + eid = await manager.create_execution("p", ["s1"]) + await manager.update_step(eid, "s1", "completed", output={"r": 1}) + await manager.complete_execution(eid, final_output={"done": True}) + + # PG persist should have been called + mock_session.merge.assert_called() + + @pytest.mark.asyncio + async def test_fail_persists_to_cold(self): + """Test that failing an execution triggers PG persist.""" + mock_session = AsyncMock() + mock_session.merge = AsyncMock() + mock_session.commit = AsyncMock() + + mock_factory = MagicMock() + mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session) + mock_factory.return_value.__aexit__ = AsyncMock(return_value=False) + + manager = PipelineStateManager(redis_url=None, session_factory=mock_factory) + eid = await manager.create_execution("p", ["s1"]) + await manager.fail_execution(eid, "s1", "error") + + mock_session.merge.assert_called() + + @pytest.mark.asyncio + async def test_get_execution_pg_fallback(self): + """Test Redis miss falls back to PG.""" + from agentkit.orchestrator.pipeline_models import PipelineExecutionModel + + model = PipelineExecutionModel( + id="pg-exec-id", + pipeline_name="pg_pipeline", + status="completed", + created_at=datetime.now(timezone.utc), + updated_at=datetime.now(timezone.utc), + ) + + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = model + + mock_session = AsyncMock() + mock_session.execute = AsyncMock(return_value=mock_result) + + mock_factory = MagicMock() + mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session) + mock_factory.return_value.__aexit__ = AsyncMock(return_value=False) + + manager = PipelineStateManager(redis_url=None, session_factory=mock_factory) + # This execution_id doesn't exist in hot store, should fall back to PG + result = await manager.get_execution("pg-exec-id") + assert result is not None + assert result["pipeline_name"] == "pg_pipeline" + + @pytest.mark.asyncio + async def test_list_executions_hot_first(self, manager: PipelineStateManager): + eid = await manager.create_execution("p", ["s1"]) + results = await manager.list_executions() + assert len(results) == 1 + assert results[0]["id"] == eid + + @pytest.mark.asyncio + async def test_health_check_memory_only(self, manager: PipelineStateManager): + health = await manager.health_check() + assert health["hot"] is True + assert health["cold"] is False + + @pytest.mark.asyncio + async def test_health_check_with_pg(self): + mock_factory = MagicMock() + manager = PipelineStateManager(redis_url=None, session_factory=mock_factory) + health = await manager.health_check() + assert health["hot"] is True + assert health["cold"] is True + + +# ═══════════════════════════════════════════════════════════════ +# PipelineEngine with state persistence +# ═══════════════════════════════════════════════════════════════ + + +class TestPipelineEngineWithState: + """Tests for PipelineEngine integration with state persistence.""" + + @pytest.fixture + def pipeline(self) -> Pipeline: + return Pipeline( + name="test_pipeline", + version="1.0", + description="Test pipeline", + stages=[ + PipelineStage(name="step_a", agent="agent1", action="do_a"), + PipelineStage(name="step_b", agent="agent2", action="do_b", depends_on=["step_a"]), + ], + ) + + @pytest.mark.asyncio + async def test_engine_without_state_backward_compatible(self, pipeline: Pipeline): + """Engine without state_manager should work as before.""" + engine = PipelineEngine(dispatcher=None) + result = await engine.execute(pipeline) + assert result.status == StageStatus.COMPLETED + + @pytest.mark.asyncio + async def test_engine_with_state_creates_execution(self, pipeline: Pipeline): + """Engine with state_manager should create execution state.""" + state_manager = PipelineStateManager(redis_url=None, session_factory=None) + engine = PipelineEngine(dispatcher=None, state_manager=state_manager) + result = await engine.execute(pipeline) + assert result.status == StageStatus.COMPLETED + # Check that execution was created in state store + executions = await state_manager.list_executions() + assert len(executions) == 1 + assert executions[0]["status"] == "completed" + assert executions[0]["pipeline_name"] == "test_pipeline" + + @pytest.mark.asyncio + async def test_engine_with_state_updates_steps(self, pipeline: Pipeline): + """Engine should update step state after each stage.""" + state_manager = PipelineStateManager(redis_url=None, session_factory=None) + engine = PipelineEngine(dispatcher=None, state_manager=state_manager) + await engine.execute(pipeline) + executions = await state_manager.list_executions() + exec_state = executions[0] + # Both steps should be completed + assert "step_a" in exec_state["completed_steps"] + assert "step_b" in exec_state["completed_steps"] + + @pytest.mark.asyncio + async def test_engine_with_state_on_failure(self): + """Engine should persist failure state when a stage fails.""" + pipeline = Pipeline( + name="fail_pipeline", + version="1.0", + description="Pipeline that fails", + stages=[ + PipelineStage(name="bad_step", agent="agent1", action="fail"), + ], + ) + + # Create a dispatcher that raises + mock_dispatcher = AsyncMock() + mock_dispatcher.dispatch = AsyncMock(side_effect=Exception("boom")) + + state_manager = PipelineStateManager(redis_url=None, session_factory=None) + engine = PipelineEngine(dispatcher=mock_dispatcher, state_manager=state_manager) + result = await engine.execute(pipeline) + assert result.status == StageStatus.FAILED + # Check state was persisted + executions = await state_manager.list_executions() + assert len(executions) == 1 + assert executions[0]["status"] == "failed" + + @pytest.mark.asyncio + async def test_engine_state_survives_check(self, pipeline: Pipeline): + """Verify state can be retrieved after execution.""" + state_manager = PipelineStateManager(redis_url=None, session_factory=None) + engine = PipelineEngine(dispatcher=None, state_manager=state_manager) + result = await engine.execute(pipeline, context={"brand": "acme"}) + # Get execution by ID + executions = await state_manager.list_executions() + eid = executions[0]["id"] + state = await state_manager.get_execution(eid) + assert state is not None + assert state["pipeline_name"] == "test_pipeline" + assert state["status"] == "completed" + + @pytest.mark.asyncio + async def test_engine_with_circular_dependency(self): + """Engine should handle circular dependency gracefully.""" + pipeline = Pipeline( + name="circular", + version="1.0", + description="Circular pipeline", + stages=[ + PipelineStage(name="a", agent="agent1", action="do", depends_on=["b"]), + PipelineStage(name="b", agent="agent2", action="do", depends_on=["a"]), + ], + ) + state_manager = PipelineStateManager(redis_url=None, session_factory=None) + engine = PipelineEngine(dispatcher=None, state_manager=state_manager) + result = await engine.execute(pipeline) + assert result.status == StageStatus.FAILED + assert "Circular" in result.error_message + # No execution state should be created (topological sort fails before creation) + executions = await state_manager.list_executions() + assert len(executions) == 0