"""U2 — MCP 危险工具黑名单过滤单元测试。 覆盖 6 个暴露路径: - create_mcp_router() REST: GET /tools/list, POST /tools/call - create_mcp_router() JSON-RPC: tools/list, tools/call - legacy MCPServer REST: GET /tools/list, POST /tools/call - legacy MCPServer JSON-RPC: tools/list, tools/call 黑名单工具(terminal/shell/file_write/file_read/file_delete)依赖 chat confirmation 流程(WebSocket),通过 MCP 暴露会绕过用户确认。 """ from __future__ import annotations import warnings from typing import Any from unittest.mock import AsyncMock, MagicMock import httpx from fastapi import FastAPI from agentkit.mcp.server import MCPServer, create_mcp_router from agentkit.server.auth.middleware import AuthMiddleware # 与 test_server_auth.py 共用的测试凭据 — 仅限单元测试。 JWT_SECRET = "u13-test-jwt-secret-xxxxxxxxxxxxx" API_KEY = "u13-test-api-key-yyy" # ── 辅助函数 ────────────────────────────────────────────── def _make_mock_tool( name: str, description: str = "", result: str = "ok", ) -> MagicMock: """构造一个 mock 工具,模拟 Tool 接口。""" tool = MagicMock() tool.name = name tool.description = description tool.input_schema = {"type": "object", "properties": {}} tool.safe_execute = AsyncMock(return_value=result) return tool def _make_mock_registry(tools: list) -> MagicMock: """构造一个 mock ToolRegistry,支持 list_tools() 和 get(name)。""" registry = MagicMock() registry.list_tools.return_value = tools def _get(name: str): for t in tools: if t.name == name: return t raise KeyError(name) registry.get = _get return registry def _make_registry_with_safe_and_dangerous() -> MagicMock: """构造含 1 安全 + 5 危险工具的 registry。 危险工具名与 _MCP_BLOCKED_TOOLS 完全对应,验证黑名单全覆盖。 """ return _make_mock_registry( [ _make_mock_tool("echo", "safe tool", "echo: hi"), _make_mock_tool("shell", "shell tool", "should-not-reach"), _make_mock_tool("file_write", "write tool", "should-not-reach"), _make_mock_tool("file_read", "read tool", "should-not-reach"), _make_mock_tool("file_delete", "delete tool", "should-not-reach"), _make_mock_tool("terminal", "terminal tool", "should-not-reach"), ] ) def _make_app(tool_registry: Any = None) -> FastAPI: """构造测试用 FastAPI app:挂载 MCP router + AuthMiddleware。""" app = FastAPI() app.state.tool_registry = tool_registry app.add_middleware(AuthMiddleware, jwt_secret=JWT_SECRET, api_key=API_KEY) mcp_router = create_mcp_router(tool_registry=tool_registry) app.include_router(mcp_router, prefix="/api/v1/mcp") return app def _make_legacy_app(tool_registry: Any = None) -> FastAPI: """构造 legacy MCPServer app(无认证,过滤 DeprecationWarning)。""" server = MCPServer(tool_registry=tool_registry) with warnings.catch_warnings(): warnings.simplefilter("ignore", DeprecationWarning) return server._create_app() _BLOCKED = {"terminal", "shell", "file_write", "file_read", "file_delete"} _HEADERS = {"X-API-Key": API_KEY} # ── create_mcp_router() REST 端点 ───────────────────────── class TestCreateMcpRouterRestFilter: """create_mcp_router() 的 REST 端点黑名单过滤。""" async def test_rest_tools_list_excludes_blocked(self): """GET /tools/list 不返回黑名单工具。""" app = _make_app(tool_registry=_make_registry_with_safe_and_dangerous()) transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: resp = await client.get("/api/v1/mcp/tools/list", headers=_HEADERS) assert resp.status_code == 200 names = {t["name"] for t in resp.json()["tools"]} assert names == {"echo"} assert not (names & _BLOCKED) async def test_rest_tools_call_safe_tool_succeeds(self): """POST /tools/call 安全工具 → 200 + 结果。""" app = _make_app(tool_registry=_make_registry_with_safe_and_dangerous()) transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: resp = await client.post( "/api/v1/mcp/tools/call", json={"name": "echo", "arguments": {}}, headers=_HEADERS, ) assert resp.status_code == 200 assert "echo: hi" in resp.json()["content"][0]["text"] async def test_rest_tools_call_shell_returns_404(self): """POST /tools/call shell → 404(黑名单)。""" app = _make_app(tool_registry=_make_registry_with_safe_and_dangerous()) transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: resp = await client.post( "/api/v1/mcp/tools/call", json={"name": "shell", "arguments": {"cmd": "rm -rf /"}}, headers=_HEADERS, ) assert resp.status_code == 404 async def test_rest_tools_call_file_delete_returns_404(self): """POST /tools/call file_delete → 404(黑名单)。""" app = _make_app(tool_registry=_make_registry_with_safe_and_dangerous()) transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: resp = await client.post( "/api/v1/mcp/tools/call", json={"name": "file_delete", "arguments": {"path": "/etc/passwd"}}, headers=_HEADERS, ) assert resp.status_code == 404 # ── create_mcp_router() JSON-RPC 端点 ───────────────────── class TestCreateMcpRouterJsonRpcFilter: """create_mcp_router() 的 JSON-RPC 端点黑名单过滤。""" async def test_jsonrpc_tools_list_excludes_blocked(self): """JSON-RPC tools/list 不包含黑名单工具。""" app = _make_app(tool_registry=_make_registry_with_safe_and_dangerous()) transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: resp = await client.post( "/api/v1/mcp/", json={"jsonrpc": "2.0", "method": "tools/list", "id": 1}, headers=_HEADERS, ) assert resp.status_code == 200 body = resp.json() names = {t["name"] for t in body["result"]["tools"]} assert names == {"echo"} assert not (names & _BLOCKED) async def test_jsonrpc_tools_call_shell_returns_iserror(self): """JSON-RPC tools/call shell → isError=True。""" app = _make_app(tool_registry=_make_registry_with_safe_and_dangerous()) transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: resp = await client.post( "/api/v1/mcp/", json={ "jsonrpc": "2.0", "method": "tools/call", "params": {"name": "shell", "arguments": {"cmd": "rm -rf /"}}, "id": 2, }, headers=_HEADERS, ) assert resp.status_code == 200 result = resp.json()["result"] assert result.get("isError") is True assert "blocked" in result["content"][0]["text"].lower() # ── legacy MCPServer REST 端点 ──────────────────────────── class TestLegacyMcpServerRestFilter: """legacy MCPServer(独立 app)的 REST 端点黑名单过滤。""" async def test_legacy_rest_tools_list_excludes_blocked(self): """legacy GET /tools/list 不返回黑名单工具。""" app = _make_legacy_app(tool_registry=_make_registry_with_safe_and_dangerous()) transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: resp = await client.get("/tools/list") assert resp.status_code == 200 names = {t["name"] for t in resp.json()["tools"]} assert names == {"echo"} assert not (names & _BLOCKED) async def test_legacy_rest_tools_call_shell_returns_error(self): """legacy POST /tools/call shell → 返回 error(blocked)。""" app = _make_legacy_app(tool_registry=_make_registry_with_safe_and_dangerous()) transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: resp = await client.post( "/tools/call", json={"name": "shell", "arguments": {"cmd": "rm -rf /"}}, ) assert resp.status_code == 200 body = resp.json() assert "error" in body assert "blocked" in body["error"].lower() # ── legacy MCPServer JSON-RPC 端点 ──────────────────────── class TestLegacyMcpServerJsonRpcFilter: """legacy MCPServer 的 JSON-RPC 端点黑名单过滤。""" async def test_legacy_jsonrpc_tools_list_excludes_blocked(self): """legacy JSON-RPC tools/list 不包含黑名单工具。""" app = _make_legacy_app(tool_registry=_make_registry_with_safe_and_dangerous()) transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: resp = await client.post( "/", json={"jsonrpc": "2.0", "method": "tools/list", "id": 1}, ) assert resp.status_code == 200 body = resp.json() names = {t["name"] for t in body["result"]["tools"]} assert names == {"echo"} assert not (names & _BLOCKED) async def test_legacy_jsonrpc_tools_call_shell_returns_iserror(self): """legacy JSON-RPC tools/call shell → isError=True。""" app = _make_legacy_app(tool_registry=_make_registry_with_safe_and_dangerous()) transport = httpx.ASGITransport(app=app) async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: resp = await client.post( "/", json={ "jsonrpc": "2.0", "method": "tools/call", "params": {"name": "shell", "arguments": {"cmd": "rm -rf /"}}, "id": 2, }, ) assert resp.status_code == 200 result = resp.json()["result"] assert result.get("isError") is True assert "blocked" in result["content"][0]["text"].lower()