371 lines
13 KiB
Python
371 lines
13 KiB
Python
"""Pipeline 反思-重规划模块
|
||
|
||
当 Pipeline 执行失败时,通过 LLM 反思分析失败原因,
|
||
生成修正后的 Pipeline 重新执行。
|
||
"""
|
||
|
||
import json
|
||
import logging
|
||
from typing import Any
|
||
|
||
from agentkit.orchestrator.pipeline_schema import (
|
||
Pipeline,
|
||
PipelineResult,
|
||
PipelineStage,
|
||
ReflectionReport,
|
||
StageResult,
|
||
StageStatus,
|
||
)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class PipelineReflector:
|
||
"""分析 Pipeline 执行失败原因,生成结构化反思报告。
|
||
|
||
使用 LLM 分析失败上下文(哪步失败、错误信息、已完成步骤输出),
|
||
输出 ReflectionReport 包含 failure_type、root_cause 和 suggested_fix。
|
||
"""
|
||
|
||
def __init__(self, llm_gateway: Any = None):
|
||
self._llm_gateway = llm_gateway
|
||
|
||
async def reflect(
|
||
self,
|
||
pipeline: Pipeline,
|
||
result: PipelineResult,
|
||
reflection_number: int = 1,
|
||
) -> ReflectionReport:
|
||
"""分析失败原因并生成反思报告。
|
||
|
||
Args:
|
||
pipeline: 原始 Pipeline 定义
|
||
result: 执行失败的 PipelineResult
|
||
reflection_number: 当前是第几次反思
|
||
|
||
Returns:
|
||
ReflectionReport 结构化反思报告
|
||
"""
|
||
# 收集失败上下文
|
||
failed_stage, error_message = self._find_failure(result)
|
||
completed_outputs = self._collect_completed_outputs(result)
|
||
|
||
# 如果有 LLM Gateway,使用 LLM 分析
|
||
if self._llm_gateway is not None:
|
||
try:
|
||
return await self._llm_reflect(
|
||
pipeline, failed_stage, error_message,
|
||
completed_outputs, reflection_number,
|
||
)
|
||
except Exception as e:
|
||
logger.warning(f"LLM reflection failed, falling back to rule-based: {e}")
|
||
|
||
# 规则兜底:基于错误信息分类
|
||
return self._rule_based_reflect(
|
||
failed_stage, error_message, reflection_number,
|
||
)
|
||
|
||
def _find_failure(
|
||
self, result: PipelineResult,
|
||
) -> tuple[str, str]:
|
||
"""找到第一个失败的 stage 及其错误信息。"""
|
||
for name, sr in result.stage_results.items():
|
||
if sr.status == StageStatus.FAILED:
|
||
return name, sr.error_message or "unknown error"
|
||
return "", "no failed stage found"
|
||
|
||
def _collect_completed_outputs(
|
||
self, result: PipelineResult,
|
||
) -> dict[str, Any]:
|
||
"""收集已完成步骤的输出。"""
|
||
outputs = {}
|
||
for name, sr in result.stage_results.items():
|
||
if sr.status == StageStatus.COMPLETED and sr.output_data:
|
||
outputs[name] = sr.output_data
|
||
return outputs
|
||
|
||
async def _llm_reflect(
|
||
self,
|
||
pipeline: Pipeline,
|
||
failed_stage: str,
|
||
error_message: str,
|
||
completed_outputs: dict[str, Any],
|
||
reflection_number: int,
|
||
) -> ReflectionReport:
|
||
"""使用 LLM 分析失败原因。"""
|
||
prompt = self._build_reflection_prompt(
|
||
pipeline, failed_stage, error_message,
|
||
completed_outputs, reflection_number,
|
||
)
|
||
|
||
response = await self._llm_gateway.chat(
|
||
messages=[{"role": "user", "content": prompt}],
|
||
model="default",
|
||
)
|
||
|
||
# 解析 LLM 返回的 JSON
|
||
content = response.content if hasattr(response, "content") else str(response)
|
||
return self._parse_reflection_response(
|
||
content, failed_stage, reflection_number,
|
||
)
|
||
|
||
def _build_reflection_prompt(
|
||
self,
|
||
pipeline: Pipeline,
|
||
failed_stage: str,
|
||
error_message: str,
|
||
completed_outputs: dict[str, Any],
|
||
reflection_number: int,
|
||
) -> str:
|
||
"""构建反思提示词。"""
|
||
stage_descriptions = []
|
||
for s in pipeline.stages:
|
||
stage_descriptions.append(
|
||
f" - {s.name}: agent={s.agent}, action={s.action}, "
|
||
f"depends_on={s.depends_on}"
|
||
)
|
||
|
||
completed_summary = json.dumps(
|
||
{k: str(v)[:200] for k, v in completed_outputs.items()},
|
||
ensure_ascii=False,
|
||
)
|
||
|
||
return f"""Analyze the following pipeline execution failure and provide a structured reflection report.
|
||
|
||
Pipeline: {pipeline.name}
|
||
Stages:
|
||
{chr(10).join(stage_descriptions)}
|
||
|
||
Failed stage: {failed_stage}
|
||
Error message: {error_message}
|
||
Completed outputs (summary): {completed_summary}
|
||
Reflection attempt: {reflection_number}
|
||
|
||
Respond in JSON format with these fields:
|
||
- failure_type: one of "input_error", "resource_error", "logic_error", "timeout"
|
||
- root_cause: brief description of the root cause
|
||
- suggested_fix: concrete fix to apply to the pipeline
|
||
|
||
JSON response:"""
|
||
|
||
def _parse_reflection_response(
|
||
self,
|
||
content: str,
|
||
failed_stage: str,
|
||
reflection_number: int,
|
||
) -> ReflectionReport:
|
||
"""解析 LLM 返回的反思报告。"""
|
||
# 尝试提取 JSON
|
||
try:
|
||
# 处理 markdown 代码块包裹的 JSON
|
||
text = content.strip()
|
||
if text.startswith("```"):
|
||
lines = text.split("\n")
|
||
text = "\n".join(lines[1:-1])
|
||
|
||
data = json.loads(text)
|
||
return ReflectionReport(
|
||
failure_type=data.get("failure_type", "logic_error"),
|
||
root_cause=data.get("root_cause", "LLM analysis unavailable"),
|
||
suggested_fix=data.get("suggested_fix", ""),
|
||
failed_stage=failed_stage,
|
||
reflection_number=reflection_number,
|
||
)
|
||
except (json.JSONDecodeError, KeyError) as e:
|
||
logger.warning(f"Failed to parse LLM reflection response: {e}")
|
||
return self._rule_based_reflect(
|
||
failed_stage, content, reflection_number,
|
||
)
|
||
|
||
def _rule_based_reflect(
|
||
self,
|
||
failed_stage: str,
|
||
error_message: str,
|
||
reflection_number: int,
|
||
) -> ReflectionReport:
|
||
"""基于规则的兜底反思。"""
|
||
error_lower = error_message.lower()
|
||
|
||
if "timeout" in error_lower or "timed out" in error_lower:
|
||
failure_type = "timeout"
|
||
root_cause = f"Stage '{failed_stage}' timed out"
|
||
suggested_fix = "Increase timeout_seconds and add retry_policy"
|
||
elif "not found" in error_lower or "404" in error_lower:
|
||
failure_type = "resource_error"
|
||
root_cause = f"Required resource not found in stage '{failed_stage}'"
|
||
suggested_fix = "Add pre-check step or adjust resource reference"
|
||
elif "invalid" in error_lower or "validation" in error_lower:
|
||
failure_type = "input_error"
|
||
root_cause = f"Invalid input to stage '{failed_stage}'"
|
||
suggested_fix = "Add input validation step before this stage"
|
||
else:
|
||
failure_type = "logic_error"
|
||
root_cause = f"Stage '{failed_stage}' failed: {error_message[:200]}"
|
||
suggested_fix = "Review stage logic and adjust action or inputs"
|
||
|
||
return ReflectionReport(
|
||
failure_type=failure_type,
|
||
root_cause=root_cause,
|
||
suggested_fix=suggested_fix,
|
||
failed_stage=failed_stage,
|
||
reflection_number=reflection_number,
|
||
)
|
||
|
||
|
||
class PipelineReplanner:
|
||
"""基于反思报告生成修正后的 Pipeline。
|
||
|
||
保留已完成步骤的结果,仅重新规划失败及后续步骤。
|
||
"""
|
||
|
||
def __init__(self, llm_gateway: Any = None):
|
||
self._llm_gateway = llm_gateway
|
||
|
||
async def replan(
|
||
self,
|
||
pipeline: Pipeline,
|
||
result: PipelineResult,
|
||
report: ReflectionReport,
|
||
) -> Pipeline:
|
||
"""基于反思报告重新规划 Pipeline。
|
||
|
||
Args:
|
||
pipeline: 原始 Pipeline
|
||
result: 执行失败的 PipelineResult
|
||
report: 反思报告
|
||
|
||
Returns:
|
||
修正后的 Pipeline
|
||
"""
|
||
# 如果有 LLM Gateway,使用 LLM 重规划
|
||
if self._llm_gateway is not None:
|
||
try:
|
||
return await self._llm_replan(pipeline, result, report)
|
||
except Exception as e:
|
||
logger.warning(f"LLM replanning failed, falling back to rule-based: {e}")
|
||
|
||
# 规则兜底:基于 failure_type 调整
|
||
return self._rule_based_replan(pipeline, result, report)
|
||
|
||
async def _llm_replan(
|
||
self,
|
||
pipeline: Pipeline,
|
||
result: PipelineResult,
|
||
report: ReflectionReport,
|
||
) -> Pipeline:
|
||
"""使用 LLM 生成修正后的 Pipeline。"""
|
||
completed_stages = [
|
||
name for name, sr in result.stage_results.items()
|
||
if sr.status == StageStatus.COMPLETED
|
||
]
|
||
|
||
prompt = f"""Based on the reflection report, generate a corrected pipeline.
|
||
|
||
Original pipeline: {pipeline.name}
|
||
Stages: {[s.name for s in pipeline.stages]}
|
||
Completed stages: {completed_stages}
|
||
Failed stage: {report.failed_stage}
|
||
Failure type: {report.failure_type}
|
||
Root cause: {report.root_cause}
|
||
Suggested fix: {report.suggested_fix}
|
||
|
||
Generate a corrected pipeline in JSON format with the same structure as the original.
|
||
Only modify stages that need changes based on the reflection.
|
||
Keep completed stages unchanged.
|
||
|
||
JSON pipeline:"""
|
||
|
||
response = await self._llm_gateway.chat(
|
||
messages=[{"role": "user", "content": prompt}],
|
||
model="default",
|
||
)
|
||
|
||
content = response.content if hasattr(response, "content") else str(response)
|
||
return self._parse_pipeline_response(content, pipeline)
|
||
|
||
def _parse_pipeline_response(
|
||
self, content: str, original: Pipeline,
|
||
) -> Pipeline:
|
||
"""解析 LLM 返回的 Pipeline JSON。"""
|
||
try:
|
||
text = content.strip()
|
||
if text.startswith("```"):
|
||
lines = text.split("\n")
|
||
text = "\n".join(lines[1:-1])
|
||
|
||
data = json.loads(text)
|
||
stages = [
|
||
PipelineStage(**s) for s in data.get("stages", [])
|
||
]
|
||
return Pipeline(
|
||
name=data.get("name", original.name),
|
||
version=data.get("version", original.version),
|
||
description=data.get("description", original.description),
|
||
stages=stages,
|
||
variables=data.get("variables", original.variables),
|
||
)
|
||
except (json.JSONDecodeError, Exception) as e:
|
||
logger.warning(f"Failed to parse LLM replan response: {e}")
|
||
return original
|
||
|
||
def _rule_based_replan(
|
||
self,
|
||
pipeline: Pipeline,
|
||
result: PipelineResult,
|
||
report: ReflectionReport,
|
||
) -> Pipeline:
|
||
"""基于规则的兜底重规划。"""
|
||
completed_stages = {
|
||
name for name, sr in result.stage_results.items()
|
||
if sr.status == StageStatus.COMPLETED
|
||
}
|
||
|
||
# 构建修正后的 stages 列表
|
||
new_stages: list[PipelineStage] = []
|
||
|
||
for stage in pipeline.stages:
|
||
if stage.name in completed_stages:
|
||
# 已完成的步骤保持不变,但标记为 continue_on_failure
|
||
# 因为它们的结果已经存在
|
||
new_stages.append(stage)
|
||
elif stage.name == report.failed_stage:
|
||
# 失败步骤:根据 failure_type 调整
|
||
modified = self._adjust_failed_stage(stage, report)
|
||
new_stages.append(modified)
|
||
else:
|
||
# 后续步骤保持不变
|
||
new_stages.append(stage)
|
||
|
||
return Pipeline(
|
||
name=f"{pipeline.name}_replanned",
|
||
version=pipeline.version,
|
||
description=f"Replanned after reflection: {report.root_cause}",
|
||
stages=new_stages,
|
||
variables=pipeline.variables,
|
||
)
|
||
|
||
def _adjust_failed_stage(
|
||
self, stage: PipelineStage, report: ReflectionReport,
|
||
) -> PipelineStage:
|
||
"""根据反思报告调整失败的步骤。"""
|
||
adjustments: dict[str, Any] = {}
|
||
|
||
if report.failure_type == "timeout":
|
||
adjustments["timeout_seconds"] = min(
|
||
stage.timeout_seconds * 2, 3600,
|
||
)
|
||
if stage.retry_policy is None:
|
||
from agentkit.orchestrator.retry import StepRetryPolicy
|
||
adjustments["retry_policy"] = StepRetryPolicy(max_attempts=2)
|
||
|
||
elif report.failure_type == "resource_error":
|
||
adjustments["continue_on_failure"] = True
|
||
|
||
elif report.failure_type == "input_error":
|
||
# 添加重试策略,可能输入在后续可用
|
||
if stage.retry_policy is None:
|
||
from agentkit.orchestrator.retry import StepRetryPolicy
|
||
adjustments["retry_policy"] = StepRetryPolicy(max_attempts=2)
|
||
|
||
return stage.model_copy(update=adjustments)
|