851 lines
33 KiB
Python
851 lines
33 KiB
Python
"""Pipeline Engine - DAG + 并行执行 + 步骤重试 + Saga 补偿"""
|
||
|
||
import asyncio
|
||
import logging
|
||
from collections import defaultdict
|
||
from datetime import datetime, timezone
|
||
from typing import Any
|
||
|
||
from agentkit.orchestrator.compensation import SagaOrchestrator
|
||
from agentkit.orchestrator.pipeline_schema import (
|
||
AdversarialState,
|
||
AdaptiveConfig,
|
||
Pipeline,
|
||
PipelineResult,
|
||
PipelineStage,
|
||
ReflectionReport,
|
||
ReviewFeedback,
|
||
ReviewIssue,
|
||
StageResult,
|
||
StageStatus,
|
||
)
|
||
from agentkit.orchestrator.reflection import PipelineReflector, PipelineReplanner
|
||
from agentkit.orchestrator.retry import execute_with_retry
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class PipelineEngine:
|
||
"""Pipeline 执行引擎
|
||
|
||
支持:
|
||
- DAG 拓扑排序
|
||
- 同层并行执行(asyncio.gather)
|
||
- 变量解析
|
||
- 条件执行
|
||
- 步骤级指数退避重试(StepRetryPolicy)
|
||
- Saga 补偿(LIFO 回滚已完成步骤)
|
||
- 状态持久化(可选)
|
||
"""
|
||
|
||
def __init__(self, dispatcher: Any = None, state_manager: Any = None, llm_gateway: Any = None):
|
||
self._dispatcher = dispatcher
|
||
self._state_manager = state_manager
|
||
self._llm_gateway = llm_gateway
|
||
|
||
async def execute(
|
||
self,
|
||
pipeline: Pipeline,
|
||
context: dict[str, Any] | None = None,
|
||
adaptive_config: AdaptiveConfig | None = None,
|
||
) -> PipelineResult:
|
||
"""执行 Pipeline
|
||
|
||
Args:
|
||
pipeline: Pipeline 定义
|
||
context: 运行时上下文变量
|
||
adaptive_config: 自适应配置,启用反思-重规划闭环
|
||
"""
|
||
# First execution
|
||
result = await self._execute_pipeline(pipeline, context)
|
||
|
||
# If failed and adaptive is enabled, enter reflection-replanning loop
|
||
if result.status == StageStatus.FAILED and adaptive_config and adaptive_config.enabled:
|
||
result = await self._adaptive_loop(pipeline, context, result, adaptive_config)
|
||
|
||
return result
|
||
|
||
async def _adaptive_loop(
|
||
self,
|
||
pipeline: Pipeline,
|
||
context: dict[str, Any] | None,
|
||
failed_result: PipelineResult,
|
||
adaptive_config: AdaptiveConfig,
|
||
) -> PipelineResult:
|
||
"""反思-重规划闭环:分析失败原因 → 修正 Pipeline → 重新执行。"""
|
||
reflector = PipelineReflector(llm_gateway=self._llm_gateway)
|
||
replanner = PipelineReplanner(llm_gateway=self._llm_gateway)
|
||
|
||
current_pipeline = pipeline
|
||
current_result = failed_result
|
||
reflections: list[ReflectionReport] = []
|
||
|
||
for reflection_num in range(1, adaptive_config.max_reflections + 1):
|
||
# Reflect
|
||
report = await reflector.reflect(current_pipeline, current_result, reflection_num)
|
||
reflections.append(report)
|
||
logger.info(
|
||
f"Pipeline reflection #{reflection_num}: "
|
||
f"failure_type={report.failure_type}, "
|
||
f"root_cause={report.root_cause}"
|
||
)
|
||
|
||
# Replan
|
||
new_pipeline = await replanner.replan(current_pipeline, current_result, report)
|
||
logger.info(f"Pipeline replanned: {new_pipeline.name} ({len(new_pipeline.stages)} stages)")
|
||
|
||
# Re-execute
|
||
current_result = await self._execute_pipeline(new_pipeline, context)
|
||
current_pipeline = new_pipeline
|
||
|
||
# Record reflection in metadata
|
||
current_result.metadata["reflections"] = [
|
||
r.model_dump() for r in reflections
|
||
]
|
||
|
||
if current_result.status == StageStatus.COMPLETED:
|
||
logger.info(f"Pipeline succeeded after {reflection_num} reflection(s)")
|
||
return current_result
|
||
|
||
# Exhausted reflections
|
||
logger.warning(
|
||
f"Pipeline failed after {adaptive_config.max_reflections} reflection(s)"
|
||
)
|
||
current_result.metadata["reflections"] = [
|
||
r.model_dump() for r in reflections
|
||
]
|
||
return current_result
|
||
|
||
async def _execute_pipeline(
|
||
self,
|
||
pipeline: Pipeline,
|
||
context: dict[str, Any] | None = None,
|
||
) -> PipelineResult:
|
||
"""执行 Pipeline 的核心逻辑(不含反思-重规划)。"""
|
||
result = PipelineResult(pipeline_name=pipeline.name)
|
||
result.variables = {**pipeline.variables, **(context or {})}
|
||
|
||
# 拓扑排序 + 按依赖层级分组
|
||
try:
|
||
level_groups = self._topological_group(pipeline.stages)
|
||
except ValueError as e:
|
||
result.status = StageStatus.FAILED
|
||
result.error_message = str(e)
|
||
return result
|
||
|
||
# Create execution state if state_manager is configured
|
||
execution_id: str | None = None
|
||
if self._state_manager is not None:
|
||
try:
|
||
step_names = [s.name for s in pipeline.stages]
|
||
execution_id = await self._state_manager.create_execution(
|
||
pipeline_name=pipeline.name,
|
||
steps=step_names,
|
||
input_data=context,
|
||
)
|
||
except (asyncio.TimeoutError, ConnectionError, RuntimeError, ValueError) as exc:
|
||
logger.warning(f"Failed to create execution state: {exc}")
|
||
|
||
# Create Saga orchestrator for compensation tracking
|
||
saga = SagaOrchestrator()
|
||
|
||
# 逐层执行
|
||
for level, stages in enumerate(level_groups):
|
||
logger.info(f"Pipeline '{pipeline.name}' executing level {level} with {len(stages)} stage(s)")
|
||
|
||
# 并行执行同层 stages
|
||
tasks = []
|
||
for stage in stages:
|
||
tasks.append(self._execute_stage(stage, result, saga))
|
||
|
||
stage_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||
|
||
# 处理结果
|
||
for stage, sr in zip(stages, stage_results):
|
||
if isinstance(sr, Exception):
|
||
sr = StageResult(
|
||
stage_name=stage.name,
|
||
status=StageStatus.FAILED,
|
||
error_message=str(sr),
|
||
)
|
||
result.stage_results[stage.name] = sr
|
||
|
||
# Update step state
|
||
if self._state_manager is not None and execution_id is not None:
|
||
try:
|
||
step_status = "completed" if sr.status == StageStatus.COMPLETED else sr.status.value
|
||
step_output = sr.output_data if hasattr(sr, 'output_data') else None
|
||
step_error = sr.error_message if hasattr(sr, 'error_message') else None
|
||
await self._state_manager.update_step(
|
||
execution_id=execution_id,
|
||
step_name=stage.name,
|
||
status=step_status,
|
||
output=step_output,
|
||
error=step_error,
|
||
)
|
||
except (asyncio.TimeoutError, ConnectionError, RuntimeError, ValueError) as exc:
|
||
logger.warning(f"Failed to update step state: {exc}")
|
||
|
||
# 收集输出变量
|
||
if sr.output_data and isinstance(sr, dict):
|
||
pass
|
||
elif hasattr(sr, 'output_data') and sr.output_data:
|
||
for output_key in stage.outputs:
|
||
if output_key in sr.output_data:
|
||
result.variables[output_key] = sr.output_data[output_key]
|
||
|
||
# 检查是否需要中止
|
||
if hasattr(sr, 'status') and sr.status == StageStatus.FAILED:
|
||
if not stage.continue_on_failure:
|
||
# Execute Saga compensation for completed steps
|
||
compensation_results = await saga.compensate()
|
||
if compensation_results:
|
||
failed_compensations = [
|
||
cr for cr in compensation_results if not cr.success and cr.error != "no_compensation_needed"
|
||
]
|
||
if failed_compensations:
|
||
logger.warning(
|
||
f"Compensation had {len(failed_compensations)} failures: "
|
||
f"{[c.step_name for c in failed_compensations]}"
|
||
)
|
||
|
||
result.status = StageStatus.FAILED
|
||
result.error_message = f"Stage '{stage.name}' failed"
|
||
# Fail execution state
|
||
if self._state_manager is not None and execution_id is not None:
|
||
try:
|
||
await self._state_manager.fail_execution(
|
||
execution_id=execution_id,
|
||
step_name=stage.name,
|
||
error=result.error_message,
|
||
)
|
||
except (asyncio.TimeoutError, ConnectionError, RuntimeError, ValueError) as exc:
|
||
logger.warning(f"Failed to persist failure state: {exc}")
|
||
return result
|
||
|
||
result.status = StageStatus.COMPLETED
|
||
|
||
# Complete execution state
|
||
if self._state_manager is not None and execution_id is not None:
|
||
try:
|
||
final_output = {
|
||
name: sr.output_data
|
||
for name, sr in result.stage_results.items()
|
||
if sr.output_data is not None
|
||
}
|
||
await self._state_manager.complete_execution(
|
||
execution_id=execution_id,
|
||
final_output=final_output,
|
||
)
|
||
except (asyncio.TimeoutError, ConnectionError, RuntimeError, ValueError) as exc:
|
||
logger.warning(f"Failed to persist completion state: {exc}")
|
||
|
||
return result
|
||
|
||
async def _execute_stage(
|
||
self,
|
||
stage: PipelineStage,
|
||
pipeline_result: PipelineResult,
|
||
saga: SagaOrchestrator,
|
||
) -> StageResult:
|
||
"""执行单个 stage"""
|
||
started_at = datetime.now(timezone.utc).isoformat()
|
||
|
||
# 条件检查
|
||
if stage.condition and not self._evaluate_condition(stage.condition, pipeline_result.variables):
|
||
return StageResult(
|
||
stage_name=stage.name,
|
||
status=StageStatus.SKIPPED,
|
||
started_at=started_at,
|
||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||
)
|
||
|
||
# 如果配置了 verifier,进入对抗模式
|
||
if stage.verifier:
|
||
return await self._execute_stage_with_adversarial(
|
||
stage, pipeline_result, saga, started_at
|
||
)
|
||
|
||
# 解析输入变量
|
||
resolved_inputs = self._resolve_variables(stage.inputs, pipeline_result.variables)
|
||
|
||
# 执行
|
||
if self._dispatcher is None:
|
||
# Dry-run 模式
|
||
result = StageResult(
|
||
stage_name=stage.name,
|
||
status=StageStatus.COMPLETED,
|
||
output_data={"dry_run": True, "inputs": resolved_inputs},
|
||
started_at=started_at,
|
||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||
)
|
||
# Record completed step for Saga compensation
|
||
saga.record_completed(
|
||
step_name=stage.name,
|
||
result=result.output_data,
|
||
compensate_action=stage.compensate,
|
||
)
|
||
return result
|
||
|
||
# 通过 Dispatcher 分发任务
|
||
from agentkit.core.protocol import TaskMessage
|
||
import uuid
|
||
|
||
task = TaskMessage(
|
||
task_id=str(uuid.uuid4()),
|
||
agent_name=stage.agent,
|
||
task_type=stage.action,
|
||
priority=0,
|
||
input_data=resolved_inputs,
|
||
callback_url=None,
|
||
created_at=datetime.now(timezone.utc),
|
||
timeout_seconds=stage.timeout_seconds,
|
||
)
|
||
|
||
async def _dispatch_and_wait() -> StageResult:
|
||
"""Dispatch task and wait for result"""
|
||
await self._dispatcher.dispatch(task)
|
||
|
||
# 等待结果
|
||
for _ in range(stage.timeout_seconds):
|
||
status = await self._dispatcher.get_task_status(task.task_id)
|
||
if status["status"] in ("completed", "failed", "cancelled"):
|
||
return StageResult(
|
||
stage_name=stage.name,
|
||
status=StageStatus.COMPLETED if status["status"] == "completed" else StageStatus.FAILED,
|
||
output_data=status.get("output_data"),
|
||
error_message=status.get("error_message"),
|
||
started_at=started_at,
|
||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||
)
|
||
await asyncio.sleep(1)
|
||
|
||
return StageResult(
|
||
stage_name=stage.name,
|
||
status=StageStatus.FAILED,
|
||
error_message=f"Timeout after {stage.timeout_seconds}s",
|
||
started_at=started_at,
|
||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||
)
|
||
|
||
try:
|
||
# Execute with retry if retry_policy is configured
|
||
sr = await execute_with_retry(
|
||
func=_dispatch_and_wait,
|
||
retry_policy=stage.retry_policy,
|
||
step_name=stage.name,
|
||
)
|
||
|
||
# Record completed step for Saga compensation on success
|
||
if sr.status == StageStatus.COMPLETED:
|
||
saga.record_completed(
|
||
step_name=stage.name,
|
||
result=sr.output_data,
|
||
compensate_action=stage.compensate,
|
||
)
|
||
|
||
return sr
|
||
|
||
except asyncio.CancelledError:
|
||
raise
|
||
|
||
except (asyncio.TimeoutError, ConnectionError, RuntimeError, ValueError) as e:
|
||
# dispatcher / agent 执行失败 — 转 StageResult.FAILED 不向上抛
|
||
return StageResult(
|
||
stage_name=stage.name,
|
||
status=StageStatus.FAILED,
|
||
error_message=str(e),
|
||
started_at=started_at,
|
||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||
)
|
||
|
||
@staticmethod
|
||
def _topological_group(stages: list[PipelineStage]) -> list[list[PipelineStage]]:
|
||
"""拓扑排序 + 按依赖层级分组"""
|
||
stage_map = {s.name: s for s in stages}
|
||
in_degree = defaultdict(int)
|
||
dependents = defaultdict(list)
|
||
|
||
for s in stages:
|
||
if s.name not in in_degree:
|
||
in_degree[s.name] = 0
|
||
for dep in s.depends_on:
|
||
if dep not in stage_map:
|
||
raise ValueError(f"Stage '{s.name}' depends on unknown stage '{dep}'")
|
||
in_degree[s.name] += 1
|
||
dependents[dep].append(s.name)
|
||
|
||
levels = []
|
||
remaining = set(in_degree.keys())
|
||
|
||
while remaining:
|
||
# 找到入度为 0 的节点
|
||
current_level = [name for name in remaining if in_degree[name] == 0]
|
||
if not current_level:
|
||
raise ValueError("Circular dependency detected in pipeline")
|
||
|
||
levels.append([stage_map[name] for name in current_level])
|
||
|
||
for name in current_level:
|
||
remaining.remove(name)
|
||
for dep in dependents[name]:
|
||
in_degree[dep] -= 1
|
||
|
||
return levels
|
||
|
||
@staticmethod
|
||
def _resolve_variables(template: dict, context: dict) -> dict:
|
||
"""解析 ${var.path} 变量引用"""
|
||
resolved = {}
|
||
for key, value in template.items():
|
||
if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
|
||
var_path = value[2:-1]
|
||
resolved[key] = PipelineEngine._get_nested(context, var_path)
|
||
else:
|
||
resolved[key] = value
|
||
return resolved
|
||
|
||
@staticmethod
|
||
def _get_nested(data: dict, path: str) -> Any:
|
||
keys = path.split(".")
|
||
current = data
|
||
for key in keys:
|
||
if isinstance(current, dict):
|
||
current = current.get(key)
|
||
else:
|
||
return None
|
||
return current
|
||
|
||
@staticmethod
|
||
def _evaluate_condition(condition: str, variables: dict) -> bool:
|
||
"""简单条件评估"""
|
||
if "==" in condition:
|
||
parts = condition.split("==", 1)
|
||
left = variables.get(parts[0].strip(), parts[0].strip())
|
||
right = parts[1].strip().strip("'\"")
|
||
return str(left) == right
|
||
elif "!=" in condition:
|
||
parts = condition.split("!=", 1)
|
||
left = variables.get(parts[0].strip(), parts[0].strip())
|
||
right = parts[1].strip().strip("'\"")
|
||
return str(left) != right
|
||
else:
|
||
return bool(variables.get(condition))
|
||
|
||
async def _execute_stage_with_adversarial(
|
||
self,
|
||
stage: PipelineStage,
|
||
pipeline_result: PipelineResult,
|
||
saga: SagaOrchestrator,
|
||
started_at: str,
|
||
) -> StageResult:
|
||
"""执行带对抗闭环的 stage
|
||
|
||
Worker 产出 → Verifier 审查 → 不通过则带反馈打回 Worker → 循环至通过或轮次耗尽
|
||
"""
|
||
adversarial_state = AdversarialState(
|
||
current_round=0,
|
||
max_rounds=stage.max_adversarial_rounds,
|
||
)
|
||
|
||
resolved_inputs = self._resolve_variables(stage.inputs, pipeline_result.variables)
|
||
current_context = resolved_inputs.copy()
|
||
|
||
for round_num in range(1, stage.max_adversarial_rounds + 1):
|
||
adversarial_state.current_round = round_num
|
||
logger.info(
|
||
f"Adversarial round {round_num}/{stage.max_adversarial_rounds} "
|
||
f"for stage '{stage.name}'"
|
||
)
|
||
|
||
# 1. 执行 Worker Agent
|
||
worker_result = await self._execute_agent_stage(
|
||
stage.agent,
|
||
stage.action,
|
||
current_context,
|
||
stage,
|
||
started_at,
|
||
)
|
||
|
||
if worker_result.status != StageStatus.COMPLETED:
|
||
# Worker 执行失败,直接返回
|
||
return worker_result
|
||
|
||
# 2. 执行 Verifier 审查
|
||
try:
|
||
verifier_feedback = await self._execute_verifier(
|
||
stage.verifier,
|
||
worker_result.output_data or {},
|
||
stage,
|
||
started_at,
|
||
)
|
||
except asyncio.CancelledError:
|
||
raise
|
||
except (asyncio.TimeoutError, ConnectionError, RuntimeError, ValueError) as e:
|
||
logger.error(f"Verifier execution failed for stage '{stage.name}': {e}")
|
||
return StageResult(
|
||
stage_name=stage.name,
|
||
status=StageStatus.FAILED,
|
||
error_message=f"Verifier failed: {e}",
|
||
started_at=started_at,
|
||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||
)
|
||
|
||
# 3. 记录反馈历史
|
||
adversarial_state.feedback_history.append(verifier_feedback)
|
||
adversarial_state.last_feedback = verifier_feedback
|
||
|
||
if verifier_feedback.passed:
|
||
# 审查通过,返回成功结果
|
||
logger.info(
|
||
f"Stage '{stage.name}' passed review in round {round_num}"
|
||
)
|
||
worker_result.output_data = worker_result.output_data or {}
|
||
worker_result.output_data["adversarial_metadata"] = {
|
||
"passed_round": round_num,
|
||
"total_rounds": round_num,
|
||
"feedback_summary": verifier_feedback.summary,
|
||
"score": verifier_feedback.score,
|
||
}
|
||
saga.record_completed(
|
||
step_name=stage.name,
|
||
result=worker_result.output_data,
|
||
compensate_action=stage.compensate,
|
||
)
|
||
return worker_result
|
||
|
||
# 4. 审查不通过,判断是否还有重试机会
|
||
logger.warning(
|
||
f"Stage '{stage.name}' failed review in round {round_num}: "
|
||
f"{verifier_feedback.summary}"
|
||
)
|
||
|
||
if round_num >= stage.max_adversarial_rounds:
|
||
# 轮次耗尽,执行升级处理
|
||
return await self._escalate(
|
||
stage,
|
||
worker_result,
|
||
adversarial_state,
|
||
started_at,
|
||
)
|
||
|
||
# 5. 打回 Worker 重做,附带反馈和前次产出
|
||
feedback_context = self._build_feedback_context(
|
||
verifier_feedback,
|
||
stage.feedback_mode,
|
||
)
|
||
current_context = {
|
||
**resolved_inputs,
|
||
"previous_output": worker_result.output_data,
|
||
**feedback_context,
|
||
}
|
||
|
||
# 不应该到达这里,但以防万一
|
||
return StageResult(
|
||
stage_name=stage.name,
|
||
status=StageStatus.FAILED,
|
||
error_message="Adversarial loop exited unexpectedly",
|
||
started_at=started_at,
|
||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||
)
|
||
|
||
async def _execute_agent_stage(
|
||
self,
|
||
agent_name: str,
|
||
action: str,
|
||
input_data: dict[str, Any],
|
||
stage: PipelineStage,
|
||
started_at: str,
|
||
timeout_seconds: int | None = None,
|
||
) -> StageResult:
|
||
"""执行单个 Agent stage(不含对抗逻辑)
|
||
|
||
Args:
|
||
agent_name: Agent 名称
|
||
action: 执行动作
|
||
input_data: 输入数据
|
||
stage: 所属 stage
|
||
started_at: 开始时间
|
||
timeout_seconds: 独立超时时间,不传则使用 stage.timeout_seconds
|
||
"""
|
||
effective_timeout = timeout_seconds if timeout_seconds is not None else stage.timeout_seconds
|
||
if self._dispatcher is None:
|
||
# Dry-run 模式
|
||
return StageResult(
|
||
stage_name=stage.name,
|
||
status=StageStatus.COMPLETED,
|
||
output_data={"dry_run": True, "inputs": input_data},
|
||
started_at=started_at,
|
||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||
)
|
||
|
||
from agentkit.core.protocol import TaskMessage
|
||
import uuid
|
||
|
||
task = TaskMessage(
|
||
task_id=str(uuid.uuid4()),
|
||
agent_name=agent_name,
|
||
task_type=action,
|
||
priority=0,
|
||
input_data=input_data,
|
||
callback_url=None,
|
||
created_at=datetime.now(timezone.utc),
|
||
timeout_seconds=effective_timeout,
|
||
)
|
||
|
||
async def _dispatch_and_wait() -> StageResult:
|
||
"""Dispatch task and wait for result"""
|
||
await self._dispatcher.dispatch(task)
|
||
|
||
for _ in range(effective_timeout):
|
||
status = await self._dispatcher.get_task_status(task.task_id)
|
||
if status["status"] in ("completed", "failed", "cancelled"):
|
||
return StageResult(
|
||
stage_name=stage.name,
|
||
status=StageStatus.COMPLETED if status["status"] == "completed" else StageStatus.FAILED,
|
||
output_data=status.get("output_data"),
|
||
error_message=status.get("error_message"),
|
||
started_at=started_at,
|
||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||
)
|
||
await asyncio.sleep(1)
|
||
|
||
return StageResult(
|
||
stage_name=stage.name,
|
||
status=StageStatus.FAILED,
|
||
error_message=f"Timeout after {effective_timeout}s",
|
||
started_at=started_at,
|
||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||
)
|
||
|
||
try:
|
||
sr = await execute_with_retry(
|
||
func=_dispatch_and_wait,
|
||
retry_policy=stage.retry_policy,
|
||
step_name=stage.name,
|
||
)
|
||
return sr
|
||
except asyncio.CancelledError:
|
||
raise
|
||
except (asyncio.TimeoutError, ConnectionError, RuntimeError, ValueError) as e:
|
||
return StageResult(
|
||
stage_name=stage.name,
|
||
status=StageStatus.FAILED,
|
||
error_message=str(e),
|
||
started_at=started_at,
|
||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||
)
|
||
|
||
async def _execute_verifier(
|
||
self,
|
||
verifier_name: str,
|
||
worker_output: dict[str, Any],
|
||
stage: PipelineStage,
|
||
started_at: str,
|
||
) -> ReviewFeedback:
|
||
"""执行 Verifier Agent 审查 Worker 产出
|
||
|
||
Returns:
|
||
ReviewFeedback: 结构化审查反馈
|
||
"""
|
||
logger.info(f"Executing verifier '{verifier_name}' for stage '{stage.name}'")
|
||
|
||
# 构建审查输入
|
||
verifier_input = {
|
||
"review_target": worker_output,
|
||
"review_instruction": (
|
||
"Please review the following output for quality, correctness, and completeness. "
|
||
"Return a structured review with pass/fail status, issues found, and a summary."
|
||
),
|
||
}
|
||
|
||
# 执行 Verifier Agent(使用独立超时)
|
||
verifier_result = await self._execute_agent_stage(
|
||
verifier_name,
|
||
"review",
|
||
verifier_input,
|
||
stage,
|
||
started_at,
|
||
timeout_seconds=stage.verifier_timeout_seconds,
|
||
)
|
||
|
||
if verifier_result.status != StageStatus.COMPLETED:
|
||
raise RuntimeError(
|
||
f"Verifier '{verifier_name}' failed: {verifier_result.error_message}"
|
||
)
|
||
|
||
# 解析返回结果为 ReviewFeedback
|
||
output_data = verifier_result.output_data or {}
|
||
try:
|
||
feedback = ReviewFeedback(
|
||
passed=output_data.get("passed", False),
|
||
issues=[
|
||
ReviewIssue(**issue)
|
||
for issue in output_data.get("issues", [])
|
||
],
|
||
summary=output_data.get("summary", "No summary provided"),
|
||
score=output_data.get("score", 0.0),
|
||
)
|
||
return feedback
|
||
except (TypeError, KeyError, ValueError) as e:
|
||
# 解析失败时直接抛出异常,避免死循环
|
||
logger.error(f"Failed to parse verifier output: {e}")
|
||
raise RuntimeError(
|
||
f"Verifier '{verifier_name}' returned unparseable output: {e}. "
|
||
f"Raw output keys: {list(output_data.keys())}"
|
||
) from e
|
||
|
||
def _build_feedback_context(
|
||
self,
|
||
feedback: ReviewFeedback,
|
||
feedback_mode: str = "structured+natural",
|
||
) -> dict[str, Any]:
|
||
"""构建反馈上下文,让 Worker Agent 理解审查反馈并定向修复
|
||
|
||
Args:
|
||
feedback: 审查反馈
|
||
feedback_mode: 反馈模式 (structured+natural / structured / natural)
|
||
|
||
Returns:
|
||
dict: 反馈上下文字典
|
||
"""
|
||
issues_list = [
|
||
{
|
||
"severity": issue.severity,
|
||
"category": issue.category,
|
||
"description": issue.description,
|
||
"location": issue.location,
|
||
"suggestion": issue.suggestion,
|
||
}
|
||
for issue in feedback.issues
|
||
]
|
||
|
||
feedback_context: dict[str, Any] = {
|
||
"previous_attempt_failed": True,
|
||
}
|
||
|
||
if feedback_mode == "structured+natural":
|
||
feedback_context["review_feedback"] = {
|
||
"summary": feedback.summary,
|
||
"issues": issues_list,
|
||
"previous_score": feedback.score,
|
||
}
|
||
feedback_context["instruction"] = (
|
||
"Your previous output did not pass review. "
|
||
"Please fix the issues listed above and regenerate. "
|
||
f"Review summary: {feedback.summary}"
|
||
)
|
||
elif feedback_mode == "structured":
|
||
feedback_context["review_feedback"] = {
|
||
"issues": issues_list,
|
||
"previous_score": feedback.score,
|
||
}
|
||
feedback_context["instruction"] = (
|
||
"Your previous output did not pass review. "
|
||
"Please fix the issues listed above and regenerate."
|
||
)
|
||
elif feedback_mode == "natural":
|
||
feedback_context["review_feedback"] = {
|
||
"summary": feedback.summary,
|
||
"previous_score": feedback.score,
|
||
}
|
||
feedback_context["instruction"] = (
|
||
f"Your previous output did not pass review. "
|
||
f"Review feedback: {feedback.summary}. "
|
||
"Please regenerate addressing the feedback."
|
||
)
|
||
else:
|
||
# 未知模式,fallback 到 structured+natural
|
||
logger.warning(f"Unknown feedback_mode '{feedback_mode}', falling back to structured+natural")
|
||
feedback_context["review_feedback"] = {
|
||
"summary": feedback.summary,
|
||
"issues": issues_list,
|
||
"previous_score": feedback.score,
|
||
}
|
||
feedback_context["instruction"] = (
|
||
"Your previous output did not pass review. "
|
||
"Please fix the issues listed above and regenerate. "
|
||
f"Review summary: {feedback.summary}"
|
||
)
|
||
|
||
return feedback_context
|
||
|
||
async def _escalate(
|
||
self,
|
||
stage: PipelineStage,
|
||
worker_result: StageResult,
|
||
adversarial_state: AdversarialState,
|
||
started_at: str,
|
||
) -> StageResult:
|
||
"""对抗轮次耗尽后的升级处理
|
||
|
||
Args:
|
||
stage: 当前 stage
|
||
worker_result: 最后一次 Worker 结果
|
||
adversarial_state: 对抗状态
|
||
started_at: 开始时间
|
||
|
||
Returns:
|
||
StageResult: 升级后的结果
|
||
"""
|
||
logger.warning(
|
||
f"Adversarial rounds exhausted for stage '{stage.name}' "
|
||
f"({adversarial_state.current_round}/{adversarial_state.max_rounds})"
|
||
)
|
||
|
||
if stage.escalate_on_exhaust:
|
||
# 转发到升级目标
|
||
logger.info(f"Escalating stage '{stage.name}' to '{stage.escalate_on_exhaust}'")
|
||
escalate_result = await self._execute_agent_stage(
|
||
stage.escalate_on_exhaust,
|
||
"handle_escalation",
|
||
{
|
||
"original_output": worker_result.output_data,
|
||
"adversarial_state": adversarial_state.model_dump(),
|
||
"escalation_reason": (
|
||
f"Failed to pass review after {adversarial_state.current_round} rounds"
|
||
),
|
||
},
|
||
stage,
|
||
started_at,
|
||
)
|
||
escalate_result.output_data = escalate_result.output_data or {}
|
||
escalate_result.output_data["adversarial_metadata"] = {
|
||
"escalated_to": stage.escalate_on_exhaust,
|
||
"total_rounds": adversarial_state.current_round,
|
||
"feedback_history_summary": [
|
||
{"round": i + 1, "passed": fb.passed, "score": fb.score}
|
||
for i, fb in enumerate(adversarial_state.feedback_history)
|
||
],
|
||
}
|
||
# 如果升级 Agent 也失败了,合并错误信息
|
||
if escalate_result.status == StageStatus.FAILED:
|
||
escalate_result.error_message = (
|
||
f"Escalation to '{stage.escalate_on_exhaust}' also failed: "
|
||
f"{escalate_result.error_message}. "
|
||
f"Original adversarial rounds exhausted: {adversarial_state.current_round}/{adversarial_state.max_rounds}"
|
||
)
|
||
return escalate_result
|
||
else:
|
||
# 返回失败结果,附带审查历史
|
||
last_feedback = adversarial_state.last_feedback
|
||
return StageResult(
|
||
stage_name=stage.name,
|
||
status=StageStatus.FAILED,
|
||
error_message=(
|
||
f"Adversarial rounds exhausted ({adversarial_state.current_round}/"
|
||
f"{adversarial_state.max_rounds}). "
|
||
f"Last review: {last_feedback.summary if last_feedback else 'N/A'}"
|
||
),
|
||
output_data={
|
||
"adversarial_metadata": {
|
||
"total_rounds": adversarial_state.current_round,
|
||
"feedback_history": [
|
||
fb.model_dump() for fb in adversarial_state.feedback_history
|
||
],
|
||
}
|
||
},
|
||
started_at=started_at,
|
||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||
)
|