geo/backend/app/agent_framework/pipeline/engine.py

546 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Pipeline执行引擎 - 编排多阶段Agent任务链的执行"""
import asyncio
import logging
import time
import uuid
from datetime import datetime, timezone
from typing import Any, Optional
from .loader import PipelineLoader
from .schema import (
Pipeline,
PipelineResult,
PipelineStage,
StageResult,
StageStatus,
)
logger = logging.getLogger(__name__)
class PipelineExecutionError(Exception):
"""Pipeline执行错误"""
def __init__(self, pipeline_name: str, reason: str = ""):
self.pipeline_name = pipeline_name
self.reason = reason
super().__init__(f"Pipeline '{pipeline_name}' execution error: {reason}")
class PipelineEngine:
"""
Pipeline执行引擎。
负责按照YAML定义的依赖关系DAG编排多阶段Agent任务的执行
支持变量传递、超时控制、重试、条件执行等。
如果dispatcher为None则使用模拟执行dry-run模式
用于测试和调试Pipeline定义。
"""
def __init__(self, dispatcher=None):
"""
Args:
dispatcher: TaskDispatcher实例可选
如果为None使用模拟执行dry-run模式
"""
self.dispatcher = dispatcher
self.loader = PipelineLoader()
async def execute(
self,
pipeline: Pipeline,
context: dict | None = None,
) -> PipelineResult:
"""
执行整个Pipeline。
执行流程:
1. 合并全局变量和传入context
2. 拓扑排序确定执行顺序
3. 按依赖顺序逐stage执行
4. 变量解析:将${...}替换为实际值
5. 调用dispatcher分发任务给对应agent
6. 收集结果传递给下游stage
Args:
pipeline: Pipeline定义对象
context: 外部传入的上下文变量覆盖pipeline.variables中的默认值
Returns:
PipelineResult包含所有阶段的执行结果
"""
start_time = time.monotonic()
pipeline_name = pipeline.name
# 构建执行上下文合并全局变量和外部context
exec_context: dict[str, Any] = dict(pipeline.variables)
if context:
exec_context.update(context)
# 初始化stages结果存储
stages_results: dict[str, StageResult] = {}
# 构建stages上下文用于变量解析
stages_context: dict[str, Any] = {}
# 拓扑排序
try:
sorted_stages = self._topological_sort(pipeline.stages)
except ValueError as e:
duration_ms = int((time.monotonic() - start_time) * 1000)
return PipelineResult(
pipeline_name=pipeline_name,
status=StageStatus.FAILED,
stages_results=stages_results,
duration_ms=duration_ms,
error=str(e),
)
logger.info(
f"Pipeline '{pipeline_name}' starting execution "
f"({len(sorted_stages)} stages, order: {[s.name for s in sorted_stages]})"
)
# 逐阶段执行
overall_status = StageStatus.COMPLETED
failed_stages: set[str] = set()
skipped_stages: set[str] = set()
for stage in sorted_stages:
# 检查是否应该跳过此阶段
if self._should_skip(stage, failed_stages, skipped_stages):
skipped_stages.add(stage.name)
stages_results[stage.name] = StageResult(
stage_name=stage.name,
status=StageStatus.SKIPPED,
)
# 将跳过阶段的信息写入stages_context
stages_context[stage.name] = {"outputs": {}, "status": StageStatus.SKIPPED}
logger.info(f"Stage '{stage.name}' skipped (dependency failed)")
continue
# 检查条件表达式
if stage.condition and not self._evaluate_condition(stage.condition, exec_context, stages_context):
skipped_stages.add(stage.name)
stages_results[stage.name] = StageResult(
stage_name=stage.name,
status=StageStatus.SKIPPED,
)
stages_context[stage.name] = {"outputs": {}, "status": StageStatus.SKIPPED}
logger.info(f"Stage '{stage.name}' skipped (condition not met: {stage.condition})")
continue
# 执行阶段
stage_result = await self._execute_stage(stage, exec_context, stages_context)
stages_results[stage.name] = stage_result
# 将结果写入stages_context供下游引用
stages_context[stage.name] = {
"outputs": stage_result.outputs,
"status": stage_result.status,
}
if stage_result.status == StageStatus.FAILED:
if stage.continue_on_failure:
logger.warning(
f"Stage '{stage.name}' failed but continue_on_failure=True"
)
else:
failed_stages.add(stage.name)
overall_status = StageStatus.FAILED
logger.error(
f"Stage '{stage.name}' failed: {stage_result.error}"
)
logger.info(
f"Stage '{stage.name}' {stage_result.status.value} "
f"({stage_result.duration_ms}ms)"
)
duration_ms = int((time.monotonic() - start_time) * 1000)
result = PipelineResult(
pipeline_name=pipeline_name,
status=overall_status,
stages_results=stages_results,
duration_ms=duration_ms,
)
logger.info(
f"Pipeline '{pipeline_name}' finished with status={overall_status.value} "
f"({duration_ms}ms)"
)
return result
async def _execute_stage(
self,
stage: PipelineStage,
exec_context: dict[str, Any],
stages_context: dict[str, Any],
) -> StageResult:
"""
执行单个Stage。
1. 解析inputs中的变量引用
2. 通过dispatcher向agent发送任务
3. 等待结果(或超时)
4. 如果dispatcher为None返回模拟结果dry-run模式
Args:
stage: 阶段定义
exec_context: 全局执行上下文
stages_context: 已完成阶段的结果上下文
Returns:
StageResult
"""
start_time = time.monotonic()
# 构建变量解析的完整上下文
resolve_ctx = dict(exec_context)
resolve_ctx["stages"] = stages_context
# 解析inputs中的变量引用
resolved_inputs = self._resolve_stage_inputs(stage.inputs, resolve_ctx)
logger.info(
f"Stage '{stage.name}' starting "
f"(agent={stage.agent}, action={stage.action})"
)
# 执行(带重试)
last_error: str | None = None
for attempt in range(stage.retry_count + 1):
if attempt > 0:
logger.info(
f"Stage '{stage.name}' retry attempt {attempt}/{stage.retry_count}"
)
try:
if self.dispatcher is not None:
# 实际执行通过dispatcher分发任务
result = await self._dispatch_and_wait(stage, resolved_inputs)
else:
# Dry-run模式返回模拟结果
result = self._dry_run_stage(stage, resolved_inputs)
# 成功
duration_ms = int((time.monotonic() - start_time) * 1000)
result.duration_ms = duration_ms
return result
except asyncio.TimeoutError:
last_error = f"Stage timed out after {stage.timeout_seconds}s"
logger.warning(
f"Stage '{stage.name}' timeout (attempt {attempt + 1}/{stage.retry_count + 1})"
)
except Exception as e:
last_error = str(e)
logger.warning(
f"Stage '{stage.name}' error (attempt {attempt + 1}/{stage.retry_count + 1}): {e}"
)
# 所有重试都失败
duration_ms = int((time.monotonic() - start_time) * 1000)
return StageResult(
stage_name=stage.name,
status=StageStatus.FAILED,
error=last_error,
duration_ms=duration_ms,
)
async def _dispatch_and_wait(
self,
stage: PipelineStage,
resolved_inputs: dict[str, Any],
) -> StageResult:
"""
通过TaskDispatcher分发任务并等待结果。
Args:
stage: 阶段定义
resolved_inputs: 已解析的输入参数
Returns:
StageResult
Raises:
asyncio.TimeoutError: 超时
Exception: 分发或执行错误
"""
from app.agent_framework.protocol import TaskMessage, TaskResult, TaskStatus
task_id = str(uuid.uuid4())
task_message = TaskMessage(
task_id=task_id,
agent_name=stage.agent,
task_type=stage.action,
priority=0,
input_data=resolved_inputs,
callback_url=None,
created_at=datetime.now(timezone.utc),
timeout_seconds=stage.timeout_seconds,
)
# 分发任务
dispatched_id = await self.dispatcher.dispatch(task_message)
# 等待结果(轮询方式)
elapsed = 0.0
poll_interval = 1.0 # 1秒轮询一次
while elapsed < stage.timeout_seconds:
await asyncio.sleep(poll_interval)
elapsed += poll_interval
task_status = await self.dispatcher.get_task_status(dispatched_id)
if task_status.get("status") == TaskStatus.COMPLETED:
output_data = task_status.get("output_data", {})
# 构建输出只提取stage.outputs中声明的变量
outputs = self._extract_outputs(stage, output_data)
return StageResult(
stage_name=stage.name,
status=StageStatus.COMPLETED,
outputs=outputs,
)
elif task_status.get("status") == TaskStatus.FAILED:
error_msg = task_status.get("error_message", "Unknown error")
return StageResult(
stage_name=stage.name,
status=StageStatus.FAILED,
error=error_msg,
)
elif task_status.get("status") == TaskStatus.CANCELLED:
return StageResult(
stage_name=stage.name,
status=StageStatus.FAILED,
error="Task was cancelled",
)
# 超时
raise asyncio.TimeoutError()
def _dry_run_stage(
self,
stage: PipelineStage,
resolved_inputs: dict[str, Any],
) -> StageResult:
"""
Dry-run模式执行模拟Agent返回结果。仅用于测试/开发环境。
在没有dispatcher的环境下使用用于测试和调试Pipeline定义。
如果在生产环境中触发,则记录 ERROR 级别告警。
Args:
stage: 阶段定义
resolved_inputs: 已解析的输入参数
Returns:
模拟的StageResult
"""
import os
if os.environ.get("ENV", "development") == "production":
logger.error(
f"Pipeline 进入 dry-run 模式stage={stage.name}"
"生产环境中 TaskDispatcher 未正确初始化,请检查系统配置。"
)
else:
logger.warning(f"[DRY-RUN] stage={stage.name} 返回模拟输出")
# 为声明的输出变量生成模拟值
mock_outputs: dict[str, Any] = {}
for output_name in stage.outputs:
mock_outputs[output_name] = f"[dry-run] {output_name} from {stage.agent}.{stage.action}"
logger.info(
f"Stage '{stage.name}' dry-run completed "
f"(agent={stage.agent}, action={stage.action}, "
f"inputs={list(resolved_inputs.keys())}, "
f"outputs={stage.outputs})"
)
return StageResult(
stage_name=stage.name,
status=StageStatus.COMPLETED,
outputs=mock_outputs,
)
def _topological_sort(self, stages: list[PipelineStage]) -> list[PipelineStage]:
"""
拓扑排序Kahn算法
根据depends_on依赖关系确定执行顺序。
无依赖的阶段先执行,有依赖的后执行。
支持并行执行无依赖关系的阶段(本实现为串行,但顺序正确)。
Args:
stages: PipelineStage列表
Returns:
排序后的PipelineStage列表
Raises:
ValueError: 存在循环依赖
"""
name_to_stage = {s.name: s for s in stages}
stage_names = set(name_to_stage.keys())
# 构建邻接表和入度表
in_degree: dict[str, int] = {name: 0 for name in stage_names}
adj: dict[str, list[str]] = {name: [] for name in stage_names}
for stage in stages:
for dep in stage.depends_on:
if dep in stage_names:
adj[dep].append(stage.name)
in_degree[stage.name] += 1
# 入度为0的节点入队
queue: list[str] = sorted(
[name for name, deg in in_degree.items() if deg == 0]
)
sorted_names: list[str] = []
while queue:
# 按名称排序保证确定性
node = queue.pop(0)
sorted_names.append(node)
for neighbor in sorted(adj[node]):
in_degree[neighbor] -= 1
if in_degree[neighbor] == 0:
queue.append(neighbor)
queue.sort() # 保证确定性顺序
if len(sorted_names) != len(stage_names):
raise ValueError(
f"Cyclic dependency detected in pipeline stages: "
f"only {len(sorted_names)}/{len(stage_names)} stages sorted"
)
return [name_to_stage[name] for name in sorted_names]
def _resolve_stage_inputs(
self,
inputs: dict[str, Any],
context: dict[str, Any],
) -> dict[str, Any]:
"""
解析stage的输入变量。
将inputs中的${...}变量引用替换为context中的实际值。
Args:
inputs: 原始输入参数(可能包含变量引用)
context: 完整上下文(全局变量 + stages上下文
Returns:
解析后的输入参数
"""
return PipelineLoader.resolve_variables(inputs, context)
def _should_skip(
self,
stage: PipelineStage,
failed_stages: set[str],
skipped_stages: set[str],
) -> bool:
"""
判断是否应该跳过此阶段。
如果某个依赖的阶段失败了且不是continue_on_failure
则此阶段应跳过。
Args:
stage: 阶段定义
failed_stages: 已失败的阶段集合
skipped_stages: 已跳过的阶段集合
Returns:
True如果应该跳过
"""
for dep in stage.depends_on:
if dep in failed_stages or dep in skipped_stages:
return True
return False
def _evaluate_condition(
self,
condition: str,
exec_context: dict[str, Any],
stages_context: dict[str, Any],
) -> bool:
"""
评估条件表达式。
当前支持简单的变量存在性检查和等值比较:
- "${var}" → 变量存在且非空
- "${var} == 'value'" → 变量等于指定值
- "${var} != 'value'" → 变量不等于指定值
Args:
condition: 条件表达式字符串
exec_context: 全局执行上下文
stages_context: 阶段结果上下文
Returns:
True如果条件满足
"""
resolve_ctx = dict(exec_context)
resolve_ctx["stages"] = stages_context
# 先解析变量
resolved = PipelineLoader.resolve_variables(condition, resolve_ctx)
if isinstance(resolved, bool):
return resolved
resolved_str = str(resolved).strip()
# 等值比较
if "==" in resolved_str:
left, right = resolved_str.split("==", 1)
return left.strip().strip("'\"") == right.strip().strip("'\"")
# 不等比较
if "!=" in resolved_str:
left, right = resolved_str.split("!=", 1)
return left.strip().strip("'\"") != right.strip().strip("'\"")
# 非空判断
return bool(resolved_str) and resolved_str.lower() not in ("false", "0", "none", "")
def _extract_outputs(
self,
stage: PipelineStage,
output_data: dict[str, Any] | None,
) -> dict[str, Any]:
"""
从Agent的output_data中提取stage声明的输出变量。
Args:
stage: 阶段定义
output_data: Agent返回的完整输出数据
Returns:
仅包含声明输出的字典
"""
if not output_data:
return {}
if not stage.outputs:
# 如果未声明outputs返回全部output_data
return output_data
return {
key: output_data.get(key)
for key in stage.outputs
if key in output_data
}