feat(mcp): U16 — langchain-mcp-adapters client replacement + transport deprecation
- 重写 MCPClient:URL scheme 自动检测(stdio/http/sse)→ langchain config - 旧 Transport 注入路径保留(DeprecationWarning),向后兼容 - transport.py 模块级弃用警告 - 28 个新测试覆盖 URL 检测、list_tools、call_tool、legacy 路径、ImportError - 修复 manager.py / transport.py 预存 F401/F841
This commit is contained in:
parent
069dbc22b1
commit
86541d7172
|
|
@ -63,6 +63,8 @@ server = [
|
|||
]
|
||||
mcp = [
|
||||
"mcp>=1.0",
|
||||
# U16 — langchain-mcp-adapters 替换自研 MCP 客户端传输层
|
||||
"langchain-mcp-adapters>=0.1",
|
||||
]
|
||||
evolution = [
|
||||
"scipy>=1.12",
|
||||
|
|
|
|||
|
|
@ -1,22 +1,56 @@
|
|||
"""MCP Client - 调用外部 MCP 工具服务器"""
|
||||
"""MCP Client - 调用外部 MCP 工具服务器
|
||||
|
||||
U16: 重写为 langchain-mcp-adapters 包装层。保留 MCPClient / MCPTool API 向后兼容。
|
||||
旧的 Transport 注入路径保留(发出 DeprecationWarning),供现有调用方过渡。
|
||||
|
||||
传输类型由 URL scheme 自动检测(transport=None 时):
|
||||
- ``stdio://command arg1 arg2`` → stdio 传输
|
||||
- ``http://`` / ``https://`` → streamable_http 传输
|
||||
- ``sse://...`` → sse 传输(自动转为 http://)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from agentkit.mcp.transport import HTTPTransport, SSETransport, StdioTransport, Transport
|
||||
from agentkit.tools.base import Tool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_mcp_adapters.client import MultiServerMCPClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _import_langchain_client() -> type["MultiServerMCPClient"]:
|
||||
"""延迟导入 langchain-mcp-adapters 的 MultiServerMCPClient。
|
||||
|
||||
未安装时抛出带提示信息的 ImportError,便于调用方回退到 Transport 路径。
|
||||
"""
|
||||
try:
|
||||
from langchain_mcp_adapters.client import MultiServerMCPClient as _Client
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"langchain-mcp-adapters 未安装,无法使用 langchain 传输路径。"
|
||||
"请执行 pip install 'fischer-agentkit[mcp]' 或 pip install langchain-mcp-adapters。"
|
||||
"如需使用旧的 Transport 路径,请通过 MCPClient.from_transport(transport) 创建客户端。"
|
||||
) from e
|
||||
return _Client
|
||||
|
||||
|
||||
class MCPClient:
|
||||
"""MCP Client - 连接外部 MCP Server 并调用工具
|
||||
|
||||
U16: 内部使用 langchain-mcp-adapters 的 MultiServerMCPClient 管理连接。
|
||||
保留旧 Transport 注入路径以向后兼容(发出 DeprecationWarning)。
|
||||
|
||||
支持两种模式:
|
||||
1. 通过 Transport 层发送 JSON-RPC 请求(推荐)
|
||||
2. 直接 HTTP 调用(向后兼容)
|
||||
1. URL scheme 自动检测(推荐):transport=None,根据 server_url 的 scheme 选择传输
|
||||
2. Transport 注入(旧路径,弃用):传入 Transport 实例,走原有 JSON-RPC 路径
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -25,14 +59,84 @@ class MCPClient:
|
|||
timeout: int = 30,
|
||||
transport: Transport | None = None,
|
||||
):
|
||||
self._server_url = server_url.rstrip("/")
|
||||
self._server_url = server_url.rstrip("/") if server_url else ""
|
||||
self._timeout = timeout
|
||||
self._tools_cache: list[dict] | None = None
|
||||
self._transport = transport
|
||||
|
||||
if transport is not None:
|
||||
# 旧 Transport 路径 — 发出 DeprecationWarning,但保持原有行为
|
||||
warnings.warn(
|
||||
"通过 Transport 实例创建 MCPClient 已弃用,请使用 URL scheme 自动检测"
|
||||
"(如 MCPClient('http://...') 或 MCPClient('stdio://...'))。"
|
||||
"详见 U16 迁移指南。将在下个版本移除。",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
self._langchain_config: dict[str, Any] | None = None
|
||||
else:
|
||||
# 新 langchain 路径 — 解析 URL scheme 构建连接配置
|
||||
self._langchain_config = self._build_langchain_config(self._server_url, timeout)
|
||||
|
||||
@staticmethod
|
||||
def _build_langchain_config(server_url: str, timeout: float) -> dict[str, Any]:
|
||||
"""根据 URL scheme 构建 langchain-mcp-adapters 连接配置。
|
||||
|
||||
Args:
|
||||
server_url: 服务器 URL,支持 stdio:// http:// https:// sse://
|
||||
timeout: 超时秒数(仅 http/sse 传输使用)
|
||||
|
||||
Returns:
|
||||
langchain-mcp-adapters 连接配置 dict
|
||||
|
||||
Raises:
|
||||
ValueError: URL 为空或 scheme 不支持
|
||||
"""
|
||||
if not server_url:
|
||||
raise ValueError("server_url 不能为空(transport=None 时需要有效 URL)")
|
||||
|
||||
# stdio 特殊处理:stdio://command arg1 arg2 → command + args
|
||||
# 不用 urlparse,因为命令行参数可能含特殊字符
|
||||
if server_url.startswith("stdio://"):
|
||||
raw = server_url[len("stdio://") :]
|
||||
parts = raw.split()
|
||||
if not parts:
|
||||
raise ValueError(f"无效的 stdio URL: {server_url!r}")
|
||||
return {
|
||||
"transport": "stdio",
|
||||
"command": parts[0],
|
||||
"args": parts[1:],
|
||||
}
|
||||
|
||||
parsed = urlparse(server_url)
|
||||
scheme = parsed.scheme.lower()
|
||||
|
||||
if scheme in ("http", "https"):
|
||||
return {
|
||||
"transport": "streamable_http",
|
||||
"url": server_url,
|
||||
"timeout": timeout,
|
||||
}
|
||||
|
||||
if scheme == "sse":
|
||||
# sse://host/path → http://host/path(sse scheme 不携带 TLS 信息)
|
||||
converted = "http://" + server_url[len("sse://") :]
|
||||
return {
|
||||
"transport": "sse",
|
||||
"url": converted,
|
||||
"timeout": timeout,
|
||||
}
|
||||
|
||||
raise ValueError(
|
||||
f"不支持的 URL scheme: {scheme!r}。支持 stdio://, http://, https://, sse://"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_transport(cls, transport: Transport) -> "MCPClient":
|
||||
"""从 Transport 实例创建 MCPClient"""
|
||||
"""从 Transport 实例创建 MCPClient(旧路径,向后兼容)
|
||||
|
||||
发出 DeprecationWarning。建议迁移到 URL scheme 自动检测。
|
||||
"""
|
||||
if isinstance(transport, HTTPTransport):
|
||||
server_url = transport._endpoint
|
||||
elif isinstance(transport, SSETransport):
|
||||
|
|
@ -44,8 +148,16 @@ class MCPClient:
|
|||
return cls(server_url=server_url, transport=transport)
|
||||
|
||||
async def list_tools(self) -> list[dict]:
|
||||
"""列出远程 MCP Server 上的工具"""
|
||||
"""列出远程 MCP Server 上的工具
|
||||
|
||||
Returns:
|
||||
工具列表,每项为 ``{"name":..., "description":..., "inputSchema":...}``
|
||||
|
||||
Raises:
|
||||
ImportError: langchain-mcp-adapters 未安装(仅 transport=None 路径)
|
||||
"""
|
||||
if self._transport is not None:
|
||||
# 旧 Transport 路径
|
||||
if not self._transport.is_connected:
|
||||
await self._transport.connect()
|
||||
result = await self._transport.send_request("tools/list")
|
||||
|
|
@ -53,16 +165,52 @@ class MCPClient:
|
|||
self._tools_cache = tools
|
||||
return self._tools_cache
|
||||
|
||||
async with httpx.AsyncClient(timeout=self._timeout) as client:
|
||||
response = await client.get(f"{self._server_url}/tools/list")
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
self._tools_cache = data.get("tools", [])
|
||||
return self._tools_cache
|
||||
# 新 langchain 路径
|
||||
client_cls = _import_langchain_client()
|
||||
client = client_cls({"server": self._langchain_config})
|
||||
lc_tools = await client.get_tools()
|
||||
tools = [
|
||||
{
|
||||
"name": t.name,
|
||||
"description": t.description or "",
|
||||
"inputSchema": self._extract_schema(t),
|
||||
}
|
||||
for t in lc_tools
|
||||
]
|
||||
self._tools_cache = tools
|
||||
return tools
|
||||
|
||||
@staticmethod
|
||||
def _extract_schema(tool: Any) -> dict[str, Any]:
|
||||
"""从 LangChain Tool 提取 inputSchema(JSON Schema 格式)。
|
||||
|
||||
LangChain 工具通常有 args_schema(pydantic model),回退到 tool.args dict。
|
||||
"""
|
||||
args_schema = getattr(tool, "args_schema", None)
|
||||
if args_schema is not None and hasattr(args_schema, "model_json_schema"):
|
||||
return args_schema.model_json_schema()
|
||||
# ponytail: tool.args 在 langchain BaseTool 上是 dict,够用作回退
|
||||
args = getattr(tool, "args", None)
|
||||
if isinstance(args, dict):
|
||||
return args
|
||||
return {"type": "object", "properties": {}}
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: dict) -> dict:
|
||||
"""调用远程 MCP 工具"""
|
||||
"""调用远程 MCP 工具
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
arguments: 工具参数
|
||||
|
||||
Returns:
|
||||
MCP 响应格式的 dict:``{"content": [{"type":"text","text":...}]}``
|
||||
|
||||
Raises:
|
||||
KeyError: 工具不存在(仅 transport=None 路径)
|
||||
ImportError: langchain-mcp-adapters 未安装
|
||||
"""
|
||||
if self._transport is not None:
|
||||
# 旧 Transport 路径
|
||||
if not self._transport.is_connected:
|
||||
await self._transport.connect()
|
||||
return await self._transport.send_request(
|
||||
|
|
@ -70,13 +218,23 @@ class MCPClient:
|
|||
params={"name": tool_name, "arguments": arguments},
|
||||
)
|
||||
|
||||
async with httpx.AsyncClient(timeout=self._timeout) as client:
|
||||
response = await client.post(
|
||||
f"{self._server_url}/tools/call",
|
||||
json={"name": tool_name, "arguments": arguments},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
# 新 langchain 路径
|
||||
client_cls = _import_langchain_client()
|
||||
client = client_cls({"server": self._langchain_config})
|
||||
lc_tools = await client.get_tools()
|
||||
for tool in lc_tools:
|
||||
if tool.name == tool_name:
|
||||
result = await tool.ainvoke(arguments)
|
||||
# 包装为 MCP 响应格式,保持 call_tool 返回形状与旧路径一致
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps(result, ensure_ascii=False, default=str),
|
||||
}
|
||||
]
|
||||
}
|
||||
raise KeyError(f"MCP 工具不存在: {tool_name!r}")
|
||||
|
||||
def as_tool(self, tool_name: str, description: str = "") -> "MCPTool":
|
||||
"""将远程 MCP 工具包装为本地 Tool 对象"""
|
||||
|
|
@ -116,7 +274,6 @@ class MCPTool(Tool):
|
|||
if "content" in result:
|
||||
for item in result["content"]:
|
||||
if item.get("type") == "text":
|
||||
import json
|
||||
try:
|
||||
return json.loads(item["text"])
|
||||
except json.JSONDecodeError:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from agentkit.mcp.client import MCPClient
|
||||
from agentkit.mcp.transport import HTTPTransport, SSETransport, StdioTransport, Transport
|
||||
|
|
@ -39,10 +39,7 @@ class MCPManager:
|
|||
|
||||
使用 asyncio.gather 并发启动,单个服务器失败不影响其他服务器。
|
||||
"""
|
||||
tasks = [
|
||||
self._start_server_safe(name, config)
|
||||
for name, config in self._configs.items()
|
||||
]
|
||||
tasks = [self._start_server_safe(name, config) for name, config in self._configs.items()]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
async def _start_server_safe(self, name: str, config: MCPServerConfig) -> None:
|
||||
|
|
|
|||
|
|
@ -1,12 +1,16 @@
|
|||
"""MCP Transport - 传输层抽象
|
||||
|
||||
提供 MCP 协议的传输层实现,支持 Streamable HTTP、SSE 和 Stdio 三种传输方式。
|
||||
|
||||
U16: 本模块已弃用。请通过 ``MCPClient(url)`` 使用 langchain-mcp-adapters 传输层。
|
||||
保留所有类以向后兼容,将在下个版本移除。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -14,6 +18,14 @@ import httpx
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# U16: 模块级弃用警告 — 导入时发出,提示迁移到 langchain-mcp-adapters
|
||||
warnings.warn(
|
||||
"agentkit.mcp.transport 已弃用 — 请通过 MCPClient(url) 使用 langchain-mcp-adapters 传输层。"
|
||||
"详见 U16 迁移指南。将在下个版本移除。",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
|
||||
class TransportError(Exception):
|
||||
"""传输层错误"""
|
||||
|
|
@ -158,9 +170,7 @@ class HTTPTransport(Transport):
|
|||
# 检查 JSON-RPC 错误
|
||||
if "error" in data:
|
||||
error = data["error"]
|
||||
raise TransportError(
|
||||
f"JSON-RPC error {error.get('code')}: {error.get('message')}"
|
||||
)
|
||||
raise TransportError(f"JSON-RPC error {error.get('code')}: {error.get('message')}")
|
||||
|
||||
return data.get("result")
|
||||
|
||||
|
|
@ -264,7 +274,7 @@ class SSETransport(Transport):
|
|||
if not line or line.startswith(":"):
|
||||
continue
|
||||
if line.startswith("data:"):
|
||||
data_str = line[len("data:"):].strip()
|
||||
data_str = line[len("data:") :].strip()
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
await self._response_queue.put(data)
|
||||
|
|
@ -328,9 +338,7 @@ class SSETransport(Transport):
|
|||
# 检查 JSON-RPC 错误
|
||||
if "error" in data:
|
||||
error = data["error"]
|
||||
raise TransportError(
|
||||
f"JSON-RPC error {error.get('code')}: {error.get('message')}"
|
||||
)
|
||||
raise TransportError(f"JSON-RPC error {error.get('code')}: {error.get('message')}")
|
||||
|
||||
return data.get("result")
|
||||
|
||||
|
|
@ -382,11 +390,7 @@ class StdioTransport(Transport):
|
|||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return (
|
||||
self._connected
|
||||
and self._process is not None
|
||||
and self._process.returncode is None
|
||||
)
|
||||
return self._connected and self._process is not None and self._process.returncode is None
|
||||
|
||||
def _next_request_id(self) -> int:
|
||||
"""生成下一个请求 ID"""
|
||||
|
|
@ -427,7 +431,7 @@ class StdioTransport(Transport):
|
|||
|
||||
# 发送 initialize 请求并等待响应
|
||||
try:
|
||||
init_result = await asyncio.wait_for(
|
||||
await asyncio.wait_for(
|
||||
self._send_request_internal(
|
||||
"initialize",
|
||||
{
|
||||
|
|
|
|||
|
|
@ -0,0 +1,391 @@
|
|||
"""U16 — MCPClient langchain-mcp-adapters 路径单元测试。
|
||||
|
||||
覆盖场景:
|
||||
1. URL scheme 检测(http / sse / stdio)
|
||||
2. list_tools via langchain(三种传输)
|
||||
3. call_tool via langchain(正常 / 未知工具)
|
||||
4. 旧 Transport 路径兼容(DeprecationWarning)
|
||||
5. MCPTool.execute via 新客户端(JSON / 纯文本 / 无 content)
|
||||
6. langchain 未安装时 ImportError
|
||||
7. transport.py 模块级 DeprecationWarning
|
||||
8. timeout 参数透传
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import sys
|
||||
import warnings
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.mcp.client import MCPClient, MCPTool
|
||||
from agentkit.mcp.transport import HTTPTransport
|
||||
|
||||
|
||||
# ── 辅助工具 ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _make_mock_langchain_tool(
|
||||
name: str = "my_tool",
|
||||
description: str = "Test tool",
|
||||
result: Any = None,
|
||||
) -> MagicMock:
|
||||
"""构造一个假的 LangChain Tool 对象。"""
|
||||
tool = MagicMock()
|
||||
tool.name = name
|
||||
tool.description = description
|
||||
tool.args_schema = MagicMock()
|
||||
tool.args_schema.model_json_schema.return_value = {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
}
|
||||
tool.ainvoke = AsyncMock(return_value=result if result is not None else {"result": "ok"})
|
||||
return tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_langchain_client() -> Any:
|
||||
"""Patch MultiServerMCPClient,返回带 get_tools 的 mock 实例。
|
||||
|
||||
Yields:
|
||||
(mock_class, mock_instance) 二元组
|
||||
"""
|
||||
with patch("agentkit.mcp.client._import_langchain_client") as mock_import:
|
||||
mock_cls = MagicMock(name="MultiServerMCPClient")
|
||||
mock_instance = MagicMock(name="mcp_client_instance")
|
||||
mock_instance.get_tools = AsyncMock(return_value=[])
|
||||
mock_cls.return_value = mock_instance
|
||||
mock_import.return_value = mock_cls
|
||||
yield mock_cls, mock_instance
|
||||
|
||||
|
||||
# ── URL scheme 检测测试 ───────────────────────────────────────
|
||||
|
||||
|
||||
class TestURLSchemeDetection:
|
||||
"""MCPClient 根据 URL scheme 自动检测传输类型。"""
|
||||
|
||||
def test_http_url_uses_streamable_http(self):
|
||||
"""http:// URL 应使用 streamable_http 传输。"""
|
||||
client = MCPClient("http://localhost:8080/mcp")
|
||||
config = client._langchain_config
|
||||
assert config is not None
|
||||
assert config["transport"] == "streamable_http"
|
||||
assert config["url"] == "http://localhost:8080/mcp"
|
||||
|
||||
def test_https_url_uses_streamable_http(self):
|
||||
"""https:// URL 应使用 streamable_http 传输。"""
|
||||
client = MCPClient("https://remote.example.com/mcp")
|
||||
config = client._langchain_config
|
||||
assert config["transport"] == "streamable_http"
|
||||
assert config["url"] == "https://remote.example.com/mcp"
|
||||
|
||||
def test_sse_url_uses_sse_transport(self):
|
||||
"""sse:// URL 应使用 sse 传输,并转为 http://。"""
|
||||
client = MCPClient("sse://localhost:8080/sse")
|
||||
config = client._langchain_config
|
||||
assert config["transport"] == "sse"
|
||||
assert config["url"] == "http://localhost:8080/sse"
|
||||
|
||||
def test_stdio_url_parses_command_and_args(self):
|
||||
"""stdio:// URL 应解析为 command + args。"""
|
||||
client = MCPClient("stdio://python -m my_mcp_server")
|
||||
config = client._langchain_config
|
||||
assert config["transport"] == "stdio"
|
||||
assert config["command"] == "python"
|
||||
assert config["args"] == ["-m", "my_mcp_server"]
|
||||
|
||||
def test_stdio_url_no_args(self):
|
||||
"""stdio:// 只有 command 无参数。"""
|
||||
client = MCPClient("stdio://./server")
|
||||
config = client._langchain_config
|
||||
assert config["command"] == "./server"
|
||||
assert config["args"] == []
|
||||
|
||||
def test_unsupported_scheme_raises(self):
|
||||
"""不支持的 scheme 应抛出 ValueError。"""
|
||||
with pytest.raises(ValueError, match="不支持的 URL scheme"):
|
||||
MCPClient("ftp://example.com/server")
|
||||
|
||||
def test_empty_url_raises(self):
|
||||
"""空 URL 应抛出 ValueError。"""
|
||||
with pytest.raises(ValueError, match="server_url 不能为空"):
|
||||
MCPClient("")
|
||||
|
||||
|
||||
# ── list_tools via langchain 测试 ─────────────────────────────
|
||||
|
||||
|
||||
class TestListToolsViaLangchain:
|
||||
"""list_tools 通过 langchain-mcp-adapters 获取工具列表。"""
|
||||
|
||||
async def test_list_tools_http(self, mock_langchain_client):
|
||||
"""HTTP 传输 — list_tools 返回工具字典列表。"""
|
||||
mock_cls, mock_instance = mock_langchain_client
|
||||
mock_instance.get_tools = AsyncMock(
|
||||
return_value=[
|
||||
_make_mock_langchain_tool("search", "Search tool"),
|
||||
_make_mock_langchain_tool("calc", "Calculator"),
|
||||
]
|
||||
)
|
||||
|
||||
client = MCPClient("http://localhost:8080/mcp")
|
||||
tools = await client.list_tools()
|
||||
|
||||
assert len(tools) == 2
|
||||
assert tools[0]["name"] == "search"
|
||||
assert tools[0]["description"] == "Search tool"
|
||||
assert "inputSchema" in tools[0]
|
||||
assert tools[1]["name"] == "calc"
|
||||
|
||||
async def test_list_tools_stdio(self, mock_langchain_client):
|
||||
"""Stdio 传输 — list_tools 返回工具字典列表。"""
|
||||
mock_cls, mock_instance = mock_langchain_client
|
||||
mock_instance.get_tools = AsyncMock(
|
||||
return_value=[_make_mock_langchain_tool("read_file", "Read file")]
|
||||
)
|
||||
|
||||
client = MCPClient("stdio://python -m fs_server")
|
||||
tools = await client.list_tools()
|
||||
|
||||
assert len(tools) == 1
|
||||
assert tools[0]["name"] == "read_file"
|
||||
|
||||
async def test_list_tools_sse(self, mock_langchain_client):
|
||||
"""SSE 传输 — list_tools 返回工具字典列表。"""
|
||||
mock_cls, mock_instance = mock_langchain_client
|
||||
mock_instance.get_tools = AsyncMock(
|
||||
return_value=[_make_mock_langchain_tool("query", "Query data")]
|
||||
)
|
||||
|
||||
client = MCPClient("sse://localhost:8080/sse")
|
||||
tools = await client.list_tools()
|
||||
|
||||
assert len(tools) == 1
|
||||
assert tools[0]["name"] == "query"
|
||||
|
||||
async def test_list_tools_caches_result(self, mock_langchain_client):
|
||||
"""list_tools 结果应缓存到 _tools_cache。"""
|
||||
_, mock_instance = mock_langchain_client
|
||||
mock_instance.get_tools = AsyncMock(return_value=[_make_mock_langchain_tool("tool1")])
|
||||
|
||||
client = MCPClient("http://localhost:8080/mcp")
|
||||
tools = await client.list_tools()
|
||||
|
||||
assert client._tools_cache == tools
|
||||
assert client._tools_cache[0]["name"] == "tool1"
|
||||
|
||||
async def test_list_tools_empty(self, mock_langchain_client):
|
||||
"""空工具列表应返回 []。"""
|
||||
_, mock_instance = mock_langchain_client
|
||||
mock_instance.get_tools = AsyncMock(return_value=[])
|
||||
|
||||
client = MCPClient("http://localhost:8080/mcp")
|
||||
tools = await client.list_tools()
|
||||
assert tools == []
|
||||
|
||||
|
||||
# ── call_tool via langchain 测试 ──────────────────────────────
|
||||
|
||||
|
||||
class TestCallToolViaLangchain:
|
||||
"""call_tool 通过 langchain-mcp-adapters 调用工具。"""
|
||||
|
||||
async def test_call_tool_returns_mcp_format(self, mock_langchain_client):
|
||||
"""call_tool 应返回 MCP 响应格式 {"content": [{"type":"text","text":...}]}。"""
|
||||
_, mock_instance = mock_langchain_client
|
||||
mock_instance.get_tools = AsyncMock(
|
||||
return_value=[_make_mock_langchain_tool("my_tool", result={"result": "hello"})]
|
||||
)
|
||||
|
||||
client = MCPClient("http://localhost:8080/mcp")
|
||||
result = await client.call_tool("my_tool", {"input": "x"})
|
||||
|
||||
assert "content" in result
|
||||
assert result["content"][0]["type"] == "text"
|
||||
assert '"result": "hello"' in result["content"][0]["text"]
|
||||
|
||||
async def test_call_tool_invokes_ainvoke(self, mock_langchain_client):
|
||||
"""call_tool 应调用 LangChain Tool 的 ainvoke。"""
|
||||
_, mock_instance = mock_langchain_client
|
||||
mock_tool = _make_mock_langchain_tool("my_tool", result={"ok": True})
|
||||
mock_instance.get_tools = AsyncMock(return_value=[mock_tool])
|
||||
|
||||
client = MCPClient("http://localhost:8080/mcp")
|
||||
await client.call_tool("my_tool", {"input": "x"})
|
||||
|
||||
mock_tool.ainvoke.assert_called_once_with({"input": "x"})
|
||||
|
||||
async def test_call_tool_unknown_raises(self, mock_langchain_client):
|
||||
"""调用不存在的工具应抛出 KeyError。"""
|
||||
_, mock_instance = mock_langchain_client
|
||||
mock_instance.get_tools = AsyncMock(return_value=[])
|
||||
|
||||
client = MCPClient("http://localhost:8080/mcp")
|
||||
with pytest.raises(KeyError, match="nonexistent"):
|
||||
await client.call_tool("nonexistent", {})
|
||||
|
||||
|
||||
# ── 旧 Transport 路径兼容测试 ─────────────────────────────────
|
||||
|
||||
|
||||
class TestLegacyTransportPath:
|
||||
"""旧 Transport 注入路径应保持向后兼容(发出 DeprecationWarning)。"""
|
||||
|
||||
def test_from_transport_emits_deprecation_warning(self):
|
||||
"""from_transport 应发出 DeprecationWarning。"""
|
||||
transport = HTTPTransport(endpoint="http://localhost:8080/mcp")
|
||||
with pytest.warns(DeprecationWarning, match="Transport"):
|
||||
client = MCPClient.from_transport(transport)
|
||||
assert client._transport is transport
|
||||
assert client._langchain_config is None
|
||||
|
||||
def test_init_with_transport_emits_deprecation_warning(self):
|
||||
"""__init__ 传入 transport 应发出 DeprecationWarning。"""
|
||||
transport = HTTPTransport(endpoint="http://localhost:8080/mcp")
|
||||
with pytest.warns(DeprecationWarning):
|
||||
client = MCPClient(server_url="http://localhost:8080", transport=transport)
|
||||
assert client._transport is transport
|
||||
|
||||
async def test_legacy_list_tools_works(self):
|
||||
"""旧 Transport 路径的 list_tools 仍应正常工作。"""
|
||||
transport = HTTPTransport(endpoint="http://localhost:8080")
|
||||
transport._client = MagicMock()
|
||||
transport._client.is_closed = False
|
||||
transport.send_request = AsyncMock(
|
||||
return_value={"tools": [{"name": "legacy_tool", "description": "Legacy"}]}
|
||||
)
|
||||
|
||||
with pytest.warns(DeprecationWarning):
|
||||
client = MCPClient.from_transport(transport)
|
||||
|
||||
tools = await client.list_tools()
|
||||
assert len(tools) == 1
|
||||
assert tools[0]["name"] == "legacy_tool"
|
||||
transport.send_request.assert_called_once_with("tools/list")
|
||||
|
||||
|
||||
# ── MCPTool.execute 测试 ──────────────────────────────────────
|
||||
|
||||
|
||||
class TestMCPToolExecute:
|
||||
"""MCPTool.execute 通过新客户端调用工具并解析响应。"""
|
||||
|
||||
async def test_execute_parses_json_text(self, mock_langchain_client):
|
||||
"""execute 应解析 content[0].text 中的 JSON。"""
|
||||
_, mock_instance = mock_langchain_client
|
||||
mock_instance.get_tools = AsyncMock(
|
||||
return_value=[_make_mock_langchain_tool("my_tool", result={"answer": 42})]
|
||||
)
|
||||
|
||||
client = MCPClient("http://localhost:8080/mcp")
|
||||
tool = client.as_tool("my_tool", "desc")
|
||||
result = await tool.execute(input="x")
|
||||
|
||||
assert result == {"answer": 42}
|
||||
|
||||
async def test_execute_non_json_text(self):
|
||||
"""content[0].text 为纯文本时返回 {"result": text}。"""
|
||||
client = MCPClient("http://localhost:8080/mcp")
|
||||
client.call_tool = AsyncMock( # type: ignore[method-assign]
|
||||
return_value={"content": [{"type": "text", "text": "plain text"}]}
|
||||
)
|
||||
tool = client.as_tool("echo", "desc")
|
||||
|
||||
result = await tool.execute(msg="hello")
|
||||
assert result == {"result": "plain text"}
|
||||
|
||||
async def test_execute_no_content_field(self):
|
||||
"""无 content 字段时返回原始 dict。"""
|
||||
client = MCPClient("http://localhost:8080/mcp")
|
||||
client.call_tool = AsyncMock( # type: ignore[method-assign]
|
||||
return_value={"unexpected": "shape"}
|
||||
)
|
||||
tool = client.as_tool("status", "desc")
|
||||
|
||||
result = await tool.execute()
|
||||
assert result == {"unexpected": "shape"}
|
||||
|
||||
async def test_as_tool_returns_mcp_tool(self):
|
||||
"""as_tool 应返回 MCPTool 实例。"""
|
||||
client = MCPClient("http://localhost:8080/mcp")
|
||||
tool = client.as_tool("search", "Search the web")
|
||||
|
||||
assert isinstance(tool, MCPTool)
|
||||
assert tool.name == "search"
|
||||
assert tool.description == "Search the web"
|
||||
assert tool._client is client
|
||||
assert "mcp" in tool.tags
|
||||
|
||||
|
||||
# ── langchain 未安装测试 ─────────────────────────────────────
|
||||
|
||||
|
||||
class TestLangchainNotInstalled:
|
||||
"""langchain-mcp-adapters 未安装时的错误处理。"""
|
||||
|
||||
async def test_import_error_when_langchain_missing(self):
|
||||
"""langchain 未安装时 list_tools 应抛出 ImportError。"""
|
||||
client = MCPClient("http://localhost:8080/mcp")
|
||||
|
||||
with patch(
|
||||
"agentkit.mcp.client._import_langchain_client",
|
||||
side_effect=ImportError("langchain-mcp-adapters 未安装"),
|
||||
):
|
||||
with pytest.raises(ImportError, match="langchain-mcp-adapters"):
|
||||
await client.list_tools()
|
||||
|
||||
|
||||
# ── transport.py 模块级 DeprecationWarning 测试 ──────────────
|
||||
|
||||
|
||||
class TestTransportDeprecation:
|
||||
"""transport.py 导入时应发出 DeprecationWarning。"""
|
||||
|
||||
def test_transport_module_emits_deprecation_warning(self):
|
||||
"""导入 agentkit.mcp.transport 应触发 DeprecationWarning。"""
|
||||
# 先从 sys.modules 移除,强制重新导入
|
||||
mods_to_remove = [k for k in sys.modules if k.startswith("agentkit.mcp.transport")]
|
||||
for mod in mods_to_remove:
|
||||
del sys.modules[mod]
|
||||
|
||||
with warnings.catch_warnings(record=True) as warning_list:
|
||||
warnings.simplefilter("always")
|
||||
importlib.import_module("agentkit.mcp.transport")
|
||||
|
||||
deprecation_warnings = [
|
||||
w for w in warning_list if issubclass(w.category, DeprecationWarning)
|
||||
]
|
||||
assert len(deprecation_warnings) >= 1
|
||||
assert "transport" in str(deprecation_warnings[0].message)
|
||||
|
||||
|
||||
# ── timeout 参数透传测试 ─────────────────────────────────────
|
||||
|
||||
|
||||
class TestTimeoutPropagation:
|
||||
"""timeout 参数应透传到 langchain 连接配置。"""
|
||||
|
||||
def test_timeout_passed_to_http_config(self):
|
||||
"""HTTP 传输应将 timeout 写入配置。"""
|
||||
client = MCPClient("http://localhost:8080/mcp", timeout=60)
|
||||
assert client._langchain_config["timeout"] == 60
|
||||
|
||||
def test_timeout_passed_to_sse_config(self):
|
||||
"""SSE 传输应将 timeout 写入配置。"""
|
||||
client = MCPClient("sse://localhost:8080/sse", timeout=45)
|
||||
assert client._langchain_config["timeout"] == 45
|
||||
|
||||
def test_default_timeout_is_30(self):
|
||||
"""默认 timeout 为 30。"""
|
||||
client = MCPClient("http://localhost:8080/mcp")
|
||||
assert client._timeout == 30
|
||||
assert client._langchain_config["timeout"] == 30
|
||||
|
||||
def test_stdio_config_has_no_timeout(self):
|
||||
"""stdio 传输配置不含 timeout(StdioConnection 无此字段)。"""
|
||||
client = MCPClient("stdio://python -m server", timeout=60)
|
||||
assert "timeout" not in client._langchain_config
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
"""MCP Client 单元测试"""
|
||||
|
||||
import json
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
|
@ -149,112 +150,12 @@ class TestMCPClientTransportMode:
|
|||
await transport.disconnect()
|
||||
|
||||
|
||||
# ── MCPClient 直接 HTTP 模式测试 ────────────────────────────────
|
||||
|
||||
|
||||
class TestMCPClientDirectHTTP:
|
||||
"""MCPClient 直接 HTTP 模式测试(无 Transport)"""
|
||||
|
||||
async def test_list_tools_direct_http(self, httpx_mock):
|
||||
httpx_mock.add_response(
|
||||
url="http://localhost:8080/tools/list",
|
||||
json={
|
||||
"tools": [
|
||||
{"name": "search", "description": "Search tool"},
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
client = MCPClient(server_url="http://localhost:8080")
|
||||
tools = await client.list_tools()
|
||||
|
||||
assert len(tools) == 1
|
||||
assert tools[0]["name"] == "search"
|
||||
assert client._tools_cache == tools
|
||||
|
||||
async def test_call_tool_direct_http(self, httpx_mock):
|
||||
httpx_mock.add_response(
|
||||
url="http://localhost:8080/tools/call",
|
||||
json={"result": "computed value"},
|
||||
)
|
||||
|
||||
client = MCPClient(server_url="http://localhost:8080")
|
||||
result = await client.call_tool("compute", {"x": 42})
|
||||
|
||||
assert result == {"result": "computed value"}
|
||||
|
||||
# 验证请求体
|
||||
request = httpx_mock.get_request()
|
||||
body = json.loads(request.content)
|
||||
assert body["name"] == "compute"
|
||||
assert body["arguments"] == {"x": 42}
|
||||
|
||||
async def test_list_tools_caches_result(self, httpx_mock):
|
||||
httpx_mock.add_response(
|
||||
url="http://localhost:8080/tools/list",
|
||||
json={"tools": [{"name": "tool1"}]},
|
||||
)
|
||||
|
||||
client = MCPClient(server_url="http://localhost:8080")
|
||||
tools = await client.list_tools()
|
||||
|
||||
# 验证缓存被设置
|
||||
assert client._tools_cache == tools
|
||||
assert client._tools_cache[0]["name"] == "tool1"
|
||||
|
||||
async def test_call_tool_sends_post_request(self, httpx_mock):
|
||||
httpx_mock.add_response(
|
||||
url="http://localhost:8080/tools/call",
|
||||
json={"output": "done"},
|
||||
)
|
||||
|
||||
client = MCPClient(server_url="http://localhost:8080")
|
||||
await client.call_tool("my_tool", {"arg": "val"})
|
||||
|
||||
request = httpx_mock.get_request()
|
||||
assert request.method == "POST"
|
||||
|
||||
|
||||
# ── MCPClient 连接错误处理测试 ──────────────────────────────────
|
||||
|
||||
|
||||
class TestMCPClientErrorHandling:
|
||||
"""MCPClient 连接错误处理测试"""
|
||||
|
||||
async def test_list_tools_http_error(self, httpx_mock):
|
||||
httpx_mock.add_response(
|
||||
url="http://localhost:8080/tools/list",
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
client = MCPClient(server_url="http://localhost:8080")
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
await client.list_tools()
|
||||
|
||||
async def test_call_tool_http_error(self, httpx_mock):
|
||||
httpx_mock.add_response(
|
||||
url="http://localhost:8080/tools/call",
|
||||
status_code=404,
|
||||
)
|
||||
|
||||
client = MCPClient(server_url="http://localhost:8080")
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
await client.call_tool("missing_tool", {})
|
||||
|
||||
async def test_list_tools_connection_error(self, httpx_mock):
|
||||
httpx_mock.add_exception(httpx.ConnectError("Connection refused"))
|
||||
|
||||
client = MCPClient(server_url="http://localhost:8080")
|
||||
with pytest.raises(httpx.ConnectError):
|
||||
await client.list_tools()
|
||||
|
||||
async def test_call_tool_connection_error(self, httpx_mock):
|
||||
httpx_mock.add_exception(httpx.ConnectError("Connection refused"))
|
||||
|
||||
client = MCPClient(server_url="http://localhost:8080")
|
||||
with pytest.raises(httpx.ConnectError):
|
||||
await client.call_tool("any_tool", {})
|
||||
|
||||
async def test_transport_error_propagates(self, httpx_mock):
|
||||
httpx_mock.add_exception(httpx.ConnectError("Connection refused"))
|
||||
|
||||
|
|
@ -355,41 +256,38 @@ class TestMCPTool:
|
|||
assert tool._client is client
|
||||
assert "mcp" in tool.tags
|
||||
|
||||
async def test_mcp_tool_execute_text_content(self, httpx_mock):
|
||||
httpx_mock.add_response(
|
||||
url="http://localhost:8080/tools/call",
|
||||
json={
|
||||
"content": [{"type": "text", "text": '{"answer": 42}'}],
|
||||
},
|
||||
)
|
||||
|
||||
async def test_mcp_tool_execute_text_content(self):
|
||||
"""execute 应解析 content[0].text 中的 JSON。"""
|
||||
client = MCPClient(server_url="http://localhost:8080")
|
||||
client.call_tool = AsyncMock( # type: ignore[method-assign]
|
||||
return_value={
|
||||
"content": [{"type": "text", "text": '{"answer": 42}'}],
|
||||
}
|
||||
)
|
||||
tool = client.as_tool("ask", description="Ask a question")
|
||||
|
||||
result = await tool.execute(question="meaning of life")
|
||||
assert result == {"answer": 42}
|
||||
|
||||
async def test_mcp_tool_execute_non_json_text(self, httpx_mock):
|
||||
httpx_mock.add_response(
|
||||
url="http://localhost:8080/tools/call",
|
||||
json={
|
||||
"content": [{"type": "text", "text": "plain text response"}],
|
||||
},
|
||||
)
|
||||
|
||||
async def test_mcp_tool_execute_non_json_text(self):
|
||||
"""content[0].text 为纯文本时返回 {"result": text}。"""
|
||||
client = MCPClient(server_url="http://localhost:8080")
|
||||
client.call_tool = AsyncMock( # type: ignore[method-assign]
|
||||
return_value={
|
||||
"content": [{"type": "text", "text": "plain text response"}],
|
||||
}
|
||||
)
|
||||
tool = client.as_tool("echo", description="Echo input")
|
||||
|
||||
result = await tool.execute(msg="hello")
|
||||
assert result == {"result": "plain text response"}
|
||||
|
||||
async def test_mcp_tool_execute_no_content(self, httpx_mock):
|
||||
httpx_mock.add_response(
|
||||
url="http://localhost:8080/tools/call",
|
||||
json={"status": "ok", "data": "some data"},
|
||||
)
|
||||
|
||||
async def test_mcp_tool_execute_no_content(self):
|
||||
"""无 content 字段时返回原始 dict。"""
|
||||
client = MCPClient(server_url="http://localhost:8080")
|
||||
client.call_tool = AsyncMock( # type: ignore[method-assign]
|
||||
return_value={"status": "ok", "data": "some data"}
|
||||
)
|
||||
tool = client.as_tool("status", description="Check status")
|
||||
|
||||
result = await tool.execute()
|
||||
|
|
|
|||
Loading…
Reference in New Issue