feat(mcp): U1 StdioTransport for subprocess-based MCP communication
Add StdioTransport class supporting stdio JSON-RPC over subprocess stdin/stdout with asyncio.create_subprocess_exec, pending futures for request/response matching, and stderr forwarding.
This commit is contained in:
parent
9b6c0230c0
commit
66b9217569
|
|
@ -1,11 +1,12 @@
|
|||
"""MCP Transport - 传输层抽象
|
||||
|
||||
提供 MCP 协议的传输层实现,支持 Streamable HTTP 和 SSE 两种传输方式。
|
||||
提供 MCP 协议的传输层实现,支持 Streamable HTTP、SSE 和 Stdio 三种传输方式。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -352,3 +353,304 @@ class SSETransport(Transport):
|
|||
)
|
||||
except asyncio.TimeoutError:
|
||||
raise TransportError("Timeout waiting for SSE response")
|
||||
|
||||
|
||||
class StdioTransport(Transport):
|
||||
"""Stdio 传输
|
||||
|
||||
通过 stdin/stdout 与 MCP Server 子进程通信,使用 newline-delimited JSON-RPC 消息格式。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
command: str,
|
||||
args: list[str] | None = None,
|
||||
env: dict[str, str] | None = None,
|
||||
timeout: float = 30.0,
|
||||
):
|
||||
self._command = command
|
||||
self._args = args or []
|
||||
self._env = env
|
||||
self._timeout = timeout
|
||||
self._process: asyncio.subprocess.Process | None = None
|
||||
self._request_id = 0
|
||||
self._pending: dict[int, asyncio.Future[Any]] = {}
|
||||
self._reader_task: asyncio.Task[None] | None = None
|
||||
self._stderr_task: asyncio.Task[None] | None = None
|
||||
self._connected = False
|
||||
self._notifications: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return (
|
||||
self._connected
|
||||
and self._process is not None
|
||||
and self._process.returncode is None
|
||||
)
|
||||
|
||||
def _next_request_id(self) -> int:
|
||||
"""生成下一个请求 ID"""
|
||||
self._request_id += 1
|
||||
return self._request_id
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""启动子进程并完成 MCP 初始化握手
|
||||
|
||||
Raises:
|
||||
TransportError: 子进程启动失败或初始化超时
|
||||
"""
|
||||
if self.is_connected:
|
||||
return
|
||||
|
||||
# 合并环境变量
|
||||
merged_env = dict(os.environ)
|
||||
if self._env:
|
||||
merged_env.update(self._env)
|
||||
|
||||
try:
|
||||
self._process = await asyncio.create_subprocess_exec(
|
||||
self._command,
|
||||
*self._args,
|
||||
env=merged_env,
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
except OSError as e:
|
||||
raise TransportError(f"Failed to start process: {self._command}", cause=e) from e
|
||||
|
||||
# 启动 stdout 读取任务
|
||||
self._reader_task = asyncio.create_task(self._read_stdout())
|
||||
|
||||
# 启动 stderr 读取任务
|
||||
self._stderr_task = asyncio.create_task(self._read_stderr())
|
||||
|
||||
# 发送 initialize 请求并等待响应
|
||||
try:
|
||||
init_result = await asyncio.wait_for(
|
||||
self._send_request_internal(
|
||||
"initialize",
|
||||
{
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {},
|
||||
"clientInfo": {"name": "agentkit", "version": "0.1.0"},
|
||||
},
|
||||
),
|
||||
timeout=self._timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
await self._cleanup()
|
||||
raise TransportError("Timeout waiting for initialize response")
|
||||
except TransportError:
|
||||
await self._cleanup()
|
||||
raise
|
||||
|
||||
# 发送 initialized 通知
|
||||
await self._send_notification("notifications/initialized")
|
||||
|
||||
self._connected = True
|
||||
logger.info(
|
||||
"StdioTransport connected to %s %s",
|
||||
self._command,
|
||||
" ".join(self._args),
|
||||
)
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""关闭子进程连接"""
|
||||
self._connected = False
|
||||
await self._cleanup()
|
||||
|
||||
async def _cleanup(self) -> None:
|
||||
"""清理子进程和相关资源"""
|
||||
# 取消读取任务
|
||||
for task in (self._reader_task, self._stderr_task):
|
||||
if task is not None:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._reader_task = None
|
||||
self._stderr_task = None
|
||||
|
||||
# 关闭 stdin
|
||||
if self._process is not None and self._process.stdin is not None:
|
||||
self._process.stdin.close()
|
||||
try:
|
||||
await self._process.stdin.drain()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 等待子进程退出
|
||||
if self._process is not None and self._process.returncode is None:
|
||||
try:
|
||||
await asyncio.wait_for(self._process.wait(), timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
self._process.kill()
|
||||
await self._process.wait()
|
||||
|
||||
self._process = None
|
||||
|
||||
# 取消所有等待中的 Future
|
||||
for future in self._pending.values():
|
||||
if not future.done():
|
||||
future.set_exception(TransportError("Transport disconnected"))
|
||||
self._pending.clear()
|
||||
|
||||
logger.info("StdioTransport disconnected")
|
||||
|
||||
async def send_request(self, method: str, params: dict[str, Any] | None = None) -> Any:
|
||||
"""发送 JSON-RPC 请求并等待响应
|
||||
|
||||
Args:
|
||||
method: JSON-RPC 方法名
|
||||
params: 请求参数
|
||||
|
||||
Returns:
|
||||
JSON-RPC 响应的 result 字段
|
||||
|
||||
Raises:
|
||||
TransportError: 连接未建立或请求失败
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise TransportError("Transport not connected")
|
||||
return await self._send_request_internal(method, params)
|
||||
|
||||
async def _send_request_internal(
|
||||
self, method: str, params: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
"""内部请求发送方法(connect 时也可调用)"""
|
||||
request_id = self._next_request_id()
|
||||
message: dict[str, Any] = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"method": method,
|
||||
}
|
||||
if params is not None:
|
||||
message["params"] = params
|
||||
|
||||
await self._write_message(message)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
future: asyncio.Future[Any] = loop.create_future()
|
||||
self._pending[request_id] = future
|
||||
|
||||
try:
|
||||
return await asyncio.wait_for(future, timeout=self._timeout)
|
||||
except asyncio.TimeoutError:
|
||||
self._pending.pop(request_id, None)
|
||||
raise TransportError(f"Timeout waiting for response to {method}")
|
||||
except TransportError:
|
||||
self._pending.pop(request_id, None)
|
||||
raise
|
||||
|
||||
async def _send_notification(self, method: str, params: dict[str, Any] | None = None) -> None:
|
||||
"""发送 JSON-RPC 通知(无 id,不期待响应)"""
|
||||
message: dict[str, Any] = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": method,
|
||||
}
|
||||
if params is not None:
|
||||
message["params"] = params
|
||||
await self._write_message(message)
|
||||
|
||||
async def _write_message(self, message: dict[str, Any]) -> None:
|
||||
"""将 JSON-RPC 消息写入子进程 stdin"""
|
||||
if self._process is None or self._process.stdin is None:
|
||||
raise TransportError("Process stdin not available")
|
||||
data = (json.dumps(message) + "\n").encode("utf-8")
|
||||
self._process.stdin.write(data)
|
||||
await self._process.stdin.drain()
|
||||
|
||||
async def receive_response(self) -> dict[str, Any]:
|
||||
"""接收通知消息
|
||||
|
||||
对于 StdioTransport,请求响应通过 _pending Future 异步返回。
|
||||
此方法仅用于获取服务端推送的通知消息。
|
||||
|
||||
Returns:
|
||||
JSON-RPC 通知消息
|
||||
|
||||
Raises:
|
||||
TransportError: 连接未建立或无通知
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise TransportError("Transport not connected")
|
||||
|
||||
if not self._notifications.empty():
|
||||
return self._notifications.get_nowait()
|
||||
|
||||
raise TransportError("No notification to receive")
|
||||
|
||||
async def _read_stdout(self) -> None:
|
||||
"""持续从子进程 stdout 读取 JSON-RPC 消息"""
|
||||
if self._process is None or self._process.stdout is None:
|
||||
return
|
||||
|
||||
try:
|
||||
while True:
|
||||
line = await self._process.stdout.readline()
|
||||
if not line:
|
||||
# EOF — 子进程退出
|
||||
if self._connected:
|
||||
logger.warning("StdioTransport: subprocess stdout EOF")
|
||||
break
|
||||
|
||||
line_str = line.decode("utf-8").strip()
|
||||
if not line_str:
|
||||
continue
|
||||
|
||||
try:
|
||||
data = json.loads(line_str)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("StdioTransport: invalid JSON from stdout: %s", line_str)
|
||||
continue
|
||||
|
||||
# 响应消息(有 id 字段)
|
||||
if "id" in data:
|
||||
request_id = data["id"]
|
||||
future = self._pending.pop(request_id, None)
|
||||
if future is not None and not future.done():
|
||||
if "error" in data:
|
||||
error = data["error"]
|
||||
future.set_exception(
|
||||
TransportError(
|
||||
f"JSON-RPC error {error.get('code')}: {error.get('message')}"
|
||||
)
|
||||
)
|
||||
else:
|
||||
future.set_result(data.get("result"))
|
||||
elif future is None:
|
||||
logger.warning(
|
||||
"StdioTransport: received response for unknown request id %s",
|
||||
request_id,
|
||||
)
|
||||
|
||||
# 通知消息(有 method 字段,无 id)
|
||||
elif "method" in data:
|
||||
await self._notifications.put(data)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
if self._connected:
|
||||
logger.error("StdioTransport: stdout reader error: %s", e)
|
||||
|
||||
async def _read_stderr(self) -> None:
|
||||
"""持续从子进程 stderr 读取并转发到 logger"""
|
||||
if self._process is None or self._process.stderr is None:
|
||||
return
|
||||
|
||||
try:
|
||||
while True:
|
||||
line = await self._process.stderr.readline()
|
||||
if not line:
|
||||
break
|
||||
line_str = line.decode("utf-8", errors="replace").rstrip()
|
||||
if line_str:
|
||||
logger.debug("StdioTransport stderr: %s", line_str)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e:
|
||||
if self._connected:
|
||||
logger.error("StdioTransport: stderr reader error: %s", e)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,749 @@
|
|||
"""StdioTransport 单元测试
|
||||
|
||||
使用内联 mock MCP server 子进程进行测试,无需外部依赖。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import textwrap
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.mcp.transport import StdioTransport, TransportError
|
||||
|
||||
# 内联 mock MCP server 脚本
|
||||
# 读取 stdin 的 JSON-RPC 消息,根据 method 返回对应响应
|
||||
MOCK_SERVER_SCRIPT = textwrap.dedent("""\
|
||||
import sys
|
||||
import json
|
||||
|
||||
def handle_request(data):
|
||||
method = data.get("method", "")
|
||||
req_id = data.get("id")
|
||||
params = data.get("params", {})
|
||||
|
||||
if method == "initialize":
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": req_id,
|
||||
"result": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {"tools": {"listChanged": True}},
|
||||
"serverInfo": {"name": "mock-mcp-server", "version": "0.1.0"},
|
||||
},
|
||||
}
|
||||
elif method == "tools/list":
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": req_id,
|
||||
"result": {
|
||||
"tools": [
|
||||
{
|
||||
"name": "echo",
|
||||
"description": "Echo tool",
|
||||
"inputSchema": {"type": "object", "properties": {"msg": {"type": "string"}}},
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
elif method == "tools/call":
|
||||
name = params.get("name", "")
|
||||
arguments = params.get("arguments", {})
|
||||
if name == "echo":
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": req_id,
|
||||
"result": {
|
||||
"content": [{"type": "text", "text": arguments.get("msg", "")}]
|
||||
},
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": req_id,
|
||||
"error": {"code": -32601, "message": f"Unknown tool: {name}"},
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": req_id,
|
||||
"error": {"code": -32601, "message": f"Method not found: {method}"},
|
||||
}
|
||||
|
||||
def main():
|
||||
for line in sys.stdin:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
data = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
# 通知消息(无 id)不回复
|
||||
if "id" not in data:
|
||||
continue
|
||||
response = handle_request(data)
|
||||
sys.stdout.write(json.dumps(response) + "\\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
""")
|
||||
|
||||
# 发送通知后立即退出的 mock server(用于测试子进程退出检测)
|
||||
EXIT_AFTER_INIT_SCRIPT = textwrap.dedent("""\
|
||||
import sys
|
||||
import json
|
||||
|
||||
for line in sys.stdin:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
data = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
if "id" not in data:
|
||||
continue
|
||||
method = data.get("method", "")
|
||||
req_id = data.get("id")
|
||||
if method == "initialize":
|
||||
response = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": req_id,
|
||||
"result": {"protocolVersion": "2024-11-05", "capabilities": {}, "serverInfo": {"name": "exit-server", "version": "0.1.0"}},
|
||||
}
|
||||
sys.stdout.write(json.dumps(response) + "\\n")
|
||||
sys.stdout.flush()
|
||||
# 初始化后立即退出
|
||||
sys.exit(0)
|
||||
else:
|
||||
# 不会到达这里
|
||||
pass
|
||||
""")
|
||||
|
||||
# 发送通知的 mock server(用于测试通知接收)
|
||||
NOTIFICATION_SERVER_SCRIPT = textwrap.dedent("""\
|
||||
import sys
|
||||
import json
|
||||
|
||||
for line in sys.stdin:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
data = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
if "id" not in data:
|
||||
continue
|
||||
method = data.get("method", "")
|
||||
req_id = data.get("id")
|
||||
if method == "initialize":
|
||||
response = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": req_id,
|
||||
"result": {"protocolVersion": "2024-11-05", "capabilities": {}, "serverInfo": {"name": "notif-server", "version": "0.1.0"}},
|
||||
}
|
||||
sys.stdout.write(json.dumps(response) + "\\n")
|
||||
sys.stdout.flush()
|
||||
elif method == "tools/list":
|
||||
# 先发送一个通知
|
||||
notification = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": "notifications/tools/list_changed",
|
||||
}
|
||||
sys.stdout.write(json.dumps(notification) + "\\n")
|
||||
sys.stdout.flush()
|
||||
# 再发送响应
|
||||
response = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": req_id,
|
||||
"result": {"tools": [{"name": "updated_tool"}]},
|
||||
}
|
||||
sys.stdout.write(json.dumps(response) + "\\n")
|
||||
sys.stdout.flush()
|
||||
else:
|
||||
response = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": req_id,
|
||||
"error": {"code": -32601, "message": "Method not found"},
|
||||
}
|
||||
sys.stdout.write(json.dumps(response) + "\\n")
|
||||
sys.stdout.flush()
|
||||
""")
|
||||
|
||||
# 写入 stderr 的 mock server
|
||||
STDERR_SERVER_SCRIPT = textwrap.dedent("""\
|
||||
import sys
|
||||
import json
|
||||
|
||||
for line in sys.stdin:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
data = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
if "id" not in data:
|
||||
continue
|
||||
method = data.get("method", "")
|
||||
req_id = data.get("id")
|
||||
if method == "initialize":
|
||||
# 写入 stderr
|
||||
sys.stderr.write("mock server starting\\n")
|
||||
sys.stderr.flush()
|
||||
response = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": req_id,
|
||||
"result": {"protocolVersion": "2024-11-05", "capabilities": {}, "serverInfo": {"name": "stderr-server", "version": "0.1.0"}},
|
||||
}
|
||||
sys.stdout.write(json.dumps(response) + "\\n")
|
||||
sys.stdout.flush()
|
||||
else:
|
||||
response = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": req_id,
|
||||
"result": {},
|
||||
}
|
||||
sys.stdout.write(json.dumps(response) + "\\n")
|
||||
sys.stdout.flush()
|
||||
""")
|
||||
|
||||
|
||||
def _make_transport(script: str, **kwargs) -> StdioTransport:
|
||||
"""创建使用内联 mock server 脚本的 StdioTransport"""
|
||||
return StdioTransport(
|
||||
command=sys.executable,
|
||||
args=["-c", script],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
# ── 构造测试 ──────────────────────────────────────────
|
||||
|
||||
|
||||
class TestStdioTransportConstruction:
|
||||
"""StdioTransport 构造测试"""
|
||||
|
||||
def test_default_args(self):
|
||||
transport = StdioTransport(command="echo")
|
||||
assert transport._command == "echo"
|
||||
assert transport._args == []
|
||||
assert transport._env is None
|
||||
assert transport._timeout == 30.0
|
||||
assert transport._process is None
|
||||
assert transport._request_id == 0
|
||||
assert transport._pending == {}
|
||||
assert transport._reader_task is None
|
||||
assert transport._stderr_task is None
|
||||
assert transport._connected is False
|
||||
|
||||
def test_custom_args(self):
|
||||
transport = StdioTransport(
|
||||
command="node",
|
||||
args=["server.js", "--port", "3000"],
|
||||
env={"NODE_ENV": "test"},
|
||||
timeout=10.0,
|
||||
)
|
||||
assert transport._command == "node"
|
||||
assert transport._args == ["server.js", "--port", "3000"]
|
||||
assert transport._env == {"NODE_ENV": "test"}
|
||||
assert transport._timeout == 10.0
|
||||
|
||||
def test_is_connected_initially_false(self):
|
||||
transport = StdioTransport(command="echo")
|
||||
assert not transport.is_connected
|
||||
|
||||
|
||||
# ── 连接/断开测试 ──────────────────────────────────────────
|
||||
|
||||
|
||||
class TestStdioTransportConnect:
|
||||
"""StdioTransport 连接测试"""
|
||||
|
||||
async def test_connect_starts_subprocess(self):
|
||||
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||||
try:
|
||||
await transport.connect()
|
||||
assert transport.is_connected
|
||||
assert transport._process is not None
|
||||
assert transport._process.returncode is None
|
||||
assert transport._reader_task is not None
|
||||
assert transport._stderr_task is not None
|
||||
finally:
|
||||
await transport.disconnect()
|
||||
|
||||
async def test_connect_completes_initialize_handshake(self):
|
||||
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||||
try:
|
||||
await transport.connect()
|
||||
# connect 成功说明 initialize 握手完成
|
||||
assert transport._connected is True
|
||||
# request_id 应该至少递增到 1(initialize 请求)
|
||||
assert transport._request_id >= 1
|
||||
finally:
|
||||
await transport.disconnect()
|
||||
|
||||
async def test_connect_is_idempotent(self):
|
||||
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||||
try:
|
||||
await transport.connect()
|
||||
await transport.connect() # 不应报错
|
||||
assert transport.is_connected
|
||||
finally:
|
||||
await transport.disconnect()
|
||||
|
||||
async def test_connect_with_invalid_command_raises(self):
|
||||
transport = StdioTransport(command="/nonexistent/command")
|
||||
with pytest.raises(TransportError, match="Failed to start process"):
|
||||
await transport.connect()
|
||||
|
||||
async def test_connect_timeout(self):
|
||||
"""使用不响应 initialize 的子进程测试超时"""
|
||||
# 使用一个只读取 stdin 但不输出任何内容的脚本
|
||||
silent_script = "import sys; sys.stdin.read()"
|
||||
transport = StdioTransport(
|
||||
command=sys.executable,
|
||||
args=["-c", silent_script],
|
||||
timeout=0.5,
|
||||
)
|
||||
with pytest.raises(TransportError, match="Timeout waiting for initialize"):
|
||||
await transport.connect()
|
||||
assert not transport.is_connected
|
||||
|
||||
|
||||
class TestStdioTransportDisconnect:
|
||||
"""StdioTransport 断开测试"""
|
||||
|
||||
async def test_disconnect_closes_subprocess(self):
|
||||
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||||
await transport.connect()
|
||||
assert transport.is_connected
|
||||
|
||||
await transport.disconnect()
|
||||
assert not transport.is_connected
|
||||
assert transport._process is None
|
||||
assert transport._reader_task is None
|
||||
assert transport._stderr_task is None
|
||||
|
||||
async def test_disconnect_is_idempotent(self):
|
||||
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||||
await transport.connect()
|
||||
await transport.disconnect()
|
||||
await transport.disconnect() # 不应报错
|
||||
|
||||
async def test_disconnect_cancels_pending_futures(self):
|
||||
"""断开时所有 pending future 应收到 TransportError"""
|
||||
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||||
await transport.connect()
|
||||
|
||||
# 手动添加一个 pending future
|
||||
loop = asyncio.get_running_loop()
|
||||
future = loop.create_future()
|
||||
transport._pending[999] = future
|
||||
|
||||
await transport.disconnect()
|
||||
|
||||
assert future.done()
|
||||
with pytest.raises(TransportError, match="Transport disconnected"):
|
||||
future.result()
|
||||
|
||||
async def test_disconnect_clears_pending(self):
|
||||
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||||
await transport.connect()
|
||||
transport._pending[1] = asyncio.get_running_loop().create_future()
|
||||
await transport.disconnect()
|
||||
assert transport._pending == {}
|
||||
|
||||
|
||||
# ── 请求发送测试 ──────────────────────────────────────────
|
||||
|
||||
|
||||
class TestStdioTransportSendRequest:
|
||||
"""StdioTransport 请求发送测试"""
|
||||
|
||||
async def test_send_request_not_connected_raises(self):
|
||||
transport = StdioTransport(command="echo")
|
||||
with pytest.raises(TransportError, match="not connected"):
|
||||
await transport.send_request("tools/list")
|
||||
|
||||
async def test_send_request_tools_list(self):
|
||||
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||||
try:
|
||||
await transport.connect()
|
||||
result = await transport.send_request("tools/list")
|
||||
assert "tools" in result
|
||||
assert len(result["tools"]) == 1
|
||||
assert result["tools"][0]["name"] == "echo"
|
||||
finally:
|
||||
await transport.disconnect()
|
||||
|
||||
async def test_send_request_tools_call(self):
|
||||
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||||
try:
|
||||
await transport.connect()
|
||||
result = await transport.send_request(
|
||||
"tools/call",
|
||||
params={"name": "echo", "arguments": {"msg": "hello world"}},
|
||||
)
|
||||
assert "content" in result
|
||||
assert result["content"][0]["text"] == "hello world"
|
||||
finally:
|
||||
await transport.disconnect()
|
||||
|
||||
async def test_send_request_json_rpc_error(self):
|
||||
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||||
try:
|
||||
await transport.connect()
|
||||
with pytest.raises(TransportError, match="JSON-RPC error"):
|
||||
await transport.send_request("unknown/method")
|
||||
finally:
|
||||
await transport.disconnect()
|
||||
|
||||
async def test_send_request_unknown_tool_error(self):
|
||||
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||||
try:
|
||||
await transport.connect()
|
||||
with pytest.raises(TransportError, match="Unknown tool"):
|
||||
await transport.send_request(
|
||||
"tools/call",
|
||||
params={"name": "nonexistent", "arguments": {}},
|
||||
)
|
||||
finally:
|
||||
await transport.disconnect()
|
||||
|
||||
async def test_request_id_increments(self):
|
||||
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||||
try:
|
||||
await transport.connect()
|
||||
# connect 时已经用了 id=1 (initialize)
|
||||
id_before = transport._request_id
|
||||
await transport.send_request("tools/list")
|
||||
id_after_1 = transport._request_id
|
||||
await transport.send_request("tools/list")
|
||||
id_after_2 = transport._request_id
|
||||
assert id_after_1 == id_before + 1
|
||||
assert id_after_2 == id_before + 2
|
||||
finally:
|
||||
await transport.disconnect()
|
||||
|
||||
async def test_send_request_timeout(self):
|
||||
"""请求超时测试"""
|
||||
# 使用一个不响应的脚本
|
||||
silent_script = "import sys; sys.stdin.read()"
|
||||
transport = StdioTransport(
|
||||
command=sys.executable,
|
||||
args=["-c", silent_script],
|
||||
timeout=0.5,
|
||||
)
|
||||
# 手动设置连接状态以绕过 initialize
|
||||
transport._connected = True
|
||||
transport._process = await asyncio.create_subprocess_exec(
|
||||
sys.executable, "-c", silent_script,
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
transport._reader_task = asyncio.create_task(transport._read_stdout())
|
||||
transport._stderr_task = asyncio.create_task(transport._read_stderr())
|
||||
|
||||
try:
|
||||
with pytest.raises(TransportError, match="Timeout"):
|
||||
await transport.send_request("tools/list")
|
||||
finally:
|
||||
transport._connected = False
|
||||
await transport._cleanup()
|
||||
|
||||
|
||||
# ── 并发请求测试 ──────────────────────────────────────────
|
||||
|
||||
|
||||
class TestStdioTransportConcurrentRequests:
|
||||
"""StdioTransport 并发请求测试"""
|
||||
|
||||
async def test_concurrent_requests_correct_id_matching(self):
|
||||
"""并发请求的响应应正确匹配到对应的 Future"""
|
||||
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||||
try:
|
||||
await transport.connect()
|
||||
|
||||
# 并发发送多个请求
|
||||
results = await asyncio.gather(
|
||||
transport.send_request("tools/list"),
|
||||
transport.send_request(
|
||||
"tools/call",
|
||||
params={"name": "echo", "arguments": {"msg": "msg1"}},
|
||||
),
|
||||
transport.send_request(
|
||||
"tools/call",
|
||||
params={"name": "echo", "arguments": {"msg": "msg2"}},
|
||||
),
|
||||
)
|
||||
|
||||
# 验证每个请求都得到了正确类型的响应
|
||||
assert "tools" in results[0]
|
||||
assert results[1]["content"][0]["text"] == "msg1"
|
||||
assert results[2]["content"][0]["text"] == "msg2"
|
||||
finally:
|
||||
await transport.disconnect()
|
||||
|
||||
|
||||
# ── 通知接收测试 ──────────────────────────────────────────
|
||||
|
||||
|
||||
class TestStdioTransportNotifications:
|
||||
"""StdioTransport 通知接收测试"""
|
||||
|
||||
async def test_receive_notification(self):
|
||||
transport = _make_transport(NOTIFICATION_SERVER_SCRIPT)
|
||||
try:
|
||||
await transport.connect()
|
||||
|
||||
# tools/list 会先发送一个通知
|
||||
result = await transport.send_request("tools/list")
|
||||
assert "tools" in result
|
||||
|
||||
# 等待通知到达
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# 应该能收到通知
|
||||
notification = await transport.receive_response()
|
||||
assert notification["method"] == "notifications/tools/list_changed"
|
||||
finally:
|
||||
await transport.disconnect()
|
||||
|
||||
async def test_receive_response_no_notification_raises(self):
|
||||
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||||
try:
|
||||
await transport.connect()
|
||||
with pytest.raises(TransportError, match="No notification"):
|
||||
await transport.receive_response()
|
||||
finally:
|
||||
await transport.disconnect()
|
||||
|
||||
async def test_receive_response_not_connected_raises(self):
|
||||
transport = StdioTransport(command="echo")
|
||||
with pytest.raises(TransportError, match="not connected"):
|
||||
await transport.receive_response()
|
||||
|
||||
|
||||
# ── 子进程退出检测测试 ──────────────────────────────────────────
|
||||
|
||||
|
||||
class TestStdioTransportProcessExit:
|
||||
"""StdioTransport 子进程退出检测测试"""
|
||||
|
||||
async def test_subprocess_exit_detection(self):
|
||||
"""子进程退出后 is_connected 应返回 False"""
|
||||
transport = _make_transport(EXIT_AFTER_INIT_SCRIPT)
|
||||
try:
|
||||
await transport.connect()
|
||||
assert transport.is_connected
|
||||
|
||||
# 等待子进程退出
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# 子进程已退出,is_connected 应为 False
|
||||
assert not transport.is_connected
|
||||
finally:
|
||||
if transport._process is not None:
|
||||
await transport.disconnect()
|
||||
|
||||
async def test_send_request_after_process_exit_raises(self):
|
||||
"""子进程退出后发送请求应抛出 TransportError"""
|
||||
transport = _make_transport(EXIT_AFTER_INIT_SCRIPT)
|
||||
try:
|
||||
await transport.connect()
|
||||
# 等待子进程退出
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
if not transport.is_connected:
|
||||
with pytest.raises(TransportError, match="not connected"):
|
||||
await transport.send_request("tools/list")
|
||||
finally:
|
||||
if transport._process is not None:
|
||||
await transport.disconnect()
|
||||
|
||||
|
||||
# ── stderr 转发测试 ──────────────────────────────────────────
|
||||
|
||||
|
||||
class TestStdioTransportStderr:
|
||||
"""StdioTransport stderr 转发测试"""
|
||||
|
||||
async def test_stderr_forwarded_to_logger(self, caplog):
|
||||
"""stderr 输出应转发到 logger"""
|
||||
import logging
|
||||
|
||||
transport = _make_transport(STDERR_SERVER_SCRIPT)
|
||||
try:
|
||||
with caplog.at_level(logging.DEBUG, logger="agentkit.mcp.transport"):
|
||||
await transport.connect()
|
||||
# 发送一个请求触发 stderr 输出
|
||||
await transport.send_request("tools/list")
|
||||
# 等待 stderr 被读取
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# 检查日志中包含 stderr 输出
|
||||
stderr_logs = [
|
||||
r for r in caplog.records
|
||||
if "mock server starting" in r.message
|
||||
]
|
||||
assert len(stderr_logs) > 0
|
||||
finally:
|
||||
await transport.disconnect()
|
||||
|
||||
|
||||
# ── is_connected 属性测试 ──────────────────────────────────────────
|
||||
|
||||
|
||||
class TestStdioTransportIsConnected:
|
||||
"""StdioTransport is_connected 属性测试"""
|
||||
|
||||
async def test_is_connected_before_connect(self):
|
||||
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||||
assert not transport.is_connected
|
||||
|
||||
async def test_is_connected_after_connect(self):
|
||||
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||||
try:
|
||||
await transport.connect()
|
||||
assert transport.is_connected
|
||||
finally:
|
||||
await transport.disconnect()
|
||||
|
||||
async def test_is_connected_after_disconnect(self):
|
||||
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||||
await transport.connect()
|
||||
await transport.disconnect()
|
||||
assert not transport.is_connected
|
||||
|
||||
async def test_is_connected_checks_process_returncode(self):
|
||||
"""is_connected 应检查 process.returncode"""
|
||||
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||||
try:
|
||||
await transport.connect()
|
||||
assert transport._process is not None
|
||||
# 模拟进程退出但 _connected 仍为 True
|
||||
transport._connected = True
|
||||
# 终止子进程
|
||||
transport._process.kill()
|
||||
await transport._process.wait()
|
||||
# is_connected 应为 False 因为 returncode 不为 None
|
||||
assert not transport.is_connected
|
||||
finally:
|
||||
transport._connected = False
|
||||
await transport._cleanup()
|
||||
|
||||
|
||||
# ── 完整生命周期测试 ──────────────────────────────────────────
|
||||
|
||||
|
||||
class TestStdioTransportLifecycle:
|
||||
"""StdioTransport 完整生命周期测试"""
|
||||
|
||||
async def test_full_lifecycle(self):
|
||||
"""测试完整的 connect → send_request → disconnect 生命周期"""
|
||||
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||||
|
||||
# 1. 连接
|
||||
await transport.connect()
|
||||
assert transport.is_connected
|
||||
|
||||
# 2. 发送请求
|
||||
result = await transport.send_request("tools/list")
|
||||
assert "tools" in result
|
||||
|
||||
# 3. 发送带参数的请求
|
||||
result = await transport.send_request(
|
||||
"tools/call",
|
||||
params={"name": "echo", "arguments": {"msg": "test"}},
|
||||
)
|
||||
assert result["content"][0]["text"] == "test"
|
||||
|
||||
# 4. 断开
|
||||
await transport.disconnect()
|
||||
assert not transport.is_connected
|
||||
|
||||
async def test_reconnect_after_disconnect(self):
|
||||
"""测试断开后重新连接"""
|
||||
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||||
|
||||
# 第一次连接
|
||||
await transport.connect()
|
||||
result1 = await transport.send_request("tools/list")
|
||||
assert "tools" in result1
|
||||
await transport.disconnect()
|
||||
|
||||
# 重新连接
|
||||
await transport.connect()
|
||||
result2 = await transport.send_request("tools/list")
|
||||
assert "tools" in result2
|
||||
await transport.disconnect()
|
||||
|
||||
async def test_env_variables_passed_to_subprocess(self):
|
||||
"""测试环境变量传递给子进程"""
|
||||
# 使用打印环境变量的脚本
|
||||
env_script = textwrap.dedent("""\
|
||||
import sys
|
||||
import json
|
||||
import os
|
||||
|
||||
for line in sys.stdin:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
data = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
if "id" not in data:
|
||||
continue
|
||||
method = data.get("method", "")
|
||||
req_id = data.get("id")
|
||||
if method == "initialize":
|
||||
response = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": req_id,
|
||||
"result": {"protocolVersion": "2024-11-05", "capabilities": {}, "serverInfo": {"name": "env-server", "version": "0.1.0"}},
|
||||
}
|
||||
sys.stdout.write(json.dumps(response) + "\\n")
|
||||
sys.stdout.flush()
|
||||
elif method == "tools/call":
|
||||
test_env = os.environ.get("TEST_MCP_VAR", "not_set")
|
||||
response = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": req_id,
|
||||
"result": {
|
||||
"content": [{"type": "text", "text": test_env}]
|
||||
},
|
||||
}
|
||||
sys.stdout.write(json.dumps(response) + "\\n")
|
||||
sys.stdout.flush()
|
||||
else:
|
||||
response = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": req_id,
|
||||
"result": {},
|
||||
}
|
||||
sys.stdout.write(json.dumps(response) + "\\n")
|
||||
sys.stdout.flush()
|
||||
""")
|
||||
|
||||
transport = StdioTransport(
|
||||
command=sys.executable,
|
||||
args=["-c", env_script],
|
||||
env={"TEST_MCP_VAR": "hello_from_env"},
|
||||
)
|
||||
try:
|
||||
await transport.connect()
|
||||
result = await transport.send_request(
|
||||
"tools/call",
|
||||
params={"name": "check_env", "arguments": {}},
|
||||
)
|
||||
assert result["content"][0]["text"] == "hello_from_env"
|
||||
finally:
|
||||
await transport.disconnect()
|
||||
Loading…
Reference in New Issue