392 lines
15 KiB
Python
392 lines
15 KiB
Python
"""U16 — MCPClient langchain-mcp-adapters 路径单元测试。
|
||
|
||
覆盖场景:
|
||
1. URL scheme 检测(http / sse / stdio)
|
||
2. list_tools via langchain(三种传输)
|
||
3. call_tool via langchain(正常 / 未知工具)
|
||
4. 旧 Transport 路径兼容(DeprecationWarning)
|
||
5. MCPTool.execute via 新客户端(JSON / 纯文本 / 无 content)
|
||
6. langchain 未安装时 ImportError
|
||
7. transport.py 模块级 DeprecationWarning
|
||
8. timeout 参数透传
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import importlib
|
||
import sys
|
||
import warnings
|
||
from typing import Any
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
import pytest
|
||
|
||
from agentkit.mcp.client import MCPClient, MCPTool
|
||
from agentkit.mcp.transport import HTTPTransport
|
||
|
||
|
||
# ── 辅助工具 ──────────────────────────────────────────────────
|
||
|
||
|
||
def _make_mock_langchain_tool(
|
||
name: str = "my_tool",
|
||
description: str = "Test tool",
|
||
result: Any = None,
|
||
) -> MagicMock:
|
||
"""构造一个假的 LangChain Tool 对象。"""
|
||
tool = MagicMock()
|
||
tool.name = name
|
||
tool.description = description
|
||
tool.args_schema = MagicMock()
|
||
tool.args_schema.model_json_schema.return_value = {
|
||
"type": "object",
|
||
"properties": {},
|
||
}
|
||
tool.ainvoke = AsyncMock(return_value=result if result is not None else {"result": "ok"})
|
||
return tool
|
||
|
||
|
||
@pytest.fixture
|
||
def mock_langchain_client() -> Any:
|
||
"""Patch MultiServerMCPClient,返回带 get_tools 的 mock 实例。
|
||
|
||
Yields:
|
||
(mock_class, mock_instance) 二元组
|
||
"""
|
||
with patch("agentkit.mcp.client._import_langchain_client") as mock_import:
|
||
mock_cls = MagicMock(name="MultiServerMCPClient")
|
||
mock_instance = MagicMock(name="mcp_client_instance")
|
||
mock_instance.get_tools = AsyncMock(return_value=[])
|
||
mock_cls.return_value = mock_instance
|
||
mock_import.return_value = mock_cls
|
||
yield mock_cls, mock_instance
|
||
|
||
|
||
# ── URL scheme 检测测试 ───────────────────────────────────────
|
||
|
||
|
||
class TestURLSchemeDetection:
|
||
"""MCPClient 根据 URL scheme 自动检测传输类型。"""
|
||
|
||
def test_http_url_uses_streamable_http(self):
|
||
"""http:// URL 应使用 streamable_http 传输。"""
|
||
client = MCPClient("http://localhost:8080/mcp")
|
||
config = client._langchain_config
|
||
assert config is not None
|
||
assert config["transport"] == "streamable_http"
|
||
assert config["url"] == "http://localhost:8080/mcp"
|
||
|
||
def test_https_url_uses_streamable_http(self):
|
||
"""https:// URL 应使用 streamable_http 传输。"""
|
||
client = MCPClient("https://remote.example.com/mcp")
|
||
config = client._langchain_config
|
||
assert config["transport"] == "streamable_http"
|
||
assert config["url"] == "https://remote.example.com/mcp"
|
||
|
||
def test_sse_url_uses_sse_transport(self):
|
||
"""sse:// URL 应使用 sse 传输,并转为 http://。"""
|
||
client = MCPClient("sse://localhost:8080/sse")
|
||
config = client._langchain_config
|
||
assert config["transport"] == "sse"
|
||
assert config["url"] == "http://localhost:8080/sse"
|
||
|
||
def test_stdio_url_parses_command_and_args(self):
|
||
"""stdio:// URL 应解析为 command + args。"""
|
||
client = MCPClient("stdio://python -m my_mcp_server")
|
||
config = client._langchain_config
|
||
assert config["transport"] == "stdio"
|
||
assert config["command"] == "python"
|
||
assert config["args"] == ["-m", "my_mcp_server"]
|
||
|
||
def test_stdio_url_no_args(self):
|
||
"""stdio:// 只有 command 无参数。"""
|
||
client = MCPClient("stdio://./server")
|
||
config = client._langchain_config
|
||
assert config["command"] == "./server"
|
||
assert config["args"] == []
|
||
|
||
def test_unsupported_scheme_raises(self):
|
||
"""不支持的 scheme 应抛出 ValueError。"""
|
||
with pytest.raises(ValueError, match="不支持的 URL scheme"):
|
||
MCPClient("ftp://example.com/server")
|
||
|
||
def test_empty_url_raises(self):
|
||
"""空 URL 应抛出 ValueError。"""
|
||
with pytest.raises(ValueError, match="server_url 不能为空"):
|
||
MCPClient("")
|
||
|
||
|
||
# ── list_tools via langchain 测试 ─────────────────────────────
|
||
|
||
|
||
class TestListToolsViaLangchain:
|
||
"""list_tools 通过 langchain-mcp-adapters 获取工具列表。"""
|
||
|
||
async def test_list_tools_http(self, mock_langchain_client):
|
||
"""HTTP 传输 — list_tools 返回工具字典列表。"""
|
||
mock_cls, mock_instance = mock_langchain_client
|
||
mock_instance.get_tools = AsyncMock(
|
||
return_value=[
|
||
_make_mock_langchain_tool("search", "Search tool"),
|
||
_make_mock_langchain_tool("calc", "Calculator"),
|
||
]
|
||
)
|
||
|
||
client = MCPClient("http://localhost:8080/mcp")
|
||
tools = await client.list_tools()
|
||
|
||
assert len(tools) == 2
|
||
assert tools[0]["name"] == "search"
|
||
assert tools[0]["description"] == "Search tool"
|
||
assert "inputSchema" in tools[0]
|
||
assert tools[1]["name"] == "calc"
|
||
|
||
async def test_list_tools_stdio(self, mock_langchain_client):
|
||
"""Stdio 传输 — list_tools 返回工具字典列表。"""
|
||
mock_cls, mock_instance = mock_langchain_client
|
||
mock_instance.get_tools = AsyncMock(
|
||
return_value=[_make_mock_langchain_tool("read_file", "Read file")]
|
||
)
|
||
|
||
client = MCPClient("stdio://python -m fs_server")
|
||
tools = await client.list_tools()
|
||
|
||
assert len(tools) == 1
|
||
assert tools[0]["name"] == "read_file"
|
||
|
||
async def test_list_tools_sse(self, mock_langchain_client):
|
||
"""SSE 传输 — list_tools 返回工具字典列表。"""
|
||
mock_cls, mock_instance = mock_langchain_client
|
||
mock_instance.get_tools = AsyncMock(
|
||
return_value=[_make_mock_langchain_tool("query", "Query data")]
|
||
)
|
||
|
||
client = MCPClient("sse://localhost:8080/sse")
|
||
tools = await client.list_tools()
|
||
|
||
assert len(tools) == 1
|
||
assert tools[0]["name"] == "query"
|
||
|
||
async def test_list_tools_caches_result(self, mock_langchain_client):
|
||
"""list_tools 结果应缓存到 _tools_cache。"""
|
||
_, mock_instance = mock_langchain_client
|
||
mock_instance.get_tools = AsyncMock(return_value=[_make_mock_langchain_tool("tool1")])
|
||
|
||
client = MCPClient("http://localhost:8080/mcp")
|
||
tools = await client.list_tools()
|
||
|
||
assert client._tools_cache == tools
|
||
assert client._tools_cache[0]["name"] == "tool1"
|
||
|
||
async def test_list_tools_empty(self, mock_langchain_client):
|
||
"""空工具列表应返回 []。"""
|
||
_, mock_instance = mock_langchain_client
|
||
mock_instance.get_tools = AsyncMock(return_value=[])
|
||
|
||
client = MCPClient("http://localhost:8080/mcp")
|
||
tools = await client.list_tools()
|
||
assert tools == []
|
||
|
||
|
||
# ── call_tool via langchain 测试 ──────────────────────────────
|
||
|
||
|
||
class TestCallToolViaLangchain:
|
||
"""call_tool 通过 langchain-mcp-adapters 调用工具。"""
|
||
|
||
async def test_call_tool_returns_mcp_format(self, mock_langchain_client):
|
||
"""call_tool 应返回 MCP 响应格式 {"content": [{"type":"text","text":...}]}。"""
|
||
_, mock_instance = mock_langchain_client
|
||
mock_instance.get_tools = AsyncMock(
|
||
return_value=[_make_mock_langchain_tool("my_tool", result={"result": "hello"})]
|
||
)
|
||
|
||
client = MCPClient("http://localhost:8080/mcp")
|
||
result = await client.call_tool("my_tool", {"input": "x"})
|
||
|
||
assert "content" in result
|
||
assert result["content"][0]["type"] == "text"
|
||
assert '"result": "hello"' in result["content"][0]["text"]
|
||
|
||
async def test_call_tool_invokes_ainvoke(self, mock_langchain_client):
|
||
"""call_tool 应调用 LangChain Tool 的 ainvoke。"""
|
||
_, mock_instance = mock_langchain_client
|
||
mock_tool = _make_mock_langchain_tool("my_tool", result={"ok": True})
|
||
mock_instance.get_tools = AsyncMock(return_value=[mock_tool])
|
||
|
||
client = MCPClient("http://localhost:8080/mcp")
|
||
await client.call_tool("my_tool", {"input": "x"})
|
||
|
||
mock_tool.ainvoke.assert_called_once_with({"input": "x"})
|
||
|
||
async def test_call_tool_unknown_raises(self, mock_langchain_client):
|
||
"""调用不存在的工具应抛出 KeyError。"""
|
||
_, mock_instance = mock_langchain_client
|
||
mock_instance.get_tools = AsyncMock(return_value=[])
|
||
|
||
client = MCPClient("http://localhost:8080/mcp")
|
||
with pytest.raises(KeyError, match="nonexistent"):
|
||
await client.call_tool("nonexistent", {})
|
||
|
||
|
||
# ── 旧 Transport 路径兼容测试 ─────────────────────────────────
|
||
|
||
|
||
class TestLegacyTransportPath:
|
||
"""旧 Transport 注入路径应保持向后兼容(发出 DeprecationWarning)。"""
|
||
|
||
def test_from_transport_emits_deprecation_warning(self):
|
||
"""from_transport 应发出 DeprecationWarning。"""
|
||
transport = HTTPTransport(endpoint="http://localhost:8080/mcp")
|
||
with pytest.warns(DeprecationWarning, match="Transport"):
|
||
client = MCPClient.from_transport(transport)
|
||
assert client._transport is transport
|
||
assert client._langchain_config is None
|
||
|
||
def test_init_with_transport_emits_deprecation_warning(self):
|
||
"""__init__ 传入 transport 应发出 DeprecationWarning。"""
|
||
transport = HTTPTransport(endpoint="http://localhost:8080/mcp")
|
||
with pytest.warns(DeprecationWarning):
|
||
client = MCPClient(server_url="http://localhost:8080", transport=transport)
|
||
assert client._transport is transport
|
||
|
||
async def test_legacy_list_tools_works(self):
|
||
"""旧 Transport 路径的 list_tools 仍应正常工作。"""
|
||
transport = HTTPTransport(endpoint="http://localhost:8080")
|
||
transport._client = MagicMock()
|
||
transport._client.is_closed = False
|
||
transport.send_request = AsyncMock(
|
||
return_value={"tools": [{"name": "legacy_tool", "description": "Legacy"}]}
|
||
)
|
||
|
||
with pytest.warns(DeprecationWarning):
|
||
client = MCPClient.from_transport(transport)
|
||
|
||
tools = await client.list_tools()
|
||
assert len(tools) == 1
|
||
assert tools[0]["name"] == "legacy_tool"
|
||
transport.send_request.assert_called_once_with("tools/list")
|
||
|
||
|
||
# ── MCPTool.execute 测试 ──────────────────────────────────────
|
||
|
||
|
||
class TestMCPToolExecute:
|
||
"""MCPTool.execute 通过新客户端调用工具并解析响应。"""
|
||
|
||
async def test_execute_parses_json_text(self, mock_langchain_client):
|
||
"""execute 应解析 content[0].text 中的 JSON。"""
|
||
_, mock_instance = mock_langchain_client
|
||
mock_instance.get_tools = AsyncMock(
|
||
return_value=[_make_mock_langchain_tool("my_tool", result={"answer": 42})]
|
||
)
|
||
|
||
client = MCPClient("http://localhost:8080/mcp")
|
||
tool = client.as_tool("my_tool", "desc")
|
||
result = await tool.execute(input="x")
|
||
|
||
assert result == {"answer": 42}
|
||
|
||
async def test_execute_non_json_text(self):
|
||
"""content[0].text 为纯文本时返回 {"result": text}。"""
|
||
client = MCPClient("http://localhost:8080/mcp")
|
||
client.call_tool = AsyncMock( # type: ignore[method-assign]
|
||
return_value={"content": [{"type": "text", "text": "plain text"}]}
|
||
)
|
||
tool = client.as_tool("echo", "desc")
|
||
|
||
result = await tool.execute(msg="hello")
|
||
assert result == {"result": "plain text"}
|
||
|
||
async def test_execute_no_content_field(self):
|
||
"""无 content 字段时返回原始 dict。"""
|
||
client = MCPClient("http://localhost:8080/mcp")
|
||
client.call_tool = AsyncMock( # type: ignore[method-assign]
|
||
return_value={"unexpected": "shape"}
|
||
)
|
||
tool = client.as_tool("status", "desc")
|
||
|
||
result = await tool.execute()
|
||
assert result == {"unexpected": "shape"}
|
||
|
||
async def test_as_tool_returns_mcp_tool(self):
|
||
"""as_tool 应返回 MCPTool 实例。"""
|
||
client = MCPClient("http://localhost:8080/mcp")
|
||
tool = client.as_tool("search", "Search the web")
|
||
|
||
assert isinstance(tool, MCPTool)
|
||
assert tool.name == "search"
|
||
assert tool.description == "Search the web"
|
||
assert tool._client is client
|
||
assert "mcp" in tool.tags
|
||
|
||
|
||
# ── langchain 未安装测试 ─────────────────────────────────────
|
||
|
||
|
||
class TestLangchainNotInstalled:
|
||
"""langchain-mcp-adapters 未安装时的错误处理。"""
|
||
|
||
async def test_import_error_when_langchain_missing(self):
|
||
"""langchain 未安装时 list_tools 应抛出 ImportError。"""
|
||
client = MCPClient("http://localhost:8080/mcp")
|
||
|
||
with patch(
|
||
"agentkit.mcp.client._import_langchain_client",
|
||
side_effect=ImportError("langchain-mcp-adapters 未安装"),
|
||
):
|
||
with pytest.raises(ImportError, match="langchain-mcp-adapters"):
|
||
await client.list_tools()
|
||
|
||
|
||
# ── transport.py 模块级 DeprecationWarning 测试 ──────────────
|
||
|
||
|
||
class TestTransportDeprecation:
|
||
"""transport.py 导入时应发出 DeprecationWarning。"""
|
||
|
||
def test_transport_module_emits_deprecation_warning(self):
|
||
"""导入 agentkit.mcp.transport 应触发 DeprecationWarning。"""
|
||
# 先从 sys.modules 移除,强制重新导入
|
||
mods_to_remove = [k for k in sys.modules if k.startswith("agentkit.mcp.transport")]
|
||
for mod in mods_to_remove:
|
||
del sys.modules[mod]
|
||
|
||
with warnings.catch_warnings(record=True) as warning_list:
|
||
warnings.simplefilter("always")
|
||
importlib.import_module("agentkit.mcp.transport")
|
||
|
||
deprecation_warnings = [
|
||
w for w in warning_list if issubclass(w.category, DeprecationWarning)
|
||
]
|
||
assert len(deprecation_warnings) >= 1
|
||
assert "transport" in str(deprecation_warnings[0].message)
|
||
|
||
|
||
# ── timeout 参数透传测试 ─────────────────────────────────────
|
||
|
||
|
||
class TestTimeoutPropagation:
|
||
"""timeout 参数应透传到 langchain 连接配置。"""
|
||
|
||
def test_timeout_passed_to_http_config(self):
|
||
"""HTTP 传输应将 timeout 写入配置。"""
|
||
client = MCPClient("http://localhost:8080/mcp", timeout=60)
|
||
assert client._langchain_config["timeout"] == 60
|
||
|
||
def test_timeout_passed_to_sse_config(self):
|
||
"""SSE 传输应将 timeout 写入配置。"""
|
||
client = MCPClient("sse://localhost:8080/sse", timeout=45)
|
||
assert client._langchain_config["timeout"] == 45
|
||
|
||
def test_default_timeout_is_30(self):
|
||
"""默认 timeout 为 30。"""
|
||
client = MCPClient("http://localhost:8080/mcp")
|
||
assert client._timeout == 30
|
||
assert client._langchain_config["timeout"] == 30
|
||
|
||
def test_stdio_config_has_no_timeout(self):
|
||
"""stdio 传输配置不含 timeout(StdioConnection 无此字段)。"""
|
||
client = MCPClient("stdio://python -m server", timeout=60)
|
||
assert "timeout" not in client._langchain_config
|