334 lines
13 KiB
Python
334 lines
13 KiB
Python
"""U7 测试: 多 Agent 协同增强 - Pipeline 并行 + Handoff + DynamicPipeline 集成"""
|
||
|
||
import time
|
||
|
||
import pytest
|
||
|
||
from agentkit.core.protocol import HandoffMessage
|
||
from agentkit.orchestrator.dynamic_pipeline import DynamicPipeline
|
||
from agentkit.orchestrator.handoff import HandoffManager
|
||
from agentkit.orchestrator.pipeline_engine import PipelineEngine
|
||
from agentkit.orchestrator.pipeline_schema import Pipeline, PipelineStage, StageStatus
|
||
|
||
|
||
# ── Fixtures ──────────────────────────────────────────────
|
||
|
||
|
||
def _make_stage(name, agent="test_agent", action="process", depends_on=None,
|
||
inputs=None, outputs=None, condition=None, timeout_seconds=60):
|
||
return PipelineStage(
|
||
name=name,
|
||
agent=agent,
|
||
action=action,
|
||
depends_on=depends_on or [],
|
||
inputs=inputs or {},
|
||
outputs=outputs or [],
|
||
condition=condition,
|
||
timeout_seconds=timeout_seconds,
|
||
)
|
||
|
||
|
||
def _make_pipeline(name, stages, variables=None):
|
||
return Pipeline(
|
||
name=name,
|
||
version="1.0.0",
|
||
description=f"Test pipeline {name}",
|
||
stages=stages,
|
||
variables=variables or {},
|
||
)
|
||
|
||
|
||
# ── Pipeline 并行执行测试 ────────────────────────────────
|
||
|
||
|
||
class TestPipelineParallel:
|
||
async def test_parallel_stages_execute_concurrently(self):
|
||
"""无依赖的 stages 并行执行"""
|
||
engine = PipelineEngine(dispatcher=None)
|
||
|
||
pipeline = _make_pipeline("parallel_test", [
|
||
_make_stage("stage_a", inputs={"x": 1}),
|
||
_make_stage("stage_b", inputs={"y": 2}),
|
||
_make_stage("stage_c", inputs={"z": 3}),
|
||
])
|
||
|
||
start = time.monotonic()
|
||
result = await engine.execute(pipeline)
|
||
elapsed = time.monotonic() - start
|
||
|
||
assert result.status == StageStatus.COMPLETED
|
||
assert len(result.stage_results) == 3
|
||
assert elapsed < 1.0
|
||
|
||
async def test_sequential_stages_with_dependencies(self):
|
||
"""有依赖的 stages 按序执行"""
|
||
engine = PipelineEngine(dispatcher=None)
|
||
|
||
pipeline = _make_pipeline("sequential_test", [
|
||
_make_stage("step1", outputs=["result1"]),
|
||
_make_stage("step2", depends_on=["step1"], inputs={"data": "${result1}"}),
|
||
])
|
||
|
||
result = await engine.execute(pipeline)
|
||
assert result.status == StageStatus.COMPLETED
|
||
assert "step1" in result.stage_results
|
||
assert "step2" in result.stage_results
|
||
|
||
async def test_mixed_parallel_and_sequential(self):
|
||
"""混合并行和串行:A+B 并行 → C 依赖 A+B"""
|
||
engine = PipelineEngine(dispatcher=None)
|
||
|
||
pipeline = _make_pipeline("mixed_test", [
|
||
_make_stage("fetch_data", outputs=["raw_data"]),
|
||
_make_stage("fetch_config", outputs=["config"]),
|
||
_make_stage("process", depends_on=["fetch_data", "fetch_config"],
|
||
inputs={"data": "${raw_data}", "cfg": "${config}"}),
|
||
])
|
||
|
||
result = await engine.execute(pipeline)
|
||
assert result.status == StageStatus.COMPLETED
|
||
groups = PipelineEngine._topological_group(pipeline.stages)
|
||
assert len(groups) == 2
|
||
assert len(groups[0]) == 2
|
||
assert len(groups[1]) == 1
|
||
|
||
async def test_pipeline_with_condition(self):
|
||
"""条件 stage:满足条件时执行"""
|
||
engine = PipelineEngine(dispatcher=None)
|
||
|
||
pipeline = _make_pipeline("conditional_test", [
|
||
_make_stage("always_run"),
|
||
_make_stage("conditional_run", condition="run_analysis"),
|
||
], variables={"run_analysis": True})
|
||
|
||
result = await engine.execute(pipeline)
|
||
assert result.stage_results["always_run"].status == StageStatus.COMPLETED
|
||
assert result.stage_results["conditional_run"].status == StageStatus.COMPLETED
|
||
|
||
async def test_pipeline_condition_skip(self):
|
||
"""条件 stage:不满足条件时跳过"""
|
||
engine = PipelineEngine(dispatcher=None)
|
||
|
||
pipeline = _make_pipeline("skip_test", [
|
||
_make_stage("always_run"),
|
||
_make_stage("conditional_run", condition="run_analysis"),
|
||
], variables={"run_analysis": False})
|
||
|
||
result = await engine.execute(pipeline)
|
||
assert result.stage_results["always_run"].status == StageStatus.COMPLETED
|
||
assert result.stage_results["conditional_run"].status == StageStatus.SKIPPED
|
||
|
||
async def test_pipeline_circular_dependency(self):
|
||
"""循环依赖检测"""
|
||
engine = PipelineEngine(dispatcher=None)
|
||
|
||
pipeline = _make_pipeline("circular_test", [
|
||
_make_stage("a", depends_on=["b"]),
|
||
_make_stage("b", depends_on=["a"]),
|
||
])
|
||
|
||
result = await engine.execute(pipeline)
|
||
assert result.status == StageStatus.FAILED
|
||
assert "Circular" in result.error_message
|
||
|
||
async def test_pipeline_variable_resolution(self):
|
||
"""变量解析 ${var.path}"""
|
||
engine = PipelineEngine(dispatcher=None)
|
||
|
||
pipeline = _make_pipeline("var_test", [
|
||
_make_stage("step1", inputs={"url": "${source.url}"}),
|
||
], variables={"source": {"url": "https://example.com"}})
|
||
|
||
result = await engine.execute(pipeline)
|
||
step_result = result.stage_results["step1"]
|
||
assert step_result.output_data["dry_run"] is True
|
||
assert step_result.output_data["inputs"]["url"] == "https://example.com"
|
||
|
||
|
||
# ── Handoff 测试 ─────────────────────────────────────────
|
||
|
||
|
||
class TestHandoff:
|
||
async def test_handoff_message_creation(self):
|
||
"""HandoffMessage 正确创建"""
|
||
msg = HandoffMessage(
|
||
source_agent="agent_a",
|
||
target_agent="agent_b",
|
||
task_id="t-001",
|
||
task_type="analysis",
|
||
context={"data": "test"},
|
||
reason="Need specialized analysis",
|
||
)
|
||
assert msg.source_agent == "agent_a"
|
||
assert msg.target_agent == "agent_b"
|
||
assert msg.reason == "Need specialized analysis"
|
||
|
||
async def test_handoff_message_serialization(self):
|
||
"""HandoffMessage 序列化/反序列化"""
|
||
msg = HandoffMessage(
|
||
source_agent="a", target_agent="b",
|
||
task_id="t-001", task_type="test",
|
||
context={"key": "value"}, reason="test handoff",
|
||
)
|
||
data = msg.to_dict()
|
||
restored = HandoffMessage.from_dict(data)
|
||
assert restored.source_agent == "a"
|
||
assert restored.context["key"] == "value"
|
||
|
||
async def test_handoff_manager_register_handler(self):
|
||
"""HandoffManager 注册处理器"""
|
||
manager = HandoffManager()
|
||
manager.register_handler("agent_b", lambda h: None)
|
||
assert "agent_b" in manager._handlers
|
||
|
||
async def test_handoff_manager_without_redis(self):
|
||
"""无 Redis 时 send_handoff 抛异常"""
|
||
manager = HandoffManager()
|
||
msg = HandoffMessage(
|
||
source_agent="a", target_agent="b",
|
||
task_id="t-001", task_type="test",
|
||
context={}, reason="test",
|
||
)
|
||
with pytest.raises(RuntimeError, match="Redis"):
|
||
await manager.send_handoff(msg)
|
||
|
||
|
||
# ── DynamicPipeline 测试 ─────────────────────────────────
|
||
|
||
|
||
class TestDynamicPipeline:
|
||
async def test_conditional_pipeline(self):
|
||
"""根据条件选择子 Pipeline"""
|
||
engine = PipelineEngine(dispatcher=None)
|
||
dynamic = DynamicPipeline(engine=engine)
|
||
|
||
pipeline_a = _make_pipeline("pipeline_a", [_make_stage("step_a")])
|
||
pipeline_b = _make_pipeline("pipeline_b", [_make_stage("step_b")])
|
||
|
||
result = await dynamic.execute_conditional(
|
||
pipelines={"type_a": pipeline_a, "type_b": pipeline_b},
|
||
condition_key="task_type",
|
||
context={"task_type": "type_a"},
|
||
)
|
||
|
||
assert result.status == StageStatus.COMPLETED
|
||
assert "step_a" in result.stage_results
|
||
|
||
async def test_conditional_pipeline_no_match(self):
|
||
"""条件不匹配时失败"""
|
||
engine = PipelineEngine(dispatcher=None)
|
||
dynamic = DynamicPipeline(engine=engine)
|
||
|
||
pipeline_a = _make_pipeline("pipeline_a", [_make_stage("step_a")])
|
||
|
||
result = await dynamic.execute_conditional(
|
||
pipelines={"type_a": pipeline_a},
|
||
condition_key="task_type",
|
||
context={"task_type": "type_z"},
|
||
)
|
||
|
||
assert result.status == StageStatus.FAILED
|
||
|
||
async def test_loop_pipeline_exits_on_condition(self):
|
||
"""循环 Pipeline 在条件满足时退出"""
|
||
engine = PipelineEngine(dispatcher=None)
|
||
dynamic = DynamicPipeline(engine=engine)
|
||
|
||
pipeline = _make_pipeline("loop_test", [_make_stage("iterate")])
|
||
|
||
result = await dynamic.execute_loop(
|
||
pipeline=pipeline,
|
||
max_iterations=3,
|
||
exit_condition="done",
|
||
context={"done": True},
|
||
)
|
||
|
||
assert result.status == StageStatus.COMPLETED
|
||
|
||
async def test_loop_pipeline_max_iterations(self):
|
||
"""循环 Pipeline 达到最大迭代次数"""
|
||
engine = PipelineEngine(dispatcher=None)
|
||
dynamic = DynamicPipeline(engine=engine)
|
||
|
||
pipeline = _make_pipeline("loop_max_test", [_make_stage("iterate")])
|
||
|
||
result = await dynamic.execute_loop(
|
||
pipeline=pipeline,
|
||
max_iterations=2,
|
||
exit_condition="done",
|
||
context={},
|
||
)
|
||
|
||
assert result.status == StageStatus.COMPLETED
|
||
|
||
async def test_nested_pipeline(self):
|
||
"""嵌套 Pipeline 执行"""
|
||
engine = PipelineEngine(dispatcher=None)
|
||
dynamic = DynamicPipeline(engine=engine)
|
||
|
||
parent = _make_pipeline("parent", [_make_stage("parent_step")])
|
||
sub = _make_pipeline("sub_pipeline", [_make_stage("sub_step")])
|
||
|
||
result = await dynamic.execute_nested(
|
||
parent=parent,
|
||
sub_pipeline_map={"sub_pipeline": sub},
|
||
)
|
||
|
||
assert result.status == StageStatus.COMPLETED
|
||
|
||
|
||
# ── 端到端集成测试 ───────────────────────────────────────
|
||
|
||
|
||
class TestOrchestratorIntegration:
|
||
async def test_full_pipeline_with_handoff_message(self):
|
||
"""完整 Pipeline + Handoff 消息流转"""
|
||
engine = PipelineEngine(dispatcher=None)
|
||
|
||
pipeline = _make_pipeline("content_production", [
|
||
_make_stage("research", agent="research_agent", action="search",
|
||
inputs={"query": "${topic}"}, outputs=["research_data"]),
|
||
_make_stage("generate", agent="content_agent", action="generate",
|
||
depends_on=["research"],
|
||
inputs={"data": "${research_data}"}, outputs=["draft"]),
|
||
_make_stage("optimize", agent="seo_agent", action="optimize",
|
||
depends_on=["generate"],
|
||
inputs={"content": "${draft}"}, outputs=["final_content"]),
|
||
], variables={"topic": "AI trends 2026"})
|
||
|
||
result = await engine.execute(pipeline)
|
||
assert result.status == StageStatus.COMPLETED
|
||
assert len(result.stage_results) == 3
|
||
|
||
handoff = HandoffMessage(
|
||
source_agent="research_agent",
|
||
target_agent="content_agent",
|
||
task_id="t-handoff-001",
|
||
task_type="handoff_research_to_content",
|
||
context={"research_data": "AI market analysis..."},
|
||
reason="Research complete, handing off to content generation",
|
||
)
|
||
assert handoff.source_agent == "research_agent"
|
||
assert handoff.context["research_data"] == "AI market analysis..."
|
||
|
||
async def test_parallel_pipeline_with_variables(self):
|
||
"""并行 Pipeline + 变量传递"""
|
||
engine = PipelineEngine(dispatcher=None)
|
||
|
||
pipeline = _make_pipeline("parallel_with_vars", [
|
||
_make_stage("check_citation", agent="citation_agent",
|
||
inputs={"url": "${url}"}, outputs=["citation_result"]),
|
||
_make_stage("check_trends", agent="trend_agent",
|
||
inputs={"url": "${url}"}, outputs=["trend_result"]),
|
||
_make_stage("compile_report", agent="report_agent",
|
||
depends_on=["check_citation", "check_trends"],
|
||
inputs={"citation": "${citation_result}", "trends": "${trend_result}"}),
|
||
], variables={"url": "https://example.com"})
|
||
|
||
result = await engine.execute(pipeline)
|
||
assert result.status == StageStatus.COMPLETED
|
||
|
||
groups = PipelineEngine._topological_group(pipeline.stages)
|
||
assert len(groups[0]) == 2
|
||
assert len(groups[1]) == 1
|