test(pipeline): add adversarial loop unit tests
11 test cases covering: - PipelineSchemaAdversarial (4): verifier fields, backward compat, serialization, state tracking - AdversarialExecution (3): no verifier passthrough, first round pass, max rounds exhausted - FeedbackContext (3): structured+natural, structured, natural modes - Escalation (1): no escalation configured
This commit is contained in:
parent
6731d96c65
commit
3392413614
|
|
@ -0,0 +1,368 @@
|
|||
"""Pipeline 对抗闭环单元测试"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from agentkit.orchestrator.pipeline_engine import PipelineEngine
|
||||
from agentkit.orchestrator.pipeline_schema import (
|
||||
AdversarialState,
|
||||
Pipeline,
|
||||
PipelineResult,
|
||||
PipelineStage,
|
||||
ReviewFeedback,
|
||||
ReviewIssue,
|
||||
StageResult,
|
||||
StageStatus,
|
||||
)
|
||||
from agentkit.orchestrator.compensation import SagaOrchestrator
|
||||
|
||||
|
||||
class TestPipelineSchemaAdversarial:
|
||||
"""测试对抗闭环相关的 Schema 模型"""
|
||||
|
||||
def test_stage_with_verifier(self):
|
||||
"""Happy path: 创建带 verifier 字段的 PipelineStage"""
|
||||
stage = PipelineStage(
|
||||
name="review",
|
||||
agent="developer_agent",
|
||||
action="fix_code_issues",
|
||||
verifier="code_reviewer",
|
||||
max_adversarial_rounds=3,
|
||||
feedback_mode="structured+natural",
|
||||
escalate_on_exhaust="human_approval",
|
||||
)
|
||||
|
||||
assert stage.verifier == "code_reviewer"
|
||||
assert stage.max_adversarial_rounds == 3
|
||||
assert stage.feedback_mode == "structured+natural"
|
||||
assert stage.escalate_on_exhaust == "human_approval"
|
||||
|
||||
def test_stage_without_verifier_backward_compat(self):
|
||||
"""Edge case: verifier=None 时,PipelineStage 正常创建(向后兼容)"""
|
||||
stage = PipelineStage(
|
||||
name="develop",
|
||||
agent="developer_agent",
|
||||
action="implement_feature",
|
||||
)
|
||||
|
||||
assert stage.verifier is None
|
||||
assert stage.max_adversarial_rounds == 3 # 默认值
|
||||
assert stage.feedback_mode == "structured+natural" # 默认值
|
||||
assert stage.escalate_on_exhaust is None
|
||||
|
||||
def test_review_feedback_serialization(self):
|
||||
"""Happy path: 创建 ReviewFeedback 对象,验证序列化和反序列化正常"""
|
||||
feedback = ReviewFeedback(
|
||||
passed=False,
|
||||
issues=[
|
||||
ReviewIssue(
|
||||
severity="critical",
|
||||
category="security",
|
||||
description="SQL injection vulnerability",
|
||||
location="src/db.py:42",
|
||||
suggestion="Use parameterized queries",
|
||||
),
|
||||
ReviewIssue(
|
||||
severity="minor",
|
||||
category="style",
|
||||
description="Variable name too generic",
|
||||
),
|
||||
],
|
||||
summary="Found critical security issue",
|
||||
score=0.3,
|
||||
)
|
||||
|
||||
# 序列化
|
||||
data = feedback.model_dump()
|
||||
assert data["passed"] is False
|
||||
assert len(data["issues"]) == 2
|
||||
assert data["issues"][0]["severity"] == "critical"
|
||||
assert data["score"] == 0.3
|
||||
|
||||
# 反序列化
|
||||
restored = ReviewFeedback(**data)
|
||||
assert restored.passed is False
|
||||
assert len(restored.issues) == 2
|
||||
assert restored.issues[0].severity == "critical"
|
||||
|
||||
def test_adversarial_state_tracking(self):
|
||||
"""Happy path: AdversarialState 正确追踪对抗轮次"""
|
||||
state = AdversarialState(
|
||||
current_round=0,
|
||||
max_rounds=3,
|
||||
)
|
||||
|
||||
assert state.current_round == 0
|
||||
assert state.max_rounds == 3
|
||||
assert len(state.feedback_history) == 0
|
||||
assert state.last_feedback is None
|
||||
|
||||
# 模拟添加反馈
|
||||
feedback1 = ReviewFeedback(
|
||||
passed=False,
|
||||
issues=[ReviewIssue(severity="major", category="logic_error", description="Bug")],
|
||||
summary="Needs fix",
|
||||
score=0.5,
|
||||
)
|
||||
state.feedback_history.append(feedback1)
|
||||
state.last_feedback = feedback1
|
||||
state.current_round = 1
|
||||
|
||||
assert len(state.feedback_history) == 1
|
||||
assert state.last_feedback.passed is False
|
||||
assert state.current_round == 1
|
||||
|
||||
|
||||
class TestAdversarialExecution:
|
||||
"""测试对抗流转执行逻辑"""
|
||||
|
||||
@pytest.fixture
|
||||
def engine(self):
|
||||
"""创建带有 mock dispatcher 的 PipelineEngine"""
|
||||
dispatcher = AsyncMock()
|
||||
engine = PipelineEngine(dispatcher=dispatcher)
|
||||
return engine
|
||||
|
||||
@pytest.fixture
|
||||
def saga(self):
|
||||
"""创建 SagaOrchestrator"""
|
||||
return SagaOrchestrator()
|
||||
|
||||
@pytest.fixture
|
||||
def pipeline_result(self):
|
||||
"""创建空的 PipelineResult"""
|
||||
return PipelineResult(pipeline_name="test")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_verifier_passthrough(self, engine, saga, pipeline_result):
|
||||
"""Happy path: Stage 无 verifier → 走原有逻辑"""
|
||||
stage = PipelineStage(
|
||||
name="develop",
|
||||
agent="developer_agent",
|
||||
action="implement",
|
||||
)
|
||||
|
||||
# Mock dispatcher
|
||||
engine._dispatcher.dispatch = AsyncMock()
|
||||
engine._dispatcher.get_task_status = AsyncMock(side_effect=[
|
||||
{"status": "running"},
|
||||
{"status": "completed", "output_data": {"code": "print('hello')"}},
|
||||
])
|
||||
|
||||
result = await engine._execute_stage(stage, pipeline_result, saga)
|
||||
|
||||
assert result.status == StageStatus.COMPLETED
|
||||
assert result.output_data["code"] == "print('hello')"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verifier_passes_first_round(self, engine, saga, pipeline_result):
|
||||
"""Happy path: Stage 有 verifier,审查通过 → 一次完成"""
|
||||
stage = PipelineStage(
|
||||
name="review",
|
||||
agent="developer_agent",
|
||||
action="fix",
|
||||
verifier="code_reviewer",
|
||||
max_adversarial_rounds=3,
|
||||
)
|
||||
|
||||
# Mock worker execution
|
||||
call_count = 0
|
||||
|
||||
async def mock_dispatch(task):
|
||||
pass
|
||||
|
||||
async def mock_get_status(task_id):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count <= 2:
|
||||
return {"status": "running"}
|
||||
else:
|
||||
return {
|
||||
"status": "completed",
|
||||
"output_data": {
|
||||
"passed": True,
|
||||
"score": 0.9,
|
||||
"summary": "Code looks good",
|
||||
"issues": [],
|
||||
},
|
||||
}
|
||||
|
||||
engine._dispatcher.dispatch = AsyncMock(side_effect=mock_dispatch)
|
||||
engine._dispatcher.get_task_status = AsyncMock(side_effect=mock_get_status)
|
||||
|
||||
result = await engine._execute_stage(stage, pipeline_result, saga)
|
||||
|
||||
assert result.status == StageStatus.COMPLETED
|
||||
assert "adversarial_metadata" in result.output_data
|
||||
assert result.output_data["adversarial_metadata"]["passed_round"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_rounds_exhausted_no_escalate(self, engine, saga, pipeline_result):
|
||||
"""Edge case: escalate_on_exhaust=None → 返回失败,附带审查历史"""
|
||||
stage = PipelineStage(
|
||||
name="review",
|
||||
agent="developer_agent",
|
||||
action="fix",
|
||||
verifier="code_reviewer",
|
||||
max_adversarial_rounds=2,
|
||||
escalate_on_exhaust=None,
|
||||
)
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_dispatch(task):
|
||||
pass
|
||||
|
||||
async def mock_get_status(task_id):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
# 总是返回审查不通过
|
||||
return {
|
||||
"status": "completed",
|
||||
"output_data": {
|
||||
"passed": False,
|
||||
"score": 0.3,
|
||||
"summary": "Still has issues",
|
||||
"issues": [
|
||||
{
|
||||
"severity": "major",
|
||||
"category": "logic_error",
|
||||
"description": "Bug not fixed",
|
||||
}
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
engine._dispatcher.dispatch = AsyncMock(side_effect=mock_dispatch)
|
||||
engine._dispatcher.get_task_status = AsyncMock(side_effect=mock_get_status)
|
||||
|
||||
result = await engine._execute_stage(stage, pipeline_result, saga)
|
||||
|
||||
assert result.status == StageStatus.FAILED
|
||||
assert "Adversarial rounds exhausted" in result.error_message
|
||||
assert "adversarial_metadata" in result.output_data
|
||||
assert result.output_data["adversarial_metadata"]["total_rounds"] == 2
|
||||
|
||||
|
||||
class TestFeedbackContext:
|
||||
"""测试反馈上下文构建"""
|
||||
|
||||
@pytest.fixture
|
||||
def engine(self):
|
||||
return PipelineEngine(dispatcher=None)
|
||||
|
||||
def test_structured_and_natural_mode(self, engine):
|
||||
"""Happy path: feedback_mode="structured+natural" → 上下文包含 issues 和 summary"""
|
||||
feedback = ReviewFeedback(
|
||||
passed=False,
|
||||
issues=[
|
||||
ReviewIssue(
|
||||
severity="critical",
|
||||
category="security",
|
||||
description="SQL injection",
|
||||
suggestion="Use params",
|
||||
)
|
||||
],
|
||||
summary="Security issues found",
|
||||
score=0.2,
|
||||
)
|
||||
|
||||
context = engine._build_feedback_context(feedback, "structured+natural")
|
||||
|
||||
assert context["previous_attempt_failed"] is True
|
||||
assert "review_feedback" in context
|
||||
assert "summary" in context["review_feedback"]
|
||||
assert "issues" in context["review_feedback"]
|
||||
assert len(context["review_feedback"]["issues"]) == 1
|
||||
assert "instruction" in context
|
||||
assert "Security issues found" in context["instruction"]
|
||||
|
||||
def test_structured_only_mode(self, engine):
|
||||
"""Happy path: feedback_mode="structured" → 上下文只包含 issues"""
|
||||
feedback = ReviewFeedback(
|
||||
passed=False,
|
||||
issues=[
|
||||
ReviewIssue(
|
||||
severity="major",
|
||||
category="logic_error",
|
||||
description="Bug",
|
||||
)
|
||||
],
|
||||
summary="Logic error",
|
||||
score=0.4,
|
||||
)
|
||||
|
||||
context = engine._build_feedback_context(feedback, "structured")
|
||||
|
||||
assert "review_feedback" in context
|
||||
assert "issues" in context["review_feedback"]
|
||||
assert "summary" not in context["review_feedback"]
|
||||
assert "previous_score" in context["review_feedback"]
|
||||
|
||||
def test_natural_only_mode(self, engine):
|
||||
"""Happy path: feedback_mode="natural" → 上下文只包含 summary"""
|
||||
feedback = ReviewFeedback(
|
||||
passed=False,
|
||||
issues=[],
|
||||
summary="Please improve code quality",
|
||||
score=0.5,
|
||||
)
|
||||
|
||||
context = engine._build_feedback_context(feedback, "natural")
|
||||
|
||||
assert "review_feedback" in context
|
||||
assert "summary" in context["review_feedback"]
|
||||
assert "issues" not in context["review_feedback"]
|
||||
assert "Please improve code quality" in context["instruction"]
|
||||
|
||||
|
||||
class TestEscalation:
|
||||
"""测试升级处理"""
|
||||
|
||||
@pytest.fixture
|
||||
def engine(self):
|
||||
dispatcher = AsyncMock()
|
||||
return PipelineEngine(dispatcher=dispatcher)
|
||||
|
||||
@pytest.fixture
|
||||
def started_at(self):
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_escalation_configured(self, engine, started_at):
|
||||
"""Edge case: 没有配置 escalate_on_exhaust → 返回失败"""
|
||||
stage = PipelineStage(
|
||||
name="review",
|
||||
agent="developer",
|
||||
action="fix",
|
||||
verifier="reviewer",
|
||||
max_adversarial_rounds=3,
|
||||
escalate_on_exhaust=None,
|
||||
)
|
||||
|
||||
worker_result = StageResult(
|
||||
stage_name="review",
|
||||
status=StageStatus.COMPLETED,
|
||||
output_data={"code": "bad code"},
|
||||
)
|
||||
|
||||
adversarial_state = AdversarialState(
|
||||
current_round=3,
|
||||
max_rounds=3,
|
||||
feedback_history=[
|
||||
ReviewFeedback(
|
||||
passed=False,
|
||||
issues=[ReviewIssue(severity="major", category="logic_error", description="Bug")],
|
||||
summary="Failed review",
|
||||
score=0.3,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
result = await engine._escalate(stage, worker_result, adversarial_state, started_at)
|
||||
|
||||
assert result.status == StageStatus.FAILED
|
||||
assert "Adversarial rounds exhausted" in result.error_message
|
||||
assert "adversarial_metadata" in result.output_data
|
||||
assert result.output_data["adversarial_metadata"]["total_rounds"] == 3
|
||||
Loading…
Reference in New Issue