feat(mcp): U2 MCP config system and MCPManager lifecycle
Add MCPServerConfig dataclass with stdio/streamable_http/sse transport validation, MCPManager for declarative YAML-driven MCP server lifecycle (start_all/stop_all), tool discovery and registration. Integrated into FastAPI lifespan startup/shutdown.
This commit is contained in:
parent
66b9217569
commit
550d29a139
|
|
@ -1,12 +1,15 @@
|
||||||
"""AgentKit MCP - Model Context Protocol 支持"""
|
"""AgentKit MCP - Model Context Protocol 支持"""
|
||||||
|
|
||||||
from agentkit.mcp.transport import HTTPTransport, SSETransport, Transport, TransportError
|
from agentkit.mcp.manager import MCPManager
|
||||||
|
from agentkit.mcp.transport import HTTPTransport, SSETransport, StdioTransport, Transport, TransportError
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"MCPManager",
|
||||||
"MCPServer",
|
"MCPServer",
|
||||||
"MCPClient",
|
"MCPClient",
|
||||||
"Transport",
|
"Transport",
|
||||||
"HTTPTransport",
|
"HTTPTransport",
|
||||||
"SSETransport",
|
"SSETransport",
|
||||||
|
"StdioTransport",
|
||||||
"TransportError",
|
"TransportError",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from typing import Any
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from agentkit.mcp.transport import HTTPTransport, Transport
|
from agentkit.mcp.transport import HTTPTransport, SSETransport, StdioTransport, Transport
|
||||||
from agentkit.tools.base import Tool
|
from agentkit.tools.base import Tool
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -35,6 +35,10 @@ class MCPClient:
|
||||||
"""从 Transport 实例创建 MCPClient"""
|
"""从 Transport 实例创建 MCPClient"""
|
||||||
if isinstance(transport, HTTPTransport):
|
if isinstance(transport, HTTPTransport):
|
||||||
server_url = transport._endpoint
|
server_url = transport._endpoint
|
||||||
|
elif isinstance(transport, SSETransport):
|
||||||
|
server_url = transport._endpoint
|
||||||
|
elif isinstance(transport, StdioTransport):
|
||||||
|
server_url = f"stdio://{transport._command}"
|
||||||
else:
|
else:
|
||||||
server_url = ""
|
server_url = ""
|
||||||
return cls(server_url=server_url, transport=transport)
|
return cls(server_url=server_url, transport=transport)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,121 @@
|
||||||
|
"""MCP Manager - 管理 MCP Server 连接和工具发现"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, TYPE_CHECKING
|
||||||
|
|
||||||
|
from agentkit.mcp.client import MCPClient
|
||||||
|
from agentkit.mcp.transport import HTTPTransport, SSETransport, StdioTransport, Transport
|
||||||
|
from agentkit.tools.registry import ToolRegistry
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from agentkit.server.config import MCPServerConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class MCPManager:
|
||||||
|
"""管理 MCP Server 连接和工具发现
|
||||||
|
|
||||||
|
负责启动/停止 MCP Server 连接,发现远程工具并注册到 ToolRegistry。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
configs: dict[str, MCPServerConfig],
|
||||||
|
tool_registry: ToolRegistry | None = None,
|
||||||
|
):
|
||||||
|
self._configs = configs
|
||||||
|
self._tool_registry = tool_registry or ToolRegistry()
|
||||||
|
self._clients: dict[str, MCPClient] = {} # server_name -> MCPClient
|
||||||
|
self._transports: dict[str, Transport] = {} # server_name -> Transport
|
||||||
|
self._available: dict[str, bool] = {} # server_name -> is_available
|
||||||
|
self._server_tools: dict[str, list[str]] = {} # server_name -> [tool_names]
|
||||||
|
|
||||||
|
async def start_all(self) -> None:
|
||||||
|
"""启动所有配置的 MCP Server,发现并注册工具"""
|
||||||
|
for name, config in self._configs.items():
|
||||||
|
try:
|
||||||
|
await self._start_server(name, config)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to start MCP server '%s': %s", name, e)
|
||||||
|
self._available[name] = False
|
||||||
|
|
||||||
|
async def _start_server(self, name: str, config: MCPServerConfig) -> None:
|
||||||
|
"""启动单个 MCP Server"""
|
||||||
|
config.validate()
|
||||||
|
|
||||||
|
# 根据配置创建传输层
|
||||||
|
if config.transport == "stdio":
|
||||||
|
transport = StdioTransport(
|
||||||
|
command=config.command,
|
||||||
|
args=config.args or [],
|
||||||
|
env=config.env,
|
||||||
|
timeout=config.timeout,
|
||||||
|
)
|
||||||
|
elif config.transport == "streamable_http":
|
||||||
|
transport = HTTPTransport(
|
||||||
|
endpoint=config.url,
|
||||||
|
headers=config.headers,
|
||||||
|
timeout=config.timeout,
|
||||||
|
)
|
||||||
|
elif config.transport == "sse":
|
||||||
|
transport = SSETransport(
|
||||||
|
endpoint=config.url,
|
||||||
|
headers=config.headers,
|
||||||
|
timeout=config.timeout,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown transport: {config.transport}")
|
||||||
|
|
||||||
|
# 建立连接
|
||||||
|
await transport.connect()
|
||||||
|
self._transports[name] = transport
|
||||||
|
|
||||||
|
# 创建客户端并发现工具
|
||||||
|
client = MCPClient.from_transport(transport)
|
||||||
|
self._clients[name] = client
|
||||||
|
|
||||||
|
tools = await client.list_tools()
|
||||||
|
tool_names = []
|
||||||
|
for tool_info in tools:
|
||||||
|
tool_name = tool_info.get("name", "")
|
||||||
|
tool_desc = tool_info.get("description", "")
|
||||||
|
mcp_tool = client.as_tool(tool_name, tool_desc)
|
||||||
|
self._tool_registry.register(mcp_tool)
|
||||||
|
tool_names.append(tool_name)
|
||||||
|
|
||||||
|
self._server_tools[name] = tool_names
|
||||||
|
self._available[name] = True
|
||||||
|
logger.info("MCP server '%s' started with tools: %s", name, tool_names)
|
||||||
|
|
||||||
|
async def stop_all(self) -> None:
|
||||||
|
"""停止所有 MCP Server"""
|
||||||
|
for name, transport in self._transports.items():
|
||||||
|
try:
|
||||||
|
await transport.disconnect()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error stopping MCP server '%s': %s", name, e)
|
||||||
|
self._available[name] = False
|
||||||
|
self._transports.clear()
|
||||||
|
self._clients.clear()
|
||||||
|
|
||||||
|
def is_available(self, server_name: str) -> bool:
|
||||||
|
"""检查指定 MCP Server 是否可用"""
|
||||||
|
return self._available.get(server_name, False)
|
||||||
|
|
||||||
|
def get_server_tools(self, server_name: str) -> list[str]:
|
||||||
|
"""获取指定 MCP Server 提供的工具列表"""
|
||||||
|
return self._server_tools.get(server_name, [])
|
||||||
|
|
||||||
|
def list_all_tools(self) -> list[str]:
|
||||||
|
"""列出所有 MCP Server 提供的工具"""
|
||||||
|
all_tools: list[str] = []
|
||||||
|
for tools in self._server_tools.values():
|
||||||
|
all_tools.extend(tools)
|
||||||
|
return all_tools
|
||||||
|
|
||||||
|
def get_tool_registry(self) -> ToolRegistry:
|
||||||
|
"""获取工具注册中心"""
|
||||||
|
return self._tool_registry
|
||||||
|
|
@ -11,6 +11,7 @@ from agentkit.core.agent_pool import AgentPool
|
||||||
from agentkit.llm.gateway import LLMGateway
|
from agentkit.llm.gateway import LLMGateway
|
||||||
from agentkit.llm.providers.anthropic import AnthropicProvider
|
from agentkit.llm.providers.anthropic import AnthropicProvider
|
||||||
from agentkit.llm.providers.openai import OpenAICompatibleProvider
|
from agentkit.llm.providers.openai import OpenAICompatibleProvider
|
||||||
|
from agentkit.mcp.manager import MCPManager
|
||||||
from agentkit.quality.gate import QualityGate
|
from agentkit.quality.gate import QualityGate
|
||||||
from agentkit.quality.output import OutputStandardizer
|
from agentkit.quality.output import OutputStandardizer
|
||||||
from agentkit.router.intent import IntentRouter
|
from agentkit.router.intent import IntentRouter
|
||||||
|
|
@ -23,6 +24,7 @@ from agentkit.server.middleware import APIKeyAuthMiddleware, RateLimitMiddleware
|
||||||
from agentkit.server.task_store import create_task_store
|
from agentkit.server.task_store import create_task_store
|
||||||
from agentkit.server.runner import BackgroundRunner
|
from agentkit.server.runner import BackgroundRunner
|
||||||
from agentkit.core.logging import setup_structured_logging
|
from agentkit.core.logging import setup_structured_logging
|
||||||
|
from agentkit.telemetry.setup import setup_telemetry
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -87,9 +89,18 @@ async def lifespan(app: FastAPI):
|
||||||
server_config.watch_config()
|
server_config.watch_config()
|
||||||
logger.info("Config hot-reload enabled")
|
logger.info("Config hot-reload enabled")
|
||||||
|
|
||||||
|
# Start MCP servers if configured
|
||||||
|
mcp_manager = getattr(app.state, "mcp_manager", None)
|
||||||
|
if mcp_manager is not None:
|
||||||
|
await mcp_manager.start_all()
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# Shutdown
|
# Shutdown
|
||||||
|
# Stop MCP servers
|
||||||
|
if mcp_manager is not None:
|
||||||
|
await mcp_manager.stop_all()
|
||||||
|
|
||||||
if server_config is not None:
|
if server_config is not None:
|
||||||
server_config.stop_watching()
|
server_config.stop_watching()
|
||||||
|
|
||||||
|
|
@ -164,6 +175,10 @@ def create_app(
|
||||||
# Initialize structured logging
|
# Initialize structured logging
|
||||||
setup_structured_logging()
|
setup_structured_logging()
|
||||||
|
|
||||||
|
# Initialize OpenTelemetry (no-op if not installed or not configured)
|
||||||
|
if server_config:
|
||||||
|
setup_telemetry(app, server_config.telemetry)
|
||||||
|
|
||||||
# Resolve effective API key and rate limit
|
# Resolve effective API key and rate limit
|
||||||
effective_api_key = api_key
|
effective_api_key = api_key
|
||||||
effective_rate_limit = rate_limit
|
effective_rate_limit = rate_limit
|
||||||
|
|
@ -210,6 +225,15 @@ def create_app(
|
||||||
app.state.llm_gateway = llm_gateway or LLMGateway()
|
app.state.llm_gateway = llm_gateway or LLMGateway()
|
||||||
app.state.skill_registry = skill_registry or SkillRegistry()
|
app.state.skill_registry = skill_registry or SkillRegistry()
|
||||||
app.state.tool_registry = tool_registry or ToolRegistry()
|
app.state.tool_registry = tool_registry or ToolRegistry()
|
||||||
|
# Initialize MCPManager if MCP servers are configured
|
||||||
|
if server_config and server_config.mcp_servers:
|
||||||
|
mcp_manager = MCPManager(
|
||||||
|
configs=server_config.mcp_servers,
|
||||||
|
tool_registry=app.state.tool_registry,
|
||||||
|
)
|
||||||
|
app.state.mcp_manager = mcp_manager
|
||||||
|
else:
|
||||||
|
app.state.mcp_manager = None
|
||||||
app.state.agent_pool = AgentPool(
|
app.state.agent_pool = AgentPool(
|
||||||
llm_gateway=app.state.llm_gateway,
|
llm_gateway=app.state.llm_gateway,
|
||||||
skill_registry=app.state.skill_registry,
|
skill_registry=app.state.skill_registry,
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
|
@ -18,6 +19,44 @@ logger = logging.getLogger(__name__)
|
||||||
DEFAULT_CONFIG_FILE = "agentkit.yaml"
|
DEFAULT_CONFIG_FILE = "agentkit.yaml"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MCPServerConfig:
|
||||||
|
"""Configuration for a single MCP Server connection"""
|
||||||
|
|
||||||
|
transport: str # "stdio" | "streamable_http" | "sse"
|
||||||
|
# stdio-specific
|
||||||
|
command: str | None = None
|
||||||
|
args: list[str] | None = None
|
||||||
|
env: dict[str, str] | None = None
|
||||||
|
# http/sse-specific
|
||||||
|
url: str | None = None
|
||||||
|
headers: dict[str, str] | None = None
|
||||||
|
# common
|
||||||
|
timeout: float = 30.0
|
||||||
|
|
||||||
|
def validate(self) -> None:
|
||||||
|
"""Validate configuration, raise ValueError if invalid"""
|
||||||
|
if self.transport not in ("stdio", "streamable_http", "sse"):
|
||||||
|
raise ValueError(f"Invalid transport: {self.transport}")
|
||||||
|
if self.transport == "stdio" and not self.command:
|
||||||
|
raise ValueError("stdio transport requires 'command'")
|
||||||
|
if self.transport in ("streamable_http", "sse") and not self.url:
|
||||||
|
raise ValueError(f"{self.transport} transport requires 'url'")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict) -> "MCPServerConfig":
|
||||||
|
"""Create from dict (parsed from YAML)"""
|
||||||
|
return cls(
|
||||||
|
transport=data.get("transport", "stdio"),
|
||||||
|
command=data.get("command"),
|
||||||
|
args=data.get("args"),
|
||||||
|
env=data.get("env"),
|
||||||
|
url=data.get("url"),
|
||||||
|
headers=data.get("headers"),
|
||||||
|
timeout=data.get("timeout", 30.0),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _resolve_env_vars(value: Any) -> Any:
|
def _resolve_env_vars(value: Any) -> Any:
|
||||||
"""Resolve ${VAR:-default} patterns in string values from environment variables."""
|
"""Resolve ${VAR:-default} patterns in string values from environment variables."""
|
||||||
if not isinstance(value, str):
|
if not isinstance(value, str):
|
||||||
|
|
@ -64,6 +103,8 @@ class ServerConfig:
|
||||||
task_store: dict[str, Any] | None = None,
|
task_store: dict[str, Any] | None = None,
|
||||||
cors_origins: list[str] | None = None,
|
cors_origins: list[str] | None = None,
|
||||||
memory: dict[str, Any] | None = None,
|
memory: dict[str, Any] | None = None,
|
||||||
|
mcp_servers: dict[str, MCPServerConfig] | None = None,
|
||||||
|
telemetry: dict[str, Any] | None = None,
|
||||||
on_change: Callable[["ServerConfig"], None] | None = None,
|
on_change: Callable[["ServerConfig"], None] | None = None,
|
||||||
):
|
):
|
||||||
self.host = host
|
self.host = host
|
||||||
|
|
@ -79,6 +120,8 @@ class ServerConfig:
|
||||||
self.task_store = task_store or {}
|
self.task_store = task_store or {}
|
||||||
self.cors_origins = cors_origins or ["*"]
|
self.cors_origins = cors_origins or ["*"]
|
||||||
self.memory = memory or {}
|
self.memory = memory or {}
|
||||||
|
self.mcp_servers = mcp_servers or {}
|
||||||
|
self.telemetry = telemetry or {}
|
||||||
self.on_change = on_change
|
self.on_change = on_change
|
||||||
|
|
||||||
# Config watching state
|
# Config watching state
|
||||||
|
|
@ -109,6 +152,7 @@ class ServerConfig:
|
||||||
logging_data = data.get("logging", {})
|
logging_data = data.get("logging", {})
|
||||||
task_store_data = data.get("task_store", {})
|
task_store_data = data.get("task_store", {})
|
||||||
memory_data = data.get("memory", {})
|
memory_data = data.get("memory", {})
|
||||||
|
mcp_data = data.get("mcp", {})
|
||||||
|
|
||||||
# Build LLMConfig
|
# Build LLMConfig
|
||||||
llm_config = cls._build_llm_config(llm_data)
|
llm_config = cls._build_llm_config(llm_data)
|
||||||
|
|
@ -117,6 +161,12 @@ class ServerConfig:
|
||||||
skill_paths = skills_data.get("paths", [])
|
skill_paths = skills_data.get("paths", [])
|
||||||
auto_discover = skills_data.get("auto_discover", True)
|
auto_discover = skills_data.get("auto_discover", True)
|
||||||
|
|
||||||
|
# Build MCP server configs
|
||||||
|
mcp_servers = cls._build_mcp_configs(mcp_data)
|
||||||
|
|
||||||
|
# Telemetry config
|
||||||
|
telemetry_data = data.get("telemetry", {})
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
host=server.get("host", "0.0.0.0"),
|
host=server.get("host", "0.0.0.0"),
|
||||||
port=server.get("port", 8001),
|
port=server.get("port", 8001),
|
||||||
|
|
@ -131,6 +181,8 @@ class ServerConfig:
|
||||||
task_store=task_store_data,
|
task_store=task_store_data,
|
||||||
cors_origins=server.get("cors_origins"),
|
cors_origins=server.get("cors_origins"),
|
||||||
memory=memory_data,
|
memory=memory_data,
|
||||||
|
mcp_servers=mcp_servers,
|
||||||
|
telemetry=telemetry_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
@ -165,6 +217,18 @@ class ServerConfig:
|
||||||
fallbacks=data.get("fallbacks", {}),
|
fallbacks=data.get("fallbacks", {}),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_mcp_configs(data: dict) -> dict[str, MCPServerConfig]:
|
||||||
|
"""Build MCP server configs from the mcp section of agentkit.yaml."""
|
||||||
|
servers = data.get("servers", {})
|
||||||
|
if not servers:
|
||||||
|
return {}
|
||||||
|
result = {}
|
||||||
|
for name, server_conf in servers.items():
|
||||||
|
if isinstance(server_conf, dict):
|
||||||
|
result[name] = MCPServerConfig.from_dict(server_conf)
|
||||||
|
return result
|
||||||
|
|
||||||
def load_skill_configs(self) -> list[SkillConfig]:
|
def load_skill_configs(self) -> list[SkillConfig]:
|
||||||
"""Load all SkillConfig from configured skill paths."""
|
"""Load all SkillConfig from configured skill paths."""
|
||||||
configs = []
|
configs = []
|
||||||
|
|
@ -307,6 +371,8 @@ class ServerConfig:
|
||||||
self.task_store = new_config.task_store
|
self.task_store = new_config.task_store
|
||||||
self.cors_origins = new_config.cors_origins
|
self.cors_origins = new_config.cors_origins
|
||||||
self.memory = new_config.memory
|
self.memory = new_config.memory
|
||||||
|
self.mcp_servers = new_config.mcp_servers
|
||||||
|
self.telemetry = new_config.telemetry
|
||||||
self._last_mtime = new_config._last_mtime
|
self._last_mtime = new_config._last_mtime
|
||||||
|
|
||||||
logger.info(f"Config reloaded from {path}")
|
logger.info(f"Config reloaded from {path}")
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,171 @@
|
||||||
|
"""Tests for MCPServerConfig and ServerConfig MCP section parsing"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.server.config import MCPServerConfig, ServerConfig
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPServerConfig:
|
||||||
|
"""Tests for MCPServerConfig dataclass"""
|
||||||
|
|
||||||
|
def test_from_dict_stdio(self):
|
||||||
|
data = {
|
||||||
|
"transport": "stdio",
|
||||||
|
"command": "npx",
|
||||||
|
"args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"],
|
||||||
|
"env": {"NODE_ENV": "production"},
|
||||||
|
"timeout": 60.0,
|
||||||
|
}
|
||||||
|
config = MCPServerConfig.from_dict(data)
|
||||||
|
assert config.transport == "stdio"
|
||||||
|
assert config.command == "npx"
|
||||||
|
assert config.args == ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"]
|
||||||
|
assert config.env == {"NODE_ENV": "production"}
|
||||||
|
assert config.timeout == 60.0
|
||||||
|
assert config.url is None
|
||||||
|
assert config.headers is None
|
||||||
|
|
||||||
|
def test_from_dict_streamable_http(self):
|
||||||
|
data = {
|
||||||
|
"transport": "streamable_http",
|
||||||
|
"url": "http://localhost:3001/mcp",
|
||||||
|
"headers": {"Authorization": "Bearer test-token"},
|
||||||
|
"timeout": 45.0,
|
||||||
|
}
|
||||||
|
config = MCPServerConfig.from_dict(data)
|
||||||
|
assert config.transport == "streamable_http"
|
||||||
|
assert config.url == "http://localhost:3001/mcp"
|
||||||
|
assert config.headers == {"Authorization": "Bearer test-token"}
|
||||||
|
assert config.timeout == 45.0
|
||||||
|
assert config.command is None
|
||||||
|
assert config.args is None
|
||||||
|
|
||||||
|
def test_from_dict_sse(self):
|
||||||
|
data = {
|
||||||
|
"transport": "sse",
|
||||||
|
"url": "http://localhost:3002/sse",
|
||||||
|
}
|
||||||
|
config = MCPServerConfig.from_dict(data)
|
||||||
|
assert config.transport == "sse"
|
||||||
|
assert config.url == "http://localhost:3002/sse"
|
||||||
|
assert config.command is None
|
||||||
|
|
||||||
|
def test_from_dict_defaults(self):
|
||||||
|
data = {}
|
||||||
|
config = MCPServerConfig.from_dict(data)
|
||||||
|
assert config.transport == "stdio"
|
||||||
|
assert config.command is None
|
||||||
|
assert config.args is None
|
||||||
|
assert config.env is None
|
||||||
|
assert config.url is None
|
||||||
|
assert config.headers is None
|
||||||
|
assert config.timeout == 30.0
|
||||||
|
|
||||||
|
def test_validate_stdio_valid(self):
|
||||||
|
config = MCPServerConfig(transport="stdio", command="python")
|
||||||
|
config.validate() # Should not raise
|
||||||
|
|
||||||
|
def test_validate_stdio_missing_command(self):
|
||||||
|
config = MCPServerConfig(transport="stdio", command=None)
|
||||||
|
with pytest.raises(ValueError, match="stdio transport requires 'command'"):
|
||||||
|
config.validate()
|
||||||
|
|
||||||
|
def test_validate_streamable_http_missing_url(self):
|
||||||
|
config = MCPServerConfig(transport="streamable_http", url=None)
|
||||||
|
with pytest.raises(ValueError, match="streamable_http transport requires 'url'"):
|
||||||
|
config.validate()
|
||||||
|
|
||||||
|
def test_validate_sse_missing_url(self):
|
||||||
|
config = MCPServerConfig(transport="sse", url=None)
|
||||||
|
with pytest.raises(ValueError, match="sse transport requires 'url'"):
|
||||||
|
config.validate()
|
||||||
|
|
||||||
|
def test_validate_invalid_transport(self):
|
||||||
|
config = MCPServerConfig(transport="websocket")
|
||||||
|
with pytest.raises(ValueError, match="Invalid transport: websocket"):
|
||||||
|
config.validate()
|
||||||
|
|
||||||
|
def test_validate_http_with_url(self):
|
||||||
|
config = MCPServerConfig(transport="streamable_http", url="http://localhost:3001")
|
||||||
|
config.validate() # Should not raise
|
||||||
|
|
||||||
|
def test_validate_sse_with_url(self):
|
||||||
|
config = MCPServerConfig(transport="sse", url="http://localhost:3002")
|
||||||
|
config.validate() # Should not raise
|
||||||
|
|
||||||
|
|
||||||
|
class TestServerConfigMCPSection:
|
||||||
|
"""Tests for ServerConfig parsing with mcp section"""
|
||||||
|
|
||||||
|
def test_from_dict_with_mcp_servers(self):
|
||||||
|
data = {
|
||||||
|
"mcp": {
|
||||||
|
"servers": {
|
||||||
|
"filesystem": {
|
||||||
|
"transport": "stdio",
|
||||||
|
"command": "npx",
|
||||||
|
"args": ["-y", "@modelcontextprotocol/server-filesystem", "/tmp"],
|
||||||
|
},
|
||||||
|
"remote": {
|
||||||
|
"transport": "streamable_http",
|
||||||
|
"url": "http://localhost:3001/mcp",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
config = ServerConfig.from_dict(data)
|
||||||
|
assert len(config.mcp_servers) == 2
|
||||||
|
assert "filesystem" in config.mcp_servers
|
||||||
|
assert "remote" in config.mcp_servers
|
||||||
|
assert config.mcp_servers["filesystem"].transport == "stdio"
|
||||||
|
assert config.mcp_servers["filesystem"].command == "npx"
|
||||||
|
assert config.mcp_servers["remote"].transport == "streamable_http"
|
||||||
|
assert config.mcp_servers["remote"].url == "http://localhost:3001/mcp"
|
||||||
|
|
||||||
|
def test_from_dict_without_mcp_section(self):
|
||||||
|
data = {}
|
||||||
|
config = ServerConfig.from_dict(data)
|
||||||
|
assert config.mcp_servers == {}
|
||||||
|
|
||||||
|
def test_from_dict_with_empty_mcp_servers(self):
|
||||||
|
data = {"mcp": {"servers": {}}}
|
||||||
|
config = ServerConfig.from_dict(data)
|
||||||
|
assert config.mcp_servers == {}
|
||||||
|
|
||||||
|
def test_from_dict_mcp_servers_with_sse(self):
|
||||||
|
data = {
|
||||||
|
"mcp": {
|
||||||
|
"servers": {
|
||||||
|
"sse-server": {
|
||||||
|
"transport": "sse",
|
||||||
|
"url": "http://localhost:3002/sse",
|
||||||
|
"headers": {"X-API-Key": "secret"},
|
||||||
|
"timeout": 60.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
config = ServerConfig.from_dict(data)
|
||||||
|
assert len(config.mcp_servers) == 1
|
||||||
|
sse_conf = config.mcp_servers["sse-server"]
|
||||||
|
assert sse_conf.transport == "sse"
|
||||||
|
assert sse_conf.url == "http://localhost:3002/sse"
|
||||||
|
assert sse_conf.headers == {"X-API-Key": "secret"}
|
||||||
|
assert sse_conf.timeout == 60.0
|
||||||
|
|
||||||
|
def test_from_dict_mcp_ignores_non_dict_entries(self):
|
||||||
|
data = {
|
||||||
|
"mcp": {
|
||||||
|
"servers": {
|
||||||
|
"valid": {
|
||||||
|
"transport": "stdio",
|
||||||
|
"command": "python",
|
||||||
|
},
|
||||||
|
"invalid": "not-a-dict",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
config = ServerConfig.from_dict(data)
|
||||||
|
assert len(config.mcp_servers) == 1
|
||||||
|
assert "valid" in config.mcp_servers
|
||||||
|
assert "invalid" not in config.mcp_servers
|
||||||
|
|
@ -0,0 +1,354 @@
|
||||||
|
"""Tests for MCPManager lifecycle and tool discovery"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.mcp.manager import MCPManager
|
||||||
|
from agentkit.mcp.transport import HTTPTransport, SSETransport, StdioTransport, Transport
|
||||||
|
from agentkit.server.config import MCPServerConfig
|
||||||
|
from agentkit.tools.registry import ToolRegistry
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mock_transport(transport_type: str = "stdio") -> MagicMock:
|
||||||
|
"""Create a mock Transport that behaves like a connected transport."""
|
||||||
|
mock = MagicMock(spec=Transport)
|
||||||
|
mock.is_connected = True
|
||||||
|
mock.connect = AsyncMock()
|
||||||
|
mock.disconnect = AsyncMock()
|
||||||
|
mock.send_request = AsyncMock()
|
||||||
|
return mock
|
||||||
|
|
||||||
|
|
||||||
|
def _make_stdio_config() -> MCPServerConfig:
|
||||||
|
return MCPServerConfig(
|
||||||
|
transport="stdio",
|
||||||
|
command="python",
|
||||||
|
args=["-m", "mcp_server"],
|
||||||
|
timeout=30.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_http_config() -> MCPServerConfig:
|
||||||
|
return MCPServerConfig(
|
||||||
|
transport="streamable_http",
|
||||||
|
url="http://localhost:3001/mcp",
|
||||||
|
timeout=30.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_sse_config() -> MCPServerConfig:
|
||||||
|
return MCPServerConfig(
|
||||||
|
transport="sse",
|
||||||
|
url="http://localhost:3002/sse",
|
||||||
|
timeout=30.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPManagerConstruction:
|
||||||
|
"""Tests for MCPManager initialization"""
|
||||||
|
|
||||||
|
def test_construction_with_configs(self):
|
||||||
|
configs = {
|
||||||
|
"server1": _make_stdio_config(),
|
||||||
|
"server2": _make_http_config(),
|
||||||
|
}
|
||||||
|
manager = MCPManager(configs=configs)
|
||||||
|
assert len(manager._configs) == 2
|
||||||
|
assert manager._tool_registry is not None
|
||||||
|
assert len(manager._clients) == 0
|
||||||
|
assert len(manager._transports) == 0
|
||||||
|
assert len(manager._available) == 0
|
||||||
|
assert len(manager._server_tools) == 0
|
||||||
|
|
||||||
|
def test_construction_with_custom_tool_registry(self):
|
||||||
|
registry = ToolRegistry()
|
||||||
|
configs = {"server1": _make_stdio_config()}
|
||||||
|
manager = MCPManager(configs=configs, tool_registry=registry)
|
||||||
|
assert manager._tool_registry is registry
|
||||||
|
|
||||||
|
def test_construction_with_empty_configs(self):
|
||||||
|
manager = MCPManager(configs={})
|
||||||
|
assert len(manager._configs) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPManagerStartAll:
|
||||||
|
"""Tests for MCPManager.start_all()"""
|
||||||
|
|
||||||
|
@patch("agentkit.mcp.manager.StdioTransport")
|
||||||
|
async def test_start_all_stdio_server(self, MockStdioTransport):
|
||||||
|
mock_transport = _make_mock_transport()
|
||||||
|
MockStdioTransport.return_value = mock_transport
|
||||||
|
|
||||||
|
# Mock list_tools response via MCPClient
|
||||||
|
with patch("agentkit.mcp.manager.MCPClient") as MockClient:
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.list_tools = AsyncMock(return_value=[
|
||||||
|
{"name": "read_file", "description": "Read a file"},
|
||||||
|
{"name": "write_file", "description": "Write a file"},
|
||||||
|
])
|
||||||
|
mock_tool = MagicMock()
|
||||||
|
mock_tool.name = "read_file"
|
||||||
|
mock_client.as_tool = MagicMock(return_value=mock_tool)
|
||||||
|
MockClient.from_transport.return_value = mock_client
|
||||||
|
|
||||||
|
configs = {"fs": _make_stdio_config()}
|
||||||
|
registry = ToolRegistry()
|
||||||
|
manager = MCPManager(configs=configs, tool_registry=registry)
|
||||||
|
|
||||||
|
await manager.start_all()
|
||||||
|
|
||||||
|
MockStdioTransport.assert_called_once()
|
||||||
|
mock_transport.connect.assert_called_once()
|
||||||
|
mock_client.list_tools.assert_called_once()
|
||||||
|
assert manager.is_available("fs") is True
|
||||||
|
assert manager.get_server_tools("fs") == ["read_file", "write_file"]
|
||||||
|
|
||||||
|
@patch("agentkit.mcp.manager.HTTPTransport")
|
||||||
|
async def test_start_all_http_server(self, MockHTTPTransport):
|
||||||
|
mock_transport = _make_mock_transport()
|
||||||
|
MockHTTPTransport.return_value = mock_transport
|
||||||
|
|
||||||
|
with patch("agentkit.mcp.manager.MCPClient") as MockClient:
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.list_tools = AsyncMock(return_value=[
|
||||||
|
{"name": "search", "description": "Search the web"},
|
||||||
|
])
|
||||||
|
mock_tool = MagicMock()
|
||||||
|
mock_tool.name = "search"
|
||||||
|
mock_client.as_tool = MagicMock(return_value=mock_tool)
|
||||||
|
MockClient.from_transport.return_value = mock_client
|
||||||
|
|
||||||
|
configs = {"web": _make_http_config()}
|
||||||
|
manager = MCPManager(configs=configs)
|
||||||
|
|
||||||
|
await manager.start_all()
|
||||||
|
|
||||||
|
MockHTTPTransport.assert_called_once()
|
||||||
|
mock_transport.connect.assert_called_once()
|
||||||
|
assert manager.is_available("web") is True
|
||||||
|
assert manager.get_server_tools("web") == ["search"]
|
||||||
|
|
||||||
|
@patch("agentkit.mcp.manager.SSETransport")
|
||||||
|
async def test_start_all_sse_server(self, MockSSETransport):
|
||||||
|
mock_transport = _make_mock_transport()
|
||||||
|
MockSSETransport.return_value = mock_transport
|
||||||
|
|
||||||
|
with patch("agentkit.mcp.manager.MCPClient") as MockClient:
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.list_tools = AsyncMock(return_value=[
|
||||||
|
{"name": "query", "description": "Query data"},
|
||||||
|
])
|
||||||
|
mock_tool = MagicMock()
|
||||||
|
mock_tool.name = "query"
|
||||||
|
mock_client.as_tool = MagicMock(return_value=mock_tool)
|
||||||
|
MockClient.from_transport.return_value = mock_client
|
||||||
|
|
||||||
|
configs = {"sse-srv": _make_sse_config()}
|
||||||
|
manager = MCPManager(configs=configs)
|
||||||
|
|
||||||
|
await manager.start_all()
|
||||||
|
|
||||||
|
MockSSETransport.assert_called_once()
|
||||||
|
assert manager.is_available("sse-srv") is True
|
||||||
|
|
||||||
|
async def test_start_all_server_failure_doesnt_affect_others(self):
|
||||||
|
"""One server failing should not prevent other servers from starting"""
|
||||||
|
with patch("agentkit.mcp.manager.StdioTransport") as MockStdio, \
|
||||||
|
patch("agentkit.mcp.manager.HTTPTransport") as MockHTTP:
|
||||||
|
# First server fails
|
||||||
|
fail_transport = _make_mock_transport()
|
||||||
|
fail_transport.connect = AsyncMock(side_effect=Exception("Connection refused"))
|
||||||
|
MockStdio.return_value = fail_transport
|
||||||
|
|
||||||
|
# Second server succeeds
|
||||||
|
ok_transport = _make_mock_transport()
|
||||||
|
MockHTTP.return_value = ok_transport
|
||||||
|
|
||||||
|
with patch("agentkit.mcp.manager.MCPClient") as MockClient:
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.list_tools = AsyncMock(return_value=[
|
||||||
|
{"name": "search", "description": "Search"},
|
||||||
|
])
|
||||||
|
mock_tool = MagicMock()
|
||||||
|
mock_tool.name = "search"
|
||||||
|
mock_client.as_tool = MagicMock(return_value=mock_tool)
|
||||||
|
MockClient.from_transport.return_value = mock_client
|
||||||
|
|
||||||
|
configs = {
|
||||||
|
"failing": _make_stdio_config(),
|
||||||
|
"working": _make_http_config(),
|
||||||
|
}
|
||||||
|
manager = MCPManager(configs=configs)
|
||||||
|
|
||||||
|
await manager.start_all()
|
||||||
|
|
||||||
|
assert manager.is_available("failing") is False
|
||||||
|
assert manager.is_available("working") is True
|
||||||
|
assert manager.get_server_tools("working") == ["search"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPManagerStopAll:
|
||||||
|
"""Tests for MCPManager.stop_all()"""
|
||||||
|
|
||||||
|
@patch("agentkit.mcp.manager.StdioTransport")
|
||||||
|
async def test_stop_all(self, MockStdioTransport):
|
||||||
|
mock_transport = _make_mock_transport()
|
||||||
|
MockStdioTransport.return_value = mock_transport
|
||||||
|
|
||||||
|
with patch("agentkit.mcp.manager.MCPClient") as MockClient:
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.list_tools = AsyncMock(return_value=[])
|
||||||
|
MockClient.from_transport.return_value = mock_client
|
||||||
|
|
||||||
|
configs = {"srv": _make_stdio_config()}
|
||||||
|
manager = MCPManager(configs=configs)
|
||||||
|
await manager.start_all()
|
||||||
|
assert manager.is_available("srv") is True
|
||||||
|
|
||||||
|
await manager.stop_all()
|
||||||
|
|
||||||
|
mock_transport.disconnect.assert_called_once()
|
||||||
|
assert manager.is_available("srv") is False
|
||||||
|
assert len(manager._transports) == 0
|
||||||
|
assert len(manager._clients) == 0
|
||||||
|
|
||||||
|
async def test_stop_all_handles_disconnect_error(self):
|
||||||
|
"""stop_all should not raise even if disconnect fails"""
|
||||||
|
manager = MCPManager(configs={})
|
||||||
|
|
||||||
|
# Manually set up internal state to simulate a connected server
|
||||||
|
mock_transport = _make_mock_transport()
|
||||||
|
mock_transport.disconnect = AsyncMock(side_effect=Exception("Disconnect error"))
|
||||||
|
manager._transports = {"srv": mock_transport}
|
||||||
|
manager._available = {"srv": True}
|
||||||
|
|
||||||
|
# Should not raise
|
||||||
|
await manager.stop_all()
|
||||||
|
assert manager.is_available("srv") is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPManagerQueryMethods:
|
||||||
|
"""Tests for MCPManager query methods"""
|
||||||
|
|
||||||
|
def test_is_available_unknown_server(self):
|
||||||
|
manager = MCPManager(configs={})
|
||||||
|
assert manager.is_available("nonexistent") is False
|
||||||
|
|
||||||
|
def test_get_server_tools_unknown_server(self):
|
||||||
|
manager = MCPManager(configs={})
|
||||||
|
assert manager.get_server_tools("nonexistent") == []
|
||||||
|
|
||||||
|
def test_list_all_tools_empty(self):
|
||||||
|
manager = MCPManager(configs={})
|
||||||
|
assert manager.list_all_tools() == []
|
||||||
|
|
||||||
|
def test_list_all_tools_with_servers(self):
|
||||||
|
manager = MCPManager(configs={})
|
||||||
|
manager._server_tools = {
|
||||||
|
"srv1": ["tool_a", "tool_b"],
|
||||||
|
"srv2": ["tool_c"],
|
||||||
|
}
|
||||||
|
result = manager.list_all_tools()
|
||||||
|
assert sorted(result) == ["tool_a", "tool_b", "tool_c"]
|
||||||
|
|
||||||
|
def test_get_tool_registry(self):
|
||||||
|
registry = ToolRegistry()
|
||||||
|
manager = MCPManager(configs={}, tool_registry=registry)
|
||||||
|
assert manager.get_tool_registry() is registry
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPManagerToolDiscovery:
|
||||||
|
"""Tests for tool discovery and registration"""
|
||||||
|
|
||||||
|
@patch("agentkit.mcp.manager.StdioTransport")
|
||||||
|
async def test_tools_registered_in_registry(self, MockStdioTransport):
|
||||||
|
mock_transport = _make_mock_transport()
|
||||||
|
MockStdioTransport.return_value = mock_transport
|
||||||
|
|
||||||
|
with patch("agentkit.mcp.manager.MCPClient") as MockClient:
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.list_tools = AsyncMock(return_value=[
|
||||||
|
{"name": "read_file", "description": "Read a file"},
|
||||||
|
{"name": "write_file", "description": "Write a file"},
|
||||||
|
])
|
||||||
|
|
||||||
|
# Create mock tools that the as_tool method returns
|
||||||
|
mock_tool_1 = MagicMock()
|
||||||
|
mock_tool_1.name = "read_file"
|
||||||
|
mock_tool_2 = MagicMock()
|
||||||
|
mock_tool_2.name = "write_file"
|
||||||
|
mock_client.as_tool = MagicMock(side_effect=[mock_tool_1, mock_tool_2])
|
||||||
|
MockClient.from_transport.return_value = mock_client
|
||||||
|
|
||||||
|
registry = ToolRegistry()
|
||||||
|
configs = {"fs": _make_stdio_config()}
|
||||||
|
manager = MCPManager(configs=configs, tool_registry=registry)
|
||||||
|
|
||||||
|
await manager.start_all()
|
||||||
|
|
||||||
|
# Verify tools were registered
|
||||||
|
assert registry.has_tool("read_file")
|
||||||
|
assert registry.has_tool("write_file")
|
||||||
|
|
||||||
|
@patch("agentkit.mcp.manager.StdioTransport")
|
||||||
|
async def test_empty_tools_list(self, MockStdioTransport):
|
||||||
|
mock_transport = _make_mock_transport()
|
||||||
|
MockStdioTransport.return_value = mock_transport
|
||||||
|
|
||||||
|
with patch("agentkit.mcp.manager.MCPClient") as MockClient:
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.list_tools = AsyncMock(return_value=[])
|
||||||
|
MockClient.from_transport.return_value = mock_client
|
||||||
|
|
||||||
|
configs = {"empty": _make_stdio_config()}
|
||||||
|
manager = MCPManager(configs=configs)
|
||||||
|
|
||||||
|
await manager.start_all()
|
||||||
|
|
||||||
|
assert manager.is_available("empty") is True
|
||||||
|
assert manager.get_server_tools("empty") == []
|
||||||
|
assert manager.list_all_tools() == []
|
||||||
|
|
||||||
|
@patch("agentkit.mcp.manager.StdioTransport")
|
||||||
|
async def test_multiple_servers_tools_combined(self, MockStdioTransport):
|
||||||
|
mock_transport = _make_mock_transport()
|
||||||
|
MockStdioTransport.return_value = mock_transport
|
||||||
|
|
||||||
|
with patch("agentkit.mcp.manager.MCPClient") as MockClient:
|
||||||
|
# First call for srv1
|
||||||
|
mock_client_1 = MagicMock()
|
||||||
|
mock_client_1.list_tools = AsyncMock(return_value=[
|
||||||
|
{"name": "tool_a", "description": "Tool A"},
|
||||||
|
])
|
||||||
|
mock_tool_a = MagicMock()
|
||||||
|
mock_tool_a.name = "tool_a"
|
||||||
|
mock_client_1.as_tool = MagicMock(return_value=mock_tool_a)
|
||||||
|
|
||||||
|
# Second call for srv2
|
||||||
|
mock_client_2 = MagicMock()
|
||||||
|
mock_client_2.list_tools = AsyncMock(return_value=[
|
||||||
|
{"name": "tool_b", "description": "Tool B"},
|
||||||
|
])
|
||||||
|
mock_tool_b = MagicMock()
|
||||||
|
mock_tool_b.name = "tool_b"
|
||||||
|
mock_client_2.as_tool = MagicMock(return_value=mock_tool_b)
|
||||||
|
|
||||||
|
MockClient.from_transport.side_effect = [mock_client_1, mock_client_2]
|
||||||
|
|
||||||
|
configs = {
|
||||||
|
"srv1": _make_stdio_config(),
|
||||||
|
"srv2": _make_stdio_config(),
|
||||||
|
}
|
||||||
|
manager = MCPManager(configs=configs)
|
||||||
|
|
||||||
|
await manager.start_all()
|
||||||
|
|
||||||
|
assert manager.get_server_tools("srv1") == ["tool_a"]
|
||||||
|
assert manager.get_server_tools("srv2") == ["tool_b"]
|
||||||
|
assert sorted(manager.list_all_tools()) == ["tool_a", "tool_b"]
|
||||||
|
|
||||||
|
|
||||||
|
# Run async tests with pytest-asyncio
|
||||||
|
pytest_plugins = ["pytest_asyncio"]
|
||||||
Loading…
Reference in New Issue