feat(pipeline): implement adversarial loop execution logic
Add Worker-Verifier adversarial loop to PipelineEngine: - _execute_stage_with_adversarial: main loop for Worker→Verifier→retry - _execute_agent_stage: extracted agent execution logic - _execute_verifier: execute verifier and parse ReviewFeedback - _build_feedback_context: build feedback context for worker retry - _escalate: handle round exhaustion (escalate or fail) - Route to adversarial mode when stage.verifier is configured Support three feedback modes: structured+natural, structured, natural
This commit is contained in:
parent
b733b3a732
commit
dc07c7c60a
|
|
@ -8,11 +8,14 @@ from typing import Any
|
|||
|
||||
from agentkit.orchestrator.compensation import SagaOrchestrator
|
||||
from agentkit.orchestrator.pipeline_schema import (
|
||||
AdversarialState,
|
||||
AdaptiveConfig,
|
||||
Pipeline,
|
||||
PipelineResult,
|
||||
PipelineStage,
|
||||
ReflectionReport,
|
||||
ReviewFeedback,
|
||||
ReviewIssue,
|
||||
StageResult,
|
||||
StageStatus,
|
||||
)
|
||||
|
|
@ -257,6 +260,12 @@ class PipelineEngine:
|
|||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
|
||||
# 如果配置了 verifier,进入对抗模式
|
||||
if stage.verifier:
|
||||
return await self._execute_stage_with_adversarial(
|
||||
stage, pipeline_result, saga, started_at
|
||||
)
|
||||
|
||||
# 解析输入变量
|
||||
resolved_inputs = self._resolve_variables(stage.inputs, pipeline_result.variables)
|
||||
|
||||
|
|
@ -418,3 +427,402 @@ class PipelineEngine:
|
|||
return str(left) != right
|
||||
else:
|
||||
return bool(variables.get(condition))
|
||||
|
||||
async def _execute_stage_with_adversarial(
|
||||
self,
|
||||
stage: PipelineStage,
|
||||
pipeline_result: PipelineResult,
|
||||
saga: SagaOrchestrator,
|
||||
started_at: str,
|
||||
) -> StageResult:
|
||||
"""执行带对抗闭环的 stage
|
||||
|
||||
Worker 产出 → Verifier 审查 → 不通过则带反馈打回 Worker → 循环至通过或轮次耗尽
|
||||
"""
|
||||
adversarial_state = AdversarialState(
|
||||
current_round=0,
|
||||
max_rounds=stage.max_adversarial_rounds,
|
||||
)
|
||||
|
||||
resolved_inputs = self._resolve_variables(stage.inputs, pipeline_result.variables)
|
||||
current_context = resolved_inputs.copy()
|
||||
last_worker_result: StageResult | None = None
|
||||
|
||||
for round_num in range(1, stage.max_adversarial_rounds + 1):
|
||||
adversarial_state.current_round = round_num
|
||||
logger.info(
|
||||
f"Adversarial round {round_num}/{stage.max_adversarial_rounds} "
|
||||
f"for stage '{stage.name}'"
|
||||
)
|
||||
|
||||
# 1. 执行 Worker Agent
|
||||
worker_result = await self._execute_agent_stage(
|
||||
stage.agent,
|
||||
stage.action,
|
||||
current_context,
|
||||
stage,
|
||||
started_at,
|
||||
)
|
||||
|
||||
if worker_result.status != StageStatus.COMPLETED:
|
||||
# Worker 执行失败,直接返回
|
||||
return worker_result
|
||||
|
||||
last_worker_result = worker_result
|
||||
|
||||
# 2. 执行 Verifier 审查
|
||||
try:
|
||||
verifier_feedback = await self._execute_verifier(
|
||||
stage.verifier,
|
||||
worker_result.output_data or {},
|
||||
stage,
|
||||
started_at,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Verifier execution failed for stage '{stage.name}': {e}")
|
||||
return StageResult(
|
||||
stage_name=stage.name,
|
||||
status=StageStatus.FAILED,
|
||||
error_message=f"Verifier failed: {e}",
|
||||
started_at=started_at,
|
||||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
|
||||
# 3. 记录反馈历史
|
||||
adversarial_state.feedback_history.append(verifier_feedback)
|
||||
adversarial_state.last_feedback = verifier_feedback
|
||||
|
||||
if verifier_feedback.passed:
|
||||
# 审查通过,返回成功结果
|
||||
logger.info(
|
||||
f"Stage '{stage.name}' passed review in round {round_num}"
|
||||
)
|
||||
worker_result.output_data = worker_result.output_data or {}
|
||||
worker_result.output_data["adversarial_metadata"] = {
|
||||
"passed_round": round_num,
|
||||
"total_rounds": round_num,
|
||||
"feedback_summary": verifier_feedback.summary,
|
||||
"score": verifier_feedback.score,
|
||||
}
|
||||
saga.record_completed(
|
||||
step_name=stage.name,
|
||||
result=worker_result.output_data,
|
||||
compensate_action=stage.compensate,
|
||||
)
|
||||
return worker_result
|
||||
|
||||
# 4. 审查不通过,判断是否还有重试机会
|
||||
logger.warning(
|
||||
f"Stage '{stage.name}' failed review in round {round_num}: "
|
||||
f"{verifier_feedback.summary}"
|
||||
)
|
||||
|
||||
if round_num >= stage.max_adversarial_rounds:
|
||||
# 轮次耗尽,执行升级处理
|
||||
return await self._escalate(
|
||||
stage,
|
||||
worker_result,
|
||||
adversarial_state,
|
||||
started_at,
|
||||
)
|
||||
|
||||
# 5. 打回 Worker 重做,附带反馈
|
||||
feedback_context = self._build_feedback_context(
|
||||
verifier_feedback,
|
||||
stage.feedback_mode,
|
||||
)
|
||||
current_context = {**resolved_inputs, **feedback_context}
|
||||
|
||||
# 不应该到达这里,但以防万一
|
||||
return StageResult(
|
||||
stage_name=stage.name,
|
||||
status=StageStatus.FAILED,
|
||||
error_message="Adversarial loop exited unexpectedly",
|
||||
started_at=started_at,
|
||||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
|
||||
async def _execute_agent_stage(
|
||||
self,
|
||||
agent_name: str,
|
||||
action: str,
|
||||
input_data: dict[str, Any],
|
||||
stage: PipelineStage,
|
||||
started_at: str,
|
||||
) -> StageResult:
|
||||
"""执行单个 Agent stage(不含对抗逻辑)"""
|
||||
if self._dispatcher is None:
|
||||
# Dry-run 模式
|
||||
return StageResult(
|
||||
stage_name=stage.name,
|
||||
status=StageStatus.COMPLETED,
|
||||
output_data={"dry_run": True, "inputs": input_data},
|
||||
started_at=started_at,
|
||||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
|
||||
from agentkit.core.protocol import TaskMessage
|
||||
import uuid
|
||||
|
||||
task = TaskMessage(
|
||||
task_id=str(uuid.uuid4()),
|
||||
agent_name=agent_name,
|
||||
task_type=action,
|
||||
priority=0,
|
||||
input_data=input_data,
|
||||
callback_url=None,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
timeout_seconds=stage.timeout_seconds,
|
||||
)
|
||||
|
||||
async def _dispatch_and_wait() -> StageResult:
|
||||
"""Dispatch task and wait for result"""
|
||||
await self._dispatcher.dispatch(task)
|
||||
|
||||
for _ in range(stage.timeout_seconds):
|
||||
status = await self._dispatcher.get_task_status(task.task_id)
|
||||
if status["status"] in ("completed", "failed", "cancelled"):
|
||||
return StageResult(
|
||||
stage_name=stage.name,
|
||||
status=StageStatus.COMPLETED if status["status"] == "completed" else StageStatus.FAILED,
|
||||
output_data=status.get("output_data"),
|
||||
error_message=status.get("error_message"),
|
||||
started_at=started_at,
|
||||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
return StageResult(
|
||||
stage_name=stage.name,
|
||||
status=StageStatus.FAILED,
|
||||
error_message=f"Timeout after {stage.timeout_seconds}s",
|
||||
started_at=started_at,
|
||||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
|
||||
try:
|
||||
sr = await execute_with_retry(
|
||||
func=_dispatch_and_wait,
|
||||
retry_policy=stage.retry_policy,
|
||||
step_name=stage.name,
|
||||
)
|
||||
return sr
|
||||
except Exception as e:
|
||||
return StageResult(
|
||||
stage_name=stage.name,
|
||||
status=StageStatus.FAILED,
|
||||
error_message=str(e),
|
||||
started_at=started_at,
|
||||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
|
||||
async def _execute_verifier(
|
||||
self,
|
||||
verifier_name: str,
|
||||
worker_output: dict[str, Any],
|
||||
stage: PipelineStage,
|
||||
started_at: str,
|
||||
) -> ReviewFeedback:
|
||||
"""执行 Verifier Agent 审查 Worker 产出
|
||||
|
||||
Returns:
|
||||
ReviewFeedback: 结构化审查反馈
|
||||
"""
|
||||
logger.info(f"Executing verifier '{verifier_name}' for stage '{stage.name}'")
|
||||
|
||||
# 构建审查输入
|
||||
verifier_input = {
|
||||
"review_target": worker_output,
|
||||
"review_instruction": (
|
||||
"Please review the following output for quality, correctness, and completeness. "
|
||||
"Return a structured review with pass/fail status, issues found, and a summary."
|
||||
),
|
||||
}
|
||||
|
||||
# 执行 Verifier Agent
|
||||
verifier_result = await self._execute_agent_stage(
|
||||
verifier_name,
|
||||
"review",
|
||||
verifier_input,
|
||||
stage,
|
||||
started_at,
|
||||
)
|
||||
|
||||
if verifier_result.status != StageStatus.COMPLETED:
|
||||
raise RuntimeError(
|
||||
f"Verifier '{verifier_name}' failed: {verifier_result.error_message}"
|
||||
)
|
||||
|
||||
# 解析返回结果为 ReviewFeedback
|
||||
output_data = verifier_result.output_data or {}
|
||||
try:
|
||||
feedback = ReviewFeedback(
|
||||
passed=output_data.get("passed", False),
|
||||
issues=[
|
||||
ReviewIssue(**issue)
|
||||
for issue in output_data.get("issues", [])
|
||||
],
|
||||
summary=output_data.get("summary", "No summary provided"),
|
||||
score=output_data.get("score", 0.0),
|
||||
)
|
||||
return feedback
|
||||
except Exception as e:
|
||||
# 如果解析失败,创建默认反馈
|
||||
logger.warning(f"Failed to parse verifier output: {e}")
|
||||
return ReviewFeedback(
|
||||
passed=False,
|
||||
issues=[
|
||||
ReviewIssue(
|
||||
severity="major",
|
||||
category="logic_error",
|
||||
description=f"Failed to parse verifier output: {e}",
|
||||
)
|
||||
],
|
||||
summary="Verifier output parsing failed",
|
||||
score=0.0,
|
||||
)
|
||||
|
||||
def _build_feedback_context(
|
||||
self,
|
||||
feedback: ReviewFeedback,
|
||||
feedback_mode: str = "structured+natural",
|
||||
) -> dict[str, Any]:
|
||||
"""构建反馈上下文,让 Worker Agent 理解审查反馈并定向修复
|
||||
|
||||
Args:
|
||||
feedback: 审查反馈
|
||||
feedback_mode: 反馈模式 (structured+natural / structured / natural)
|
||||
|
||||
Returns:
|
||||
dict: 反馈上下文字典
|
||||
"""
|
||||
issues_list = [
|
||||
{
|
||||
"severity": issue.severity,
|
||||
"category": issue.category,
|
||||
"description": issue.description,
|
||||
"location": issue.location,
|
||||
"suggestion": issue.suggestion,
|
||||
}
|
||||
for issue in feedback.issues
|
||||
]
|
||||
|
||||
feedback_context: dict[str, Any] = {
|
||||
"previous_attempt_failed": True,
|
||||
}
|
||||
|
||||
if feedback_mode == "structured+natural":
|
||||
feedback_context["review_feedback"] = {
|
||||
"summary": feedback.summary,
|
||||
"issues": issues_list,
|
||||
"previous_score": feedback.score,
|
||||
}
|
||||
feedback_context["instruction"] = (
|
||||
"Your previous output did not pass review. "
|
||||
"Please fix the issues listed above and regenerate. "
|
||||
f"Review summary: {feedback.summary}"
|
||||
)
|
||||
elif feedback_mode == "structured":
|
||||
feedback_context["review_feedback"] = {
|
||||
"issues": issues_list,
|
||||
"previous_score": feedback.score,
|
||||
}
|
||||
feedback_context["instruction"] = (
|
||||
"Your previous output did not pass review. "
|
||||
"Please fix the issues listed above and regenerate."
|
||||
)
|
||||
elif feedback_mode == "natural":
|
||||
feedback_context["review_feedback"] = {
|
||||
"summary": feedback.summary,
|
||||
"previous_score": feedback.score,
|
||||
}
|
||||
feedback_context["instruction"] = (
|
||||
f"Your previous output did not pass review. "
|
||||
f"Review feedback: {feedback.summary}. "
|
||||
"Please regenerate addressing the feedback."
|
||||
)
|
||||
else:
|
||||
# 默认使用 structured+natural
|
||||
feedback_context["review_feedback"] = {
|
||||
"summary": feedback.summary,
|
||||
"issues": issues_list,
|
||||
"previous_score": feedback.score,
|
||||
}
|
||||
feedback_context["instruction"] = (
|
||||
"Your previous output did not pass review. "
|
||||
"Please fix the issues listed above and regenerate."
|
||||
)
|
||||
|
||||
return feedback_context
|
||||
|
||||
async def _escalate(
|
||||
self,
|
||||
stage: PipelineStage,
|
||||
worker_result: StageResult,
|
||||
adversarial_state: AdversarialState,
|
||||
started_at: str,
|
||||
) -> StageResult:
|
||||
"""对抗轮次耗尽后的升级处理
|
||||
|
||||
Args:
|
||||
stage: 当前 stage
|
||||
worker_result: 最后一次 Worker 结果
|
||||
adversarial_state: 对抗状态
|
||||
started_at: 开始时间
|
||||
|
||||
Returns:
|
||||
StageResult: 升级后的结果
|
||||
"""
|
||||
logger.warning(
|
||||
f"Adversarial rounds exhausted for stage '{stage.name}' "
|
||||
f"({adversarial_state.current_round}/{adversarial_state.max_rounds})"
|
||||
)
|
||||
|
||||
if stage.escalate_on_exhaust:
|
||||
# 转发到升级目标
|
||||
logger.info(f"Escalating stage '{stage.name}' to '{stage.escalate_on_exhaust}'")
|
||||
escalate_result = await self._execute_agent_stage(
|
||||
stage.escalate_on_exhaust,
|
||||
"handle_escalation",
|
||||
{
|
||||
"original_output": worker_result.output_data,
|
||||
"adversarial_state": adversarial_state.model_dump(),
|
||||
"escalation_reason": (
|
||||
f"Failed to pass review after {adversarial_state.current_round} rounds"
|
||||
),
|
||||
},
|
||||
stage,
|
||||
started_at,
|
||||
)
|
||||
escalate_result.output_data = escalate_result.output_data or {}
|
||||
escalate_result.output_data["adversarial_metadata"] = {
|
||||
"escalated_to": stage.escalate_on_exhaust,
|
||||
"total_rounds": adversarial_state.current_round,
|
||||
"feedback_history_summary": [
|
||||
{"round": i + 1, "passed": fb.passed, "score": fb.score}
|
||||
for i, fb in enumerate(adversarial_state.feedback_history)
|
||||
],
|
||||
}
|
||||
return escalate_result
|
||||
else:
|
||||
# 返回失败结果,附带审查历史
|
||||
last_feedback = adversarial_state.last_feedback
|
||||
return StageResult(
|
||||
stage_name=stage.name,
|
||||
status=StageStatus.FAILED,
|
||||
error_message=(
|
||||
f"Adversarial rounds exhausted ({adversarial_state.current_round}/"
|
||||
f"{adversarial_state.max_rounds}). "
|
||||
f"Last review: {last_feedback.summary if last_feedback else 'N/A'}"
|
||||
),
|
||||
output_data={
|
||||
"adversarial_metadata": {
|
||||
"total_rounds": adversarial_state.current_round,
|
||||
"feedback_history": [
|
||||
fb.model_dump() for fb in adversarial_state.feedback_history
|
||||
],
|
||||
}
|
||||
},
|
||||
started_at=started_at,
|
||||
completed_at=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue