fischer-agentkit/tests/unit/mcp/test_server_dangerous_tools.py

270 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""U2 — MCP 危险工具黑名单过滤单元测试。
覆盖 6 个暴露路径:
- create_mcp_router() REST: GET /tools/list, POST /tools/call
- create_mcp_router() JSON-RPC: tools/list, tools/call
- legacy MCPServer REST: GET /tools/list, POST /tools/call
- legacy MCPServer JSON-RPC: tools/list, tools/call
黑名单工具terminal/shell/file_write/file_read/file_delete依赖 chat
confirmation 流程WebSocket通过 MCP 暴露会绕过用户确认。
"""
from __future__ import annotations
import warnings
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import httpx
from fastapi import FastAPI
from agentkit.mcp.server import MCPServer, create_mcp_router
from agentkit.server.auth.middleware import AuthMiddleware
# 与 test_server_auth.py 共用的测试凭据 — 仅限单元测试。
JWT_SECRET = "u13-test-jwt-secret-xxxxxxxxxxxxx"
API_KEY = "u13-test-api-key-yyy"
# ── 辅助函数 ──────────────────────────────────────────────
def _make_mock_tool(
name: str,
description: str = "",
result: str = "ok",
) -> MagicMock:
"""构造一个 mock 工具,模拟 Tool 接口。"""
tool = MagicMock()
tool.name = name
tool.description = description
tool.input_schema = {"type": "object", "properties": {}}
tool.safe_execute = AsyncMock(return_value=result)
return tool
def _make_mock_registry(tools: list) -> MagicMock:
"""构造一个 mock ToolRegistry支持 list_tools() 和 get(name)。"""
registry = MagicMock()
registry.list_tools.return_value = tools
def _get(name: str):
for t in tools:
if t.name == name:
return t
raise KeyError(name)
registry.get = _get
return registry
def _make_registry_with_safe_and_dangerous() -> MagicMock:
"""构造含 1 安全 + 5 危险工具的 registry。
危险工具名与 _MCP_BLOCKED_TOOLS 完全对应,验证黑名单全覆盖。
"""
return _make_mock_registry(
[
_make_mock_tool("echo", "safe tool", "echo: hi"),
_make_mock_tool("shell", "shell tool", "should-not-reach"),
_make_mock_tool("file_write", "write tool", "should-not-reach"),
_make_mock_tool("file_read", "read tool", "should-not-reach"),
_make_mock_tool("file_delete", "delete tool", "should-not-reach"),
_make_mock_tool("terminal", "terminal tool", "should-not-reach"),
]
)
def _make_app(tool_registry: Any = None) -> FastAPI:
"""构造测试用 FastAPI app挂载 MCP router + AuthMiddleware。"""
app = FastAPI()
app.state.tool_registry = tool_registry
app.add_middleware(AuthMiddleware, jwt_secret=JWT_SECRET, api_key=API_KEY)
mcp_router = create_mcp_router(tool_registry=tool_registry)
app.include_router(mcp_router, prefix="/api/v1/mcp")
return app
def _make_legacy_app(tool_registry: Any = None) -> FastAPI:
"""构造 legacy MCPServer app无认证过滤 DeprecationWarning"""
server = MCPServer(tool_registry=tool_registry)
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
return server._create_app()
_BLOCKED = {"terminal", "shell", "file_write", "file_read", "file_delete"}
_HEADERS = {"X-API-Key": API_KEY}
# ── create_mcp_router() REST 端点 ─────────────────────────
class TestCreateMcpRouterRestFilter:
"""create_mcp_router() 的 REST 端点黑名单过滤。"""
async def test_rest_tools_list_excludes_blocked(self):
"""GET /tools/list 不返回黑名单工具。"""
app = _make_app(tool_registry=_make_registry_with_safe_and_dangerous())
transport = httpx.ASGITransport(app=app)
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
resp = await client.get("/api/v1/mcp/tools/list", headers=_HEADERS)
assert resp.status_code == 200
names = {t["name"] for t in resp.json()["tools"]}
assert names == {"echo"}
assert not (names & _BLOCKED)
async def test_rest_tools_call_safe_tool_succeeds(self):
"""POST /tools/call 安全工具 → 200 + 结果。"""
app = _make_app(tool_registry=_make_registry_with_safe_and_dangerous())
transport = httpx.ASGITransport(app=app)
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
resp = await client.post(
"/api/v1/mcp/tools/call",
json={"name": "echo", "arguments": {}},
headers=_HEADERS,
)
assert resp.status_code == 200
assert "echo: hi" in resp.json()["content"][0]["text"]
async def test_rest_tools_call_shell_returns_404(self):
"""POST /tools/call shell → 404黑名单"""
app = _make_app(tool_registry=_make_registry_with_safe_and_dangerous())
transport = httpx.ASGITransport(app=app)
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
resp = await client.post(
"/api/v1/mcp/tools/call",
json={"name": "shell", "arguments": {"cmd": "rm -rf /"}},
headers=_HEADERS,
)
assert resp.status_code == 404
async def test_rest_tools_call_file_delete_returns_404(self):
"""POST /tools/call file_delete → 404黑名单"""
app = _make_app(tool_registry=_make_registry_with_safe_and_dangerous())
transport = httpx.ASGITransport(app=app)
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
resp = await client.post(
"/api/v1/mcp/tools/call",
json={"name": "file_delete", "arguments": {"path": "/etc/passwd"}},
headers=_HEADERS,
)
assert resp.status_code == 404
# ── create_mcp_router() JSON-RPC 端点 ─────────────────────
class TestCreateMcpRouterJsonRpcFilter:
"""create_mcp_router() 的 JSON-RPC 端点黑名单过滤。"""
async def test_jsonrpc_tools_list_excludes_blocked(self):
"""JSON-RPC tools/list 不包含黑名单工具。"""
app = _make_app(tool_registry=_make_registry_with_safe_and_dangerous())
transport = httpx.ASGITransport(app=app)
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
resp = await client.post(
"/api/v1/mcp/",
json={"jsonrpc": "2.0", "method": "tools/list", "id": 1},
headers=_HEADERS,
)
assert resp.status_code == 200
body = resp.json()
names = {t["name"] for t in body["result"]["tools"]}
assert names == {"echo"}
assert not (names & _BLOCKED)
async def test_jsonrpc_tools_call_shell_returns_iserror(self):
"""JSON-RPC tools/call shell → isError=True。"""
app = _make_app(tool_registry=_make_registry_with_safe_and_dangerous())
transport = httpx.ASGITransport(app=app)
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
resp = await client.post(
"/api/v1/mcp/",
json={
"jsonrpc": "2.0",
"method": "tools/call",
"params": {"name": "shell", "arguments": {"cmd": "rm -rf /"}},
"id": 2,
},
headers=_HEADERS,
)
assert resp.status_code == 200
result = resp.json()["result"]
assert result.get("isError") is True
assert "blocked" in result["content"][0]["text"].lower()
# ── legacy MCPServer REST 端点 ────────────────────────────
class TestLegacyMcpServerRestFilter:
"""legacy MCPServer独立 app的 REST 端点黑名单过滤。"""
async def test_legacy_rest_tools_list_excludes_blocked(self):
"""legacy GET /tools/list 不返回黑名单工具。"""
app = _make_legacy_app(tool_registry=_make_registry_with_safe_and_dangerous())
transport = httpx.ASGITransport(app=app)
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
resp = await client.get("/tools/list")
assert resp.status_code == 200
names = {t["name"] for t in resp.json()["tools"]}
assert names == {"echo"}
assert not (names & _BLOCKED)
async def test_legacy_rest_tools_call_shell_returns_error(self):
"""legacy POST /tools/call shell → 返回 errorblocked"""
app = _make_legacy_app(tool_registry=_make_registry_with_safe_and_dangerous())
transport = httpx.ASGITransport(app=app)
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
resp = await client.post(
"/tools/call",
json={"name": "shell", "arguments": {"cmd": "rm -rf /"}},
)
assert resp.status_code == 200
body = resp.json()
assert "error" in body
assert "blocked" in body["error"].lower()
# ── legacy MCPServer JSON-RPC 端点 ────────────────────────
class TestLegacyMcpServerJsonRpcFilter:
"""legacy MCPServer 的 JSON-RPC 端点黑名单过滤。"""
async def test_legacy_jsonrpc_tools_list_excludes_blocked(self):
"""legacy JSON-RPC tools/list 不包含黑名单工具。"""
app = _make_legacy_app(tool_registry=_make_registry_with_safe_and_dangerous())
transport = httpx.ASGITransport(app=app)
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
resp = await client.post(
"/",
json={"jsonrpc": "2.0", "method": "tools/list", "id": 1},
)
assert resp.status_code == 200
body = resp.json()
names = {t["name"] for t in body["result"]["tools"]}
assert names == {"echo"}
assert not (names & _BLOCKED)
async def test_legacy_jsonrpc_tools_call_shell_returns_iserror(self):
"""legacy JSON-RPC tools/call shell → isError=True。"""
app = _make_legacy_app(tool_registry=_make_registry_with_safe_and_dangerous())
transport = httpx.ASGITransport(app=app)
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
resp = await client.post(
"/",
json={
"jsonrpc": "2.0",
"method": "tools/call",
"params": {"name": "shell", "arguments": {"cmd": "rm -rf /"}},
"id": 2,
},
)
assert resp.status_code == 200
result = resp.json()["result"]
assert result.get("isError") is True
assert "blocked" in result["content"][0]["text"].lower()