feat(mcp): U16 — langchain-mcp-adapters client replacement + transport deprecation
- 重写 MCPClient:URL scheme 自动检测(stdio/http/sse)→ langchain config - 旧 Transport 注入路径保留(DeprecationWarning),向后兼容 - transport.py 模块级弃用警告 - 28 个新测试覆盖 URL 检测、list_tools、call_tool、legacy 路径、ImportError - 修复 manager.py / transport.py 预存 F401/F841
This commit is contained in:
parent
069dbc22b1
commit
86541d7172
|
|
@ -63,6 +63,8 @@ server = [
|
||||||
]
|
]
|
||||||
mcp = [
|
mcp = [
|
||||||
"mcp>=1.0",
|
"mcp>=1.0",
|
||||||
|
# U16 — langchain-mcp-adapters 替换自研 MCP 客户端传输层
|
||||||
|
"langchain-mcp-adapters>=0.1",
|
||||||
]
|
]
|
||||||
evolution = [
|
evolution = [
|
||||||
"scipy>=1.12",
|
"scipy>=1.12",
|
||||||
|
|
|
||||||
|
|
@ -1,22 +1,56 @@
|
||||||
"""MCP Client - 调用外部 MCP 工具服务器"""
|
"""MCP Client - 调用外部 MCP 工具服务器
|
||||||
|
|
||||||
|
U16: 重写为 langchain-mcp-adapters 包装层。保留 MCPClient / MCPTool API 向后兼容。
|
||||||
|
旧的 Transport 注入路径保留(发出 DeprecationWarning),供现有调用方过渡。
|
||||||
|
|
||||||
|
传输类型由 URL scheme 自动检测(transport=None 时):
|
||||||
|
- ``stdio://command arg1 arg2`` → stdio 传输
|
||||||
|
- ``http://`` / ``https://`` → streamable_http 传输
|
||||||
|
- ``sse://...`` → sse 传输(自动转为 http://)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
import warnings
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
import httpx
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from agentkit.mcp.transport import HTTPTransport, SSETransport, StdioTransport, Transport
|
from agentkit.mcp.transport import HTTPTransport, SSETransport, StdioTransport, Transport
|
||||||
from agentkit.tools.base import Tool
|
from agentkit.tools.base import Tool
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from langchain_mcp_adapters.client import MultiServerMCPClient
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _import_langchain_client() -> type["MultiServerMCPClient"]:
|
||||||
|
"""延迟导入 langchain-mcp-adapters 的 MultiServerMCPClient。
|
||||||
|
|
||||||
|
未安装时抛出带提示信息的 ImportError,便于调用方回退到 Transport 路径。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from langchain_mcp_adapters.client import MultiServerMCPClient as _Client
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"langchain-mcp-adapters 未安装,无法使用 langchain 传输路径。"
|
||||||
|
"请执行 pip install 'fischer-agentkit[mcp]' 或 pip install langchain-mcp-adapters。"
|
||||||
|
"如需使用旧的 Transport 路径,请通过 MCPClient.from_transport(transport) 创建客户端。"
|
||||||
|
) from e
|
||||||
|
return _Client
|
||||||
|
|
||||||
|
|
||||||
class MCPClient:
|
class MCPClient:
|
||||||
"""MCP Client - 连接外部 MCP Server 并调用工具
|
"""MCP Client - 连接外部 MCP Server 并调用工具
|
||||||
|
|
||||||
|
U16: 内部使用 langchain-mcp-adapters 的 MultiServerMCPClient 管理连接。
|
||||||
|
保留旧 Transport 注入路径以向后兼容(发出 DeprecationWarning)。
|
||||||
|
|
||||||
支持两种模式:
|
支持两种模式:
|
||||||
1. 通过 Transport 层发送 JSON-RPC 请求(推荐)
|
1. URL scheme 自动检测(推荐):transport=None,根据 server_url 的 scheme 选择传输
|
||||||
2. 直接 HTTP 调用(向后兼容)
|
2. Transport 注入(旧路径,弃用):传入 Transport 实例,走原有 JSON-RPC 路径
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -25,14 +59,84 @@ class MCPClient:
|
||||||
timeout: int = 30,
|
timeout: int = 30,
|
||||||
transport: Transport | None = None,
|
transport: Transport | None = None,
|
||||||
):
|
):
|
||||||
self._server_url = server_url.rstrip("/")
|
self._server_url = server_url.rstrip("/") if server_url else ""
|
||||||
self._timeout = timeout
|
self._timeout = timeout
|
||||||
self._tools_cache: list[dict] | None = None
|
self._tools_cache: list[dict] | None = None
|
||||||
self._transport = transport
|
self._transport = transport
|
||||||
|
|
||||||
|
if transport is not None:
|
||||||
|
# 旧 Transport 路径 — 发出 DeprecationWarning,但保持原有行为
|
||||||
|
warnings.warn(
|
||||||
|
"通过 Transport 实例创建 MCPClient 已弃用,请使用 URL scheme 自动检测"
|
||||||
|
"(如 MCPClient('http://...') 或 MCPClient('stdio://...'))。"
|
||||||
|
"详见 U16 迁移指南。将在下个版本移除。",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
self._langchain_config: dict[str, Any] | None = None
|
||||||
|
else:
|
||||||
|
# 新 langchain 路径 — 解析 URL scheme 构建连接配置
|
||||||
|
self._langchain_config = self._build_langchain_config(self._server_url, timeout)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_langchain_config(server_url: str, timeout: float) -> dict[str, Any]:
|
||||||
|
"""根据 URL scheme 构建 langchain-mcp-adapters 连接配置。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_url: 服务器 URL,支持 stdio:// http:// https:// sse://
|
||||||
|
timeout: 超时秒数(仅 http/sse 传输使用)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
langchain-mcp-adapters 连接配置 dict
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: URL 为空或 scheme 不支持
|
||||||
|
"""
|
||||||
|
if not server_url:
|
||||||
|
raise ValueError("server_url 不能为空(transport=None 时需要有效 URL)")
|
||||||
|
|
||||||
|
# stdio 特殊处理:stdio://command arg1 arg2 → command + args
|
||||||
|
# 不用 urlparse,因为命令行参数可能含特殊字符
|
||||||
|
if server_url.startswith("stdio://"):
|
||||||
|
raw = server_url[len("stdio://") :]
|
||||||
|
parts = raw.split()
|
||||||
|
if not parts:
|
||||||
|
raise ValueError(f"无效的 stdio URL: {server_url!r}")
|
||||||
|
return {
|
||||||
|
"transport": "stdio",
|
||||||
|
"command": parts[0],
|
||||||
|
"args": parts[1:],
|
||||||
|
}
|
||||||
|
|
||||||
|
parsed = urlparse(server_url)
|
||||||
|
scheme = parsed.scheme.lower()
|
||||||
|
|
||||||
|
if scheme in ("http", "https"):
|
||||||
|
return {
|
||||||
|
"transport": "streamable_http",
|
||||||
|
"url": server_url,
|
||||||
|
"timeout": timeout,
|
||||||
|
}
|
||||||
|
|
||||||
|
if scheme == "sse":
|
||||||
|
# sse://host/path → http://host/path(sse scheme 不携带 TLS 信息)
|
||||||
|
converted = "http://" + server_url[len("sse://") :]
|
||||||
|
return {
|
||||||
|
"transport": "sse",
|
||||||
|
"url": converted,
|
||||||
|
"timeout": timeout,
|
||||||
|
}
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"不支持的 URL scheme: {scheme!r}。支持 stdio://, http://, https://, sse://"
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_transport(cls, transport: Transport) -> "MCPClient":
|
def from_transport(cls, transport: Transport) -> "MCPClient":
|
||||||
"""从 Transport 实例创建 MCPClient"""
|
"""从 Transport 实例创建 MCPClient(旧路径,向后兼容)
|
||||||
|
|
||||||
|
发出 DeprecationWarning。建议迁移到 URL scheme 自动检测。
|
||||||
|
"""
|
||||||
if isinstance(transport, HTTPTransport):
|
if isinstance(transport, HTTPTransport):
|
||||||
server_url = transport._endpoint
|
server_url = transport._endpoint
|
||||||
elif isinstance(transport, SSETransport):
|
elif isinstance(transport, SSETransport):
|
||||||
|
|
@ -44,8 +148,16 @@ class MCPClient:
|
||||||
return cls(server_url=server_url, transport=transport)
|
return cls(server_url=server_url, transport=transport)
|
||||||
|
|
||||||
async def list_tools(self) -> list[dict]:
|
async def list_tools(self) -> list[dict]:
|
||||||
"""列出远程 MCP Server 上的工具"""
|
"""列出远程 MCP Server 上的工具
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
工具列表,每项为 ``{"name":..., "description":..., "inputSchema":...}``
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImportError: langchain-mcp-adapters 未安装(仅 transport=None 路径)
|
||||||
|
"""
|
||||||
if self._transport is not None:
|
if self._transport is not None:
|
||||||
|
# 旧 Transport 路径
|
||||||
if not self._transport.is_connected:
|
if not self._transport.is_connected:
|
||||||
await self._transport.connect()
|
await self._transport.connect()
|
||||||
result = await self._transport.send_request("tools/list")
|
result = await self._transport.send_request("tools/list")
|
||||||
|
|
@ -53,16 +165,52 @@ class MCPClient:
|
||||||
self._tools_cache = tools
|
self._tools_cache = tools
|
||||||
return self._tools_cache
|
return self._tools_cache
|
||||||
|
|
||||||
async with httpx.AsyncClient(timeout=self._timeout) as client:
|
# 新 langchain 路径
|
||||||
response = await client.get(f"{self._server_url}/tools/list")
|
client_cls = _import_langchain_client()
|
||||||
response.raise_for_status()
|
client = client_cls({"server": self._langchain_config})
|
||||||
data = response.json()
|
lc_tools = await client.get_tools()
|
||||||
self._tools_cache = data.get("tools", [])
|
tools = [
|
||||||
return self._tools_cache
|
{
|
||||||
|
"name": t.name,
|
||||||
|
"description": t.description or "",
|
||||||
|
"inputSchema": self._extract_schema(t),
|
||||||
|
}
|
||||||
|
for t in lc_tools
|
||||||
|
]
|
||||||
|
self._tools_cache = tools
|
||||||
|
return tools
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_schema(tool: Any) -> dict[str, Any]:
|
||||||
|
"""从 LangChain Tool 提取 inputSchema(JSON Schema 格式)。
|
||||||
|
|
||||||
|
LangChain 工具通常有 args_schema(pydantic model),回退到 tool.args dict。
|
||||||
|
"""
|
||||||
|
args_schema = getattr(tool, "args_schema", None)
|
||||||
|
if args_schema is not None and hasattr(args_schema, "model_json_schema"):
|
||||||
|
return args_schema.model_json_schema()
|
||||||
|
# ponytail: tool.args 在 langchain BaseTool 上是 dict,够用作回退
|
||||||
|
args = getattr(tool, "args", None)
|
||||||
|
if isinstance(args, dict):
|
||||||
|
return args
|
||||||
|
return {"type": "object", "properties": {}}
|
||||||
|
|
||||||
async def call_tool(self, tool_name: str, arguments: dict) -> dict:
|
async def call_tool(self, tool_name: str, arguments: dict) -> dict:
|
||||||
"""调用远程 MCP 工具"""
|
"""调用远程 MCP 工具
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: 工具名称
|
||||||
|
arguments: 工具参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MCP 响应格式的 dict:``{"content": [{"type":"text","text":...}]}``
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError: 工具不存在(仅 transport=None 路径)
|
||||||
|
ImportError: langchain-mcp-adapters 未安装
|
||||||
|
"""
|
||||||
if self._transport is not None:
|
if self._transport is not None:
|
||||||
|
# 旧 Transport 路径
|
||||||
if not self._transport.is_connected:
|
if not self._transport.is_connected:
|
||||||
await self._transport.connect()
|
await self._transport.connect()
|
||||||
return await self._transport.send_request(
|
return await self._transport.send_request(
|
||||||
|
|
@ -70,13 +218,23 @@ class MCPClient:
|
||||||
params={"name": tool_name, "arguments": arguments},
|
params={"name": tool_name, "arguments": arguments},
|
||||||
)
|
)
|
||||||
|
|
||||||
async with httpx.AsyncClient(timeout=self._timeout) as client:
|
# 新 langchain 路径
|
||||||
response = await client.post(
|
client_cls = _import_langchain_client()
|
||||||
f"{self._server_url}/tools/call",
|
client = client_cls({"server": self._langchain_config})
|
||||||
json={"name": tool_name, "arguments": arguments},
|
lc_tools = await client.get_tools()
|
||||||
)
|
for tool in lc_tools:
|
||||||
response.raise_for_status()
|
if tool.name == tool_name:
|
||||||
return response.json()
|
result = await tool.ainvoke(arguments)
|
||||||
|
# 包装为 MCP 响应格式,保持 call_tool 返回形状与旧路径一致
|
||||||
|
return {
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": json.dumps(result, ensure_ascii=False, default=str),
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
raise KeyError(f"MCP 工具不存在: {tool_name!r}")
|
||||||
|
|
||||||
def as_tool(self, tool_name: str, description: str = "") -> "MCPTool":
|
def as_tool(self, tool_name: str, description: str = "") -> "MCPTool":
|
||||||
"""将远程 MCP 工具包装为本地 Tool 对象"""
|
"""将远程 MCP 工具包装为本地 Tool 对象"""
|
||||||
|
|
@ -116,7 +274,6 @@ class MCPTool(Tool):
|
||||||
if "content" in result:
|
if "content" in result:
|
||||||
for item in result["content"]:
|
for item in result["content"]:
|
||||||
if item.get("type") == "text":
|
if item.get("type") == "text":
|
||||||
import json
|
|
||||||
try:
|
try:
|
||||||
return json.loads(item["text"])
|
return json.loads(item["text"])
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from agentkit.mcp.client import MCPClient
|
from agentkit.mcp.client import MCPClient
|
||||||
from agentkit.mcp.transport import HTTPTransport, SSETransport, StdioTransport, Transport
|
from agentkit.mcp.transport import HTTPTransport, SSETransport, StdioTransport, Transport
|
||||||
|
|
@ -39,10 +39,7 @@ class MCPManager:
|
||||||
|
|
||||||
使用 asyncio.gather 并发启动,单个服务器失败不影响其他服务器。
|
使用 asyncio.gather 并发启动,单个服务器失败不影响其他服务器。
|
||||||
"""
|
"""
|
||||||
tasks = [
|
tasks = [self._start_server_safe(name, config) for name, config in self._configs.items()]
|
||||||
self._start_server_safe(name, config)
|
|
||||||
for name, config in self._configs.items()
|
|
||||||
]
|
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
async def _start_server_safe(self, name: str, config: MCPServerConfig) -> None:
|
async def _start_server_safe(self, name: str, config: MCPServerConfig) -> None:
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,16 @@
|
||||||
"""MCP Transport - 传输层抽象
|
"""MCP Transport - 传输层抽象
|
||||||
|
|
||||||
提供 MCP 协议的传输层实现,支持 Streamable HTTP、SSE 和 Stdio 三种传输方式。
|
提供 MCP 协议的传输层实现,支持 Streamable HTTP、SSE 和 Stdio 三种传输方式。
|
||||||
|
|
||||||
|
U16: 本模块已弃用。请通过 ``MCPClient(url)`` 使用 langchain-mcp-adapters 传输层。
|
||||||
|
保留所有类以向后兼容,将在下个版本移除。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import warnings
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
@ -14,6 +18,14 @@ import httpx
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# U16: 模块级弃用警告 — 导入时发出,提示迁移到 langchain-mcp-adapters
|
||||||
|
warnings.warn(
|
||||||
|
"agentkit.mcp.transport 已弃用 — 请通过 MCPClient(url) 使用 langchain-mcp-adapters 传输层。"
|
||||||
|
"详见 U16 迁移指南。将在下个版本移除。",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TransportError(Exception):
|
class TransportError(Exception):
|
||||||
"""传输层错误"""
|
"""传输层错误"""
|
||||||
|
|
@ -158,9 +170,7 @@ class HTTPTransport(Transport):
|
||||||
# 检查 JSON-RPC 错误
|
# 检查 JSON-RPC 错误
|
||||||
if "error" in data:
|
if "error" in data:
|
||||||
error = data["error"]
|
error = data["error"]
|
||||||
raise TransportError(
|
raise TransportError(f"JSON-RPC error {error.get('code')}: {error.get('message')}")
|
||||||
f"JSON-RPC error {error.get('code')}: {error.get('message')}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return data.get("result")
|
return data.get("result")
|
||||||
|
|
||||||
|
|
@ -264,7 +274,7 @@ class SSETransport(Transport):
|
||||||
if not line or line.startswith(":"):
|
if not line or line.startswith(":"):
|
||||||
continue
|
continue
|
||||||
if line.startswith("data:"):
|
if line.startswith("data:"):
|
||||||
data_str = line[len("data:"):].strip()
|
data_str = line[len("data:") :].strip()
|
||||||
try:
|
try:
|
||||||
data = json.loads(data_str)
|
data = json.loads(data_str)
|
||||||
await self._response_queue.put(data)
|
await self._response_queue.put(data)
|
||||||
|
|
@ -328,9 +338,7 @@ class SSETransport(Transport):
|
||||||
# 检查 JSON-RPC 错误
|
# 检查 JSON-RPC 错误
|
||||||
if "error" in data:
|
if "error" in data:
|
||||||
error = data["error"]
|
error = data["error"]
|
||||||
raise TransportError(
|
raise TransportError(f"JSON-RPC error {error.get('code')}: {error.get('message')}")
|
||||||
f"JSON-RPC error {error.get('code')}: {error.get('message')}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return data.get("result")
|
return data.get("result")
|
||||||
|
|
||||||
|
|
@ -382,11 +390,7 @@ class StdioTransport(Transport):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_connected(self) -> bool:
|
def is_connected(self) -> bool:
|
||||||
return (
|
return self._connected and self._process is not None and self._process.returncode is None
|
||||||
self._connected
|
|
||||||
and self._process is not None
|
|
||||||
and self._process.returncode is None
|
|
||||||
)
|
|
||||||
|
|
||||||
def _next_request_id(self) -> int:
|
def _next_request_id(self) -> int:
|
||||||
"""生成下一个请求 ID"""
|
"""生成下一个请求 ID"""
|
||||||
|
|
@ -427,7 +431,7 @@ class StdioTransport(Transport):
|
||||||
|
|
||||||
# 发送 initialize 请求并等待响应
|
# 发送 initialize 请求并等待响应
|
||||||
try:
|
try:
|
||||||
init_result = await asyncio.wait_for(
|
await asyncio.wait_for(
|
||||||
self._send_request_internal(
|
self._send_request_internal(
|
||||||
"initialize",
|
"initialize",
|
||||||
{
|
{
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,391 @@
|
||||||
|
"""U16 — MCPClient langchain-mcp-adapters 路径单元测试。
|
||||||
|
|
||||||
|
覆盖场景:
|
||||||
|
1. URL scheme 检测(http / sse / stdio)
|
||||||
|
2. list_tools via langchain(三种传输)
|
||||||
|
3. call_tool via langchain(正常 / 未知工具)
|
||||||
|
4. 旧 Transport 路径兼容(DeprecationWarning)
|
||||||
|
5. MCPTool.execute via 新客户端(JSON / 纯文本 / 无 content)
|
||||||
|
6. langchain 未安装时 ImportError
|
||||||
|
7. transport.py 模块级 DeprecationWarning
|
||||||
|
8. timeout 参数透传
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import sys
|
||||||
|
import warnings
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.mcp.client import MCPClient, MCPTool
|
||||||
|
from agentkit.mcp.transport import HTTPTransport
|
||||||
|
|
||||||
|
|
||||||
|
# ── 辅助工具 ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _make_mock_langchain_tool(
|
||||||
|
name: str = "my_tool",
|
||||||
|
description: str = "Test tool",
|
||||||
|
result: Any = None,
|
||||||
|
) -> MagicMock:
|
||||||
|
"""构造一个假的 LangChain Tool 对象。"""
|
||||||
|
tool = MagicMock()
|
||||||
|
tool.name = name
|
||||||
|
tool.description = description
|
||||||
|
tool.args_schema = MagicMock()
|
||||||
|
tool.args_schema.model_json_schema.return_value = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {},
|
||||||
|
}
|
||||||
|
tool.ainvoke = AsyncMock(return_value=result if result is not None else {"result": "ok"})
|
||||||
|
return tool
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_langchain_client() -> Any:
|
||||||
|
"""Patch MultiServerMCPClient,返回带 get_tools 的 mock 实例。
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
(mock_class, mock_instance) 二元组
|
||||||
|
"""
|
||||||
|
with patch("agentkit.mcp.client._import_langchain_client") as mock_import:
|
||||||
|
mock_cls = MagicMock(name="MultiServerMCPClient")
|
||||||
|
mock_instance = MagicMock(name="mcp_client_instance")
|
||||||
|
mock_instance.get_tools = AsyncMock(return_value=[])
|
||||||
|
mock_cls.return_value = mock_instance
|
||||||
|
mock_import.return_value = mock_cls
|
||||||
|
yield mock_cls, mock_instance
|
||||||
|
|
||||||
|
|
||||||
|
# ── URL scheme 检测测试 ───────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestURLSchemeDetection:
|
||||||
|
"""MCPClient 根据 URL scheme 自动检测传输类型。"""
|
||||||
|
|
||||||
|
def test_http_url_uses_streamable_http(self):
|
||||||
|
"""http:// URL 应使用 streamable_http 传输。"""
|
||||||
|
client = MCPClient("http://localhost:8080/mcp")
|
||||||
|
config = client._langchain_config
|
||||||
|
assert config is not None
|
||||||
|
assert config["transport"] == "streamable_http"
|
||||||
|
assert config["url"] == "http://localhost:8080/mcp"
|
||||||
|
|
||||||
|
def test_https_url_uses_streamable_http(self):
|
||||||
|
"""https:// URL 应使用 streamable_http 传输。"""
|
||||||
|
client = MCPClient("https://remote.example.com/mcp")
|
||||||
|
config = client._langchain_config
|
||||||
|
assert config["transport"] == "streamable_http"
|
||||||
|
assert config["url"] == "https://remote.example.com/mcp"
|
||||||
|
|
||||||
|
def test_sse_url_uses_sse_transport(self):
|
||||||
|
"""sse:// URL 应使用 sse 传输,并转为 http://。"""
|
||||||
|
client = MCPClient("sse://localhost:8080/sse")
|
||||||
|
config = client._langchain_config
|
||||||
|
assert config["transport"] == "sse"
|
||||||
|
assert config["url"] == "http://localhost:8080/sse"
|
||||||
|
|
||||||
|
def test_stdio_url_parses_command_and_args(self):
|
||||||
|
"""stdio:// URL 应解析为 command + args。"""
|
||||||
|
client = MCPClient("stdio://python -m my_mcp_server")
|
||||||
|
config = client._langchain_config
|
||||||
|
assert config["transport"] == "stdio"
|
||||||
|
assert config["command"] == "python"
|
||||||
|
assert config["args"] == ["-m", "my_mcp_server"]
|
||||||
|
|
||||||
|
def test_stdio_url_no_args(self):
|
||||||
|
"""stdio:// 只有 command 无参数。"""
|
||||||
|
client = MCPClient("stdio://./server")
|
||||||
|
config = client._langchain_config
|
||||||
|
assert config["command"] == "./server"
|
||||||
|
assert config["args"] == []
|
||||||
|
|
||||||
|
def test_unsupported_scheme_raises(self):
|
||||||
|
"""不支持的 scheme 应抛出 ValueError。"""
|
||||||
|
with pytest.raises(ValueError, match="不支持的 URL scheme"):
|
||||||
|
MCPClient("ftp://example.com/server")
|
||||||
|
|
||||||
|
def test_empty_url_raises(self):
|
||||||
|
"""空 URL 应抛出 ValueError。"""
|
||||||
|
with pytest.raises(ValueError, match="server_url 不能为空"):
|
||||||
|
MCPClient("")
|
||||||
|
|
||||||
|
|
||||||
|
# ── list_tools via langchain 测试 ─────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestListToolsViaLangchain:
|
||||||
|
"""list_tools 通过 langchain-mcp-adapters 获取工具列表。"""
|
||||||
|
|
||||||
|
async def test_list_tools_http(self, mock_langchain_client):
|
||||||
|
"""HTTP 传输 — list_tools 返回工具字典列表。"""
|
||||||
|
mock_cls, mock_instance = mock_langchain_client
|
||||||
|
mock_instance.get_tools = AsyncMock(
|
||||||
|
return_value=[
|
||||||
|
_make_mock_langchain_tool("search", "Search tool"),
|
||||||
|
_make_mock_langchain_tool("calc", "Calculator"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
client = MCPClient("http://localhost:8080/mcp")
|
||||||
|
tools = await client.list_tools()
|
||||||
|
|
||||||
|
assert len(tools) == 2
|
||||||
|
assert tools[0]["name"] == "search"
|
||||||
|
assert tools[0]["description"] == "Search tool"
|
||||||
|
assert "inputSchema" in tools[0]
|
||||||
|
assert tools[1]["name"] == "calc"
|
||||||
|
|
||||||
|
async def test_list_tools_stdio(self, mock_langchain_client):
|
||||||
|
"""Stdio 传输 — list_tools 返回工具字典列表。"""
|
||||||
|
mock_cls, mock_instance = mock_langchain_client
|
||||||
|
mock_instance.get_tools = AsyncMock(
|
||||||
|
return_value=[_make_mock_langchain_tool("read_file", "Read file")]
|
||||||
|
)
|
||||||
|
|
||||||
|
client = MCPClient("stdio://python -m fs_server")
|
||||||
|
tools = await client.list_tools()
|
||||||
|
|
||||||
|
assert len(tools) == 1
|
||||||
|
assert tools[0]["name"] == "read_file"
|
||||||
|
|
||||||
|
async def test_list_tools_sse(self, mock_langchain_client):
|
||||||
|
"""SSE 传输 — list_tools 返回工具字典列表。"""
|
||||||
|
mock_cls, mock_instance = mock_langchain_client
|
||||||
|
mock_instance.get_tools = AsyncMock(
|
||||||
|
return_value=[_make_mock_langchain_tool("query", "Query data")]
|
||||||
|
)
|
||||||
|
|
||||||
|
client = MCPClient("sse://localhost:8080/sse")
|
||||||
|
tools = await client.list_tools()
|
||||||
|
|
||||||
|
assert len(tools) == 1
|
||||||
|
assert tools[0]["name"] == "query"
|
||||||
|
|
||||||
|
async def test_list_tools_caches_result(self, mock_langchain_client):
|
||||||
|
"""list_tools 结果应缓存到 _tools_cache。"""
|
||||||
|
_, mock_instance = mock_langchain_client
|
||||||
|
mock_instance.get_tools = AsyncMock(return_value=[_make_mock_langchain_tool("tool1")])
|
||||||
|
|
||||||
|
client = MCPClient("http://localhost:8080/mcp")
|
||||||
|
tools = await client.list_tools()
|
||||||
|
|
||||||
|
assert client._tools_cache == tools
|
||||||
|
assert client._tools_cache[0]["name"] == "tool1"
|
||||||
|
|
||||||
|
async def test_list_tools_empty(self, mock_langchain_client):
|
||||||
|
"""空工具列表应返回 []。"""
|
||||||
|
_, mock_instance = mock_langchain_client
|
||||||
|
mock_instance.get_tools = AsyncMock(return_value=[])
|
||||||
|
|
||||||
|
client = MCPClient("http://localhost:8080/mcp")
|
||||||
|
tools = await client.list_tools()
|
||||||
|
assert tools == []
|
||||||
|
|
||||||
|
|
||||||
|
# ── call_tool via langchain 测试 ──────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestCallToolViaLangchain:
|
||||||
|
"""call_tool 通过 langchain-mcp-adapters 调用工具。"""
|
||||||
|
|
||||||
|
async def test_call_tool_returns_mcp_format(self, mock_langchain_client):
|
||||||
|
"""call_tool 应返回 MCP 响应格式 {"content": [{"type":"text","text":...}]}。"""
|
||||||
|
_, mock_instance = mock_langchain_client
|
||||||
|
mock_instance.get_tools = AsyncMock(
|
||||||
|
return_value=[_make_mock_langchain_tool("my_tool", result={"result": "hello"})]
|
||||||
|
)
|
||||||
|
|
||||||
|
client = MCPClient("http://localhost:8080/mcp")
|
||||||
|
result = await client.call_tool("my_tool", {"input": "x"})
|
||||||
|
|
||||||
|
assert "content" in result
|
||||||
|
assert result["content"][0]["type"] == "text"
|
||||||
|
assert '"result": "hello"' in result["content"][0]["text"]
|
||||||
|
|
||||||
|
async def test_call_tool_invokes_ainvoke(self, mock_langchain_client):
|
||||||
|
"""call_tool 应调用 LangChain Tool 的 ainvoke。"""
|
||||||
|
_, mock_instance = mock_langchain_client
|
||||||
|
mock_tool = _make_mock_langchain_tool("my_tool", result={"ok": True})
|
||||||
|
mock_instance.get_tools = AsyncMock(return_value=[mock_tool])
|
||||||
|
|
||||||
|
client = MCPClient("http://localhost:8080/mcp")
|
||||||
|
await client.call_tool("my_tool", {"input": "x"})
|
||||||
|
|
||||||
|
mock_tool.ainvoke.assert_called_once_with({"input": "x"})
|
||||||
|
|
||||||
|
async def test_call_tool_unknown_raises(self, mock_langchain_client):
|
||||||
|
"""调用不存在的工具应抛出 KeyError。"""
|
||||||
|
_, mock_instance = mock_langchain_client
|
||||||
|
mock_instance.get_tools = AsyncMock(return_value=[])
|
||||||
|
|
||||||
|
client = MCPClient("http://localhost:8080/mcp")
|
||||||
|
with pytest.raises(KeyError, match="nonexistent"):
|
||||||
|
await client.call_tool("nonexistent", {})
|
||||||
|
|
||||||
|
|
||||||
|
# ── 旧 Transport 路径兼容测试 ─────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestLegacyTransportPath:
|
||||||
|
"""旧 Transport 注入路径应保持向后兼容(发出 DeprecationWarning)。"""
|
||||||
|
|
||||||
|
def test_from_transport_emits_deprecation_warning(self):
|
||||||
|
"""from_transport 应发出 DeprecationWarning。"""
|
||||||
|
transport = HTTPTransport(endpoint="http://localhost:8080/mcp")
|
||||||
|
with pytest.warns(DeprecationWarning, match="Transport"):
|
||||||
|
client = MCPClient.from_transport(transport)
|
||||||
|
assert client._transport is transport
|
||||||
|
assert client._langchain_config is None
|
||||||
|
|
||||||
|
def test_init_with_transport_emits_deprecation_warning(self):
|
||||||
|
"""__init__ 传入 transport 应发出 DeprecationWarning。"""
|
||||||
|
transport = HTTPTransport(endpoint="http://localhost:8080/mcp")
|
||||||
|
with pytest.warns(DeprecationWarning):
|
||||||
|
client = MCPClient(server_url="http://localhost:8080", transport=transport)
|
||||||
|
assert client._transport is transport
|
||||||
|
|
||||||
|
async def test_legacy_list_tools_works(self):
|
||||||
|
"""旧 Transport 路径的 list_tools 仍应正常工作。"""
|
||||||
|
transport = HTTPTransport(endpoint="http://localhost:8080")
|
||||||
|
transport._client = MagicMock()
|
||||||
|
transport._client.is_closed = False
|
||||||
|
transport.send_request = AsyncMock(
|
||||||
|
return_value={"tools": [{"name": "legacy_tool", "description": "Legacy"}]}
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.warns(DeprecationWarning):
|
||||||
|
client = MCPClient.from_transport(transport)
|
||||||
|
|
||||||
|
tools = await client.list_tools()
|
||||||
|
assert len(tools) == 1
|
||||||
|
assert tools[0]["name"] == "legacy_tool"
|
||||||
|
transport.send_request.assert_called_once_with("tools/list")
|
||||||
|
|
||||||
|
|
||||||
|
# ── MCPTool.execute 测试 ──────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestMCPToolExecute:
|
||||||
|
"""MCPTool.execute 通过新客户端调用工具并解析响应。"""
|
||||||
|
|
||||||
|
async def test_execute_parses_json_text(self, mock_langchain_client):
|
||||||
|
"""execute 应解析 content[0].text 中的 JSON。"""
|
||||||
|
_, mock_instance = mock_langchain_client
|
||||||
|
mock_instance.get_tools = AsyncMock(
|
||||||
|
return_value=[_make_mock_langchain_tool("my_tool", result={"answer": 42})]
|
||||||
|
)
|
||||||
|
|
||||||
|
client = MCPClient("http://localhost:8080/mcp")
|
||||||
|
tool = client.as_tool("my_tool", "desc")
|
||||||
|
result = await tool.execute(input="x")
|
||||||
|
|
||||||
|
assert result == {"answer": 42}
|
||||||
|
|
||||||
|
async def test_execute_non_json_text(self):
|
||||||
|
"""content[0].text 为纯文本时返回 {"result": text}。"""
|
||||||
|
client = MCPClient("http://localhost:8080/mcp")
|
||||||
|
client.call_tool = AsyncMock( # type: ignore[method-assign]
|
||||||
|
return_value={"content": [{"type": "text", "text": "plain text"}]}
|
||||||
|
)
|
||||||
|
tool = client.as_tool("echo", "desc")
|
||||||
|
|
||||||
|
result = await tool.execute(msg="hello")
|
||||||
|
assert result == {"result": "plain text"}
|
||||||
|
|
||||||
|
async def test_execute_no_content_field(self):
|
||||||
|
"""无 content 字段时返回原始 dict。"""
|
||||||
|
client = MCPClient("http://localhost:8080/mcp")
|
||||||
|
client.call_tool = AsyncMock( # type: ignore[method-assign]
|
||||||
|
return_value={"unexpected": "shape"}
|
||||||
|
)
|
||||||
|
tool = client.as_tool("status", "desc")
|
||||||
|
|
||||||
|
result = await tool.execute()
|
||||||
|
assert result == {"unexpected": "shape"}
|
||||||
|
|
||||||
|
async def test_as_tool_returns_mcp_tool(self):
|
||||||
|
"""as_tool 应返回 MCPTool 实例。"""
|
||||||
|
client = MCPClient("http://localhost:8080/mcp")
|
||||||
|
tool = client.as_tool("search", "Search the web")
|
||||||
|
|
||||||
|
assert isinstance(tool, MCPTool)
|
||||||
|
assert tool.name == "search"
|
||||||
|
assert tool.description == "Search the web"
|
||||||
|
assert tool._client is client
|
||||||
|
assert "mcp" in tool.tags
|
||||||
|
|
||||||
|
|
||||||
|
# ── langchain 未安装测试 ─────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestLangchainNotInstalled:
|
||||||
|
"""langchain-mcp-adapters 未安装时的错误处理。"""
|
||||||
|
|
||||||
|
async def test_import_error_when_langchain_missing(self):
|
||||||
|
"""langchain 未安装时 list_tools 应抛出 ImportError。"""
|
||||||
|
client = MCPClient("http://localhost:8080/mcp")
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"agentkit.mcp.client._import_langchain_client",
|
||||||
|
side_effect=ImportError("langchain-mcp-adapters 未安装"),
|
||||||
|
):
|
||||||
|
with pytest.raises(ImportError, match="langchain-mcp-adapters"):
|
||||||
|
await client.list_tools()
|
||||||
|
|
||||||
|
|
||||||
|
# ── transport.py 模块级 DeprecationWarning 测试 ──────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestTransportDeprecation:
|
||||||
|
"""transport.py 导入时应发出 DeprecationWarning。"""
|
||||||
|
|
||||||
|
def test_transport_module_emits_deprecation_warning(self):
|
||||||
|
"""导入 agentkit.mcp.transport 应触发 DeprecationWarning。"""
|
||||||
|
# 先从 sys.modules 移除,强制重新导入
|
||||||
|
mods_to_remove = [k for k in sys.modules if k.startswith("agentkit.mcp.transport")]
|
||||||
|
for mod in mods_to_remove:
|
||||||
|
del sys.modules[mod]
|
||||||
|
|
||||||
|
with warnings.catch_warnings(record=True) as warning_list:
|
||||||
|
warnings.simplefilter("always")
|
||||||
|
importlib.import_module("agentkit.mcp.transport")
|
||||||
|
|
||||||
|
deprecation_warnings = [
|
||||||
|
w for w in warning_list if issubclass(w.category, DeprecationWarning)
|
||||||
|
]
|
||||||
|
assert len(deprecation_warnings) >= 1
|
||||||
|
assert "transport" in str(deprecation_warnings[0].message)
|
||||||
|
|
||||||
|
|
||||||
|
# ── timeout 参数透传测试 ─────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestTimeoutPropagation:
|
||||||
|
"""timeout 参数应透传到 langchain 连接配置。"""
|
||||||
|
|
||||||
|
def test_timeout_passed_to_http_config(self):
|
||||||
|
"""HTTP 传输应将 timeout 写入配置。"""
|
||||||
|
client = MCPClient("http://localhost:8080/mcp", timeout=60)
|
||||||
|
assert client._langchain_config["timeout"] == 60
|
||||||
|
|
||||||
|
def test_timeout_passed_to_sse_config(self):
|
||||||
|
"""SSE 传输应将 timeout 写入配置。"""
|
||||||
|
client = MCPClient("sse://localhost:8080/sse", timeout=45)
|
||||||
|
assert client._langchain_config["timeout"] == 45
|
||||||
|
|
||||||
|
def test_default_timeout_is_30(self):
|
||||||
|
"""默认 timeout 为 30。"""
|
||||||
|
client = MCPClient("http://localhost:8080/mcp")
|
||||||
|
assert client._timeout == 30
|
||||||
|
assert client._langchain_config["timeout"] == 30
|
||||||
|
|
||||||
|
def test_stdio_config_has_no_timeout(self):
|
||||||
|
"""stdio 传输配置不含 timeout(StdioConnection 无此字段)。"""
|
||||||
|
client = MCPClient("stdio://python -m server", timeout=60)
|
||||||
|
assert "timeout" not in client._langchain_config
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
"""MCP Client 单元测试"""
|
"""MCP Client 单元测试"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import pytest
|
import pytest
|
||||||
|
|
@ -149,112 +150,12 @@ class TestMCPClientTransportMode:
|
||||||
await transport.disconnect()
|
await transport.disconnect()
|
||||||
|
|
||||||
|
|
||||||
# ── MCPClient 直接 HTTP 模式测试 ────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestMCPClientDirectHTTP:
|
|
||||||
"""MCPClient 直接 HTTP 模式测试(无 Transport)"""
|
|
||||||
|
|
||||||
async def test_list_tools_direct_http(self, httpx_mock):
|
|
||||||
httpx_mock.add_response(
|
|
||||||
url="http://localhost:8080/tools/list",
|
|
||||||
json={
|
|
||||||
"tools": [
|
|
||||||
{"name": "search", "description": "Search tool"},
|
|
||||||
]
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
client = MCPClient(server_url="http://localhost:8080")
|
|
||||||
tools = await client.list_tools()
|
|
||||||
|
|
||||||
assert len(tools) == 1
|
|
||||||
assert tools[0]["name"] == "search"
|
|
||||||
assert client._tools_cache == tools
|
|
||||||
|
|
||||||
async def test_call_tool_direct_http(self, httpx_mock):
|
|
||||||
httpx_mock.add_response(
|
|
||||||
url="http://localhost:8080/tools/call",
|
|
||||||
json={"result": "computed value"},
|
|
||||||
)
|
|
||||||
|
|
||||||
client = MCPClient(server_url="http://localhost:8080")
|
|
||||||
result = await client.call_tool("compute", {"x": 42})
|
|
||||||
|
|
||||||
assert result == {"result": "computed value"}
|
|
||||||
|
|
||||||
# 验证请求体
|
|
||||||
request = httpx_mock.get_request()
|
|
||||||
body = json.loads(request.content)
|
|
||||||
assert body["name"] == "compute"
|
|
||||||
assert body["arguments"] == {"x": 42}
|
|
||||||
|
|
||||||
async def test_list_tools_caches_result(self, httpx_mock):
|
|
||||||
httpx_mock.add_response(
|
|
||||||
url="http://localhost:8080/tools/list",
|
|
||||||
json={"tools": [{"name": "tool1"}]},
|
|
||||||
)
|
|
||||||
|
|
||||||
client = MCPClient(server_url="http://localhost:8080")
|
|
||||||
tools = await client.list_tools()
|
|
||||||
|
|
||||||
# 验证缓存被设置
|
|
||||||
assert client._tools_cache == tools
|
|
||||||
assert client._tools_cache[0]["name"] == "tool1"
|
|
||||||
|
|
||||||
async def test_call_tool_sends_post_request(self, httpx_mock):
|
|
||||||
httpx_mock.add_response(
|
|
||||||
url="http://localhost:8080/tools/call",
|
|
||||||
json={"output": "done"},
|
|
||||||
)
|
|
||||||
|
|
||||||
client = MCPClient(server_url="http://localhost:8080")
|
|
||||||
await client.call_tool("my_tool", {"arg": "val"})
|
|
||||||
|
|
||||||
request = httpx_mock.get_request()
|
|
||||||
assert request.method == "POST"
|
|
||||||
|
|
||||||
|
|
||||||
# ── MCPClient 连接错误处理测试 ──────────────────────────────────
|
# ── MCPClient 连接错误处理测试 ──────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
class TestMCPClientErrorHandling:
|
class TestMCPClientErrorHandling:
|
||||||
"""MCPClient 连接错误处理测试"""
|
"""MCPClient 连接错误处理测试"""
|
||||||
|
|
||||||
async def test_list_tools_http_error(self, httpx_mock):
|
|
||||||
httpx_mock.add_response(
|
|
||||||
url="http://localhost:8080/tools/list",
|
|
||||||
status_code=500,
|
|
||||||
)
|
|
||||||
|
|
||||||
client = MCPClient(server_url="http://localhost:8080")
|
|
||||||
with pytest.raises(httpx.HTTPStatusError):
|
|
||||||
await client.list_tools()
|
|
||||||
|
|
||||||
async def test_call_tool_http_error(self, httpx_mock):
|
|
||||||
httpx_mock.add_response(
|
|
||||||
url="http://localhost:8080/tools/call",
|
|
||||||
status_code=404,
|
|
||||||
)
|
|
||||||
|
|
||||||
client = MCPClient(server_url="http://localhost:8080")
|
|
||||||
with pytest.raises(httpx.HTTPStatusError):
|
|
||||||
await client.call_tool("missing_tool", {})
|
|
||||||
|
|
||||||
async def test_list_tools_connection_error(self, httpx_mock):
|
|
||||||
httpx_mock.add_exception(httpx.ConnectError("Connection refused"))
|
|
||||||
|
|
||||||
client = MCPClient(server_url="http://localhost:8080")
|
|
||||||
with pytest.raises(httpx.ConnectError):
|
|
||||||
await client.list_tools()
|
|
||||||
|
|
||||||
async def test_call_tool_connection_error(self, httpx_mock):
|
|
||||||
httpx_mock.add_exception(httpx.ConnectError("Connection refused"))
|
|
||||||
|
|
||||||
client = MCPClient(server_url="http://localhost:8080")
|
|
||||||
with pytest.raises(httpx.ConnectError):
|
|
||||||
await client.call_tool("any_tool", {})
|
|
||||||
|
|
||||||
async def test_transport_error_propagates(self, httpx_mock):
|
async def test_transport_error_propagates(self, httpx_mock):
|
||||||
httpx_mock.add_exception(httpx.ConnectError("Connection refused"))
|
httpx_mock.add_exception(httpx.ConnectError("Connection refused"))
|
||||||
|
|
||||||
|
|
@ -355,41 +256,38 @@ class TestMCPTool:
|
||||||
assert tool._client is client
|
assert tool._client is client
|
||||||
assert "mcp" in tool.tags
|
assert "mcp" in tool.tags
|
||||||
|
|
||||||
async def test_mcp_tool_execute_text_content(self, httpx_mock):
|
async def test_mcp_tool_execute_text_content(self):
|
||||||
httpx_mock.add_response(
|
"""execute 应解析 content[0].text 中的 JSON。"""
|
||||||
url="http://localhost:8080/tools/call",
|
|
||||||
json={
|
|
||||||
"content": [{"type": "text", "text": '{"answer": 42}'}],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
client = MCPClient(server_url="http://localhost:8080")
|
client = MCPClient(server_url="http://localhost:8080")
|
||||||
|
client.call_tool = AsyncMock( # type: ignore[method-assign]
|
||||||
|
return_value={
|
||||||
|
"content": [{"type": "text", "text": '{"answer": 42}'}],
|
||||||
|
}
|
||||||
|
)
|
||||||
tool = client.as_tool("ask", description="Ask a question")
|
tool = client.as_tool("ask", description="Ask a question")
|
||||||
|
|
||||||
result = await tool.execute(question="meaning of life")
|
result = await tool.execute(question="meaning of life")
|
||||||
assert result == {"answer": 42}
|
assert result == {"answer": 42}
|
||||||
|
|
||||||
async def test_mcp_tool_execute_non_json_text(self, httpx_mock):
|
async def test_mcp_tool_execute_non_json_text(self):
|
||||||
httpx_mock.add_response(
|
"""content[0].text 为纯文本时返回 {"result": text}。"""
|
||||||
url="http://localhost:8080/tools/call",
|
|
||||||
json={
|
|
||||||
"content": [{"type": "text", "text": "plain text response"}],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
client = MCPClient(server_url="http://localhost:8080")
|
client = MCPClient(server_url="http://localhost:8080")
|
||||||
|
client.call_tool = AsyncMock( # type: ignore[method-assign]
|
||||||
|
return_value={
|
||||||
|
"content": [{"type": "text", "text": "plain text response"}],
|
||||||
|
}
|
||||||
|
)
|
||||||
tool = client.as_tool("echo", description="Echo input")
|
tool = client.as_tool("echo", description="Echo input")
|
||||||
|
|
||||||
result = await tool.execute(msg="hello")
|
result = await tool.execute(msg="hello")
|
||||||
assert result == {"result": "plain text response"}
|
assert result == {"result": "plain text response"}
|
||||||
|
|
||||||
async def test_mcp_tool_execute_no_content(self, httpx_mock):
|
async def test_mcp_tool_execute_no_content(self):
|
||||||
httpx_mock.add_response(
|
"""无 content 字段时返回原始 dict。"""
|
||||||
url="http://localhost:8080/tools/call",
|
|
||||||
json={"status": "ok", "data": "some data"},
|
|
||||||
)
|
|
||||||
|
|
||||||
client = MCPClient(server_url="http://localhost:8080")
|
client = MCPClient(server_url="http://localhost:8080")
|
||||||
|
client.call_tool = AsyncMock( # type: ignore[method-assign]
|
||||||
|
return_value={"status": "ok", "data": "some data"}
|
||||||
|
)
|
||||||
tool = client.as_tool("status", description="Check status")
|
tool = client.as_tool("status", description="Check status")
|
||||||
|
|
||||||
result = await tool.execute()
|
result = await tool.execute()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue