fischer-agentkit/tests/unit/mcp/test_client.py

392 lines
15 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""U16 — MCPClient langchain-mcp-adapters 路径单元测试。
覆盖场景:
1. URL scheme 检测http / sse / stdio
2. list_tools via langchain三种传输
3. call_tool via langchain正常 / 未知工具)
4. 旧 Transport 路径兼容DeprecationWarning
5. MCPTool.execute via 新客户端JSON / 纯文本 / 无 content
6. langchain 未安装时 ImportError
7. transport.py 模块级 DeprecationWarning
8. timeout 参数透传
"""
from __future__ import annotations
import importlib
import sys
import warnings
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from agentkit.mcp.client import MCPClient, MCPTool
from agentkit.mcp.transport import HTTPTransport
# ── 辅助工具 ──────────────────────────────────────────────────
def _make_mock_langchain_tool(
name: str = "my_tool",
description: str = "Test tool",
result: Any = None,
) -> MagicMock:
"""构造一个假的 LangChain Tool 对象。"""
tool = MagicMock()
tool.name = name
tool.description = description
tool.args_schema = MagicMock()
tool.args_schema.model_json_schema.return_value = {
"type": "object",
"properties": {},
}
tool.ainvoke = AsyncMock(return_value=result if result is not None else {"result": "ok"})
return tool
@pytest.fixture
def mock_langchain_client() -> Any:
"""Patch MultiServerMCPClient返回带 get_tools 的 mock 实例。
Yields:
(mock_class, mock_instance) 二元组
"""
with patch("agentkit.mcp.client._import_langchain_client") as mock_import:
mock_cls = MagicMock(name="MultiServerMCPClient")
mock_instance = MagicMock(name="mcp_client_instance")
mock_instance.get_tools = AsyncMock(return_value=[])
mock_cls.return_value = mock_instance
mock_import.return_value = mock_cls
yield mock_cls, mock_instance
# ── URL scheme 检测测试 ───────────────────────────────────────
class TestURLSchemeDetection:
"""MCPClient 根据 URL scheme 自动检测传输类型。"""
def test_http_url_uses_streamable_http(self):
"""http:// URL 应使用 streamable_http 传输。"""
client = MCPClient("http://localhost:8080/mcp")
config = client._langchain_config
assert config is not None
assert config["transport"] == "streamable_http"
assert config["url"] == "http://localhost:8080/mcp"
def test_https_url_uses_streamable_http(self):
"""https:// URL 应使用 streamable_http 传输。"""
client = MCPClient("https://remote.example.com/mcp")
config = client._langchain_config
assert config["transport"] == "streamable_http"
assert config["url"] == "https://remote.example.com/mcp"
def test_sse_url_uses_sse_transport(self):
"""sse:// URL 应使用 sse 传输,并转为 http://。"""
client = MCPClient("sse://localhost:8080/sse")
config = client._langchain_config
assert config["transport"] == "sse"
assert config["url"] == "http://localhost:8080/sse"
def test_stdio_url_parses_command_and_args(self):
"""stdio:// URL 应解析为 command + args。"""
client = MCPClient("stdio://python -m my_mcp_server")
config = client._langchain_config
assert config["transport"] == "stdio"
assert config["command"] == "python"
assert config["args"] == ["-m", "my_mcp_server"]
def test_stdio_url_no_args(self):
"""stdio:// 只有 command 无参数。"""
client = MCPClient("stdio://./server")
config = client._langchain_config
assert config["command"] == "./server"
assert config["args"] == []
def test_unsupported_scheme_raises(self):
"""不支持的 scheme 应抛出 ValueError。"""
with pytest.raises(ValueError, match="不支持的 URL scheme"):
MCPClient("ftp://example.com/server")
def test_empty_url_raises(self):
"""空 URL 应抛出 ValueError。"""
with pytest.raises(ValueError, match="server_url 不能为空"):
MCPClient("")
# ── list_tools via langchain 测试 ─────────────────────────────
class TestListToolsViaLangchain:
"""list_tools 通过 langchain-mcp-adapters 获取工具列表。"""
async def test_list_tools_http(self, mock_langchain_client):
"""HTTP 传输 — list_tools 返回工具字典列表。"""
mock_cls, mock_instance = mock_langchain_client
mock_instance.get_tools = AsyncMock(
return_value=[
_make_mock_langchain_tool("search", "Search tool"),
_make_mock_langchain_tool("calc", "Calculator"),
]
)
client = MCPClient("http://localhost:8080/mcp")
tools = await client.list_tools()
assert len(tools) == 2
assert tools[0]["name"] == "search"
assert tools[0]["description"] == "Search tool"
assert "inputSchema" in tools[0]
assert tools[1]["name"] == "calc"
async def test_list_tools_stdio(self, mock_langchain_client):
"""Stdio 传输 — list_tools 返回工具字典列表。"""
mock_cls, mock_instance = mock_langchain_client
mock_instance.get_tools = AsyncMock(
return_value=[_make_mock_langchain_tool("read_file", "Read file")]
)
client = MCPClient("stdio://python -m fs_server")
tools = await client.list_tools()
assert len(tools) == 1
assert tools[0]["name"] == "read_file"
async def test_list_tools_sse(self, mock_langchain_client):
"""SSE 传输 — list_tools 返回工具字典列表。"""
mock_cls, mock_instance = mock_langchain_client
mock_instance.get_tools = AsyncMock(
return_value=[_make_mock_langchain_tool("query", "Query data")]
)
client = MCPClient("sse://localhost:8080/sse")
tools = await client.list_tools()
assert len(tools) == 1
assert tools[0]["name"] == "query"
async def test_list_tools_caches_result(self, mock_langchain_client):
"""list_tools 结果应缓存到 _tools_cache。"""
_, mock_instance = mock_langchain_client
mock_instance.get_tools = AsyncMock(return_value=[_make_mock_langchain_tool("tool1")])
client = MCPClient("http://localhost:8080/mcp")
tools = await client.list_tools()
assert client._tools_cache == tools
assert client._tools_cache[0]["name"] == "tool1"
async def test_list_tools_empty(self, mock_langchain_client):
"""空工具列表应返回 []。"""
_, mock_instance = mock_langchain_client
mock_instance.get_tools = AsyncMock(return_value=[])
client = MCPClient("http://localhost:8080/mcp")
tools = await client.list_tools()
assert tools == []
# ── call_tool via langchain 测试 ──────────────────────────────
class TestCallToolViaLangchain:
"""call_tool 通过 langchain-mcp-adapters 调用工具。"""
async def test_call_tool_returns_mcp_format(self, mock_langchain_client):
"""call_tool 应返回 MCP 响应格式 {"content": [{"type":"text","text":...}]}。"""
_, mock_instance = mock_langchain_client
mock_instance.get_tools = AsyncMock(
return_value=[_make_mock_langchain_tool("my_tool", result={"result": "hello"})]
)
client = MCPClient("http://localhost:8080/mcp")
result = await client.call_tool("my_tool", {"input": "x"})
assert "content" in result
assert result["content"][0]["type"] == "text"
assert '"result": "hello"' in result["content"][0]["text"]
async def test_call_tool_invokes_ainvoke(self, mock_langchain_client):
"""call_tool 应调用 LangChain Tool 的 ainvoke。"""
_, mock_instance = mock_langchain_client
mock_tool = _make_mock_langchain_tool("my_tool", result={"ok": True})
mock_instance.get_tools = AsyncMock(return_value=[mock_tool])
client = MCPClient("http://localhost:8080/mcp")
await client.call_tool("my_tool", {"input": "x"})
mock_tool.ainvoke.assert_called_once_with({"input": "x"})
async def test_call_tool_unknown_raises(self, mock_langchain_client):
"""调用不存在的工具应抛出 KeyError。"""
_, mock_instance = mock_langchain_client
mock_instance.get_tools = AsyncMock(return_value=[])
client = MCPClient("http://localhost:8080/mcp")
with pytest.raises(KeyError, match="nonexistent"):
await client.call_tool("nonexistent", {})
# ── 旧 Transport 路径兼容测试 ─────────────────────────────────
class TestLegacyTransportPath:
"""旧 Transport 注入路径应保持向后兼容(发出 DeprecationWarning"""
def test_from_transport_emits_deprecation_warning(self):
"""from_transport 应发出 DeprecationWarning。"""
transport = HTTPTransport(endpoint="http://localhost:8080/mcp")
with pytest.warns(DeprecationWarning, match="Transport"):
client = MCPClient.from_transport(transport)
assert client._transport is transport
assert client._langchain_config is None
def test_init_with_transport_emits_deprecation_warning(self):
"""__init__ 传入 transport 应发出 DeprecationWarning。"""
transport = HTTPTransport(endpoint="http://localhost:8080/mcp")
with pytest.warns(DeprecationWarning):
client = MCPClient(server_url="http://localhost:8080", transport=transport)
assert client._transport is transport
async def test_legacy_list_tools_works(self):
"""旧 Transport 路径的 list_tools 仍应正常工作。"""
transport = HTTPTransport(endpoint="http://localhost:8080")
transport._client = MagicMock()
transport._client.is_closed = False
transport.send_request = AsyncMock(
return_value={"tools": [{"name": "legacy_tool", "description": "Legacy"}]}
)
with pytest.warns(DeprecationWarning):
client = MCPClient.from_transport(transport)
tools = await client.list_tools()
assert len(tools) == 1
assert tools[0]["name"] == "legacy_tool"
transport.send_request.assert_called_once_with("tools/list")
# ── MCPTool.execute 测试 ──────────────────────────────────────
class TestMCPToolExecute:
"""MCPTool.execute 通过新客户端调用工具并解析响应。"""
async def test_execute_parses_json_text(self, mock_langchain_client):
"""execute 应解析 content[0].text 中的 JSON。"""
_, mock_instance = mock_langchain_client
mock_instance.get_tools = AsyncMock(
return_value=[_make_mock_langchain_tool("my_tool", result={"answer": 42})]
)
client = MCPClient("http://localhost:8080/mcp")
tool = client.as_tool("my_tool", "desc")
result = await tool.execute(input="x")
assert result == {"answer": 42}
async def test_execute_non_json_text(self):
"""content[0].text 为纯文本时返回 {"result": text}。"""
client = MCPClient("http://localhost:8080/mcp")
client.call_tool = AsyncMock( # type: ignore[method-assign]
return_value={"content": [{"type": "text", "text": "plain text"}]}
)
tool = client.as_tool("echo", "desc")
result = await tool.execute(msg="hello")
assert result == {"result": "plain text"}
async def test_execute_no_content_field(self):
"""无 content 字段时返回原始 dict。"""
client = MCPClient("http://localhost:8080/mcp")
client.call_tool = AsyncMock( # type: ignore[method-assign]
return_value={"unexpected": "shape"}
)
tool = client.as_tool("status", "desc")
result = await tool.execute()
assert result == {"unexpected": "shape"}
async def test_as_tool_returns_mcp_tool(self):
"""as_tool 应返回 MCPTool 实例。"""
client = MCPClient("http://localhost:8080/mcp")
tool = client.as_tool("search", "Search the web")
assert isinstance(tool, MCPTool)
assert tool.name == "search"
assert tool.description == "Search the web"
assert tool._client is client
assert "mcp" in tool.tags
# ── langchain 未安装测试 ─────────────────────────────────────
class TestLangchainNotInstalled:
"""langchain-mcp-adapters 未安装时的错误处理。"""
async def test_import_error_when_langchain_missing(self):
"""langchain 未安装时 list_tools 应抛出 ImportError。"""
client = MCPClient("http://localhost:8080/mcp")
with patch(
"agentkit.mcp.client._import_langchain_client",
side_effect=ImportError("langchain-mcp-adapters 未安装"),
):
with pytest.raises(ImportError, match="langchain-mcp-adapters"):
await client.list_tools()
# ── transport.py 模块级 DeprecationWarning 测试 ──────────────
class TestTransportDeprecation:
"""transport.py 导入时应发出 DeprecationWarning。"""
def test_transport_module_emits_deprecation_warning(self):
"""导入 agentkit.mcp.transport 应触发 DeprecationWarning。"""
# 先从 sys.modules 移除,强制重新导入
mods_to_remove = [k for k in sys.modules if k.startswith("agentkit.mcp.transport")]
for mod in mods_to_remove:
del sys.modules[mod]
with warnings.catch_warnings(record=True) as warning_list:
warnings.simplefilter("always")
importlib.import_module("agentkit.mcp.transport")
deprecation_warnings = [
w for w in warning_list if issubclass(w.category, DeprecationWarning)
]
assert len(deprecation_warnings) >= 1
assert "transport" in str(deprecation_warnings[0].message)
# ── timeout 参数透传测试 ─────────────────────────────────────
class TestTimeoutPropagation:
"""timeout 参数应透传到 langchain 连接配置。"""
def test_timeout_passed_to_http_config(self):
"""HTTP 传输应将 timeout 写入配置。"""
client = MCPClient("http://localhost:8080/mcp", timeout=60)
assert client._langchain_config["timeout"] == 60
def test_timeout_passed_to_sse_config(self):
"""SSE 传输应将 timeout 写入配置。"""
client = MCPClient("sse://localhost:8080/sse", timeout=45)
assert client._langchain_config["timeout"] == 45
def test_default_timeout_is_30(self):
"""默认 timeout 为 30。"""
client = MCPClient("http://localhost:8080/mcp")
assert client._timeout == 30
assert client._langchain_config["timeout"] == 30
def test_stdio_config_has_no_timeout(self):
"""stdio 传输配置不含 timeoutStdioConnection 无此字段)。"""
client = MCPClient("stdio://python -m server", timeout=60)
assert "timeout" not in client._langchain_config