355 lines
13 KiB
Python
355 lines
13 KiB
Python
"""Tests for MCPManager lifecycle and tool discovery"""
|
|
|
|
import asyncio
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from agentkit.mcp.manager import MCPManager
|
|
from agentkit.mcp.transport import HTTPTransport, SSETransport, StdioTransport, Transport
|
|
from agentkit.server.config import MCPServerConfig
|
|
from agentkit.tools.registry import ToolRegistry
|
|
|
|
|
|
def _make_mock_transport(transport_type: str = "stdio") -> MagicMock:
|
|
"""Create a mock Transport that behaves like a connected transport."""
|
|
mock = MagicMock(spec=Transport)
|
|
mock.is_connected = True
|
|
mock.connect = AsyncMock()
|
|
mock.disconnect = AsyncMock()
|
|
mock.send_request = AsyncMock()
|
|
return mock
|
|
|
|
|
|
def _make_stdio_config() -> MCPServerConfig:
|
|
return MCPServerConfig(
|
|
transport="stdio",
|
|
command="python",
|
|
args=["-m", "mcp_server"],
|
|
timeout=30.0,
|
|
)
|
|
|
|
|
|
def _make_http_config() -> MCPServerConfig:
|
|
return MCPServerConfig(
|
|
transport="streamable_http",
|
|
url="http://localhost:3001/mcp",
|
|
timeout=30.0,
|
|
)
|
|
|
|
|
|
def _make_sse_config() -> MCPServerConfig:
|
|
return MCPServerConfig(
|
|
transport="sse",
|
|
url="http://localhost:3002/sse",
|
|
timeout=30.0,
|
|
)
|
|
|
|
|
|
class TestMCPManagerConstruction:
|
|
"""Tests for MCPManager initialization"""
|
|
|
|
def test_construction_with_configs(self):
|
|
configs = {
|
|
"server1": _make_stdio_config(),
|
|
"server2": _make_http_config(),
|
|
}
|
|
manager = MCPManager(configs=configs)
|
|
assert len(manager._configs) == 2
|
|
assert manager._tool_registry is not None
|
|
assert len(manager._clients) == 0
|
|
assert len(manager._transports) == 0
|
|
assert len(manager._available) == 0
|
|
assert len(manager._server_tools) == 0
|
|
|
|
def test_construction_with_custom_tool_registry(self):
|
|
registry = ToolRegistry()
|
|
configs = {"server1": _make_stdio_config()}
|
|
manager = MCPManager(configs=configs, tool_registry=registry)
|
|
assert manager._tool_registry is registry
|
|
|
|
def test_construction_with_empty_configs(self):
|
|
manager = MCPManager(configs={})
|
|
assert len(manager._configs) == 0
|
|
|
|
|
|
class TestMCPManagerStartAll:
|
|
"""Tests for MCPManager.start_all()"""
|
|
|
|
@patch("agentkit.mcp.manager.StdioTransport")
|
|
async def test_start_all_stdio_server(self, MockStdioTransport):
|
|
mock_transport = _make_mock_transport()
|
|
MockStdioTransport.return_value = mock_transport
|
|
|
|
# Mock list_tools response via MCPClient
|
|
with patch("agentkit.mcp.manager.MCPClient") as MockClient:
|
|
mock_client = MagicMock()
|
|
mock_client.list_tools = AsyncMock(return_value=[
|
|
{"name": "read_file", "description": "Read a file"},
|
|
{"name": "write_file", "description": "Write a file"},
|
|
])
|
|
mock_tool = MagicMock()
|
|
mock_tool.name = "read_file"
|
|
mock_client.as_tool = MagicMock(return_value=mock_tool)
|
|
MockClient.from_transport.return_value = mock_client
|
|
|
|
configs = {"fs": _make_stdio_config()}
|
|
registry = ToolRegistry()
|
|
manager = MCPManager(configs=configs, tool_registry=registry)
|
|
|
|
await manager.start_all()
|
|
|
|
MockStdioTransport.assert_called_once()
|
|
mock_transport.connect.assert_called_once()
|
|
mock_client.list_tools.assert_called_once()
|
|
assert manager.is_available("fs") is True
|
|
assert manager.get_server_tools("fs") == ["read_file", "write_file"]
|
|
|
|
@patch("agentkit.mcp.manager.HTTPTransport")
|
|
async def test_start_all_http_server(self, MockHTTPTransport):
|
|
mock_transport = _make_mock_transport()
|
|
MockHTTPTransport.return_value = mock_transport
|
|
|
|
with patch("agentkit.mcp.manager.MCPClient") as MockClient:
|
|
mock_client = MagicMock()
|
|
mock_client.list_tools = AsyncMock(return_value=[
|
|
{"name": "search", "description": "Search the web"},
|
|
])
|
|
mock_tool = MagicMock()
|
|
mock_tool.name = "search"
|
|
mock_client.as_tool = MagicMock(return_value=mock_tool)
|
|
MockClient.from_transport.return_value = mock_client
|
|
|
|
configs = {"web": _make_http_config()}
|
|
manager = MCPManager(configs=configs)
|
|
|
|
await manager.start_all()
|
|
|
|
MockHTTPTransport.assert_called_once()
|
|
mock_transport.connect.assert_called_once()
|
|
assert manager.is_available("web") is True
|
|
assert manager.get_server_tools("web") == ["search"]
|
|
|
|
@patch("agentkit.mcp.manager.SSETransport")
|
|
async def test_start_all_sse_server(self, MockSSETransport):
|
|
mock_transport = _make_mock_transport()
|
|
MockSSETransport.return_value = mock_transport
|
|
|
|
with patch("agentkit.mcp.manager.MCPClient") as MockClient:
|
|
mock_client = MagicMock()
|
|
mock_client.list_tools = AsyncMock(return_value=[
|
|
{"name": "query", "description": "Query data"},
|
|
])
|
|
mock_tool = MagicMock()
|
|
mock_tool.name = "query"
|
|
mock_client.as_tool = MagicMock(return_value=mock_tool)
|
|
MockClient.from_transport.return_value = mock_client
|
|
|
|
configs = {"sse-srv": _make_sse_config()}
|
|
manager = MCPManager(configs=configs)
|
|
|
|
await manager.start_all()
|
|
|
|
MockSSETransport.assert_called_once()
|
|
assert manager.is_available("sse-srv") is True
|
|
|
|
async def test_start_all_server_failure_doesnt_affect_others(self):
|
|
"""One server failing should not prevent other servers from starting"""
|
|
with patch("agentkit.mcp.manager.StdioTransport") as MockStdio, \
|
|
patch("agentkit.mcp.manager.HTTPTransport") as MockHTTP:
|
|
# First server fails
|
|
fail_transport = _make_mock_transport()
|
|
fail_transport.connect = AsyncMock(side_effect=Exception("Connection refused"))
|
|
MockStdio.return_value = fail_transport
|
|
|
|
# Second server succeeds
|
|
ok_transport = _make_mock_transport()
|
|
MockHTTP.return_value = ok_transport
|
|
|
|
with patch("agentkit.mcp.manager.MCPClient") as MockClient:
|
|
mock_client = MagicMock()
|
|
mock_client.list_tools = AsyncMock(return_value=[
|
|
{"name": "search", "description": "Search"},
|
|
])
|
|
mock_tool = MagicMock()
|
|
mock_tool.name = "search"
|
|
mock_client.as_tool = MagicMock(return_value=mock_tool)
|
|
MockClient.from_transport.return_value = mock_client
|
|
|
|
configs = {
|
|
"failing": _make_stdio_config(),
|
|
"working": _make_http_config(),
|
|
}
|
|
manager = MCPManager(configs=configs)
|
|
|
|
await manager.start_all()
|
|
|
|
assert manager.is_available("failing") is False
|
|
assert manager.is_available("working") is True
|
|
assert manager.get_server_tools("working") == ["search"]
|
|
|
|
|
|
class TestMCPManagerStopAll:
|
|
"""Tests for MCPManager.stop_all()"""
|
|
|
|
@patch("agentkit.mcp.manager.StdioTransport")
|
|
async def test_stop_all(self, MockStdioTransport):
|
|
mock_transport = _make_mock_transport()
|
|
MockStdioTransport.return_value = mock_transport
|
|
|
|
with patch("agentkit.mcp.manager.MCPClient") as MockClient:
|
|
mock_client = MagicMock()
|
|
mock_client.list_tools = AsyncMock(return_value=[])
|
|
MockClient.from_transport.return_value = mock_client
|
|
|
|
configs = {"srv": _make_stdio_config()}
|
|
manager = MCPManager(configs=configs)
|
|
await manager.start_all()
|
|
assert manager.is_available("srv") is True
|
|
|
|
await manager.stop_all()
|
|
|
|
mock_transport.disconnect.assert_called_once()
|
|
assert manager.is_available("srv") is False
|
|
assert len(manager._transports) == 0
|
|
assert len(manager._clients) == 0
|
|
|
|
async def test_stop_all_handles_disconnect_error(self):
|
|
"""stop_all should not raise even if disconnect fails"""
|
|
manager = MCPManager(configs={})
|
|
|
|
# Manually set up internal state to simulate a connected server
|
|
mock_transport = _make_mock_transport()
|
|
mock_transport.disconnect = AsyncMock(side_effect=Exception("Disconnect error"))
|
|
manager._transports = {"srv": mock_transport}
|
|
manager._available = {"srv": True}
|
|
|
|
# Should not raise
|
|
await manager.stop_all()
|
|
assert manager.is_available("srv") is False
|
|
|
|
|
|
class TestMCPManagerQueryMethods:
|
|
"""Tests for MCPManager query methods"""
|
|
|
|
def test_is_available_unknown_server(self):
|
|
manager = MCPManager(configs={})
|
|
assert manager.is_available("nonexistent") is False
|
|
|
|
def test_get_server_tools_unknown_server(self):
|
|
manager = MCPManager(configs={})
|
|
assert manager.get_server_tools("nonexistent") == []
|
|
|
|
def test_list_all_tools_empty(self):
|
|
manager = MCPManager(configs={})
|
|
assert manager.list_all_tools() == []
|
|
|
|
def test_list_all_tools_with_servers(self):
|
|
manager = MCPManager(configs={})
|
|
manager._server_tools = {
|
|
"srv1": ["tool_a", "tool_b"],
|
|
"srv2": ["tool_c"],
|
|
}
|
|
result = manager.list_all_tools()
|
|
assert sorted(result) == ["tool_a", "tool_b", "tool_c"]
|
|
|
|
def test_get_tool_registry(self):
|
|
registry = ToolRegistry()
|
|
manager = MCPManager(configs={}, tool_registry=registry)
|
|
assert manager.get_tool_registry() is registry
|
|
|
|
|
|
class TestMCPManagerToolDiscovery:
|
|
"""Tests for tool discovery and registration"""
|
|
|
|
@patch("agentkit.mcp.manager.StdioTransport")
|
|
async def test_tools_registered_in_registry(self, MockStdioTransport):
|
|
mock_transport = _make_mock_transport()
|
|
MockStdioTransport.return_value = mock_transport
|
|
|
|
with patch("agentkit.mcp.manager.MCPClient") as MockClient:
|
|
mock_client = MagicMock()
|
|
mock_client.list_tools = AsyncMock(return_value=[
|
|
{"name": "read_file", "description": "Read a file"},
|
|
{"name": "write_file", "description": "Write a file"},
|
|
])
|
|
|
|
# Create mock tools that the as_tool method returns
|
|
mock_tool_1 = MagicMock()
|
|
mock_tool_1.name = "read_file"
|
|
mock_tool_2 = MagicMock()
|
|
mock_tool_2.name = "write_file"
|
|
mock_client.as_tool = MagicMock(side_effect=[mock_tool_1, mock_tool_2])
|
|
MockClient.from_transport.return_value = mock_client
|
|
|
|
registry = ToolRegistry()
|
|
configs = {"fs": _make_stdio_config()}
|
|
manager = MCPManager(configs=configs, tool_registry=registry)
|
|
|
|
await manager.start_all()
|
|
|
|
# Verify tools were registered
|
|
assert registry.has_tool("read_file")
|
|
assert registry.has_tool("write_file")
|
|
|
|
@patch("agentkit.mcp.manager.StdioTransport")
|
|
async def test_empty_tools_list(self, MockStdioTransport):
|
|
mock_transport = _make_mock_transport()
|
|
MockStdioTransport.return_value = mock_transport
|
|
|
|
with patch("agentkit.mcp.manager.MCPClient") as MockClient:
|
|
mock_client = MagicMock()
|
|
mock_client.list_tools = AsyncMock(return_value=[])
|
|
MockClient.from_transport.return_value = mock_client
|
|
|
|
configs = {"empty": _make_stdio_config()}
|
|
manager = MCPManager(configs=configs)
|
|
|
|
await manager.start_all()
|
|
|
|
assert manager.is_available("empty") is True
|
|
assert manager.get_server_tools("empty") == []
|
|
assert manager.list_all_tools() == []
|
|
|
|
@patch("agentkit.mcp.manager.StdioTransport")
|
|
async def test_multiple_servers_tools_combined(self, MockStdioTransport):
|
|
mock_transport = _make_mock_transport()
|
|
MockStdioTransport.return_value = mock_transport
|
|
|
|
with patch("agentkit.mcp.manager.MCPClient") as MockClient:
|
|
# First call for srv1
|
|
mock_client_1 = MagicMock()
|
|
mock_client_1.list_tools = AsyncMock(return_value=[
|
|
{"name": "tool_a", "description": "Tool A"},
|
|
])
|
|
mock_tool_a = MagicMock()
|
|
mock_tool_a.name = "tool_a"
|
|
mock_client_1.as_tool = MagicMock(return_value=mock_tool_a)
|
|
|
|
# Second call for srv2
|
|
mock_client_2 = MagicMock()
|
|
mock_client_2.list_tools = AsyncMock(return_value=[
|
|
{"name": "tool_b", "description": "Tool B"},
|
|
])
|
|
mock_tool_b = MagicMock()
|
|
mock_tool_b.name = "tool_b"
|
|
mock_client_2.as_tool = MagicMock(return_value=mock_tool_b)
|
|
|
|
MockClient.from_transport.side_effect = [mock_client_1, mock_client_2]
|
|
|
|
configs = {
|
|
"srv1": _make_stdio_config(),
|
|
"srv2": _make_stdio_config(),
|
|
}
|
|
manager = MCPManager(configs=configs)
|
|
|
|
await manager.start_all()
|
|
|
|
assert manager.get_server_tools("srv1") == ["tool_a"]
|
|
assert manager.get_server_tools("srv2") == ["tool_b"]
|
|
assert sorted(manager.list_all_tools()) == ["tool_a", "tool_b"]
|
|
|
|
|
|
# Run async tests with pytest-asyncio
|
|
pytest_plugins = ["pytest_asyncio"]
|