fischer-agentkit/src/agentkit/core/plan_executor.py

502 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""PlanExecutor — 执行计划执行器
按确认后的 ExecutionPlan 执行,自动并行调度无依赖步骤,支持执行中调整。
执行流程:
1. 按 parallel_groups 分组执行步骤
2. 每组内使用 asyncio.gather 并行执行
3. 步骤级状态机PENDING → RUNNING → COMPLETED/FAILED
4. 失败处理:重试 / 调整计划(跳过/替换)/ 请求人工介入
5. 与 AgentPool 集成:每个步骤通过 AgentPool 创建 Agent 执行
"""
from __future__ import annotations
import asyncio
import logging
import time
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Awaitable
from agentkit.core.plan_schema import ExecutionPlan, PlanStep, PlanStepStatus
from agentkit.core.protocol import TaskMessage, TaskResult, TaskStatus
logger = logging.getLogger(__name__)
class FailureAction(str, Enum):
"""步骤失败后的处理策略"""
RETRY = "retry"
SKIP = "skip"
REPLACE = "replace"
REQUEST_HUMAN = "request_human"
ABORT = "abort"
@dataclass
class StepExecutionResult:
"""单个步骤的执行结果"""
step_id: str
status: PlanStepStatus
result: dict[str, Any] | None = None
error: str | None = None
retry_count: int = 0
duration_ms: float = 0.0
@dataclass
class PlanExecutionResult:
"""整个计划的执行结果"""
plan_id: str
step_results: dict[str, StepExecutionResult]
status: TaskStatus
total_duration_ms: float
adjusted: bool = False
human_intervention_requested: bool = False
@property
def completed_steps(self) -> list[str]:
return [sid for sid, r in self.step_results.items() if r.status == PlanStepStatus.COMPLETED]
@property
def failed_steps(self) -> list[str]:
return [sid for sid, r in self.step_results.items() if r.status == PlanStepStatus.FAILED]
@property
def skipped_steps(self) -> list[str]:
return [sid for sid, r in self.step_results.items() if r.status == PlanStepStatus.SKIPPED]
# 回调类型
OnStepCompleteCallback = Callable[[PlanStep, StepExecutionResult], Awaitable[None]]
OnStepFailedCallback = Callable[[PlanStep, StepExecutionResult], FailureAction]
OnHumanInterventionCallback = Callable[[PlanStep, StepExecutionResult], Awaitable[FailureAction]]
class PlanExecutor:
"""执行计划执行器
按确认后的 ExecutionPlan 执行,自动并行调度无依赖步骤,
支持失败重试、计划调整和人工介入。
使用方式:
executor = PlanExecutor(agent_pool=pool)
result = await executor.execute(plan, original_task)
"""
def __init__(
self,
agent_pool: Any,
max_retries: int = 2,
step_timeout: float = 300.0,
max_parallel: int = 5,
on_step_complete: OnStepCompleteCallback | None = None,
on_step_failed: OnStepFailedCallback | None = None,
on_human_intervention: OnHumanInterventionCallback | None = None,
):
"""
Args:
agent_pool: AgentPool 实例
max_retries: 步骤失败后最大重试次数
step_timeout: 单个步骤超时时间(秒)
max_parallel: 最大并行步骤数
on_step_complete: 步骤完成回调
on_step_failed: 步骤失败回调,返回 FailureAction 决定后续处理
on_human_intervention: 人工介入回调
"""
self._agent_pool = agent_pool
self._max_retries = max_retries
self._step_timeout = step_timeout
self._max_parallel = max_parallel
self._on_step_complete = on_step_complete
self._on_step_failed = on_step_failed
self._on_human_intervention = on_human_intervention
async def execute(
self,
plan: ExecutionPlan,
original_task: TaskMessage,
) -> PlanExecutionResult:
"""执行确认后的 ExecutionPlan
Args:
plan: 已确认的执行计划
original_task: 原始任务消息
Returns:
PlanExecutionResult: 执行结果
"""
start_time = time.monotonic()
step_results: dict[str, StepExecutionResult] = {}
plan_adjusted = False
human_intervention_requested = False
# 构建步骤索引
step_map = {s.step_id: s for s in plan.steps}
# 按 parallel_groups 分组执行
for group in plan.parallel_groups:
# 过滤掉已跳过/已完成的步骤(可能因计划调整而变化)
active_step_ids = [
sid for sid in group
if sid in step_map and step_map[sid].status in (PlanStepStatus.PENDING,)
]
if not active_step_ids:
continue
# 为每个步骤注入依赖结果
coros = []
for step_id in active_step_ids:
step = step_map[step_id]
enriched_input = self._inject_dependency_results(step, step_results)
coros.append(self._execute_step_with_retry(step, enriched_input, original_task))
# 并行执行当前组
results = await asyncio.gather(*coros, return_exceptions=True)
for step_id, result in zip(active_step_ids, results):
if isinstance(result, Exception):
step_results[step_id] = StepExecutionResult(
step_id=step_id,
status=PlanStepStatus.FAILED,
error=str(result),
)
else:
step_results[step_id] = result
# 处理失败步骤
if step_results[step_id].status == PlanStepStatus.FAILED:
step = step_map[step_id]
action_taken = await self._handle_step_failure(
step, step_results[step_id], step_map, step_results, plan,
)
if action_taken == "adjusted":
plan_adjusted = True
elif action_taken in ("human", "human_adjusted"):
human_intervention_requested = True
if action_taken == "human_adjusted":
plan_adjusted = True
# 计算总耗时
total_duration_ms = (time.monotonic() - start_time) * 1000
# 确定整体状态
status = self._determine_overall_status(plan, step_results)
return PlanExecutionResult(
plan_id=plan.plan_id,
step_results=step_results,
status=status,
total_duration_ms=total_duration_ms,
adjusted=plan_adjusted,
human_intervention_requested=human_intervention_requested,
)
async def _execute_step_with_retry(
self,
step: PlanStep,
input_data: dict[str, Any],
original_task: TaskMessage,
) -> StepExecutionResult:
"""执行单个步骤,支持重试
Args:
step: 计划步骤
input_data: 注入依赖结果后的输入数据
original_task: 原始任务消息
Returns:
StepExecutionResult: 步骤执行结果
"""
step.status = PlanStepStatus.RUNNING
retry_count = 0
last_error: str | None = None
while retry_count <= self._max_retries:
start = time.monotonic()
try:
result = await asyncio.wait_for(
self._execute_step_once(step, input_data, original_task),
timeout=self._step_timeout,
)
duration_ms = (time.monotonic() - start) * 1000
step.status = PlanStepStatus.COMPLETED
exec_result = StepExecutionResult(
step_id=step.step_id,
status=PlanStepStatus.COMPLETED,
result=result,
retry_count=retry_count,
duration_ms=duration_ms,
)
# 完成回调
if self._on_step_complete:
await self._on_step_complete(step, exec_result)
return exec_result
except asyncio.TimeoutError:
last_error = f"Step '{step.step_id}' timed out after {self._step_timeout}s"
logger.warning(last_error)
except Exception as e:
last_error = str(e)
logger.warning(f"Step '{step.step_id}' failed (attempt {retry_count + 1}): {e}")
retry_count += 1
# 所有重试耗尽
step.status = PlanStepStatus.FAILED
step.error = last_error
return StepExecutionResult(
step_id=step.step_id,
status=PlanStepStatus.FAILED,
error=last_error,
retry_count=retry_count - 1,
duration_ms=0.0,
)
async def _execute_step_once(
self,
step: PlanStep,
input_data: dict[str, Any],
original_task: TaskMessage,
) -> dict[str, Any]:
"""执行单个步骤一次
通过 AgentPool 创建 Agent 执行步骤。
Args:
step: 计划步骤
input_data: 输入数据
original_task: 原始任务消息
Returns:
步骤执行结果字典
"""
# 尝试通过 required_skills 创建 Agent
agent = None
for skill_name in step.required_skills:
try:
agent = await self._agent_pool.create_agent_from_skill(skill_name)
break
except Exception as e:
logger.debug(f"Failed to create agent from skill '{skill_name}': {e}")
continue
# 如果 Skill 创建失败,尝试从池中获取已有 Agent
if agent is None:
# 尝试用步骤名称或默认 agent
agent = self._agent_pool.get_agent(step.step_id)
if agent is None and step.required_skills:
agent = self._agent_pool.get_agent(step.required_skills[0])
if agent is None:
raise RuntimeError(
f"No agent available for step '{step.step_id}' "
f"(required_skills: {step.required_skills})"
)
# 构造 TaskMessage
task_msg = TaskMessage(
task_id=step.step_id,
agent_name=agent.name if hasattr(agent, "name") else step.step_id,
task_type=original_task.task_type,
priority=original_task.priority,
input_data=input_data,
callback_url=None,
created_at=original_task.created_at,
timeout_seconds=int(self._step_timeout),
)
result = await agent.execute(task_msg)
if isinstance(result, TaskResult):
if result.status == TaskStatus.FAILED:
raise RuntimeError(result.error_message or "Agent execution failed")
return result.output_data or {}
return result if isinstance(result, dict) else {"output": result}
async def _handle_step_failure(
self,
step: PlanStep,
exec_result: StepExecutionResult,
step_map: dict[str, PlanStep],
step_results: dict[str, StepExecutionResult],
plan: ExecutionPlan,
) -> str:
"""处理步骤失败
根据失败类型决定:重试 / 调整计划 / 请求人工
Args:
step: 失败的步骤
exec_result: 执行结果
step_map: 步骤映射
step_results: 所有步骤结果
plan: 执行计划
Returns:
"none" / "adjusted" / "human"
"""
# 如果已有回调,让回调决定
if self._on_step_failed:
action = await self._on_step_failed(step, exec_result)
else:
# 默认策略:根据错误类型决定
action = self._default_failure_action(step, exec_result)
if action == FailureAction.RETRY:
# 重试已在 _execute_step_with_retry 中处理
return "none"
if action == FailureAction.SKIP:
step.status = PlanStepStatus.SKIPPED
exec_result.status = PlanStepStatus.SKIPPED
# 跳过依赖此步骤的后续步骤
self._skip_dependent_steps(step.step_id, step_map, step_results, plan)
return "adjusted"
if action == FailureAction.REPLACE:
# 替换步骤:标记当前步骤为 SKIPPED后续步骤不再依赖它
step.status = PlanStepStatus.SKIPPED
exec_result.status = PlanStepStatus.SKIPPED
return "adjusted"
if action == FailureAction.REQUEST_HUMAN:
if self._on_human_intervention:
human_action = await self._on_human_intervention(step, exec_result)
if human_action == FailureAction.SKIP:
step.status = PlanStepStatus.SKIPPED
exec_result.status = PlanStepStatus.SKIPPED
self._skip_dependent_steps(step.step_id, step_map, step_results, plan)
return "human_adjusted"
elif human_action == FailureAction.RETRY:
# 人工介入后重试
return "human"
return "human"
if action == FailureAction.ABORT:
# 将失败步骤本身也标记为 SKIPPED
step.status = PlanStepStatus.SKIPPED
exec_result.status = PlanStepStatus.SKIPPED
# 中止所有后续步骤
self._abort_remaining_steps(step_map, step_results, plan)
return "adjusted"
return "none"
def _default_failure_action(self, step: PlanStep, exec_result: StepExecutionResult) -> FailureAction:
"""默认失败处理策略
根据错误类型决定:
- 超时错误 → RETRY重试已在 _execute_step_with_retry 处理)
- Agent 不可用 → SKIP
- 其他错误 → SKIP
"""
error = exec_result.error or ""
if "timed out" in error.lower():
# 超时已通过重试处理,重试耗尽后跳过
return FailureAction.SKIP
if "no agent available" in error.lower():
return FailureAction.SKIP
return FailureAction.SKIP
def _skip_dependent_steps(
self,
failed_step_id: str,
step_map: dict[str, PlanStep],
step_results: dict[str, StepExecutionResult],
plan: ExecutionPlan,
) -> None:
"""跳过依赖失败步骤的后续步骤"""
for step in plan.steps:
if failed_step_id in step.dependencies and step.status == PlanStepStatus.PENDING:
step.status = PlanStepStatus.SKIPPED
step_results[step.step_id] = StepExecutionResult(
step_id=step.step_id,
status=PlanStepStatus.SKIPPED,
error=f"Skipped due to failed dependency '{failed_step_id}'",
)
# 递归跳过
self._skip_dependent_steps(step.step_id, step_map, step_results, plan)
def _abort_remaining_steps(
self,
step_map: dict[str, PlanStep],
step_results: dict[str, StepExecutionResult],
plan: ExecutionPlan,
) -> None:
"""中止所有剩余的未执行步骤"""
for step in plan.steps:
if step.status == PlanStepStatus.PENDING:
step.status = PlanStepStatus.SKIPPED
step_results[step.step_id] = StepExecutionResult(
step_id=step.step_id,
status=PlanStepStatus.SKIPPED,
error="Aborted due to previous step failure",
)
def _inject_dependency_results(
self,
step: PlanStep,
step_results: dict[str, StepExecutionResult],
) -> dict[str, Any]:
"""将依赖步骤的结果注入到当前步骤的输入中
兼容 Orchestrator 的 subtask_results 累积模式。
"""
enriched = dict(step.input_data)
if step.dependencies:
dep_results: dict[str, dict[str, Any]] = {}
for dep_id in step.dependencies:
if dep_id in step_results:
dep_result = step_results[dep_id]
dep_results[dep_id] = {
"status": dep_result.status.value,
"result": dep_result.result,
"error": dep_result.error,
}
if dep_results:
enriched["dependency_results"] = dep_results
# 添加步骤元信息
enriched["step_name"] = step.name
enriched["step_description"] = step.description
return enriched
def _determine_overall_status(
self,
plan: ExecutionPlan,
step_results: dict[str, StepExecutionResult],
) -> TaskStatus:
"""根据步骤执行结果确定整体状态"""
total = len(plan.steps)
if total == 0:
return TaskStatus.COMPLETED
completed = sum(1 for r in step_results.values() if r.status == PlanStepStatus.COMPLETED)
failed = sum(1 for r in step_results.values() if r.status == PlanStepStatus.FAILED)
skipped = sum(1 for r in step_results.values() if r.status == PlanStepStatus.SKIPPED)
if completed == total:
return TaskStatus.COMPLETED
if failed == total:
return TaskStatus.FAILED
if completed + skipped == total:
# 所有步骤要么完成要么跳过
return TaskStatus.COMPLETED
if failed > 0:
return TaskStatus.PARTIALLY_COMPLETED # 部分成功
return TaskStatus.COMPLETED