502 lines
18 KiB
Python
502 lines
18 KiB
Python
"""PlanExecutor — 执行计划执行器
|
||
|
||
按确认后的 ExecutionPlan 执行,自动并行调度无依赖步骤,支持执行中调整。
|
||
|
||
执行流程:
|
||
1. 按 parallel_groups 分组执行步骤
|
||
2. 每组内使用 asyncio.gather 并行执行
|
||
3. 步骤级状态机:PENDING → RUNNING → COMPLETED/FAILED
|
||
4. 失败处理:重试 / 调整计划(跳过/替换)/ 请求人工介入
|
||
5. 与 AgentPool 集成:每个步骤通过 AgentPool 创建 Agent 执行
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import logging
|
||
import time
|
||
from dataclasses import dataclass, field
|
||
from enum import Enum
|
||
from typing import Any, Callable, Awaitable
|
||
|
||
from agentkit.core.plan_schema import ExecutionPlan, PlanStep, PlanStepStatus
|
||
from agentkit.core.protocol import TaskMessage, TaskResult, TaskStatus
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class FailureAction(str, Enum):
|
||
"""步骤失败后的处理策略"""
|
||
|
||
RETRY = "retry"
|
||
SKIP = "skip"
|
||
REPLACE = "replace"
|
||
REQUEST_HUMAN = "request_human"
|
||
ABORT = "abort"
|
||
|
||
|
||
@dataclass
|
||
class StepExecutionResult:
|
||
"""单个步骤的执行结果"""
|
||
|
||
step_id: str
|
||
status: PlanStepStatus
|
||
result: dict[str, Any] | None = None
|
||
error: str | None = None
|
||
retry_count: int = 0
|
||
duration_ms: float = 0.0
|
||
|
||
|
||
@dataclass
|
||
class PlanExecutionResult:
|
||
"""整个计划的执行结果"""
|
||
|
||
plan_id: str
|
||
step_results: dict[str, StepExecutionResult]
|
||
status: TaskStatus
|
||
total_duration_ms: float
|
||
adjusted: bool = False
|
||
human_intervention_requested: bool = False
|
||
|
||
@property
|
||
def completed_steps(self) -> list[str]:
|
||
return [sid for sid, r in self.step_results.items() if r.status == PlanStepStatus.COMPLETED]
|
||
|
||
@property
|
||
def failed_steps(self) -> list[str]:
|
||
return [sid for sid, r in self.step_results.items() if r.status == PlanStepStatus.FAILED]
|
||
|
||
@property
|
||
def skipped_steps(self) -> list[str]:
|
||
return [sid for sid, r in self.step_results.items() if r.status == PlanStepStatus.SKIPPED]
|
||
|
||
|
||
# 回调类型
|
||
OnStepCompleteCallback = Callable[[PlanStep, StepExecutionResult], Awaitable[None]]
|
||
OnStepFailedCallback = Callable[[PlanStep, StepExecutionResult], FailureAction]
|
||
OnHumanInterventionCallback = Callable[[PlanStep, StepExecutionResult], Awaitable[FailureAction]]
|
||
|
||
|
||
class PlanExecutor:
|
||
"""执行计划执行器
|
||
|
||
按确认后的 ExecutionPlan 执行,自动并行调度无依赖步骤,
|
||
支持失败重试、计划调整和人工介入。
|
||
|
||
使用方式:
|
||
executor = PlanExecutor(agent_pool=pool)
|
||
result = await executor.execute(plan, original_task)
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
agent_pool: Any,
|
||
max_retries: int = 2,
|
||
step_timeout: float = 300.0,
|
||
max_parallel: int = 5,
|
||
on_step_complete: OnStepCompleteCallback | None = None,
|
||
on_step_failed: OnStepFailedCallback | None = None,
|
||
on_human_intervention: OnHumanInterventionCallback | None = None,
|
||
):
|
||
"""
|
||
Args:
|
||
agent_pool: AgentPool 实例
|
||
max_retries: 步骤失败后最大重试次数
|
||
step_timeout: 单个步骤超时时间(秒)
|
||
max_parallel: 最大并行步骤数
|
||
on_step_complete: 步骤完成回调
|
||
on_step_failed: 步骤失败回调,返回 FailureAction 决定后续处理
|
||
on_human_intervention: 人工介入回调
|
||
"""
|
||
self._agent_pool = agent_pool
|
||
self._max_retries = max_retries
|
||
self._step_timeout = step_timeout
|
||
self._max_parallel = max_parallel
|
||
self._on_step_complete = on_step_complete
|
||
self._on_step_failed = on_step_failed
|
||
self._on_human_intervention = on_human_intervention
|
||
|
||
async def execute(
|
||
self,
|
||
plan: ExecutionPlan,
|
||
original_task: TaskMessage,
|
||
) -> PlanExecutionResult:
|
||
"""执行确认后的 ExecutionPlan
|
||
|
||
Args:
|
||
plan: 已确认的执行计划
|
||
original_task: 原始任务消息
|
||
|
||
Returns:
|
||
PlanExecutionResult: 执行结果
|
||
"""
|
||
start_time = time.monotonic()
|
||
step_results: dict[str, StepExecutionResult] = {}
|
||
plan_adjusted = False
|
||
human_intervention_requested = False
|
||
|
||
# 构建步骤索引
|
||
step_map = {s.step_id: s for s in plan.steps}
|
||
|
||
# 按 parallel_groups 分组执行
|
||
for group in plan.parallel_groups:
|
||
# 过滤掉已跳过/已完成的步骤(可能因计划调整而变化)
|
||
active_step_ids = [
|
||
sid for sid in group
|
||
if sid in step_map and step_map[sid].status in (PlanStepStatus.PENDING,)
|
||
]
|
||
|
||
if not active_step_ids:
|
||
continue
|
||
|
||
# 为每个步骤注入依赖结果
|
||
coros = []
|
||
for step_id in active_step_ids:
|
||
step = step_map[step_id]
|
||
enriched_input = self._inject_dependency_results(step, step_results)
|
||
coros.append(self._execute_step_with_retry(step, enriched_input, original_task))
|
||
|
||
# 并行执行当前组
|
||
results = await asyncio.gather(*coros, return_exceptions=True)
|
||
|
||
for step_id, result in zip(active_step_ids, results):
|
||
if isinstance(result, Exception):
|
||
step_results[step_id] = StepExecutionResult(
|
||
step_id=step_id,
|
||
status=PlanStepStatus.FAILED,
|
||
error=str(result),
|
||
)
|
||
else:
|
||
step_results[step_id] = result
|
||
|
||
# 处理失败步骤
|
||
if step_results[step_id].status == PlanStepStatus.FAILED:
|
||
step = step_map[step_id]
|
||
action_taken = await self._handle_step_failure(
|
||
step, step_results[step_id], step_map, step_results, plan,
|
||
)
|
||
if action_taken == "adjusted":
|
||
plan_adjusted = True
|
||
elif action_taken in ("human", "human_adjusted"):
|
||
human_intervention_requested = True
|
||
if action_taken == "human_adjusted":
|
||
plan_adjusted = True
|
||
|
||
# 计算总耗时
|
||
total_duration_ms = (time.monotonic() - start_time) * 1000
|
||
|
||
# 确定整体状态
|
||
status = self._determine_overall_status(plan, step_results)
|
||
|
||
return PlanExecutionResult(
|
||
plan_id=plan.plan_id,
|
||
step_results=step_results,
|
||
status=status,
|
||
total_duration_ms=total_duration_ms,
|
||
adjusted=plan_adjusted,
|
||
human_intervention_requested=human_intervention_requested,
|
||
)
|
||
|
||
async def _execute_step_with_retry(
|
||
self,
|
||
step: PlanStep,
|
||
input_data: dict[str, Any],
|
||
original_task: TaskMessage,
|
||
) -> StepExecutionResult:
|
||
"""执行单个步骤,支持重试
|
||
|
||
Args:
|
||
step: 计划步骤
|
||
input_data: 注入依赖结果后的输入数据
|
||
original_task: 原始任务消息
|
||
|
||
Returns:
|
||
StepExecutionResult: 步骤执行结果
|
||
"""
|
||
step.status = PlanStepStatus.RUNNING
|
||
retry_count = 0
|
||
last_error: str | None = None
|
||
|
||
while retry_count <= self._max_retries:
|
||
start = time.monotonic()
|
||
try:
|
||
result = await asyncio.wait_for(
|
||
self._execute_step_once(step, input_data, original_task),
|
||
timeout=self._step_timeout,
|
||
)
|
||
duration_ms = (time.monotonic() - start) * 1000
|
||
step.status = PlanStepStatus.COMPLETED
|
||
|
||
exec_result = StepExecutionResult(
|
||
step_id=step.step_id,
|
||
status=PlanStepStatus.COMPLETED,
|
||
result=result,
|
||
retry_count=retry_count,
|
||
duration_ms=duration_ms,
|
||
)
|
||
|
||
# 完成回调
|
||
if self._on_step_complete:
|
||
await self._on_step_complete(step, exec_result)
|
||
|
||
return exec_result
|
||
|
||
except asyncio.TimeoutError:
|
||
last_error = f"Step '{step.step_id}' timed out after {self._step_timeout}s"
|
||
logger.warning(last_error)
|
||
except Exception as e:
|
||
last_error = str(e)
|
||
logger.warning(f"Step '{step.step_id}' failed (attempt {retry_count + 1}): {e}")
|
||
|
||
retry_count += 1
|
||
|
||
# 所有重试耗尽
|
||
step.status = PlanStepStatus.FAILED
|
||
step.error = last_error
|
||
|
||
return StepExecutionResult(
|
||
step_id=step.step_id,
|
||
status=PlanStepStatus.FAILED,
|
||
error=last_error,
|
||
retry_count=retry_count - 1,
|
||
duration_ms=0.0,
|
||
)
|
||
|
||
async def _execute_step_once(
|
||
self,
|
||
step: PlanStep,
|
||
input_data: dict[str, Any],
|
||
original_task: TaskMessage,
|
||
) -> dict[str, Any]:
|
||
"""执行单个步骤一次
|
||
|
||
通过 AgentPool 创建 Agent 执行步骤。
|
||
|
||
Args:
|
||
step: 计划步骤
|
||
input_data: 输入数据
|
||
original_task: 原始任务消息
|
||
|
||
Returns:
|
||
步骤执行结果字典
|
||
"""
|
||
# 尝试通过 required_skills 创建 Agent
|
||
agent = None
|
||
for skill_name in step.required_skills:
|
||
try:
|
||
agent = await self._agent_pool.create_agent_from_skill(skill_name)
|
||
break
|
||
except Exception as e:
|
||
logger.debug(f"Failed to create agent from skill '{skill_name}': {e}")
|
||
continue
|
||
|
||
# 如果 Skill 创建失败,尝试从池中获取已有 Agent
|
||
if agent is None:
|
||
# 尝试用步骤名称或默认 agent
|
||
agent = self._agent_pool.get_agent(step.step_id)
|
||
if agent is None and step.required_skills:
|
||
agent = self._agent_pool.get_agent(step.required_skills[0])
|
||
|
||
if agent is None:
|
||
raise RuntimeError(
|
||
f"No agent available for step '{step.step_id}' "
|
||
f"(required_skills: {step.required_skills})"
|
||
)
|
||
|
||
# 构造 TaskMessage
|
||
task_msg = TaskMessage(
|
||
task_id=step.step_id,
|
||
agent_name=agent.name if hasattr(agent, "name") else step.step_id,
|
||
task_type=original_task.task_type,
|
||
priority=original_task.priority,
|
||
input_data=input_data,
|
||
callback_url=None,
|
||
created_at=original_task.created_at,
|
||
timeout_seconds=int(self._step_timeout),
|
||
)
|
||
|
||
result = await agent.execute(task_msg)
|
||
|
||
if isinstance(result, TaskResult):
|
||
if result.status == TaskStatus.FAILED:
|
||
raise RuntimeError(result.error_message or "Agent execution failed")
|
||
return result.output_data or {}
|
||
|
||
return result if isinstance(result, dict) else {"output": result}
|
||
|
||
async def _handle_step_failure(
|
||
self,
|
||
step: PlanStep,
|
||
exec_result: StepExecutionResult,
|
||
step_map: dict[str, PlanStep],
|
||
step_results: dict[str, StepExecutionResult],
|
||
plan: ExecutionPlan,
|
||
) -> str:
|
||
"""处理步骤失败
|
||
|
||
根据失败类型决定:重试 / 调整计划 / 请求人工
|
||
|
||
Args:
|
||
step: 失败的步骤
|
||
exec_result: 执行结果
|
||
step_map: 步骤映射
|
||
step_results: 所有步骤结果
|
||
plan: 执行计划
|
||
|
||
Returns:
|
||
"none" / "adjusted" / "human"
|
||
"""
|
||
# 如果已有回调,让回调决定
|
||
if self._on_step_failed:
|
||
action = await self._on_step_failed(step, exec_result)
|
||
else:
|
||
# 默认策略:根据错误类型决定
|
||
action = self._default_failure_action(step, exec_result)
|
||
|
||
if action == FailureAction.RETRY:
|
||
# 重试已在 _execute_step_with_retry 中处理
|
||
return "none"
|
||
|
||
if action == FailureAction.SKIP:
|
||
step.status = PlanStepStatus.SKIPPED
|
||
exec_result.status = PlanStepStatus.SKIPPED
|
||
# 跳过依赖此步骤的后续步骤
|
||
self._skip_dependent_steps(step.step_id, step_map, step_results, plan)
|
||
return "adjusted"
|
||
|
||
if action == FailureAction.REPLACE:
|
||
# 替换步骤:标记当前步骤为 SKIPPED,后续步骤不再依赖它
|
||
step.status = PlanStepStatus.SKIPPED
|
||
exec_result.status = PlanStepStatus.SKIPPED
|
||
return "adjusted"
|
||
|
||
if action == FailureAction.REQUEST_HUMAN:
|
||
if self._on_human_intervention:
|
||
human_action = await self._on_human_intervention(step, exec_result)
|
||
if human_action == FailureAction.SKIP:
|
||
step.status = PlanStepStatus.SKIPPED
|
||
exec_result.status = PlanStepStatus.SKIPPED
|
||
self._skip_dependent_steps(step.step_id, step_map, step_results, plan)
|
||
return "human_adjusted"
|
||
elif human_action == FailureAction.RETRY:
|
||
# 人工介入后重试
|
||
return "human"
|
||
return "human"
|
||
|
||
if action == FailureAction.ABORT:
|
||
# 将失败步骤本身也标记为 SKIPPED
|
||
step.status = PlanStepStatus.SKIPPED
|
||
exec_result.status = PlanStepStatus.SKIPPED
|
||
# 中止所有后续步骤
|
||
self._abort_remaining_steps(step_map, step_results, plan)
|
||
return "adjusted"
|
||
|
||
return "none"
|
||
|
||
def _default_failure_action(self, step: PlanStep, exec_result: StepExecutionResult) -> FailureAction:
|
||
"""默认失败处理策略
|
||
|
||
根据错误类型决定:
|
||
- 超时错误 → RETRY(重试已在 _execute_step_with_retry 处理)
|
||
- Agent 不可用 → SKIP
|
||
- 其他错误 → SKIP
|
||
"""
|
||
error = exec_result.error or ""
|
||
if "timed out" in error.lower():
|
||
# 超时已通过重试处理,重试耗尽后跳过
|
||
return FailureAction.SKIP
|
||
if "no agent available" in error.lower():
|
||
return FailureAction.SKIP
|
||
return FailureAction.SKIP
|
||
|
||
def _skip_dependent_steps(
|
||
self,
|
||
failed_step_id: str,
|
||
step_map: dict[str, PlanStep],
|
||
step_results: dict[str, StepExecutionResult],
|
||
plan: ExecutionPlan,
|
||
) -> None:
|
||
"""跳过依赖失败步骤的后续步骤"""
|
||
for step in plan.steps:
|
||
if failed_step_id in step.dependencies and step.status == PlanStepStatus.PENDING:
|
||
step.status = PlanStepStatus.SKIPPED
|
||
step_results[step.step_id] = StepExecutionResult(
|
||
step_id=step.step_id,
|
||
status=PlanStepStatus.SKIPPED,
|
||
error=f"Skipped due to failed dependency '{failed_step_id}'",
|
||
)
|
||
# 递归跳过
|
||
self._skip_dependent_steps(step.step_id, step_map, step_results, plan)
|
||
|
||
def _abort_remaining_steps(
|
||
self,
|
||
step_map: dict[str, PlanStep],
|
||
step_results: dict[str, StepExecutionResult],
|
||
plan: ExecutionPlan,
|
||
) -> None:
|
||
"""中止所有剩余的未执行步骤"""
|
||
for step in plan.steps:
|
||
if step.status == PlanStepStatus.PENDING:
|
||
step.status = PlanStepStatus.SKIPPED
|
||
step_results[step.step_id] = StepExecutionResult(
|
||
step_id=step.step_id,
|
||
status=PlanStepStatus.SKIPPED,
|
||
error="Aborted due to previous step failure",
|
||
)
|
||
|
||
def _inject_dependency_results(
|
||
self,
|
||
step: PlanStep,
|
||
step_results: dict[str, StepExecutionResult],
|
||
) -> dict[str, Any]:
|
||
"""将依赖步骤的结果注入到当前步骤的输入中
|
||
|
||
兼容 Orchestrator 的 subtask_results 累积模式。
|
||
"""
|
||
enriched = dict(step.input_data)
|
||
|
||
if step.dependencies:
|
||
dep_results: dict[str, dict[str, Any]] = {}
|
||
for dep_id in step.dependencies:
|
||
if dep_id in step_results:
|
||
dep_result = step_results[dep_id]
|
||
dep_results[dep_id] = {
|
||
"status": dep_result.status.value,
|
||
"result": dep_result.result,
|
||
"error": dep_result.error,
|
||
}
|
||
if dep_results:
|
||
enriched["dependency_results"] = dep_results
|
||
|
||
# 添加步骤元信息
|
||
enriched["step_name"] = step.name
|
||
enriched["step_description"] = step.description
|
||
|
||
return enriched
|
||
|
||
def _determine_overall_status(
|
||
self,
|
||
plan: ExecutionPlan,
|
||
step_results: dict[str, StepExecutionResult],
|
||
) -> TaskStatus:
|
||
"""根据步骤执行结果确定整体状态"""
|
||
total = len(plan.steps)
|
||
if total == 0:
|
||
return TaskStatus.COMPLETED
|
||
|
||
completed = sum(1 for r in step_results.values() if r.status == PlanStepStatus.COMPLETED)
|
||
failed = sum(1 for r in step_results.values() if r.status == PlanStepStatus.FAILED)
|
||
skipped = sum(1 for r in step_results.values() if r.status == PlanStepStatus.SKIPPED)
|
||
|
||
if completed == total:
|
||
return TaskStatus.COMPLETED
|
||
if failed == total:
|
||
return TaskStatus.FAILED
|
||
if completed + skipped == total:
|
||
# 所有步骤要么完成要么跳过
|
||
return TaskStatus.COMPLETED
|
||
if failed > 0:
|
||
return TaskStatus.PARTIALLY_COMPLETED # 部分成功
|
||
|
||
return TaskStatus.COMPLETED
|