"""StdioTransport 单元测试 使用内联 mock MCP server 子进程进行测试,无需外部依赖。 """ import asyncio import json import sys import textwrap import pytest from agentkit.mcp.transport import StdioTransport, TransportError # 内联 mock MCP server 脚本 # 读取 stdin 的 JSON-RPC 消息,根据 method 返回对应响应 MOCK_SERVER_SCRIPT = textwrap.dedent("""\ import sys import json def handle_request(data): method = data.get("method", "") req_id = data.get("id") params = data.get("params", {}) if method == "initialize": return { "jsonrpc": "2.0", "id": req_id, "result": { "protocolVersion": "2024-11-05", "capabilities": {"tools": {"listChanged": True}}, "serverInfo": {"name": "mock-mcp-server", "version": "0.1.0"}, }, } elif method == "tools/list": return { "jsonrpc": "2.0", "id": req_id, "result": { "tools": [ { "name": "echo", "description": "Echo tool", "inputSchema": {"type": "object", "properties": {"msg": {"type": "string"}}}, } ] }, } elif method == "tools/call": name = params.get("name", "") arguments = params.get("arguments", {}) if name == "echo": return { "jsonrpc": "2.0", "id": req_id, "result": { "content": [{"type": "text", "text": arguments.get("msg", "")}] }, } else: return { "jsonrpc": "2.0", "id": req_id, "error": {"code": -32601, "message": f"Unknown tool: {name}"}, } else: return { "jsonrpc": "2.0", "id": req_id, "error": {"code": -32601, "message": f"Method not found: {method}"}, } def main(): for line in sys.stdin: line = line.strip() if not line: continue try: data = json.loads(line) except json.JSONDecodeError: continue # 通知消息(无 id)不回复 if "id" not in data: continue response = handle_request(data) sys.stdout.write(json.dumps(response) + "\\n") sys.stdout.flush() if __name__ == "__main__": main() """) # 发送通知后立即退出的 mock server(用于测试子进程退出检测) EXIT_AFTER_INIT_SCRIPT = textwrap.dedent("""\ import sys import json for line in sys.stdin: line = line.strip() if not line: continue try: data = json.loads(line) except json.JSONDecodeError: continue if "id" not in data: continue method = data.get("method", "") req_id = data.get("id") if method == "initialize": response = { "jsonrpc": "2.0", "id": req_id, "result": {"protocolVersion": "2024-11-05", "capabilities": {}, "serverInfo": {"name": "exit-server", "version": "0.1.0"}}, } sys.stdout.write(json.dumps(response) + "\\n") sys.stdout.flush() # 初始化后立即退出 sys.exit(0) else: # 不会到达这里 pass """) # 发送通知的 mock server(用于测试通知接收) NOTIFICATION_SERVER_SCRIPT = textwrap.dedent("""\ import sys import json for line in sys.stdin: line = line.strip() if not line: continue try: data = json.loads(line) except json.JSONDecodeError: continue if "id" not in data: continue method = data.get("method", "") req_id = data.get("id") if method == "initialize": response = { "jsonrpc": "2.0", "id": req_id, "result": {"protocolVersion": "2024-11-05", "capabilities": {}, "serverInfo": {"name": "notif-server", "version": "0.1.0"}}, } sys.stdout.write(json.dumps(response) + "\\n") sys.stdout.flush() elif method == "tools/list": # 先发送一个通知 notification = { "jsonrpc": "2.0", "method": "notifications/tools/list_changed", } sys.stdout.write(json.dumps(notification) + "\\n") sys.stdout.flush() # 再发送响应 response = { "jsonrpc": "2.0", "id": req_id, "result": {"tools": [{"name": "updated_tool"}]}, } sys.stdout.write(json.dumps(response) + "\\n") sys.stdout.flush() else: response = { "jsonrpc": "2.0", "id": req_id, "error": {"code": -32601, "message": "Method not found"}, } sys.stdout.write(json.dumps(response) + "\\n") sys.stdout.flush() """) # 写入 stderr 的 mock server STDERR_SERVER_SCRIPT = textwrap.dedent("""\ import sys import json for line in sys.stdin: line = line.strip() if not line: continue try: data = json.loads(line) except json.JSONDecodeError: continue if "id" not in data: continue method = data.get("method", "") req_id = data.get("id") if method == "initialize": # 写入 stderr sys.stderr.write("mock server starting\\n") sys.stderr.flush() response = { "jsonrpc": "2.0", "id": req_id, "result": {"protocolVersion": "2024-11-05", "capabilities": {}, "serverInfo": {"name": "stderr-server", "version": "0.1.0"}}, } sys.stdout.write(json.dumps(response) + "\\n") sys.stdout.flush() else: response = { "jsonrpc": "2.0", "id": req_id, "result": {}, } sys.stdout.write(json.dumps(response) + "\\n") sys.stdout.flush() """) def _make_transport(script: str, **kwargs) -> StdioTransport: """创建使用内联 mock server 脚本的 StdioTransport""" return StdioTransport( command=sys.executable, args=["-c", script], **kwargs, ) # ── 构造测试 ────────────────────────────────────────── class TestStdioTransportConstruction: """StdioTransport 构造测试""" def test_default_args(self): transport = StdioTransport(command="echo") assert transport._command == "echo" assert transport._args == [] assert transport._env is None assert transport._timeout == 30.0 assert transport._process is None assert transport._request_id == 0 assert transport._pending == {} assert transport._reader_task is None assert transport._stderr_task is None assert transport._connected is False def test_custom_args(self): transport = StdioTransport( command="node", args=["server.js", "--port", "3000"], env={"NODE_ENV": "test"}, timeout=10.0, ) assert transport._command == "node" assert transport._args == ["server.js", "--port", "3000"] assert transport._env == {"NODE_ENV": "test"} assert transport._timeout == 10.0 def test_is_connected_initially_false(self): transport = StdioTransport(command="echo") assert not transport.is_connected # ── 连接/断开测试 ────────────────────────────────────────── class TestStdioTransportConnect: """StdioTransport 连接测试""" async def test_connect_starts_subprocess(self): transport = _make_transport(MOCK_SERVER_SCRIPT) try: await transport.connect() assert transport.is_connected assert transport._process is not None assert transport._process.returncode is None assert transport._reader_task is not None assert transport._stderr_task is not None finally: await transport.disconnect() async def test_connect_completes_initialize_handshake(self): transport = _make_transport(MOCK_SERVER_SCRIPT) try: await transport.connect() # connect 成功说明 initialize 握手完成 assert transport._connected is True # request_id 应该至少递增到 1(initialize 请求) assert transport._request_id >= 1 finally: await transport.disconnect() async def test_connect_is_idempotent(self): transport = _make_transport(MOCK_SERVER_SCRIPT) try: await transport.connect() await transport.connect() # 不应报错 assert transport.is_connected finally: await transport.disconnect() async def test_connect_with_invalid_command_raises(self): transport = StdioTransport(command="/nonexistent/command") with pytest.raises(TransportError, match="Failed to start process"): await transport.connect() async def test_connect_timeout(self): """使用不响应 initialize 的子进程测试超时""" # 使用一个只读取 stdin 但不输出任何内容的脚本 silent_script = "import sys; sys.stdin.read()" transport = StdioTransport( command=sys.executable, args=["-c", silent_script], timeout=0.5, ) with pytest.raises(TransportError, match="Timeout waiting for initialize"): await transport.connect() assert not transport.is_connected class TestStdioTransportDisconnect: """StdioTransport 断开测试""" async def test_disconnect_closes_subprocess(self): transport = _make_transport(MOCK_SERVER_SCRIPT) await transport.connect() assert transport.is_connected await transport.disconnect() assert not transport.is_connected assert transport._process is None assert transport._reader_task is None assert transport._stderr_task is None async def test_disconnect_is_idempotent(self): transport = _make_transport(MOCK_SERVER_SCRIPT) await transport.connect() await transport.disconnect() await transport.disconnect() # 不应报错 async def test_disconnect_cancels_pending_futures(self): """断开时所有 pending future 应收到 TransportError""" transport = _make_transport(MOCK_SERVER_SCRIPT) await transport.connect() # 手动添加一个 pending future loop = asyncio.get_running_loop() future = loop.create_future() transport._pending[999] = future await transport.disconnect() assert future.done() with pytest.raises(TransportError, match="Transport disconnected"): future.result() async def test_disconnect_clears_pending(self): transport = _make_transport(MOCK_SERVER_SCRIPT) await transport.connect() transport._pending[1] = asyncio.get_running_loop().create_future() await transport.disconnect() assert transport._pending == {} # ── 请求发送测试 ────────────────────────────────────────── class TestStdioTransportSendRequest: """StdioTransport 请求发送测试""" async def test_send_request_not_connected_raises(self): transport = StdioTransport(command="echo") with pytest.raises(TransportError, match="not connected"): await transport.send_request("tools/list") async def test_send_request_tools_list(self): transport = _make_transport(MOCK_SERVER_SCRIPT) try: await transport.connect() result = await transport.send_request("tools/list") assert "tools" in result assert len(result["tools"]) == 1 assert result["tools"][0]["name"] == "echo" finally: await transport.disconnect() async def test_send_request_tools_call(self): transport = _make_transport(MOCK_SERVER_SCRIPT) try: await transport.connect() result = await transport.send_request( "tools/call", params={"name": "echo", "arguments": {"msg": "hello world"}}, ) assert "content" in result assert result["content"][0]["text"] == "hello world" finally: await transport.disconnect() async def test_send_request_json_rpc_error(self): transport = _make_transport(MOCK_SERVER_SCRIPT) try: await transport.connect() with pytest.raises(TransportError, match="JSON-RPC error"): await transport.send_request("unknown/method") finally: await transport.disconnect() async def test_send_request_unknown_tool_error(self): transport = _make_transport(MOCK_SERVER_SCRIPT) try: await transport.connect() with pytest.raises(TransportError, match="Unknown tool"): await transport.send_request( "tools/call", params={"name": "nonexistent", "arguments": {}}, ) finally: await transport.disconnect() async def test_request_id_increments(self): transport = _make_transport(MOCK_SERVER_SCRIPT) try: await transport.connect() # connect 时已经用了 id=1 (initialize) id_before = transport._request_id await transport.send_request("tools/list") id_after_1 = transport._request_id await transport.send_request("tools/list") id_after_2 = transport._request_id assert id_after_1 == id_before + 1 assert id_after_2 == id_before + 2 finally: await transport.disconnect() async def test_send_request_timeout(self): """请求超时测试""" # 使用一个不响应的脚本 silent_script = "import sys; sys.stdin.read()" transport = StdioTransport( command=sys.executable, args=["-c", silent_script], timeout=0.5, ) # 手动设置连接状态以绕过 initialize transport._connected = True transport._process = await asyncio.create_subprocess_exec( sys.executable, "-c", silent_script, stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) transport._reader_task = asyncio.create_task(transport._read_stdout()) transport._stderr_task = asyncio.create_task(transport._read_stderr()) try: with pytest.raises(TransportError, match="Timeout"): await transport.send_request("tools/list") finally: transport._connected = False await transport._cleanup() # ── 并发请求测试 ────────────────────────────────────────── class TestStdioTransportConcurrentRequests: """StdioTransport 并发请求测试""" async def test_concurrent_requests_correct_id_matching(self): """并发请求的响应应正确匹配到对应的 Future""" transport = _make_transport(MOCK_SERVER_SCRIPT) try: await transport.connect() # 并发发送多个请求 results = await asyncio.gather( transport.send_request("tools/list"), transport.send_request( "tools/call", params={"name": "echo", "arguments": {"msg": "msg1"}}, ), transport.send_request( "tools/call", params={"name": "echo", "arguments": {"msg": "msg2"}}, ), ) # 验证每个请求都得到了正确类型的响应 assert "tools" in results[0] assert results[1]["content"][0]["text"] == "msg1" assert results[2]["content"][0]["text"] == "msg2" finally: await transport.disconnect() # ── 通知接收测试 ────────────────────────────────────────── class TestStdioTransportNotifications: """StdioTransport 通知接收测试""" async def test_receive_notification(self): transport = _make_transport(NOTIFICATION_SERVER_SCRIPT) try: await transport.connect() # tools/list 会先发送一个通知 result = await transport.send_request("tools/list") assert "tools" in result # 等待通知到达 await asyncio.sleep(0.1) # 应该能收到通知 notification = await transport.receive_response() assert notification["method"] == "notifications/tools/list_changed" finally: await transport.disconnect() async def test_receive_response_no_notification_raises(self): """空通知队列时 receive_response 超时抛出 TransportError""" transport = _make_transport(MOCK_SERVER_SCRIPT) try: await transport.connect() # 临时缩短 receive_response 超时 transport._timeout = 0.1 with pytest.raises(TransportError, match="Timeout"): await transport.receive_response() finally: await transport.disconnect() async def test_receive_response_not_connected_raises(self): transport = StdioTransport(command="echo") with pytest.raises(TransportError, match="not connected"): await transport.receive_response() # ── 子进程退出检测测试 ────────────────────────────────────────── class TestStdioTransportProcessExit: """StdioTransport 子进程退出检测测试""" async def test_subprocess_exit_detection(self): """子进程退出后 is_connected 应返回 False""" transport = _make_transport(EXIT_AFTER_INIT_SCRIPT) try: await transport.connect() assert transport.is_connected # 等待子进程退出 await asyncio.sleep(0.5) # 子进程已退出,is_connected 应为 False assert not transport.is_connected finally: if transport._process is not None: await transport.disconnect() async def test_send_request_after_process_exit_raises(self): """子进程退出后发送请求应抛出 TransportError""" transport = _make_transport(EXIT_AFTER_INIT_SCRIPT) try: await transport.connect() # 等待子进程退出 await asyncio.sleep(0.5) if not transport.is_connected: with pytest.raises(TransportError, match="not connected"): await transport.send_request("tools/list") finally: if transport._process is not None: await transport.disconnect() # ── stderr 转发测试 ────────────────────────────────────────── class TestStdioTransportStderr: """StdioTransport stderr 转发测试""" async def test_stderr_forwarded_to_logger(self, caplog): """stderr 输出应转发到 logger""" import logging transport = _make_transport(STDERR_SERVER_SCRIPT) try: with caplog.at_level(logging.DEBUG, logger="agentkit.mcp.transport"): await transport.connect() # 发送一个请求触发 stderr 输出 await transport.send_request("tools/list") # 等待 stderr 被读取 await asyncio.sleep(0.2) # 检查日志中包含 stderr 输出 stderr_logs = [ r for r in caplog.records if "mock server starting" in r.message ] assert len(stderr_logs) > 0 finally: await transport.disconnect() # ── is_connected 属性测试 ────────────────────────────────────────── class TestStdioTransportIsConnected: """StdioTransport is_connected 属性测试""" async def test_is_connected_before_connect(self): transport = _make_transport(MOCK_SERVER_SCRIPT) assert not transport.is_connected async def test_is_connected_after_connect(self): transport = _make_transport(MOCK_SERVER_SCRIPT) try: await transport.connect() assert transport.is_connected finally: await transport.disconnect() async def test_is_connected_after_disconnect(self): transport = _make_transport(MOCK_SERVER_SCRIPT) await transport.connect() await transport.disconnect() assert not transport.is_connected async def test_is_connected_checks_process_returncode(self): """is_connected 应检查 process.returncode""" transport = _make_transport(MOCK_SERVER_SCRIPT) try: await transport.connect() assert transport._process is not None # 模拟进程退出但 _connected 仍为 True transport._connected = True # 终止子进程 transport._process.kill() await transport._process.wait() # is_connected 应为 False 因为 returncode 不为 None assert not transport.is_connected finally: transport._connected = False await transport._cleanup() # ── 完整生命周期测试 ────────────────────────────────────────── class TestStdioTransportLifecycle: """StdioTransport 完整生命周期测试""" async def test_full_lifecycle(self): """测试完整的 connect → send_request → disconnect 生命周期""" transport = _make_transport(MOCK_SERVER_SCRIPT) # 1. 连接 await transport.connect() assert transport.is_connected # 2. 发送请求 result = await transport.send_request("tools/list") assert "tools" in result # 3. 发送带参数的请求 result = await transport.send_request( "tools/call", params={"name": "echo", "arguments": {"msg": "test"}}, ) assert result["content"][0]["text"] == "test" # 4. 断开 await transport.disconnect() assert not transport.is_connected async def test_reconnect_after_disconnect(self): """测试断开后重新连接""" transport = _make_transport(MOCK_SERVER_SCRIPT) # 第一次连接 await transport.connect() result1 = await transport.send_request("tools/list") assert "tools" in result1 await transport.disconnect() # 重新连接 await transport.connect() result2 = await transport.send_request("tools/list") assert "tools" in result2 await transport.disconnect() async def test_env_variables_passed_to_subprocess(self): """测试环境变量传递给子进程""" # 使用打印环境变量的脚本 env_script = textwrap.dedent("""\ import sys import json import os for line in sys.stdin: line = line.strip() if not line: continue try: data = json.loads(line) except json.JSONDecodeError: continue if "id" not in data: continue method = data.get("method", "") req_id = data.get("id") if method == "initialize": response = { "jsonrpc": "2.0", "id": req_id, "result": {"protocolVersion": "2024-11-05", "capabilities": {}, "serverInfo": {"name": "env-server", "version": "0.1.0"}}, } sys.stdout.write(json.dumps(response) + "\\n") sys.stdout.flush() elif method == "tools/call": test_env = os.environ.get("TEST_MCP_VAR", "not_set") response = { "jsonrpc": "2.0", "id": req_id, "result": { "content": [{"type": "text", "text": test_env}] }, } sys.stdout.write(json.dumps(response) + "\\n") sys.stdout.flush() else: response = { "jsonrpc": "2.0", "id": req_id, "result": {}, } sys.stdout.write(json.dumps(response) + "\\n") sys.stdout.flush() """) transport = StdioTransport( command=sys.executable, args=["-c", env_script], env={"TEST_MCP_VAR": "hello_from_env"}, ) try: await transport.connect() result = await transport.send_request( "tools/call", params={"name": "check_env", "arguments": {}}, ) assert result["content"][0]["text"] == "hello_from_env" finally: await transport.disconnect()