feat(mcp): U1 StdioTransport for subprocess-based MCP communication

Add StdioTransport class supporting stdio JSON-RPC over subprocess
stdin/stdout with asyncio.create_subprocess_exec, pending futures
for request/response matching, and stderr forwarding.
This commit is contained in:
chiguyong 2026-06-07 17:24:52 +08:00
parent 9b6c0230c0
commit 66b9217569
2 changed files with 1052 additions and 1 deletions

View File

@ -1,11 +1,12 @@
"""MCP Transport - 传输层抽象 """MCP Transport - 传输层抽象
提供 MCP 协议的传输层实现支持 Streamable HTTP SSE 种传输方式 提供 MCP 协议的传输层实现支持 Streamable HTTPSSE Stdio 种传输方式
""" """
import asyncio import asyncio
import json import json
import logging import logging
import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any from typing import Any
@ -352,3 +353,304 @@ class SSETransport(Transport):
) )
except asyncio.TimeoutError: except asyncio.TimeoutError:
raise TransportError("Timeout waiting for SSE response") raise TransportError("Timeout waiting for SSE response")
class StdioTransport(Transport):
"""Stdio 传输
通过 stdin/stdout MCP Server 子进程通信使用 newline-delimited JSON-RPC 消息格式
"""
def __init__(
self,
command: str,
args: list[str] | None = None,
env: dict[str, str] | None = None,
timeout: float = 30.0,
):
self._command = command
self._args = args or []
self._env = env
self._timeout = timeout
self._process: asyncio.subprocess.Process | None = None
self._request_id = 0
self._pending: dict[int, asyncio.Future[Any]] = {}
self._reader_task: asyncio.Task[None] | None = None
self._stderr_task: asyncio.Task[None] | None = None
self._connected = False
self._notifications: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
@property
def is_connected(self) -> bool:
return (
self._connected
and self._process is not None
and self._process.returncode is None
)
def _next_request_id(self) -> int:
"""生成下一个请求 ID"""
self._request_id += 1
return self._request_id
async def connect(self) -> None:
"""启动子进程并完成 MCP 初始化握手
Raises:
TransportError: 子进程启动失败或初始化超时
"""
if self.is_connected:
return
# 合并环境变量
merged_env = dict(os.environ)
if self._env:
merged_env.update(self._env)
try:
self._process = await asyncio.create_subprocess_exec(
self._command,
*self._args,
env=merged_env,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
except OSError as e:
raise TransportError(f"Failed to start process: {self._command}", cause=e) from e
# 启动 stdout 读取任务
self._reader_task = asyncio.create_task(self._read_stdout())
# 启动 stderr 读取任务
self._stderr_task = asyncio.create_task(self._read_stderr())
# 发送 initialize 请求并等待响应
try:
init_result = await asyncio.wait_for(
self._send_request_internal(
"initialize",
{
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {"name": "agentkit", "version": "0.1.0"},
},
),
timeout=self._timeout,
)
except asyncio.TimeoutError:
await self._cleanup()
raise TransportError("Timeout waiting for initialize response")
except TransportError:
await self._cleanup()
raise
# 发送 initialized 通知
await self._send_notification("notifications/initialized")
self._connected = True
logger.info(
"StdioTransport connected to %s %s",
self._command,
" ".join(self._args),
)
async def disconnect(self) -> None:
"""关闭子进程连接"""
self._connected = False
await self._cleanup()
async def _cleanup(self) -> None:
"""清理子进程和相关资源"""
# 取消读取任务
for task in (self._reader_task, self._stderr_task):
if task is not None:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
self._reader_task = None
self._stderr_task = None
# 关闭 stdin
if self._process is not None and self._process.stdin is not None:
self._process.stdin.close()
try:
await self._process.stdin.drain()
except Exception:
pass
# 等待子进程退出
if self._process is not None and self._process.returncode is None:
try:
await asyncio.wait_for(self._process.wait(), timeout=5.0)
except asyncio.TimeoutError:
self._process.kill()
await self._process.wait()
self._process = None
# 取消所有等待中的 Future
for future in self._pending.values():
if not future.done():
future.set_exception(TransportError("Transport disconnected"))
self._pending.clear()
logger.info("StdioTransport disconnected")
async def send_request(self, method: str, params: dict[str, Any] | None = None) -> Any:
"""发送 JSON-RPC 请求并等待响应
Args:
method: JSON-RPC 方法名
params: 请求参数
Returns:
JSON-RPC 响应的 result 字段
Raises:
TransportError: 连接未建立或请求失败
"""
if not self.is_connected:
raise TransportError("Transport not connected")
return await self._send_request_internal(method, params)
async def _send_request_internal(
self, method: str, params: dict[str, Any] | None = None
) -> Any:
"""内部请求发送方法connect 时也可调用)"""
request_id = self._next_request_id()
message: dict[str, Any] = {
"jsonrpc": "2.0",
"id": request_id,
"method": method,
}
if params is not None:
message["params"] = params
await self._write_message(message)
loop = asyncio.get_running_loop()
future: asyncio.Future[Any] = loop.create_future()
self._pending[request_id] = future
try:
return await asyncio.wait_for(future, timeout=self._timeout)
except asyncio.TimeoutError:
self._pending.pop(request_id, None)
raise TransportError(f"Timeout waiting for response to {method}")
except TransportError:
self._pending.pop(request_id, None)
raise
async def _send_notification(self, method: str, params: dict[str, Any] | None = None) -> None:
"""发送 JSON-RPC 通知(无 id不期待响应"""
message: dict[str, Any] = {
"jsonrpc": "2.0",
"method": method,
}
if params is not None:
message["params"] = params
await self._write_message(message)
async def _write_message(self, message: dict[str, Any]) -> None:
"""将 JSON-RPC 消息写入子进程 stdin"""
if self._process is None or self._process.stdin is None:
raise TransportError("Process stdin not available")
data = (json.dumps(message) + "\n").encode("utf-8")
self._process.stdin.write(data)
await self._process.stdin.drain()
async def receive_response(self) -> dict[str, Any]:
"""接收通知消息
对于 StdioTransport请求响应通过 _pending Future 异步返回
此方法仅用于获取服务端推送的通知消息
Returns:
JSON-RPC 通知消息
Raises:
TransportError: 连接未建立或无通知
"""
if not self.is_connected:
raise TransportError("Transport not connected")
if not self._notifications.empty():
return self._notifications.get_nowait()
raise TransportError("No notification to receive")
async def _read_stdout(self) -> None:
"""持续从子进程 stdout 读取 JSON-RPC 消息"""
if self._process is None or self._process.stdout is None:
return
try:
while True:
line = await self._process.stdout.readline()
if not line:
# EOF — 子进程退出
if self._connected:
logger.warning("StdioTransport: subprocess stdout EOF")
break
line_str = line.decode("utf-8").strip()
if not line_str:
continue
try:
data = json.loads(line_str)
except json.JSONDecodeError:
logger.warning("StdioTransport: invalid JSON from stdout: %s", line_str)
continue
# 响应消息(有 id 字段)
if "id" in data:
request_id = data["id"]
future = self._pending.pop(request_id, None)
if future is not None and not future.done():
if "error" in data:
error = data["error"]
future.set_exception(
TransportError(
f"JSON-RPC error {error.get('code')}: {error.get('message')}"
)
)
else:
future.set_result(data.get("result"))
elif future is None:
logger.warning(
"StdioTransport: received response for unknown request id %s",
request_id,
)
# 通知消息(有 method 字段,无 id
elif "method" in data:
await self._notifications.put(data)
except asyncio.CancelledError:
raise
except Exception as e:
if self._connected:
logger.error("StdioTransport: stdout reader error: %s", e)
async def _read_stderr(self) -> None:
"""持续从子进程 stderr 读取并转发到 logger"""
if self._process is None or self._process.stderr is None:
return
try:
while True:
line = await self._process.stderr.readline()
if not line:
break
line_str = line.decode("utf-8", errors="replace").rstrip()
if line_str:
logger.debug("StdioTransport stderr: %s", line_str)
except asyncio.CancelledError:
raise
except Exception as e:
if self._connected:
logger.error("StdioTransport: stderr reader error: %s", e)

View File

@ -0,0 +1,749 @@
"""StdioTransport 单元测试
使用内联 mock MCP server 子进程进行测试无需外部依赖
"""
import asyncio
import json
import sys
import textwrap
import pytest
from agentkit.mcp.transport import StdioTransport, TransportError
# 内联 mock MCP server 脚本
# 读取 stdin 的 JSON-RPC 消息,根据 method 返回对应响应
MOCK_SERVER_SCRIPT = textwrap.dedent("""\
import sys
import json
def handle_request(data):
method = data.get("method", "")
req_id = data.get("id")
params = data.get("params", {})
if method == "initialize":
return {
"jsonrpc": "2.0",
"id": req_id,
"result": {
"protocolVersion": "2024-11-05",
"capabilities": {"tools": {"listChanged": True}},
"serverInfo": {"name": "mock-mcp-server", "version": "0.1.0"},
},
}
elif method == "tools/list":
return {
"jsonrpc": "2.0",
"id": req_id,
"result": {
"tools": [
{
"name": "echo",
"description": "Echo tool",
"inputSchema": {"type": "object", "properties": {"msg": {"type": "string"}}},
}
]
},
}
elif method == "tools/call":
name = params.get("name", "")
arguments = params.get("arguments", {})
if name == "echo":
return {
"jsonrpc": "2.0",
"id": req_id,
"result": {
"content": [{"type": "text", "text": arguments.get("msg", "")}]
},
}
else:
return {
"jsonrpc": "2.0",
"id": req_id,
"error": {"code": -32601, "message": f"Unknown tool: {name}"},
}
else:
return {
"jsonrpc": "2.0",
"id": req_id,
"error": {"code": -32601, "message": f"Method not found: {method}"},
}
def main():
for line in sys.stdin:
line = line.strip()
if not line:
continue
try:
data = json.loads(line)
except json.JSONDecodeError:
continue
# 通知消息(无 id不回复
if "id" not in data:
continue
response = handle_request(data)
sys.stdout.write(json.dumps(response) + "\\n")
sys.stdout.flush()
if __name__ == "__main__":
main()
""")
# 发送通知后立即退出的 mock server用于测试子进程退出检测
EXIT_AFTER_INIT_SCRIPT = textwrap.dedent("""\
import sys
import json
for line in sys.stdin:
line = line.strip()
if not line:
continue
try:
data = json.loads(line)
except json.JSONDecodeError:
continue
if "id" not in data:
continue
method = data.get("method", "")
req_id = data.get("id")
if method == "initialize":
response = {
"jsonrpc": "2.0",
"id": req_id,
"result": {"protocolVersion": "2024-11-05", "capabilities": {}, "serverInfo": {"name": "exit-server", "version": "0.1.0"}},
}
sys.stdout.write(json.dumps(response) + "\\n")
sys.stdout.flush()
# 初始化后立即退出
sys.exit(0)
else:
# 不会到达这里
pass
""")
# 发送通知的 mock server用于测试通知接收
NOTIFICATION_SERVER_SCRIPT = textwrap.dedent("""\
import sys
import json
for line in sys.stdin:
line = line.strip()
if not line:
continue
try:
data = json.loads(line)
except json.JSONDecodeError:
continue
if "id" not in data:
continue
method = data.get("method", "")
req_id = data.get("id")
if method == "initialize":
response = {
"jsonrpc": "2.0",
"id": req_id,
"result": {"protocolVersion": "2024-11-05", "capabilities": {}, "serverInfo": {"name": "notif-server", "version": "0.1.0"}},
}
sys.stdout.write(json.dumps(response) + "\\n")
sys.stdout.flush()
elif method == "tools/list":
# 先发送一个通知
notification = {
"jsonrpc": "2.0",
"method": "notifications/tools/list_changed",
}
sys.stdout.write(json.dumps(notification) + "\\n")
sys.stdout.flush()
# 再发送响应
response = {
"jsonrpc": "2.0",
"id": req_id,
"result": {"tools": [{"name": "updated_tool"}]},
}
sys.stdout.write(json.dumps(response) + "\\n")
sys.stdout.flush()
else:
response = {
"jsonrpc": "2.0",
"id": req_id,
"error": {"code": -32601, "message": "Method not found"},
}
sys.stdout.write(json.dumps(response) + "\\n")
sys.stdout.flush()
""")
# 写入 stderr 的 mock server
STDERR_SERVER_SCRIPT = textwrap.dedent("""\
import sys
import json
for line in sys.stdin:
line = line.strip()
if not line:
continue
try:
data = json.loads(line)
except json.JSONDecodeError:
continue
if "id" not in data:
continue
method = data.get("method", "")
req_id = data.get("id")
if method == "initialize":
# 写入 stderr
sys.stderr.write("mock server starting\\n")
sys.stderr.flush()
response = {
"jsonrpc": "2.0",
"id": req_id,
"result": {"protocolVersion": "2024-11-05", "capabilities": {}, "serverInfo": {"name": "stderr-server", "version": "0.1.0"}},
}
sys.stdout.write(json.dumps(response) + "\\n")
sys.stdout.flush()
else:
response = {
"jsonrpc": "2.0",
"id": req_id,
"result": {},
}
sys.stdout.write(json.dumps(response) + "\\n")
sys.stdout.flush()
""")
def _make_transport(script: str, **kwargs) -> StdioTransport:
"""创建使用内联 mock server 脚本的 StdioTransport"""
return StdioTransport(
command=sys.executable,
args=["-c", script],
**kwargs,
)
# ── 构造测试 ──────────────────────────────────────────
class TestStdioTransportConstruction:
"""StdioTransport 构造测试"""
def test_default_args(self):
transport = StdioTransport(command="echo")
assert transport._command == "echo"
assert transport._args == []
assert transport._env is None
assert transport._timeout == 30.0
assert transport._process is None
assert transport._request_id == 0
assert transport._pending == {}
assert transport._reader_task is None
assert transport._stderr_task is None
assert transport._connected is False
def test_custom_args(self):
transport = StdioTransport(
command="node",
args=["server.js", "--port", "3000"],
env={"NODE_ENV": "test"},
timeout=10.0,
)
assert transport._command == "node"
assert transport._args == ["server.js", "--port", "3000"]
assert transport._env == {"NODE_ENV": "test"}
assert transport._timeout == 10.0
def test_is_connected_initially_false(self):
transport = StdioTransport(command="echo")
assert not transport.is_connected
# ── 连接/断开测试 ──────────────────────────────────────────
class TestStdioTransportConnect:
"""StdioTransport 连接测试"""
async def test_connect_starts_subprocess(self):
transport = _make_transport(MOCK_SERVER_SCRIPT)
try:
await transport.connect()
assert transport.is_connected
assert transport._process is not None
assert transport._process.returncode is None
assert transport._reader_task is not None
assert transport._stderr_task is not None
finally:
await transport.disconnect()
async def test_connect_completes_initialize_handshake(self):
transport = _make_transport(MOCK_SERVER_SCRIPT)
try:
await transport.connect()
# connect 成功说明 initialize 握手完成
assert transport._connected is True
# request_id 应该至少递增到 1initialize 请求)
assert transport._request_id >= 1
finally:
await transport.disconnect()
async def test_connect_is_idempotent(self):
transport = _make_transport(MOCK_SERVER_SCRIPT)
try:
await transport.connect()
await transport.connect() # 不应报错
assert transport.is_connected
finally:
await transport.disconnect()
async def test_connect_with_invalid_command_raises(self):
transport = StdioTransport(command="/nonexistent/command")
with pytest.raises(TransportError, match="Failed to start process"):
await transport.connect()
async def test_connect_timeout(self):
"""使用不响应 initialize 的子进程测试超时"""
# 使用一个只读取 stdin 但不输出任何内容的脚本
silent_script = "import sys; sys.stdin.read()"
transport = StdioTransport(
command=sys.executable,
args=["-c", silent_script],
timeout=0.5,
)
with pytest.raises(TransportError, match="Timeout waiting for initialize"):
await transport.connect()
assert not transport.is_connected
class TestStdioTransportDisconnect:
"""StdioTransport 断开测试"""
async def test_disconnect_closes_subprocess(self):
transport = _make_transport(MOCK_SERVER_SCRIPT)
await transport.connect()
assert transport.is_connected
await transport.disconnect()
assert not transport.is_connected
assert transport._process is None
assert transport._reader_task is None
assert transport._stderr_task is None
async def test_disconnect_is_idempotent(self):
transport = _make_transport(MOCK_SERVER_SCRIPT)
await transport.connect()
await transport.disconnect()
await transport.disconnect() # 不应报错
async def test_disconnect_cancels_pending_futures(self):
"""断开时所有 pending future 应收到 TransportError"""
transport = _make_transport(MOCK_SERVER_SCRIPT)
await transport.connect()
# 手动添加一个 pending future
loop = asyncio.get_running_loop()
future = loop.create_future()
transport._pending[999] = future
await transport.disconnect()
assert future.done()
with pytest.raises(TransportError, match="Transport disconnected"):
future.result()
async def test_disconnect_clears_pending(self):
transport = _make_transport(MOCK_SERVER_SCRIPT)
await transport.connect()
transport._pending[1] = asyncio.get_running_loop().create_future()
await transport.disconnect()
assert transport._pending == {}
# ── 请求发送测试 ──────────────────────────────────────────
class TestStdioTransportSendRequest:
"""StdioTransport 请求发送测试"""
async def test_send_request_not_connected_raises(self):
transport = StdioTransport(command="echo")
with pytest.raises(TransportError, match="not connected"):
await transport.send_request("tools/list")
async def test_send_request_tools_list(self):
transport = _make_transport(MOCK_SERVER_SCRIPT)
try:
await transport.connect()
result = await transport.send_request("tools/list")
assert "tools" in result
assert len(result["tools"]) == 1
assert result["tools"][0]["name"] == "echo"
finally:
await transport.disconnect()
async def test_send_request_tools_call(self):
transport = _make_transport(MOCK_SERVER_SCRIPT)
try:
await transport.connect()
result = await transport.send_request(
"tools/call",
params={"name": "echo", "arguments": {"msg": "hello world"}},
)
assert "content" in result
assert result["content"][0]["text"] == "hello world"
finally:
await transport.disconnect()
async def test_send_request_json_rpc_error(self):
transport = _make_transport(MOCK_SERVER_SCRIPT)
try:
await transport.connect()
with pytest.raises(TransportError, match="JSON-RPC error"):
await transport.send_request("unknown/method")
finally:
await transport.disconnect()
async def test_send_request_unknown_tool_error(self):
transport = _make_transport(MOCK_SERVER_SCRIPT)
try:
await transport.connect()
with pytest.raises(TransportError, match="Unknown tool"):
await transport.send_request(
"tools/call",
params={"name": "nonexistent", "arguments": {}},
)
finally:
await transport.disconnect()
async def test_request_id_increments(self):
transport = _make_transport(MOCK_SERVER_SCRIPT)
try:
await transport.connect()
# connect 时已经用了 id=1 (initialize)
id_before = transport._request_id
await transport.send_request("tools/list")
id_after_1 = transport._request_id
await transport.send_request("tools/list")
id_after_2 = transport._request_id
assert id_after_1 == id_before + 1
assert id_after_2 == id_before + 2
finally:
await transport.disconnect()
async def test_send_request_timeout(self):
"""请求超时测试"""
# 使用一个不响应的脚本
silent_script = "import sys; sys.stdin.read()"
transport = StdioTransport(
command=sys.executable,
args=["-c", silent_script],
timeout=0.5,
)
# 手动设置连接状态以绕过 initialize
transport._connected = True
transport._process = await asyncio.create_subprocess_exec(
sys.executable, "-c", silent_script,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
transport._reader_task = asyncio.create_task(transport._read_stdout())
transport._stderr_task = asyncio.create_task(transport._read_stderr())
try:
with pytest.raises(TransportError, match="Timeout"):
await transport.send_request("tools/list")
finally:
transport._connected = False
await transport._cleanup()
# ── 并发请求测试 ──────────────────────────────────────────
class TestStdioTransportConcurrentRequests:
"""StdioTransport 并发请求测试"""
async def test_concurrent_requests_correct_id_matching(self):
"""并发请求的响应应正确匹配到对应的 Future"""
transport = _make_transport(MOCK_SERVER_SCRIPT)
try:
await transport.connect()
# 并发发送多个请求
results = await asyncio.gather(
transport.send_request("tools/list"),
transport.send_request(
"tools/call",
params={"name": "echo", "arguments": {"msg": "msg1"}},
),
transport.send_request(
"tools/call",
params={"name": "echo", "arguments": {"msg": "msg2"}},
),
)
# 验证每个请求都得到了正确类型的响应
assert "tools" in results[0]
assert results[1]["content"][0]["text"] == "msg1"
assert results[2]["content"][0]["text"] == "msg2"
finally:
await transport.disconnect()
# ── 通知接收测试 ──────────────────────────────────────────
class TestStdioTransportNotifications:
"""StdioTransport 通知接收测试"""
async def test_receive_notification(self):
transport = _make_transport(NOTIFICATION_SERVER_SCRIPT)
try:
await transport.connect()
# tools/list 会先发送一个通知
result = await transport.send_request("tools/list")
assert "tools" in result
# 等待通知到达
await asyncio.sleep(0.1)
# 应该能收到通知
notification = await transport.receive_response()
assert notification["method"] == "notifications/tools/list_changed"
finally:
await transport.disconnect()
async def test_receive_response_no_notification_raises(self):
transport = _make_transport(MOCK_SERVER_SCRIPT)
try:
await transport.connect()
with pytest.raises(TransportError, match="No notification"):
await transport.receive_response()
finally:
await transport.disconnect()
async def test_receive_response_not_connected_raises(self):
transport = StdioTransport(command="echo")
with pytest.raises(TransportError, match="not connected"):
await transport.receive_response()
# ── 子进程退出检测测试 ──────────────────────────────────────────
class TestStdioTransportProcessExit:
"""StdioTransport 子进程退出检测测试"""
async def test_subprocess_exit_detection(self):
"""子进程退出后 is_connected 应返回 False"""
transport = _make_transport(EXIT_AFTER_INIT_SCRIPT)
try:
await transport.connect()
assert transport.is_connected
# 等待子进程退出
await asyncio.sleep(0.5)
# 子进程已退出is_connected 应为 False
assert not transport.is_connected
finally:
if transport._process is not None:
await transport.disconnect()
async def test_send_request_after_process_exit_raises(self):
"""子进程退出后发送请求应抛出 TransportError"""
transport = _make_transport(EXIT_AFTER_INIT_SCRIPT)
try:
await transport.connect()
# 等待子进程退出
await asyncio.sleep(0.5)
if not transport.is_connected:
with pytest.raises(TransportError, match="not connected"):
await transport.send_request("tools/list")
finally:
if transport._process is not None:
await transport.disconnect()
# ── stderr 转发测试 ──────────────────────────────────────────
class TestStdioTransportStderr:
"""StdioTransport stderr 转发测试"""
async def test_stderr_forwarded_to_logger(self, caplog):
"""stderr 输出应转发到 logger"""
import logging
transport = _make_transport(STDERR_SERVER_SCRIPT)
try:
with caplog.at_level(logging.DEBUG, logger="agentkit.mcp.transport"):
await transport.connect()
# 发送一个请求触发 stderr 输出
await transport.send_request("tools/list")
# 等待 stderr 被读取
await asyncio.sleep(0.2)
# 检查日志中包含 stderr 输出
stderr_logs = [
r for r in caplog.records
if "mock server starting" in r.message
]
assert len(stderr_logs) > 0
finally:
await transport.disconnect()
# ── is_connected 属性测试 ──────────────────────────────────────────
class TestStdioTransportIsConnected:
"""StdioTransport is_connected 属性测试"""
async def test_is_connected_before_connect(self):
transport = _make_transport(MOCK_SERVER_SCRIPT)
assert not transport.is_connected
async def test_is_connected_after_connect(self):
transport = _make_transport(MOCK_SERVER_SCRIPT)
try:
await transport.connect()
assert transport.is_connected
finally:
await transport.disconnect()
async def test_is_connected_after_disconnect(self):
transport = _make_transport(MOCK_SERVER_SCRIPT)
await transport.connect()
await transport.disconnect()
assert not transport.is_connected
async def test_is_connected_checks_process_returncode(self):
"""is_connected 应检查 process.returncode"""
transport = _make_transport(MOCK_SERVER_SCRIPT)
try:
await transport.connect()
assert transport._process is not None
# 模拟进程退出但 _connected 仍为 True
transport._connected = True
# 终止子进程
transport._process.kill()
await transport._process.wait()
# is_connected 应为 False 因为 returncode 不为 None
assert not transport.is_connected
finally:
transport._connected = False
await transport._cleanup()
# ── 完整生命周期测试 ──────────────────────────────────────────
class TestStdioTransportLifecycle:
"""StdioTransport 完整生命周期测试"""
async def test_full_lifecycle(self):
"""测试完整的 connect → send_request → disconnect 生命周期"""
transport = _make_transport(MOCK_SERVER_SCRIPT)
# 1. 连接
await transport.connect()
assert transport.is_connected
# 2. 发送请求
result = await transport.send_request("tools/list")
assert "tools" in result
# 3. 发送带参数的请求
result = await transport.send_request(
"tools/call",
params={"name": "echo", "arguments": {"msg": "test"}},
)
assert result["content"][0]["text"] == "test"
# 4. 断开
await transport.disconnect()
assert not transport.is_connected
async def test_reconnect_after_disconnect(self):
"""测试断开后重新连接"""
transport = _make_transport(MOCK_SERVER_SCRIPT)
# 第一次连接
await transport.connect()
result1 = await transport.send_request("tools/list")
assert "tools" in result1
await transport.disconnect()
# 重新连接
await transport.connect()
result2 = await transport.send_request("tools/list")
assert "tools" in result2
await transport.disconnect()
async def test_env_variables_passed_to_subprocess(self):
"""测试环境变量传递给子进程"""
# 使用打印环境变量的脚本
env_script = textwrap.dedent("""\
import sys
import json
import os
for line in sys.stdin:
line = line.strip()
if not line:
continue
try:
data = json.loads(line)
except json.JSONDecodeError:
continue
if "id" not in data:
continue
method = data.get("method", "")
req_id = data.get("id")
if method == "initialize":
response = {
"jsonrpc": "2.0",
"id": req_id,
"result": {"protocolVersion": "2024-11-05", "capabilities": {}, "serverInfo": {"name": "env-server", "version": "0.1.0"}},
}
sys.stdout.write(json.dumps(response) + "\\n")
sys.stdout.flush()
elif method == "tools/call":
test_env = os.environ.get("TEST_MCP_VAR", "not_set")
response = {
"jsonrpc": "2.0",
"id": req_id,
"result": {
"content": [{"type": "text", "text": test_env}]
},
}
sys.stdout.write(json.dumps(response) + "\\n")
sys.stdout.flush()
else:
response = {
"jsonrpc": "2.0",
"id": req_id,
"result": {},
}
sys.stdout.write(json.dumps(response) + "\\n")
sys.stdout.flush()
""")
transport = StdioTransport(
command=sys.executable,
args=["-c", env_script],
env={"TEST_MCP_VAR": "hello_from_env"},
)
try:
await transport.connect()
result = await transport.send_request(
"tools/call",
params={"name": "check_env", "arguments": {}},
)
assert result["content"][0]["text"] == "hello_from_env"
finally:
await transport.disconnect()