397 lines
14 KiB
Python
397 lines
14 KiB
Python
"""MCP Client 单元测试"""
|
||
|
||
import json
|
||
|
||
import httpx
|
||
import pytest
|
||
|
||
from agentkit.mcp.client import MCPClient, MCPTool
|
||
from agentkit.mcp.transport import HTTPTransport, TransportError
|
||
|
||
|
||
# ── MCPClient 构造测试 ──────────────────────────────────────────
|
||
|
||
|
||
class TestMCPClientConstruction:
|
||
"""MCPClient 构造测试"""
|
||
|
||
def test_construction_with_server_url(self):
|
||
client = MCPClient(server_url="http://localhost:8080")
|
||
assert client._server_url == "http://localhost:8080"
|
||
assert client._transport is None
|
||
assert client._timeout == 30
|
||
assert client._tools_cache is None
|
||
|
||
def test_construction_strips_trailing_slash(self):
|
||
client = MCPClient(server_url="http://localhost:8080/")
|
||
assert client._server_url == "http://localhost:8080"
|
||
|
||
def test_construction_with_custom_timeout(self):
|
||
client = MCPClient(server_url="http://localhost:8080", timeout=60)
|
||
assert client._timeout == 60
|
||
|
||
def test_construction_with_transport(self):
|
||
transport = HTTPTransport(endpoint="http://localhost:8080")
|
||
client = MCPClient(server_url="http://localhost:8080", transport=transport)
|
||
assert client._transport is transport
|
||
|
||
def test_from_transport_with_http_transport(self):
|
||
transport = HTTPTransport(endpoint="http://localhost:8080/mcp")
|
||
client = MCPClient.from_transport(transport)
|
||
assert client._transport is transport
|
||
assert client._server_url == "http://localhost:8080/mcp"
|
||
|
||
def test_from_transport_preserves_endpoint(self):
|
||
transport = HTTPTransport(endpoint="http://remote-server:3000/api")
|
||
client = MCPClient.from_transport(transport)
|
||
assert client._server_url == "http://remote-server:3000/api"
|
||
|
||
|
||
# ── MCPClient Transport 模式测试 ────────────────────────────────
|
||
|
||
|
||
class TestMCPClientTransportMode:
|
||
"""MCPClient Transport 模式测试"""
|
||
|
||
async def test_list_tools_via_transport(self, httpx_mock):
|
||
httpx_mock.add_response(
|
||
url="http://localhost:8080/",
|
||
json={
|
||
"jsonrpc": "2.0",
|
||
"id": 1,
|
||
"result": {
|
||
"tools": [
|
||
{"name": "echo", "description": "Echo tool"},
|
||
{"name": "calc", "description": "Calculator"},
|
||
]
|
||
},
|
||
},
|
||
)
|
||
|
||
transport = HTTPTransport(endpoint="http://localhost:8080")
|
||
client = MCPClient.from_transport(transport)
|
||
|
||
tools = await client.list_tools()
|
||
assert len(tools) == 2
|
||
assert tools[0]["name"] == "echo"
|
||
assert tools[1]["name"] == "calc"
|
||
|
||
# 验证缓存
|
||
assert client._tools_cache == tools
|
||
|
||
await transport.disconnect()
|
||
|
||
async def test_list_tools_transport_auto_connects(self, httpx_mock):
|
||
httpx_mock.add_response(
|
||
url="http://localhost:8080/",
|
||
json={
|
||
"jsonrpc": "2.0",
|
||
"id": 1,
|
||
"result": {"tools": [{"name": "search"}]},
|
||
},
|
||
)
|
||
|
||
transport = HTTPTransport(endpoint="http://localhost:8080")
|
||
client = MCPClient.from_transport(transport)
|
||
assert not transport.is_connected
|
||
|
||
tools = await client.list_tools()
|
||
assert len(tools) == 1
|
||
assert transport.is_connected
|
||
|
||
await transport.disconnect()
|
||
|
||
async def test_call_tool_via_transport(self, httpx_mock):
|
||
httpx_mock.add_response(
|
||
url="http://localhost:8080/",
|
||
json={
|
||
"jsonrpc": "2.0",
|
||
"id": 1,
|
||
"result": {
|
||
"content": [{"type": "text", "text": "hello world"}],
|
||
},
|
||
},
|
||
)
|
||
|
||
transport = HTTPTransport(endpoint="http://localhost:8080")
|
||
client = MCPClient.from_transport(transport)
|
||
|
||
result = await client.call_tool("echo", {"msg": "hello world"})
|
||
assert result["content"][0]["text"] == "hello world"
|
||
|
||
# 验证请求体为 JSON-RPC 格式
|
||
request = httpx_mock.get_request()
|
||
body = json.loads(request.content)
|
||
assert body["jsonrpc"] == "2.0"
|
||
assert body["method"] == "tools/call"
|
||
assert body["params"]["name"] == "echo"
|
||
assert body["params"]["arguments"] == {"msg": "hello world"}
|
||
|
||
await transport.disconnect()
|
||
|
||
async def test_call_tool_transport_auto_connects(self, httpx_mock):
|
||
httpx_mock.add_response(
|
||
url="http://localhost:8080/",
|
||
json={
|
||
"jsonrpc": "2.0",
|
||
"id": 1,
|
||
"result": {"content": []},
|
||
},
|
||
)
|
||
|
||
transport = HTTPTransport(endpoint="http://localhost:8080")
|
||
client = MCPClient.from_transport(transport)
|
||
assert not transport.is_connected
|
||
|
||
await client.call_tool("test_tool", {})
|
||
assert transport.is_connected
|
||
|
||
await transport.disconnect()
|
||
|
||
|
||
# ── MCPClient 直接 HTTP 模式测试 ────────────────────────────────
|
||
|
||
|
||
class TestMCPClientDirectHTTP:
|
||
"""MCPClient 直接 HTTP 模式测试(无 Transport)"""
|
||
|
||
async def test_list_tools_direct_http(self, httpx_mock):
|
||
httpx_mock.add_response(
|
||
url="http://localhost:8080/tools/list",
|
||
json={
|
||
"tools": [
|
||
{"name": "search", "description": "Search tool"},
|
||
]
|
||
},
|
||
)
|
||
|
||
client = MCPClient(server_url="http://localhost:8080")
|
||
tools = await client.list_tools()
|
||
|
||
assert len(tools) == 1
|
||
assert tools[0]["name"] == "search"
|
||
assert client._tools_cache == tools
|
||
|
||
async def test_call_tool_direct_http(self, httpx_mock):
|
||
httpx_mock.add_response(
|
||
url="http://localhost:8080/tools/call",
|
||
json={"result": "computed value"},
|
||
)
|
||
|
||
client = MCPClient(server_url="http://localhost:8080")
|
||
result = await client.call_tool("compute", {"x": 42})
|
||
|
||
assert result == {"result": "computed value"}
|
||
|
||
# 验证请求体
|
||
request = httpx_mock.get_request()
|
||
body = json.loads(request.content)
|
||
assert body["name"] == "compute"
|
||
assert body["arguments"] == {"x": 42}
|
||
|
||
async def test_list_tools_caches_result(self, httpx_mock):
|
||
httpx_mock.add_response(
|
||
url="http://localhost:8080/tools/list",
|
||
json={"tools": [{"name": "tool1"}]},
|
||
)
|
||
|
||
client = MCPClient(server_url="http://localhost:8080")
|
||
tools = await client.list_tools()
|
||
|
||
# 验证缓存被设置
|
||
assert client._tools_cache == tools
|
||
assert client._tools_cache[0]["name"] == "tool1"
|
||
|
||
async def test_call_tool_sends_post_request(self, httpx_mock):
|
||
httpx_mock.add_response(
|
||
url="http://localhost:8080/tools/call",
|
||
json={"output": "done"},
|
||
)
|
||
|
||
client = MCPClient(server_url="http://localhost:8080")
|
||
await client.call_tool("my_tool", {"arg": "val"})
|
||
|
||
request = httpx_mock.get_request()
|
||
assert request.method == "POST"
|
||
|
||
|
||
# ── MCPClient 连接错误处理测试 ──────────────────────────────────
|
||
|
||
|
||
class TestMCPClientErrorHandling:
|
||
"""MCPClient 连接错误处理测试"""
|
||
|
||
async def test_list_tools_http_error(self, httpx_mock):
|
||
httpx_mock.add_response(
|
||
url="http://localhost:8080/tools/list",
|
||
status_code=500,
|
||
)
|
||
|
||
client = MCPClient(server_url="http://localhost:8080")
|
||
with pytest.raises(httpx.HTTPStatusError):
|
||
await client.list_tools()
|
||
|
||
async def test_call_tool_http_error(self, httpx_mock):
|
||
httpx_mock.add_response(
|
||
url="http://localhost:8080/tools/call",
|
||
status_code=404,
|
||
)
|
||
|
||
client = MCPClient(server_url="http://localhost:8080")
|
||
with pytest.raises(httpx.HTTPStatusError):
|
||
await client.call_tool("missing_tool", {})
|
||
|
||
async def test_list_tools_connection_error(self, httpx_mock):
|
||
httpx_mock.add_exception(httpx.ConnectError("Connection refused"))
|
||
|
||
client = MCPClient(server_url="http://localhost:8080")
|
||
with pytest.raises(httpx.ConnectError):
|
||
await client.list_tools()
|
||
|
||
async def test_call_tool_connection_error(self, httpx_mock):
|
||
httpx_mock.add_exception(httpx.ConnectError("Connection refused"))
|
||
|
||
client = MCPClient(server_url="http://localhost:8080")
|
||
with pytest.raises(httpx.ConnectError):
|
||
await client.call_tool("any_tool", {})
|
||
|
||
async def test_transport_error_propagates(self, httpx_mock):
|
||
httpx_mock.add_exception(httpx.ConnectError("Connection refused"))
|
||
|
||
transport = HTTPTransport(endpoint="http://localhost:8080")
|
||
client = MCPClient.from_transport(transport)
|
||
await transport.connect()
|
||
|
||
with pytest.raises(TransportError, match="Request failed"):
|
||
await client.list_tools()
|
||
|
||
await transport.disconnect()
|
||
|
||
|
||
# ── JSON-RPC 2.0 请求格式测试 ───────────────────────────────────
|
||
|
||
|
||
class TestMCPClientJSONRPCFormat:
|
||
"""JSON-RPC 2.0 请求格式测试"""
|
||
|
||
async def test_transport_list_tools_request_format(self, httpx_mock):
|
||
httpx_mock.add_response(
|
||
url="http://localhost:8080/",
|
||
json={"jsonrpc": "2.0", "id": 1, "result": {"tools": []}},
|
||
)
|
||
|
||
transport = HTTPTransport(endpoint="http://localhost:8080")
|
||
client = MCPClient.from_transport(transport)
|
||
|
||
await client.list_tools()
|
||
|
||
request = httpx_mock.get_request()
|
||
body = json.loads(request.content)
|
||
assert body["jsonrpc"] == "2.0"
|
||
assert "id" in body
|
||
assert body["method"] == "tools/list"
|
||
|
||
await transport.disconnect()
|
||
|
||
async def test_transport_call_tool_request_format(self, httpx_mock):
|
||
httpx_mock.add_response(
|
||
url="http://localhost:8080/",
|
||
json={"jsonrpc": "2.0", "id": 1, "result": {}},
|
||
)
|
||
|
||
transport = HTTPTransport(endpoint="http://localhost:8080")
|
||
client = MCPClient.from_transport(transport)
|
||
|
||
await client.call_tool("search", {"query": "test"})
|
||
|
||
request = httpx_mock.get_request()
|
||
body = json.loads(request.content)
|
||
assert body["jsonrpc"] == "2.0"
|
||
assert "id" in body
|
||
assert body["method"] == "tools/call"
|
||
assert body["params"]["name"] == "search"
|
||
assert body["params"]["arguments"] == {"query": "test"}
|
||
|
||
await transport.disconnect()
|
||
|
||
async def test_request_id_increments_across_calls(self, httpx_mock):
|
||
httpx_mock.add_response(
|
||
url="http://localhost:8080/",
|
||
json={"jsonrpc": "2.0", "id": 1, "result": {"tools": []}},
|
||
)
|
||
httpx_mock.add_response(
|
||
url="http://localhost:8080/",
|
||
json={"jsonrpc": "2.0", "id": 2, "result": {}},
|
||
)
|
||
|
||
transport = HTTPTransport(endpoint="http://localhost:8080")
|
||
client = MCPClient.from_transport(transport)
|
||
|
||
await client.list_tools()
|
||
await client.call_tool("test", {})
|
||
|
||
requests = httpx_mock.get_requests()
|
||
body1 = json.loads(requests[0].content)
|
||
body2 = json.loads(requests[1].content)
|
||
assert body1["id"] == 1
|
||
assert body2["id"] == 2
|
||
|
||
await transport.disconnect()
|
||
|
||
|
||
# ── MCPTool 测试 ────────────────────────────────────────────────
|
||
|
||
|
||
class TestMCPTool:
|
||
"""MCPTool 包装测试"""
|
||
|
||
async def test_as_tool_creates_mcp_tool(self):
|
||
client = MCPClient(server_url="http://localhost:8080")
|
||
tool = client.as_tool("search", description="Search the web")
|
||
|
||
assert isinstance(tool, MCPTool)
|
||
assert tool.name == "search"
|
||
assert tool.description == "Search the web"
|
||
assert tool._client is client
|
||
assert "mcp" in tool.tags
|
||
|
||
async def test_mcp_tool_execute_text_content(self, httpx_mock):
|
||
httpx_mock.add_response(
|
||
url="http://localhost:8080/tools/call",
|
||
json={
|
||
"content": [{"type": "text", "text": '{"answer": 42}'}],
|
||
},
|
||
)
|
||
|
||
client = MCPClient(server_url="http://localhost:8080")
|
||
tool = client.as_tool("ask", description="Ask a question")
|
||
|
||
result = await tool.execute(question="meaning of life")
|
||
assert result == {"answer": 42}
|
||
|
||
async def test_mcp_tool_execute_non_json_text(self, httpx_mock):
|
||
httpx_mock.add_response(
|
||
url="http://localhost:8080/tools/call",
|
||
json={
|
||
"content": [{"type": "text", "text": "plain text response"}],
|
||
},
|
||
)
|
||
|
||
client = MCPClient(server_url="http://localhost:8080")
|
||
tool = client.as_tool("echo", description="Echo input")
|
||
|
||
result = await tool.execute(msg="hello")
|
||
assert result == {"result": "plain text response"}
|
||
|
||
async def test_mcp_tool_execute_no_content(self, httpx_mock):
|
||
httpx_mock.add_response(
|
||
url="http://localhost:8080/tools/call",
|
||
json={"status": "ok", "data": "some data"},
|
||
)
|
||
|
||
client = MCPClient(server_url="http://localhost:8080")
|
||
tool = client.as_tool("status", description="Check status")
|
||
|
||
result = await tool.execute()
|
||
assert result == {"status": "ok", "data": "some data"}
|