diff --git a/src/agentkit/mcp/transport.py b/src/agentkit/mcp/transport.py index cd636fc..32ad36e 100644 --- a/src/agentkit/mcp/transport.py +++ b/src/agentkit/mcp/transport.py @@ -1,11 +1,12 @@ """MCP Transport - 传输层抽象 -提供 MCP 协议的传输层实现,支持 Streamable HTTP 和 SSE 两种传输方式。 +提供 MCP 协议的传输层实现,支持 Streamable HTTP、SSE 和 Stdio 三种传输方式。 """ import asyncio import json import logging +import os from abc import ABC, abstractmethod from typing import Any @@ -352,3 +353,304 @@ class SSETransport(Transport): ) except asyncio.TimeoutError: raise TransportError("Timeout waiting for SSE response") + + +class StdioTransport(Transport): + """Stdio 传输 + + 通过 stdin/stdout 与 MCP Server 子进程通信,使用 newline-delimited JSON-RPC 消息格式。 + """ + + def __init__( + self, + command: str, + args: list[str] | None = None, + env: dict[str, str] | None = None, + timeout: float = 30.0, + ): + self._command = command + self._args = args or [] + self._env = env + self._timeout = timeout + self._process: asyncio.subprocess.Process | None = None + self._request_id = 0 + self._pending: dict[int, asyncio.Future[Any]] = {} + self._reader_task: asyncio.Task[None] | None = None + self._stderr_task: asyncio.Task[None] | None = None + self._connected = False + self._notifications: asyncio.Queue[dict[str, Any]] = asyncio.Queue() + + @property + def is_connected(self) -> bool: + return ( + self._connected + and self._process is not None + and self._process.returncode is None + ) + + def _next_request_id(self) -> int: + """生成下一个请求 ID""" + self._request_id += 1 + return self._request_id + + async def connect(self) -> None: + """启动子进程并完成 MCP 初始化握手 + + Raises: + TransportError: 子进程启动失败或初始化超时 + """ + if self.is_connected: + return + + # 合并环境变量 + merged_env = dict(os.environ) + if self._env: + merged_env.update(self._env) + + try: + self._process = await asyncio.create_subprocess_exec( + self._command, + *self._args, + env=merged_env, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + except OSError as e: + raise TransportError(f"Failed to start process: {self._command}", cause=e) from e + + # 启动 stdout 读取任务 + self._reader_task = asyncio.create_task(self._read_stdout()) + + # 启动 stderr 读取任务 + self._stderr_task = asyncio.create_task(self._read_stderr()) + + # 发送 initialize 请求并等待响应 + try: + init_result = await asyncio.wait_for( + self._send_request_internal( + "initialize", + { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": {"name": "agentkit", "version": "0.1.0"}, + }, + ), + timeout=self._timeout, + ) + except asyncio.TimeoutError: + await self._cleanup() + raise TransportError("Timeout waiting for initialize response") + except TransportError: + await self._cleanup() + raise + + # 发送 initialized 通知 + await self._send_notification("notifications/initialized") + + self._connected = True + logger.info( + "StdioTransport connected to %s %s", + self._command, + " ".join(self._args), + ) + + async def disconnect(self) -> None: + """关闭子进程连接""" + self._connected = False + await self._cleanup() + + async def _cleanup(self) -> None: + """清理子进程和相关资源""" + # 取消读取任务 + for task in (self._reader_task, self._stderr_task): + if task is not None: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + self._reader_task = None + self._stderr_task = None + + # 关闭 stdin + if self._process is not None and self._process.stdin is not None: + self._process.stdin.close() + try: + await self._process.stdin.drain() + except Exception: + pass + + # 等待子进程退出 + if self._process is not None and self._process.returncode is None: + try: + await asyncio.wait_for(self._process.wait(), timeout=5.0) + except asyncio.TimeoutError: + self._process.kill() + await self._process.wait() + + self._process = None + + # 取消所有等待中的 Future + for future in self._pending.values(): + if not future.done(): + future.set_exception(TransportError("Transport disconnected")) + self._pending.clear() + + logger.info("StdioTransport disconnected") + + async def send_request(self, method: str, params: dict[str, Any] | None = None) -> Any: + """发送 JSON-RPC 请求并等待响应 + + Args: + method: JSON-RPC 方法名 + params: 请求参数 + + Returns: + JSON-RPC 响应的 result 字段 + + Raises: + TransportError: 连接未建立或请求失败 + """ + if not self.is_connected: + raise TransportError("Transport not connected") + return await self._send_request_internal(method, params) + + async def _send_request_internal( + self, method: str, params: dict[str, Any] | None = None + ) -> Any: + """内部请求发送方法(connect 时也可调用)""" + request_id = self._next_request_id() + message: dict[str, Any] = { + "jsonrpc": "2.0", + "id": request_id, + "method": method, + } + if params is not None: + message["params"] = params + + await self._write_message(message) + + loop = asyncio.get_running_loop() + future: asyncio.Future[Any] = loop.create_future() + self._pending[request_id] = future + + try: + return await asyncio.wait_for(future, timeout=self._timeout) + except asyncio.TimeoutError: + self._pending.pop(request_id, None) + raise TransportError(f"Timeout waiting for response to {method}") + except TransportError: + self._pending.pop(request_id, None) + raise + + async def _send_notification(self, method: str, params: dict[str, Any] | None = None) -> None: + """发送 JSON-RPC 通知(无 id,不期待响应)""" + message: dict[str, Any] = { + "jsonrpc": "2.0", + "method": method, + } + if params is not None: + message["params"] = params + await self._write_message(message) + + async def _write_message(self, message: dict[str, Any]) -> None: + """将 JSON-RPC 消息写入子进程 stdin""" + if self._process is None or self._process.stdin is None: + raise TransportError("Process stdin not available") + data = (json.dumps(message) + "\n").encode("utf-8") + self._process.stdin.write(data) + await self._process.stdin.drain() + + async def receive_response(self) -> dict[str, Any]: + """接收通知消息 + + 对于 StdioTransport,请求响应通过 _pending Future 异步返回。 + 此方法仅用于获取服务端推送的通知消息。 + + Returns: + JSON-RPC 通知消息 + + Raises: + TransportError: 连接未建立或无通知 + """ + if not self.is_connected: + raise TransportError("Transport not connected") + + if not self._notifications.empty(): + return self._notifications.get_nowait() + + raise TransportError("No notification to receive") + + async def _read_stdout(self) -> None: + """持续从子进程 stdout 读取 JSON-RPC 消息""" + if self._process is None or self._process.stdout is None: + return + + try: + while True: + line = await self._process.stdout.readline() + if not line: + # EOF — 子进程退出 + if self._connected: + logger.warning("StdioTransport: subprocess stdout EOF") + break + + line_str = line.decode("utf-8").strip() + if not line_str: + continue + + try: + data = json.loads(line_str) + except json.JSONDecodeError: + logger.warning("StdioTransport: invalid JSON from stdout: %s", line_str) + continue + + # 响应消息(有 id 字段) + if "id" in data: + request_id = data["id"] + future = self._pending.pop(request_id, None) + if future is not None and not future.done(): + if "error" in data: + error = data["error"] + future.set_exception( + TransportError( + f"JSON-RPC error {error.get('code')}: {error.get('message')}" + ) + ) + else: + future.set_result(data.get("result")) + elif future is None: + logger.warning( + "StdioTransport: received response for unknown request id %s", + request_id, + ) + + # 通知消息(有 method 字段,无 id) + elif "method" in data: + await self._notifications.put(data) + + except asyncio.CancelledError: + raise + except Exception as e: + if self._connected: + logger.error("StdioTransport: stdout reader error: %s", e) + + async def _read_stderr(self) -> None: + """持续从子进程 stderr 读取并转发到 logger""" + if self._process is None or self._process.stderr is None: + return + + try: + while True: + line = await self._process.stderr.readline() + if not line: + break + line_str = line.decode("utf-8", errors="replace").rstrip() + if line_str: + logger.debug("StdioTransport stderr: %s", line_str) + except asyncio.CancelledError: + raise + except Exception as e: + if self._connected: + logger.error("StdioTransport: stderr reader error: %s", e) diff --git a/tests/unit/test_stdio_transport.py b/tests/unit/test_stdio_transport.py new file mode 100644 index 0000000..4b3ae65 --- /dev/null +++ b/tests/unit/test_stdio_transport.py @@ -0,0 +1,749 @@ +"""StdioTransport 单元测试 + +使用内联 mock MCP server 子进程进行测试,无需外部依赖。 +""" + +import asyncio +import json +import sys +import textwrap + +import pytest + +from agentkit.mcp.transport import StdioTransport, TransportError + +# 内联 mock MCP server 脚本 +# 读取 stdin 的 JSON-RPC 消息,根据 method 返回对应响应 +MOCK_SERVER_SCRIPT = textwrap.dedent("""\ + import sys + import json + + def handle_request(data): + method = data.get("method", "") + req_id = data.get("id") + params = data.get("params", {}) + + if method == "initialize": + return { + "jsonrpc": "2.0", + "id": req_id, + "result": { + "protocolVersion": "2024-11-05", + "capabilities": {"tools": {"listChanged": True}}, + "serverInfo": {"name": "mock-mcp-server", "version": "0.1.0"}, + }, + } + elif method == "tools/list": + return { + "jsonrpc": "2.0", + "id": req_id, + "result": { + "tools": [ + { + "name": "echo", + "description": "Echo tool", + "inputSchema": {"type": "object", "properties": {"msg": {"type": "string"}}}, + } + ] + }, + } + elif method == "tools/call": + name = params.get("name", "") + arguments = params.get("arguments", {}) + if name == "echo": + return { + "jsonrpc": "2.0", + "id": req_id, + "result": { + "content": [{"type": "text", "text": arguments.get("msg", "")}] + }, + } + else: + return { + "jsonrpc": "2.0", + "id": req_id, + "error": {"code": -32601, "message": f"Unknown tool: {name}"}, + } + else: + return { + "jsonrpc": "2.0", + "id": req_id, + "error": {"code": -32601, "message": f"Method not found: {method}"}, + } + + def main(): + for line in sys.stdin: + line = line.strip() + if not line: + continue + try: + data = json.loads(line) + except json.JSONDecodeError: + continue + # 通知消息(无 id)不回复 + if "id" not in data: + continue + response = handle_request(data) + sys.stdout.write(json.dumps(response) + "\\n") + sys.stdout.flush() + + if __name__ == "__main__": + main() +""") + +# 发送通知后立即退出的 mock server(用于测试子进程退出检测) +EXIT_AFTER_INIT_SCRIPT = textwrap.dedent("""\ + import sys + import json + + for line in sys.stdin: + line = line.strip() + if not line: + continue + try: + data = json.loads(line) + except json.JSONDecodeError: + continue + if "id" not in data: + continue + method = data.get("method", "") + req_id = data.get("id") + if method == "initialize": + response = { + "jsonrpc": "2.0", + "id": req_id, + "result": {"protocolVersion": "2024-11-05", "capabilities": {}, "serverInfo": {"name": "exit-server", "version": "0.1.0"}}, + } + sys.stdout.write(json.dumps(response) + "\\n") + sys.stdout.flush() + # 初始化后立即退出 + sys.exit(0) + else: + # 不会到达这里 + pass +""") + +# 发送通知的 mock server(用于测试通知接收) +NOTIFICATION_SERVER_SCRIPT = textwrap.dedent("""\ + import sys + import json + + for line in sys.stdin: + line = line.strip() + if not line: + continue + try: + data = json.loads(line) + except json.JSONDecodeError: + continue + if "id" not in data: + continue + method = data.get("method", "") + req_id = data.get("id") + if method == "initialize": + response = { + "jsonrpc": "2.0", + "id": req_id, + "result": {"protocolVersion": "2024-11-05", "capabilities": {}, "serverInfo": {"name": "notif-server", "version": "0.1.0"}}, + } + sys.stdout.write(json.dumps(response) + "\\n") + sys.stdout.flush() + elif method == "tools/list": + # 先发送一个通知 + notification = { + "jsonrpc": "2.0", + "method": "notifications/tools/list_changed", + } + sys.stdout.write(json.dumps(notification) + "\\n") + sys.stdout.flush() + # 再发送响应 + response = { + "jsonrpc": "2.0", + "id": req_id, + "result": {"tools": [{"name": "updated_tool"}]}, + } + sys.stdout.write(json.dumps(response) + "\\n") + sys.stdout.flush() + else: + response = { + "jsonrpc": "2.0", + "id": req_id, + "error": {"code": -32601, "message": "Method not found"}, + } + sys.stdout.write(json.dumps(response) + "\\n") + sys.stdout.flush() +""") + +# 写入 stderr 的 mock server +STDERR_SERVER_SCRIPT = textwrap.dedent("""\ + import sys + import json + + for line in sys.stdin: + line = line.strip() + if not line: + continue + try: + data = json.loads(line) + except json.JSONDecodeError: + continue + if "id" not in data: + continue + method = data.get("method", "") + req_id = data.get("id") + if method == "initialize": + # 写入 stderr + sys.stderr.write("mock server starting\\n") + sys.stderr.flush() + response = { + "jsonrpc": "2.0", + "id": req_id, + "result": {"protocolVersion": "2024-11-05", "capabilities": {}, "serverInfo": {"name": "stderr-server", "version": "0.1.0"}}, + } + sys.stdout.write(json.dumps(response) + "\\n") + sys.stdout.flush() + else: + response = { + "jsonrpc": "2.0", + "id": req_id, + "result": {}, + } + sys.stdout.write(json.dumps(response) + "\\n") + sys.stdout.flush() +""") + + +def _make_transport(script: str, **kwargs) -> StdioTransport: + """创建使用内联 mock server 脚本的 StdioTransport""" + return StdioTransport( + command=sys.executable, + args=["-c", script], + **kwargs, + ) + + +# ── 构造测试 ────────────────────────────────────────── + + +class TestStdioTransportConstruction: + """StdioTransport 构造测试""" + + def test_default_args(self): + transport = StdioTransport(command="echo") + assert transport._command == "echo" + assert transport._args == [] + assert transport._env is None + assert transport._timeout == 30.0 + assert transport._process is None + assert transport._request_id == 0 + assert transport._pending == {} + assert transport._reader_task is None + assert transport._stderr_task is None + assert transport._connected is False + + def test_custom_args(self): + transport = StdioTransport( + command="node", + args=["server.js", "--port", "3000"], + env={"NODE_ENV": "test"}, + timeout=10.0, + ) + assert transport._command == "node" + assert transport._args == ["server.js", "--port", "3000"] + assert transport._env == {"NODE_ENV": "test"} + assert transport._timeout == 10.0 + + def test_is_connected_initially_false(self): + transport = StdioTransport(command="echo") + assert not transport.is_connected + + +# ── 连接/断开测试 ────────────────────────────────────────── + + +class TestStdioTransportConnect: + """StdioTransport 连接测试""" + + async def test_connect_starts_subprocess(self): + transport = _make_transport(MOCK_SERVER_SCRIPT) + try: + await transport.connect() + assert transport.is_connected + assert transport._process is not None + assert transport._process.returncode is None + assert transport._reader_task is not None + assert transport._stderr_task is not None + finally: + await transport.disconnect() + + async def test_connect_completes_initialize_handshake(self): + transport = _make_transport(MOCK_SERVER_SCRIPT) + try: + await transport.connect() + # connect 成功说明 initialize 握手完成 + assert transport._connected is True + # request_id 应该至少递增到 1(initialize 请求) + assert transport._request_id >= 1 + finally: + await transport.disconnect() + + async def test_connect_is_idempotent(self): + transport = _make_transport(MOCK_SERVER_SCRIPT) + try: + await transport.connect() + await transport.connect() # 不应报错 + assert transport.is_connected + finally: + await transport.disconnect() + + async def test_connect_with_invalid_command_raises(self): + transport = StdioTransport(command="/nonexistent/command") + with pytest.raises(TransportError, match="Failed to start process"): + await transport.connect() + + async def test_connect_timeout(self): + """使用不响应 initialize 的子进程测试超时""" + # 使用一个只读取 stdin 但不输出任何内容的脚本 + silent_script = "import sys; sys.stdin.read()" + transport = StdioTransport( + command=sys.executable, + args=["-c", silent_script], + timeout=0.5, + ) + with pytest.raises(TransportError, match="Timeout waiting for initialize"): + await transport.connect() + assert not transport.is_connected + + +class TestStdioTransportDisconnect: + """StdioTransport 断开测试""" + + async def test_disconnect_closes_subprocess(self): + transport = _make_transport(MOCK_SERVER_SCRIPT) + await transport.connect() + assert transport.is_connected + + await transport.disconnect() + assert not transport.is_connected + assert transport._process is None + assert transport._reader_task is None + assert transport._stderr_task is None + + async def test_disconnect_is_idempotent(self): + transport = _make_transport(MOCK_SERVER_SCRIPT) + await transport.connect() + await transport.disconnect() + await transport.disconnect() # 不应报错 + + async def test_disconnect_cancels_pending_futures(self): + """断开时所有 pending future 应收到 TransportError""" + transport = _make_transport(MOCK_SERVER_SCRIPT) + await transport.connect() + + # 手动添加一个 pending future + loop = asyncio.get_running_loop() + future = loop.create_future() + transport._pending[999] = future + + await transport.disconnect() + + assert future.done() + with pytest.raises(TransportError, match="Transport disconnected"): + future.result() + + async def test_disconnect_clears_pending(self): + transport = _make_transport(MOCK_SERVER_SCRIPT) + await transport.connect() + transport._pending[1] = asyncio.get_running_loop().create_future() + await transport.disconnect() + assert transport._pending == {} + + +# ── 请求发送测试 ────────────────────────────────────────── + + +class TestStdioTransportSendRequest: + """StdioTransport 请求发送测试""" + + async def test_send_request_not_connected_raises(self): + transport = StdioTransport(command="echo") + with pytest.raises(TransportError, match="not connected"): + await transport.send_request("tools/list") + + async def test_send_request_tools_list(self): + transport = _make_transport(MOCK_SERVER_SCRIPT) + try: + await transport.connect() + result = await transport.send_request("tools/list") + assert "tools" in result + assert len(result["tools"]) == 1 + assert result["tools"][0]["name"] == "echo" + finally: + await transport.disconnect() + + async def test_send_request_tools_call(self): + transport = _make_transport(MOCK_SERVER_SCRIPT) + try: + await transport.connect() + result = await transport.send_request( + "tools/call", + params={"name": "echo", "arguments": {"msg": "hello world"}}, + ) + assert "content" in result + assert result["content"][0]["text"] == "hello world" + finally: + await transport.disconnect() + + async def test_send_request_json_rpc_error(self): + transport = _make_transport(MOCK_SERVER_SCRIPT) + try: + await transport.connect() + with pytest.raises(TransportError, match="JSON-RPC error"): + await transport.send_request("unknown/method") + finally: + await transport.disconnect() + + async def test_send_request_unknown_tool_error(self): + transport = _make_transport(MOCK_SERVER_SCRIPT) + try: + await transport.connect() + with pytest.raises(TransportError, match="Unknown tool"): + await transport.send_request( + "tools/call", + params={"name": "nonexistent", "arguments": {}}, + ) + finally: + await transport.disconnect() + + async def test_request_id_increments(self): + transport = _make_transport(MOCK_SERVER_SCRIPT) + try: + await transport.connect() + # connect 时已经用了 id=1 (initialize) + id_before = transport._request_id + await transport.send_request("tools/list") + id_after_1 = transport._request_id + await transport.send_request("tools/list") + id_after_2 = transport._request_id + assert id_after_1 == id_before + 1 + assert id_after_2 == id_before + 2 + finally: + await transport.disconnect() + + async def test_send_request_timeout(self): + """请求超时测试""" + # 使用一个不响应的脚本 + silent_script = "import sys; sys.stdin.read()" + transport = StdioTransport( + command=sys.executable, + args=["-c", silent_script], + timeout=0.5, + ) + # 手动设置连接状态以绕过 initialize + transport._connected = True + transport._process = await asyncio.create_subprocess_exec( + sys.executable, "-c", silent_script, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + transport._reader_task = asyncio.create_task(transport._read_stdout()) + transport._stderr_task = asyncio.create_task(transport._read_stderr()) + + try: + with pytest.raises(TransportError, match="Timeout"): + await transport.send_request("tools/list") + finally: + transport._connected = False + await transport._cleanup() + + +# ── 并发请求测试 ────────────────────────────────────────── + + +class TestStdioTransportConcurrentRequests: + """StdioTransport 并发请求测试""" + + async def test_concurrent_requests_correct_id_matching(self): + """并发请求的响应应正确匹配到对应的 Future""" + transport = _make_transport(MOCK_SERVER_SCRIPT) + try: + await transport.connect() + + # 并发发送多个请求 + results = await asyncio.gather( + transport.send_request("tools/list"), + transport.send_request( + "tools/call", + params={"name": "echo", "arguments": {"msg": "msg1"}}, + ), + transport.send_request( + "tools/call", + params={"name": "echo", "arguments": {"msg": "msg2"}}, + ), + ) + + # 验证每个请求都得到了正确类型的响应 + assert "tools" in results[0] + assert results[1]["content"][0]["text"] == "msg1" + assert results[2]["content"][0]["text"] == "msg2" + finally: + await transport.disconnect() + + +# ── 通知接收测试 ────────────────────────────────────────── + + +class TestStdioTransportNotifications: + """StdioTransport 通知接收测试""" + + async def test_receive_notification(self): + transport = _make_transport(NOTIFICATION_SERVER_SCRIPT) + try: + await transport.connect() + + # tools/list 会先发送一个通知 + result = await transport.send_request("tools/list") + assert "tools" in result + + # 等待通知到达 + await asyncio.sleep(0.1) + + # 应该能收到通知 + notification = await transport.receive_response() + assert notification["method"] == "notifications/tools/list_changed" + finally: + await transport.disconnect() + + async def test_receive_response_no_notification_raises(self): + transport = _make_transport(MOCK_SERVER_SCRIPT) + try: + await transport.connect() + with pytest.raises(TransportError, match="No notification"): + await transport.receive_response() + finally: + await transport.disconnect() + + async def test_receive_response_not_connected_raises(self): + transport = StdioTransport(command="echo") + with pytest.raises(TransportError, match="not connected"): + await transport.receive_response() + + +# ── 子进程退出检测测试 ────────────────────────────────────────── + + +class TestStdioTransportProcessExit: + """StdioTransport 子进程退出检测测试""" + + async def test_subprocess_exit_detection(self): + """子进程退出后 is_connected 应返回 False""" + transport = _make_transport(EXIT_AFTER_INIT_SCRIPT) + try: + await transport.connect() + assert transport.is_connected + + # 等待子进程退出 + await asyncio.sleep(0.5) + + # 子进程已退出,is_connected 应为 False + assert not transport.is_connected + finally: + if transport._process is not None: + await transport.disconnect() + + async def test_send_request_after_process_exit_raises(self): + """子进程退出后发送请求应抛出 TransportError""" + transport = _make_transport(EXIT_AFTER_INIT_SCRIPT) + try: + await transport.connect() + # 等待子进程退出 + await asyncio.sleep(0.5) + + if not transport.is_connected: + with pytest.raises(TransportError, match="not connected"): + await transport.send_request("tools/list") + finally: + if transport._process is not None: + await transport.disconnect() + + +# ── stderr 转发测试 ────────────────────────────────────────── + + +class TestStdioTransportStderr: + """StdioTransport stderr 转发测试""" + + async def test_stderr_forwarded_to_logger(self, caplog): + """stderr 输出应转发到 logger""" + import logging + + transport = _make_transport(STDERR_SERVER_SCRIPT) + try: + with caplog.at_level(logging.DEBUG, logger="agentkit.mcp.transport"): + await transport.connect() + # 发送一个请求触发 stderr 输出 + await transport.send_request("tools/list") + # 等待 stderr 被读取 + await asyncio.sleep(0.2) + + # 检查日志中包含 stderr 输出 + stderr_logs = [ + r for r in caplog.records + if "mock server starting" in r.message + ] + assert len(stderr_logs) > 0 + finally: + await transport.disconnect() + + +# ── is_connected 属性测试 ────────────────────────────────────────── + + +class TestStdioTransportIsConnected: + """StdioTransport is_connected 属性测试""" + + async def test_is_connected_before_connect(self): + transport = _make_transport(MOCK_SERVER_SCRIPT) + assert not transport.is_connected + + async def test_is_connected_after_connect(self): + transport = _make_transport(MOCK_SERVER_SCRIPT) + try: + await transport.connect() + assert transport.is_connected + finally: + await transport.disconnect() + + async def test_is_connected_after_disconnect(self): + transport = _make_transport(MOCK_SERVER_SCRIPT) + await transport.connect() + await transport.disconnect() + assert not transport.is_connected + + async def test_is_connected_checks_process_returncode(self): + """is_connected 应检查 process.returncode""" + transport = _make_transport(MOCK_SERVER_SCRIPT) + try: + await transport.connect() + assert transport._process is not None + # 模拟进程退出但 _connected 仍为 True + transport._connected = True + # 终止子进程 + transport._process.kill() + await transport._process.wait() + # is_connected 应为 False 因为 returncode 不为 None + assert not transport.is_connected + finally: + transport._connected = False + await transport._cleanup() + + +# ── 完整生命周期测试 ────────────────────────────────────────── + + +class TestStdioTransportLifecycle: + """StdioTransport 完整生命周期测试""" + + async def test_full_lifecycle(self): + """测试完整的 connect → send_request → disconnect 生命周期""" + transport = _make_transport(MOCK_SERVER_SCRIPT) + + # 1. 连接 + await transport.connect() + assert transport.is_connected + + # 2. 发送请求 + result = await transport.send_request("tools/list") + assert "tools" in result + + # 3. 发送带参数的请求 + result = await transport.send_request( + "tools/call", + params={"name": "echo", "arguments": {"msg": "test"}}, + ) + assert result["content"][0]["text"] == "test" + + # 4. 断开 + await transport.disconnect() + assert not transport.is_connected + + async def test_reconnect_after_disconnect(self): + """测试断开后重新连接""" + transport = _make_transport(MOCK_SERVER_SCRIPT) + + # 第一次连接 + await transport.connect() + result1 = await transport.send_request("tools/list") + assert "tools" in result1 + await transport.disconnect() + + # 重新连接 + await transport.connect() + result2 = await transport.send_request("tools/list") + assert "tools" in result2 + await transport.disconnect() + + async def test_env_variables_passed_to_subprocess(self): + """测试环境变量传递给子进程""" + # 使用打印环境变量的脚本 + env_script = textwrap.dedent("""\ + import sys + import json + import os + + for line in sys.stdin: + line = line.strip() + if not line: + continue + try: + data = json.loads(line) + except json.JSONDecodeError: + continue + if "id" not in data: + continue + method = data.get("method", "") + req_id = data.get("id") + if method == "initialize": + response = { + "jsonrpc": "2.0", + "id": req_id, + "result": {"protocolVersion": "2024-11-05", "capabilities": {}, "serverInfo": {"name": "env-server", "version": "0.1.0"}}, + } + sys.stdout.write(json.dumps(response) + "\\n") + sys.stdout.flush() + elif method == "tools/call": + test_env = os.environ.get("TEST_MCP_VAR", "not_set") + response = { + "jsonrpc": "2.0", + "id": req_id, + "result": { + "content": [{"type": "text", "text": test_env}] + }, + } + sys.stdout.write(json.dumps(response) + "\\n") + sys.stdout.flush() + else: + response = { + "jsonrpc": "2.0", + "id": req_id, + "result": {}, + } + sys.stdout.write(json.dumps(response) + "\\n") + sys.stdout.flush() + """) + + transport = StdioTransport( + command=sys.executable, + args=["-c", env_script], + env={"TEST_MCP_VAR": "hello_from_env"}, + ) + try: + await transport.connect() + result = await transport.send_request( + "tools/call", + params={"name": "check_env", "arguments": {}}, + ) + assert result["content"][0]["text"] == "hello_from_env" + finally: + await transport.disconnect()