feat(mcp): U1 StdioTransport for subprocess-based MCP communication
Add StdioTransport class supporting stdio JSON-RPC over subprocess stdin/stdout with asyncio.create_subprocess_exec, pending futures for request/response matching, and stderr forwarding.
This commit is contained in:
parent
9b6c0230c0
commit
66b9217569
|
|
@ -1,11 +1,12 @@
|
||||||
"""MCP Transport - 传输层抽象
|
"""MCP Transport - 传输层抽象
|
||||||
|
|
||||||
提供 MCP 协议的传输层实现,支持 Streamable HTTP 和 SSE 两种传输方式。
|
提供 MCP 协议的传输层实现,支持 Streamable HTTP、SSE 和 Stdio 三种传输方式。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
@ -352,3 +353,304 @@ class SSETransport(Transport):
|
||||||
)
|
)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
raise TransportError("Timeout waiting for SSE response")
|
raise TransportError("Timeout waiting for SSE response")
|
||||||
|
|
||||||
|
|
||||||
|
class StdioTransport(Transport):
|
||||||
|
"""Stdio 传输
|
||||||
|
|
||||||
|
通过 stdin/stdout 与 MCP Server 子进程通信,使用 newline-delimited JSON-RPC 消息格式。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
command: str,
|
||||||
|
args: list[str] | None = None,
|
||||||
|
env: dict[str, str] | None = None,
|
||||||
|
timeout: float = 30.0,
|
||||||
|
):
|
||||||
|
self._command = command
|
||||||
|
self._args = args or []
|
||||||
|
self._env = env
|
||||||
|
self._timeout = timeout
|
||||||
|
self._process: asyncio.subprocess.Process | None = None
|
||||||
|
self._request_id = 0
|
||||||
|
self._pending: dict[int, asyncio.Future[Any]] = {}
|
||||||
|
self._reader_task: asyncio.Task[None] | None = None
|
||||||
|
self._stderr_task: asyncio.Task[None] | None = None
|
||||||
|
self._connected = False
|
||||||
|
self._notifications: asyncio.Queue[dict[str, Any]] = asyncio.Queue()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_connected(self) -> bool:
|
||||||
|
return (
|
||||||
|
self._connected
|
||||||
|
and self._process is not None
|
||||||
|
and self._process.returncode is None
|
||||||
|
)
|
||||||
|
|
||||||
|
def _next_request_id(self) -> int:
|
||||||
|
"""生成下一个请求 ID"""
|
||||||
|
self._request_id += 1
|
||||||
|
return self._request_id
|
||||||
|
|
||||||
|
async def connect(self) -> None:
|
||||||
|
"""启动子进程并完成 MCP 初始化握手
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TransportError: 子进程启动失败或初始化超时
|
||||||
|
"""
|
||||||
|
if self.is_connected:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 合并环境变量
|
||||||
|
merged_env = dict(os.environ)
|
||||||
|
if self._env:
|
||||||
|
merged_env.update(self._env)
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._process = await asyncio.create_subprocess_exec(
|
||||||
|
self._command,
|
||||||
|
*self._args,
|
||||||
|
env=merged_env,
|
||||||
|
stdin=asyncio.subprocess.PIPE,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
|
)
|
||||||
|
except OSError as e:
|
||||||
|
raise TransportError(f"Failed to start process: {self._command}", cause=e) from e
|
||||||
|
|
||||||
|
# 启动 stdout 读取任务
|
||||||
|
self._reader_task = asyncio.create_task(self._read_stdout())
|
||||||
|
|
||||||
|
# 启动 stderr 读取任务
|
||||||
|
self._stderr_task = asyncio.create_task(self._read_stderr())
|
||||||
|
|
||||||
|
# 发送 initialize 请求并等待响应
|
||||||
|
try:
|
||||||
|
init_result = await asyncio.wait_for(
|
||||||
|
self._send_request_internal(
|
||||||
|
"initialize",
|
||||||
|
{
|
||||||
|
"protocolVersion": "2024-11-05",
|
||||||
|
"capabilities": {},
|
||||||
|
"clientInfo": {"name": "agentkit", "version": "0.1.0"},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
timeout=self._timeout,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
await self._cleanup()
|
||||||
|
raise TransportError("Timeout waiting for initialize response")
|
||||||
|
except TransportError:
|
||||||
|
await self._cleanup()
|
||||||
|
raise
|
||||||
|
|
||||||
|
# 发送 initialized 通知
|
||||||
|
await self._send_notification("notifications/initialized")
|
||||||
|
|
||||||
|
self._connected = True
|
||||||
|
logger.info(
|
||||||
|
"StdioTransport connected to %s %s",
|
||||||
|
self._command,
|
||||||
|
" ".join(self._args),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def disconnect(self) -> None:
|
||||||
|
"""关闭子进程连接"""
|
||||||
|
self._connected = False
|
||||||
|
await self._cleanup()
|
||||||
|
|
||||||
|
async def _cleanup(self) -> None:
|
||||||
|
"""清理子进程和相关资源"""
|
||||||
|
# 取消读取任务
|
||||||
|
for task in (self._reader_task, self._stderr_task):
|
||||||
|
if task is not None:
|
||||||
|
task.cancel()
|
||||||
|
try:
|
||||||
|
await task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
self._reader_task = None
|
||||||
|
self._stderr_task = None
|
||||||
|
|
||||||
|
# 关闭 stdin
|
||||||
|
if self._process is not None and self._process.stdin is not None:
|
||||||
|
self._process.stdin.close()
|
||||||
|
try:
|
||||||
|
await self._process.stdin.drain()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 等待子进程退出
|
||||||
|
if self._process is not None and self._process.returncode is None:
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(self._process.wait(), timeout=5.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
self._process.kill()
|
||||||
|
await self._process.wait()
|
||||||
|
|
||||||
|
self._process = None
|
||||||
|
|
||||||
|
# 取消所有等待中的 Future
|
||||||
|
for future in self._pending.values():
|
||||||
|
if not future.done():
|
||||||
|
future.set_exception(TransportError("Transport disconnected"))
|
||||||
|
self._pending.clear()
|
||||||
|
|
||||||
|
logger.info("StdioTransport disconnected")
|
||||||
|
|
||||||
|
async def send_request(self, method: str, params: dict[str, Any] | None = None) -> Any:
|
||||||
|
"""发送 JSON-RPC 请求并等待响应
|
||||||
|
|
||||||
|
Args:
|
||||||
|
method: JSON-RPC 方法名
|
||||||
|
params: 请求参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON-RPC 响应的 result 字段
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TransportError: 连接未建立或请求失败
|
||||||
|
"""
|
||||||
|
if not self.is_connected:
|
||||||
|
raise TransportError("Transport not connected")
|
||||||
|
return await self._send_request_internal(method, params)
|
||||||
|
|
||||||
|
async def _send_request_internal(
|
||||||
|
self, method: str, params: dict[str, Any] | None = None
|
||||||
|
) -> Any:
|
||||||
|
"""内部请求发送方法(connect 时也可调用)"""
|
||||||
|
request_id = self._next_request_id()
|
||||||
|
message: dict[str, Any] = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": request_id,
|
||||||
|
"method": method,
|
||||||
|
}
|
||||||
|
if params is not None:
|
||||||
|
message["params"] = params
|
||||||
|
|
||||||
|
await self._write_message(message)
|
||||||
|
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
future: asyncio.Future[Any] = loop.create_future()
|
||||||
|
self._pending[request_id] = future
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await asyncio.wait_for(future, timeout=self._timeout)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
self._pending.pop(request_id, None)
|
||||||
|
raise TransportError(f"Timeout waiting for response to {method}")
|
||||||
|
except TransportError:
|
||||||
|
self._pending.pop(request_id, None)
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _send_notification(self, method: str, params: dict[str, Any] | None = None) -> None:
|
||||||
|
"""发送 JSON-RPC 通知(无 id,不期待响应)"""
|
||||||
|
message: dict[str, Any] = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"method": method,
|
||||||
|
}
|
||||||
|
if params is not None:
|
||||||
|
message["params"] = params
|
||||||
|
await self._write_message(message)
|
||||||
|
|
||||||
|
async def _write_message(self, message: dict[str, Any]) -> None:
|
||||||
|
"""将 JSON-RPC 消息写入子进程 stdin"""
|
||||||
|
if self._process is None or self._process.stdin is None:
|
||||||
|
raise TransportError("Process stdin not available")
|
||||||
|
data = (json.dumps(message) + "\n").encode("utf-8")
|
||||||
|
self._process.stdin.write(data)
|
||||||
|
await self._process.stdin.drain()
|
||||||
|
|
||||||
|
async def receive_response(self) -> dict[str, Any]:
|
||||||
|
"""接收通知消息
|
||||||
|
|
||||||
|
对于 StdioTransport,请求响应通过 _pending Future 异步返回。
|
||||||
|
此方法仅用于获取服务端推送的通知消息。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON-RPC 通知消息
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TransportError: 连接未建立或无通知
|
||||||
|
"""
|
||||||
|
if not self.is_connected:
|
||||||
|
raise TransportError("Transport not connected")
|
||||||
|
|
||||||
|
if not self._notifications.empty():
|
||||||
|
return self._notifications.get_nowait()
|
||||||
|
|
||||||
|
raise TransportError("No notification to receive")
|
||||||
|
|
||||||
|
async def _read_stdout(self) -> None:
|
||||||
|
"""持续从子进程 stdout 读取 JSON-RPC 消息"""
|
||||||
|
if self._process is None or self._process.stdout is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
line = await self._process.stdout.readline()
|
||||||
|
if not line:
|
||||||
|
# EOF — 子进程退出
|
||||||
|
if self._connected:
|
||||||
|
logger.warning("StdioTransport: subprocess stdout EOF")
|
||||||
|
break
|
||||||
|
|
||||||
|
line_str = line.decode("utf-8").strip()
|
||||||
|
if not line_str:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = json.loads(line_str)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning("StdioTransport: invalid JSON from stdout: %s", line_str)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 响应消息(有 id 字段)
|
||||||
|
if "id" in data:
|
||||||
|
request_id = data["id"]
|
||||||
|
future = self._pending.pop(request_id, None)
|
||||||
|
if future is not None and not future.done():
|
||||||
|
if "error" in data:
|
||||||
|
error = data["error"]
|
||||||
|
future.set_exception(
|
||||||
|
TransportError(
|
||||||
|
f"JSON-RPC error {error.get('code')}: {error.get('message')}"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
future.set_result(data.get("result"))
|
||||||
|
elif future is None:
|
||||||
|
logger.warning(
|
||||||
|
"StdioTransport: received response for unknown request id %s",
|
||||||
|
request_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 通知消息(有 method 字段,无 id)
|
||||||
|
elif "method" in data:
|
||||||
|
await self._notifications.put(data)
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
if self._connected:
|
||||||
|
logger.error("StdioTransport: stdout reader error: %s", e)
|
||||||
|
|
||||||
|
async def _read_stderr(self) -> None:
|
||||||
|
"""持续从子进程 stderr 读取并转发到 logger"""
|
||||||
|
if self._process is None or self._process.stderr is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
line = await self._process.stderr.readline()
|
||||||
|
if not line:
|
||||||
|
break
|
||||||
|
line_str = line.decode("utf-8", errors="replace").rstrip()
|
||||||
|
if line_str:
|
||||||
|
logger.debug("StdioTransport stderr: %s", line_str)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
if self._connected:
|
||||||
|
logger.error("StdioTransport: stderr reader error: %s", e)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,749 @@
|
||||||
|
"""StdioTransport 单元测试
|
||||||
|
|
||||||
|
使用内联 mock MCP server 子进程进行测试,无需外部依赖。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import textwrap
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.mcp.transport import StdioTransport, TransportError
|
||||||
|
|
||||||
|
# 内联 mock MCP server 脚本
|
||||||
|
# 读取 stdin 的 JSON-RPC 消息,根据 method 返回对应响应
|
||||||
|
MOCK_SERVER_SCRIPT = textwrap.dedent("""\
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
|
||||||
|
def handle_request(data):
|
||||||
|
method = data.get("method", "")
|
||||||
|
req_id = data.get("id")
|
||||||
|
params = data.get("params", {})
|
||||||
|
|
||||||
|
if method == "initialize":
|
||||||
|
return {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": req_id,
|
||||||
|
"result": {
|
||||||
|
"protocolVersion": "2024-11-05",
|
||||||
|
"capabilities": {"tools": {"listChanged": True}},
|
||||||
|
"serverInfo": {"name": "mock-mcp-server", "version": "0.1.0"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
elif method == "tools/list":
|
||||||
|
return {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": req_id,
|
||||||
|
"result": {
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"name": "echo",
|
||||||
|
"description": "Echo tool",
|
||||||
|
"inputSchema": {"type": "object", "properties": {"msg": {"type": "string"}}},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
elif method == "tools/call":
|
||||||
|
name = params.get("name", "")
|
||||||
|
arguments = params.get("arguments", {})
|
||||||
|
if name == "echo":
|
||||||
|
return {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": req_id,
|
||||||
|
"result": {
|
||||||
|
"content": [{"type": "text", "text": arguments.get("msg", "")}]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": req_id,
|
||||||
|
"error": {"code": -32601, "message": f"Unknown tool: {name}"},
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": req_id,
|
||||||
|
"error": {"code": -32601, "message": f"Method not found: {method}"},
|
||||||
|
}
|
||||||
|
|
||||||
|
def main():
|
||||||
|
for line in sys.stdin:
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
data = json.loads(line)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
# 通知消息(无 id)不回复
|
||||||
|
if "id" not in data:
|
||||||
|
continue
|
||||||
|
response = handle_request(data)
|
||||||
|
sys.stdout.write(json.dumps(response) + "\\n")
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
""")
|
||||||
|
|
||||||
|
# 发送通知后立即退出的 mock server(用于测试子进程退出检测)
|
||||||
|
EXIT_AFTER_INIT_SCRIPT = textwrap.dedent("""\
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
|
||||||
|
for line in sys.stdin:
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
data = json.loads(line)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
if "id" not in data:
|
||||||
|
continue
|
||||||
|
method = data.get("method", "")
|
||||||
|
req_id = data.get("id")
|
||||||
|
if method == "initialize":
|
||||||
|
response = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": req_id,
|
||||||
|
"result": {"protocolVersion": "2024-11-05", "capabilities": {}, "serverInfo": {"name": "exit-server", "version": "0.1.0"}},
|
||||||
|
}
|
||||||
|
sys.stdout.write(json.dumps(response) + "\\n")
|
||||||
|
sys.stdout.flush()
|
||||||
|
# 初始化后立即退出
|
||||||
|
sys.exit(0)
|
||||||
|
else:
|
||||||
|
# 不会到达这里
|
||||||
|
pass
|
||||||
|
""")
|
||||||
|
|
||||||
|
# 发送通知的 mock server(用于测试通知接收)
|
||||||
|
NOTIFICATION_SERVER_SCRIPT = textwrap.dedent("""\
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
|
||||||
|
for line in sys.stdin:
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
data = json.loads(line)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
if "id" not in data:
|
||||||
|
continue
|
||||||
|
method = data.get("method", "")
|
||||||
|
req_id = data.get("id")
|
||||||
|
if method == "initialize":
|
||||||
|
response = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": req_id,
|
||||||
|
"result": {"protocolVersion": "2024-11-05", "capabilities": {}, "serverInfo": {"name": "notif-server", "version": "0.1.0"}},
|
||||||
|
}
|
||||||
|
sys.stdout.write(json.dumps(response) + "\\n")
|
||||||
|
sys.stdout.flush()
|
||||||
|
elif method == "tools/list":
|
||||||
|
# 先发送一个通知
|
||||||
|
notification = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"method": "notifications/tools/list_changed",
|
||||||
|
}
|
||||||
|
sys.stdout.write(json.dumps(notification) + "\\n")
|
||||||
|
sys.stdout.flush()
|
||||||
|
# 再发送响应
|
||||||
|
response = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": req_id,
|
||||||
|
"result": {"tools": [{"name": "updated_tool"}]},
|
||||||
|
}
|
||||||
|
sys.stdout.write(json.dumps(response) + "\\n")
|
||||||
|
sys.stdout.flush()
|
||||||
|
else:
|
||||||
|
response = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": req_id,
|
||||||
|
"error": {"code": -32601, "message": "Method not found"},
|
||||||
|
}
|
||||||
|
sys.stdout.write(json.dumps(response) + "\\n")
|
||||||
|
sys.stdout.flush()
|
||||||
|
""")
|
||||||
|
|
||||||
|
# 写入 stderr 的 mock server
|
||||||
|
STDERR_SERVER_SCRIPT = textwrap.dedent("""\
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
|
||||||
|
for line in sys.stdin:
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
data = json.loads(line)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
if "id" not in data:
|
||||||
|
continue
|
||||||
|
method = data.get("method", "")
|
||||||
|
req_id = data.get("id")
|
||||||
|
if method == "initialize":
|
||||||
|
# 写入 stderr
|
||||||
|
sys.stderr.write("mock server starting\\n")
|
||||||
|
sys.stderr.flush()
|
||||||
|
response = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": req_id,
|
||||||
|
"result": {"protocolVersion": "2024-11-05", "capabilities": {}, "serverInfo": {"name": "stderr-server", "version": "0.1.0"}},
|
||||||
|
}
|
||||||
|
sys.stdout.write(json.dumps(response) + "\\n")
|
||||||
|
sys.stdout.flush()
|
||||||
|
else:
|
||||||
|
response = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": req_id,
|
||||||
|
"result": {},
|
||||||
|
}
|
||||||
|
sys.stdout.write(json.dumps(response) + "\\n")
|
||||||
|
sys.stdout.flush()
|
||||||
|
""")
|
||||||
|
|
||||||
|
|
||||||
|
def _make_transport(script: str, **kwargs) -> StdioTransport:
|
||||||
|
"""创建使用内联 mock server 脚本的 StdioTransport"""
|
||||||
|
return StdioTransport(
|
||||||
|
command=sys.executable,
|
||||||
|
args=["-c", script],
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── 构造测试 ──────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestStdioTransportConstruction:
|
||||||
|
"""StdioTransport 构造测试"""
|
||||||
|
|
||||||
|
def test_default_args(self):
|
||||||
|
transport = StdioTransport(command="echo")
|
||||||
|
assert transport._command == "echo"
|
||||||
|
assert transport._args == []
|
||||||
|
assert transport._env is None
|
||||||
|
assert transport._timeout == 30.0
|
||||||
|
assert transport._process is None
|
||||||
|
assert transport._request_id == 0
|
||||||
|
assert transport._pending == {}
|
||||||
|
assert transport._reader_task is None
|
||||||
|
assert transport._stderr_task is None
|
||||||
|
assert transport._connected is False
|
||||||
|
|
||||||
|
def test_custom_args(self):
|
||||||
|
transport = StdioTransport(
|
||||||
|
command="node",
|
||||||
|
args=["server.js", "--port", "3000"],
|
||||||
|
env={"NODE_ENV": "test"},
|
||||||
|
timeout=10.0,
|
||||||
|
)
|
||||||
|
assert transport._command == "node"
|
||||||
|
assert transport._args == ["server.js", "--port", "3000"]
|
||||||
|
assert transport._env == {"NODE_ENV": "test"}
|
||||||
|
assert transport._timeout == 10.0
|
||||||
|
|
||||||
|
def test_is_connected_initially_false(self):
|
||||||
|
transport = StdioTransport(command="echo")
|
||||||
|
assert not transport.is_connected
|
||||||
|
|
||||||
|
|
||||||
|
# ── 连接/断开测试 ──────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestStdioTransportConnect:
|
||||||
|
"""StdioTransport 连接测试"""
|
||||||
|
|
||||||
|
async def test_connect_starts_subprocess(self):
|
||||||
|
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||||||
|
try:
|
||||||
|
await transport.connect()
|
||||||
|
assert transport.is_connected
|
||||||
|
assert transport._process is not None
|
||||||
|
assert transport._process.returncode is None
|
||||||
|
assert transport._reader_task is not None
|
||||||
|
assert transport._stderr_task is not None
|
||||||
|
finally:
|
||||||
|
await transport.disconnect()
|
||||||
|
|
||||||
|
async def test_connect_completes_initialize_handshake(self):
|
||||||
|
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||||||
|
try:
|
||||||
|
await transport.connect()
|
||||||
|
# connect 成功说明 initialize 握手完成
|
||||||
|
assert transport._connected is True
|
||||||
|
# request_id 应该至少递增到 1(initialize 请求)
|
||||||
|
assert transport._request_id >= 1
|
||||||
|
finally:
|
||||||
|
await transport.disconnect()
|
||||||
|
|
||||||
|
async def test_connect_is_idempotent(self):
|
||||||
|
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||||||
|
try:
|
||||||
|
await transport.connect()
|
||||||
|
await transport.connect() # 不应报错
|
||||||
|
assert transport.is_connected
|
||||||
|
finally:
|
||||||
|
await transport.disconnect()
|
||||||
|
|
||||||
|
async def test_connect_with_invalid_command_raises(self):
|
||||||
|
transport = StdioTransport(command="/nonexistent/command")
|
||||||
|
with pytest.raises(TransportError, match="Failed to start process"):
|
||||||
|
await transport.connect()
|
||||||
|
|
||||||
|
async def test_connect_timeout(self):
|
||||||
|
"""使用不响应 initialize 的子进程测试超时"""
|
||||||
|
# 使用一个只读取 stdin 但不输出任何内容的脚本
|
||||||
|
silent_script = "import sys; sys.stdin.read()"
|
||||||
|
transport = StdioTransport(
|
||||||
|
command=sys.executable,
|
||||||
|
args=["-c", silent_script],
|
||||||
|
timeout=0.5,
|
||||||
|
)
|
||||||
|
with pytest.raises(TransportError, match="Timeout waiting for initialize"):
|
||||||
|
await transport.connect()
|
||||||
|
assert not transport.is_connected
|
||||||
|
|
||||||
|
|
||||||
|
class TestStdioTransportDisconnect:
|
||||||
|
"""StdioTransport 断开测试"""
|
||||||
|
|
||||||
|
async def test_disconnect_closes_subprocess(self):
|
||||||
|
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||||||
|
await transport.connect()
|
||||||
|
assert transport.is_connected
|
||||||
|
|
||||||
|
await transport.disconnect()
|
||||||
|
assert not transport.is_connected
|
||||||
|
assert transport._process is None
|
||||||
|
assert transport._reader_task is None
|
||||||
|
assert transport._stderr_task is None
|
||||||
|
|
||||||
|
async def test_disconnect_is_idempotent(self):
|
||||||
|
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||||||
|
await transport.connect()
|
||||||
|
await transport.disconnect()
|
||||||
|
await transport.disconnect() # 不应报错
|
||||||
|
|
||||||
|
async def test_disconnect_cancels_pending_futures(self):
|
||||||
|
"""断开时所有 pending future 应收到 TransportError"""
|
||||||
|
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||||||
|
await transport.connect()
|
||||||
|
|
||||||
|
# 手动添加一个 pending future
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
future = loop.create_future()
|
||||||
|
transport._pending[999] = future
|
||||||
|
|
||||||
|
await transport.disconnect()
|
||||||
|
|
||||||
|
assert future.done()
|
||||||
|
with pytest.raises(TransportError, match="Transport disconnected"):
|
||||||
|
future.result()
|
||||||
|
|
||||||
|
async def test_disconnect_clears_pending(self):
|
||||||
|
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||||||
|
await transport.connect()
|
||||||
|
transport._pending[1] = asyncio.get_running_loop().create_future()
|
||||||
|
await transport.disconnect()
|
||||||
|
assert transport._pending == {}
|
||||||
|
|
||||||
|
|
||||||
|
# ── 请求发送测试 ──────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestStdioTransportSendRequest:
|
||||||
|
"""StdioTransport 请求发送测试"""
|
||||||
|
|
||||||
|
async def test_send_request_not_connected_raises(self):
|
||||||
|
transport = StdioTransport(command="echo")
|
||||||
|
with pytest.raises(TransportError, match="not connected"):
|
||||||
|
await transport.send_request("tools/list")
|
||||||
|
|
||||||
|
async def test_send_request_tools_list(self):
|
||||||
|
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||||||
|
try:
|
||||||
|
await transport.connect()
|
||||||
|
result = await transport.send_request("tools/list")
|
||||||
|
assert "tools" in result
|
||||||
|
assert len(result["tools"]) == 1
|
||||||
|
assert result["tools"][0]["name"] == "echo"
|
||||||
|
finally:
|
||||||
|
await transport.disconnect()
|
||||||
|
|
||||||
|
async def test_send_request_tools_call(self):
|
||||||
|
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||||||
|
try:
|
||||||
|
await transport.connect()
|
||||||
|
result = await transport.send_request(
|
||||||
|
"tools/call",
|
||||||
|
params={"name": "echo", "arguments": {"msg": "hello world"}},
|
||||||
|
)
|
||||||
|
assert "content" in result
|
||||||
|
assert result["content"][0]["text"] == "hello world"
|
||||||
|
finally:
|
||||||
|
await transport.disconnect()
|
||||||
|
|
||||||
|
async def test_send_request_json_rpc_error(self):
|
||||||
|
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||||||
|
try:
|
||||||
|
await transport.connect()
|
||||||
|
with pytest.raises(TransportError, match="JSON-RPC error"):
|
||||||
|
await transport.send_request("unknown/method")
|
||||||
|
finally:
|
||||||
|
await transport.disconnect()
|
||||||
|
|
||||||
|
async def test_send_request_unknown_tool_error(self):
|
||||||
|
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||||||
|
try:
|
||||||
|
await transport.connect()
|
||||||
|
with pytest.raises(TransportError, match="Unknown tool"):
|
||||||
|
await transport.send_request(
|
||||||
|
"tools/call",
|
||||||
|
params={"name": "nonexistent", "arguments": {}},
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
await transport.disconnect()
|
||||||
|
|
||||||
|
async def test_request_id_increments(self):
|
||||||
|
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||||||
|
try:
|
||||||
|
await transport.connect()
|
||||||
|
# connect 时已经用了 id=1 (initialize)
|
||||||
|
id_before = transport._request_id
|
||||||
|
await transport.send_request("tools/list")
|
||||||
|
id_after_1 = transport._request_id
|
||||||
|
await transport.send_request("tools/list")
|
||||||
|
id_after_2 = transport._request_id
|
||||||
|
assert id_after_1 == id_before + 1
|
||||||
|
assert id_after_2 == id_before + 2
|
||||||
|
finally:
|
||||||
|
await transport.disconnect()
|
||||||
|
|
||||||
|
async def test_send_request_timeout(self):
|
||||||
|
"""请求超时测试"""
|
||||||
|
# 使用一个不响应的脚本
|
||||||
|
silent_script = "import sys; sys.stdin.read()"
|
||||||
|
transport = StdioTransport(
|
||||||
|
command=sys.executable,
|
||||||
|
args=["-c", silent_script],
|
||||||
|
timeout=0.5,
|
||||||
|
)
|
||||||
|
# 手动设置连接状态以绕过 initialize
|
||||||
|
transport._connected = True
|
||||||
|
transport._process = await asyncio.create_subprocess_exec(
|
||||||
|
sys.executable, "-c", silent_script,
|
||||||
|
stdin=asyncio.subprocess.PIPE,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.PIPE,
|
||||||
|
)
|
||||||
|
transport._reader_task = asyncio.create_task(transport._read_stdout())
|
||||||
|
transport._stderr_task = asyncio.create_task(transport._read_stderr())
|
||||||
|
|
||||||
|
try:
|
||||||
|
with pytest.raises(TransportError, match="Timeout"):
|
||||||
|
await transport.send_request("tools/list")
|
||||||
|
finally:
|
||||||
|
transport._connected = False
|
||||||
|
await transport._cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
# ── 并发请求测试 ──────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestStdioTransportConcurrentRequests:
|
||||||
|
"""StdioTransport 并发请求测试"""
|
||||||
|
|
||||||
|
async def test_concurrent_requests_correct_id_matching(self):
|
||||||
|
"""并发请求的响应应正确匹配到对应的 Future"""
|
||||||
|
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||||||
|
try:
|
||||||
|
await transport.connect()
|
||||||
|
|
||||||
|
# 并发发送多个请求
|
||||||
|
results = await asyncio.gather(
|
||||||
|
transport.send_request("tools/list"),
|
||||||
|
transport.send_request(
|
||||||
|
"tools/call",
|
||||||
|
params={"name": "echo", "arguments": {"msg": "msg1"}},
|
||||||
|
),
|
||||||
|
transport.send_request(
|
||||||
|
"tools/call",
|
||||||
|
params={"name": "echo", "arguments": {"msg": "msg2"}},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 验证每个请求都得到了正确类型的响应
|
||||||
|
assert "tools" in results[0]
|
||||||
|
assert results[1]["content"][0]["text"] == "msg1"
|
||||||
|
assert results[2]["content"][0]["text"] == "msg2"
|
||||||
|
finally:
|
||||||
|
await transport.disconnect()
|
||||||
|
|
||||||
|
|
||||||
|
# ── 通知接收测试 ──────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestStdioTransportNotifications:
|
||||||
|
"""StdioTransport 通知接收测试"""
|
||||||
|
|
||||||
|
async def test_receive_notification(self):
|
||||||
|
transport = _make_transport(NOTIFICATION_SERVER_SCRIPT)
|
||||||
|
try:
|
||||||
|
await transport.connect()
|
||||||
|
|
||||||
|
# tools/list 会先发送一个通知
|
||||||
|
result = await transport.send_request("tools/list")
|
||||||
|
assert "tools" in result
|
||||||
|
|
||||||
|
# 等待通知到达
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
# 应该能收到通知
|
||||||
|
notification = await transport.receive_response()
|
||||||
|
assert notification["method"] == "notifications/tools/list_changed"
|
||||||
|
finally:
|
||||||
|
await transport.disconnect()
|
||||||
|
|
||||||
|
async def test_receive_response_no_notification_raises(self):
|
||||||
|
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||||||
|
try:
|
||||||
|
await transport.connect()
|
||||||
|
with pytest.raises(TransportError, match="No notification"):
|
||||||
|
await transport.receive_response()
|
||||||
|
finally:
|
||||||
|
await transport.disconnect()
|
||||||
|
|
||||||
|
async def test_receive_response_not_connected_raises(self):
|
||||||
|
transport = StdioTransport(command="echo")
|
||||||
|
with pytest.raises(TransportError, match="not connected"):
|
||||||
|
await transport.receive_response()
|
||||||
|
|
||||||
|
|
||||||
|
# ── 子进程退出检测测试 ──────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestStdioTransportProcessExit:
|
||||||
|
"""StdioTransport 子进程退出检测测试"""
|
||||||
|
|
||||||
|
async def test_subprocess_exit_detection(self):
|
||||||
|
"""子进程退出后 is_connected 应返回 False"""
|
||||||
|
transport = _make_transport(EXIT_AFTER_INIT_SCRIPT)
|
||||||
|
try:
|
||||||
|
await transport.connect()
|
||||||
|
assert transport.is_connected
|
||||||
|
|
||||||
|
# 等待子进程退出
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
# 子进程已退出,is_connected 应为 False
|
||||||
|
assert not transport.is_connected
|
||||||
|
finally:
|
||||||
|
if transport._process is not None:
|
||||||
|
await transport.disconnect()
|
||||||
|
|
||||||
|
async def test_send_request_after_process_exit_raises(self):
|
||||||
|
"""子进程退出后发送请求应抛出 TransportError"""
|
||||||
|
transport = _make_transport(EXIT_AFTER_INIT_SCRIPT)
|
||||||
|
try:
|
||||||
|
await transport.connect()
|
||||||
|
# 等待子进程退出
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
if not transport.is_connected:
|
||||||
|
with pytest.raises(TransportError, match="not connected"):
|
||||||
|
await transport.send_request("tools/list")
|
||||||
|
finally:
|
||||||
|
if transport._process is not None:
|
||||||
|
await transport.disconnect()
|
||||||
|
|
||||||
|
|
||||||
|
# ── stderr 转发测试 ──────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestStdioTransportStderr:
|
||||||
|
"""StdioTransport stderr 转发测试"""
|
||||||
|
|
||||||
|
async def test_stderr_forwarded_to_logger(self, caplog):
|
||||||
|
"""stderr 输出应转发到 logger"""
|
||||||
|
import logging
|
||||||
|
|
||||||
|
transport = _make_transport(STDERR_SERVER_SCRIPT)
|
||||||
|
try:
|
||||||
|
with caplog.at_level(logging.DEBUG, logger="agentkit.mcp.transport"):
|
||||||
|
await transport.connect()
|
||||||
|
# 发送一个请求触发 stderr 输出
|
||||||
|
await transport.send_request("tools/list")
|
||||||
|
# 等待 stderr 被读取
|
||||||
|
await asyncio.sleep(0.2)
|
||||||
|
|
||||||
|
# 检查日志中包含 stderr 输出
|
||||||
|
stderr_logs = [
|
||||||
|
r for r in caplog.records
|
||||||
|
if "mock server starting" in r.message
|
||||||
|
]
|
||||||
|
assert len(stderr_logs) > 0
|
||||||
|
finally:
|
||||||
|
await transport.disconnect()
|
||||||
|
|
||||||
|
|
||||||
|
# ── is_connected 属性测试 ──────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestStdioTransportIsConnected:
|
||||||
|
"""StdioTransport is_connected 属性测试"""
|
||||||
|
|
||||||
|
async def test_is_connected_before_connect(self):
|
||||||
|
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||||||
|
assert not transport.is_connected
|
||||||
|
|
||||||
|
async def test_is_connected_after_connect(self):
|
||||||
|
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||||||
|
try:
|
||||||
|
await transport.connect()
|
||||||
|
assert transport.is_connected
|
||||||
|
finally:
|
||||||
|
await transport.disconnect()
|
||||||
|
|
||||||
|
async def test_is_connected_after_disconnect(self):
|
||||||
|
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||||||
|
await transport.connect()
|
||||||
|
await transport.disconnect()
|
||||||
|
assert not transport.is_connected
|
||||||
|
|
||||||
|
async def test_is_connected_checks_process_returncode(self):
|
||||||
|
"""is_connected 应检查 process.returncode"""
|
||||||
|
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||||||
|
try:
|
||||||
|
await transport.connect()
|
||||||
|
assert transport._process is not None
|
||||||
|
# 模拟进程退出但 _connected 仍为 True
|
||||||
|
transport._connected = True
|
||||||
|
# 终止子进程
|
||||||
|
transport._process.kill()
|
||||||
|
await transport._process.wait()
|
||||||
|
# is_connected 应为 False 因为 returncode 不为 None
|
||||||
|
assert not transport.is_connected
|
||||||
|
finally:
|
||||||
|
transport._connected = False
|
||||||
|
await transport._cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
# ── 完整生命周期测试 ──────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestStdioTransportLifecycle:
|
||||||
|
"""StdioTransport 完整生命周期测试"""
|
||||||
|
|
||||||
|
async def test_full_lifecycle(self):
|
||||||
|
"""测试完整的 connect → send_request → disconnect 生命周期"""
|
||||||
|
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||||||
|
|
||||||
|
# 1. 连接
|
||||||
|
await transport.connect()
|
||||||
|
assert transport.is_connected
|
||||||
|
|
||||||
|
# 2. 发送请求
|
||||||
|
result = await transport.send_request("tools/list")
|
||||||
|
assert "tools" in result
|
||||||
|
|
||||||
|
# 3. 发送带参数的请求
|
||||||
|
result = await transport.send_request(
|
||||||
|
"tools/call",
|
||||||
|
params={"name": "echo", "arguments": {"msg": "test"}},
|
||||||
|
)
|
||||||
|
assert result["content"][0]["text"] == "test"
|
||||||
|
|
||||||
|
# 4. 断开
|
||||||
|
await transport.disconnect()
|
||||||
|
assert not transport.is_connected
|
||||||
|
|
||||||
|
async def test_reconnect_after_disconnect(self):
|
||||||
|
"""测试断开后重新连接"""
|
||||||
|
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||||||
|
|
||||||
|
# 第一次连接
|
||||||
|
await transport.connect()
|
||||||
|
result1 = await transport.send_request("tools/list")
|
||||||
|
assert "tools" in result1
|
||||||
|
await transport.disconnect()
|
||||||
|
|
||||||
|
# 重新连接
|
||||||
|
await transport.connect()
|
||||||
|
result2 = await transport.send_request("tools/list")
|
||||||
|
assert "tools" in result2
|
||||||
|
await transport.disconnect()
|
||||||
|
|
||||||
|
async def test_env_variables_passed_to_subprocess(self):
|
||||||
|
"""测试环境变量传递给子进程"""
|
||||||
|
# 使用打印环境变量的脚本
|
||||||
|
env_script = textwrap.dedent("""\
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
for line in sys.stdin:
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
data = json.loads(line)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
if "id" not in data:
|
||||||
|
continue
|
||||||
|
method = data.get("method", "")
|
||||||
|
req_id = data.get("id")
|
||||||
|
if method == "initialize":
|
||||||
|
response = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": req_id,
|
||||||
|
"result": {"protocolVersion": "2024-11-05", "capabilities": {}, "serverInfo": {"name": "env-server", "version": "0.1.0"}},
|
||||||
|
}
|
||||||
|
sys.stdout.write(json.dumps(response) + "\\n")
|
||||||
|
sys.stdout.flush()
|
||||||
|
elif method == "tools/call":
|
||||||
|
test_env = os.environ.get("TEST_MCP_VAR", "not_set")
|
||||||
|
response = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": req_id,
|
||||||
|
"result": {
|
||||||
|
"content": [{"type": "text", "text": test_env}]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
sys.stdout.write(json.dumps(response) + "\\n")
|
||||||
|
sys.stdout.flush()
|
||||||
|
else:
|
||||||
|
response = {
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"id": req_id,
|
||||||
|
"result": {},
|
||||||
|
}
|
||||||
|
sys.stdout.write(json.dumps(response) + "\\n")
|
||||||
|
sys.stdout.flush()
|
||||||
|
""")
|
||||||
|
|
||||||
|
transport = StdioTransport(
|
||||||
|
command=sys.executable,
|
||||||
|
args=["-c", env_script],
|
||||||
|
env={"TEST_MCP_VAR": "hello_from_env"},
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
await transport.connect()
|
||||||
|
result = await transport.send_request(
|
||||||
|
"tools/call",
|
||||||
|
params={"name": "check_env", "arguments": {}},
|
||||||
|
)
|
||||||
|
assert result["content"][0]["text"] == "hello_from_env"
|
||||||
|
finally:
|
||||||
|
await transport.disconnect()
|
||||||
Loading…
Reference in New Issue