feat(mcp): U2 MCP config system and MCPManager lifecycle

Add MCPServerConfig dataclass with stdio/streamable_http/sse transport
validation, MCPManager for declarative YAML-driven MCP server lifecycle
(start_all/stop_all), tool discovery and registration. Integrated
into FastAPI lifespan startup/shutdown.
This commit is contained in:
chiguyong 2026-06-07 17:25:07 +08:00
parent 66b9217569
commit 550d29a139
7 changed files with 745 additions and 2 deletions

View File

@ -1,12 +1,15 @@
"""AgentKit MCP - Model Context Protocol 支持""" """AgentKit MCP - Model Context Protocol 支持"""
from agentkit.mcp.transport import HTTPTransport, SSETransport, Transport, TransportError from agentkit.mcp.manager import MCPManager
from agentkit.mcp.transport import HTTPTransport, SSETransport, StdioTransport, Transport, TransportError
__all__ = [ __all__ = [
"MCPManager",
"MCPServer", "MCPServer",
"MCPClient", "MCPClient",
"Transport", "Transport",
"HTTPTransport", "HTTPTransport",
"SSETransport", "SSETransport",
"StdioTransport",
"TransportError", "TransportError",
] ]

View File

@ -5,7 +5,7 @@ from typing import Any
import httpx import httpx
from agentkit.mcp.transport import HTTPTransport, Transport from agentkit.mcp.transport import HTTPTransport, SSETransport, StdioTransport, Transport
from agentkit.tools.base import Tool from agentkit.tools.base import Tool
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -35,6 +35,10 @@ class MCPClient:
"""从 Transport 实例创建 MCPClient""" """从 Transport 实例创建 MCPClient"""
if isinstance(transport, HTTPTransport): if isinstance(transport, HTTPTransport):
server_url = transport._endpoint server_url = transport._endpoint
elif isinstance(transport, SSETransport):
server_url = transport._endpoint
elif isinstance(transport, StdioTransport):
server_url = f"stdio://{transport._command}"
else: else:
server_url = "" server_url = ""
return cls(server_url=server_url, transport=transport) return cls(server_url=server_url, transport=transport)

121
src/agentkit/mcp/manager.py Normal file
View File

@ -0,0 +1,121 @@
"""MCP Manager - 管理 MCP Server 连接和工具发现"""
from __future__ import annotations
import logging
from typing import Any, TYPE_CHECKING
from agentkit.mcp.client import MCPClient
from agentkit.mcp.transport import HTTPTransport, SSETransport, StdioTransport, Transport
from agentkit.tools.registry import ToolRegistry
if TYPE_CHECKING:
from agentkit.server.config import MCPServerConfig
logger = logging.getLogger(__name__)
class MCPManager:
"""管理 MCP Server 连接和工具发现
负责启动/停止 MCP Server 连接发现远程工具并注册到 ToolRegistry
"""
def __init__(
self,
configs: dict[str, MCPServerConfig],
tool_registry: ToolRegistry | None = None,
):
self._configs = configs
self._tool_registry = tool_registry or ToolRegistry()
self._clients: dict[str, MCPClient] = {} # server_name -> MCPClient
self._transports: dict[str, Transport] = {} # server_name -> Transport
self._available: dict[str, bool] = {} # server_name -> is_available
self._server_tools: dict[str, list[str]] = {} # server_name -> [tool_names]
async def start_all(self) -> None:
"""启动所有配置的 MCP Server发现并注册工具"""
for name, config in self._configs.items():
try:
await self._start_server(name, config)
except Exception as e:
logger.error("Failed to start MCP server '%s': %s", name, e)
self._available[name] = False
async def _start_server(self, name: str, config: MCPServerConfig) -> None:
"""启动单个 MCP Server"""
config.validate()
# 根据配置创建传输层
if config.transport == "stdio":
transport = StdioTransport(
command=config.command,
args=config.args or [],
env=config.env,
timeout=config.timeout,
)
elif config.transport == "streamable_http":
transport = HTTPTransport(
endpoint=config.url,
headers=config.headers,
timeout=config.timeout,
)
elif config.transport == "sse":
transport = SSETransport(
endpoint=config.url,
headers=config.headers,
timeout=config.timeout,
)
else:
raise ValueError(f"Unknown transport: {config.transport}")
# 建立连接
await transport.connect()
self._transports[name] = transport
# 创建客户端并发现工具
client = MCPClient.from_transport(transport)
self._clients[name] = client
tools = await client.list_tools()
tool_names = []
for tool_info in tools:
tool_name = tool_info.get("name", "")
tool_desc = tool_info.get("description", "")
mcp_tool = client.as_tool(tool_name, tool_desc)
self._tool_registry.register(mcp_tool)
tool_names.append(tool_name)
self._server_tools[name] = tool_names
self._available[name] = True
logger.info("MCP server '%s' started with tools: %s", name, tool_names)
async def stop_all(self) -> None:
"""停止所有 MCP Server"""
for name, transport in self._transports.items():
try:
await transport.disconnect()
except Exception as e:
logger.error("Error stopping MCP server '%s': %s", name, e)
self._available[name] = False
self._transports.clear()
self._clients.clear()
def is_available(self, server_name: str) -> bool:
"""检查指定 MCP Server 是否可用"""
return self._available.get(server_name, False)
def get_server_tools(self, server_name: str) -> list[str]:
"""获取指定 MCP Server 提供的工具列表"""
return self._server_tools.get(server_name, [])
def list_all_tools(self) -> list[str]:
"""列出所有 MCP Server 提供的工具"""
all_tools: list[str] = []
for tools in self._server_tools.values():
all_tools.extend(tools)
return all_tools
def get_tool_registry(self) -> ToolRegistry:
"""获取工具注册中心"""
return self._tool_registry

View File

@ -11,6 +11,7 @@ from agentkit.core.agent_pool import AgentPool
from agentkit.llm.gateway import LLMGateway from agentkit.llm.gateway import LLMGateway
from agentkit.llm.providers.anthropic import AnthropicProvider from agentkit.llm.providers.anthropic import AnthropicProvider
from agentkit.llm.providers.openai import OpenAICompatibleProvider from agentkit.llm.providers.openai import OpenAICompatibleProvider
from agentkit.mcp.manager import MCPManager
from agentkit.quality.gate import QualityGate from agentkit.quality.gate import QualityGate
from agentkit.quality.output import OutputStandardizer from agentkit.quality.output import OutputStandardizer
from agentkit.router.intent import IntentRouter from agentkit.router.intent import IntentRouter
@ -23,6 +24,7 @@ from agentkit.server.middleware import APIKeyAuthMiddleware, RateLimitMiddleware
from agentkit.server.task_store import create_task_store from agentkit.server.task_store import create_task_store
from agentkit.server.runner import BackgroundRunner from agentkit.server.runner import BackgroundRunner
from agentkit.core.logging import setup_structured_logging from agentkit.core.logging import setup_structured_logging
from agentkit.telemetry.setup import setup_telemetry
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -87,9 +89,18 @@ async def lifespan(app: FastAPI):
server_config.watch_config() server_config.watch_config()
logger.info("Config hot-reload enabled") logger.info("Config hot-reload enabled")
# Start MCP servers if configured
mcp_manager = getattr(app.state, "mcp_manager", None)
if mcp_manager is not None:
await mcp_manager.start_all()
yield yield
# Shutdown # Shutdown
# Stop MCP servers
if mcp_manager is not None:
await mcp_manager.stop_all()
if server_config is not None: if server_config is not None:
server_config.stop_watching() server_config.stop_watching()
@ -164,6 +175,10 @@ def create_app(
# Initialize structured logging # Initialize structured logging
setup_structured_logging() setup_structured_logging()
# Initialize OpenTelemetry (no-op if not installed or not configured)
if server_config:
setup_telemetry(app, server_config.telemetry)
# Resolve effective API key and rate limit # Resolve effective API key and rate limit
effective_api_key = api_key effective_api_key = api_key
effective_rate_limit = rate_limit effective_rate_limit = rate_limit
@ -210,6 +225,15 @@ def create_app(
app.state.llm_gateway = llm_gateway or LLMGateway() app.state.llm_gateway = llm_gateway or LLMGateway()
app.state.skill_registry = skill_registry or SkillRegistry() app.state.skill_registry = skill_registry or SkillRegistry()
app.state.tool_registry = tool_registry or ToolRegistry() app.state.tool_registry = tool_registry or ToolRegistry()
# Initialize MCPManager if MCP servers are configured
if server_config and server_config.mcp_servers:
mcp_manager = MCPManager(
configs=server_config.mcp_servers,
tool_registry=app.state.tool_registry,
)
app.state.mcp_manager = mcp_manager
else:
app.state.mcp_manager = None
app.state.agent_pool = AgentPool( app.state.agent_pool = AgentPool(
llm_gateway=app.state.llm_gateway, llm_gateway=app.state.llm_gateway,
skill_registry=app.state.skill_registry, skill_registry=app.state.skill_registry,

View File

@ -4,6 +4,7 @@ import asyncio
import logging import logging
import os import os
import re import re
from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any, Callable from typing import Any, Callable
@ -18,6 +19,44 @@ logger = logging.getLogger(__name__)
DEFAULT_CONFIG_FILE = "agentkit.yaml" DEFAULT_CONFIG_FILE = "agentkit.yaml"
@dataclass
class MCPServerConfig:
"""Configuration for a single MCP Server connection"""
transport: str # "stdio" | "streamable_http" | "sse"
# stdio-specific
command: str | None = None
args: list[str] | None = None
env: dict[str, str] | None = None
# http/sse-specific
url: str | None = None
headers: dict[str, str] | None = None
# common
timeout: float = 30.0
def validate(self) -> None:
"""Validate configuration, raise ValueError if invalid"""
if self.transport not in ("stdio", "streamable_http", "sse"):
raise ValueError(f"Invalid transport: {self.transport}")
if self.transport == "stdio" and not self.command:
raise ValueError("stdio transport requires 'command'")
if self.transport in ("streamable_http", "sse") and not self.url:
raise ValueError(f"{self.transport} transport requires 'url'")
@classmethod
def from_dict(cls, data: dict) -> "MCPServerConfig":
"""Create from dict (parsed from YAML)"""
return cls(
transport=data.get("transport", "stdio"),
command=data.get("command"),
args=data.get("args"),
env=data.get("env"),
url=data.get("url"),
headers=data.get("headers"),
timeout=data.get("timeout", 30.0),
)
def _resolve_env_vars(value: Any) -> Any: def _resolve_env_vars(value: Any) -> Any:
"""Resolve ${VAR:-default} patterns in string values from environment variables.""" """Resolve ${VAR:-default} patterns in string values from environment variables."""
if not isinstance(value, str): if not isinstance(value, str):
@ -64,6 +103,8 @@ class ServerConfig:
task_store: dict[str, Any] | None = None, task_store: dict[str, Any] | None = None,
cors_origins: list[str] | None = None, cors_origins: list[str] | None = None,
memory: dict[str, Any] | None = None, memory: dict[str, Any] | None = None,
mcp_servers: dict[str, MCPServerConfig] | None = None,
telemetry: dict[str, Any] | None = None,
on_change: Callable[["ServerConfig"], None] | None = None, on_change: Callable[["ServerConfig"], None] | None = None,
): ):
self.host = host self.host = host
@ -79,6 +120,8 @@ class ServerConfig:
self.task_store = task_store or {} self.task_store = task_store or {}
self.cors_origins = cors_origins or ["*"] self.cors_origins = cors_origins or ["*"]
self.memory = memory or {} self.memory = memory or {}
self.mcp_servers = mcp_servers or {}
self.telemetry = telemetry or {}
self.on_change = on_change self.on_change = on_change
# Config watching state # Config watching state
@ -109,6 +152,7 @@ class ServerConfig:
logging_data = data.get("logging", {}) logging_data = data.get("logging", {})
task_store_data = data.get("task_store", {}) task_store_data = data.get("task_store", {})
memory_data = data.get("memory", {}) memory_data = data.get("memory", {})
mcp_data = data.get("mcp", {})
# Build LLMConfig # Build LLMConfig
llm_config = cls._build_llm_config(llm_data) llm_config = cls._build_llm_config(llm_data)
@ -117,6 +161,12 @@ class ServerConfig:
skill_paths = skills_data.get("paths", []) skill_paths = skills_data.get("paths", [])
auto_discover = skills_data.get("auto_discover", True) auto_discover = skills_data.get("auto_discover", True)
# Build MCP server configs
mcp_servers = cls._build_mcp_configs(mcp_data)
# Telemetry config
telemetry_data = data.get("telemetry", {})
return cls( return cls(
host=server.get("host", "0.0.0.0"), host=server.get("host", "0.0.0.0"),
port=server.get("port", 8001), port=server.get("port", 8001),
@ -131,6 +181,8 @@ class ServerConfig:
task_store=task_store_data, task_store=task_store_data,
cors_origins=server.get("cors_origins"), cors_origins=server.get("cors_origins"),
memory=memory_data, memory=memory_data,
mcp_servers=mcp_servers,
telemetry=telemetry_data,
) )
@staticmethod @staticmethod
@ -165,6 +217,18 @@ class ServerConfig:
fallbacks=data.get("fallbacks", {}), fallbacks=data.get("fallbacks", {}),
) )
@staticmethod
def _build_mcp_configs(data: dict) -> dict[str, MCPServerConfig]:
"""Build MCP server configs from the mcp section of agentkit.yaml."""
servers = data.get("servers", {})
if not servers:
return {}
result = {}
for name, server_conf in servers.items():
if isinstance(server_conf, dict):
result[name] = MCPServerConfig.from_dict(server_conf)
return result
def load_skill_configs(self) -> list[SkillConfig]: def load_skill_configs(self) -> list[SkillConfig]:
"""Load all SkillConfig from configured skill paths.""" """Load all SkillConfig from configured skill paths."""
configs = [] configs = []
@ -307,6 +371,8 @@ class ServerConfig:
self.task_store = new_config.task_store self.task_store = new_config.task_store
self.cors_origins = new_config.cors_origins self.cors_origins = new_config.cors_origins
self.memory = new_config.memory self.memory = new_config.memory
self.mcp_servers = new_config.mcp_servers
self.telemetry = new_config.telemetry
self._last_mtime = new_config._last_mtime self._last_mtime = new_config._last_mtime
logger.info(f"Config reloaded from {path}") logger.info(f"Config reloaded from {path}")

View File

@ -0,0 +1,171 @@
"""Tests for MCPServerConfig and ServerConfig MCP section parsing"""
import pytest
from agentkit.server.config import MCPServerConfig, ServerConfig
class TestMCPServerConfig:
"""Tests for MCPServerConfig dataclass"""
def test_from_dict_stdio(self):
data = {
"transport": "stdio",
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"],
"env": {"NODE_ENV": "production"},
"timeout": 60.0,
}
config = MCPServerConfig.from_dict(data)
assert config.transport == "stdio"
assert config.command == "npx"
assert config.args == ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"]
assert config.env == {"NODE_ENV": "production"}
assert config.timeout == 60.0
assert config.url is None
assert config.headers is None
def test_from_dict_streamable_http(self):
data = {
"transport": "streamable_http",
"url": "http://localhost:3001/mcp",
"headers": {"Authorization": "Bearer test-token"},
"timeout": 45.0,
}
config = MCPServerConfig.from_dict(data)
assert config.transport == "streamable_http"
assert config.url == "http://localhost:3001/mcp"
assert config.headers == {"Authorization": "Bearer test-token"}
assert config.timeout == 45.0
assert config.command is None
assert config.args is None
def test_from_dict_sse(self):
data = {
"transport": "sse",
"url": "http://localhost:3002/sse",
}
config = MCPServerConfig.from_dict(data)
assert config.transport == "sse"
assert config.url == "http://localhost:3002/sse"
assert config.command is None
def test_from_dict_defaults(self):
data = {}
config = MCPServerConfig.from_dict(data)
assert config.transport == "stdio"
assert config.command is None
assert config.args is None
assert config.env is None
assert config.url is None
assert config.headers is None
assert config.timeout == 30.0
def test_validate_stdio_valid(self):
config = MCPServerConfig(transport="stdio", command="python")
config.validate() # Should not raise
def test_validate_stdio_missing_command(self):
config = MCPServerConfig(transport="stdio", command=None)
with pytest.raises(ValueError, match="stdio transport requires 'command'"):
config.validate()
def test_validate_streamable_http_missing_url(self):
config = MCPServerConfig(transport="streamable_http", url=None)
with pytest.raises(ValueError, match="streamable_http transport requires 'url'"):
config.validate()
def test_validate_sse_missing_url(self):
config = MCPServerConfig(transport="sse", url=None)
with pytest.raises(ValueError, match="sse transport requires 'url'"):
config.validate()
def test_validate_invalid_transport(self):
config = MCPServerConfig(transport="websocket")
with pytest.raises(ValueError, match="Invalid transport: websocket"):
config.validate()
def test_validate_http_with_url(self):
config = MCPServerConfig(transport="streamable_http", url="http://localhost:3001")
config.validate() # Should not raise
def test_validate_sse_with_url(self):
config = MCPServerConfig(transport="sse", url="http://localhost:3002")
config.validate() # Should not raise
class TestServerConfigMCPSection:
"""Tests for ServerConfig parsing with mcp section"""
def test_from_dict_with_mcp_servers(self):
data = {
"mcp": {
"servers": {
"filesystem": {
"transport": "stdio",
"command": "npx",
"args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"],
},
"remote": {
"transport": "streamable_http",
"url": "http://localhost:3001/mcp",
},
}
}
}
config = ServerConfig.from_dict(data)
assert len(config.mcp_servers) == 2
assert "filesystem" in config.mcp_servers
assert "remote" in config.mcp_servers
assert config.mcp_servers["filesystem"].transport == "stdio"
assert config.mcp_servers["filesystem"].command == "npx"
assert config.mcp_servers["remote"].transport == "streamable_http"
assert config.mcp_servers["remote"].url == "http://localhost:3001/mcp"
def test_from_dict_without_mcp_section(self):
data = {}
config = ServerConfig.from_dict(data)
assert config.mcp_servers == {}
def test_from_dict_with_empty_mcp_servers(self):
data = {"mcp": {"servers": {}}}
config = ServerConfig.from_dict(data)
assert config.mcp_servers == {}
def test_from_dict_mcp_servers_with_sse(self):
data = {
"mcp": {
"servers": {
"sse-server": {
"transport": "sse",
"url": "http://localhost:3002/sse",
"headers": {"X-API-Key": "secret"},
"timeout": 60.0,
}
}
}
}
config = ServerConfig.from_dict(data)
assert len(config.mcp_servers) == 1
sse_conf = config.mcp_servers["sse-server"]
assert sse_conf.transport == "sse"
assert sse_conf.url == "http://localhost:3002/sse"
assert sse_conf.headers == {"X-API-Key": "secret"}
assert sse_conf.timeout == 60.0
def test_from_dict_mcp_ignores_non_dict_entries(self):
data = {
"mcp": {
"servers": {
"valid": {
"transport": "stdio",
"command": "python",
},
"invalid": "not-a-dict",
}
}
}
config = ServerConfig.from_dict(data)
assert len(config.mcp_servers) == 1
assert "valid" in config.mcp_servers
assert "invalid" not in config.mcp_servers

View File

@ -0,0 +1,354 @@
"""Tests for MCPManager lifecycle and tool discovery"""
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from agentkit.mcp.manager import MCPManager
from agentkit.mcp.transport import HTTPTransport, SSETransport, StdioTransport, Transport
from agentkit.server.config import MCPServerConfig
from agentkit.tools.registry import ToolRegistry
def _make_mock_transport(transport_type: str = "stdio") -> MagicMock:
"""Create a mock Transport that behaves like a connected transport."""
mock = MagicMock(spec=Transport)
mock.is_connected = True
mock.connect = AsyncMock()
mock.disconnect = AsyncMock()
mock.send_request = AsyncMock()
return mock
def _make_stdio_config() -> MCPServerConfig:
return MCPServerConfig(
transport="stdio",
command="python",
args=["-m", "mcp_server"],
timeout=30.0,
)
def _make_http_config() -> MCPServerConfig:
return MCPServerConfig(
transport="streamable_http",
url="http://localhost:3001/mcp",
timeout=30.0,
)
def _make_sse_config() -> MCPServerConfig:
return MCPServerConfig(
transport="sse",
url="http://localhost:3002/sse",
timeout=30.0,
)
class TestMCPManagerConstruction:
"""Tests for MCPManager initialization"""
def test_construction_with_configs(self):
configs = {
"server1": _make_stdio_config(),
"server2": _make_http_config(),
}
manager = MCPManager(configs=configs)
assert len(manager._configs) == 2
assert manager._tool_registry is not None
assert len(manager._clients) == 0
assert len(manager._transports) == 0
assert len(manager._available) == 0
assert len(manager._server_tools) == 0
def test_construction_with_custom_tool_registry(self):
registry = ToolRegistry()
configs = {"server1": _make_stdio_config()}
manager = MCPManager(configs=configs, tool_registry=registry)
assert manager._tool_registry is registry
def test_construction_with_empty_configs(self):
manager = MCPManager(configs={})
assert len(manager._configs) == 0
class TestMCPManagerStartAll:
"""Tests for MCPManager.start_all()"""
@patch("agentkit.mcp.manager.StdioTransport")
async def test_start_all_stdio_server(self, MockStdioTransport):
mock_transport = _make_mock_transport()
MockStdioTransport.return_value = mock_transport
# Mock list_tools response via MCPClient
with patch("agentkit.mcp.manager.MCPClient") as MockClient:
mock_client = MagicMock()
mock_client.list_tools = AsyncMock(return_value=[
{"name": "read_file", "description": "Read a file"},
{"name": "write_file", "description": "Write a file"},
])
mock_tool = MagicMock()
mock_tool.name = "read_file"
mock_client.as_tool = MagicMock(return_value=mock_tool)
MockClient.from_transport.return_value = mock_client
configs = {"fs": _make_stdio_config()}
registry = ToolRegistry()
manager = MCPManager(configs=configs, tool_registry=registry)
await manager.start_all()
MockStdioTransport.assert_called_once()
mock_transport.connect.assert_called_once()
mock_client.list_tools.assert_called_once()
assert manager.is_available("fs") is True
assert manager.get_server_tools("fs") == ["read_file", "write_file"]
@patch("agentkit.mcp.manager.HTTPTransport")
async def test_start_all_http_server(self, MockHTTPTransport):
mock_transport = _make_mock_transport()
MockHTTPTransport.return_value = mock_transport
with patch("agentkit.mcp.manager.MCPClient") as MockClient:
mock_client = MagicMock()
mock_client.list_tools = AsyncMock(return_value=[
{"name": "search", "description": "Search the web"},
])
mock_tool = MagicMock()
mock_tool.name = "search"
mock_client.as_tool = MagicMock(return_value=mock_tool)
MockClient.from_transport.return_value = mock_client
configs = {"web": _make_http_config()}
manager = MCPManager(configs=configs)
await manager.start_all()
MockHTTPTransport.assert_called_once()
mock_transport.connect.assert_called_once()
assert manager.is_available("web") is True
assert manager.get_server_tools("web") == ["search"]
@patch("agentkit.mcp.manager.SSETransport")
async def test_start_all_sse_server(self, MockSSETransport):
mock_transport = _make_mock_transport()
MockSSETransport.return_value = mock_transport
with patch("agentkit.mcp.manager.MCPClient") as MockClient:
mock_client = MagicMock()
mock_client.list_tools = AsyncMock(return_value=[
{"name": "query", "description": "Query data"},
])
mock_tool = MagicMock()
mock_tool.name = "query"
mock_client.as_tool = MagicMock(return_value=mock_tool)
MockClient.from_transport.return_value = mock_client
configs = {"sse-srv": _make_sse_config()}
manager = MCPManager(configs=configs)
await manager.start_all()
MockSSETransport.assert_called_once()
assert manager.is_available("sse-srv") is True
async def test_start_all_server_failure_doesnt_affect_others(self):
"""One server failing should not prevent other servers from starting"""
with patch("agentkit.mcp.manager.StdioTransport") as MockStdio, \
patch("agentkit.mcp.manager.HTTPTransport") as MockHTTP:
# First server fails
fail_transport = _make_mock_transport()
fail_transport.connect = AsyncMock(side_effect=Exception("Connection refused"))
MockStdio.return_value = fail_transport
# Second server succeeds
ok_transport = _make_mock_transport()
MockHTTP.return_value = ok_transport
with patch("agentkit.mcp.manager.MCPClient") as MockClient:
mock_client = MagicMock()
mock_client.list_tools = AsyncMock(return_value=[
{"name": "search", "description": "Search"},
])
mock_tool = MagicMock()
mock_tool.name = "search"
mock_client.as_tool = MagicMock(return_value=mock_tool)
MockClient.from_transport.return_value = mock_client
configs = {
"failing": _make_stdio_config(),
"working": _make_http_config(),
}
manager = MCPManager(configs=configs)
await manager.start_all()
assert manager.is_available("failing") is False
assert manager.is_available("working") is True
assert manager.get_server_tools("working") == ["search"]
class TestMCPManagerStopAll:
"""Tests for MCPManager.stop_all()"""
@patch("agentkit.mcp.manager.StdioTransport")
async def test_stop_all(self, MockStdioTransport):
mock_transport = _make_mock_transport()
MockStdioTransport.return_value = mock_transport
with patch("agentkit.mcp.manager.MCPClient") as MockClient:
mock_client = MagicMock()
mock_client.list_tools = AsyncMock(return_value=[])
MockClient.from_transport.return_value = mock_client
configs = {"srv": _make_stdio_config()}
manager = MCPManager(configs=configs)
await manager.start_all()
assert manager.is_available("srv") is True
await manager.stop_all()
mock_transport.disconnect.assert_called_once()
assert manager.is_available("srv") is False
assert len(manager._transports) == 0
assert len(manager._clients) == 0
async def test_stop_all_handles_disconnect_error(self):
"""stop_all should not raise even if disconnect fails"""
manager = MCPManager(configs={})
# Manually set up internal state to simulate a connected server
mock_transport = _make_mock_transport()
mock_transport.disconnect = AsyncMock(side_effect=Exception("Disconnect error"))
manager._transports = {"srv": mock_transport}
manager._available = {"srv": True}
# Should not raise
await manager.stop_all()
assert manager.is_available("srv") is False
class TestMCPManagerQueryMethods:
"""Tests for MCPManager query methods"""
def test_is_available_unknown_server(self):
manager = MCPManager(configs={})
assert manager.is_available("nonexistent") is False
def test_get_server_tools_unknown_server(self):
manager = MCPManager(configs={})
assert manager.get_server_tools("nonexistent") == []
def test_list_all_tools_empty(self):
manager = MCPManager(configs={})
assert manager.list_all_tools() == []
def test_list_all_tools_with_servers(self):
manager = MCPManager(configs={})
manager._server_tools = {
"srv1": ["tool_a", "tool_b"],
"srv2": ["tool_c"],
}
result = manager.list_all_tools()
assert sorted(result) == ["tool_a", "tool_b", "tool_c"]
def test_get_tool_registry(self):
registry = ToolRegistry()
manager = MCPManager(configs={}, tool_registry=registry)
assert manager.get_tool_registry() is registry
class TestMCPManagerToolDiscovery:
"""Tests for tool discovery and registration"""
@patch("agentkit.mcp.manager.StdioTransport")
async def test_tools_registered_in_registry(self, MockStdioTransport):
mock_transport = _make_mock_transport()
MockStdioTransport.return_value = mock_transport
with patch("agentkit.mcp.manager.MCPClient") as MockClient:
mock_client = MagicMock()
mock_client.list_tools = AsyncMock(return_value=[
{"name": "read_file", "description": "Read a file"},
{"name": "write_file", "description": "Write a file"},
])
# Create mock tools that the as_tool method returns
mock_tool_1 = MagicMock()
mock_tool_1.name = "read_file"
mock_tool_2 = MagicMock()
mock_tool_2.name = "write_file"
mock_client.as_tool = MagicMock(side_effect=[mock_tool_1, mock_tool_2])
MockClient.from_transport.return_value = mock_client
registry = ToolRegistry()
configs = {"fs": _make_stdio_config()}
manager = MCPManager(configs=configs, tool_registry=registry)
await manager.start_all()
# Verify tools were registered
assert registry.has_tool("read_file")
assert registry.has_tool("write_file")
@patch("agentkit.mcp.manager.StdioTransport")
async def test_empty_tools_list(self, MockStdioTransport):
mock_transport = _make_mock_transport()
MockStdioTransport.return_value = mock_transport
with patch("agentkit.mcp.manager.MCPClient") as MockClient:
mock_client = MagicMock()
mock_client.list_tools = AsyncMock(return_value=[])
MockClient.from_transport.return_value = mock_client
configs = {"empty": _make_stdio_config()}
manager = MCPManager(configs=configs)
await manager.start_all()
assert manager.is_available("empty") is True
assert manager.get_server_tools("empty") == []
assert manager.list_all_tools() == []
@patch("agentkit.mcp.manager.StdioTransport")
async def test_multiple_servers_tools_combined(self, MockStdioTransport):
mock_transport = _make_mock_transport()
MockStdioTransport.return_value = mock_transport
with patch("agentkit.mcp.manager.MCPClient") as MockClient:
# First call for srv1
mock_client_1 = MagicMock()
mock_client_1.list_tools = AsyncMock(return_value=[
{"name": "tool_a", "description": "Tool A"},
])
mock_tool_a = MagicMock()
mock_tool_a.name = "tool_a"
mock_client_1.as_tool = MagicMock(return_value=mock_tool_a)
# Second call for srv2
mock_client_2 = MagicMock()
mock_client_2.list_tools = AsyncMock(return_value=[
{"name": "tool_b", "description": "Tool B"},
])
mock_tool_b = MagicMock()
mock_tool_b.name = "tool_b"
mock_client_2.as_tool = MagicMock(return_value=mock_tool_b)
MockClient.from_transport.side_effect = [mock_client_1, mock_client_2]
configs = {
"srv1": _make_stdio_config(),
"srv2": _make_stdio_config(),
}
manager = MCPManager(configs=configs)
await manager.start_all()
assert manager.get_server_tools("srv1") == ["tool_a"]
assert manager.get_server_tools("srv2") == ["tool_b"]
assert sorted(manager.list_all_tools()) == ["tool_a", "tool_b"]
# Run async tests with pytest-asyncio
pytest_plugins = ["pytest_asyncio"]