fischer-agentkit/tests/unit/test_orchestrator_integrati...

334 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""U7 测试: 多 Agent 协同增强 - Pipeline 并行 + Handoff + DynamicPipeline 集成"""
import time
import pytest
from agentkit.core.protocol import HandoffMessage
from agentkit.orchestrator.dynamic_pipeline import DynamicPipeline
from agentkit.orchestrator.handoff import HandoffManager
from agentkit.orchestrator.pipeline_engine import PipelineEngine
from agentkit.orchestrator.pipeline_schema import Pipeline, PipelineStage, StageStatus
# ── Fixtures ──────────────────────────────────────────────
def _make_stage(name, agent="test_agent", action="process", depends_on=None,
inputs=None, outputs=None, condition=None, timeout_seconds=60):
return PipelineStage(
name=name,
agent=agent,
action=action,
depends_on=depends_on or [],
inputs=inputs or {},
outputs=outputs or [],
condition=condition,
timeout_seconds=timeout_seconds,
)
def _make_pipeline(name, stages, variables=None):
return Pipeline(
name=name,
version="1.0.0",
description=f"Test pipeline {name}",
stages=stages,
variables=variables or {},
)
# ── Pipeline 并行执行测试 ────────────────────────────────
class TestPipelineParallel:
async def test_parallel_stages_execute_concurrently(self):
"""无依赖的 stages 并行执行"""
engine = PipelineEngine(dispatcher=None)
pipeline = _make_pipeline("parallel_test", [
_make_stage("stage_a", inputs={"x": 1}),
_make_stage("stage_b", inputs={"y": 2}),
_make_stage("stage_c", inputs={"z": 3}),
])
start = time.monotonic()
result = await engine.execute(pipeline)
elapsed = time.monotonic() - start
assert result.status == StageStatus.COMPLETED
assert len(result.stage_results) == 3
assert elapsed < 1.0
async def test_sequential_stages_with_dependencies(self):
"""有依赖的 stages 按序执行"""
engine = PipelineEngine(dispatcher=None)
pipeline = _make_pipeline("sequential_test", [
_make_stage("step1", outputs=["result1"]),
_make_stage("step2", depends_on=["step1"], inputs={"data": "${result1}"}),
])
result = await engine.execute(pipeline)
assert result.status == StageStatus.COMPLETED
assert "step1" in result.stage_results
assert "step2" in result.stage_results
async def test_mixed_parallel_and_sequential(self):
"""混合并行和串行A+B 并行 → C 依赖 A+B"""
engine = PipelineEngine(dispatcher=None)
pipeline = _make_pipeline("mixed_test", [
_make_stage("fetch_data", outputs=["raw_data"]),
_make_stage("fetch_config", outputs=["config"]),
_make_stage("process", depends_on=["fetch_data", "fetch_config"],
inputs={"data": "${raw_data}", "cfg": "${config}"}),
])
result = await engine.execute(pipeline)
assert result.status == StageStatus.COMPLETED
groups = PipelineEngine._topological_group(pipeline.stages)
assert len(groups) == 2
assert len(groups[0]) == 2
assert len(groups[1]) == 1
async def test_pipeline_with_condition(self):
"""条件 stage满足条件时执行"""
engine = PipelineEngine(dispatcher=None)
pipeline = _make_pipeline("conditional_test", [
_make_stage("always_run"),
_make_stage("conditional_run", condition="run_analysis"),
], variables={"run_analysis": True})
result = await engine.execute(pipeline)
assert result.stage_results["always_run"].status == StageStatus.COMPLETED
assert result.stage_results["conditional_run"].status == StageStatus.COMPLETED
async def test_pipeline_condition_skip(self):
"""条件 stage不满足条件时跳过"""
engine = PipelineEngine(dispatcher=None)
pipeline = _make_pipeline("skip_test", [
_make_stage("always_run"),
_make_stage("conditional_run", condition="run_analysis"),
], variables={"run_analysis": False})
result = await engine.execute(pipeline)
assert result.stage_results["always_run"].status == StageStatus.COMPLETED
assert result.stage_results["conditional_run"].status == StageStatus.SKIPPED
async def test_pipeline_circular_dependency(self):
"""循环依赖检测"""
engine = PipelineEngine(dispatcher=None)
pipeline = _make_pipeline("circular_test", [
_make_stage("a", depends_on=["b"]),
_make_stage("b", depends_on=["a"]),
])
result = await engine.execute(pipeline)
assert result.status == StageStatus.FAILED
assert "Circular" in result.error_message
async def test_pipeline_variable_resolution(self):
"""变量解析 ${var.path}"""
engine = PipelineEngine(dispatcher=None)
pipeline = _make_pipeline("var_test", [
_make_stage("step1", inputs={"url": "${source.url}"}),
], variables={"source": {"url": "https://example.com"}})
result = await engine.execute(pipeline)
step_result = result.stage_results["step1"]
assert step_result.output_data["dry_run"] is True
assert step_result.output_data["inputs"]["url"] == "https://example.com"
# ── Handoff 测试 ─────────────────────────────────────────
class TestHandoff:
async def test_handoff_message_creation(self):
"""HandoffMessage 正确创建"""
msg = HandoffMessage(
source_agent="agent_a",
target_agent="agent_b",
task_id="t-001",
task_type="analysis",
context={"data": "test"},
reason="Need specialized analysis",
)
assert msg.source_agent == "agent_a"
assert msg.target_agent == "agent_b"
assert msg.reason == "Need specialized analysis"
async def test_handoff_message_serialization(self):
"""HandoffMessage 序列化/反序列化"""
msg = HandoffMessage(
source_agent="a", target_agent="b",
task_id="t-001", task_type="test",
context={"key": "value"}, reason="test handoff",
)
data = msg.to_dict()
restored = HandoffMessage.from_dict(data)
assert restored.source_agent == "a"
assert restored.context["key"] == "value"
async def test_handoff_manager_register_handler(self):
"""HandoffManager 注册处理器"""
manager = HandoffManager()
manager.register_handler("agent_b", lambda h: None)
assert "agent_b" in manager._handlers
async def test_handoff_manager_without_redis(self):
"""无 Redis 时 send_handoff 抛异常"""
manager = HandoffManager()
msg = HandoffMessage(
source_agent="a", target_agent="b",
task_id="t-001", task_type="test",
context={}, reason="test",
)
with pytest.raises(RuntimeError, match="Redis"):
await manager.send_handoff(msg)
# ── DynamicPipeline 测试 ─────────────────────────────────
class TestDynamicPipeline:
async def test_conditional_pipeline(self):
"""根据条件选择子 Pipeline"""
engine = PipelineEngine(dispatcher=None)
dynamic = DynamicPipeline(engine=engine)
pipeline_a = _make_pipeline("pipeline_a", [_make_stage("step_a")])
pipeline_b = _make_pipeline("pipeline_b", [_make_stage("step_b")])
result = await dynamic.execute_conditional(
pipelines={"type_a": pipeline_a, "type_b": pipeline_b},
condition_key="task_type",
context={"task_type": "type_a"},
)
assert result.status == StageStatus.COMPLETED
assert "step_a" in result.stage_results
async def test_conditional_pipeline_no_match(self):
"""条件不匹配时失败"""
engine = PipelineEngine(dispatcher=None)
dynamic = DynamicPipeline(engine=engine)
pipeline_a = _make_pipeline("pipeline_a", [_make_stage("step_a")])
result = await dynamic.execute_conditional(
pipelines={"type_a": pipeline_a},
condition_key="task_type",
context={"task_type": "type_z"},
)
assert result.status == StageStatus.FAILED
async def test_loop_pipeline_exits_on_condition(self):
"""循环 Pipeline 在条件满足时退出"""
engine = PipelineEngine(dispatcher=None)
dynamic = DynamicPipeline(engine=engine)
pipeline = _make_pipeline("loop_test", [_make_stage("iterate")])
result = await dynamic.execute_loop(
pipeline=pipeline,
max_iterations=3,
exit_condition="done",
context={"done": True},
)
assert result.status == StageStatus.COMPLETED
async def test_loop_pipeline_max_iterations(self):
"""循环 Pipeline 达到最大迭代次数"""
engine = PipelineEngine(dispatcher=None)
dynamic = DynamicPipeline(engine=engine)
pipeline = _make_pipeline("loop_max_test", [_make_stage("iterate")])
result = await dynamic.execute_loop(
pipeline=pipeline,
max_iterations=2,
exit_condition="done",
context={},
)
assert result.status == StageStatus.COMPLETED
async def test_nested_pipeline(self):
"""嵌套 Pipeline 执行"""
engine = PipelineEngine(dispatcher=None)
dynamic = DynamicPipeline(engine=engine)
parent = _make_pipeline("parent", [_make_stage("parent_step")])
sub = _make_pipeline("sub_pipeline", [_make_stage("sub_step")])
result = await dynamic.execute_nested(
parent=parent,
sub_pipeline_map={"sub_pipeline": sub},
)
assert result.status == StageStatus.COMPLETED
# ── 端到端集成测试 ───────────────────────────────────────
class TestOrchestratorIntegration:
async def test_full_pipeline_with_handoff_message(self):
"""完整 Pipeline + Handoff 消息流转"""
engine = PipelineEngine(dispatcher=None)
pipeline = _make_pipeline("content_production", [
_make_stage("research", agent="research_agent", action="search",
inputs={"query": "${topic}"}, outputs=["research_data"]),
_make_stage("generate", agent="content_agent", action="generate",
depends_on=["research"],
inputs={"data": "${research_data}"}, outputs=["draft"]),
_make_stage("optimize", agent="seo_agent", action="optimize",
depends_on=["generate"],
inputs={"content": "${draft}"}, outputs=["final_content"]),
], variables={"topic": "AI trends 2026"})
result = await engine.execute(pipeline)
assert result.status == StageStatus.COMPLETED
assert len(result.stage_results) == 3
handoff = HandoffMessage(
source_agent="research_agent",
target_agent="content_agent",
task_id="t-handoff-001",
task_type="handoff_research_to_content",
context={"research_data": "AI market analysis..."},
reason="Research complete, handing off to content generation",
)
assert handoff.source_agent == "research_agent"
assert handoff.context["research_data"] == "AI market analysis..."
async def test_parallel_pipeline_with_variables(self):
"""并行 Pipeline + 变量传递"""
engine = PipelineEngine(dispatcher=None)
pipeline = _make_pipeline("parallel_with_vars", [
_make_stage("check_citation", agent="citation_agent",
inputs={"url": "${url}"}, outputs=["citation_result"]),
_make_stage("check_trends", agent="trend_agent",
inputs={"url": "${url}"}, outputs=["trend_result"]),
_make_stage("compile_report", agent="report_agent",
depends_on=["check_citation", "check_trends"],
inputs={"citation": "${citation_result}", "trends": "${trend_result}"}),
], variables={"url": "https://example.com"})
result = await engine.execute(pipeline)
assert result.status == StageStatus.COMPLETED
groups = PipelineEngine._topological_group(pipeline.stages)
assert len(groups[0]) == 2
assert len(groups[1]) == 1