463 lines
15 KiB
Python
463 lines
15 KiB
Python
"""MCP Transport 层单元测试"""
|
|
|
|
import asyncio
|
|
import json
|
|
|
|
import httpx
|
|
import pytest
|
|
|
|
from agentkit.mcp.transport import HTTPTransport, SSETransport, TransportError
|
|
|
|
|
|
# ── HTTPTransport 测试 ──────────────────────────────────────────
|
|
|
|
|
|
class TestHTTPTransport:
|
|
"""HTTPTransport 测试"""
|
|
|
|
async def test_connect_creates_client(self):
|
|
transport = HTTPTransport(endpoint="http://localhost:8080")
|
|
assert not transport.is_connected
|
|
|
|
await transport.connect()
|
|
assert transport.is_connected
|
|
|
|
await transport.disconnect()
|
|
assert not transport.is_connected
|
|
|
|
async def test_disconnect_is_idempotent(self):
|
|
transport = HTTPTransport(endpoint="http://localhost:8080")
|
|
await transport.connect()
|
|
await transport.disconnect()
|
|
# 再次 disconnect 不应报错
|
|
await transport.disconnect()
|
|
|
|
async def test_connect_is_idempotent(self):
|
|
transport = HTTPTransport(endpoint="http://localhost:8080")
|
|
await transport.connect()
|
|
await transport.connect() # 不应报错
|
|
assert transport.is_connected
|
|
await transport.disconnect()
|
|
|
|
async def test_send_request_not_connected_raises(self):
|
|
transport = HTTPTransport(endpoint="http://localhost:8080")
|
|
with pytest.raises(TransportError, match="not connected"):
|
|
await transport.send_request("tools/list")
|
|
|
|
async def test_send_request_with_mock_server(self, httpx_mock):
|
|
httpx_mock.add_response(
|
|
url="http://localhost:8080/",
|
|
json={
|
|
"jsonrpc": "2.0",
|
|
"id": 1,
|
|
"result": {"tools": [{"name": "echo"}]},
|
|
},
|
|
)
|
|
|
|
transport = HTTPTransport(endpoint="http://localhost:8080")
|
|
await transport.connect()
|
|
|
|
result = await transport.send_request("tools/list")
|
|
assert result == {"tools": [{"name": "echo"}]}
|
|
|
|
await transport.disconnect()
|
|
|
|
async def test_send_request_with_params(self, httpx_mock):
|
|
httpx_mock.add_response(
|
|
url="http://localhost:8080/",
|
|
json={
|
|
"jsonrpc": "2.0",
|
|
"id": 1,
|
|
"result": {"content": [{"type": "text", "text": "hello"}]},
|
|
},
|
|
)
|
|
|
|
transport = HTTPTransport(endpoint="http://localhost:8080")
|
|
await transport.connect()
|
|
|
|
result = await transport.send_request(
|
|
"tools/call", params={"name": "echo", "arguments": {"msg": "hello"}}
|
|
)
|
|
assert result == {"content": [{"type": "text", "text": "hello"}]}
|
|
|
|
# 验证请求体
|
|
request = httpx_mock.get_request()
|
|
body = json.loads(request.content)
|
|
assert body["jsonrpc"] == "2.0"
|
|
assert body["method"] == "tools/call"
|
|
assert body["params"] == {"name": "echo", "arguments": {"msg": "hello"}}
|
|
|
|
await transport.disconnect()
|
|
|
|
async def test_send_request_json_rpc_error(self, httpx_mock):
|
|
httpx_mock.add_response(
|
|
url="http://localhost:8080/",
|
|
json={
|
|
"jsonrpc": "2.0",
|
|
"id": 1,
|
|
"error": {"code": -32600, "message": "Invalid Request"},
|
|
},
|
|
)
|
|
|
|
transport = HTTPTransport(endpoint="http://localhost:8080")
|
|
await transport.connect()
|
|
|
|
with pytest.raises(TransportError, match="JSON-RPC error"):
|
|
await transport.send_request("invalid/method")
|
|
|
|
await transport.disconnect()
|
|
|
|
async def test_send_request_http_error(self, httpx_mock):
|
|
httpx_mock.add_response(
|
|
url="http://localhost:8080/",
|
|
status_code=500,
|
|
)
|
|
|
|
transport = HTTPTransport(endpoint="http://localhost:8080")
|
|
await transport.connect()
|
|
|
|
with pytest.raises(TransportError, match="HTTP error 500"):
|
|
await transport.send_request("tools/list")
|
|
|
|
await transport.disconnect()
|
|
|
|
async def test_send_request_network_error(self, httpx_mock):
|
|
httpx_mock.add_exception(httpx.ConnectError("Connection refused"))
|
|
|
|
transport = HTTPTransport(endpoint="http://localhost:8080")
|
|
await transport.connect()
|
|
|
|
with pytest.raises(TransportError, match="Request failed"):
|
|
await transport.send_request("tools/list")
|
|
|
|
await transport.disconnect()
|
|
|
|
async def test_send_request_invalid_json_response(self, httpx_mock):
|
|
httpx_mock.add_response(
|
|
url="http://localhost:8080/",
|
|
text="not json",
|
|
)
|
|
|
|
transport = HTTPTransport(endpoint="http://localhost:8080")
|
|
await transport.connect()
|
|
|
|
with pytest.raises(TransportError, match="Invalid JSON response"):
|
|
await transport.send_request("tools/list")
|
|
|
|
await transport.disconnect()
|
|
|
|
async def test_request_id_increments(self, httpx_mock):
|
|
httpx_mock.add_response(
|
|
url="http://localhost:8080/",
|
|
json={"jsonrpc": "2.0", "id": 1, "result": {}},
|
|
)
|
|
httpx_mock.add_response(
|
|
url="http://localhost:8080/",
|
|
json={"jsonrpc": "2.0", "id": 2, "result": {}},
|
|
)
|
|
|
|
transport = HTTPTransport(endpoint="http://localhost:8080")
|
|
await transport.connect()
|
|
|
|
await transport.send_request("method1")
|
|
await transport.send_request("method2")
|
|
|
|
requests = httpx_mock.get_requests()
|
|
body1 = json.loads(requests[0].content)
|
|
body2 = json.loads(requests[1].content)
|
|
assert body1["id"] == 1
|
|
assert body2["id"] == 2
|
|
|
|
await transport.disconnect()
|
|
|
|
async def test_receive_response_no_pending_raises(self):
|
|
transport = HTTPTransport(endpoint="http://localhost:8080")
|
|
await transport.connect()
|
|
|
|
with pytest.raises(TransportError, match="No pending response"):
|
|
await transport.receive_response()
|
|
|
|
await transport.disconnect()
|
|
|
|
async def test_custom_headers(self, httpx_mock):
|
|
httpx_mock.add_response(
|
|
url="http://localhost:8080/",
|
|
json={"jsonrpc": "2.0", "id": 1, "result": {}},
|
|
)
|
|
|
|
transport = HTTPTransport(
|
|
endpoint="http://localhost:8080",
|
|
headers={"Authorization": "Bearer test-token"},
|
|
)
|
|
await transport.connect()
|
|
|
|
await transport.send_request("tools/list")
|
|
|
|
request = httpx_mock.get_request()
|
|
assert request.headers.get("authorization") == "Bearer test-token"
|
|
|
|
await transport.disconnect()
|
|
|
|
async def test_custom_timeout(self):
|
|
transport = HTTPTransport(endpoint="http://localhost:8080", timeout=5.0)
|
|
await transport.connect()
|
|
assert transport._client is not None
|
|
assert transport._client.timeout.read == 5.0
|
|
await transport.disconnect()
|
|
|
|
|
|
# ── SSETransport 测试 ──────────────────────────────────────────
|
|
|
|
|
|
class TestSSETransport:
|
|
"""SSETransport 测试"""
|
|
|
|
async def test_connect_sets_connected(self):
|
|
transport = SSETransport(endpoint="http://localhost:8080")
|
|
assert not transport.is_connected
|
|
|
|
await transport.connect()
|
|
assert transport.is_connected
|
|
|
|
await transport.disconnect()
|
|
assert not transport.is_connected
|
|
|
|
async def test_disconnect_cancels_sse_task(self):
|
|
transport = SSETransport(endpoint="http://localhost:8080")
|
|
await transport.connect()
|
|
assert transport._sse_task is not None
|
|
|
|
await transport.disconnect()
|
|
assert transport._sse_task is None
|
|
|
|
async def test_disconnect_is_idempotent(self):
|
|
transport = SSETransport(endpoint="http://localhost:8080")
|
|
await transport.connect()
|
|
await transport.disconnect()
|
|
await transport.disconnect() # 不应报错
|
|
|
|
async def test_send_request_not_connected_raises(self):
|
|
transport = SSETransport(endpoint="http://localhost:8080")
|
|
with pytest.raises(TransportError, match="not connected"):
|
|
await transport.send_request("tools/list")
|
|
|
|
async def test_receive_response_not_connected_raises(self):
|
|
transport = SSETransport(endpoint="http://localhost:8080")
|
|
with pytest.raises(TransportError, match="not connected"):
|
|
await transport.receive_response()
|
|
|
|
async def test_send_request_with_mock_server(self, httpx_mock):
|
|
httpx_mock.add_response(
|
|
url="http://localhost:8080/message",
|
|
json={
|
|
"jsonrpc": "2.0",
|
|
"id": 1,
|
|
"result": {"status": "ok"},
|
|
},
|
|
)
|
|
|
|
transport = SSETransport(endpoint="http://localhost:8080")
|
|
await transport.connect()
|
|
|
|
result = await transport.send_request("initialize", params={"protocol": "2024-11-05"})
|
|
assert result == {"status": "ok"}
|
|
|
|
await transport.disconnect()
|
|
|
|
async def test_send_request_http_error(self, httpx_mock):
|
|
httpx_mock.add_response(
|
|
url="http://localhost:8080/message",
|
|
status_code=503,
|
|
)
|
|
|
|
transport = SSETransport(endpoint="http://localhost:8080")
|
|
await transport.connect()
|
|
|
|
with pytest.raises(TransportError, match="HTTP error 503"):
|
|
await transport.send_request("tools/list")
|
|
|
|
await transport.disconnect()
|
|
|
|
async def test_send_request_network_error(self, httpx_mock):
|
|
httpx_mock.add_exception(httpx.ConnectError("Connection refused"))
|
|
|
|
transport = SSETransport(endpoint="http://localhost:8080")
|
|
await transport.connect()
|
|
|
|
with pytest.raises(TransportError, match="Request failed"):
|
|
await transport.send_request("tools/list")
|
|
|
|
await transport.disconnect()
|
|
|
|
async def test_send_request_json_rpc_error(self, httpx_mock):
|
|
httpx_mock.add_response(
|
|
url="http://localhost:8080/message",
|
|
json={
|
|
"jsonrpc": "2.0",
|
|
"id": 1,
|
|
"error": {"code": -32601, "message": "Method not found"},
|
|
},
|
|
)
|
|
|
|
transport = SSETransport(endpoint="http://localhost:8080")
|
|
await transport.connect()
|
|
|
|
with pytest.raises(TransportError, match="JSON-RPC error"):
|
|
await transport.send_request("unknown/method")
|
|
|
|
await transport.disconnect()
|
|
|
|
async def test_receive_response_from_sse_stream(self):
|
|
"""测试从 SSE 流接收响应(通过直接注入队列数据模拟)"""
|
|
transport = SSETransport(endpoint="http://localhost:8080")
|
|
await transport.connect()
|
|
|
|
# 模拟 SSE 监听器收到数据并放入队列
|
|
await transport._response_queue.put(
|
|
{"jsonrpc": "2.0", "id": 1, "result": {"tools": []}}
|
|
)
|
|
|
|
response = await asyncio.wait_for(
|
|
transport.receive_response(), timeout=2.0
|
|
)
|
|
assert response == {"jsonrpc": "2.0", "id": 1, "result": {"tools": []}}
|
|
|
|
await transport.disconnect()
|
|
|
|
async def test_receive_response_timeout(self):
|
|
"""测试接收响应超时"""
|
|
transport = SSETransport(endpoint="http://localhost:8080", timeout=0.1)
|
|
await transport.connect()
|
|
|
|
with pytest.raises(TransportError, match="Timeout"):
|
|
await transport.receive_response()
|
|
|
|
await transport.disconnect()
|
|
|
|
async def test_custom_paths(self, httpx_mock):
|
|
httpx_mock.add_response(
|
|
url="http://localhost:8080/custom-message",
|
|
json={"jsonrpc": "2.0", "id": 1, "result": {}},
|
|
)
|
|
|
|
transport = SSETransport(
|
|
endpoint="http://localhost:8080",
|
|
sse_path="/custom-sse",
|
|
message_path="/custom-message",
|
|
)
|
|
await transport.connect()
|
|
|
|
await transport.send_request("test")
|
|
|
|
request = httpx_mock.get_request()
|
|
assert request.url.path == "/custom-message"
|
|
|
|
await transport.disconnect()
|
|
|
|
async def test_custom_headers(self, httpx_mock):
|
|
httpx_mock.add_response(
|
|
url="http://localhost:8080/message",
|
|
json={"jsonrpc": "2.0", "id": 1, "result": {}},
|
|
)
|
|
|
|
transport = SSETransport(
|
|
endpoint="http://localhost:8080",
|
|
headers={"Authorization": "Bearer sse-token"},
|
|
)
|
|
await transport.connect()
|
|
|
|
await transport.send_request("test")
|
|
|
|
request = httpx_mock.get_request()
|
|
assert request.headers.get("authorization") == "Bearer sse-token"
|
|
|
|
await transport.disconnect()
|
|
|
|
async def test_sse_ignores_comments_and_empty_lines(self):
|
|
"""测试 SSE 忽略注释行和空行(通过直接注入队列数据模拟)"""
|
|
transport = SSETransport(endpoint="http://localhost:8080")
|
|
await transport.connect()
|
|
|
|
# 模拟 SSE 监听器过滤注释和空行后放入队列
|
|
await transport._response_queue.put(
|
|
{"jsonrpc": "2.0", "id": 1, "result": {"ok": True}}
|
|
)
|
|
|
|
response = await asyncio.wait_for(
|
|
transport.receive_response(), timeout=2.0
|
|
)
|
|
assert response == {"jsonrpc": "2.0", "id": 1, "result": {"ok": True}}
|
|
|
|
await transport.disconnect()
|
|
|
|
|
|
# ── Transport 生命周期测试 ──────────────────────────────────────
|
|
|
|
|
|
class TestTransportLifecycle:
|
|
"""传输层生命周期测试"""
|
|
|
|
async def test_http_transport_full_lifecycle(self, httpx_mock):
|
|
httpx_mock.add_response(
|
|
url="http://localhost:8080/",
|
|
json={"jsonrpc": "2.0", "id": 1, "result": {"initialized": True}},
|
|
)
|
|
|
|
transport = HTTPTransport(endpoint="http://localhost:8080")
|
|
|
|
# 1. 连接
|
|
await transport.connect()
|
|
assert transport.is_connected
|
|
|
|
# 2. 发送请求
|
|
result = await transport.send_request("initialize")
|
|
assert result == {"initialized": True}
|
|
|
|
# 3. 断开
|
|
await transport.disconnect()
|
|
assert not transport.is_connected
|
|
|
|
async def test_sse_transport_full_lifecycle(self, httpx_mock):
|
|
httpx_mock.add_response(
|
|
url="http://localhost:8080/message",
|
|
json={"jsonrpc": "2.0", "id": 1, "result": {"initialized": True}},
|
|
)
|
|
|
|
transport = SSETransport(endpoint="http://localhost:8080")
|
|
|
|
# 1. 连接
|
|
await transport.connect()
|
|
assert transport.is_connected
|
|
|
|
# 2. 发送请求
|
|
result = await transport.send_request("initialize")
|
|
assert result == {"initialized": True}
|
|
|
|
# 3. 断开
|
|
await transport.disconnect()
|
|
assert not transport.is_connected
|
|
|
|
async def test_reconnect_after_disconnect(self, httpx_mock):
|
|
httpx_mock.add_response(
|
|
url="http://localhost:8080/",
|
|
json={"jsonrpc": "2.0", "id": 1, "result": {"first": True}},
|
|
)
|
|
httpx_mock.add_response(
|
|
url="http://localhost:8080/",
|
|
json={"jsonrpc": "2.0", "id": 2, "result": {"second": True}},
|
|
)
|
|
|
|
transport = HTTPTransport(endpoint="http://localhost:8080")
|
|
|
|
# 第一次连接
|
|
await transport.connect()
|
|
result1 = await transport.send_request("method1")
|
|
assert result1 == {"first": True}
|
|
await transport.disconnect()
|
|
|
|
# 重新连接
|
|
await transport.connect()
|
|
result2 = await transport.send_request("method2")
|
|
assert result2 == {"second": True}
|
|
await transport.disconnect()
|