399 lines
16 KiB
Python
399 lines
16 KiB
Python
"""U13 — MCP Server 认证 + 合并至主 app 的单元测试。
|
||
|
||
覆盖场景:
|
||
- 无认证 / 无效凭据 → 401
|
||
- 有效 API Key / 有效 JWT(member)→ 200 + 工具列表
|
||
- 有效 JWT 但无权限(guest)→ 403
|
||
- tools/call 成功 / 未知工具(404) / 执行错误
|
||
- JSON-RPC 2.0 端点(initialize / tools/list / tools/call / 未知方法 / 解析错误)
|
||
- MCPServer 类向后兼容(DeprecationWarning)
|
||
- create_mcp_router() 返回 APIRouter / 空 registry
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import warnings
|
||
from typing import Any
|
||
from unittest.mock import AsyncMock, MagicMock
|
||
|
||
import httpx
|
||
import jwt
|
||
from fastapi import FastAPI
|
||
from fastapi.routing import APIRouter
|
||
|
||
from agentkit.mcp.server import MCPServer, create_mcp_router
|
||
from agentkit.server.auth.middleware import AuthMiddleware
|
||
|
||
# 测试用的固定凭据 — 仅限单元测试,不可用于生产。
|
||
JWT_SECRET = "u13-test-jwt-secret-xxxxxxxxxxxxx"
|
||
API_KEY = "u13-test-api-key-yyy"
|
||
|
||
|
||
# ── 测试辅助函数 ──────────────────────────────────────────
|
||
|
||
|
||
def _make_app(tool_registry: Any = None, *, dev_mode: bool = False) -> FastAPI:
|
||
"""构造测试用 FastAPI app:挂载 MCP router + AuthMiddleware。
|
||
|
||
Args:
|
||
tool_registry: ToolRegistry 实例(或 mock)。None 表示未配置。
|
||
dev_mode: True = 不添加 AuthMiddleware(开发模式,所有请求放行)。
|
||
"""
|
||
app = FastAPI()
|
||
app.state.tool_registry = tool_registry
|
||
|
||
if not dev_mode:
|
||
app.add_middleware(AuthMiddleware, jwt_secret=JWT_SECRET, api_key=API_KEY)
|
||
|
||
mcp_router = create_mcp_router(tool_registry=tool_registry)
|
||
app.include_router(mcp_router, prefix="/api/v1/mcp")
|
||
return app
|
||
|
||
|
||
def _make_mock_tool(
|
||
name: str = "echo",
|
||
description: str = "Echo tool",
|
||
input_schema: dict | None = None,
|
||
result: str = "echo: hi",
|
||
) -> MagicMock:
|
||
"""构造一个 mock 工具,模拟 Tool 接口。"""
|
||
tool = MagicMock()
|
||
tool.name = name
|
||
tool.description = description
|
||
tool.input_schema = input_schema or {"type": "object", "properties": {}}
|
||
tool.safe_execute = AsyncMock(return_value=result)
|
||
return tool
|
||
|
||
|
||
def _make_mock_registry(tools: list) -> MagicMock:
|
||
"""构造一个 mock ToolRegistry,支持 list_tools() 和 get(name)。"""
|
||
registry = MagicMock()
|
||
registry.list_tools.return_value = tools
|
||
|
||
def _get(name: str):
|
||
for t in tools:
|
||
if t.name == name:
|
||
return t
|
||
raise KeyError(name)
|
||
|
||
registry.get = _get
|
||
return registry
|
||
|
||
|
||
def _make_jwt(
|
||
role: str = "member",
|
||
user_id: str = "u1",
|
||
username: str = "alice",
|
||
) -> str:
|
||
"""签发一个测试用 access JWT(type=access,HS256)。"""
|
||
payload = {
|
||
"sub": user_id,
|
||
"username": username,
|
||
"role": role,
|
||
"type": "access",
|
||
"iat": 1700000000,
|
||
"exp": 9999999999, # 远未来,避免过期
|
||
}
|
||
token = jwt.encode(payload, JWT_SECRET, algorithm="HS256")
|
||
return token.decode("utf-8") if isinstance(token, bytes) else token
|
||
|
||
|
||
def _make_registry_with_two_tools() -> MagicMock:
|
||
"""构造含两个 mock 工具的 registry(echo + reverse)。"""
|
||
echo = _make_mock_tool(name="echo", description="Echo input", result="echo: hi")
|
||
reverse = _make_mock_tool(
|
||
name="reverse",
|
||
description="Reverse text",
|
||
result="dcba",
|
||
)
|
||
return _make_mock_registry([echo, reverse])
|
||
|
||
|
||
# ── 1-2: 无认证 / 无效凭据 ────────────────────────────────
|
||
|
||
|
||
class TestNoAuth:
|
||
"""无认证或无效凭据应返回 401。"""
|
||
|
||
async def test_no_auth_returns_401(self):
|
||
"""场景 1: 不带 Authorization 头 → 401。"""
|
||
app = _make_app(tool_registry=_make_registry_with_two_tools())
|
||
transport = httpx.ASGITransport(app=app)
|
||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||
resp = await client.get("/api/v1/mcp/tools/list")
|
||
assert resp.status_code == 401
|
||
|
||
async def test_invalid_api_key_returns_401(self):
|
||
"""场景 2: 无效 X-API-Key → 401。"""
|
||
app = _make_app(tool_registry=_make_registry_with_two_tools())
|
||
transport = httpx.ASGITransport(app=app)
|
||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||
resp = await client.get(
|
||
"/api/v1/mcp/tools/list",
|
||
headers={"X-API-Key": "invalid-key"},
|
||
)
|
||
assert resp.status_code == 401
|
||
|
||
|
||
# ── 3-5: 有效凭据 + 权限校验 ──────────────────────────────
|
||
|
||
|
||
class TestValidAuth:
|
||
"""有效凭据应通过认证,权限校验决定 200/403。"""
|
||
|
||
async def test_valid_api_key_returns_tool_list(self):
|
||
"""场景 3: 有效 X-API-Key → 200 + 工具列表。"""
|
||
app = _make_app(tool_registry=_make_registry_with_two_tools())
|
||
transport = httpx.ASGITransport(app=app)
|
||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||
resp = await client.get(
|
||
"/api/v1/mcp/tools/list",
|
||
headers={"X-API-Key": API_KEY},
|
||
)
|
||
assert resp.status_code == 200
|
||
body = resp.json()
|
||
assert "tools" in body
|
||
names = {t["name"] for t in body["tools"]}
|
||
assert names == {"echo", "reverse"}
|
||
|
||
async def test_valid_jwt_member_returns_tool_list(self):
|
||
"""场景 4: 有效 JWT(member 角色)→ 200 + 工具列表。"""
|
||
app = _make_app(tool_registry=_make_registry_with_two_tools())
|
||
token = _make_jwt(role="member")
|
||
transport = httpx.ASGITransport(app=app)
|
||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||
resp = await client.get(
|
||
"/api/v1/mcp/tools/list",
|
||
headers={"Authorization": f"Bearer {token}"},
|
||
)
|
||
assert resp.status_code == 200
|
||
body = resp.json()
|
||
assert len(body["tools"]) == 2
|
||
|
||
async def test_valid_jwt_guest_no_permission_returns_403(self):
|
||
"""场景 5: 有效 JWT 但 role=guest(无 CHAT 权限)→ 403。
|
||
|
||
guest 角色不在 ROLE_PERMISSIONS 中,故 has_permission 返回 False。
|
||
"""
|
||
app = _make_app(tool_registry=_make_registry_with_two_tools())
|
||
token = _make_jwt(role="guest")
|
||
transport = httpx.ASGITransport(app=app)
|
||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||
resp = await client.get(
|
||
"/api/v1/mcp/tools/list",
|
||
headers={"Authorization": f"Bearer {token}"},
|
||
)
|
||
assert resp.status_code == 403
|
||
|
||
|
||
# ── 6-8: tools/call 端点 ─────────────────────────────────
|
||
|
||
|
||
class TestCallTool:
|
||
"""POST /api/v1/mcp/tools/call 的成功 / 404 / 执行错误。"""
|
||
|
||
async def test_call_tool_success(self):
|
||
"""场景 6: 有效认证 + 有效工具 → 200 + 结果。"""
|
||
registry = _make_registry_with_two_tools()
|
||
app = _make_app(tool_registry=registry)
|
||
transport = httpx.ASGITransport(app=app)
|
||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||
resp = await client.post(
|
||
"/api/v1/mcp/tools/call",
|
||
json={"name": "echo", "arguments": {"text": "hi"}},
|
||
headers={"X-API-Key": API_KEY},
|
||
)
|
||
assert resp.status_code == 200
|
||
body = resp.json()
|
||
assert body["content"][0]["type"] == "text"
|
||
assert "echo: hi" in body["content"][0]["text"]
|
||
|
||
async def test_call_tool_unknown_returns_404(self):
|
||
"""场景 7: 有效认证 + 未知工具 → 404。"""
|
||
registry = _make_registry_with_two_tools()
|
||
app = _make_app(tool_registry=registry)
|
||
transport = httpx.ASGITransport(app=app)
|
||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||
resp = await client.post(
|
||
"/api/v1/mcp/tools/call",
|
||
json={"name": "nonexistent", "arguments": {}},
|
||
headers={"X-API-Key": API_KEY},
|
||
)
|
||
assert resp.status_code == 404
|
||
|
||
async def test_call_tool_execution_error_returns_isError(self):
|
||
"""场景 8: 有效认证 + 工具抛异常 → 200 + isError=True。"""
|
||
failing_tool = _make_mock_tool(name="boom", result="should-not-reach")
|
||
failing_tool.safe_execute = AsyncMock(side_effect=RuntimeError("boom!"))
|
||
registry = _make_mock_registry([failing_tool])
|
||
app = _make_app(tool_registry=registry)
|
||
transport = httpx.ASGITransport(app=app)
|
||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||
resp = await client.post(
|
||
"/api/v1/mcp/tools/call",
|
||
json={"name": "boom", "arguments": {}},
|
||
headers={"X-API-Key": API_KEY},
|
||
)
|
||
assert resp.status_code == 200
|
||
body = resp.json()
|
||
assert body.get("isError") is True
|
||
assert "boom!" in body["content"][0]["text"]
|
||
|
||
|
||
# ── 9: health 端点 ───────────────────────────────────────
|
||
|
||
|
||
class TestHealthEndpoint:
|
||
"""GET /api/v1/mcp/health — 有效认证 → 200。"""
|
||
|
||
async def test_health_with_valid_auth(self):
|
||
"""场景 9: 有效认证 → 200 + {"status": "ok"}。"""
|
||
app = _make_app(tool_registry=None)
|
||
transport = httpx.ASGITransport(app=app)
|
||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||
resp = await client.get(
|
||
"/api/v1/mcp/health",
|
||
headers={"X-API-Key": API_KEY},
|
||
)
|
||
assert resp.status_code == 200
|
||
assert resp.json() == {"status": "ok"}
|
||
|
||
|
||
# ── 10-14: JSON-RPC 2.0 端点 ─────────────────────────────
|
||
|
||
|
||
class TestJsonRpcEndpoint:
|
||
"""POST /api/v1/mcp/ — MCP 协议兼容的 JSON-RPC 2.0 端点。"""
|
||
|
||
async def test_jsonrpc_initialize(self):
|
||
"""场景 10: initialize → 200 + protocolVersion。"""
|
||
app = _make_app(tool_registry=_make_registry_with_two_tools())
|
||
transport = httpx.ASGITransport(app=app)
|
||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||
resp = await client.post(
|
||
"/api/v1/mcp/",
|
||
json={"jsonrpc": "2.0", "method": "initialize", "id": 1},
|
||
headers={"X-API-Key": API_KEY},
|
||
)
|
||
assert resp.status_code == 200
|
||
body = resp.json()
|
||
assert body["jsonrpc"] == "2.0"
|
||
assert body["id"] == 1
|
||
assert body["result"]["protocolVersion"] == "2024-11-05"
|
||
assert body["result"]["serverInfo"]["name"] == "agentkit-mcp-server"
|
||
|
||
async def test_jsonrpc_tools_list(self):
|
||
"""场景 11: tools/list via JSON-RPC → 200 + tools 数组。"""
|
||
app = _make_app(tool_registry=_make_registry_with_two_tools())
|
||
transport = httpx.ASGITransport(app=app)
|
||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||
resp = await client.post(
|
||
"/api/v1/mcp/",
|
||
json={"jsonrpc": "2.0", "method": "tools/list", "id": 2},
|
||
headers={"X-API-Key": API_KEY},
|
||
)
|
||
assert resp.status_code == 200
|
||
body = resp.json()
|
||
assert body["id"] == 2
|
||
assert len(body["result"]["tools"]) == 2
|
||
|
||
async def test_jsonrpc_tools_call(self):
|
||
"""场景 12: tools/call via JSON-RPC → 200 + 结果。"""
|
||
app = _make_app(tool_registry=_make_registry_with_two_tools())
|
||
transport = httpx.ASGITransport(app=app)
|
||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||
resp = await client.post(
|
||
"/api/v1/mcp/",
|
||
json={
|
||
"jsonrpc": "2.0",
|
||
"method": "tools/call",
|
||
"params": {"name": "echo", "arguments": {}},
|
||
"id": 3,
|
||
},
|
||
headers={"X-API-Key": API_KEY},
|
||
)
|
||
assert resp.status_code == 200
|
||
body = resp.json()
|
||
assert body["id"] == 3
|
||
assert "content" in body["result"]
|
||
|
||
async def test_jsonrpc_unknown_method_returns_error_32601(self):
|
||
"""场景 13: 未知 method → error code -32601。"""
|
||
app = _make_app(tool_registry=None)
|
||
transport = httpx.ASGITransport(app=app)
|
||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||
resp = await client.post(
|
||
"/api/v1/mcp/",
|
||
json={"jsonrpc": "2.0", "method": "unknown", "id": 4},
|
||
headers={"X-API-Key": API_KEY},
|
||
)
|
||
assert resp.status_code == 200
|
||
body = resp.json()
|
||
assert body["error"]["code"] == -32601
|
||
assert body["id"] == 4
|
||
|
||
async def test_jsonrpc_parse_error_returns_error_32700(self):
|
||
"""场景 14: 无效 JSON → error code -32700。"""
|
||
app = _make_app(tool_registry=None)
|
||
transport = httpx.ASGITransport(app=app)
|
||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||
resp = await client.post(
|
||
"/api/v1/mcp/",
|
||
content=b"this is not json",
|
||
headers={
|
||
"X-API-Key": API_KEY,
|
||
"Content-Type": "application/json",
|
||
},
|
||
)
|
||
assert resp.status_code == 200
|
||
body = resp.json()
|
||
assert body["error"]["code"] == -32700
|
||
assert body["id"] is None
|
||
|
||
|
||
# ── 15: MCPServer 向后兼容 ───────────────────────────────
|
||
|
||
|
||
class TestMCPServerDeprecation:
|
||
"""旧 MCPServer 类应发出 DeprecationWarning。"""
|
||
|
||
def test_create_app_emits_deprecation_warning(self):
|
||
"""场景 15: MCPServer()._create_app() 发出 DeprecationWarning。"""
|
||
server = MCPServer()
|
||
with warnings.catch_warnings(record=True) as caught:
|
||
warnings.simplefilter("always")
|
||
server._create_app()
|
||
dep_warnings = [w for w in caught if issubclass(w.category, DeprecationWarning)]
|
||
assert len(dep_warnings) >= 1
|
||
assert "create_mcp_router" in str(dep_warnings[0].message)
|
||
|
||
|
||
# ── 16-17: create_mcp_router 工厂 ────────────────────────
|
||
|
||
|
||
class TestCreateMcpRouter:
|
||
"""create_mcp_router() 工厂行为。"""
|
||
|
||
def test_create_mcp_router_returns_apirouter(self):
|
||
"""场景 16: create_mcp_router() 返回 APIRouter 实例。"""
|
||
router = create_mcp_router(tool_registry=None)
|
||
assert isinstance(router, APIRouter)
|
||
# 验证路由数量:tools/list, tools/call, health, /(JSON-RPC)
|
||
paths = {route.path for route in router.routes}
|
||
assert "/tools/list" in paths
|
||
assert "/tools/call" in paths
|
||
assert "/health" in paths
|
||
assert "/" in paths
|
||
|
||
async def test_empty_registry_returns_empty_list(self):
|
||
"""场景 17: tool_registry=None → /tools/list 返回 {"tools": []}。"""
|
||
app = _make_app(tool_registry=None)
|
||
transport = httpx.ASGITransport(app=app)
|
||
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
|
||
resp = await client.get(
|
||
"/api/v1/mcp/tools/list",
|
||
headers={"X-API-Key": API_KEY},
|
||
)
|
||
assert resp.status_code == 200
|
||
assert resp.json() == {"tools": []}
|