fischer-agentkit/tests/unit/experts/test_team_orchestrator.py

669 lines
22 KiB
Python

"""TeamOrchestrator 单元测试"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from agentkit.core.handoff_transport import InProcessHandoffTransport
from agentkit.experts.config import ExpertConfig
from agentkit.experts.expert import Expert
from agentkit.experts.orchestrator import TeamOrchestrator
from agentkit.experts.plan import (
CollaborationPlan,
MergeStrategy,
ParallelType,
PhaseStatus,
PlanPhase,
PlanStatus,
)
from agentkit.experts.team import ExpertTeam, TeamStatus
# ── 辅助函数 ──────────────────────────────────────────────
def _make_expert_config(
name: str = "test_expert",
is_lead: bool = False,
) -> ExpertConfig:
"""创建测试用 ExpertConfig"""
return ExpertConfig(
name=name,
agent_type="expert",
persona="测试专家",
thinking_style="逻辑推理",
bound_skills=["skill_a"],
is_lead=is_lead,
task_mode="llm_generate",
prompt={"identity": "测试"},
)
def _make_mock_expert(
name: str = "test_expert",
is_lead: bool = False,
is_active: bool = True,
) -> MagicMock:
"""创建 mock Expert"""
config = _make_expert_config(name=name, is_lead=is_lead)
expert = MagicMock(spec=Expert)
expert.config = config
expert.is_active = is_active
expert.team_id = None
expert.get_capabilities_summary.return_value = {
"name": name,
"persona": config.persona,
"thinking_style": config.thinking_style,
"bound_skills": config.bound_skills,
"is_lead": is_lead,
}
return expert
def _make_team_with_experts(
expert_names: list[str] | None = None,
lead_name: str = "lead",
) -> ExpertTeam:
"""创建包含 mock experts 的 ExpertTeam"""
team = ExpertTeam()
transport = AsyncMock(spec=InProcessHandoffTransport)
team._handoff_transport = transport
if expert_names is None:
expert_names = [lead_name, "member1", "member2"]
for name in expert_names:
is_lead = name == lead_name
expert = _make_mock_expert(name=name, is_lead=is_lead)
team._experts[name] = expert
if is_lead:
team._lead_expert_name = name
return team
def _make_serial_plan(
plan_id: str = "plan_1",
task: str = "测试任务",
lead_expert: str = "lead",
num_phases: int = 1,
) -> CollaborationPlan:
"""创建串行阶段的 CollaborationPlan"""
phases = []
for i in range(num_phases):
deps = [f"phase_{i}"] if i > 0 else []
phases.append(
PlanPhase(
id=f"phase_{i + 1}",
name=f"阶段{i + 1}",
assigned_expert=lead_expert,
task_description=f"执行任务{i + 1}",
depends_on=deps,
parallel_type=ParallelType.SERIAL,
)
)
return CollaborationPlan(
id=plan_id,
task=task,
phases=phases,
lead_expert=lead_expert,
)
def _make_parallel_plan(
plan_id: str = "plan_parallel",
task: str = "并行测试任务",
parallel_type: ParallelType = ParallelType.SUBTASK_PARALLEL,
merge_strategy: MergeStrategy | None = None,
) -> CollaborationPlan:
"""创建并行阶段的 CollaborationPlan"""
phases = [
PlanPhase(
id="phase_1",
name="并行阶段1",
assigned_expert="member1",
task_description="并行任务1",
parallel_type=parallel_type,
merge_strategy=merge_strategy,
),
PlanPhase(
id="phase_2",
name="并行阶段2",
assigned_expert="member2",
task_description="并行任务2",
parallel_type=parallel_type,
merge_strategy=merge_strategy,
),
]
return CollaborationPlan(
id=plan_id,
task=task,
phases=phases,
lead_expert="lead",
)
# ── 串行阶段执行测试 ──────────────────────────────────────
class TestSerialPhaseExecution:
"""串行阶段执行测试"""
@pytest.mark.asyncio
async def test_single_serial_phase_completes(self):
"""单个串行阶段执行完成"""
team = _make_team_with_experts()
orchestrator = TeamOrchestrator(team)
plan = _make_serial_plan(num_phases=1)
result = await orchestrator.execute_plan(plan)
assert result["status"] == "completed"
assert "phase_1" in result["phase_results"]
assert plan.phases[0].status == PhaseStatus.COMPLETED
@pytest.mark.asyncio
async def test_multiple_serial_phases_in_order(self):
"""多个串行阶段按依赖顺序执行"""
team = _make_team_with_experts()
orchestrator = TeamOrchestrator(team)
plan = _make_serial_plan(num_phases=3)
result = await orchestrator.execute_plan(plan)
assert result["status"] == "completed"
assert len(result["phase_results"]) == 3
# All phases should be completed
for phase in plan.phases:
assert phase.status == PhaseStatus.COMPLETED
@pytest.mark.asyncio
async def test_serial_phase_sets_plan_and_team_status(self):
"""执行计划时设置 plan 和 team 状态"""
team = _make_team_with_experts()
orchestrator = TeamOrchestrator(team)
plan = _make_serial_plan()
await orchestrator.execute_plan(plan)
assert plan.status == PlanStatus.COMPLETED
assert team._status == TeamStatus.COMPLETED
# ── 子任务并行阶段执行测试 ────────────────────────────────
class TestSubtaskParallelExecution:
"""子任务并行阶段执行测试"""
@pytest.mark.asyncio
async def test_subtask_parallel_phases_execute(self):
"""子任务并行阶段并行执行"""
team = _make_team_with_experts()
orchestrator = TeamOrchestrator(team)
plan = _make_parallel_plan(parallel_type=ParallelType.SUBTASK_PARALLEL)
result = await orchestrator.execute_plan(plan)
assert result["status"] == "completed"
assert "phase_1" in result["phase_results"]
assert "phase_2" in result["phase_results"]
@pytest.mark.asyncio
async def test_subtask_parallel_phase_failure_recorded(self):
"""子任务并行阶段失败时记录错误"""
team = _make_team_with_experts()
orchestrator = TeamOrchestrator(team)
plan = _make_parallel_plan(parallel_type=ParallelType.SUBTASK_PARALLEL)
# Mock _execute_phase to raise for one phase
original_execute = orchestrator._execute_phase
call_count = 0
async def mock_execute_phase(phase, p, pr):
nonlocal call_count
call_count += 1
if phase.id == "phase_1":
raise RuntimeError("Simulated failure")
return await original_execute_phase(phase, p, pr)
with patch.object(
orchestrator, "_execute_phase", side_effect=mock_execute_phase
):
result = await orchestrator.execute_plan(plan)
# The exception should be caught by asyncio.gather(return_exceptions=True)
assert "phase_1" in result["phase_results"]
assert "error" in result["phase_results"]["phase_1"]
# ── 竞争并行阶段测试 ──────────────────────────────────────
class TestCompetitiveParallelExecution:
"""竞争并行阶段执行测试"""
@pytest.mark.asyncio
async def test_competitive_parallel_best_strategy(self):
"""竞争并行阶段使用 BEST 合并策略"""
team = _make_team_with_experts()
orchestrator = TeamOrchestrator(team)
plan = _make_parallel_plan(
parallel_type=ParallelType.COMPETITIVE_PARALLEL,
merge_strategy=MergeStrategy.BEST,
)
result = await orchestrator.execute_plan(plan)
assert result["status"] == "completed"
# Competitive phases are merged into one result per phase
for phase_id in ["phase_1", "phase_2"]:
assert phase_id in result["phase_results"]
phase_result = result["phase_results"][phase_id]
assert phase_result.get("merged") is True
assert phase_result.get("strategy") == "best"
@pytest.mark.asyncio
async def test_competitive_parallel_vote_strategy(self):
"""竞争并行阶段使用 VOTE 合并策略"""
team = _make_team_with_experts()
orchestrator = TeamOrchestrator(team)
plan = _make_parallel_plan(
parallel_type=ParallelType.COMPETITIVE_PARALLEL,
merge_strategy=MergeStrategy.VOTE,
)
result = await orchestrator.execute_plan(plan)
assert result["status"] == "completed"
for phase_id in ["phase_1", "phase_2"]:
phase_result = result["phase_results"][phase_id]
assert phase_result.get("merged") is True
assert phase_result.get("strategy") == "vote"
@pytest.mark.asyncio
async def test_competitive_parallel_fusion_strategy(self):
"""竞争并行阶段使用 FUSION 合并策略"""
team = _make_team_with_experts()
orchestrator = TeamOrchestrator(team)
plan = _make_parallel_plan(
parallel_type=ParallelType.COMPETITIVE_PARALLEL,
merge_strategy=MergeStrategy.FUSION,
)
result = await orchestrator.execute_plan(plan)
assert result["status"] == "completed"
for phase_id in ["phase_1", "phase_2"]:
phase_result = result["phase_results"][phase_id]
assert phase_result.get("merged") is True
assert phase_result.get("strategy") == "fusion"
assert phase_result.get("fused_from") == 3 # 3 active experts
# ── 里程碑检查点测试 ──────────────────────────────────────
class TestMilestoneCheckpoint:
"""里程碑检查点测试"""
@pytest.mark.asyncio
async def test_milestone_pass(self):
"""里程碑检查通过"""
team = _make_team_with_experts()
orchestrator = TeamOrchestrator(team)
plan = CollaborationPlan(
id="plan_milestone",
task="里程碑测试",
phases=[
PlanPhase(
id="phase_1",
name="带里程碑阶段",
assigned_expert="lead",
task_description="执行带里程碑的任务",
milestone="输出质量达标",
)
],
lead_expert="lead",
)
result = await orchestrator.execute_plan(plan)
assert result["status"] == "completed"
assert plan.phases[0].status == PhaseStatus.COMPLETED
@pytest.mark.asyncio
async def test_milestone_fail_phase_failed(self):
"""里程碑检查失败 → 阶段状态为 FAILED"""
team = _make_team_with_experts()
orchestrator = TeamOrchestrator(team)
plan = CollaborationPlan(
id="plan_milestone_fail",
task="里程碑失败测试",
phases=[
PlanPhase(
id="phase_1",
name="带里程碑阶段",
assigned_expert="lead",
task_description="执行带里程碑的任务",
milestone="输出质量达标",
)
],
lead_expert="lead",
)
# Mock _check_milestone to return False
with patch.object(
orchestrator, "_check_milestone", return_value=False
):
result = await orchestrator.execute_plan(plan)
assert plan.phases[0].status == PhaseStatus.FAILED
# Phase failed → retry → still failed → fallback
assert result["status"] == "fallback"
# ── 重试与回退测试 ────────────────────────────────────────
class TestRetryAndFallback:
"""重试与回退测试"""
@pytest.mark.asyncio
async def test_phase_failure_triggers_retry(self):
"""阶段失败触发重试"""
team = _make_team_with_experts()
orchestrator = TeamOrchestrator(team)
plan = _make_serial_plan(num_phases=1)
# Mock _execute_phase: first call returns None, second call succeeds
call_count = 0
async def mock_execute_phase(phase, p, pr):
nonlocal call_count
call_count += 1
if call_count == 1:
# First call fails
p.update_phase_status(phase.id, PhaseStatus.FAILED)
return None
# Retry succeeds — simulate a successful phase execution
p.update_phase_status(phase.id, PhaseStatus.COMPLETED, {"output": "retry ok"})
return {"output": "retry ok"}
with patch.object(
orchestrator, "_execute_phase", side_effect=mock_execute_phase
):
result = await orchestrator.execute_plan(plan)
# After retry, the phase should succeed
assert call_count == 2
assert result["status"] == "completed"
@pytest.mark.asyncio
async def test_retry_failure_triggers_fallback(self):
"""重试仍然失败 → 回退到单 Agent 模式"""
team = _make_team_with_experts()
orchestrator = TeamOrchestrator(team)
plan = _make_serial_plan(num_phases=1)
# Mock _execute_phase to always return None (failure)
async def mock_execute_phase(phase, p, pr):
plan.update_phase_status(phase.id, PhaseStatus.FAILED)
return None
with patch.object(
orchestrator, "_execute_phase", side_effect=mock_execute_phase
):
result = await orchestrator.execute_plan(plan)
assert result["status"] == "fallback"
assert plan.status == PlanStatus.FALLBACK
# ── 最大交互轮次测试 ──────────────────────────────────────
class TestMaxInteractionRounds:
"""最大交互轮次限制测试"""
@pytest.mark.asyncio
async def test_max_interaction_rounds_limit(self):
"""超过最大交互轮次时停止执行"""
team = _make_team_with_experts()
orchestrator = TeamOrchestrator(team)
orchestrator.MAX_INTERACTION_ROUNDS = 1
# Create a plan with many phases that would take many rounds
plan = _make_serial_plan(num_phases=5)
result = await orchestrator.execute_plan(plan)
# Should stop after 1 round, not completing all phases
# Only the first phase should complete (1 interaction round)
assert orchestrator._interaction_count >= 1
# ── 无效计划测试 ──────────────────────────────────────────
class TestInvalidPlan:
"""无效计划测试"""
@pytest.mark.asyncio
async def test_invalid_plan_returns_failed_status(self):
"""无效计划返回 failed 状态"""
team = _make_team_with_experts()
orchestrator = TeamOrchestrator(team)
# Create invalid plan with circular dependency
plan = CollaborationPlan(
id="invalid_plan",
task="无效任务",
phases=[
PlanPhase(
id="p1",
name="阶段1",
assigned_expert="lead",
task_description="t1",
depends_on=["p2"],
),
PlanPhase(
id="p2",
name="阶段2",
assigned_expert="lead",
task_description="t2",
depends_on=["p1"],
),
],
lead_expert="lead",
)
result = await orchestrator.execute_plan(plan)
assert result["status"] == "failed"
assert "errors" in result
assert len(result["errors"]) > 0
# ── 结果综合测试 ──────────────────────────────────────────
class TestSynthesizeResults:
"""结果综合测试"""
@pytest.mark.asyncio
async def test_synthesize_results(self):
"""综合所有阶段结果"""
team = _make_team_with_experts()
orchestrator = TeamOrchestrator(team)
plan = _make_serial_plan(num_phases=2)
result = await orchestrator.execute_plan(plan)
assert result["status"] == "completed"
final = result["result"]
assert final["task"] == "测试任务"
assert final["phases_completed"] == 2
assert final["phases_total"] == 2
assert len(final["results"]) == 2
@pytest.mark.asyncio
async def test_synthesize_results_only_completed_phases(self):
"""只综合已完成阶段的结果"""
team = _make_team_with_experts()
orchestrator = TeamOrchestrator(team)
plan = CollaborationPlan(
id="plan_partial",
task="部分完成测试",
phases=[
PlanPhase(
id="phase_1",
name="完成阶段",
assigned_expert="lead",
task_description="任务1",
),
PlanPhase(
id="phase_2",
name="依赖阶段",
assigned_expert="member1",
task_description="任务2",
depends_on=["phase_1"],
),
],
lead_expert="lead",
)
# Manually set phase_1 as completed, phase_2 as pending
plan.update_phase_status("phase_1", PhaseStatus.COMPLETED, {"output": "done"})
# Synthesize directly
phase_results = {"phase_1": {"output": "done"}}
result = await orchestrator._synthesize_results(plan, phase_results)
assert result["phases_completed"] == 1
assert result["phases_total"] == 2
# ── 事件广播测试 ──────────────────────────────────────────
class TestBroadcastEvent:
"""事件广播测试"""
@pytest.mark.asyncio
async def test_broadcast_event_sends_to_transport(self):
"""广播事件通过 handoff_transport 发送"""
team = _make_team_with_experts()
orchestrator = TeamOrchestrator(team)
await orchestrator._broadcast_event("test_event", {"key": "value"})
team._handoff_transport.send.assert_awaited_once()
call_args = team._handoff_transport.send.call_args
assert call_args[0][0] == team._team_channel
message = call_args[0][1]
assert message["type"] == "test_event"
assert message["key"] == "value"
@pytest.mark.asyncio
async def test_broadcast_event_no_transport(self):
"""没有 handoff_transport 时不报错"""
team = _make_team_with_experts()
team._handoff_transport = None
orchestrator = TeamOrchestrator(team)
# Should not raise
await orchestrator._broadcast_event("test_event", {"key": "value"})
@pytest.mark.asyncio
async def test_phase_execution_broadcasts_events(self):
"""阶段执行时广播 phase_started 和 phase_completed 事件"""
team = _make_team_with_experts()
orchestrator = TeamOrchestrator(team)
plan = _make_serial_plan(num_phases=1)
await orchestrator.execute_plan(plan)
calls = team._handoff_transport.send.call_args_list
event_types = [c[0][1]["type"] for c in calls]
assert "phase_started" in event_types
assert "phase_completed" in event_types
# ── 竞争并行全部失败测试 ──────────────────────────────────
class TestCompetitiveAllFail:
"""竞争并行全部失败测试"""
@pytest.mark.asyncio
async def test_all_competitors_fail(self):
"""所有竞争者都失败时触发 fallback"""
team = _make_team_with_experts()
orchestrator = TeamOrchestrator(team)
# Mock _run_competitor to always raise
async def mock_run_competitor(expert, phase):
raise RuntimeError("Competitor failed")
with patch.object(
orchestrator, "_run_competitor", side_effect=mock_run_competitor
):
plan = _make_parallel_plan(
parallel_type=ParallelType.COMPETITIVE_PARALLEL,
merge_strategy=MergeStrategy.BEST,
)
result = await orchestrator.execute_plan(plan)
# All competitors failed → triggers fallback
assert result["status"] == "fallback"
# ── Expert 不可用测试 ────────────────────────────────────
class TestExpertUnavailable:
"""Expert 不可用测试"""
@pytest.mark.asyncio
async def test_inactive_expert_causes_phase_failure(self):
"""分配的 Expert 不活跃导致阶段失败"""
team = _make_team_with_experts()
# Mark the lead expert as inactive
team._experts["lead"].is_active = False
orchestrator = TeamOrchestrator(team)
plan = _make_serial_plan(num_phases=1)
result = await orchestrator.execute_plan(plan)
# Phase should fail because expert is not active → retry → still fail → fallback
assert result["status"] == "fallback"
@pytest.mark.asyncio
async def test_nonexistent_expert_causes_phase_failure(self):
"""分配的 Expert 不存在导致阶段失败"""
team = _make_team_with_experts()
orchestrator = TeamOrchestrator(team)
plan = CollaborationPlan(
id="plan_no_expert",
task="无专家测试",
phases=[
PlanPhase(
id="phase_1",
name="无专家阶段",
assigned_expert="nonexistent_expert",
task_description="执行任务",
)
],
lead_expert="lead",
)
result = await orchestrator.execute_plan(plan)
# Expert doesn't exist → phase fails → retry → still fails → fallback
assert result["status"] == "fallback"