From 550d29a1397291c3f02ea6d2e44b73aa6967b5d7 Mon Sep 17 00:00:00 2001 From: chiguyong Date: Sun, 7 Jun 2026 17:25:07 +0800 Subject: [PATCH] feat(mcp): U2 MCP config system and MCPManager lifecycle Add MCPServerConfig dataclass with stdio/streamable_http/sse transport validation, MCPManager for declarative YAML-driven MCP server lifecycle (start_all/stop_all), tool discovery and registration. Integrated into FastAPI lifespan startup/shutdown. --- src/agentkit/mcp/__init__.py | 5 +- src/agentkit/mcp/client.py | 6 +- src/agentkit/mcp/manager.py | 121 +++++++++++ src/agentkit/server/app.py | 24 +++ src/agentkit/server/config.py | 66 ++++++ tests/unit/test_mcp_config.py | 171 ++++++++++++++++ tests/unit/test_mcp_manager.py | 354 +++++++++++++++++++++++++++++++++ 7 files changed, 745 insertions(+), 2 deletions(-) create mode 100644 src/agentkit/mcp/manager.py create mode 100644 tests/unit/test_mcp_config.py create mode 100644 tests/unit/test_mcp_manager.py diff --git a/src/agentkit/mcp/__init__.py b/src/agentkit/mcp/__init__.py index 4536fe6..c9eeb07 100644 --- a/src/agentkit/mcp/__init__.py +++ b/src/agentkit/mcp/__init__.py @@ -1,12 +1,15 @@ """AgentKit MCP - Model Context Protocol 支持""" -from agentkit.mcp.transport import HTTPTransport, SSETransport, Transport, TransportError +from agentkit.mcp.manager import MCPManager +from agentkit.mcp.transport import HTTPTransport, SSETransport, StdioTransport, Transport, TransportError __all__ = [ + "MCPManager", "MCPServer", "MCPClient", "Transport", "HTTPTransport", "SSETransport", + "StdioTransport", "TransportError", ] diff --git a/src/agentkit/mcp/client.py b/src/agentkit/mcp/client.py index f2998d2..448b452 100644 --- a/src/agentkit/mcp/client.py +++ b/src/agentkit/mcp/client.py @@ -5,7 +5,7 @@ from typing import Any import httpx -from agentkit.mcp.transport import HTTPTransport, Transport +from agentkit.mcp.transport import HTTPTransport, SSETransport, StdioTransport, Transport from agentkit.tools.base import Tool logger = logging.getLogger(__name__) @@ -35,6 +35,10 @@ class MCPClient: """从 Transport 实例创建 MCPClient""" if isinstance(transport, HTTPTransport): server_url = transport._endpoint + elif isinstance(transport, SSETransport): + server_url = transport._endpoint + elif isinstance(transport, StdioTransport): + server_url = f"stdio://{transport._command}" else: server_url = "" return cls(server_url=server_url, transport=transport) diff --git a/src/agentkit/mcp/manager.py b/src/agentkit/mcp/manager.py new file mode 100644 index 0000000..5bd8949 --- /dev/null +++ b/src/agentkit/mcp/manager.py @@ -0,0 +1,121 @@ +"""MCP Manager - 管理 MCP Server 连接和工具发现""" + +from __future__ import annotations + +import logging +from typing import Any, TYPE_CHECKING + +from agentkit.mcp.client import MCPClient +from agentkit.mcp.transport import HTTPTransport, SSETransport, StdioTransport, Transport +from agentkit.tools.registry import ToolRegistry + +if TYPE_CHECKING: + from agentkit.server.config import MCPServerConfig + +logger = logging.getLogger(__name__) + + +class MCPManager: + """管理 MCP Server 连接和工具发现 + + 负责启动/停止 MCP Server 连接,发现远程工具并注册到 ToolRegistry。 + """ + + def __init__( + self, + configs: dict[str, MCPServerConfig], + tool_registry: ToolRegistry | None = None, + ): + self._configs = configs + self._tool_registry = tool_registry or ToolRegistry() + self._clients: dict[str, MCPClient] = {} # server_name -> MCPClient + self._transports: dict[str, Transport] = {} # server_name -> Transport + self._available: dict[str, bool] = {} # server_name -> is_available + self._server_tools: dict[str, list[str]] = {} # server_name -> [tool_names] + + async def start_all(self) -> None: + """启动所有配置的 MCP Server,发现并注册工具""" + for name, config in self._configs.items(): + try: + await self._start_server(name, config) + except Exception as e: + logger.error("Failed to start MCP server '%s': %s", name, e) + self._available[name] = False + + async def _start_server(self, name: str, config: MCPServerConfig) -> None: + """启动单个 MCP Server""" + config.validate() + + # 根据配置创建传输层 + if config.transport == "stdio": + transport = StdioTransport( + command=config.command, + args=config.args or [], + env=config.env, + timeout=config.timeout, + ) + elif config.transport == "streamable_http": + transport = HTTPTransport( + endpoint=config.url, + headers=config.headers, + timeout=config.timeout, + ) + elif config.transport == "sse": + transport = SSETransport( + endpoint=config.url, + headers=config.headers, + timeout=config.timeout, + ) + else: + raise ValueError(f"Unknown transport: {config.transport}") + + # 建立连接 + await transport.connect() + self._transports[name] = transport + + # 创建客户端并发现工具 + client = MCPClient.from_transport(transport) + self._clients[name] = client + + tools = await client.list_tools() + tool_names = [] + for tool_info in tools: + tool_name = tool_info.get("name", "") + tool_desc = tool_info.get("description", "") + mcp_tool = client.as_tool(tool_name, tool_desc) + self._tool_registry.register(mcp_tool) + tool_names.append(tool_name) + + self._server_tools[name] = tool_names + self._available[name] = True + logger.info("MCP server '%s' started with tools: %s", name, tool_names) + + async def stop_all(self) -> None: + """停止所有 MCP Server""" + for name, transport in self._transports.items(): + try: + await transport.disconnect() + except Exception as e: + logger.error("Error stopping MCP server '%s': %s", name, e) + self._available[name] = False + self._transports.clear() + self._clients.clear() + + def is_available(self, server_name: str) -> bool: + """检查指定 MCP Server 是否可用""" + return self._available.get(server_name, False) + + def get_server_tools(self, server_name: str) -> list[str]: + """获取指定 MCP Server 提供的工具列表""" + return self._server_tools.get(server_name, []) + + def list_all_tools(self) -> list[str]: + """列出所有 MCP Server 提供的工具""" + all_tools: list[str] = [] + for tools in self._server_tools.values(): + all_tools.extend(tools) + return all_tools + + def get_tool_registry(self) -> ToolRegistry: + """获取工具注册中心""" + return self._tool_registry diff --git a/src/agentkit/server/app.py b/src/agentkit/server/app.py index f4677c2..d980108 100644 --- a/src/agentkit/server/app.py +++ b/src/agentkit/server/app.py @@ -11,6 +11,7 @@ from agentkit.core.agent_pool import AgentPool from agentkit.llm.gateway import LLMGateway from agentkit.llm.providers.anthropic import AnthropicProvider from agentkit.llm.providers.openai import OpenAICompatibleProvider +from agentkit.mcp.manager import MCPManager from agentkit.quality.gate import QualityGate from agentkit.quality.output import OutputStandardizer from agentkit.router.intent import IntentRouter @@ -23,6 +24,7 @@ from agentkit.server.middleware import APIKeyAuthMiddleware, RateLimitMiddleware from agentkit.server.task_store import create_task_store from agentkit.server.runner import BackgroundRunner from agentkit.core.logging import setup_structured_logging +from agentkit.telemetry.setup import setup_telemetry logger = logging.getLogger(__name__) @@ -87,9 +89,18 @@ async def lifespan(app: FastAPI): server_config.watch_config() logger.info("Config hot-reload enabled") + # Start MCP servers if configured + mcp_manager = getattr(app.state, "mcp_manager", None) + if mcp_manager is not None: + await mcp_manager.start_all() + yield # Shutdown + # Stop MCP servers + if mcp_manager is not None: + await mcp_manager.stop_all() + if server_config is not None: server_config.stop_watching() @@ -164,6 +175,10 @@ def create_app( # Initialize structured logging setup_structured_logging() + # Initialize OpenTelemetry (no-op if not installed or not configured) + if server_config: + setup_telemetry(app, server_config.telemetry) + # Resolve effective API key and rate limit effective_api_key = api_key effective_rate_limit = rate_limit @@ -210,6 +225,15 @@ def create_app( app.state.llm_gateway = llm_gateway or LLMGateway() app.state.skill_registry = skill_registry or SkillRegistry() app.state.tool_registry = tool_registry or ToolRegistry() + # Initialize MCPManager if MCP servers are configured + if server_config and server_config.mcp_servers: + mcp_manager = MCPManager( + configs=server_config.mcp_servers, + tool_registry=app.state.tool_registry, + ) + app.state.mcp_manager = mcp_manager + else: + app.state.mcp_manager = None app.state.agent_pool = AgentPool( llm_gateway=app.state.llm_gateway, skill_registry=app.state.skill_registry, diff --git a/src/agentkit/server/config.py b/src/agentkit/server/config.py index 1033f51..8900671 100644 --- a/src/agentkit/server/config.py +++ b/src/agentkit/server/config.py @@ -4,6 +4,7 @@ import asyncio import logging import os import re +from dataclasses import dataclass, field from pathlib import Path from typing import Any, Callable @@ -18,6 +19,44 @@ logger = logging.getLogger(__name__) DEFAULT_CONFIG_FILE = "agentkit.yaml" +@dataclass +class MCPServerConfig: + """Configuration for a single MCP Server connection""" + + transport: str # "stdio" | "streamable_http" | "sse" + # stdio-specific + command: str | None = None + args: list[str] | None = None + env: dict[str, str] | None = None + # http/sse-specific + url: str | None = None + headers: dict[str, str] | None = None + # common + timeout: float = 30.0 + + def validate(self) -> None: + """Validate configuration, raise ValueError if invalid""" + if self.transport not in ("stdio", "streamable_http", "sse"): + raise ValueError(f"Invalid transport: {self.transport}") + if self.transport == "stdio" and not self.command: + raise ValueError("stdio transport requires 'command'") + if self.transport in ("streamable_http", "sse") and not self.url: + raise ValueError(f"{self.transport} transport requires 'url'") + + @classmethod + def from_dict(cls, data: dict) -> "MCPServerConfig": + """Create from dict (parsed from YAML)""" + return cls( + transport=data.get("transport", "stdio"), + command=data.get("command"), + args=data.get("args"), + env=data.get("env"), + url=data.get("url"), + headers=data.get("headers"), + timeout=data.get("timeout", 30.0), + ) + + def _resolve_env_vars(value: Any) -> Any: """Resolve ${VAR:-default} patterns in string values from environment variables.""" if not isinstance(value, str): @@ -64,6 +103,8 @@ class ServerConfig: task_store: dict[str, Any] | None = None, cors_origins: list[str] | None = None, memory: dict[str, Any] | None = None, + mcp_servers: dict[str, MCPServerConfig] | None = None, + telemetry: dict[str, Any] | None = None, on_change: Callable[["ServerConfig"], None] | None = None, ): self.host = host @@ -79,6 +120,8 @@ class ServerConfig: self.task_store = task_store or {} self.cors_origins = cors_origins or ["*"] self.memory = memory or {} + self.mcp_servers = mcp_servers or {} + self.telemetry = telemetry or {} self.on_change = on_change # Config watching state @@ -109,6 +152,7 @@ class ServerConfig: logging_data = data.get("logging", {}) task_store_data = data.get("task_store", {}) memory_data = data.get("memory", {}) + mcp_data = data.get("mcp", {}) # Build LLMConfig llm_config = cls._build_llm_config(llm_data) @@ -117,6 +161,12 @@ class ServerConfig: skill_paths = skills_data.get("paths", []) auto_discover = skills_data.get("auto_discover", True) + # Build MCP server configs + mcp_servers = cls._build_mcp_configs(mcp_data) + + # Telemetry config + telemetry_data = data.get("telemetry", {}) + return cls( host=server.get("host", "0.0.0.0"), port=server.get("port", 8001), @@ -131,6 +181,8 @@ class ServerConfig: task_store=task_store_data, cors_origins=server.get("cors_origins"), memory=memory_data, + mcp_servers=mcp_servers, + telemetry=telemetry_data, ) @staticmethod @@ -165,6 +217,18 @@ class ServerConfig: fallbacks=data.get("fallbacks", {}), ) + @staticmethod + def _build_mcp_configs(data: dict) -> dict[str, MCPServerConfig]: + """Build MCP server configs from the mcp section of agentkit.yaml.""" + servers = data.get("servers", {}) + if not servers: + return {} + result = {} + for name, server_conf in servers.items(): + if isinstance(server_conf, dict): + result[name] = MCPServerConfig.from_dict(server_conf) + return result + def load_skill_configs(self) -> list[SkillConfig]: """Load all SkillConfig from configured skill paths.""" configs = [] @@ -307,6 +371,8 @@ class ServerConfig: self.task_store = new_config.task_store self.cors_origins = new_config.cors_origins self.memory = new_config.memory + self.mcp_servers = new_config.mcp_servers + self.telemetry = new_config.telemetry self._last_mtime = new_config._last_mtime logger.info(f"Config reloaded from {path}") diff --git a/tests/unit/test_mcp_config.py b/tests/unit/test_mcp_config.py new file mode 100644 index 0000000..af6c573 --- /dev/null +++ b/tests/unit/test_mcp_config.py @@ -0,0 +1,171 @@ +"""Tests for MCPServerConfig and ServerConfig MCP section parsing""" + +import pytest + +from agentkit.server.config import MCPServerConfig, ServerConfig + + +class TestMCPServerConfig: + """Tests for MCPServerConfig dataclass""" + + def test_from_dict_stdio(self): + data = { + "transport": "stdio", + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"], + "env": {"NODE_ENV": "production"}, + "timeout": 60.0, + } + config = MCPServerConfig.from_dict(data) + assert config.transport == "stdio" + assert config.command == "npx" + assert config.args == ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"] + assert config.env == {"NODE_ENV": "production"} + assert config.timeout == 60.0 + assert config.url is None + assert config.headers is None + + def test_from_dict_streamable_http(self): + data = { + "transport": "streamable_http", + "url": "http://localhost:3001/mcp", + "headers": {"Authorization": "Bearer test-token"}, + "timeout": 45.0, + } + config = MCPServerConfig.from_dict(data) + assert config.transport == "streamable_http" + assert config.url == "http://localhost:3001/mcp" + assert config.headers == {"Authorization": "Bearer test-token"} + assert config.timeout == 45.0 + assert config.command is None + assert config.args is None + + def test_from_dict_sse(self): + data = { + "transport": "sse", + "url": "http://localhost:3002/sse", + } + config = MCPServerConfig.from_dict(data) + assert config.transport == "sse" + assert config.url == "http://localhost:3002/sse" + assert config.command is None + + def test_from_dict_defaults(self): + data = {} + config = MCPServerConfig.from_dict(data) + assert config.transport == "stdio" + assert config.command is None + assert config.args is None + assert config.env is None + assert config.url is None + assert config.headers is None + assert config.timeout == 30.0 + + def test_validate_stdio_valid(self): + config = MCPServerConfig(transport="stdio", command="python") + config.validate() # Should not raise + + def test_validate_stdio_missing_command(self): + config = MCPServerConfig(transport="stdio", command=None) + with pytest.raises(ValueError, match="stdio transport requires 'command'"): + config.validate() + + def test_validate_streamable_http_missing_url(self): + config = MCPServerConfig(transport="streamable_http", url=None) + with pytest.raises(ValueError, match="streamable_http transport requires 'url'"): + config.validate() + + def test_validate_sse_missing_url(self): + config = MCPServerConfig(transport="sse", url=None) + with pytest.raises(ValueError, match="sse transport requires 'url'"): + config.validate() + + def test_validate_invalid_transport(self): + config = MCPServerConfig(transport="websocket") + with pytest.raises(ValueError, match="Invalid transport: websocket"): + config.validate() + + def test_validate_http_with_url(self): + config = MCPServerConfig(transport="streamable_http", url="http://localhost:3001") + config.validate() # Should not raise + + def test_validate_sse_with_url(self): + config = MCPServerConfig(transport="sse", url="http://localhost:3002") + config.validate() # Should not raise + + +class TestServerConfigMCPSection: + """Tests for ServerConfig parsing with mcp section""" + + def test_from_dict_with_mcp_servers(self): + data = { + "mcp": { + "servers": { + "filesystem": { + "transport": "stdio", + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"], + }, + "remote": { + "transport": "streamable_http", + "url": "http://localhost:3001/mcp", + }, + } + } + } + config = ServerConfig.from_dict(data) + assert len(config.mcp_servers) == 2 + assert "filesystem" in config.mcp_servers + assert "remote" in config.mcp_servers + assert config.mcp_servers["filesystem"].transport == "stdio" + assert config.mcp_servers["filesystem"].command == "npx" + assert config.mcp_servers["remote"].transport == "streamable_http" + assert config.mcp_servers["remote"].url == "http://localhost:3001/mcp" + + def test_from_dict_without_mcp_section(self): + data = {} + config = ServerConfig.from_dict(data) + assert config.mcp_servers == {} + + def test_from_dict_with_empty_mcp_servers(self): + data = {"mcp": {"servers": {}}} + config = ServerConfig.from_dict(data) + assert config.mcp_servers == {} + + def test_from_dict_mcp_servers_with_sse(self): + data = { + "mcp": { + "servers": { + "sse-server": { + "transport": "sse", + "url": "http://localhost:3002/sse", + "headers": {"X-API-Key": "secret"}, + "timeout": 60.0, + } + } + } + } + config = ServerConfig.from_dict(data) + assert len(config.mcp_servers) == 1 + sse_conf = config.mcp_servers["sse-server"] + assert sse_conf.transport == "sse" + assert sse_conf.url == "http://localhost:3002/sse" + assert sse_conf.headers == {"X-API-Key": "secret"} + assert sse_conf.timeout == 60.0 + + def test_from_dict_mcp_ignores_non_dict_entries(self): + data = { + "mcp": { + "servers": { + "valid": { + "transport": "stdio", + "command": "python", + }, + "invalid": "not-a-dict", + } + } + } + config = ServerConfig.from_dict(data) + assert len(config.mcp_servers) == 1 + assert "valid" in config.mcp_servers + assert "invalid" not in config.mcp_servers diff --git a/tests/unit/test_mcp_manager.py b/tests/unit/test_mcp_manager.py new file mode 100644 index 0000000..06916b9 --- /dev/null +++ b/tests/unit/test_mcp_manager.py @@ -0,0 +1,354 @@ +"""Tests for MCPManager lifecycle and tool discovery""" + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from agentkit.mcp.manager import MCPManager +from agentkit.mcp.transport import HTTPTransport, SSETransport, StdioTransport, Transport +from agentkit.server.config import MCPServerConfig +from agentkit.tools.registry import ToolRegistry + + +def _make_mock_transport(transport_type: str = "stdio") -> MagicMock: + """Create a mock Transport that behaves like a connected transport.""" + mock = MagicMock(spec=Transport) + mock.is_connected = True + mock.connect = AsyncMock() + mock.disconnect = AsyncMock() + mock.send_request = AsyncMock() + return mock + + +def _make_stdio_config() -> MCPServerConfig: + return MCPServerConfig( + transport="stdio", + command="python", + args=["-m", "mcp_server"], + timeout=30.0, + ) + + +def _make_http_config() -> MCPServerConfig: + return MCPServerConfig( + transport="streamable_http", + url="http://localhost:3001/mcp", + timeout=30.0, + ) + + +def _make_sse_config() -> MCPServerConfig: + return MCPServerConfig( + transport="sse", + url="http://localhost:3002/sse", + timeout=30.0, + ) + + +class TestMCPManagerConstruction: + """Tests for MCPManager initialization""" + + def test_construction_with_configs(self): + configs = { + "server1": _make_stdio_config(), + "server2": _make_http_config(), + } + manager = MCPManager(configs=configs) + assert len(manager._configs) == 2 + assert manager._tool_registry is not None + assert len(manager._clients) == 0 + assert len(manager._transports) == 0 + assert len(manager._available) == 0 + assert len(manager._server_tools) == 0 + + def test_construction_with_custom_tool_registry(self): + registry = ToolRegistry() + configs = {"server1": _make_stdio_config()} + manager = MCPManager(configs=configs, tool_registry=registry) + assert manager._tool_registry is registry + + def test_construction_with_empty_configs(self): + manager = MCPManager(configs={}) + assert len(manager._configs) == 0 + + +class TestMCPManagerStartAll: + """Tests for MCPManager.start_all()""" + + @patch("agentkit.mcp.manager.StdioTransport") + async def test_start_all_stdio_server(self, MockStdioTransport): + mock_transport = _make_mock_transport() + MockStdioTransport.return_value = mock_transport + + # Mock list_tools response via MCPClient + with patch("agentkit.mcp.manager.MCPClient") as MockClient: + mock_client = MagicMock() + mock_client.list_tools = AsyncMock(return_value=[ + {"name": "read_file", "description": "Read a file"}, + {"name": "write_file", "description": "Write a file"}, + ]) + mock_tool = MagicMock() + mock_tool.name = "read_file" + mock_client.as_tool = MagicMock(return_value=mock_tool) + MockClient.from_transport.return_value = mock_client + + configs = {"fs": _make_stdio_config()} + registry = ToolRegistry() + manager = MCPManager(configs=configs, tool_registry=registry) + + await manager.start_all() + + MockStdioTransport.assert_called_once() + mock_transport.connect.assert_called_once() + mock_client.list_tools.assert_called_once() + assert manager.is_available("fs") is True + assert manager.get_server_tools("fs") == ["read_file", "write_file"] + + @patch("agentkit.mcp.manager.HTTPTransport") + async def test_start_all_http_server(self, MockHTTPTransport): + mock_transport = _make_mock_transport() + MockHTTPTransport.return_value = mock_transport + + with patch("agentkit.mcp.manager.MCPClient") as MockClient: + mock_client = MagicMock() + mock_client.list_tools = AsyncMock(return_value=[ + {"name": "search", "description": "Search the web"}, + ]) + mock_tool = MagicMock() + mock_tool.name = "search" + mock_client.as_tool = MagicMock(return_value=mock_tool) + MockClient.from_transport.return_value = mock_client + + configs = {"web": _make_http_config()} + manager = MCPManager(configs=configs) + + await manager.start_all() + + MockHTTPTransport.assert_called_once() + mock_transport.connect.assert_called_once() + assert manager.is_available("web") is True + assert manager.get_server_tools("web") == ["search"] + + @patch("agentkit.mcp.manager.SSETransport") + async def test_start_all_sse_server(self, MockSSETransport): + mock_transport = _make_mock_transport() + MockSSETransport.return_value = mock_transport + + with patch("agentkit.mcp.manager.MCPClient") as MockClient: + mock_client = MagicMock() + mock_client.list_tools = AsyncMock(return_value=[ + {"name": "query", "description": "Query data"}, + ]) + mock_tool = MagicMock() + mock_tool.name = "query" + mock_client.as_tool = MagicMock(return_value=mock_tool) + MockClient.from_transport.return_value = mock_client + + configs = {"sse-srv": _make_sse_config()} + manager = MCPManager(configs=configs) + + await manager.start_all() + + MockSSETransport.assert_called_once() + assert manager.is_available("sse-srv") is True + + async def test_start_all_server_failure_doesnt_affect_others(self): + """One server failing should not prevent other servers from starting""" + with patch("agentkit.mcp.manager.StdioTransport") as MockStdio, \ + patch("agentkit.mcp.manager.HTTPTransport") as MockHTTP: + # First server fails + fail_transport = _make_mock_transport() + fail_transport.connect = AsyncMock(side_effect=Exception("Connection refused")) + MockStdio.return_value = fail_transport + + # Second server succeeds + ok_transport = _make_mock_transport() + MockHTTP.return_value = ok_transport + + with patch("agentkit.mcp.manager.MCPClient") as MockClient: + mock_client = MagicMock() + mock_client.list_tools = AsyncMock(return_value=[ + {"name": "search", "description": "Search"}, + ]) + mock_tool = MagicMock() + mock_tool.name = "search" + mock_client.as_tool = MagicMock(return_value=mock_tool) + MockClient.from_transport.return_value = mock_client + + configs = { + "failing": _make_stdio_config(), + "working": _make_http_config(), + } + manager = MCPManager(configs=configs) + + await manager.start_all() + + assert manager.is_available("failing") is False + assert manager.is_available("working") is True + assert manager.get_server_tools("working") == ["search"] + + +class TestMCPManagerStopAll: + """Tests for MCPManager.stop_all()""" + + @patch("agentkit.mcp.manager.StdioTransport") + async def test_stop_all(self, MockStdioTransport): + mock_transport = _make_mock_transport() + MockStdioTransport.return_value = mock_transport + + with patch("agentkit.mcp.manager.MCPClient") as MockClient: + mock_client = MagicMock() + mock_client.list_tools = AsyncMock(return_value=[]) + MockClient.from_transport.return_value = mock_client + + configs = {"srv": _make_stdio_config()} + manager = MCPManager(configs=configs) + await manager.start_all() + assert manager.is_available("srv") is True + + await manager.stop_all() + + mock_transport.disconnect.assert_called_once() + assert manager.is_available("srv") is False + assert len(manager._transports) == 0 + assert len(manager._clients) == 0 + + async def test_stop_all_handles_disconnect_error(self): + """stop_all should not raise even if disconnect fails""" + manager = MCPManager(configs={}) + + # Manually set up internal state to simulate a connected server + mock_transport = _make_mock_transport() + mock_transport.disconnect = AsyncMock(side_effect=Exception("Disconnect error")) + manager._transports = {"srv": mock_transport} + manager._available = {"srv": True} + + # Should not raise + await manager.stop_all() + assert manager.is_available("srv") is False + + +class TestMCPManagerQueryMethods: + """Tests for MCPManager query methods""" + + def test_is_available_unknown_server(self): + manager = MCPManager(configs={}) + assert manager.is_available("nonexistent") is False + + def test_get_server_tools_unknown_server(self): + manager = MCPManager(configs={}) + assert manager.get_server_tools("nonexistent") == [] + + def test_list_all_tools_empty(self): + manager = MCPManager(configs={}) + assert manager.list_all_tools() == [] + + def test_list_all_tools_with_servers(self): + manager = MCPManager(configs={}) + manager._server_tools = { + "srv1": ["tool_a", "tool_b"], + "srv2": ["tool_c"], + } + result = manager.list_all_tools() + assert sorted(result) == ["tool_a", "tool_b", "tool_c"] + + def test_get_tool_registry(self): + registry = ToolRegistry() + manager = MCPManager(configs={}, tool_registry=registry) + assert manager.get_tool_registry() is registry + + +class TestMCPManagerToolDiscovery: + """Tests for tool discovery and registration""" + + @patch("agentkit.mcp.manager.StdioTransport") + async def test_tools_registered_in_registry(self, MockStdioTransport): + mock_transport = _make_mock_transport() + MockStdioTransport.return_value = mock_transport + + with patch("agentkit.mcp.manager.MCPClient") as MockClient: + mock_client = MagicMock() + mock_client.list_tools = AsyncMock(return_value=[ + {"name": "read_file", "description": "Read a file"}, + {"name": "write_file", "description": "Write a file"}, + ]) + + # Create mock tools that the as_tool method returns + mock_tool_1 = MagicMock() + mock_tool_1.name = "read_file" + mock_tool_2 = MagicMock() + mock_tool_2.name = "write_file" + mock_client.as_tool = MagicMock(side_effect=[mock_tool_1, mock_tool_2]) + MockClient.from_transport.return_value = mock_client + + registry = ToolRegistry() + configs = {"fs": _make_stdio_config()} + manager = MCPManager(configs=configs, tool_registry=registry) + + await manager.start_all() + + # Verify tools were registered + assert registry.has_tool("read_file") + assert registry.has_tool("write_file") + + @patch("agentkit.mcp.manager.StdioTransport") + async def test_empty_tools_list(self, MockStdioTransport): + mock_transport = _make_mock_transport() + MockStdioTransport.return_value = mock_transport + + with patch("agentkit.mcp.manager.MCPClient") as MockClient: + mock_client = MagicMock() + mock_client.list_tools = AsyncMock(return_value=[]) + MockClient.from_transport.return_value = mock_client + + configs = {"empty": _make_stdio_config()} + manager = MCPManager(configs=configs) + + await manager.start_all() + + assert manager.is_available("empty") is True + assert manager.get_server_tools("empty") == [] + assert manager.list_all_tools() == [] + + @patch("agentkit.mcp.manager.StdioTransport") + async def test_multiple_servers_tools_combined(self, MockStdioTransport): + mock_transport = _make_mock_transport() + MockStdioTransport.return_value = mock_transport + + with patch("agentkit.mcp.manager.MCPClient") as MockClient: + # First call for srv1 + mock_client_1 = MagicMock() + mock_client_1.list_tools = AsyncMock(return_value=[ + {"name": "tool_a", "description": "Tool A"}, + ]) + mock_tool_a = MagicMock() + mock_tool_a.name = "tool_a" + mock_client_1.as_tool = MagicMock(return_value=mock_tool_a) + + # Second call for srv2 + mock_client_2 = MagicMock() + mock_client_2.list_tools = AsyncMock(return_value=[ + {"name": "tool_b", "description": "Tool B"}, + ]) + mock_tool_b = MagicMock() + mock_tool_b.name = "tool_b" + mock_client_2.as_tool = MagicMock(return_value=mock_tool_b) + + MockClient.from_transport.side_effect = [mock_client_1, mock_client_2] + + configs = { + "srv1": _make_stdio_config(), + "srv2": _make_stdio_config(), + } + manager = MCPManager(configs=configs) + + await manager.start_all() + + assert manager.get_server_tools("srv1") == ["tool_a"] + assert manager.get_server_tools("srv2") == ["tool_b"] + assert sorted(manager.list_all_tools()) == ["tool_a", "tool_b"] + + +# Run async tests with pytest-asyncio +pytest_plugins = ["pytest_asyncio"]