fischer-agentkit/tests/unit/test_mcp_transport.py

463 lines
15 KiB
Python

"""MCP Transport 层单元测试"""
import asyncio
import json
import httpx
import pytest
from agentkit.mcp.transport import HTTPTransport, SSETransport, TransportError
# ── HTTPTransport 测试 ──────────────────────────────────────────
class TestHTTPTransport:
"""HTTPTransport 测试"""
async def test_connect_creates_client(self):
transport = HTTPTransport(endpoint="http://localhost:8080")
assert not transport.is_connected
await transport.connect()
assert transport.is_connected
await transport.disconnect()
assert not transport.is_connected
async def test_disconnect_is_idempotent(self):
transport = HTTPTransport(endpoint="http://localhost:8080")
await transport.connect()
await transport.disconnect()
# 再次 disconnect 不应报错
await transport.disconnect()
async def test_connect_is_idempotent(self):
transport = HTTPTransport(endpoint="http://localhost:8080")
await transport.connect()
await transport.connect() # 不应报错
assert transport.is_connected
await transport.disconnect()
async def test_send_request_not_connected_raises(self):
transport = HTTPTransport(endpoint="http://localhost:8080")
with pytest.raises(TransportError, match="not connected"):
await transport.send_request("tools/list")
async def test_send_request_with_mock_server(self, httpx_mock):
httpx_mock.add_response(
url="http://localhost:8080/",
json={
"jsonrpc": "2.0",
"id": 1,
"result": {"tools": [{"name": "echo"}]},
},
)
transport = HTTPTransport(endpoint="http://localhost:8080")
await transport.connect()
result = await transport.send_request("tools/list")
assert result == {"tools": [{"name": "echo"}]}
await transport.disconnect()
async def test_send_request_with_params(self, httpx_mock):
httpx_mock.add_response(
url="http://localhost:8080/",
json={
"jsonrpc": "2.0",
"id": 1,
"result": {"content": [{"type": "text", "text": "hello"}]},
},
)
transport = HTTPTransport(endpoint="http://localhost:8080")
await transport.connect()
result = await transport.send_request(
"tools/call", params={"name": "echo", "arguments": {"msg": "hello"}}
)
assert result == {"content": [{"type": "text", "text": "hello"}]}
# 验证请求体
request = httpx_mock.get_request()
body = json.loads(request.content)
assert body["jsonrpc"] == "2.0"
assert body["method"] == "tools/call"
assert body["params"] == {"name": "echo", "arguments": {"msg": "hello"}}
await transport.disconnect()
async def test_send_request_json_rpc_error(self, httpx_mock):
httpx_mock.add_response(
url="http://localhost:8080/",
json={
"jsonrpc": "2.0",
"id": 1,
"error": {"code": -32600, "message": "Invalid Request"},
},
)
transport = HTTPTransport(endpoint="http://localhost:8080")
await transport.connect()
with pytest.raises(TransportError, match="JSON-RPC error"):
await transport.send_request("invalid/method")
await transport.disconnect()
async def test_send_request_http_error(self, httpx_mock):
httpx_mock.add_response(
url="http://localhost:8080/",
status_code=500,
)
transport = HTTPTransport(endpoint="http://localhost:8080")
await transport.connect()
with pytest.raises(TransportError, match="HTTP error 500"):
await transport.send_request("tools/list")
await transport.disconnect()
async def test_send_request_network_error(self, httpx_mock):
httpx_mock.add_exception(httpx.ConnectError("Connection refused"))
transport = HTTPTransport(endpoint="http://localhost:8080")
await transport.connect()
with pytest.raises(TransportError, match="Request failed"):
await transport.send_request("tools/list")
await transport.disconnect()
async def test_send_request_invalid_json_response(self, httpx_mock):
httpx_mock.add_response(
url="http://localhost:8080/",
text="not json",
)
transport = HTTPTransport(endpoint="http://localhost:8080")
await transport.connect()
with pytest.raises(TransportError, match="Invalid JSON response"):
await transport.send_request("tools/list")
await transport.disconnect()
async def test_request_id_increments(self, httpx_mock):
httpx_mock.add_response(
url="http://localhost:8080/",
json={"jsonrpc": "2.0", "id": 1, "result": {}},
)
httpx_mock.add_response(
url="http://localhost:8080/",
json={"jsonrpc": "2.0", "id": 2, "result": {}},
)
transport = HTTPTransport(endpoint="http://localhost:8080")
await transport.connect()
await transport.send_request("method1")
await transport.send_request("method2")
requests = httpx_mock.get_requests()
body1 = json.loads(requests[0].content)
body2 = json.loads(requests[1].content)
assert body1["id"] == 1
assert body2["id"] == 2
await transport.disconnect()
async def test_receive_response_no_pending_raises(self):
transport = HTTPTransport(endpoint="http://localhost:8080")
await transport.connect()
with pytest.raises(TransportError, match="No pending response"):
await transport.receive_response()
await transport.disconnect()
async def test_custom_headers(self, httpx_mock):
httpx_mock.add_response(
url="http://localhost:8080/",
json={"jsonrpc": "2.0", "id": 1, "result": {}},
)
transport = HTTPTransport(
endpoint="http://localhost:8080",
headers={"Authorization": "Bearer test-token"},
)
await transport.connect()
await transport.send_request("tools/list")
request = httpx_mock.get_request()
assert request.headers.get("authorization") == "Bearer test-token"
await transport.disconnect()
async def test_custom_timeout(self):
transport = HTTPTransport(endpoint="http://localhost:8080", timeout=5.0)
await transport.connect()
assert transport._client is not None
assert transport._client.timeout.read == 5.0
await transport.disconnect()
# ── SSETransport 测试 ──────────────────────────────────────────
class TestSSETransport:
"""SSETransport 测试"""
async def test_connect_sets_connected(self):
transport = SSETransport(endpoint="http://localhost:8080")
assert not transport.is_connected
await transport.connect()
assert transport.is_connected
await transport.disconnect()
assert not transport.is_connected
async def test_disconnect_cancels_sse_task(self):
transport = SSETransport(endpoint="http://localhost:8080")
await transport.connect()
assert transport._sse_task is not None
await transport.disconnect()
assert transport._sse_task is None
async def test_disconnect_is_idempotent(self):
transport = SSETransport(endpoint="http://localhost:8080")
await transport.connect()
await transport.disconnect()
await transport.disconnect() # 不应报错
async def test_send_request_not_connected_raises(self):
transport = SSETransport(endpoint="http://localhost:8080")
with pytest.raises(TransportError, match="not connected"):
await transport.send_request("tools/list")
async def test_receive_response_not_connected_raises(self):
transport = SSETransport(endpoint="http://localhost:8080")
with pytest.raises(TransportError, match="not connected"):
await transport.receive_response()
async def test_send_request_with_mock_server(self, httpx_mock):
httpx_mock.add_response(
url="http://localhost:8080/message",
json={
"jsonrpc": "2.0",
"id": 1,
"result": {"status": "ok"},
},
)
transport = SSETransport(endpoint="http://localhost:8080")
await transport.connect()
result = await transport.send_request("initialize", params={"protocol": "2024-11-05"})
assert result == {"status": "ok"}
await transport.disconnect()
async def test_send_request_http_error(self, httpx_mock):
httpx_mock.add_response(
url="http://localhost:8080/message",
status_code=503,
)
transport = SSETransport(endpoint="http://localhost:8080")
await transport.connect()
with pytest.raises(TransportError, match="HTTP error 503"):
await transport.send_request("tools/list")
await transport.disconnect()
async def test_send_request_network_error(self, httpx_mock):
httpx_mock.add_exception(httpx.ConnectError("Connection refused"))
transport = SSETransport(endpoint="http://localhost:8080")
await transport.connect()
with pytest.raises(TransportError, match="Request failed"):
await transport.send_request("tools/list")
await transport.disconnect()
async def test_send_request_json_rpc_error(self, httpx_mock):
httpx_mock.add_response(
url="http://localhost:8080/message",
json={
"jsonrpc": "2.0",
"id": 1,
"error": {"code": -32601, "message": "Method not found"},
},
)
transport = SSETransport(endpoint="http://localhost:8080")
await transport.connect()
with pytest.raises(TransportError, match="JSON-RPC error"):
await transport.send_request("unknown/method")
await transport.disconnect()
async def test_receive_response_from_sse_stream(self):
"""测试从 SSE 流接收响应(通过直接注入队列数据模拟)"""
transport = SSETransport(endpoint="http://localhost:8080")
await transport.connect()
# 模拟 SSE 监听器收到数据并放入队列
await transport._response_queue.put(
{"jsonrpc": "2.0", "id": 1, "result": {"tools": []}}
)
response = await asyncio.wait_for(
transport.receive_response(), timeout=2.0
)
assert response == {"jsonrpc": "2.0", "id": 1, "result": {"tools": []}}
await transport.disconnect()
async def test_receive_response_timeout(self):
"""测试接收响应超时"""
transport = SSETransport(endpoint="http://localhost:8080", timeout=0.1)
await transport.connect()
with pytest.raises(TransportError, match="Timeout"):
await transport.receive_response()
await transport.disconnect()
async def test_custom_paths(self, httpx_mock):
httpx_mock.add_response(
url="http://localhost:8080/custom-message",
json={"jsonrpc": "2.0", "id": 1, "result": {}},
)
transport = SSETransport(
endpoint="http://localhost:8080",
sse_path="/custom-sse",
message_path="/custom-message",
)
await transport.connect()
await transport.send_request("test")
request = httpx_mock.get_request()
assert request.url.path == "/custom-message"
await transport.disconnect()
async def test_custom_headers(self, httpx_mock):
httpx_mock.add_response(
url="http://localhost:8080/message",
json={"jsonrpc": "2.0", "id": 1, "result": {}},
)
transport = SSETransport(
endpoint="http://localhost:8080",
headers={"Authorization": "Bearer sse-token"},
)
await transport.connect()
await transport.send_request("test")
request = httpx_mock.get_request()
assert request.headers.get("authorization") == "Bearer sse-token"
await transport.disconnect()
async def test_sse_ignores_comments_and_empty_lines(self):
"""测试 SSE 忽略注释行和空行(通过直接注入队列数据模拟)"""
transport = SSETransport(endpoint="http://localhost:8080")
await transport.connect()
# 模拟 SSE 监听器过滤注释和空行后放入队列
await transport._response_queue.put(
{"jsonrpc": "2.0", "id": 1, "result": {"ok": True}}
)
response = await asyncio.wait_for(
transport.receive_response(), timeout=2.0
)
assert response == {"jsonrpc": "2.0", "id": 1, "result": {"ok": True}}
await transport.disconnect()
# ── Transport 生命周期测试 ──────────────────────────────────────
class TestTransportLifecycle:
"""传输层生命周期测试"""
async def test_http_transport_full_lifecycle(self, httpx_mock):
httpx_mock.add_response(
url="http://localhost:8080/",
json={"jsonrpc": "2.0", "id": 1, "result": {"initialized": True}},
)
transport = HTTPTransport(endpoint="http://localhost:8080")
# 1. 连接
await transport.connect()
assert transport.is_connected
# 2. 发送请求
result = await transport.send_request("initialize")
assert result == {"initialized": True}
# 3. 断开
await transport.disconnect()
assert not transport.is_connected
async def test_sse_transport_full_lifecycle(self, httpx_mock):
httpx_mock.add_response(
url="http://localhost:8080/message",
json={"jsonrpc": "2.0", "id": 1, "result": {"initialized": True}},
)
transport = SSETransport(endpoint="http://localhost:8080")
# 1. 连接
await transport.connect()
assert transport.is_connected
# 2. 发送请求
result = await transport.send_request("initialize")
assert result == {"initialized": True}
# 3. 断开
await transport.disconnect()
assert not transport.is_connected
async def test_reconnect_after_disconnect(self, httpx_mock):
httpx_mock.add_response(
url="http://localhost:8080/",
json={"jsonrpc": "2.0", "id": 1, "result": {"first": True}},
)
httpx_mock.add_response(
url="http://localhost:8080/",
json={"jsonrpc": "2.0", "id": 2, "result": {"second": True}},
)
transport = HTTPTransport(endpoint="http://localhost:8080")
# 第一次连接
await transport.connect()
result1 = await transport.send_request("method1")
assert result1 == {"first": True}
await transport.disconnect()
# 重新连接
await transport.connect()
result2 = await transport.send_request("method2")
assert result2 == {"second": True}
await transport.disconnect()