270 lines
11 KiB
Python
270 lines
11 KiB
Python
"""U2 — MCP 危险工具黑名单过滤单元测试。
|
||
|
||
覆盖 6 个暴露路径:
|
||
- create_mcp_router() REST: GET /tools/list, POST /tools/call
|
||
- create_mcp_router() JSON-RPC: tools/list, tools/call
|
||
- legacy MCPServer REST: GET /tools/list, POST /tools/call
|
||
- legacy MCPServer JSON-RPC: tools/list, tools/call
|
||
|
||
黑名单工具(terminal/shell/file_write/file_read/file_delete)依赖 chat
|
||
confirmation 流程(WebSocket),通过 MCP 暴露会绕过用户确认。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import warnings
|
||
from typing import Any
|
||
from unittest.mock import AsyncMock, MagicMock
|
||
|
||
import httpx
|
||
from fastapi import FastAPI
|
||
|
||
from agentkit.mcp.server import MCPServer, create_mcp_router
|
||
from agentkit.server.auth.middleware import AuthMiddleware
|
||
|
||
# 与 test_server_auth.py 共用的测试凭据 — 仅限单元测试。
|
||
JWT_SECRET = "u13-test-jwt-secret-xxxxxxxxxxxxx"
|
||
API_KEY = "u13-test-api-key-yyy"
|
||
|
||
|
||
# ── 辅助函数 ──────────────────────────────────────────────
|
||
|
||
|
||
def _make_mock_tool(
|
||
name: str,
|
||
description: str = "",
|
||
result: str = "ok",
|
||
) -> MagicMock:
|
||
"""构造一个 mock 工具,模拟 Tool 接口。"""
|
||
tool = MagicMock()
|
||
tool.name = name
|
||
tool.description = description
|
||
tool.input_schema = {"type": "object", "properties": {}}
|
||
tool.safe_execute = AsyncMock(return_value=result)
|
||
return tool
|
||
|
||
|
||
def _make_mock_registry(tools: list) -> MagicMock:
|
||
"""构造一个 mock ToolRegistry,支持 list_tools() 和 get(name)。"""
|
||
registry = MagicMock()
|
||
registry.list_tools.return_value = tools
|
||
|
||
def _get(name: str):
|
||
for t in tools:
|
||
if t.name == name:
|
||
return t
|
||
raise KeyError(name)
|
||
|
||
registry.get = _get
|
||
return registry
|
||
|
||
|
||
def _make_registry_with_safe_and_dangerous() -> MagicMock:
|
||
"""构造含 1 安全 + 5 危险工具的 registry。
|
||
|
||
危险工具名与 _MCP_BLOCKED_TOOLS 完全对应,验证黑名单全覆盖。
|
||
"""
|
||
return _make_mock_registry(
|
||
[
|
||
_make_mock_tool("echo", "safe tool", "echo: hi"),
|
||
_make_mock_tool("shell", "shell tool", "should-not-reach"),
|
||
_make_mock_tool("file_write", "write tool", "should-not-reach"),
|
||
_make_mock_tool("file_read", "read tool", "should-not-reach"),
|
||
_make_mock_tool("file_delete", "delete tool", "should-not-reach"),
|
||
_make_mock_tool("terminal", "terminal tool", "should-not-reach"),
|
||
]
|
||
)
|
||
|
||
|
||
def _make_app(tool_registry: Any = None) -> FastAPI:
|
||
"""构造测试用 FastAPI app:挂载 MCP router + AuthMiddleware。"""
|
||
app = FastAPI()
|
||
app.state.tool_registry = tool_registry
|
||
app.add_middleware(AuthMiddleware, jwt_secret=JWT_SECRET, api_key=API_KEY)
|
||
mcp_router = create_mcp_router(tool_registry=tool_registry)
|
||
app.include_router(mcp_router, prefix="/api/v1/mcp")
|
||
return app
|
||
|
||
|
||
def _make_legacy_app(tool_registry: Any = None) -> FastAPI:
|
||
"""构造 legacy MCPServer app(无认证,过滤 DeprecationWarning)。"""
|
||
server = MCPServer(tool_registry=tool_registry)
|
||
with warnings.catch_warnings():
|
||
warnings.simplefilter("ignore", DeprecationWarning)
|
||
return server._create_app()
|
||
|
||
|
||
_BLOCKED = {"terminal", "shell", "file_write", "file_read", "file_delete"}
|
||
_HEADERS = {"X-API-Key": API_KEY}
|
||
|
||
|
||
# ── create_mcp_router() REST 端点 ─────────────────────────
|
||
|
||
|
||
class TestCreateMcpRouterRestFilter:
|
||
"""create_mcp_router() 的 REST 端点黑名单过滤。"""
|
||
|
||
async def test_rest_tools_list_excludes_blocked(self):
|
||
"""GET /tools/list 不返回黑名单工具。"""
|
||
app = _make_app(tool_registry=_make_registry_with_safe_and_dangerous())
|
||
transport = httpx.ASGITransport(app=app)
|
||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||
resp = await client.get("/api/v1/mcp/tools/list", headers=_HEADERS)
|
||
assert resp.status_code == 200
|
||
names = {t["name"] for t in resp.json()["tools"]}
|
||
assert names == {"echo"}
|
||
assert not (names & _BLOCKED)
|
||
|
||
async def test_rest_tools_call_safe_tool_succeeds(self):
|
||
"""POST /tools/call 安全工具 → 200 + 结果。"""
|
||
app = _make_app(tool_registry=_make_registry_with_safe_and_dangerous())
|
||
transport = httpx.ASGITransport(app=app)
|
||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||
resp = await client.post(
|
||
"/api/v1/mcp/tools/call",
|
||
json={"name": "echo", "arguments": {}},
|
||
headers=_HEADERS,
|
||
)
|
||
assert resp.status_code == 200
|
||
assert "echo: hi" in resp.json()["content"][0]["text"]
|
||
|
||
async def test_rest_tools_call_shell_returns_404(self):
|
||
"""POST /tools/call shell → 404(黑名单)。"""
|
||
app = _make_app(tool_registry=_make_registry_with_safe_and_dangerous())
|
||
transport = httpx.ASGITransport(app=app)
|
||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||
resp = await client.post(
|
||
"/api/v1/mcp/tools/call",
|
||
json={"name": "shell", "arguments": {"cmd": "rm -rf /"}},
|
||
headers=_HEADERS,
|
||
)
|
||
assert resp.status_code == 404
|
||
|
||
async def test_rest_tools_call_file_delete_returns_404(self):
|
||
"""POST /tools/call file_delete → 404(黑名单)。"""
|
||
app = _make_app(tool_registry=_make_registry_with_safe_and_dangerous())
|
||
transport = httpx.ASGITransport(app=app)
|
||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||
resp = await client.post(
|
||
"/api/v1/mcp/tools/call",
|
||
json={"name": "file_delete", "arguments": {"path": "/etc/passwd"}},
|
||
headers=_HEADERS,
|
||
)
|
||
assert resp.status_code == 404
|
||
|
||
|
||
# ── create_mcp_router() JSON-RPC 端点 ─────────────────────
|
||
|
||
|
||
class TestCreateMcpRouterJsonRpcFilter:
|
||
"""create_mcp_router() 的 JSON-RPC 端点黑名单过滤。"""
|
||
|
||
async def test_jsonrpc_tools_list_excludes_blocked(self):
|
||
"""JSON-RPC tools/list 不包含黑名单工具。"""
|
||
app = _make_app(tool_registry=_make_registry_with_safe_and_dangerous())
|
||
transport = httpx.ASGITransport(app=app)
|
||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||
resp = await client.post(
|
||
"/api/v1/mcp/",
|
||
json={"jsonrpc": "2.0", "method": "tools/list", "id": 1},
|
||
headers=_HEADERS,
|
||
)
|
||
assert resp.status_code == 200
|
||
body = resp.json()
|
||
names = {t["name"] for t in body["result"]["tools"]}
|
||
assert names == {"echo"}
|
||
assert not (names & _BLOCKED)
|
||
|
||
async def test_jsonrpc_tools_call_shell_returns_iserror(self):
|
||
"""JSON-RPC tools/call shell → isError=True。"""
|
||
app = _make_app(tool_registry=_make_registry_with_safe_and_dangerous())
|
||
transport = httpx.ASGITransport(app=app)
|
||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||
resp = await client.post(
|
||
"/api/v1/mcp/",
|
||
json={
|
||
"jsonrpc": "2.0",
|
||
"method": "tools/call",
|
||
"params": {"name": "shell", "arguments": {"cmd": "rm -rf /"}},
|
||
"id": 2,
|
||
},
|
||
headers=_HEADERS,
|
||
)
|
||
assert resp.status_code == 200
|
||
result = resp.json()["result"]
|
||
assert result.get("isError") is True
|
||
assert "blocked" in result["content"][0]["text"].lower()
|
||
|
||
|
||
# ── legacy MCPServer REST 端点 ────────────────────────────
|
||
|
||
|
||
class TestLegacyMcpServerRestFilter:
|
||
"""legacy MCPServer(独立 app)的 REST 端点黑名单过滤。"""
|
||
|
||
async def test_legacy_rest_tools_list_excludes_blocked(self):
|
||
"""legacy GET /tools/list 不返回黑名单工具。"""
|
||
app = _make_legacy_app(tool_registry=_make_registry_with_safe_and_dangerous())
|
||
transport = httpx.ASGITransport(app=app)
|
||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||
resp = await client.get("/tools/list")
|
||
assert resp.status_code == 200
|
||
names = {t["name"] for t in resp.json()["tools"]}
|
||
assert names == {"echo"}
|
||
assert not (names & _BLOCKED)
|
||
|
||
async def test_legacy_rest_tools_call_shell_returns_error(self):
|
||
"""legacy POST /tools/call shell → 返回 error(blocked)。"""
|
||
app = _make_legacy_app(tool_registry=_make_registry_with_safe_and_dangerous())
|
||
transport = httpx.ASGITransport(app=app)
|
||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||
resp = await client.post(
|
||
"/tools/call",
|
||
json={"name": "shell", "arguments": {"cmd": "rm -rf /"}},
|
||
)
|
||
assert resp.status_code == 200
|
||
body = resp.json()
|
||
assert "error" in body
|
||
assert "blocked" in body["error"].lower()
|
||
|
||
|
||
# ── legacy MCPServer JSON-RPC 端点 ────────────────────────
|
||
|
||
|
||
class TestLegacyMcpServerJsonRpcFilter:
|
||
"""legacy MCPServer 的 JSON-RPC 端点黑名单过滤。"""
|
||
|
||
async def test_legacy_jsonrpc_tools_list_excludes_blocked(self):
|
||
"""legacy JSON-RPC tools/list 不包含黑名单工具。"""
|
||
app = _make_legacy_app(tool_registry=_make_registry_with_safe_and_dangerous())
|
||
transport = httpx.ASGITransport(app=app)
|
||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||
resp = await client.post(
|
||
"/",
|
||
json={"jsonrpc": "2.0", "method": "tools/list", "id": 1},
|
||
)
|
||
assert resp.status_code == 200
|
||
body = resp.json()
|
||
names = {t["name"] for t in body["result"]["tools"]}
|
||
assert names == {"echo"}
|
||
assert not (names & _BLOCKED)
|
||
|
||
async def test_legacy_jsonrpc_tools_call_shell_returns_iserror(self):
|
||
"""legacy JSON-RPC tools/call shell → isError=True。"""
|
||
app = _make_legacy_app(tool_registry=_make_registry_with_safe_and_dangerous())
|
||
transport = httpx.ASGITransport(app=app)
|
||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||
resp = await client.post(
|
||
"/",
|
||
json={
|
||
"jsonrpc": "2.0",
|
||
"method": "tools/call",
|
||
"params": {"name": "shell", "arguments": {"cmd": "rm -rf /"}},
|
||
"id": 2,
|
||
},
|
||
)
|
||
assert resp.status_code == 200
|
||
result = resp.json()["result"]
|
||
assert result.get("isError") is True
|
||
assert "blocked" in result["content"][0]["text"].lower()
|