527 lines
20 KiB
Python
527 lines
20 KiB
Python
"""Tests for _execute_team_collab in chat.py (U6).
|
||
|
||
Tests cover:
|
||
- @team prefix triggers team collaboration (returns True)
|
||
- Non-@team input does not trigger (returns False)
|
||
- @team without task content sends error
|
||
- Team events are relayed to WebSocket via emit_team_event
|
||
- final_answer is sent after execution
|
||
- User message and final result are persisted to session history
|
||
- Team is dissolved after execution
|
||
- Execution failure sends error and dissolves team
|
||
- @team and @board do not interfere with each other
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
import pytest
|
||
|
||
from agentkit.experts.config import ExpertConfig, ExpertTemplate
|
||
from agentkit.experts.registry import ExpertTemplateRegistry
|
||
from agentkit.server.routes.chat import _execute_team_collab
|
||
from agentkit.session.manager import SessionManager
|
||
from agentkit.session.models import MessageRole
|
||
from agentkit.session.store import InMemorySessionStore
|
||
|
||
|
||
# ── 辅助函数 ──────────────────────────────────────────────
|
||
|
||
|
||
def _make_expert_template(
|
||
name: str,
|
||
persona: str = "测试专家",
|
||
is_lead: bool = False,
|
||
) -> ExpertTemplate:
|
||
"""创建测试用 ExpertTemplate"""
|
||
config = ExpertConfig(
|
||
name=name,
|
||
agent_type="expert",
|
||
persona=persona,
|
||
thinking_style="analytical",
|
||
bound_skills=[],
|
||
is_lead=is_lead,
|
||
task_mode="llm_generate",
|
||
prompt={"identity": persona},
|
||
)
|
||
return ExpertTemplate(
|
||
name=name,
|
||
config=config,
|
||
is_builtin=True,
|
||
description=f"{name} 模板",
|
||
)
|
||
|
||
|
||
def _make_registry_with_dev_team() -> ExpertTemplateRegistry:
|
||
"""创建包含 dev_team 模板和成员模板的注册中心"""
|
||
registry = ExpertTemplateRegistry()
|
||
registry.register(_make_expert_template("tech_lead", persona="技术负责人"))
|
||
registry.register(_make_expert_template("frontend_engineer", persona="前端工程师"))
|
||
registry.register(_make_expert_template("backend_engineer", persona="后端工程师"))
|
||
# dev_team 模板(bound_skills 存储成员列表)
|
||
registry.register(
|
||
ExpertTemplate(
|
||
name="dev_team",
|
||
config=ExpertConfig(
|
||
name="dev_team",
|
||
agent_type="expert",
|
||
persona="编程团队",
|
||
thinking_style="流水线",
|
||
bound_skills=["tech_lead", "frontend_engineer", "backend_engineer"],
|
||
task_mode="llm_generate",
|
||
prompt={"identity": "Dev Team"},
|
||
),
|
||
is_builtin=True,
|
||
description="编程团队模板",
|
||
)
|
||
)
|
||
return registry
|
||
|
||
|
||
class FakeWebSocket:
|
||
"""Minimal WebSocket fake for testing."""
|
||
|
||
def __init__(self) -> None:
|
||
self.sent: list[dict] = []
|
||
self.app = MagicMock()
|
||
self.app.state.agent_pool = MagicMock()
|
||
self.app.state.expert_template_registry = None # Will be set per-test
|
||
|
||
async def send_json(self, data: dict) -> None:
|
||
self.sent.append(data)
|
||
|
||
|
||
def _make_mock_team() -> MagicMock:
|
||
"""创建 mock ExpertTeam 实例"""
|
||
mock_team = MagicMock()
|
||
mock_team.team_channel = "team:test"
|
||
mock_team.handoff_transport = MagicMock()
|
||
mock_team.handoff_transport.register_handler = MagicMock()
|
||
mock_team.handoff_transport._handlers = {}
|
||
mock_team.create_team = AsyncMock()
|
||
mock_team.dissolve = AsyncMock()
|
||
return mock_team
|
||
|
||
|
||
def _make_mock_orchestrator(result: dict) -> MagicMock:
|
||
"""创建 mock TeamOrchestrator 实例"""
|
||
mock_orch = MagicMock()
|
||
mock_orch.execute = AsyncMock(return_value=result)
|
||
return mock_orch
|
||
|
||
|
||
@pytest.fixture
|
||
async def session_manager() -> SessionManager:
|
||
sm = SessionManager(store=InMemorySessionStore())
|
||
session = await sm.create_session(agent_name="test-agent")
|
||
sm._test_session_id = session.session_id
|
||
return sm
|
||
|
||
|
||
@pytest.fixture
|
||
def websocket() -> FakeWebSocket:
|
||
ws = FakeWebSocket()
|
||
ws.app.state.expert_template_registry = _make_registry_with_dev_team()
|
||
return ws
|
||
|
||
|
||
@pytest.fixture
|
||
def mock_orchestrator_result() -> dict:
|
||
"""Mock result from TeamOrchestrator.execute()"""
|
||
return {
|
||
"status": "completed",
|
||
"result": {"content": "## 团队最终结果\n\n用户登录功能已实现"},
|
||
"phase_results": {
|
||
"phase-1": {"content": "规划完成", "phase_name": "规划"},
|
||
"phase-2": {"content": "前端实现完成", "phase_name": "前端"},
|
||
},
|
||
"plan": MagicMock(),
|
||
}
|
||
|
||
|
||
# ── 路由匹配测试 ──────────────────────────────────────────
|
||
|
||
|
||
class TestTeamCollabRouting:
|
||
"""_execute_team_collab 路由匹配测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_team_prefix_triggers_collab(
|
||
self, websocket, session_manager, mock_orchestrator_result
|
||
):
|
||
"""@team 前缀触发团队协作,返回 True"""
|
||
session_id = session_manager._test_session_id
|
||
with patch(
|
||
"agentkit.experts.team.ExpertTeam"
|
||
) as mock_team_cls, patch(
|
||
"agentkit.experts.orchestrator.TeamOrchestrator"
|
||
) as mock_orch_cls:
|
||
mock_team_cls.return_value = _make_mock_team()
|
||
mock_orch_cls.return_value = _make_mock_orchestrator(mock_orchestrator_result)
|
||
|
||
result = await _execute_team_collab(
|
||
websocket, session_id, "@team 开发用户登录功能", session_manager
|
||
)
|
||
|
||
assert result is True
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_non_team_input_does_not_trigger(self, websocket, session_manager):
|
||
"""非 @team 输入不触发,返回 False"""
|
||
session_id = session_manager._test_session_id
|
||
result = await _execute_team_collab(
|
||
websocket, session_id, "普通问题", session_manager
|
||
)
|
||
assert result is False
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_board_prefix_does_not_trigger_team(
|
||
self, websocket, session_manager
|
||
):
|
||
"""@board 前缀不触发 @team 协作"""
|
||
session_id = session_manager._test_session_id
|
||
result = await _execute_team_collab(
|
||
websocket, session_id, "@board 讨论主题", session_manager
|
||
)
|
||
assert result is False
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_team_with_dev_team_template(
|
||
self, websocket, session_manager, mock_orchestrator_result
|
||
):
|
||
"""@team:dev_team 触发团队协作"""
|
||
session_id = session_manager._test_session_id
|
||
with patch(
|
||
"agentkit.experts.team.ExpertTeam"
|
||
) as mock_team_cls, patch(
|
||
"agentkit.experts.orchestrator.TeamOrchestrator"
|
||
) as mock_orch_cls:
|
||
mock_team_cls.return_value = _make_mock_team()
|
||
mock_orch_cls.return_value = _make_mock_orchestrator(mock_orchestrator_result)
|
||
|
||
result = await _execute_team_collab(
|
||
websocket, session_id, "@team:dev_team 开发功能", session_manager
|
||
)
|
||
|
||
assert result is True
|
||
|
||
|
||
# ── 错误处理测试 ──────────────────────────────────────────
|
||
|
||
|
||
class TestTeamCollabErrorHandling:
|
||
"""_execute_team_collab 错误处理测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_team_without_task_sends_error(
|
||
self, websocket, session_manager
|
||
):
|
||
"""@team 无任务内容时发送错误(通过 mock router 模拟空 task_content)"""
|
||
session_id = session_manager._test_session_id
|
||
# Mock ExpertTeamRouter to return empty task_content
|
||
mock_routing_result = MagicMock()
|
||
mock_routing_result.matched = True
|
||
mock_routing_result.task_content = "" # 空 task_content
|
||
mock_routing_result.specified_experts = ["tech_lead"]
|
||
|
||
with patch(
|
||
"agentkit.experts.router.ExpertTeamRouter"
|
||
) as mock_router_cls:
|
||
mock_router = MagicMock()
|
||
mock_router.resolve = MagicMock(return_value=mock_routing_result)
|
||
mock_router_cls.return_value = mock_router
|
||
|
||
result = await _execute_team_collab(
|
||
websocket, session_id, "@team", session_manager
|
||
)
|
||
|
||
assert result is True
|
||
# 应该发送了错误消息
|
||
assert any(
|
||
msg.get("type") == "error" and "描述" in msg.get("data", {}).get("message", "")
|
||
for msg in websocket.sent
|
||
)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_team_execution_failure_sends_error(
|
||
self, websocket, session_manager
|
||
):
|
||
"""团队执行失败时发送错误并清理"""
|
||
session_id = session_manager._test_session_id
|
||
with patch(
|
||
"agentkit.experts.team.ExpertTeam"
|
||
) as mock_team_cls, patch(
|
||
"agentkit.experts.orchestrator.TeamOrchestrator"
|
||
) as mock_orch_cls:
|
||
mock_team = _make_mock_team()
|
||
mock_team_cls.return_value = mock_team
|
||
|
||
mock_orch_cls.return_value = _make_mock_orchestrator_result_failing()
|
||
|
||
result = await _execute_team_collab(
|
||
websocket, session_id, "@team 开发功能", session_manager
|
||
)
|
||
|
||
assert result is True
|
||
# 应该发送了错误消息
|
||
assert any(
|
||
msg.get("type") == "error"
|
||
and "团队协作执行失败" in msg.get("data", {}).get("message", "")
|
||
for msg in websocket.sent
|
||
)
|
||
# 应该调用了 dissolve 清理
|
||
mock_team.dissolve.assert_called()
|
||
# 失败时不应发送 final_answer
|
||
assert not any(msg.get("type") == "final_answer" for msg in websocket.sent)
|
||
|
||
|
||
def _make_mock_orchestrator_result_failing() -> MagicMock:
|
||
"""创建 mock TeamOrchestrator that raises an exception"""
|
||
mock_orch = MagicMock()
|
||
mock_orch.execute = AsyncMock(side_effect=RuntimeError("LLM 不可用"))
|
||
return mock_orch
|
||
|
||
|
||
# ── 事件中继与持久化测试 ──────────────────────────────────
|
||
|
||
|
||
class TestTeamCollabEventRelay:
|
||
"""_execute_team_collab 事件中继与持久化测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_final_answer_sent_after_execution(
|
||
self, websocket, session_manager, mock_orchestrator_result
|
||
):
|
||
"""执行完成后发送 final_answer"""
|
||
session_id = session_manager._test_session_id
|
||
with patch(
|
||
"agentkit.experts.team.ExpertTeam"
|
||
) as mock_team_cls, patch(
|
||
"agentkit.experts.orchestrator.TeamOrchestrator"
|
||
) as mock_orch_cls:
|
||
mock_team_cls.return_value = _make_mock_team()
|
||
mock_orch_cls.return_value = _make_mock_orchestrator(mock_orchestrator_result)
|
||
|
||
await _execute_team_collab(
|
||
websocket, session_id, "@team 开发登录功能", session_manager
|
||
)
|
||
|
||
final_msgs = [msg for msg in websocket.sent if msg.get("type") == "final_answer"]
|
||
assert len(final_msgs) == 1
|
||
assert "团队最终结果" in final_msgs[0]["content"]
|
||
assert final_msgs[0]["is_final"] is True
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_user_message_persisted(
|
||
self, websocket, session_manager, mock_orchestrator_result
|
||
):
|
||
"""用户消息持久化到会话历史"""
|
||
session_id = session_manager._test_session_id
|
||
with patch(
|
||
"agentkit.experts.team.ExpertTeam"
|
||
) as mock_team_cls, patch(
|
||
"agentkit.experts.orchestrator.TeamOrchestrator"
|
||
) as mock_orch_cls:
|
||
mock_team_cls.return_value = _make_mock_team()
|
||
mock_orch_cls.return_value = _make_mock_orchestrator(mock_orchestrator_result)
|
||
|
||
await _execute_team_collab(
|
||
websocket, session_id, "@team 开发登录功能", session_manager
|
||
)
|
||
|
||
messages = await session_manager.get_messages(session_id)
|
||
user_msgs = [m for m in messages if m.role == MessageRole.USER]
|
||
assert len(user_msgs) == 1
|
||
assert "@team 开发登录功能" in user_msgs[0].content
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_final_result_persisted(
|
||
self, websocket, session_manager, mock_orchestrator_result
|
||
):
|
||
"""最终结果持久化为 assistant 消息"""
|
||
session_id = session_manager._test_session_id
|
||
with patch(
|
||
"agentkit.experts.team.ExpertTeam"
|
||
) as mock_team_cls, patch(
|
||
"agentkit.experts.orchestrator.TeamOrchestrator"
|
||
) as mock_orch_cls:
|
||
mock_team_cls.return_value = _make_mock_team()
|
||
mock_orch_cls.return_value = _make_mock_orchestrator(mock_orchestrator_result)
|
||
|
||
await _execute_team_collab(
|
||
websocket, session_id, "@team 开发登录功能", session_manager
|
||
)
|
||
|
||
messages = await session_manager.get_messages(session_id)
|
||
assistant_msgs = [m for m in messages if m.role == MessageRole.ASSISTANT]
|
||
assert len(assistant_msgs) == 1
|
||
assert "团队最终结果" in assistant_msgs[0].content
|
||
assert assistant_msgs[0].agent_name == "team_collab"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_team_dissolved_after_execution(
|
||
self, websocket, session_manager, mock_orchestrator_result
|
||
):
|
||
"""执行后 team.dissolve() 被调用"""
|
||
session_id = session_manager._test_session_id
|
||
with patch(
|
||
"agentkit.experts.team.ExpertTeam"
|
||
) as mock_team_cls, patch(
|
||
"agentkit.experts.orchestrator.TeamOrchestrator"
|
||
) as mock_orch_cls:
|
||
mock_team = _make_mock_team()
|
||
mock_team_cls.return_value = mock_team
|
||
mock_orch_cls.return_value = _make_mock_orchestrator(mock_orchestrator_result)
|
||
|
||
await _execute_team_collab(
|
||
websocket, session_id, "@team 开发登录功能", session_manager
|
||
)
|
||
|
||
mock_team.dissolve.assert_called_once()
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_handoff_handler_registered(
|
||
self, websocket, session_manager, mock_orchestrator_result
|
||
):
|
||
"""handoff_transport handler 被注册用于事件中继"""
|
||
session_id = session_manager._test_session_id
|
||
with patch(
|
||
"agentkit.experts.team.ExpertTeam"
|
||
) as mock_team_cls, patch(
|
||
"agentkit.experts.orchestrator.TeamOrchestrator"
|
||
) as mock_orch_cls:
|
||
mock_team = _make_mock_team()
|
||
mock_team.team_channel = "team:test-channel"
|
||
mock_team_cls.return_value = mock_team
|
||
mock_orch_cls.return_value = _make_mock_orchestrator(mock_orchestrator_result)
|
||
|
||
await _execute_team_collab(
|
||
websocket, session_id, "@team 开发登录功能", session_manager
|
||
)
|
||
|
||
mock_team.handoff_transport.register_handler.assert_called_once()
|
||
call_args = mock_team.handoff_transport.register_handler.call_args
|
||
assert call_args[0][0] == "team:test-channel"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_team_events_relayed_to_websocket(
|
||
self, websocket, session_manager, mock_orchestrator_result
|
||
):
|
||
"""注册的 handler 将团队事件转发到 WebSocket"""
|
||
session_id = session_manager._test_session_id
|
||
with patch(
|
||
"agentkit.experts.team.ExpertTeam"
|
||
) as mock_team_cls, patch(
|
||
"agentkit.experts.orchestrator.TeamOrchestrator"
|
||
) as mock_orch_cls:
|
||
mock_team = _make_mock_team()
|
||
mock_team.team_channel = "team:test-channel"
|
||
mock_team_cls.return_value = mock_team
|
||
mock_orch_cls.return_value = _make_mock_orchestrator(mock_orchestrator_result)
|
||
|
||
await _execute_team_collab(
|
||
websocket, session_id, "@team 开发功能", session_manager
|
||
)
|
||
|
||
# Get the registered handler and invoke it with a sample event
|
||
register_call = mock_team.handoff_transport.register_handler.call_args
|
||
handler = register_call[0][1] # second arg is the handler callback
|
||
await handler({"type": "expert_step", "expert": "lead", "message": "开始分析"})
|
||
|
||
# Verify the event was relayed to WebSocket via emit_team_event
|
||
team_event_msgs = [
|
||
msg for msg in websocket.sent
|
||
if msg.get("type") == "expert_step"
|
||
]
|
||
assert len(team_event_msgs) == 1
|
||
assert team_event_msgs[0].get("data", {}).get("expert") == "lead"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_create_team_called_with_lead_and_members(
|
||
self, websocket, session_manager, mock_orchestrator_result
|
||
):
|
||
"""create_team 以 lead_config 和 member_configs 调用"""
|
||
session_id = session_manager._test_session_id
|
||
with patch(
|
||
"agentkit.experts.team.ExpertTeam"
|
||
) as mock_team_cls, patch(
|
||
"agentkit.experts.orchestrator.TeamOrchestrator"
|
||
) as mock_orch_cls:
|
||
mock_team = _make_mock_team()
|
||
mock_team_cls.return_value = mock_team
|
||
mock_orch_cls.return_value = _make_mock_orchestrator(mock_orchestrator_result)
|
||
|
||
await _execute_team_collab(
|
||
websocket, session_id, "@team:dev_team 开发功能", session_manager
|
||
)
|
||
|
||
mock_team.create_team.assert_called_once()
|
||
call_kwargs = mock_team.create_team.call_args.kwargs
|
||
# lead_config 应该是第一个专家(tech_lead)
|
||
assert call_kwargs["lead_config"].name == "tech_lead"
|
||
# member_configs 应该包含其余专家
|
||
member_names = [c.name for c in call_kwargs["member_configs"]]
|
||
assert "frontend_engineer" in member_names
|
||
assert "backend_engineer" in member_names
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_empty_synthesis_falls_back_to_phase_results(
|
||
self, websocket, session_manager
|
||
):
|
||
"""synthesis 结果为空时 fallback 到 phase_results"""
|
||
session_id = session_manager._test_session_id
|
||
empty_result = {
|
||
"status": "completed",
|
||
"result": {"content": ""}, # 空的 synthesis
|
||
"phase_results": {
|
||
"phase-1": {"content": "阶段1结果", "phase_name": "规划"},
|
||
"phase-2": {"content": "阶段2结果", "phase_name": "执行"},
|
||
},
|
||
"plan": MagicMock(),
|
||
}
|
||
with patch(
|
||
"agentkit.experts.team.ExpertTeam"
|
||
) as mock_team_cls, patch(
|
||
"agentkit.experts.orchestrator.TeamOrchestrator"
|
||
) as mock_orch_cls:
|
||
mock_team_cls.return_value = _make_mock_team()
|
||
mock_orch_cls.return_value = _make_mock_orchestrator(empty_result)
|
||
|
||
await _execute_team_collab(
|
||
websocket, session_id, "@team 开发功能", session_manager
|
||
)
|
||
|
||
final_msgs = [msg for msg in websocket.sent if msg.get("type") == "final_answer"]
|
||
assert len(final_msgs) == 1
|
||
# 应该包含 phase_results 的内容
|
||
assert "阶段1结果" in final_msgs[0]["content"] and "阶段2结果" in final_msgs[0]["content"]
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_completely_empty_result(
|
||
self, websocket, session_manager
|
||
):
|
||
"""result 和 phase_results 都为空时发送默认消息"""
|
||
session_id = session_manager._test_session_id
|
||
empty_result = {
|
||
"status": "completed",
|
||
"result": {},
|
||
"phase_results": {},
|
||
"plan": MagicMock(),
|
||
}
|
||
with patch(
|
||
"agentkit.experts.team.ExpertTeam"
|
||
) as mock_team_cls, patch(
|
||
"agentkit.experts.orchestrator.TeamOrchestrator"
|
||
) as mock_orch_cls:
|
||
mock_team_cls.return_value = _make_mock_team()
|
||
mock_orch_cls.return_value = _make_mock_orchestrator(empty_result)
|
||
|
||
await _execute_team_collab(
|
||
websocket, session_id, "@team 开发功能", session_manager
|
||
)
|
||
|
||
final_msgs = [msg for msg in websocket.sent if msg.get("type") == "final_answer"]
|
||
assert len(final_msgs) == 1
|
||
assert "未生成最终结果" in final_msgs[0]["content"]
|