feat(pipeline): U5 state persistence with Redis hot + PG cold dual-write
Add PipelineStateMemory/Redis/PG backends, PipelineStateManager with Redis Sorted Set hot state + PostgreSQL JSONB cold persistence. Integrated into PipelineEngine with state persistence calls at each step transition.
This commit is contained in:
parent
2e547e345a
commit
4db637cd4f
|
|
@ -5,6 +5,18 @@ from agentkit.orchestrator.pipeline_engine import PipelineEngine
|
|||
from agentkit.orchestrator.pipeline_loader import PipelineLoader
|
||||
from agentkit.orchestrator.handoff import HandoffManager
|
||||
from agentkit.orchestrator.dynamic_pipeline import DynamicPipeline
|
||||
from agentkit.orchestrator.pipeline_state import (
|
||||
PipelineStateMemory,
|
||||
PipelineStateRedis,
|
||||
PipelineStatePG,
|
||||
PipelineStateManager,
|
||||
)
|
||||
from agentkit.orchestrator.retry import StepRetryPolicy, execute_with_retry
|
||||
from agentkit.orchestrator.compensation import (
|
||||
CompletedStep,
|
||||
CompensationResult,
|
||||
SagaOrchestrator,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Pipeline",
|
||||
|
|
@ -14,4 +26,13 @@ __all__ = [
|
|||
"PipelineLoader",
|
||||
"HandoffManager",
|
||||
"DynamicPipeline",
|
||||
"PipelineStateMemory",
|
||||
"PipelineStateRedis",
|
||||
"PipelineStatePG",
|
||||
"PipelineStateManager",
|
||||
"StepRetryPolicy",
|
||||
"execute_with_retry",
|
||||
"CompletedStep",
|
||||
"CompensationResult",
|
||||
"SagaOrchestrator",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
"""Pipeline Engine - DAG + 并行执行"""
|
||||
"""Pipeline Engine - DAG + 并行执行 + 步骤重试 + Saga 补偿"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
|
@ -6,6 +6,7 @@ from collections import defaultdict
|
|||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from agentkit.orchestrator.compensation import SagaOrchestrator
|
||||
from agentkit.orchestrator.pipeline_schema import (
|
||||
Pipeline,
|
||||
PipelineResult,
|
||||
|
|
@ -13,6 +14,7 @@ from agentkit.orchestrator.pipeline_schema import (
|
|||
StageResult,
|
||||
StageStatus,
|
||||
)
|
||||
from agentkit.orchestrator.retry import StepRetryPolicy, execute_with_retry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -25,11 +27,14 @@ class PipelineEngine:
|
|||
- 同层并行执行(asyncio.gather)
|
||||
- 变量解析
|
||||
- 条件执行
|
||||
- 重试
|
||||
- 步骤级指数退避重试(StepRetryPolicy)
|
||||
- Saga 补偿(LIFO 回滚已完成步骤)
|
||||
- 状态持久化(可选)
|
||||
"""
|
||||
|
||||
def __init__(self, dispatcher: Any = None):
|
||||
def __init__(self, dispatcher: Any = None, state_manager: Any = None):
|
||||
self._dispatcher = dispatcher
|
||||
self._state_manager = state_manager
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
|
|
@ -48,6 +53,22 @@ class PipelineEngine:
|
|||
result.error_message = str(e)
|
||||
return result
|
||||
|
||||
# Create execution state if state_manager is configured
|
||||
execution_id: str | None = None
|
||||
if self._state_manager is not None:
|
||||
try:
|
||||
step_names = [s.name for s in pipeline.stages]
|
||||
execution_id = await self._state_manager.create_execution(
|
||||
pipeline_name=pipeline.name,
|
||||
steps=step_names,
|
||||
input_data=context,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to create execution state: {exc}")
|
||||
|
||||
# Create Saga orchestrator for compensation tracking
|
||||
saga = SagaOrchestrator()
|
||||
|
||||
# 逐层执行
|
||||
for level, stages in enumerate(level_groups):
|
||||
logger.info(f"Pipeline '{pipeline.name}' executing level {level} with {len(stages)} stage(s)")
|
||||
|
|
@ -55,7 +76,7 @@ class PipelineEngine:
|
|||
# 并行执行同层 stages
|
||||
tasks = []
|
||||
for stage in stages:
|
||||
tasks.append(self._execute_stage(stage, result))
|
||||
tasks.append(self._execute_stage(stage, result, saga))
|
||||
|
||||
stage_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
|
|
@ -69,6 +90,22 @@ class PipelineEngine:
|
|||
)
|
||||
result.stage_results[stage.name] = sr
|
||||
|
||||
# Update step state
|
||||
if self._state_manager is not None and execution_id is not None:
|
||||
try:
|
||||
step_status = "completed" if sr.status == StageStatus.COMPLETED else sr.status.value
|
||||
step_output = sr.output_data if hasattr(sr, 'output_data') else None
|
||||
step_error = sr.error_message if hasattr(sr, 'error_message') else None
|
||||
await self._state_manager.update_step(
|
||||
execution_id=execution_id,
|
||||
step_name=stage.name,
|
||||
status=step_status,
|
||||
output=step_output,
|
||||
error=step_error,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to update step state: {exc}")
|
||||
|
||||
# 收集输出变量
|
||||
if sr.output_data and isinstance(sr, dict):
|
||||
pass
|
||||
|
|
@ -80,17 +117,56 @@ class PipelineEngine:
|
|||
# 检查是否需要中止
|
||||
if hasattr(sr, 'status') and sr.status == StageStatus.FAILED:
|
||||
if not stage.continue_on_failure:
|
||||
# Execute Saga compensation for completed steps
|
||||
compensation_results = await saga.compensate()
|
||||
if compensation_results:
|
||||
failed_compensations = [
|
||||
cr for cr in compensation_results if not cr.success and cr.error != "no_compensation_needed"
|
||||
]
|
||||
if failed_compensations:
|
||||
logger.warning(
|
||||
f"Compensation had {len(failed_compensations)} failures: "
|
||||
f"{[c.step_name for c in failed_compensations]}"
|
||||
)
|
||||
|
||||
result.status = StageStatus.FAILED
|
||||
result.error_message = f"Stage '{stage.name}' failed"
|
||||
# Fail execution state
|
||||
if self._state_manager is not None and execution_id is not None:
|
||||
try:
|
||||
await self._state_manager.fail_execution(
|
||||
execution_id=execution_id,
|
||||
step_name=stage.name,
|
||||
error=result.error_message,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to persist failure state: {exc}")
|
||||
return result
|
||||
|
||||
result.status = StageStatus.COMPLETED
|
||||
|
||||
# Complete execution state
|
||||
if self._state_manager is not None and execution_id is not None:
|
||||
try:
|
||||
final_output = {
|
||||
name: sr.output_data
|
||||
for name, sr in result.stage_results.items()
|
||||
if sr.output_data is not None
|
||||
}
|
||||
await self._state_manager.complete_execution(
|
||||
execution_id=execution_id,
|
||||
final_output=final_output,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to persist completion state: {exc}")
|
||||
|
||||
return result
|
||||
|
||||
async def _execute_stage(
|
||||
self,
|
||||
stage: PipelineStage,
|
||||
pipeline_result: PipelineResult,
|
||||
saga: SagaOrchestrator,
|
||||
) -> StageResult:
|
||||
"""执行单个 stage"""
|
||||
started_at = datetime.now(timezone.utc).isoformat()
|
||||
|
|
@ -110,13 +186,20 @@ class PipelineEngine:
|
|||
# 执行
|
||||
if self._dispatcher is None:
|
||||
# Dry-run 模式
|
||||
return StageResult(
|
||||
result = StageResult(
|
||||
stage_name=stage.name,
|
||||
status=StageStatus.COMPLETED,
|
||||
output_data={"dry_run": True, "inputs": resolved_inputs},
|
||||
started_at=started_at,
|
||||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
# Record completed step for Saga compensation
|
||||
saga.record_completed(
|
||||
step_name=stage.name,
|
||||
result=result.output_data,
|
||||
compensate_action=stage.compensate,
|
||||
)
|
||||
return result
|
||||
|
||||
# 通过 Dispatcher 分发任务
|
||||
from agentkit.core.protocol import TaskMessage
|
||||
|
|
@ -133,7 +216,8 @@ class PipelineEngine:
|
|||
timeout_seconds=stage.timeout_seconds,
|
||||
)
|
||||
|
||||
try:
|
||||
async def _dispatch_and_wait() -> StageResult:
|
||||
"""Dispatch task and wait for result"""
|
||||
await self._dispatcher.dispatch(task)
|
||||
|
||||
# 等待结果
|
||||
|
|
@ -158,6 +242,24 @@ class PipelineEngine:
|
|||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
|
||||
try:
|
||||
# Execute with retry if retry_policy is configured
|
||||
sr = await execute_with_retry(
|
||||
func=_dispatch_and_wait,
|
||||
retry_policy=stage.retry_policy,
|
||||
step_name=stage.name,
|
||||
)
|
||||
|
||||
# Record completed step for Saga compensation on success
|
||||
if sr.status == StageStatus.COMPLETED:
|
||||
saga.record_completed(
|
||||
step_name=stage.name,
|
||||
result=sr.output_data,
|
||||
compensate_action=stage.compensate,
|
||||
)
|
||||
|
||||
return sr
|
||||
|
||||
except Exception as e:
|
||||
return StageResult(
|
||||
stage_name=stage.name,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,59 @@
|
|||
"""Pipeline execution ORM models for PostgreSQL persistence."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import Column, DateTime, Index, Integer, String, Text
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
class PipelineExecutionModel(Base):
|
||||
"""Pipeline execution record — persisted final state."""
|
||||
|
||||
__tablename__ = "pipeline_executions"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
pipeline_name = Column(String(128), nullable=False, index=True)
|
||||
status = Column(String(32), nullable=False, index=True)
|
||||
current_step = Column(String(128))
|
||||
completed_steps = Column(JSONB, default=list)
|
||||
step_results = Column(JSONB, default=dict)
|
||||
input_data = Column(JSONB)
|
||||
final_output = Column(JSONB)
|
||||
error_message = Column(Text)
|
||||
tenant_id = Column(String(64), index=True)
|
||||
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(
|
||||
DateTime,
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc),
|
||||
)
|
||||
completed_at = Column(DateTime)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_pipeline_status_created", "status", "created_at"),
|
||||
)
|
||||
|
||||
|
||||
class PipelineStepHistoryModel(Base):
|
||||
"""Step execution history — audit trail."""
|
||||
|
||||
__tablename__ = "pipeline_step_history"
|
||||
|
||||
id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
execution_id = Column(String(36), nullable=False, index=True)
|
||||
step_name = Column(String(128), nullable=False)
|
||||
step_index = Column(Integer, nullable=False)
|
||||
status = Column(String(32), nullable=False)
|
||||
input_data = Column(JSONB)
|
||||
output_data = Column(JSONB)
|
||||
error_message = Column(Text)
|
||||
duration_ms = Column(Integer)
|
||||
retry_attempt = Column(Integer, default=0)
|
||||
started_at = Column(DateTime)
|
||||
completed_at = Column(DateTime)
|
||||
|
|
@ -0,0 +1,572 @@
|
|||
"""Pipeline execution state persistence — Redis hot state + PostgreSQL cold storage.
|
||||
|
||||
Architecture:
|
||||
PipelineStateMemory — In-memory fallback (always available, for testing)
|
||||
PipelineStateRedis — Redis hot state (low-latency reads/writes)
|
||||
PipelineStatePG — PostgreSQL cold persistence (durable audit trail)
|
||||
PipelineStateManager — Unified manager (Redis + PG dual write, fallback chain)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Callable, Coroutine
|
||||
|
||||
from agentkit.orchestrator.pipeline_models import (
|
||||
PipelineExecutionModel,
|
||||
PipelineStepHistoryModel,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Redis key patterns
|
||||
_EXEC_KEY_PREFIX = "agentkit:pipeline:exec:"
|
||||
_INDEX_KEY = "agentkit:pipeline:index"
|
||||
_TTL_SECONDS = 7 * 24 * 3600 # 7 days
|
||||
|
||||
|
||||
class PipelineStateMemory:
|
||||
"""In-memory pipeline state storage (testing / fallback)."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._executions: dict[str, dict[str, Any]] = {}
|
||||
self._step_history: dict[str, list[dict[str, Any]]] = {}
|
||||
|
||||
async def create_execution(
|
||||
self,
|
||||
pipeline_name: str,
|
||||
steps: list[str],
|
||||
input_data: dict[str, Any] | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> str:
|
||||
execution_id = str(uuid.uuid4())
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
self._executions[execution_id] = {
|
||||
"id": execution_id,
|
||||
"pipeline_name": pipeline_name,
|
||||
"status": "running",
|
||||
"current_step": steps[0] if steps else None,
|
||||
"completed_steps": [],
|
||||
"step_results": {},
|
||||
"input_data": input_data,
|
||||
"final_output": None,
|
||||
"error_message": None,
|
||||
"tenant_id": tenant_id,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"completed_at": None,
|
||||
}
|
||||
self._step_history[execution_id] = []
|
||||
return execution_id
|
||||
|
||||
async def update_step(
|
||||
self,
|
||||
execution_id: str,
|
||||
step_name: str,
|
||||
status: str,
|
||||
output: dict[str, Any] | None = None,
|
||||
error: str | None = None,
|
||||
duration_ms: int | None = None,
|
||||
) -> None:
|
||||
exec_state = self._executions.get(execution_id)
|
||||
if exec_state is None:
|
||||
logger.warning(f"Execution '{execution_id}' not found for step update")
|
||||
return
|
||||
|
||||
exec_state["current_step"] = step_name
|
||||
exec_state["updated_at"] = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
if status == "completed":
|
||||
if step_name not in exec_state["completed_steps"]:
|
||||
exec_state["completed_steps"].append(step_name)
|
||||
if output is not None:
|
||||
exec_state["step_results"][step_name] = output
|
||||
elif status == "failed":
|
||||
exec_state["error_message"] = error
|
||||
|
||||
# Record step history event
|
||||
step_event: dict[str, Any] = {
|
||||
"id": str(uuid.uuid4()),
|
||||
"execution_id": execution_id,
|
||||
"step_name": step_name,
|
||||
"status": status,
|
||||
"output_data": output,
|
||||
"error_message": error,
|
||||
"duration_ms": duration_ms,
|
||||
"started_at": datetime.now(timezone.utc).isoformat(),
|
||||
"completed_at": datetime.now(timezone.utc).isoformat() if status in ("completed", "failed") else None,
|
||||
}
|
||||
self._step_history[execution_id].append(step_event)
|
||||
|
||||
async def complete_execution(
|
||||
self,
|
||||
execution_id: str,
|
||||
final_output: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
exec_state = self._executions.get(execution_id)
|
||||
if exec_state is None:
|
||||
return
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
exec_state["status"] = "completed"
|
||||
exec_state["final_output"] = final_output
|
||||
exec_state["updated_at"] = now
|
||||
exec_state["completed_at"] = now
|
||||
|
||||
async def fail_execution(
|
||||
self,
|
||||
execution_id: str,
|
||||
step_name: str,
|
||||
error: str,
|
||||
) -> None:
|
||||
exec_state = self._executions.get(execution_id)
|
||||
if exec_state is None:
|
||||
return
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
exec_state["status"] = "failed"
|
||||
exec_state["error_message"] = f"Step '{step_name}' failed: {error}"
|
||||
exec_state["updated_at"] = now
|
||||
exec_state["completed_at"] = now
|
||||
|
||||
async def get_execution(self, execution_id: str) -> dict[str, Any] | None:
|
||||
return self._executions.get(execution_id)
|
||||
|
||||
async def list_executions(
|
||||
self,
|
||||
status: str | None = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> list[dict[str, Any]]:
|
||||
results = list(self._executions.values())
|
||||
if status:
|
||||
results = [e for e in results if e.get("status") == status]
|
||||
results.sort(key=lambda e: e.get("created_at", ""), reverse=True)
|
||||
return results[offset : offset + limit]
|
||||
|
||||
async def get_step_history(self, execution_id: str) -> list[dict[str, Any]]:
|
||||
return self._step_history.get(execution_id, [])
|
||||
|
||||
|
||||
class PipelineStateRedis:
|
||||
"""Redis-backed pipeline state storage (hot state).
|
||||
|
||||
Uses Redis Hash for execution state and Sorted Set for indexing.
|
||||
Falls back to PipelineStateMemory if Redis is unavailable.
|
||||
"""
|
||||
|
||||
def __init__(self, redis_url: str = "redis://localhost:6379/0") -> None:
|
||||
self._redis_url = redis_url
|
||||
self._redis: Any = None
|
||||
self._fallback = PipelineStateMemory()
|
||||
self._use_fallback = False
|
||||
|
||||
async def _get_redis(self):
|
||||
if self._redis is None:
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
self._redis = aioredis.from_url(
|
||||
self._redis_url,
|
||||
decode_responses=True,
|
||||
)
|
||||
return self._redis
|
||||
|
||||
async def _safe_redis_call(
|
||||
self, fn: Callable[..., Coroutine[Any, Any, Any]], *args: Any, **kwargs: Any
|
||||
) -> Any:
|
||||
"""Execute a Redis call, falling back to memory on failure."""
|
||||
if self._use_fallback:
|
||||
return None
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
return await fn(redis, *args, **kwargs)
|
||||
except Exception as exc:
|
||||
logger.warning(f"Redis operation failed, switching to memory fallback: {exc}")
|
||||
self._use_fallback = True
|
||||
self._redis = None
|
||||
return None
|
||||
|
||||
def _key(self, execution_id: str) -> str:
|
||||
return f"{_EXEC_KEY_PREFIX}{execution_id}"
|
||||
|
||||
async def create_execution(
|
||||
self,
|
||||
pipeline_name: str,
|
||||
steps: list[str],
|
||||
input_data: dict[str, Any] | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> str:
|
||||
# Always write to fallback first for consistency
|
||||
execution_id = await self._fallback.create_execution(
|
||||
pipeline_name, steps, input_data, tenant_id
|
||||
)
|
||||
|
||||
# Try Redis
|
||||
async def _redis_create(redis: Any) -> None:
|
||||
state = self._fallback._executions[execution_id]
|
||||
score = datetime.now(timezone.utc).timestamp()
|
||||
pipe = redis.pipeline()
|
||||
pipe.set(self._key(execution_id), json.dumps(state), ex=_TTL_SECONDS)
|
||||
pipe.zadd(_INDEX_KEY, {execution_id: score})
|
||||
await pipe.execute()
|
||||
|
||||
await self._safe_redis_call(_redis_create)
|
||||
return execution_id
|
||||
|
||||
async def update_step(
|
||||
self,
|
||||
execution_id: str,
|
||||
step_name: str,
|
||||
status: str,
|
||||
output: dict[str, Any] | None = None,
|
||||
error: str | None = None,
|
||||
duration_ms: int | None = None,
|
||||
) -> None:
|
||||
await self._fallback.update_step(execution_id, step_name, status, output, error, duration_ms)
|
||||
|
||||
async def _redis_update(redis: Any) -> None:
|
||||
state = self._fallback._executions.get(execution_id)
|
||||
if state is None:
|
||||
return
|
||||
await redis.set(self._key(execution_id), json.dumps(state), ex=_TTL_SECONDS)
|
||||
|
||||
await self._safe_redis_call(_redis_update)
|
||||
|
||||
async def complete_execution(
|
||||
self,
|
||||
execution_id: str,
|
||||
final_output: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
await self._fallback.complete_execution(execution_id, final_output)
|
||||
|
||||
async def _redis_complete(redis: Any) -> None:
|
||||
state = self._fallback._executions.get(execution_id)
|
||||
if state is None:
|
||||
return
|
||||
await redis.set(self._key(execution_id), json.dumps(state), ex=_TTL_SECONDS)
|
||||
|
||||
await self._safe_redis_call(_redis_complete)
|
||||
|
||||
async def fail_execution(
|
||||
self,
|
||||
execution_id: str,
|
||||
step_name: str,
|
||||
error: str,
|
||||
) -> None:
|
||||
await self._fallback.fail_execution(execution_id, step_name, error)
|
||||
|
||||
async def _redis_fail(redis: Any) -> None:
|
||||
state = self._fallback._executions.get(execution_id)
|
||||
if state is None:
|
||||
return
|
||||
await redis.set(self._key(execution_id), json.dumps(state), ex=_TTL_SECONDS)
|
||||
|
||||
await self._safe_redis_call(_redis_fail)
|
||||
|
||||
async def get_execution(self, execution_id: str) -> dict[str, Any] | None:
|
||||
# Try Redis first
|
||||
if not self._use_fallback:
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
raw = await redis.get(self._key(execution_id))
|
||||
if raw is not None:
|
||||
return json.loads(raw)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Fallback to memory
|
||||
return await self._fallback.get_execution(execution_id)
|
||||
|
||||
async def list_executions(
|
||||
self,
|
||||
status: str | None = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> list[dict[str, Any]]:
|
||||
# Try Redis sorted set for efficient listing
|
||||
if not self._use_fallback:
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
# Get recent execution IDs from sorted set (newest first)
|
||||
ids = await redis.zrevrange(_INDEX_KEY, offset, offset + limit - 1)
|
||||
if ids:
|
||||
keys = [self._key(eid) for eid in ids]
|
||||
values = await redis.mget(keys)
|
||||
results = []
|
||||
for raw in values:
|
||||
if raw is None:
|
||||
continue
|
||||
state = json.loads(raw)
|
||||
if status is None or state.get("status") == status:
|
||||
results.append(state)
|
||||
return results
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return await self._fallback.list_executions(status, limit, offset)
|
||||
|
||||
async def get_step_history(self, execution_id: str) -> list[dict[str, Any]]:
|
||||
return await self._fallback.get_step_history(execution_id)
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
if self._use_fallback:
|
||||
return False
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
return await redis.ping()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@property
|
||||
def using_fallback(self) -> bool:
|
||||
return self._use_fallback
|
||||
|
||||
|
||||
class PipelineStatePG:
|
||||
"""PostgreSQL cold persistence for pipeline execution records.
|
||||
|
||||
If session_factory is None, all methods are no-op.
|
||||
"""
|
||||
|
||||
def __init__(self, session_factory: Any = None) -> None:
|
||||
self._session_factory = session_factory
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self._session_factory is not None
|
||||
|
||||
async def persist_execution(self, state: dict[str, Any]) -> None:
|
||||
"""Write a completed/failed execution to PostgreSQL."""
|
||||
if not self.enabled:
|
||||
return
|
||||
try:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
async with self._session_factory() as session:
|
||||
model = PipelineExecutionModel(
|
||||
id=state["id"],
|
||||
pipeline_name=state["pipeline_name"],
|
||||
status=state["status"],
|
||||
current_step=state.get("current_step"),
|
||||
completed_steps=state.get("completed_steps", []),
|
||||
step_results=state.get("step_results", {}),
|
||||
input_data=state.get("input_data"),
|
||||
final_output=state.get("final_output"),
|
||||
error_message=state.get("error_message"),
|
||||
tenant_id=state.get("tenant_id"),
|
||||
created_at=datetime.fromisoformat(state["created_at"]) if state.get("created_at") else None,
|
||||
updated_at=datetime.fromisoformat(state["updated_at"]) if state.get("updated_at") else None,
|
||||
completed_at=datetime.fromisoformat(state["completed_at"]) if state.get("completed_at") else None,
|
||||
)
|
||||
await session.merge(model)
|
||||
await session.commit()
|
||||
except Exception as exc:
|
||||
logger.error(f"Failed to persist execution to PG: {exc}")
|
||||
|
||||
async def persist_step_history(
|
||||
self, execution_id: str, steps: list[dict[str, Any]]
|
||||
) -> None:
|
||||
"""Write step history to PostgreSQL."""
|
||||
if not self.enabled:
|
||||
return
|
||||
try:
|
||||
async with self._session_factory() as session:
|
||||
for idx, step in enumerate(steps):
|
||||
model = PipelineStepHistoryModel(
|
||||
id=step.get("id", str(uuid.uuid4())),
|
||||
execution_id=execution_id,
|
||||
step_name=step["step_name"],
|
||||
step_index=idx,
|
||||
status=step["status"],
|
||||
input_data=step.get("input_data"),
|
||||
output_data=step.get("output_data"),
|
||||
error_message=step.get("error_message"),
|
||||
duration_ms=step.get("duration_ms"),
|
||||
retry_attempt=step.get("retry_attempt", 0),
|
||||
started_at=datetime.fromisoformat(step["started_at"]) if step.get("started_at") else None,
|
||||
completed_at=datetime.fromisoformat(step["completed_at"]) if step.get("completed_at") else None,
|
||||
)
|
||||
await session.merge(model)
|
||||
await session.commit()
|
||||
except Exception as exc:
|
||||
logger.error(f"Failed to persist step history to PG: {exc}")
|
||||
|
||||
async def query_executions(
|
||||
self,
|
||||
pipeline_name: str | None = None,
|
||||
status: str | None = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Query historical executions from PostgreSQL."""
|
||||
if not self.enabled:
|
||||
return []
|
||||
try:
|
||||
from sqlalchemy import select
|
||||
|
||||
async with self._session_factory() as session:
|
||||
stmt = select(PipelineExecutionModel).order_by(
|
||||
PipelineExecutionModel.created_at.desc()
|
||||
)
|
||||
if pipeline_name:
|
||||
stmt = stmt.where(
|
||||
PipelineExecutionModel.pipeline_name == pipeline_name
|
||||
)
|
||||
if status:
|
||||
stmt = stmt.where(PipelineExecutionModel.status == status)
|
||||
stmt = stmt.offset(offset).limit(limit)
|
||||
result = await session.execute(stmt)
|
||||
rows = result.scalars().all()
|
||||
return [self._model_to_dict(row) for row in rows]
|
||||
except Exception as exc:
|
||||
logger.error(f"Failed to query executions from PG: {exc}")
|
||||
return []
|
||||
|
||||
async def get_execution(self, execution_id: str) -> dict[str, Any] | None:
|
||||
"""Get a single execution from PostgreSQL (for Redis miss fallback)."""
|
||||
if not self.enabled:
|
||||
return None
|
||||
try:
|
||||
from sqlalchemy import select
|
||||
|
||||
async with self._session_factory() as session:
|
||||
stmt = select(PipelineExecutionModel).where(
|
||||
PipelineExecutionModel.id == execution_id
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
row = result.scalar_one_or_none()
|
||||
if row is None:
|
||||
return None
|
||||
return self._model_to_dict(row)
|
||||
except Exception as exc:
|
||||
logger.error(f"Failed to get execution from PG: {exc}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _model_to_dict(model: PipelineExecutionModel) -> dict[str, Any]:
|
||||
return {
|
||||
"id": model.id,
|
||||
"pipeline_name": model.pipeline_name,
|
||||
"status": model.status,
|
||||
"current_step": model.current_step,
|
||||
"completed_steps": model.completed_steps or [],
|
||||
"step_results": model.step_results or {},
|
||||
"input_data": model.input_data,
|
||||
"final_output": model.final_output,
|
||||
"error_message": model.error_message,
|
||||
"tenant_id": model.tenant_id,
|
||||
"created_at": model.created_at.isoformat() if model.created_at else None,
|
||||
"updated_at": model.updated_at.isoformat() if model.updated_at else None,
|
||||
"completed_at": model.completed_at.isoformat() if model.completed_at else None,
|
||||
}
|
||||
|
||||
|
||||
class PipelineStateManager:
|
||||
"""Unified pipeline state manager — Redis hot + PG cold.
|
||||
|
||||
- create / update → Redis (with in-memory fallback)
|
||||
- complete / fail → Redis + async persist to PG
|
||||
- get → Redis first, PG fallback
|
||||
- list → Redis for recent, PG for historical
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_url: str | None = None,
|
||||
session_factory: Any = None,
|
||||
) -> None:
|
||||
if redis_url:
|
||||
self._hot = PipelineStateRedis(redis_url=redis_url)
|
||||
else:
|
||||
self._hot = PipelineStateMemory()
|
||||
self._cold = PipelineStatePG(session_factory=session_factory)
|
||||
|
||||
@property
|
||||
def hot_store(self) -> PipelineStateMemory | PipelineStateRedis:
|
||||
return self._hot
|
||||
|
||||
@property
|
||||
def cold_store(self) -> PipelineStatePG:
|
||||
return self._cold
|
||||
|
||||
async def create_execution(
|
||||
self,
|
||||
pipeline_name: str,
|
||||
steps: list[str],
|
||||
input_data: dict[str, Any] | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> str:
|
||||
return await self._hot.create_execution(pipeline_name, steps, input_data, tenant_id)
|
||||
|
||||
async def update_step(
|
||||
self,
|
||||
execution_id: str,
|
||||
step_name: str,
|
||||
status: str,
|
||||
output: dict[str, Any] | None = None,
|
||||
error: str | None = None,
|
||||
duration_ms: int | None = None,
|
||||
) -> None:
|
||||
await self._hot.update_step(execution_id, step_name, status, output, error, duration_ms)
|
||||
|
||||
async def complete_execution(
|
||||
self,
|
||||
execution_id: str,
|
||||
final_output: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
await self._hot.complete_execution(execution_id, final_output)
|
||||
# Persist to PG
|
||||
state = await self._hot.get_execution(execution_id)
|
||||
if state:
|
||||
await self._cold.persist_execution(state)
|
||||
step_history = await self._hot.get_step_history(execution_id)
|
||||
if step_history:
|
||||
await self._cold.persist_step_history(execution_id, step_history)
|
||||
|
||||
async def fail_execution(
|
||||
self,
|
||||
execution_id: str,
|
||||
step_name: str,
|
||||
error: str,
|
||||
) -> None:
|
||||
await self._hot.fail_execution(execution_id, step_name, error)
|
||||
# Persist to PG
|
||||
state = await self._hot.get_execution(execution_id)
|
||||
if state:
|
||||
await self._cold.persist_execution(state)
|
||||
step_history = await self._hot.get_step_history(execution_id)
|
||||
if step_history:
|
||||
await self._cold.persist_step_history(execution_id, step_history)
|
||||
|
||||
async def get_execution(self, execution_id: str) -> dict[str, Any] | None:
|
||||
# Redis / memory first
|
||||
state = await self._hot.get_execution(execution_id)
|
||||
if state is not None:
|
||||
return state
|
||||
# PG fallback
|
||||
return await self._cold.get_execution(execution_id)
|
||||
|
||||
async def list_executions(
|
||||
self,
|
||||
status: str | None = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> list[dict[str, Any]]:
|
||||
# Hot store for recent executions
|
||||
results = await self._hot.list_executions(status, limit, offset)
|
||||
if results:
|
||||
return results
|
||||
# Cold store for historical queries
|
||||
return await self._cold.query_executions(status=status, limit=limit, offset=offset)
|
||||
|
||||
async def get_step_history(self, execution_id: str) -> list[dict[str, Any]]:
|
||||
return await self._hot.get_step_history(execution_id)
|
||||
|
||||
async def health_check(self) -> dict[str, bool]:
|
||||
"""Check health of both stores."""
|
||||
hot_ok = True
|
||||
if isinstance(self._hot, PipelineStateRedis):
|
||||
hot_ok = await self._hot.health_check()
|
||||
cold_ok = self._cold.enabled
|
||||
return {"hot": hot_ok, "cold": cold_ok}
|
||||
|
|
@ -0,0 +1,661 @@
|
|||
"""Unit tests for Pipeline execution state persistence."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.orchestrator.pipeline_engine import PipelineEngine
|
||||
from agentkit.orchestrator.pipeline_schema import Pipeline, PipelineStage, StageStatus
|
||||
from agentkit.orchestrator.pipeline_state import (
|
||||
PipelineStateMemory,
|
||||
PipelineStatePG,
|
||||
PipelineStateRedis,
|
||||
PipelineStateManager,
|
||||
)
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
# PipelineStateMemory
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestPipelineStateMemory:
|
||||
"""Tests for in-memory pipeline state storage."""
|
||||
|
||||
@pytest.fixture
|
||||
def store(self) -> PipelineStateMemory:
|
||||
return PipelineStateMemory()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_execution(self, store: PipelineStateMemory):
|
||||
eid = await store.create_execution(
|
||||
pipeline_name="test_pipeline",
|
||||
steps=["step_a", "step_b"],
|
||||
input_data={"key": "value"},
|
||||
)
|
||||
assert eid is not None
|
||||
state = await store.get_execution(eid)
|
||||
assert state is not None
|
||||
assert state["pipeline_name"] == "test_pipeline"
|
||||
assert state["status"] == "running"
|
||||
assert state["current_step"] == "step_a"
|
||||
assert state["completed_steps"] == []
|
||||
assert state["input_data"] == {"key": "value"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_step_completed(self, store: PipelineStateMemory):
|
||||
eid = await store.create_execution("p", ["s1", "s2"])
|
||||
await store.update_step(eid, "s1", "completed", output={"result": 42})
|
||||
state = await store.get_execution(eid)
|
||||
assert "s1" in state["completed_steps"]
|
||||
assert state["step_results"]["s1"] == {"result": 42}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_step_failed(self, store: PipelineStateMemory):
|
||||
eid = await store.create_execution("p", ["s1"])
|
||||
await store.update_step(eid, "s1", "failed", error="boom")
|
||||
state = await store.get_execution(eid)
|
||||
assert state["error_message"] == "boom"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_execution(self, store: PipelineStateMemory):
|
||||
eid = await store.create_execution("p", ["s1"])
|
||||
await store.complete_execution(eid, final_output={"done": True})
|
||||
state = await store.get_execution(eid)
|
||||
assert state["status"] == "completed"
|
||||
assert state["final_output"] == {"done": True}
|
||||
assert state["completed_at"] is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fail_execution(self, store: PipelineStateMemory):
|
||||
eid = await store.create_execution("p", ["s1"])
|
||||
await store.fail_execution(eid, "s1", "timeout")
|
||||
state = await store.get_execution(eid)
|
||||
assert state["status"] == "failed"
|
||||
assert "s1" in state["error_message"]
|
||||
assert "timeout" in state["error_message"]
|
||||
assert state["completed_at"] is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_execution_not_found(self, store: PipelineStateMemory):
|
||||
result = await store.get_execution("nonexistent")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_executions(self, store: PipelineStateMemory):
|
||||
eid1 = await store.create_execution("p1", ["s1"])
|
||||
eid2 = await store.create_execution("p2", ["s2"])
|
||||
await store.complete_execution(eid1)
|
||||
# List all
|
||||
all_execs = await store.list_executions()
|
||||
assert len(all_execs) == 2
|
||||
# Filter by status
|
||||
completed = await store.list_executions(status="completed")
|
||||
assert len(completed) == 1
|
||||
assert completed[0]["id"] == eid1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_executions_pagination(self, store: PipelineStateMemory):
|
||||
for i in range(5):
|
||||
await store.create_execution(f"p{i}", ["s1"])
|
||||
page1 = await store.list_executions(limit=2, offset=0)
|
||||
page2 = await store.list_executions(limit=2, offset=2)
|
||||
assert len(page1) == 2
|
||||
assert len(page2) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_step_history(self, store: PipelineStateMemory):
|
||||
eid = await store.create_execution("p", ["s1", "s2"])
|
||||
await store.update_step(eid, "s1", "completed", output={"r": 1})
|
||||
await store.update_step(eid, "s2", "failed", error="err")
|
||||
history = await store.get_step_history(eid)
|
||||
assert len(history) == 2
|
||||
assert history[0]["step_name"] == "s1"
|
||||
assert history[0]["status"] == "completed"
|
||||
assert history[1]["step_name"] == "s2"
|
||||
assert history[1]["status"] == "failed"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_step_nonexistent_execution(self, store: PipelineStateMemory):
|
||||
# Should not raise, just log warning
|
||||
await store.update_step("nonexistent", "s1", "completed")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_execution_with_tenant(self, store: PipelineStateMemory):
|
||||
eid = await store.create_execution("p", ["s1"], tenant_id="tenant_123")
|
||||
state = await store.get_execution(eid)
|
||||
assert state["tenant_id"] == "tenant_123"
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
# PipelineStateRedis
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestPipelineStateRedis:
|
||||
"""Tests for Redis-backed pipeline state storage (using mocks)."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis(self):
|
||||
"""Create a mock Redis client."""
|
||||
redis = AsyncMock()
|
||||
redis.get = AsyncMock(return_value=None)
|
||||
redis.set = AsyncMock(return_value=True)
|
||||
redis.zadd = AsyncMock(return_value=1)
|
||||
redis.zrevrange = AsyncMock(return_value=[])
|
||||
redis.mget = AsyncMock(return_value=[])
|
||||
# Redis pipeline: set/zadd are synchronous (return self for chaining), execute is async
|
||||
pipe = MagicMock()
|
||||
pipe.set = MagicMock(return_value=pipe)
|
||||
pipe.zadd = MagicMock(return_value=pipe)
|
||||
pipe.execute = AsyncMock(return_value=[True, 1])
|
||||
redis.pipeline = MagicMock(return_value=pipe)
|
||||
return redis
|
||||
|
||||
@pytest.fixture
|
||||
def store(self, mock_redis) -> PipelineStateRedis:
|
||||
"""Create a PipelineStateRedis with mocked Redis."""
|
||||
store = PipelineStateRedis(redis_url="redis://localhost:6379/0")
|
||||
# Pre-inject the mock Redis client
|
||||
store._redis = mock_redis
|
||||
return store
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_execution_writes_to_redis(self, store: PipelineStateRedis, mock_redis):
|
||||
eid = await store.create_execution("test_pipeline", ["s1", "s2"])
|
||||
assert eid is not None
|
||||
# Redis pipeline should have been used (pipe.set + pipe.zadd + pipe.execute)
|
||||
pipe = mock_redis.pipeline.return_value
|
||||
pipe.set.assert_called_once()
|
||||
call_args = pipe.set.call_args
|
||||
assert call_args[0][0].startswith("agentkit:pipeline:exec:")
|
||||
# Verify the stored data is valid JSON
|
||||
stored_data = json.loads(call_args[0][1])
|
||||
assert stored_data["pipeline_name"] == "test_pipeline"
|
||||
assert stored_data["status"] == "running"
|
||||
# Verify TTL was set (7 days)
|
||||
assert call_args[1].get("ex") == 7 * 24 * 3600
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_execution_adds_to_sorted_set(self, store: PipelineStateRedis, mock_redis):
|
||||
await store.create_execution("p", ["s1"])
|
||||
# ZADD should have been called via pipeline
|
||||
pipe = mock_redis.pipeline.return_value
|
||||
pipe.zadd.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_step_writes_to_redis(self, store: PipelineStateRedis, mock_redis):
|
||||
eid = await store.create_execution("p", ["s1"])
|
||||
mock_redis.set.reset_mock()
|
||||
await store.update_step(eid, "s1", "completed", output={"r": 1})
|
||||
mock_redis.set.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_execution_writes_to_redis(self, store: PipelineStateRedis, mock_redis):
|
||||
eid = await store.create_execution("p", ["s1"])
|
||||
mock_redis.set.reset_mock()
|
||||
await store.complete_execution(eid, final_output={"done": True})
|
||||
mock_redis.set.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fail_execution_writes_to_redis(self, store: PipelineStateRedis, mock_redis):
|
||||
eid = await store.create_execution("p", ["s1"])
|
||||
mock_redis.set.reset_mock()
|
||||
await store.fail_execution(eid, "s1", "error")
|
||||
mock_redis.set.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_execution_from_redis(self, store: PipelineStateRedis, mock_redis):
|
||||
eid = await store.create_execution("p", ["s1"])
|
||||
# Simulate Redis returning data
|
||||
state = await store._fallback.get_execution(eid)
|
||||
mock_redis.get.return_value = json.dumps(state)
|
||||
result = await store.get_execution(eid)
|
||||
assert result is not None
|
||||
assert result["pipeline_name"] == "p"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_execution_redis_miss_falls_back_to_memory(self, store: PipelineStateRedis, mock_redis):
|
||||
eid = await store.create_execution("p", ["s1"])
|
||||
# Redis returns None (miss)
|
||||
mock_redis.get.return_value = None
|
||||
# Should still find it in memory fallback
|
||||
result = await store.get_execution(eid)
|
||||
assert result is not None
|
||||
assert result["pipeline_name"] == "p"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_executions_from_sorted_set(self, store: PipelineStateRedis, mock_redis):
|
||||
eid = await store.create_execution("p", ["s1"])
|
||||
state = await store._fallback.get_execution(eid)
|
||||
mock_redis.zrevrange.return_value = [eid]
|
||||
mock_redis.mget.return_value = [json.dumps(state)]
|
||||
results = await store.list_executions()
|
||||
assert len(results) == 1
|
||||
assert results[0]["pipeline_name"] == "p"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_on_redis_failure(self, mock_redis):
|
||||
store = PipelineStateRedis(redis_url="redis://localhost:6379/0")
|
||||
# Make Redis initialization fail
|
||||
mock_redis.ping = AsyncMock(side_effect=Exception("connection refused"))
|
||||
store._redis = mock_redis
|
||||
# Force a Redis operation to fail
|
||||
mock_redis.set = AsyncMock(side_effect=Exception("connection refused"))
|
||||
mock_redis.pipeline = MagicMock(side_effect=Exception("connection refused"))
|
||||
# Should fall back to memory
|
||||
eid = await store.create_execution("p", ["s1"])
|
||||
assert eid is not None
|
||||
assert store.using_fallback is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check(self, store: PipelineStateRedis, mock_redis):
|
||||
mock_redis.ping = AsyncMock(return_value=True)
|
||||
assert await store.health_check() is True
|
||||
mock_redis.ping = AsyncMock(side_effect=Exception("fail"))
|
||||
assert await store.health_check() is False
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
# PipelineStatePG
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestPipelineStatePG:
|
||||
"""Tests for PostgreSQL cold persistence (using mocks)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_op_when_session_factory_is_none(self):
|
||||
pg = PipelineStatePG(session_factory=None)
|
||||
assert pg.enabled is False
|
||||
# All methods should be no-op
|
||||
await pg.persist_execution({"id": "1", "pipeline_name": "p", "status": "completed"})
|
||||
await pg.persist_step_history("1", [])
|
||||
result = await pg.query_executions()
|
||||
assert result == []
|
||||
result = await pg.get_execution("1")
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persist_execution(self):
|
||||
mock_session = AsyncMock()
|
||||
mock_session.merge = AsyncMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_factory.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
pg = PipelineStatePG(session_factory=mock_factory)
|
||||
assert pg.enabled is True
|
||||
|
||||
state = {
|
||||
"id": "test-id-123",
|
||||
"pipeline_name": "test_pipeline",
|
||||
"status": "completed",
|
||||
"current_step": None,
|
||||
"completed_steps": ["s1"],
|
||||
"step_results": {"s1": {"r": 1}},
|
||||
"input_data": {"key": "val"},
|
||||
"final_output": {"done": True},
|
||||
"error_message": None,
|
||||
"tenant_id": None,
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
"updated_at": datetime.now(timezone.utc).isoformat(),
|
||||
"completed_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
await pg.persist_execution(state)
|
||||
mock_session.merge.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persist_step_history(self):
|
||||
mock_session = AsyncMock()
|
||||
mock_session.merge = AsyncMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_factory.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
pg = PipelineStatePG(session_factory=mock_factory)
|
||||
|
||||
steps = [
|
||||
{
|
||||
"id": "step-id-1",
|
||||
"step_name": "s1",
|
||||
"status": "completed",
|
||||
"output_data": {"r": 1},
|
||||
"error_message": None,
|
||||
"duration_ms": 100,
|
||||
"started_at": datetime.now(timezone.utc).isoformat(),
|
||||
"completed_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
]
|
||||
await pg.persist_step_history("exec-1", steps)
|
||||
mock_session.merge.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_persist_execution_handles_error(self):
|
||||
mock_session = AsyncMock()
|
||||
mock_session.merge = AsyncMock(side_effect=Exception("DB error"))
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_factory.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
pg = PipelineStatePG(session_factory=mock_factory)
|
||||
# Should not raise
|
||||
await pg.persist_execution({
|
||||
"id": "1",
|
||||
"pipeline_name": "p",
|
||||
"status": "completed",
|
||||
"created_at": datetime.now(timezone.utc).isoformat(),
|
||||
"updated_at": datetime.now(timezone.utc).isoformat(),
|
||||
})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_executions(self):
|
||||
from agentkit.orchestrator.pipeline_models import PipelineExecutionModel
|
||||
|
||||
# Create a mock model instance
|
||||
model = PipelineExecutionModel(
|
||||
id="test-id",
|
||||
pipeline_name="test_pipeline",
|
||||
status="completed",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = [model]
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_factory.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
pg = PipelineStatePG(session_factory=mock_factory)
|
||||
results = await pg.query_executions(pipeline_name="test_pipeline")
|
||||
assert len(results) == 1
|
||||
assert results[0]["pipeline_name"] == "test_pipeline"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_execution_found(self):
|
||||
from agentkit.orchestrator.pipeline_models import PipelineExecutionModel
|
||||
|
||||
model = PipelineExecutionModel(
|
||||
id="test-id",
|
||||
pipeline_name="test_pipeline",
|
||||
status="completed",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = model
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_factory.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
pg = PipelineStatePG(session_factory=mock_factory)
|
||||
result = await pg.get_execution("test-id")
|
||||
assert result is not None
|
||||
assert result["id"] == "test-id"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_execution_not_found(self):
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = None
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_factory.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
pg = PipelineStatePG(session_factory=mock_factory)
|
||||
result = await pg.get_execution("nonexistent")
|
||||
assert result is None
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
# PipelineStateManager
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestPipelineStateManager:
|
||||
"""Tests for the unified state manager."""
|
||||
|
||||
@pytest.fixture
|
||||
def manager(self) -> PipelineStateManager:
|
||||
"""Create a manager with memory-only backend."""
|
||||
return PipelineStateManager(redis_url=None, session_factory=None)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_and_get_execution(self, manager: PipelineStateManager):
|
||||
eid = await manager.create_execution("p", ["s1"], input_data={"k": "v"})
|
||||
state = await manager.get_execution(eid)
|
||||
assert state is not None
|
||||
assert state["pipeline_name"] == "p"
|
||||
assert state["status"] == "running"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_step(self, manager: PipelineStateManager):
|
||||
eid = await manager.create_execution("p", ["s1"])
|
||||
await manager.update_step(eid, "s1", "completed", output={"r": 1})
|
||||
state = await manager.get_execution(eid)
|
||||
assert "s1" in state["completed_steps"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_persists_to_cold(self):
|
||||
"""Test that completing an execution triggers PG persist."""
|
||||
mock_session = AsyncMock()
|
||||
mock_session.merge = AsyncMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_factory.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
manager = PipelineStateManager(redis_url=None, session_factory=mock_factory)
|
||||
eid = await manager.create_execution("p", ["s1"])
|
||||
await manager.update_step(eid, "s1", "completed", output={"r": 1})
|
||||
await manager.complete_execution(eid, final_output={"done": True})
|
||||
|
||||
# PG persist should have been called
|
||||
mock_session.merge.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fail_persists_to_cold(self):
|
||||
"""Test that failing an execution triggers PG persist."""
|
||||
mock_session = AsyncMock()
|
||||
mock_session.merge = AsyncMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_factory.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
manager = PipelineStateManager(redis_url=None, session_factory=mock_factory)
|
||||
eid = await manager.create_execution("p", ["s1"])
|
||||
await manager.fail_execution(eid, "s1", "error")
|
||||
|
||||
mock_session.merge.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_execution_pg_fallback(self):
|
||||
"""Test Redis miss falls back to PG."""
|
||||
from agentkit.orchestrator.pipeline_models import PipelineExecutionModel
|
||||
|
||||
model = PipelineExecutionModel(
|
||||
id="pg-exec-id",
|
||||
pipeline_name="pg_pipeline",
|
||||
status="completed",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = model
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
mock_factory = MagicMock()
|
||||
mock_factory.return_value.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_factory.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
manager = PipelineStateManager(redis_url=None, session_factory=mock_factory)
|
||||
# This execution_id doesn't exist in hot store, should fall back to PG
|
||||
result = await manager.get_execution("pg-exec-id")
|
||||
assert result is not None
|
||||
assert result["pipeline_name"] == "pg_pipeline"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_executions_hot_first(self, manager: PipelineStateManager):
|
||||
eid = await manager.create_execution("p", ["s1"])
|
||||
results = await manager.list_executions()
|
||||
assert len(results) == 1
|
||||
assert results[0]["id"] == eid
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_memory_only(self, manager: PipelineStateManager):
|
||||
health = await manager.health_check()
|
||||
assert health["hot"] is True
|
||||
assert health["cold"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_with_pg(self):
|
||||
mock_factory = MagicMock()
|
||||
manager = PipelineStateManager(redis_url=None, session_factory=mock_factory)
|
||||
health = await manager.health_check()
|
||||
assert health["hot"] is True
|
||||
assert health["cold"] is True
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
# PipelineEngine with state persistence
|
||||
# ═══════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
class TestPipelineEngineWithState:
|
||||
"""Tests for PipelineEngine integration with state persistence."""
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline(self) -> Pipeline:
|
||||
return Pipeline(
|
||||
name="test_pipeline",
|
||||
version="1.0",
|
||||
description="Test pipeline",
|
||||
stages=[
|
||||
PipelineStage(name="step_a", agent="agent1", action="do_a"),
|
||||
PipelineStage(name="step_b", agent="agent2", action="do_b", depends_on=["step_a"]),
|
||||
],
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_engine_without_state_backward_compatible(self, pipeline: Pipeline):
|
||||
"""Engine without state_manager should work as before."""
|
||||
engine = PipelineEngine(dispatcher=None)
|
||||
result = await engine.execute(pipeline)
|
||||
assert result.status == StageStatus.COMPLETED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_engine_with_state_creates_execution(self, pipeline: Pipeline):
|
||||
"""Engine with state_manager should create execution state."""
|
||||
state_manager = PipelineStateManager(redis_url=None, session_factory=None)
|
||||
engine = PipelineEngine(dispatcher=None, state_manager=state_manager)
|
||||
result = await engine.execute(pipeline)
|
||||
assert result.status == StageStatus.COMPLETED
|
||||
# Check that execution was created in state store
|
||||
executions = await state_manager.list_executions()
|
||||
assert len(executions) == 1
|
||||
assert executions[0]["status"] == "completed"
|
||||
assert executions[0]["pipeline_name"] == "test_pipeline"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_engine_with_state_updates_steps(self, pipeline: Pipeline):
|
||||
"""Engine should update step state after each stage."""
|
||||
state_manager = PipelineStateManager(redis_url=None, session_factory=None)
|
||||
engine = PipelineEngine(dispatcher=None, state_manager=state_manager)
|
||||
await engine.execute(pipeline)
|
||||
executions = await state_manager.list_executions()
|
||||
exec_state = executions[0]
|
||||
# Both steps should be completed
|
||||
assert "step_a" in exec_state["completed_steps"]
|
||||
assert "step_b" in exec_state["completed_steps"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_engine_with_state_on_failure(self):
|
||||
"""Engine should persist failure state when a stage fails."""
|
||||
pipeline = Pipeline(
|
||||
name="fail_pipeline",
|
||||
version="1.0",
|
||||
description="Pipeline that fails",
|
||||
stages=[
|
||||
PipelineStage(name="bad_step", agent="agent1", action="fail"),
|
||||
],
|
||||
)
|
||||
|
||||
# Create a dispatcher that raises
|
||||
mock_dispatcher = AsyncMock()
|
||||
mock_dispatcher.dispatch = AsyncMock(side_effect=Exception("boom"))
|
||||
|
||||
state_manager = PipelineStateManager(redis_url=None, session_factory=None)
|
||||
engine = PipelineEngine(dispatcher=mock_dispatcher, state_manager=state_manager)
|
||||
result = await engine.execute(pipeline)
|
||||
assert result.status == StageStatus.FAILED
|
||||
# Check state was persisted
|
||||
executions = await state_manager.list_executions()
|
||||
assert len(executions) == 1
|
||||
assert executions[0]["status"] == "failed"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_engine_state_survives_check(self, pipeline: Pipeline):
|
||||
"""Verify state can be retrieved after execution."""
|
||||
state_manager = PipelineStateManager(redis_url=None, session_factory=None)
|
||||
engine = PipelineEngine(dispatcher=None, state_manager=state_manager)
|
||||
result = await engine.execute(pipeline, context={"brand": "acme"})
|
||||
# Get execution by ID
|
||||
executions = await state_manager.list_executions()
|
||||
eid = executions[0]["id"]
|
||||
state = await state_manager.get_execution(eid)
|
||||
assert state is not None
|
||||
assert state["pipeline_name"] == "test_pipeline"
|
||||
assert state["status"] == "completed"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_engine_with_circular_dependency(self):
|
||||
"""Engine should handle circular dependency gracefully."""
|
||||
pipeline = Pipeline(
|
||||
name="circular",
|
||||
version="1.0",
|
||||
description="Circular pipeline",
|
||||
stages=[
|
||||
PipelineStage(name="a", agent="agent1", action="do", depends_on=["b"]),
|
||||
PipelineStage(name="b", agent="agent2", action="do", depends_on=["a"]),
|
||||
],
|
||||
)
|
||||
state_manager = PipelineStateManager(redis_url=None, session_factory=None)
|
||||
engine = PipelineEngine(dispatcher=None, state_manager=state_manager)
|
||||
result = await engine.execute(pipeline)
|
||||
assert result.status == StageStatus.FAILED
|
||||
assert "Circular" in result.error_message
|
||||
# No execution state should be created (topological sort fails before creation)
|
||||
executions = await state_manager.list_executions()
|
||||
assert len(executions) == 0
|
||||
Loading…
Reference in New Issue