fischer-agentkit/src/agentkit/orchestrator/pipeline_engine.py

851 lines
33 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Pipeline Engine - DAG + 并行执行 + 步骤重试 + Saga 补偿"""
import asyncio
import logging
from collections import defaultdict
from datetime import datetime, timezone
from typing import Any
from agentkit.orchestrator.compensation import SagaOrchestrator
from agentkit.orchestrator.pipeline_schema import (
AdversarialState,
AdaptiveConfig,
Pipeline,
PipelineResult,
PipelineStage,
ReflectionReport,
ReviewFeedback,
ReviewIssue,
StageResult,
StageStatus,
)
from agentkit.orchestrator.reflection import PipelineReflector, PipelineReplanner
from agentkit.orchestrator.retry import execute_with_retry
logger = logging.getLogger(__name__)
class PipelineEngine:
"""Pipeline 执行引擎
支持:
- DAG 拓扑排序
- 同层并行执行asyncio.gather
- 变量解析
- 条件执行
- 步骤级指数退避重试StepRetryPolicy
- Saga 补偿LIFO 回滚已完成步骤)
- 状态持久化(可选)
"""
def __init__(self, dispatcher: Any = None, state_manager: Any = None, llm_gateway: Any = None):
self._dispatcher = dispatcher
self._state_manager = state_manager
self._llm_gateway = llm_gateway
async def execute(
self,
pipeline: Pipeline,
context: dict[str, Any] | None = None,
adaptive_config: AdaptiveConfig | None = None,
) -> PipelineResult:
"""执行 Pipeline
Args:
pipeline: Pipeline 定义
context: 运行时上下文变量
adaptive_config: 自适应配置,启用反思-重规划闭环
"""
# First execution
result = await self._execute_pipeline(pipeline, context)
# If failed and adaptive is enabled, enter reflection-replanning loop
if result.status == StageStatus.FAILED and adaptive_config and adaptive_config.enabled:
result = await self._adaptive_loop(pipeline, context, result, adaptive_config)
return result
async def _adaptive_loop(
self,
pipeline: Pipeline,
context: dict[str, Any] | None,
failed_result: PipelineResult,
adaptive_config: AdaptiveConfig,
) -> PipelineResult:
"""反思-重规划闭环:分析失败原因 → 修正 Pipeline → 重新执行。"""
reflector = PipelineReflector(llm_gateway=self._llm_gateway)
replanner = PipelineReplanner(llm_gateway=self._llm_gateway)
current_pipeline = pipeline
current_result = failed_result
reflections: list[ReflectionReport] = []
for reflection_num in range(1, adaptive_config.max_reflections + 1):
# Reflect
report = await reflector.reflect(current_pipeline, current_result, reflection_num)
reflections.append(report)
logger.info(
f"Pipeline reflection #{reflection_num}: "
f"failure_type={report.failure_type}, "
f"root_cause={report.root_cause}"
)
# Replan
new_pipeline = await replanner.replan(current_pipeline, current_result, report)
logger.info(f"Pipeline replanned: {new_pipeline.name} ({len(new_pipeline.stages)} stages)")
# Re-execute
current_result = await self._execute_pipeline(new_pipeline, context)
current_pipeline = new_pipeline
# Record reflection in metadata
current_result.metadata["reflections"] = [
r.model_dump() for r in reflections
]
if current_result.status == StageStatus.COMPLETED:
logger.info(f"Pipeline succeeded after {reflection_num} reflection(s)")
return current_result
# Exhausted reflections
logger.warning(
f"Pipeline failed after {adaptive_config.max_reflections} reflection(s)"
)
current_result.metadata["reflections"] = [
r.model_dump() for r in reflections
]
return current_result
async def _execute_pipeline(
self,
pipeline: Pipeline,
context: dict[str, Any] | None = None,
) -> PipelineResult:
"""执行 Pipeline 的核心逻辑(不含反思-重规划)。"""
result = PipelineResult(pipeline_name=pipeline.name)
result.variables = {**pipeline.variables, **(context or {})}
# 拓扑排序 + 按依赖层级分组
try:
level_groups = self._topological_group(pipeline.stages)
except ValueError as e:
result.status = StageStatus.FAILED
result.error_message = str(e)
return result
# Create execution state if state_manager is configured
execution_id: str | None = None
if self._state_manager is not None:
try:
step_names = [s.name for s in pipeline.stages]
execution_id = await self._state_manager.create_execution(
pipeline_name=pipeline.name,
steps=step_names,
input_data=context,
)
except (asyncio.TimeoutError, ConnectionError, RuntimeError, ValueError) as exc:
logger.warning(f"Failed to create execution state: {exc}")
# Create Saga orchestrator for compensation tracking
saga = SagaOrchestrator()
# 逐层执行
for level, stages in enumerate(level_groups):
logger.info(f"Pipeline '{pipeline.name}' executing level {level} with {len(stages)} stage(s)")
# 并行执行同层 stages
tasks = []
for stage in stages:
tasks.append(self._execute_stage(stage, result, saga))
stage_results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理结果
for stage, sr in zip(stages, stage_results):
if isinstance(sr, Exception):
sr = StageResult(
stage_name=stage.name,
status=StageStatus.FAILED,
error_message=str(sr),
)
result.stage_results[stage.name] = sr
# Update step state
if self._state_manager is not None and execution_id is not None:
try:
step_status = "completed" if sr.status == StageStatus.COMPLETED else sr.status.value
step_output = sr.output_data if hasattr(sr, 'output_data') else None
step_error = sr.error_message if hasattr(sr, 'error_message') else None
await self._state_manager.update_step(
execution_id=execution_id,
step_name=stage.name,
status=step_status,
output=step_output,
error=step_error,
)
except (asyncio.TimeoutError, ConnectionError, RuntimeError, ValueError) as exc:
logger.warning(f"Failed to update step state: {exc}")
# 收集输出变量
if sr.output_data and isinstance(sr, dict):
pass
elif hasattr(sr, 'output_data') and sr.output_data:
for output_key in stage.outputs:
if output_key in sr.output_data:
result.variables[output_key] = sr.output_data[output_key]
# 检查是否需要中止
if hasattr(sr, 'status') and sr.status == StageStatus.FAILED:
if not stage.continue_on_failure:
# Execute Saga compensation for completed steps
compensation_results = await saga.compensate()
if compensation_results:
failed_compensations = [
cr for cr in compensation_results if not cr.success and cr.error != "no_compensation_needed"
]
if failed_compensations:
logger.warning(
f"Compensation had {len(failed_compensations)} failures: "
f"{[c.step_name for c in failed_compensations]}"
)
result.status = StageStatus.FAILED
result.error_message = f"Stage '{stage.name}' failed"
# Fail execution state
if self._state_manager is not None and execution_id is not None:
try:
await self._state_manager.fail_execution(
execution_id=execution_id,
step_name=stage.name,
error=result.error_message,
)
except (asyncio.TimeoutError, ConnectionError, RuntimeError, ValueError) as exc:
logger.warning(f"Failed to persist failure state: {exc}")
return result
result.status = StageStatus.COMPLETED
# Complete execution state
if self._state_manager is not None and execution_id is not None:
try:
final_output = {
name: sr.output_data
for name, sr in result.stage_results.items()
if sr.output_data is not None
}
await self._state_manager.complete_execution(
execution_id=execution_id,
final_output=final_output,
)
except (asyncio.TimeoutError, ConnectionError, RuntimeError, ValueError) as exc:
logger.warning(f"Failed to persist completion state: {exc}")
return result
async def _execute_stage(
self,
stage: PipelineStage,
pipeline_result: PipelineResult,
saga: SagaOrchestrator,
) -> StageResult:
"""执行单个 stage"""
started_at = datetime.now(timezone.utc).isoformat()
# 条件检查
if stage.condition and not self._evaluate_condition(stage.condition, pipeline_result.variables):
return StageResult(
stage_name=stage.name,
status=StageStatus.SKIPPED,
started_at=started_at,
completed_at=datetime.now(timezone.utc).isoformat(),
)
# 如果配置了 verifier进入对抗模式
if stage.verifier:
return await self._execute_stage_with_adversarial(
stage, pipeline_result, saga, started_at
)
# 解析输入变量
resolved_inputs = self._resolve_variables(stage.inputs, pipeline_result.variables)
# 执行
if self._dispatcher is None:
# Dry-run 模式
result = StageResult(
stage_name=stage.name,
status=StageStatus.COMPLETED,
output_data={"dry_run": True, "inputs": resolved_inputs},
started_at=started_at,
completed_at=datetime.now(timezone.utc).isoformat(),
)
# Record completed step for Saga compensation
saga.record_completed(
step_name=stage.name,
result=result.output_data,
compensate_action=stage.compensate,
)
return result
# 通过 Dispatcher 分发任务
from agentkit.core.protocol import TaskMessage
import uuid
task = TaskMessage(
task_id=str(uuid.uuid4()),
agent_name=stage.agent,
task_type=stage.action,
priority=0,
input_data=resolved_inputs,
callback_url=None,
created_at=datetime.now(timezone.utc),
timeout_seconds=stage.timeout_seconds,
)
async def _dispatch_and_wait() -> StageResult:
"""Dispatch task and wait for result"""
await self._dispatcher.dispatch(task)
# 等待结果
for _ in range(stage.timeout_seconds):
status = await self._dispatcher.get_task_status(task.task_id)
if status["status"] in ("completed", "failed", "cancelled"):
return StageResult(
stage_name=stage.name,
status=StageStatus.COMPLETED if status["status"] == "completed" else StageStatus.FAILED,
output_data=status.get("output_data"),
error_message=status.get("error_message"),
started_at=started_at,
completed_at=datetime.now(timezone.utc).isoformat(),
)
await asyncio.sleep(1)
return StageResult(
stage_name=stage.name,
status=StageStatus.FAILED,
error_message=f"Timeout after {stage.timeout_seconds}s",
started_at=started_at,
completed_at=datetime.now(timezone.utc).isoformat(),
)
try:
# Execute with retry if retry_policy is configured
sr = await execute_with_retry(
func=_dispatch_and_wait,
retry_policy=stage.retry_policy,
step_name=stage.name,
)
# Record completed step for Saga compensation on success
if sr.status == StageStatus.COMPLETED:
saga.record_completed(
step_name=stage.name,
result=sr.output_data,
compensate_action=stage.compensate,
)
return sr
except asyncio.CancelledError:
raise
except (asyncio.TimeoutError, ConnectionError, RuntimeError, ValueError) as e:
# dispatcher / agent 执行失败 — 转 StageResult.FAILED 不向上抛
return StageResult(
stage_name=stage.name,
status=StageStatus.FAILED,
error_message=str(e),
started_at=started_at,
completed_at=datetime.now(timezone.utc).isoformat(),
)
@staticmethod
def _topological_group(stages: list[PipelineStage]) -> list[list[PipelineStage]]:
"""拓扑排序 + 按依赖层级分组"""
stage_map = {s.name: s for s in stages}
in_degree = defaultdict(int)
dependents = defaultdict(list)
for s in stages:
if s.name not in in_degree:
in_degree[s.name] = 0
for dep in s.depends_on:
if dep not in stage_map:
raise ValueError(f"Stage '{s.name}' depends on unknown stage '{dep}'")
in_degree[s.name] += 1
dependents[dep].append(s.name)
levels = []
remaining = set(in_degree.keys())
while remaining:
# 找到入度为 0 的节点
current_level = [name for name in remaining if in_degree[name] == 0]
if not current_level:
raise ValueError("Circular dependency detected in pipeline")
levels.append([stage_map[name] for name in current_level])
for name in current_level:
remaining.remove(name)
for dep in dependents[name]:
in_degree[dep] -= 1
return levels
@staticmethod
def _resolve_variables(template: dict, context: dict) -> dict:
"""解析 ${var.path} 变量引用"""
resolved = {}
for key, value in template.items():
if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
var_path = value[2:-1]
resolved[key] = PipelineEngine._get_nested(context, var_path)
else:
resolved[key] = value
return resolved
@staticmethod
def _get_nested(data: dict, path: str) -> Any:
keys = path.split(".")
current = data
for key in keys:
if isinstance(current, dict):
current = current.get(key)
else:
return None
return current
@staticmethod
def _evaluate_condition(condition: str, variables: dict) -> bool:
"""简单条件评估"""
if "==" in condition:
parts = condition.split("==", 1)
left = variables.get(parts[0].strip(), parts[0].strip())
right = parts[1].strip().strip("'\"")
return str(left) == right
elif "!=" in condition:
parts = condition.split("!=", 1)
left = variables.get(parts[0].strip(), parts[0].strip())
right = parts[1].strip().strip("'\"")
return str(left) != right
else:
return bool(variables.get(condition))
async def _execute_stage_with_adversarial(
self,
stage: PipelineStage,
pipeline_result: PipelineResult,
saga: SagaOrchestrator,
started_at: str,
) -> StageResult:
"""执行带对抗闭环的 stage
Worker 产出 → Verifier 审查 → 不通过则带反馈打回 Worker → 循环至通过或轮次耗尽
"""
adversarial_state = AdversarialState(
current_round=0,
max_rounds=stage.max_adversarial_rounds,
)
resolved_inputs = self._resolve_variables(stage.inputs, pipeline_result.variables)
current_context = resolved_inputs.copy()
for round_num in range(1, stage.max_adversarial_rounds + 1):
adversarial_state.current_round = round_num
logger.info(
f"Adversarial round {round_num}/{stage.max_adversarial_rounds} "
f"for stage '{stage.name}'"
)
# 1. 执行 Worker Agent
worker_result = await self._execute_agent_stage(
stage.agent,
stage.action,
current_context,
stage,
started_at,
)
if worker_result.status != StageStatus.COMPLETED:
# Worker 执行失败,直接返回
return worker_result
# 2. 执行 Verifier 审查
try:
verifier_feedback = await self._execute_verifier(
stage.verifier,
worker_result.output_data or {},
stage,
started_at,
)
except asyncio.CancelledError:
raise
except (asyncio.TimeoutError, ConnectionError, RuntimeError, ValueError) as e:
logger.error(f"Verifier execution failed for stage '{stage.name}': {e}")
return StageResult(
stage_name=stage.name,
status=StageStatus.FAILED,
error_message=f"Verifier failed: {e}",
started_at=started_at,
completed_at=datetime.now(timezone.utc).isoformat(),
)
# 3. 记录反馈历史
adversarial_state.feedback_history.append(verifier_feedback)
adversarial_state.last_feedback = verifier_feedback
if verifier_feedback.passed:
# 审查通过,返回成功结果
logger.info(
f"Stage '{stage.name}' passed review in round {round_num}"
)
worker_result.output_data = worker_result.output_data or {}
worker_result.output_data["adversarial_metadata"] = {
"passed_round": round_num,
"total_rounds": round_num,
"feedback_summary": verifier_feedback.summary,
"score": verifier_feedback.score,
}
saga.record_completed(
step_name=stage.name,
result=worker_result.output_data,
compensate_action=stage.compensate,
)
return worker_result
# 4. 审查不通过,判断是否还有重试机会
logger.warning(
f"Stage '{stage.name}' failed review in round {round_num}: "
f"{verifier_feedback.summary}"
)
if round_num >= stage.max_adversarial_rounds:
# 轮次耗尽,执行升级处理
return await self._escalate(
stage,
worker_result,
adversarial_state,
started_at,
)
# 5. 打回 Worker 重做,附带反馈和前次产出
feedback_context = self._build_feedback_context(
verifier_feedback,
stage.feedback_mode,
)
current_context = {
**resolved_inputs,
"previous_output": worker_result.output_data,
**feedback_context,
}
# 不应该到达这里,但以防万一
return StageResult(
stage_name=stage.name,
status=StageStatus.FAILED,
error_message="Adversarial loop exited unexpectedly",
started_at=started_at,
completed_at=datetime.now(timezone.utc).isoformat(),
)
async def _execute_agent_stage(
self,
agent_name: str,
action: str,
input_data: dict[str, Any],
stage: PipelineStage,
started_at: str,
timeout_seconds: int | None = None,
) -> StageResult:
"""执行单个 Agent stage不含对抗逻辑
Args:
agent_name: Agent 名称
action: 执行动作
input_data: 输入数据
stage: 所属 stage
started_at: 开始时间
timeout_seconds: 独立超时时间,不传则使用 stage.timeout_seconds
"""
effective_timeout = timeout_seconds if timeout_seconds is not None else stage.timeout_seconds
if self._dispatcher is None:
# Dry-run 模式
return StageResult(
stage_name=stage.name,
status=StageStatus.COMPLETED,
output_data={"dry_run": True, "inputs": input_data},
started_at=started_at,
completed_at=datetime.now(timezone.utc).isoformat(),
)
from agentkit.core.protocol import TaskMessage
import uuid
task = TaskMessage(
task_id=str(uuid.uuid4()),
agent_name=agent_name,
task_type=action,
priority=0,
input_data=input_data,
callback_url=None,
created_at=datetime.now(timezone.utc),
timeout_seconds=effective_timeout,
)
async def _dispatch_and_wait() -> StageResult:
"""Dispatch task and wait for result"""
await self._dispatcher.dispatch(task)
for _ in range(effective_timeout):
status = await self._dispatcher.get_task_status(task.task_id)
if status["status"] in ("completed", "failed", "cancelled"):
return StageResult(
stage_name=stage.name,
status=StageStatus.COMPLETED if status["status"] == "completed" else StageStatus.FAILED,
output_data=status.get("output_data"),
error_message=status.get("error_message"),
started_at=started_at,
completed_at=datetime.now(timezone.utc).isoformat(),
)
await asyncio.sleep(1)
return StageResult(
stage_name=stage.name,
status=StageStatus.FAILED,
error_message=f"Timeout after {effective_timeout}s",
started_at=started_at,
completed_at=datetime.now(timezone.utc).isoformat(),
)
try:
sr = await execute_with_retry(
func=_dispatch_and_wait,
retry_policy=stage.retry_policy,
step_name=stage.name,
)
return sr
except asyncio.CancelledError:
raise
except (asyncio.TimeoutError, ConnectionError, RuntimeError, ValueError) as e:
return StageResult(
stage_name=stage.name,
status=StageStatus.FAILED,
error_message=str(e),
started_at=started_at,
completed_at=datetime.now(timezone.utc).isoformat(),
)
async def _execute_verifier(
self,
verifier_name: str,
worker_output: dict[str, Any],
stage: PipelineStage,
started_at: str,
) -> ReviewFeedback:
"""执行 Verifier Agent 审查 Worker 产出
Returns:
ReviewFeedback: 结构化审查反馈
"""
logger.info(f"Executing verifier '{verifier_name}' for stage '{stage.name}'")
# 构建审查输入
verifier_input = {
"review_target": worker_output,
"review_instruction": (
"Please review the following output for quality, correctness, and completeness. "
"Return a structured review with pass/fail status, issues found, and a summary."
),
}
# 执行 Verifier Agent使用独立超时
verifier_result = await self._execute_agent_stage(
verifier_name,
"review",
verifier_input,
stage,
started_at,
timeout_seconds=stage.verifier_timeout_seconds,
)
if verifier_result.status != StageStatus.COMPLETED:
raise RuntimeError(
f"Verifier '{verifier_name}' failed: {verifier_result.error_message}"
)
# 解析返回结果为 ReviewFeedback
output_data = verifier_result.output_data or {}
try:
feedback = ReviewFeedback(
passed=output_data.get("passed", False),
issues=[
ReviewIssue(**issue)
for issue in output_data.get("issues", [])
],
summary=output_data.get("summary", "No summary provided"),
score=output_data.get("score", 0.0),
)
return feedback
except (TypeError, KeyError, ValueError) as e:
# 解析失败时直接抛出异常,避免死循环
logger.error(f"Failed to parse verifier output: {e}")
raise RuntimeError(
f"Verifier '{verifier_name}' returned unparseable output: {e}. "
f"Raw output keys: {list(output_data.keys())}"
) from e
def _build_feedback_context(
self,
feedback: ReviewFeedback,
feedback_mode: str = "structured+natural",
) -> dict[str, Any]:
"""构建反馈上下文,让 Worker Agent 理解审查反馈并定向修复
Args:
feedback: 审查反馈
feedback_mode: 反馈模式 (structured+natural / structured / natural)
Returns:
dict: 反馈上下文字典
"""
issues_list = [
{
"severity": issue.severity,
"category": issue.category,
"description": issue.description,
"location": issue.location,
"suggestion": issue.suggestion,
}
for issue in feedback.issues
]
feedback_context: dict[str, Any] = {
"previous_attempt_failed": True,
}
if feedback_mode == "structured+natural":
feedback_context["review_feedback"] = {
"summary": feedback.summary,
"issues": issues_list,
"previous_score": feedback.score,
}
feedback_context["instruction"] = (
"Your previous output did not pass review. "
"Please fix the issues listed above and regenerate. "
f"Review summary: {feedback.summary}"
)
elif feedback_mode == "structured":
feedback_context["review_feedback"] = {
"issues": issues_list,
"previous_score": feedback.score,
}
feedback_context["instruction"] = (
"Your previous output did not pass review. "
"Please fix the issues listed above and regenerate."
)
elif feedback_mode == "natural":
feedback_context["review_feedback"] = {
"summary": feedback.summary,
"previous_score": feedback.score,
}
feedback_context["instruction"] = (
f"Your previous output did not pass review. "
f"Review feedback: {feedback.summary}. "
"Please regenerate addressing the feedback."
)
else:
# 未知模式fallback 到 structured+natural
logger.warning(f"Unknown feedback_mode '{feedback_mode}', falling back to structured+natural")
feedback_context["review_feedback"] = {
"summary": feedback.summary,
"issues": issues_list,
"previous_score": feedback.score,
}
feedback_context["instruction"] = (
"Your previous output did not pass review. "
"Please fix the issues listed above and regenerate. "
f"Review summary: {feedback.summary}"
)
return feedback_context
async def _escalate(
self,
stage: PipelineStage,
worker_result: StageResult,
adversarial_state: AdversarialState,
started_at: str,
) -> StageResult:
"""对抗轮次耗尽后的升级处理
Args:
stage: 当前 stage
worker_result: 最后一次 Worker 结果
adversarial_state: 对抗状态
started_at: 开始时间
Returns:
StageResult: 升级后的结果
"""
logger.warning(
f"Adversarial rounds exhausted for stage '{stage.name}' "
f"({adversarial_state.current_round}/{adversarial_state.max_rounds})"
)
if stage.escalate_on_exhaust:
# 转发到升级目标
logger.info(f"Escalating stage '{stage.name}' to '{stage.escalate_on_exhaust}'")
escalate_result = await self._execute_agent_stage(
stage.escalate_on_exhaust,
"handle_escalation",
{
"original_output": worker_result.output_data,
"adversarial_state": adversarial_state.model_dump(),
"escalation_reason": (
f"Failed to pass review after {adversarial_state.current_round} rounds"
),
},
stage,
started_at,
)
escalate_result.output_data = escalate_result.output_data or {}
escalate_result.output_data["adversarial_metadata"] = {
"escalated_to": stage.escalate_on_exhaust,
"total_rounds": adversarial_state.current_round,
"feedback_history_summary": [
{"round": i + 1, "passed": fb.passed, "score": fb.score}
for i, fb in enumerate(adversarial_state.feedback_history)
],
}
# 如果升级 Agent 也失败了,合并错误信息
if escalate_result.status == StageStatus.FAILED:
escalate_result.error_message = (
f"Escalation to '{stage.escalate_on_exhaust}' also failed: "
f"{escalate_result.error_message}. "
f"Original adversarial rounds exhausted: {adversarial_state.current_round}/{adversarial_state.max_rounds}"
)
return escalate_result
else:
# 返回失败结果,附带审查历史
last_feedback = adversarial_state.last_feedback
return StageResult(
stage_name=stage.name,
status=StageStatus.FAILED,
error_message=(
f"Adversarial rounds exhausted ({adversarial_state.current_round}/"
f"{adversarial_state.max_rounds}). "
f"Last review: {last_feedback.summary if last_feedback else 'N/A'}"
),
output_data={
"adversarial_metadata": {
"total_rounds": adversarial_state.current_round,
"feedback_history": [
fb.model_dump() for fb in adversarial_state.feedback_history
],
}
},
started_at=started_at,
completed_at=datetime.now(timezone.utc).isoformat(),
)