"""MCP Transport 层单元测试""" import asyncio import json from unittest.mock import MagicMock import httpx import pytest from agentkit.mcp.transport import HTTPTransport, SSETransport, TransportError # ── HTTPTransport 测试 ────────────────────────────────────────── class TestHTTPTransport: """HTTPTransport 测试""" async def test_connect_creates_client(self): transport = HTTPTransport(endpoint="http://localhost:8080") assert not transport.is_connected await transport.connect() assert transport.is_connected await transport.disconnect() assert not transport.is_connected async def test_disconnect_is_idempotent(self): transport = HTTPTransport(endpoint="http://localhost:8080") await transport.connect() await transport.disconnect() # 再次 disconnect 不应报错 await transport.disconnect() async def test_connect_is_idempotent(self): transport = HTTPTransport(endpoint="http://localhost:8080") await transport.connect() await transport.connect() # 不应报错 assert transport.is_connected await transport.disconnect() async def test_send_request_not_connected_raises(self): transport = HTTPTransport(endpoint="http://localhost:8080") with pytest.raises(TransportError, match="not connected"): await transport.send_request("tools/list") async def test_send_request_with_mock_server(self, httpx_mock): httpx_mock.add_response( url="http://localhost:8080/", json={ "jsonrpc": "2.0", "id": 1, "result": {"tools": [{"name": "echo"}]}, }, ) transport = HTTPTransport(endpoint="http://localhost:8080") await transport.connect() result = await transport.send_request("tools/list") assert result == {"tools": [{"name": "echo"}]} await transport.disconnect() async def test_send_request_with_params(self, httpx_mock): httpx_mock.add_response( url="http://localhost:8080/", json={ "jsonrpc": "2.0", "id": 1, "result": {"content": [{"type": "text", "text": "hello"}]}, }, ) transport = HTTPTransport(endpoint="http://localhost:8080") await transport.connect() result = await transport.send_request( "tools/call", params={"name": "echo", "arguments": {"msg": "hello"}} ) assert result == {"content": [{"type": "text", "text": "hello"}]} # 验证请求体 request = httpx_mock.get_request() body = json.loads(request.content) assert body["jsonrpc"] == "2.0" assert body["method"] == "tools/call" assert body["params"] == {"name": "echo", "arguments": {"msg": "hello"}} await transport.disconnect() async def test_send_request_json_rpc_error(self, httpx_mock): httpx_mock.add_response( url="http://localhost:8080/", json={ "jsonrpc": "2.0", "id": 1, "error": {"code": -32600, "message": "Invalid Request"}, }, ) transport = HTTPTransport(endpoint="http://localhost:8080") await transport.connect() with pytest.raises(TransportError, match="JSON-RPC error"): await transport.send_request("invalid/method") await transport.disconnect() async def test_send_request_http_error(self, httpx_mock): httpx_mock.add_response( url="http://localhost:8080/", status_code=500, ) transport = HTTPTransport(endpoint="http://localhost:8080") await transport.connect() with pytest.raises(TransportError, match="HTTP error 500"): await transport.send_request("tools/list") await transport.disconnect() async def test_send_request_network_error(self, httpx_mock): httpx_mock.add_exception(httpx.ConnectError("Connection refused")) transport = HTTPTransport(endpoint="http://localhost:8080") await transport.connect() with pytest.raises(TransportError, match="Request failed"): await transport.send_request("tools/list") await transport.disconnect() async def test_send_request_invalid_json_response(self, httpx_mock): httpx_mock.add_response( url="http://localhost:8080/", text="not json", ) transport = HTTPTransport(endpoint="http://localhost:8080") await transport.connect() with pytest.raises(TransportError, match="Invalid JSON response"): await transport.send_request("tools/list") await transport.disconnect() async def test_request_id_increments(self, httpx_mock): httpx_mock.add_response( url="http://localhost:8080/", json={"jsonrpc": "2.0", "id": 1, "result": {}}, ) httpx_mock.add_response( url="http://localhost:8080/", json={"jsonrpc": "2.0", "id": 2, "result": {}}, ) transport = HTTPTransport(endpoint="http://localhost:8080") await transport.connect() await transport.send_request("method1") await transport.send_request("method2") requests = httpx_mock.get_requests() body1 = json.loads(requests[0].content) body2 = json.loads(requests[1].content) assert body1["id"] == 1 assert body2["id"] == 2 await transport.disconnect() async def test_receive_response_no_pending_raises(self): transport = HTTPTransport(endpoint="http://localhost:8080") await transport.connect() with pytest.raises(TransportError, match="No pending response"): await transport.receive_response() await transport.disconnect() async def test_custom_headers(self, httpx_mock): httpx_mock.add_response( url="http://localhost:8080/", json={"jsonrpc": "2.0", "id": 1, "result": {}}, ) transport = HTTPTransport( endpoint="http://localhost:8080", headers={"Authorization": "Bearer test-token"}, ) await transport.connect() await transport.send_request("tools/list") request = httpx_mock.get_request() assert request.headers.get("authorization") == "Bearer test-token" await transport.disconnect() async def test_custom_timeout(self): transport = HTTPTransport(endpoint="http://localhost:8080", timeout=5.0) await transport.connect() assert transport._client is not None assert transport._client.timeout.read == 5.0 await transport.disconnect() # ── SSETransport 测试 ────────────────────────────────────────── class TestSSETransport: """SSETransport 测试""" async def test_connect_sets_connected(self): transport = SSETransport(endpoint="http://localhost:8080") assert not transport.is_connected await transport.connect() assert transport.is_connected await transport.disconnect() assert not transport.is_connected async def test_disconnect_cancels_sse_task(self): transport = SSETransport(endpoint="http://localhost:8080") await transport.connect() assert transport._sse_task is not None await transport.disconnect() assert transport._sse_task is None async def test_disconnect_is_idempotent(self): transport = SSETransport(endpoint="http://localhost:8080") await transport.connect() await transport.disconnect() await transport.disconnect() # 不应报错 async def test_send_request_not_connected_raises(self): transport = SSETransport(endpoint="http://localhost:8080") with pytest.raises(TransportError, match="not connected"): await transport.send_request("tools/list") async def test_receive_response_not_connected_raises(self): transport = SSETransport(endpoint="http://localhost:8080") with pytest.raises(TransportError, match="not connected"): await transport.receive_response() async def test_send_request_with_mock_server(self, httpx_mock): httpx_mock.add_response( url="http://localhost:8080/message", json={ "jsonrpc": "2.0", "id": 1, "result": {"status": "ok"}, }, ) transport = SSETransport(endpoint="http://localhost:8080") await transport.connect() result = await transport.send_request("initialize", params={"protocol": "2024-11-05"}) assert result == {"status": "ok"} await transport.disconnect() async def test_send_request_http_error(self, httpx_mock): httpx_mock.add_response( url="http://localhost:8080/message", status_code=503, ) transport = SSETransport(endpoint="http://localhost:8080") await transport.connect() with pytest.raises(TransportError, match="HTTP error 503"): await transport.send_request("tools/list") await transport.disconnect() async def test_send_request_network_error(self, httpx_mock): httpx_mock.add_exception(httpx.ConnectError("Connection refused")) transport = SSETransport(endpoint="http://localhost:8080") await transport.connect() with pytest.raises(TransportError, match="Request failed"): await transport.send_request("tools/list") await transport.disconnect() async def test_send_request_json_rpc_error(self, httpx_mock): httpx_mock.add_response( url="http://localhost:8080/message", json={ "jsonrpc": "2.0", "id": 1, "error": {"code": -32601, "message": "Method not found"}, }, ) transport = SSETransport(endpoint="http://localhost:8080") await transport.connect() with pytest.raises(TransportError, match="JSON-RPC error"): await transport.send_request("unknown/method") await transport.disconnect() async def test_receive_response_from_sse_stream(self): """测试从 SSE 流接收响应(通过直接注入队列数据模拟)""" transport = SSETransport(endpoint="http://localhost:8080") await transport.connect() # 模拟 SSE 监听器收到数据并放入队列 await transport._response_queue.put( {"jsonrpc": "2.0", "id": 1, "result": {"tools": []}} ) response = await asyncio.wait_for( transport.receive_response(), timeout=2.0 ) assert response == {"jsonrpc": "2.0", "id": 1, "result": {"tools": []}} await transport.disconnect() async def test_receive_response_timeout(self): """测试接收响应超时""" transport = SSETransport(endpoint="http://localhost:8080", timeout=0.1) await transport.connect() with pytest.raises(TransportError, match="Timeout"): await transport.receive_response() await transport.disconnect() async def test_custom_paths(self, httpx_mock): httpx_mock.add_response( url="http://localhost:8080/custom-message", json={"jsonrpc": "2.0", "id": 1, "result": {}}, ) transport = SSETransport( endpoint="http://localhost:8080", sse_path="/custom-sse", message_path="/custom-message", ) await transport.connect() await transport.send_request("test") request = httpx_mock.get_request() assert request.url.path == "/custom-message" await transport.disconnect() async def test_custom_headers(self, httpx_mock): httpx_mock.add_response( url="http://localhost:8080/message", json={"jsonrpc": "2.0", "id": 1, "result": {}}, ) transport = SSETransport( endpoint="http://localhost:8080", headers={"Authorization": "Bearer sse-token"}, ) await transport.connect() await transport.send_request("test") request = httpx_mock.get_request() assert request.headers.get("authorization") == "Bearer sse-token" await transport.disconnect() async def test_sse_ignores_comments_and_empty_lines(self): """测试 SSE 忽略注释行和空行(通过直接注入队列数据模拟)""" transport = SSETransport(endpoint="http://localhost:8080") await transport.connect() # 模拟 SSE 监听器过滤注释和空行后放入队列 await transport._response_queue.put( {"jsonrpc": "2.0", "id": 1, "result": {"ok": True}} ) response = await asyncio.wait_for( transport.receive_response(), timeout=2.0 ) assert response == {"jsonrpc": "2.0", "id": 1, "result": {"ok": True}} await transport.disconnect() # ── Transport 生命周期测试 ────────────────────────────────────── class TestTransportLifecycle: """传输层生命周期测试""" async def test_http_transport_full_lifecycle(self, httpx_mock): httpx_mock.add_response( url="http://localhost:8080/", json={"jsonrpc": "2.0", "id": 1, "result": {"initialized": True}}, ) transport = HTTPTransport(endpoint="http://localhost:8080") # 1. 连接 await transport.connect() assert transport.is_connected # 2. 发送请求 result = await transport.send_request("initialize") assert result == {"initialized": True} # 3. 断开 await transport.disconnect() assert not transport.is_connected async def test_sse_transport_full_lifecycle(self, httpx_mock): httpx_mock.add_response( url="http://localhost:8080/message", json={"jsonrpc": "2.0", "id": 1, "result": {"initialized": True}}, ) transport = SSETransport(endpoint="http://localhost:8080") # 1. 连接 await transport.connect() assert transport.is_connected # 2. 发送请求 result = await transport.send_request("initialize") assert result == {"initialized": True} # 3. 断开 await transport.disconnect() assert not transport.is_connected async def test_reconnect_after_disconnect(self, httpx_mock): httpx_mock.add_response( url="http://localhost:8080/", json={"jsonrpc": "2.0", "id": 1, "result": {"first": True}}, ) httpx_mock.add_response( url="http://localhost:8080/", json={"jsonrpc": "2.0", "id": 2, "result": {"second": True}}, ) transport = HTTPTransport(endpoint="http://localhost:8080") # 第一次连接 await transport.connect() result1 = await transport.send_request("method1") assert result1 == {"first": True} await transport.disconnect() # 重新连接 await transport.connect() result2 = await transport.send_request("method2") assert result2 == {"second": True} await transport.disconnect() # ── StdioTransport receive_response 测试 (P0 fix) ────────────────── class TestStdioTransportReceiveResponse: """测试 StdioTransport.receive_response() await 行为""" async def test_awaits_empty_notification_queue(self): """空队列时 receive_response 应 await 而非立即抛异常""" from agentkit.mcp.transport import StdioTransport transport = StdioTransport(command="echo", timeout=2.0) # 手动设置连接状态(不实际启动子进程) transport._connected = True transport._process = MagicMock() transport._process.returncode = None # 在后台放入一个通知来解除 await notification = {"jsonrpc": "2.0", "method": "notifications/progress", "params": {"progress": 50}} asyncio.get_event_loop().call_later(0.1, lambda: asyncio.ensure_future( transport._notifications.put(notification) )) result = await asyncio.wait_for( transport.receive_response(), timeout=1.0 ) assert result == notification async def test_immediate_return_when_notification_available(self): """队列中已有通知时立即返回""" from agentkit.mcp.transport import StdioTransport transport = StdioTransport(command="echo", timeout=2.0) transport._connected = True transport._process = MagicMock() transport._process.returncode = None notification = {"jsonrpc": "2.0", "method": "test"} await transport._notifications.put(notification) result = await transport.receive_response() assert result == notification async def test_timeout_raises_transport_error(self): """超时时抛出 TransportError""" from agentkit.mcp.transport import StdioTransport, TransportError transport = StdioTransport(command="echo", timeout=0.1) transport._connected = True transport._process = MagicMock() transport._process.returncode = None with pytest.raises(TransportError, match="Timeout"): await transport.receive_response() async def test_not_connected_raises_transport_error(self): """未连接时抛出 TransportError""" from agentkit.mcp.transport import StdioTransport, TransportError transport = StdioTransport(command="echo") with pytest.raises(TransportError, match="not connected"): await transport.receive_response()