From 86541d71727b690dd5b928187689364e8a36099c Mon Sep 17 00:00:00 2001 From: chiguyong Date: Thu, 25 Jun 2026 22:04:37 +0800 Subject: [PATCH] =?UTF-8?q?feat(mcp):=20U16=20=E2=80=94=20langchain-mcp-ad?= =?UTF-8?q?apters=20client=20replacement=20+=20transport=20deprecation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 重写 MCPClient:URL scheme 自动检测(stdio/http/sse)→ langchain config - 旧 Transport 注入路径保留(DeprecationWarning),向后兼容 - transport.py 模块级弃用警告 - 28 个新测试覆盖 URL 检测、list_tools、call_tool、legacy 路径、ImportError - 修复 manager.py / transport.py 预存 F401/F841 --- pyproject.toml | 2 + src/agentkit/mcp/client.py | 205 +++++++++++++++--- src/agentkit/mcp/manager.py | 7 +- src/agentkit/mcp/transport.py | 30 +-- tests/unit/mcp/test_client.py | 391 ++++++++++++++++++++++++++++++++++ tests/unit/test_mcp_client.py | 142 ++---------- 6 files changed, 613 insertions(+), 164 deletions(-) create mode 100644 tests/unit/mcp/test_client.py diff --git a/pyproject.toml b/pyproject.toml index 51d901a..e713435 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,6 +63,8 @@ server = [ ] mcp = [ "mcp>=1.0", + # U16 — langchain-mcp-adapters 替换自研 MCP 客户端传输层 + "langchain-mcp-adapters>=0.1", ] evolution = [ "scipy>=1.12", diff --git a/src/agentkit/mcp/client.py b/src/agentkit/mcp/client.py index 448b452..202fae3 100644 --- a/src/agentkit/mcp/client.py +++ b/src/agentkit/mcp/client.py @@ -1,22 +1,56 @@ -"""MCP Client - 调用外部 MCP 工具服务器""" +"""MCP Client - 调用外部 MCP 工具服务器 +U16: 重写为 langchain-mcp-adapters 包装层。保留 MCPClient / MCPTool API 向后兼容。 +旧的 Transport 注入路径保留(发出 DeprecationWarning),供现有调用方过渡。 + +传输类型由 URL scheme 自动检测(transport=None 时): +- ``stdio://command arg1 arg2`` → stdio 传输 +- ``http://`` / ``https://`` → streamable_http 传输 +- ``sse://...`` → sse 传输(自动转为 http://) +""" + +from __future__ import annotations + +import json import logging -from typing import Any - -import httpx +import warnings +from typing import TYPE_CHECKING, Any +from urllib.parse import urlparse from agentkit.mcp.transport import HTTPTransport, SSETransport, StdioTransport, Transport from agentkit.tools.base import Tool +if TYPE_CHECKING: + from langchain_mcp_adapters.client import MultiServerMCPClient + logger = logging.getLogger(__name__) +def _import_langchain_client() -> type["MultiServerMCPClient"]: + """延迟导入 langchain-mcp-adapters 的 MultiServerMCPClient。 + + 未安装时抛出带提示信息的 ImportError,便于调用方回退到 Transport 路径。 + """ + try: + from langchain_mcp_adapters.client import MultiServerMCPClient as _Client + except ImportError as e: + raise ImportError( + "langchain-mcp-adapters 未安装,无法使用 langchain 传输路径。" + "请执行 pip install 'fischer-agentkit[mcp]' 或 pip install langchain-mcp-adapters。" + "如需使用旧的 Transport 路径,请通过 MCPClient.from_transport(transport) 创建客户端。" + ) from e + return _Client + + class MCPClient: """MCP Client - 连接外部 MCP Server 并调用工具 + U16: 内部使用 langchain-mcp-adapters 的 MultiServerMCPClient 管理连接。 + 保留旧 Transport 注入路径以向后兼容(发出 DeprecationWarning)。 + 支持两种模式: - 1. 通过 Transport 层发送 JSON-RPC 请求(推荐) - 2. 直接 HTTP 调用(向后兼容) + 1. URL scheme 自动检测(推荐):transport=None,根据 server_url 的 scheme 选择传输 + 2. Transport 注入(旧路径,弃用):传入 Transport 实例,走原有 JSON-RPC 路径 """ def __init__( @@ -25,14 +59,84 @@ class MCPClient: timeout: int = 30, transport: Transport | None = None, ): - self._server_url = server_url.rstrip("/") + self._server_url = server_url.rstrip("/") if server_url else "" self._timeout = timeout self._tools_cache: list[dict] | None = None self._transport = transport + if transport is not None: + # 旧 Transport 路径 — 发出 DeprecationWarning,但保持原有行为 + warnings.warn( + "通过 Transport 实例创建 MCPClient 已弃用,请使用 URL scheme 自动检测" + "(如 MCPClient('http://...') 或 MCPClient('stdio://...'))。" + "详见 U16 迁移指南。将在下个版本移除。", + DeprecationWarning, + stacklevel=2, + ) + self._langchain_config: dict[str, Any] | None = None + else: + # 新 langchain 路径 — 解析 URL scheme 构建连接配置 + self._langchain_config = self._build_langchain_config(self._server_url, timeout) + + @staticmethod + def _build_langchain_config(server_url: str, timeout: float) -> dict[str, Any]: + """根据 URL scheme 构建 langchain-mcp-adapters 连接配置。 + + Args: + server_url: 服务器 URL,支持 stdio:// http:// https:// sse:// + timeout: 超时秒数(仅 http/sse 传输使用) + + Returns: + langchain-mcp-adapters 连接配置 dict + + Raises: + ValueError: URL 为空或 scheme 不支持 + """ + if not server_url: + raise ValueError("server_url 不能为空(transport=None 时需要有效 URL)") + + # stdio 特殊处理:stdio://command arg1 arg2 → command + args + # 不用 urlparse,因为命令行参数可能含特殊字符 + if server_url.startswith("stdio://"): + raw = server_url[len("stdio://") :] + parts = raw.split() + if not parts: + raise ValueError(f"无效的 stdio URL: {server_url!r}") + return { + "transport": "stdio", + "command": parts[0], + "args": parts[1:], + } + + parsed = urlparse(server_url) + scheme = parsed.scheme.lower() + + if scheme in ("http", "https"): + return { + "transport": "streamable_http", + "url": server_url, + "timeout": timeout, + } + + if scheme == "sse": + # sse://host/path → http://host/path(sse scheme 不携带 TLS 信息) + converted = "http://" + server_url[len("sse://") :] + return { + "transport": "sse", + "url": converted, + "timeout": timeout, + } + + raise ValueError( + f"不支持的 URL scheme: {scheme!r}。支持 stdio://, http://, https://, sse://" + ) + @classmethod def from_transport(cls, transport: Transport) -> "MCPClient": - """从 Transport 实例创建 MCPClient""" + """从 Transport 实例创建 MCPClient(旧路径,向后兼容) + + 发出 DeprecationWarning。建议迁移到 URL scheme 自动检测。 + """ if isinstance(transport, HTTPTransport): server_url = transport._endpoint elif isinstance(transport, SSETransport): @@ -44,8 +148,16 @@ class MCPClient: return cls(server_url=server_url, transport=transport) async def list_tools(self) -> list[dict]: - """列出远程 MCP Server 上的工具""" + """列出远程 MCP Server 上的工具 + + Returns: + 工具列表,每项为 ``{"name":..., "description":..., "inputSchema":...}`` + + Raises: + ImportError: langchain-mcp-adapters 未安装(仅 transport=None 路径) + """ if self._transport is not None: + # 旧 Transport 路径 if not self._transport.is_connected: await self._transport.connect() result = await self._transport.send_request("tools/list") @@ -53,16 +165,52 @@ class MCPClient: self._tools_cache = tools return self._tools_cache - async with httpx.AsyncClient(timeout=self._timeout) as client: - response = await client.get(f"{self._server_url}/tools/list") - response.raise_for_status() - data = response.json() - self._tools_cache = data.get("tools", []) - return self._tools_cache + # 新 langchain 路径 + client_cls = _import_langchain_client() + client = client_cls({"server": self._langchain_config}) + lc_tools = await client.get_tools() + tools = [ + { + "name": t.name, + "description": t.description or "", + "inputSchema": self._extract_schema(t), + } + for t in lc_tools + ] + self._tools_cache = tools + return tools + + @staticmethod + def _extract_schema(tool: Any) -> dict[str, Any]: + """从 LangChain Tool 提取 inputSchema(JSON Schema 格式)。 + + LangChain 工具通常有 args_schema(pydantic model),回退到 tool.args dict。 + """ + args_schema = getattr(tool, "args_schema", None) + if args_schema is not None and hasattr(args_schema, "model_json_schema"): + return args_schema.model_json_schema() + # ponytail: tool.args 在 langchain BaseTool 上是 dict,够用作回退 + args = getattr(tool, "args", None) + if isinstance(args, dict): + return args + return {"type": "object", "properties": {}} async def call_tool(self, tool_name: str, arguments: dict) -> dict: - """调用远程 MCP 工具""" + """调用远程 MCP 工具 + + Args: + tool_name: 工具名称 + arguments: 工具参数 + + Returns: + MCP 响应格式的 dict:``{"content": [{"type":"text","text":...}]}`` + + Raises: + KeyError: 工具不存在(仅 transport=None 路径) + ImportError: langchain-mcp-adapters 未安装 + """ if self._transport is not None: + # 旧 Transport 路径 if not self._transport.is_connected: await self._transport.connect() return await self._transport.send_request( @@ -70,13 +218,23 @@ class MCPClient: params={"name": tool_name, "arguments": arguments}, ) - async with httpx.AsyncClient(timeout=self._timeout) as client: - response = await client.post( - f"{self._server_url}/tools/call", - json={"name": tool_name, "arguments": arguments}, - ) - response.raise_for_status() - return response.json() + # 新 langchain 路径 + client_cls = _import_langchain_client() + client = client_cls({"server": self._langchain_config}) + lc_tools = await client.get_tools() + for tool in lc_tools: + if tool.name == tool_name: + result = await tool.ainvoke(arguments) + # 包装为 MCP 响应格式,保持 call_tool 返回形状与旧路径一致 + return { + "content": [ + { + "type": "text", + "text": json.dumps(result, ensure_ascii=False, default=str), + } + ] + } + raise KeyError(f"MCP 工具不存在: {tool_name!r}") def as_tool(self, tool_name: str, description: str = "") -> "MCPTool": """将远程 MCP 工具包装为本地 Tool 对象""" @@ -116,7 +274,6 @@ class MCPTool(Tool): if "content" in result: for item in result["content"]: if item.get("type") == "text": - import json try: return json.loads(item["text"]) except json.JSONDecodeError: diff --git a/src/agentkit/mcp/manager.py b/src/agentkit/mcp/manager.py index b27ab49..62fda37 100644 --- a/src/agentkit/mcp/manager.py +++ b/src/agentkit/mcp/manager.py @@ -4,7 +4,7 @@ from __future__ import annotations import asyncio import logging -from typing import Any, TYPE_CHECKING +from typing import TYPE_CHECKING from agentkit.mcp.client import MCPClient from agentkit.mcp.transport import HTTPTransport, SSETransport, StdioTransport, Transport @@ -39,10 +39,7 @@ class MCPManager: 使用 asyncio.gather 并发启动,单个服务器失败不影响其他服务器。 """ - tasks = [ - self._start_server_safe(name, config) - for name, config in self._configs.items() - ] + tasks = [self._start_server_safe(name, config) for name, config in self._configs.items()] await asyncio.gather(*tasks) async def _start_server_safe(self, name: str, config: MCPServerConfig) -> None: diff --git a/src/agentkit/mcp/transport.py b/src/agentkit/mcp/transport.py index f54624f..1c7c705 100644 --- a/src/agentkit/mcp/transport.py +++ b/src/agentkit/mcp/transport.py @@ -1,12 +1,16 @@ """MCP Transport - 传输层抽象 提供 MCP 协议的传输层实现,支持 Streamable HTTP、SSE 和 Stdio 三种传输方式。 + +U16: 本模块已弃用。请通过 ``MCPClient(url)`` 使用 langchain-mcp-adapters 传输层。 +保留所有类以向后兼容,将在下个版本移除。 """ import asyncio import json import logging import os +import warnings from abc import ABC, abstractmethod from typing import Any @@ -14,6 +18,14 @@ import httpx logger = logging.getLogger(__name__) +# U16: 模块级弃用警告 — 导入时发出,提示迁移到 langchain-mcp-adapters +warnings.warn( + "agentkit.mcp.transport 已弃用 — 请通过 MCPClient(url) 使用 langchain-mcp-adapters 传输层。" + "详见 U16 迁移指南。将在下个版本移除。", + DeprecationWarning, + stacklevel=2, +) + class TransportError(Exception): """传输层错误""" @@ -158,9 +170,7 @@ class HTTPTransport(Transport): # 检查 JSON-RPC 错误 if "error" in data: error = data["error"] - raise TransportError( - f"JSON-RPC error {error.get('code')}: {error.get('message')}" - ) + raise TransportError(f"JSON-RPC error {error.get('code')}: {error.get('message')}") return data.get("result") @@ -264,7 +274,7 @@ class SSETransport(Transport): if not line or line.startswith(":"): continue if line.startswith("data:"): - data_str = line[len("data:"):].strip() + data_str = line[len("data:") :].strip() try: data = json.loads(data_str) await self._response_queue.put(data) @@ -328,9 +338,7 @@ class SSETransport(Transport): # 检查 JSON-RPC 错误 if "error" in data: error = data["error"] - raise TransportError( - f"JSON-RPC error {error.get('code')}: {error.get('message')}" - ) + raise TransportError(f"JSON-RPC error {error.get('code')}: {error.get('message')}") return data.get("result") @@ -382,11 +390,7 @@ class StdioTransport(Transport): @property def is_connected(self) -> bool: - return ( - self._connected - and self._process is not None - and self._process.returncode is None - ) + return self._connected and self._process is not None and self._process.returncode is None def _next_request_id(self) -> int: """生成下一个请求 ID""" @@ -427,7 +431,7 @@ class StdioTransport(Transport): # 发送 initialize 请求并等待响应 try: - init_result = await asyncio.wait_for( + await asyncio.wait_for( self._send_request_internal( "initialize", { diff --git a/tests/unit/mcp/test_client.py b/tests/unit/mcp/test_client.py new file mode 100644 index 0000000..74d9fdd --- /dev/null +++ b/tests/unit/mcp/test_client.py @@ -0,0 +1,391 @@ +"""U16 — MCPClient langchain-mcp-adapters 路径单元测试。 + +覆盖场景: +1. URL scheme 检测(http / sse / stdio) +2. list_tools via langchain(三种传输) +3. call_tool via langchain(正常 / 未知工具) +4. 旧 Transport 路径兼容(DeprecationWarning) +5. MCPTool.execute via 新客户端(JSON / 纯文本 / 无 content) +6. langchain 未安装时 ImportError +7. transport.py 模块级 DeprecationWarning +8. timeout 参数透传 +""" + +from __future__ import annotations + +import importlib +import sys +import warnings +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from agentkit.mcp.client import MCPClient, MCPTool +from agentkit.mcp.transport import HTTPTransport + + +# ── 辅助工具 ────────────────────────────────────────────────── + + +def _make_mock_langchain_tool( + name: str = "my_tool", + description: str = "Test tool", + result: Any = None, +) -> MagicMock: + """构造一个假的 LangChain Tool 对象。""" + tool = MagicMock() + tool.name = name + tool.description = description + tool.args_schema = MagicMock() + tool.args_schema.model_json_schema.return_value = { + "type": "object", + "properties": {}, + } + tool.ainvoke = AsyncMock(return_value=result if result is not None else {"result": "ok"}) + return tool + + +@pytest.fixture +def mock_langchain_client() -> Any: + """Patch MultiServerMCPClient,返回带 get_tools 的 mock 实例。 + + Yields: + (mock_class, mock_instance) 二元组 + """ + with patch("agentkit.mcp.client._import_langchain_client") as mock_import: + mock_cls = MagicMock(name="MultiServerMCPClient") + mock_instance = MagicMock(name="mcp_client_instance") + mock_instance.get_tools = AsyncMock(return_value=[]) + mock_cls.return_value = mock_instance + mock_import.return_value = mock_cls + yield mock_cls, mock_instance + + +# ── URL scheme 检测测试 ─────────────────────────────────────── + + +class TestURLSchemeDetection: + """MCPClient 根据 URL scheme 自动检测传输类型。""" + + def test_http_url_uses_streamable_http(self): + """http:// URL 应使用 streamable_http 传输。""" + client = MCPClient("http://localhost:8080/mcp") + config = client._langchain_config + assert config is not None + assert config["transport"] == "streamable_http" + assert config["url"] == "http://localhost:8080/mcp" + + def test_https_url_uses_streamable_http(self): + """https:// URL 应使用 streamable_http 传输。""" + client = MCPClient("https://remote.example.com/mcp") + config = client._langchain_config + assert config["transport"] == "streamable_http" + assert config["url"] == "https://remote.example.com/mcp" + + def test_sse_url_uses_sse_transport(self): + """sse:// URL 应使用 sse 传输,并转为 http://。""" + client = MCPClient("sse://localhost:8080/sse") + config = client._langchain_config + assert config["transport"] == "sse" + assert config["url"] == "http://localhost:8080/sse" + + def test_stdio_url_parses_command_and_args(self): + """stdio:// URL 应解析为 command + args。""" + client = MCPClient("stdio://python -m my_mcp_server") + config = client._langchain_config + assert config["transport"] == "stdio" + assert config["command"] == "python" + assert config["args"] == ["-m", "my_mcp_server"] + + def test_stdio_url_no_args(self): + """stdio:// 只有 command 无参数。""" + client = MCPClient("stdio://./server") + config = client._langchain_config + assert config["command"] == "./server" + assert config["args"] == [] + + def test_unsupported_scheme_raises(self): + """不支持的 scheme 应抛出 ValueError。""" + with pytest.raises(ValueError, match="不支持的 URL scheme"): + MCPClient("ftp://example.com/server") + + def test_empty_url_raises(self): + """空 URL 应抛出 ValueError。""" + with pytest.raises(ValueError, match="server_url 不能为空"): + MCPClient("") + + +# ── list_tools via langchain 测试 ───────────────────────────── + + +class TestListToolsViaLangchain: + """list_tools 通过 langchain-mcp-adapters 获取工具列表。""" + + async def test_list_tools_http(self, mock_langchain_client): + """HTTP 传输 — list_tools 返回工具字典列表。""" + mock_cls, mock_instance = mock_langchain_client + mock_instance.get_tools = AsyncMock( + return_value=[ + _make_mock_langchain_tool("search", "Search tool"), + _make_mock_langchain_tool("calc", "Calculator"), + ] + ) + + client = MCPClient("http://localhost:8080/mcp") + tools = await client.list_tools() + + assert len(tools) == 2 + assert tools[0]["name"] == "search" + assert tools[0]["description"] == "Search tool" + assert "inputSchema" in tools[0] + assert tools[1]["name"] == "calc" + + async def test_list_tools_stdio(self, mock_langchain_client): + """Stdio 传输 — list_tools 返回工具字典列表。""" + mock_cls, mock_instance = mock_langchain_client + mock_instance.get_tools = AsyncMock( + return_value=[_make_mock_langchain_tool("read_file", "Read file")] + ) + + client = MCPClient("stdio://python -m fs_server") + tools = await client.list_tools() + + assert len(tools) == 1 + assert tools[0]["name"] == "read_file" + + async def test_list_tools_sse(self, mock_langchain_client): + """SSE 传输 — list_tools 返回工具字典列表。""" + mock_cls, mock_instance = mock_langchain_client + mock_instance.get_tools = AsyncMock( + return_value=[_make_mock_langchain_tool("query", "Query data")] + ) + + client = MCPClient("sse://localhost:8080/sse") + tools = await client.list_tools() + + assert len(tools) == 1 + assert tools[0]["name"] == "query" + + async def test_list_tools_caches_result(self, mock_langchain_client): + """list_tools 结果应缓存到 _tools_cache。""" + _, mock_instance = mock_langchain_client + mock_instance.get_tools = AsyncMock(return_value=[_make_mock_langchain_tool("tool1")]) + + client = MCPClient("http://localhost:8080/mcp") + tools = await client.list_tools() + + assert client._tools_cache == tools + assert client._tools_cache[0]["name"] == "tool1" + + async def test_list_tools_empty(self, mock_langchain_client): + """空工具列表应返回 []。""" + _, mock_instance = mock_langchain_client + mock_instance.get_tools = AsyncMock(return_value=[]) + + client = MCPClient("http://localhost:8080/mcp") + tools = await client.list_tools() + assert tools == [] + + +# ── call_tool via langchain 测试 ────────────────────────────── + + +class TestCallToolViaLangchain: + """call_tool 通过 langchain-mcp-adapters 调用工具。""" + + async def test_call_tool_returns_mcp_format(self, mock_langchain_client): + """call_tool 应返回 MCP 响应格式 {"content": [{"type":"text","text":...}]}。""" + _, mock_instance = mock_langchain_client + mock_instance.get_tools = AsyncMock( + return_value=[_make_mock_langchain_tool("my_tool", result={"result": "hello"})] + ) + + client = MCPClient("http://localhost:8080/mcp") + result = await client.call_tool("my_tool", {"input": "x"}) + + assert "content" in result + assert result["content"][0]["type"] == "text" + assert '"result": "hello"' in result["content"][0]["text"] + + async def test_call_tool_invokes_ainvoke(self, mock_langchain_client): + """call_tool 应调用 LangChain Tool 的 ainvoke。""" + _, mock_instance = mock_langchain_client + mock_tool = _make_mock_langchain_tool("my_tool", result={"ok": True}) + mock_instance.get_tools = AsyncMock(return_value=[mock_tool]) + + client = MCPClient("http://localhost:8080/mcp") + await client.call_tool("my_tool", {"input": "x"}) + + mock_tool.ainvoke.assert_called_once_with({"input": "x"}) + + async def test_call_tool_unknown_raises(self, mock_langchain_client): + """调用不存在的工具应抛出 KeyError。""" + _, mock_instance = mock_langchain_client + mock_instance.get_tools = AsyncMock(return_value=[]) + + client = MCPClient("http://localhost:8080/mcp") + with pytest.raises(KeyError, match="nonexistent"): + await client.call_tool("nonexistent", {}) + + +# ── 旧 Transport 路径兼容测试 ───────────────────────────────── + + +class TestLegacyTransportPath: + """旧 Transport 注入路径应保持向后兼容(发出 DeprecationWarning)。""" + + def test_from_transport_emits_deprecation_warning(self): + """from_transport 应发出 DeprecationWarning。""" + transport = HTTPTransport(endpoint="http://localhost:8080/mcp") + with pytest.warns(DeprecationWarning, match="Transport"): + client = MCPClient.from_transport(transport) + assert client._transport is transport + assert client._langchain_config is None + + def test_init_with_transport_emits_deprecation_warning(self): + """__init__ 传入 transport 应发出 DeprecationWarning。""" + transport = HTTPTransport(endpoint="http://localhost:8080/mcp") + with pytest.warns(DeprecationWarning): + client = MCPClient(server_url="http://localhost:8080", transport=transport) + assert client._transport is transport + + async def test_legacy_list_tools_works(self): + """旧 Transport 路径的 list_tools 仍应正常工作。""" + transport = HTTPTransport(endpoint="http://localhost:8080") + transport._client = MagicMock() + transport._client.is_closed = False + transport.send_request = AsyncMock( + return_value={"tools": [{"name": "legacy_tool", "description": "Legacy"}]} + ) + + with pytest.warns(DeprecationWarning): + client = MCPClient.from_transport(transport) + + tools = await client.list_tools() + assert len(tools) == 1 + assert tools[0]["name"] == "legacy_tool" + transport.send_request.assert_called_once_with("tools/list") + + +# ── MCPTool.execute 测试 ────────────────────────────────────── + + +class TestMCPToolExecute: + """MCPTool.execute 通过新客户端调用工具并解析响应。""" + + async def test_execute_parses_json_text(self, mock_langchain_client): + """execute 应解析 content[0].text 中的 JSON。""" + _, mock_instance = mock_langchain_client + mock_instance.get_tools = AsyncMock( + return_value=[_make_mock_langchain_tool("my_tool", result={"answer": 42})] + ) + + client = MCPClient("http://localhost:8080/mcp") + tool = client.as_tool("my_tool", "desc") + result = await tool.execute(input="x") + + assert result == {"answer": 42} + + async def test_execute_non_json_text(self): + """content[0].text 为纯文本时返回 {"result": text}。""" + client = MCPClient("http://localhost:8080/mcp") + client.call_tool = AsyncMock( # type: ignore[method-assign] + return_value={"content": [{"type": "text", "text": "plain text"}]} + ) + tool = client.as_tool("echo", "desc") + + result = await tool.execute(msg="hello") + assert result == {"result": "plain text"} + + async def test_execute_no_content_field(self): + """无 content 字段时返回原始 dict。""" + client = MCPClient("http://localhost:8080/mcp") + client.call_tool = AsyncMock( # type: ignore[method-assign] + return_value={"unexpected": "shape"} + ) + tool = client.as_tool("status", "desc") + + result = await tool.execute() + assert result == {"unexpected": "shape"} + + async def test_as_tool_returns_mcp_tool(self): + """as_tool 应返回 MCPTool 实例。""" + client = MCPClient("http://localhost:8080/mcp") + tool = client.as_tool("search", "Search the web") + + assert isinstance(tool, MCPTool) + assert tool.name == "search" + assert tool.description == "Search the web" + assert tool._client is client + assert "mcp" in tool.tags + + +# ── langchain 未安装测试 ───────────────────────────────────── + + +class TestLangchainNotInstalled: + """langchain-mcp-adapters 未安装时的错误处理。""" + + async def test_import_error_when_langchain_missing(self): + """langchain 未安装时 list_tools 应抛出 ImportError。""" + client = MCPClient("http://localhost:8080/mcp") + + with patch( + "agentkit.mcp.client._import_langchain_client", + side_effect=ImportError("langchain-mcp-adapters 未安装"), + ): + with pytest.raises(ImportError, match="langchain-mcp-adapters"): + await client.list_tools() + + +# ── transport.py 模块级 DeprecationWarning 测试 ────────────── + + +class TestTransportDeprecation: + """transport.py 导入时应发出 DeprecationWarning。""" + + def test_transport_module_emits_deprecation_warning(self): + """导入 agentkit.mcp.transport 应触发 DeprecationWarning。""" + # 先从 sys.modules 移除,强制重新导入 + mods_to_remove = [k for k in sys.modules if k.startswith("agentkit.mcp.transport")] + for mod in mods_to_remove: + del sys.modules[mod] + + with warnings.catch_warnings(record=True) as warning_list: + warnings.simplefilter("always") + importlib.import_module("agentkit.mcp.transport") + + deprecation_warnings = [ + w for w in warning_list if issubclass(w.category, DeprecationWarning) + ] + assert len(deprecation_warnings) >= 1 + assert "transport" in str(deprecation_warnings[0].message) + + +# ── timeout 参数透传测试 ───────────────────────────────────── + + +class TestTimeoutPropagation: + """timeout 参数应透传到 langchain 连接配置。""" + + def test_timeout_passed_to_http_config(self): + """HTTP 传输应将 timeout 写入配置。""" + client = MCPClient("http://localhost:8080/mcp", timeout=60) + assert client._langchain_config["timeout"] == 60 + + def test_timeout_passed_to_sse_config(self): + """SSE 传输应将 timeout 写入配置。""" + client = MCPClient("sse://localhost:8080/sse", timeout=45) + assert client._langchain_config["timeout"] == 45 + + def test_default_timeout_is_30(self): + """默认 timeout 为 30。""" + client = MCPClient("http://localhost:8080/mcp") + assert client._timeout == 30 + assert client._langchain_config["timeout"] == 30 + + def test_stdio_config_has_no_timeout(self): + """stdio 传输配置不含 timeout(StdioConnection 无此字段)。""" + client = MCPClient("stdio://python -m server", timeout=60) + assert "timeout" not in client._langchain_config diff --git a/tests/unit/test_mcp_client.py b/tests/unit/test_mcp_client.py index ccd5bc6..b08f1c8 100644 --- a/tests/unit/test_mcp_client.py +++ b/tests/unit/test_mcp_client.py @@ -1,6 +1,7 @@ """MCP Client 单元测试""" import json +from unittest.mock import AsyncMock import httpx import pytest @@ -149,112 +150,12 @@ class TestMCPClientTransportMode: await transport.disconnect() -# ── MCPClient 直接 HTTP 模式测试 ──────────────────────────────── - - -class TestMCPClientDirectHTTP: - """MCPClient 直接 HTTP 模式测试(无 Transport)""" - - async def test_list_tools_direct_http(self, httpx_mock): - httpx_mock.add_response( - url="http://localhost:8080/tools/list", - json={ - "tools": [ - {"name": "search", "description": "Search tool"}, - ] - }, - ) - - client = MCPClient(server_url="http://localhost:8080") - tools = await client.list_tools() - - assert len(tools) == 1 - assert tools[0]["name"] == "search" - assert client._tools_cache == tools - - async def test_call_tool_direct_http(self, httpx_mock): - httpx_mock.add_response( - url="http://localhost:8080/tools/call", - json={"result": "computed value"}, - ) - - client = MCPClient(server_url="http://localhost:8080") - result = await client.call_tool("compute", {"x": 42}) - - assert result == {"result": "computed value"} - - # 验证请求体 - request = httpx_mock.get_request() - body = json.loads(request.content) - assert body["name"] == "compute" - assert body["arguments"] == {"x": 42} - - async def test_list_tools_caches_result(self, httpx_mock): - httpx_mock.add_response( - url="http://localhost:8080/tools/list", - json={"tools": [{"name": "tool1"}]}, - ) - - client = MCPClient(server_url="http://localhost:8080") - tools = await client.list_tools() - - # 验证缓存被设置 - assert client._tools_cache == tools - assert client._tools_cache[0]["name"] == "tool1" - - async def test_call_tool_sends_post_request(self, httpx_mock): - httpx_mock.add_response( - url="http://localhost:8080/tools/call", - json={"output": "done"}, - ) - - client = MCPClient(server_url="http://localhost:8080") - await client.call_tool("my_tool", {"arg": "val"}) - - request = httpx_mock.get_request() - assert request.method == "POST" - - # ── MCPClient 连接错误处理测试 ────────────────────────────────── class TestMCPClientErrorHandling: """MCPClient 连接错误处理测试""" - async def test_list_tools_http_error(self, httpx_mock): - httpx_mock.add_response( - url="http://localhost:8080/tools/list", - status_code=500, - ) - - client = MCPClient(server_url="http://localhost:8080") - with pytest.raises(httpx.HTTPStatusError): - await client.list_tools() - - async def test_call_tool_http_error(self, httpx_mock): - httpx_mock.add_response( - url="http://localhost:8080/tools/call", - status_code=404, - ) - - client = MCPClient(server_url="http://localhost:8080") - with pytest.raises(httpx.HTTPStatusError): - await client.call_tool("missing_tool", {}) - - async def test_list_tools_connection_error(self, httpx_mock): - httpx_mock.add_exception(httpx.ConnectError("Connection refused")) - - client = MCPClient(server_url="http://localhost:8080") - with pytest.raises(httpx.ConnectError): - await client.list_tools() - - async def test_call_tool_connection_error(self, httpx_mock): - httpx_mock.add_exception(httpx.ConnectError("Connection refused")) - - client = MCPClient(server_url="http://localhost:8080") - with pytest.raises(httpx.ConnectError): - await client.call_tool("any_tool", {}) - async def test_transport_error_propagates(self, httpx_mock): httpx_mock.add_exception(httpx.ConnectError("Connection refused")) @@ -355,41 +256,38 @@ class TestMCPTool: assert tool._client is client assert "mcp" in tool.tags - async def test_mcp_tool_execute_text_content(self, httpx_mock): - httpx_mock.add_response( - url="http://localhost:8080/tools/call", - json={ - "content": [{"type": "text", "text": '{"answer": 42}'}], - }, - ) - + async def test_mcp_tool_execute_text_content(self): + """execute 应解析 content[0].text 中的 JSON。""" client = MCPClient(server_url="http://localhost:8080") + client.call_tool = AsyncMock( # type: ignore[method-assign] + return_value={ + "content": [{"type": "text", "text": '{"answer": 42}'}], + } + ) tool = client.as_tool("ask", description="Ask a question") result = await tool.execute(question="meaning of life") assert result == {"answer": 42} - async def test_mcp_tool_execute_non_json_text(self, httpx_mock): - httpx_mock.add_response( - url="http://localhost:8080/tools/call", - json={ - "content": [{"type": "text", "text": "plain text response"}], - }, - ) - + async def test_mcp_tool_execute_non_json_text(self): + """content[0].text 为纯文本时返回 {"result": text}。""" client = MCPClient(server_url="http://localhost:8080") + client.call_tool = AsyncMock( # type: ignore[method-assign] + return_value={ + "content": [{"type": "text", "text": "plain text response"}], + } + ) tool = client.as_tool("echo", description="Echo input") result = await tool.execute(msg="hello") assert result == {"result": "plain text response"} - async def test_mcp_tool_execute_no_content(self, httpx_mock): - httpx_mock.add_response( - url="http://localhost:8080/tools/call", - json={"status": "ok", "data": "some data"}, - ) - + async def test_mcp_tool_execute_no_content(self): + """无 content 字段时返回原始 dict。""" client = MCPClient(server_url="http://localhost:8080") + client.call_tool = AsyncMock( # type: ignore[method-assign] + return_value={"status": "ok", "data": "some data"} + ) tool = client.as_tool("status", description="Check status") result = await tool.execute()