fischer-agentkit/src/agentkit/orchestrator/reflection.py

371 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Pipeline 反思-重规划模块
当 Pipeline 执行失败时,通过 LLM 反思分析失败原因,
生成修正后的 Pipeline 重新执行。
"""
import json
import logging
from typing import Any
from agentkit.orchestrator.pipeline_schema import (
Pipeline,
PipelineResult,
PipelineStage,
ReflectionReport,
StageResult,
StageStatus,
)
logger = logging.getLogger(__name__)
class PipelineReflector:
"""分析 Pipeline 执行失败原因,生成结构化反思报告。
使用 LLM 分析失败上下文(哪步失败、错误信息、已完成步骤输出),
输出 ReflectionReport 包含 failure_type、root_cause 和 suggested_fix。
"""
def __init__(self, llm_gateway: Any = None):
self._llm_gateway = llm_gateway
async def reflect(
self,
pipeline: Pipeline,
result: PipelineResult,
reflection_number: int = 1,
) -> ReflectionReport:
"""分析失败原因并生成反思报告。
Args:
pipeline: 原始 Pipeline 定义
result: 执行失败的 PipelineResult
reflection_number: 当前是第几次反思
Returns:
ReflectionReport 结构化反思报告
"""
# 收集失败上下文
failed_stage, error_message = self._find_failure(result)
completed_outputs = self._collect_completed_outputs(result)
# 如果有 LLM Gateway使用 LLM 分析
if self._llm_gateway is not None:
try:
return await self._llm_reflect(
pipeline, failed_stage, error_message,
completed_outputs, reflection_number,
)
except Exception as e:
logger.warning(f"LLM reflection failed, falling back to rule-based: {e}")
# 规则兜底:基于错误信息分类
return self._rule_based_reflect(
failed_stage, error_message, reflection_number,
)
def _find_failure(
self, result: PipelineResult,
) -> tuple[str, str]:
"""找到第一个失败的 stage 及其错误信息。"""
for name, sr in result.stage_results.items():
if sr.status == StageStatus.FAILED:
return name, sr.error_message or "unknown error"
return "", "no failed stage found"
def _collect_completed_outputs(
self, result: PipelineResult,
) -> dict[str, Any]:
"""收集已完成步骤的输出。"""
outputs = {}
for name, sr in result.stage_results.items():
if sr.status == StageStatus.COMPLETED and sr.output_data:
outputs[name] = sr.output_data
return outputs
async def _llm_reflect(
self,
pipeline: Pipeline,
failed_stage: str,
error_message: str,
completed_outputs: dict[str, Any],
reflection_number: int,
) -> ReflectionReport:
"""使用 LLM 分析失败原因。"""
prompt = self._build_reflection_prompt(
pipeline, failed_stage, error_message,
completed_outputs, reflection_number,
)
response = await self._llm_gateway.chat(
messages=[{"role": "user", "content": prompt}],
model="default",
)
# 解析 LLM 返回的 JSON
content = response.content if hasattr(response, "content") else str(response)
return self._parse_reflection_response(
content, failed_stage, reflection_number,
)
def _build_reflection_prompt(
self,
pipeline: Pipeline,
failed_stage: str,
error_message: str,
completed_outputs: dict[str, Any],
reflection_number: int,
) -> str:
"""构建反思提示词。"""
stage_descriptions = []
for s in pipeline.stages:
stage_descriptions.append(
f" - {s.name}: agent={s.agent}, action={s.action}, "
f"depends_on={s.depends_on}"
)
completed_summary = json.dumps(
{k: str(v)[:200] for k, v in completed_outputs.items()},
ensure_ascii=False,
)
return f"""Analyze the following pipeline execution failure and provide a structured reflection report.
Pipeline: {pipeline.name}
Stages:
{chr(10).join(stage_descriptions)}
Failed stage: {failed_stage}
Error message: {error_message}
Completed outputs (summary): {completed_summary}
Reflection attempt: {reflection_number}
Respond in JSON format with these fields:
- failure_type: one of "input_error", "resource_error", "logic_error", "timeout"
- root_cause: brief description of the root cause
- suggested_fix: concrete fix to apply to the pipeline
JSON response:"""
def _parse_reflection_response(
self,
content: str,
failed_stage: str,
reflection_number: int,
) -> ReflectionReport:
"""解析 LLM 返回的反思报告。"""
# 尝试提取 JSON
try:
# 处理 markdown 代码块包裹的 JSON
text = content.strip()
if text.startswith("```"):
lines = text.split("\n")
text = "\n".join(lines[1:-1])
data = json.loads(text)
return ReflectionReport(
failure_type=data.get("failure_type", "logic_error"),
root_cause=data.get("root_cause", "LLM analysis unavailable"),
suggested_fix=data.get("suggested_fix", ""),
failed_stage=failed_stage,
reflection_number=reflection_number,
)
except (json.JSONDecodeError, KeyError) as e:
logger.warning(f"Failed to parse LLM reflection response: {e}")
return self._rule_based_reflect(
failed_stage, content, reflection_number,
)
def _rule_based_reflect(
self,
failed_stage: str,
error_message: str,
reflection_number: int,
) -> ReflectionReport:
"""基于规则的兜底反思。"""
error_lower = error_message.lower()
if "timeout" in error_lower or "timed out" in error_lower:
failure_type = "timeout"
root_cause = f"Stage '{failed_stage}' timed out"
suggested_fix = "Increase timeout_seconds and add retry_policy"
elif "not found" in error_lower or "404" in error_lower:
failure_type = "resource_error"
root_cause = f"Required resource not found in stage '{failed_stage}'"
suggested_fix = "Add pre-check step or adjust resource reference"
elif "invalid" in error_lower or "validation" in error_lower:
failure_type = "input_error"
root_cause = f"Invalid input to stage '{failed_stage}'"
suggested_fix = "Add input validation step before this stage"
else:
failure_type = "logic_error"
root_cause = f"Stage '{failed_stage}' failed: {error_message[:200]}"
suggested_fix = "Review stage logic and adjust action or inputs"
return ReflectionReport(
failure_type=failure_type,
root_cause=root_cause,
suggested_fix=suggested_fix,
failed_stage=failed_stage,
reflection_number=reflection_number,
)
class PipelineReplanner:
"""基于反思报告生成修正后的 Pipeline。
保留已完成步骤的结果,仅重新规划失败及后续步骤。
"""
def __init__(self, llm_gateway: Any = None):
self._llm_gateway = llm_gateway
async def replan(
self,
pipeline: Pipeline,
result: PipelineResult,
report: ReflectionReport,
) -> Pipeline:
"""基于反思报告重新规划 Pipeline。
Args:
pipeline: 原始 Pipeline
result: 执行失败的 PipelineResult
report: 反思报告
Returns:
修正后的 Pipeline
"""
# 如果有 LLM Gateway使用 LLM 重规划
if self._llm_gateway is not None:
try:
return await self._llm_replan(pipeline, result, report)
except Exception as e:
logger.warning(f"LLM replanning failed, falling back to rule-based: {e}")
# 规则兜底:基于 failure_type 调整
return self._rule_based_replan(pipeline, result, report)
async def _llm_replan(
self,
pipeline: Pipeline,
result: PipelineResult,
report: ReflectionReport,
) -> Pipeline:
"""使用 LLM 生成修正后的 Pipeline。"""
completed_stages = [
name for name, sr in result.stage_results.items()
if sr.status == StageStatus.COMPLETED
]
prompt = f"""Based on the reflection report, generate a corrected pipeline.
Original pipeline: {pipeline.name}
Stages: {[s.name for s in pipeline.stages]}
Completed stages: {completed_stages}
Failed stage: {report.failed_stage}
Failure type: {report.failure_type}
Root cause: {report.root_cause}
Suggested fix: {report.suggested_fix}
Generate a corrected pipeline in JSON format with the same structure as the original.
Only modify stages that need changes based on the reflection.
Keep completed stages unchanged.
JSON pipeline:"""
response = await self._llm_gateway.chat(
messages=[{"role": "user", "content": prompt}],
model="default",
)
content = response.content if hasattr(response, "content") else str(response)
return self._parse_pipeline_response(content, pipeline)
def _parse_pipeline_response(
self, content: str, original: Pipeline,
) -> Pipeline:
"""解析 LLM 返回的 Pipeline JSON。"""
try:
text = content.strip()
if text.startswith("```"):
lines = text.split("\n")
text = "\n".join(lines[1:-1])
data = json.loads(text)
stages = [
PipelineStage(**s) for s in data.get("stages", [])
]
return Pipeline(
name=data.get("name", original.name),
version=data.get("version", original.version),
description=data.get("description", original.description),
stages=stages,
variables=data.get("variables", original.variables),
)
except (json.JSONDecodeError, Exception) as e:
logger.warning(f"Failed to parse LLM replan response: {e}")
return original
def _rule_based_replan(
self,
pipeline: Pipeline,
result: PipelineResult,
report: ReflectionReport,
) -> Pipeline:
"""基于规则的兜底重规划。"""
completed_stages = {
name for name, sr in result.stage_results.items()
if sr.status == StageStatus.COMPLETED
}
# 构建修正后的 stages 列表
new_stages: list[PipelineStage] = []
for stage in pipeline.stages:
if stage.name in completed_stages:
# 已完成的步骤保持不变,但标记为 continue_on_failure
# 因为它们的结果已经存在
new_stages.append(stage)
elif stage.name == report.failed_stage:
# 失败步骤:根据 failure_type 调整
modified = self._adjust_failed_stage(stage, report)
new_stages.append(modified)
else:
# 后续步骤保持不变
new_stages.append(stage)
return Pipeline(
name=f"{pipeline.name}_replanned",
version=pipeline.version,
description=f"Replanned after reflection: {report.root_cause}",
stages=new_stages,
variables=pipeline.variables,
)
def _adjust_failed_stage(
self, stage: PipelineStage, report: ReflectionReport,
) -> PipelineStage:
"""根据反思报告调整失败的步骤。"""
adjustments: dict[str, Any] = {}
if report.failure_type == "timeout":
adjustments["timeout_seconds"] = min(
stage.timeout_seconds * 2, 3600,
)
if stage.retry_policy is None:
from agentkit.orchestrator.retry import StepRetryPolicy
adjustments["retry_policy"] = StepRetryPolicy(max_attempts=2)
elif report.failure_type == "resource_error":
adjustments["continue_on_failure"] = True
elif report.failure_type == "input_error":
# 添加重试策略,可能输入在后续可用
if stage.retry_policy is None:
from agentkit.orchestrator.retry import StepRetryPolicy
adjustments["retry_policy"] = StepRetryPolicy(max_attempts=2)
return stage.model_copy(update=adjustments)