fischer-agentkit/tests/unit/test_stdio_transport.py

753 lines
26 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""StdioTransport 单元测试
使用内联 mock MCP server 子进程进行测试,无需外部依赖。
"""
import asyncio
import json
import sys
import textwrap
import pytest
from agentkit.mcp.transport import StdioTransport, TransportError
# 内联 mock MCP server 脚本
# 读取 stdin 的 JSON-RPC 消息,根据 method 返回对应响应
MOCK_SERVER_SCRIPT = textwrap.dedent("""\
import sys
import json
def handle_request(data):
method = data.get("method", "")
req_id = data.get("id")
params = data.get("params", {})
if method == "initialize":
return {
"jsonrpc": "2.0",
"id": req_id,
"result": {
"protocolVersion": "2024-11-05",
"capabilities": {"tools": {"listChanged": True}},
"serverInfo": {"name": "mock-mcp-server", "version": "0.1.0"},
},
}
elif method == "tools/list":
return {
"jsonrpc": "2.0",
"id": req_id,
"result": {
"tools": [
{
"name": "echo",
"description": "Echo tool",
"inputSchema": {"type": "object", "properties": {"msg": {"type": "string"}}},
}
]
},
}
elif method == "tools/call":
name = params.get("name", "")
arguments = params.get("arguments", {})
if name == "echo":
return {
"jsonrpc": "2.0",
"id": req_id,
"result": {
"content": [{"type": "text", "text": arguments.get("msg", "")}]
},
}
else:
return {
"jsonrpc": "2.0",
"id": req_id,
"error": {"code": -32601, "message": f"Unknown tool: {name}"},
}
else:
return {
"jsonrpc": "2.0",
"id": req_id,
"error": {"code": -32601, "message": f"Method not found: {method}"},
}
def main():
for line in sys.stdin:
line = line.strip()
if not line:
continue
try:
data = json.loads(line)
except json.JSONDecodeError:
continue
# 通知消息(无 id不回复
if "id" not in data:
continue
response = handle_request(data)
sys.stdout.write(json.dumps(response) + "\\n")
sys.stdout.flush()
if __name__ == "__main__":
main()
""")
# 发送通知后立即退出的 mock server用于测试子进程退出检测
EXIT_AFTER_INIT_SCRIPT = textwrap.dedent("""\
import sys
import json
for line in sys.stdin:
line = line.strip()
if not line:
continue
try:
data = json.loads(line)
except json.JSONDecodeError:
continue
if "id" not in data:
continue
method = data.get("method", "")
req_id = data.get("id")
if method == "initialize":
response = {
"jsonrpc": "2.0",
"id": req_id,
"result": {"protocolVersion": "2024-11-05", "capabilities": {}, "serverInfo": {"name": "exit-server", "version": "0.1.0"}},
}
sys.stdout.write(json.dumps(response) + "\\n")
sys.stdout.flush()
# 初始化后立即退出
sys.exit(0)
else:
# 不会到达这里
pass
""")
# 发送通知的 mock server用于测试通知接收
NOTIFICATION_SERVER_SCRIPT = textwrap.dedent("""\
import sys
import json
for line in sys.stdin:
line = line.strip()
if not line:
continue
try:
data = json.loads(line)
except json.JSONDecodeError:
continue
if "id" not in data:
continue
method = data.get("method", "")
req_id = data.get("id")
if method == "initialize":
response = {
"jsonrpc": "2.0",
"id": req_id,
"result": {"protocolVersion": "2024-11-05", "capabilities": {}, "serverInfo": {"name": "notif-server", "version": "0.1.0"}},
}
sys.stdout.write(json.dumps(response) + "\\n")
sys.stdout.flush()
elif method == "tools/list":
# 先发送一个通知
notification = {
"jsonrpc": "2.0",
"method": "notifications/tools/list_changed",
}
sys.stdout.write(json.dumps(notification) + "\\n")
sys.stdout.flush()
# 再发送响应
response = {
"jsonrpc": "2.0",
"id": req_id,
"result": {"tools": [{"name": "updated_tool"}]},
}
sys.stdout.write(json.dumps(response) + "\\n")
sys.stdout.flush()
else:
response = {
"jsonrpc": "2.0",
"id": req_id,
"error": {"code": -32601, "message": "Method not found"},
}
sys.stdout.write(json.dumps(response) + "\\n")
sys.stdout.flush()
""")
# 写入 stderr 的 mock server
STDERR_SERVER_SCRIPT = textwrap.dedent("""\
import sys
import json
for line in sys.stdin:
line = line.strip()
if not line:
continue
try:
data = json.loads(line)
except json.JSONDecodeError:
continue
if "id" not in data:
continue
method = data.get("method", "")
req_id = data.get("id")
if method == "initialize":
# 写入 stderr
sys.stderr.write("mock server starting\\n")
sys.stderr.flush()
response = {
"jsonrpc": "2.0",
"id": req_id,
"result": {"protocolVersion": "2024-11-05", "capabilities": {}, "serverInfo": {"name": "stderr-server", "version": "0.1.0"}},
}
sys.stdout.write(json.dumps(response) + "\\n")
sys.stdout.flush()
else:
response = {
"jsonrpc": "2.0",
"id": req_id,
"result": {},
}
sys.stdout.write(json.dumps(response) + "\\n")
sys.stdout.flush()
""")
def _make_transport(script: str, **kwargs) -> StdioTransport:
"""创建使用内联 mock server 脚本的 StdioTransport"""
return StdioTransport(
command=sys.executable,
args=["-c", script],
**kwargs,
)
# ── 构造测试 ──────────────────────────────────────────
class TestStdioTransportConstruction:
"""StdioTransport 构造测试"""
def test_default_args(self):
transport = StdioTransport(command="echo")
assert transport._command == "echo"
assert transport._args == []
assert transport._env is None
assert transport._timeout == 30.0
assert transport._process is None
assert transport._request_id == 0
assert transport._pending == {}
assert transport._reader_task is None
assert transport._stderr_task is None
assert transport._connected is False
def test_custom_args(self):
transport = StdioTransport(
command="node",
args=["server.js", "--port", "3000"],
env={"NODE_ENV": "test"},
timeout=10.0,
)
assert transport._command == "node"
assert transport._args == ["server.js", "--port", "3000"]
assert transport._env == {"NODE_ENV": "test"}
assert transport._timeout == 10.0
def test_is_connected_initially_false(self):
transport = StdioTransport(command="echo")
assert not transport.is_connected
# ── 连接/断开测试 ──────────────────────────────────────────
class TestStdioTransportConnect:
"""StdioTransport 连接测试"""
async def test_connect_starts_subprocess(self):
transport = _make_transport(MOCK_SERVER_SCRIPT)
try:
await transport.connect()
assert transport.is_connected
assert transport._process is not None
assert transport._process.returncode is None
assert transport._reader_task is not None
assert transport._stderr_task is not None
finally:
await transport.disconnect()
async def test_connect_completes_initialize_handshake(self):
transport = _make_transport(MOCK_SERVER_SCRIPT)
try:
await transport.connect()
# connect 成功说明 initialize 握手完成
assert transport._connected is True
# request_id 应该至少递增到 1initialize 请求)
assert transport._request_id >= 1
finally:
await transport.disconnect()
async def test_connect_is_idempotent(self):
transport = _make_transport(MOCK_SERVER_SCRIPT)
try:
await transport.connect()
await transport.connect() # 不应报错
assert transport.is_connected
finally:
await transport.disconnect()
async def test_connect_with_invalid_command_raises(self):
transport = StdioTransport(command="/nonexistent/command")
with pytest.raises(TransportError, match="Failed to start process"):
await transport.connect()
async def test_connect_timeout(self):
"""使用不响应 initialize 的子进程测试超时"""
# 使用一个只读取 stdin 但不输出任何内容的脚本
silent_script = "import sys; sys.stdin.read()"
transport = StdioTransport(
command=sys.executable,
args=["-c", silent_script],
timeout=0.5,
)
with pytest.raises(TransportError, match="Timeout waiting for initialize"):
await transport.connect()
assert not transport.is_connected
class TestStdioTransportDisconnect:
"""StdioTransport 断开测试"""
async def test_disconnect_closes_subprocess(self):
transport = _make_transport(MOCK_SERVER_SCRIPT)
await transport.connect()
assert transport.is_connected
await transport.disconnect()
assert not transport.is_connected
assert transport._process is None
assert transport._reader_task is None
assert transport._stderr_task is None
async def test_disconnect_is_idempotent(self):
transport = _make_transport(MOCK_SERVER_SCRIPT)
await transport.connect()
await transport.disconnect()
await transport.disconnect() # 不应报错
async def test_disconnect_cancels_pending_futures(self):
"""断开时所有 pending future 应收到 TransportError"""
transport = _make_transport(MOCK_SERVER_SCRIPT)
await transport.connect()
# 手动添加一个 pending future
loop = asyncio.get_running_loop()
future = loop.create_future()
transport._pending[999] = future
await transport.disconnect()
assert future.done()
with pytest.raises(TransportError, match="Transport disconnected"):
future.result()
async def test_disconnect_clears_pending(self):
transport = _make_transport(MOCK_SERVER_SCRIPT)
await transport.connect()
transport._pending[1] = asyncio.get_running_loop().create_future()
await transport.disconnect()
assert transport._pending == {}
# ── 请求发送测试 ──────────────────────────────────────────
class TestStdioTransportSendRequest:
"""StdioTransport 请求发送测试"""
async def test_send_request_not_connected_raises(self):
transport = StdioTransport(command="echo")
with pytest.raises(TransportError, match="not connected"):
await transport.send_request("tools/list")
async def test_send_request_tools_list(self):
transport = _make_transport(MOCK_SERVER_SCRIPT)
try:
await transport.connect()
result = await transport.send_request("tools/list")
assert "tools" in result
assert len(result["tools"]) == 1
assert result["tools"][0]["name"] == "echo"
finally:
await transport.disconnect()
async def test_send_request_tools_call(self):
transport = _make_transport(MOCK_SERVER_SCRIPT)
try:
await transport.connect()
result = await transport.send_request(
"tools/call",
params={"name": "echo", "arguments": {"msg": "hello world"}},
)
assert "content" in result
assert result["content"][0]["text"] == "hello world"
finally:
await transport.disconnect()
async def test_send_request_json_rpc_error(self):
transport = _make_transport(MOCK_SERVER_SCRIPT)
try:
await transport.connect()
with pytest.raises(TransportError, match="JSON-RPC error"):
await transport.send_request("unknown/method")
finally:
await transport.disconnect()
async def test_send_request_unknown_tool_error(self):
transport = _make_transport(MOCK_SERVER_SCRIPT)
try:
await transport.connect()
with pytest.raises(TransportError, match="Unknown tool"):
await transport.send_request(
"tools/call",
params={"name": "nonexistent", "arguments": {}},
)
finally:
await transport.disconnect()
async def test_request_id_increments(self):
transport = _make_transport(MOCK_SERVER_SCRIPT)
try:
await transport.connect()
# connect 时已经用了 id=1 (initialize)
id_before = transport._request_id
await transport.send_request("tools/list")
id_after_1 = transport._request_id
await transport.send_request("tools/list")
id_after_2 = transport._request_id
assert id_after_1 == id_before + 1
assert id_after_2 == id_before + 2
finally:
await transport.disconnect()
async def test_send_request_timeout(self):
"""请求超时测试"""
# 使用一个不响应的脚本
silent_script = "import sys; sys.stdin.read()"
transport = StdioTransport(
command=sys.executable,
args=["-c", silent_script],
timeout=0.5,
)
# 手动设置连接状态以绕过 initialize
transport._connected = True
transport._process = await asyncio.create_subprocess_exec(
sys.executable, "-c", silent_script,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
transport._reader_task = asyncio.create_task(transport._read_stdout())
transport._stderr_task = asyncio.create_task(transport._read_stderr())
try:
with pytest.raises(TransportError, match="Timeout"):
await transport.send_request("tools/list")
finally:
transport._connected = False
await transport._cleanup()
# ── 并发请求测试 ──────────────────────────────────────────
class TestStdioTransportConcurrentRequests:
"""StdioTransport 并发请求测试"""
async def test_concurrent_requests_correct_id_matching(self):
"""并发请求的响应应正确匹配到对应的 Future"""
transport = _make_transport(MOCK_SERVER_SCRIPT)
try:
await transport.connect()
# 并发发送多个请求
results = await asyncio.gather(
transport.send_request("tools/list"),
transport.send_request(
"tools/call",
params={"name": "echo", "arguments": {"msg": "msg1"}},
),
transport.send_request(
"tools/call",
params={"name": "echo", "arguments": {"msg": "msg2"}},
),
)
# 验证每个请求都得到了正确类型的响应
assert "tools" in results[0]
assert results[1]["content"][0]["text"] == "msg1"
assert results[2]["content"][0]["text"] == "msg2"
finally:
await transport.disconnect()
# ── 通知接收测试 ──────────────────────────────────────────
class TestStdioTransportNotifications:
"""StdioTransport 通知接收测试"""
async def test_receive_notification(self):
transport = _make_transport(NOTIFICATION_SERVER_SCRIPT)
try:
await transport.connect()
# tools/list 会先发送一个通知
result = await transport.send_request("tools/list")
assert "tools" in result
# 等待通知到达
await asyncio.sleep(0.1)
# 应该能收到通知
notification = await transport.receive_response()
assert notification["method"] == "notifications/tools/list_changed"
finally:
await transport.disconnect()
async def test_receive_response_no_notification_raises(self):
"""空通知队列时 receive_response 超时抛出 TransportError"""
transport = _make_transport(MOCK_SERVER_SCRIPT)
try:
await transport.connect()
# 临时缩短 receive_response 超时
transport._timeout = 0.1
with pytest.raises(TransportError, match="Timeout"):
await transport.receive_response()
finally:
await transport.disconnect()
async def test_receive_response_not_connected_raises(self):
transport = StdioTransport(command="echo")
with pytest.raises(TransportError, match="not connected"):
await transport.receive_response()
# ── 子进程退出检测测试 ──────────────────────────────────────────
class TestStdioTransportProcessExit:
"""StdioTransport 子进程退出检测测试"""
async def test_subprocess_exit_detection(self):
"""子进程退出后 is_connected 应返回 False"""
transport = _make_transport(EXIT_AFTER_INIT_SCRIPT)
try:
await transport.connect()
assert transport.is_connected
# 等待子进程退出
await asyncio.sleep(0.5)
# 子进程已退出is_connected 应为 False
assert not transport.is_connected
finally:
if transport._process is not None:
await transport.disconnect()
async def test_send_request_after_process_exit_raises(self):
"""子进程退出后发送请求应抛出 TransportError"""
transport = _make_transport(EXIT_AFTER_INIT_SCRIPT)
try:
await transport.connect()
# 等待子进程退出
await asyncio.sleep(0.5)
if not transport.is_connected:
with pytest.raises(TransportError, match="not connected"):
await transport.send_request("tools/list")
finally:
if transport._process is not None:
await transport.disconnect()
# ── stderr 转发测试 ──────────────────────────────────────────
class TestStdioTransportStderr:
"""StdioTransport stderr 转发测试"""
async def test_stderr_forwarded_to_logger(self, caplog):
"""stderr 输出应转发到 logger"""
import logging
transport = _make_transport(STDERR_SERVER_SCRIPT)
try:
with caplog.at_level(logging.DEBUG, logger="agentkit.mcp.transport"):
await transport.connect()
# 发送一个请求触发 stderr 输出
await transport.send_request("tools/list")
# 等待 stderr 被读取
await asyncio.sleep(0.2)
# 检查日志中包含 stderr 输出
stderr_logs = [
r for r in caplog.records
if "mock server starting" in r.message
]
assert len(stderr_logs) > 0
finally:
await transport.disconnect()
# ── is_connected 属性测试 ──────────────────────────────────────────
class TestStdioTransportIsConnected:
"""StdioTransport is_connected 属性测试"""
async def test_is_connected_before_connect(self):
transport = _make_transport(MOCK_SERVER_SCRIPT)
assert not transport.is_connected
async def test_is_connected_after_connect(self):
transport = _make_transport(MOCK_SERVER_SCRIPT)
try:
await transport.connect()
assert transport.is_connected
finally:
await transport.disconnect()
async def test_is_connected_after_disconnect(self):
transport = _make_transport(MOCK_SERVER_SCRIPT)
await transport.connect()
await transport.disconnect()
assert not transport.is_connected
async def test_is_connected_checks_process_returncode(self):
"""is_connected 应检查 process.returncode"""
transport = _make_transport(MOCK_SERVER_SCRIPT)
try:
await transport.connect()
assert transport._process is not None
# 模拟进程退出但 _connected 仍为 True
transport._connected = True
# 终止子进程
transport._process.kill()
await transport._process.wait()
# is_connected 应为 False 因为 returncode 不为 None
assert not transport.is_connected
finally:
transport._connected = False
await transport._cleanup()
# ── 完整生命周期测试 ──────────────────────────────────────────
class TestStdioTransportLifecycle:
"""StdioTransport 完整生命周期测试"""
async def test_full_lifecycle(self):
"""测试完整的 connect → send_request → disconnect 生命周期"""
transport = _make_transport(MOCK_SERVER_SCRIPT)
# 1. 连接
await transport.connect()
assert transport.is_connected
# 2. 发送请求
result = await transport.send_request("tools/list")
assert "tools" in result
# 3. 发送带参数的请求
result = await transport.send_request(
"tools/call",
params={"name": "echo", "arguments": {"msg": "test"}},
)
assert result["content"][0]["text"] == "test"
# 4. 断开
await transport.disconnect()
assert not transport.is_connected
async def test_reconnect_after_disconnect(self):
"""测试断开后重新连接"""
transport = _make_transport(MOCK_SERVER_SCRIPT)
# 第一次连接
await transport.connect()
result1 = await transport.send_request("tools/list")
assert "tools" in result1
await transport.disconnect()
# 重新连接
await transport.connect()
result2 = await transport.send_request("tools/list")
assert "tools" in result2
await transport.disconnect()
async def test_env_variables_passed_to_subprocess(self):
"""测试环境变量传递给子进程"""
# 使用打印环境变量的脚本
env_script = textwrap.dedent("""\
import sys
import json
import os
for line in sys.stdin:
line = line.strip()
if not line:
continue
try:
data = json.loads(line)
except json.JSONDecodeError:
continue
if "id" not in data:
continue
method = data.get("method", "")
req_id = data.get("id")
if method == "initialize":
response = {
"jsonrpc": "2.0",
"id": req_id,
"result": {"protocolVersion": "2024-11-05", "capabilities": {}, "serverInfo": {"name": "env-server", "version": "0.1.0"}},
}
sys.stdout.write(json.dumps(response) + "\\n")
sys.stdout.flush()
elif method == "tools/call":
test_env = os.environ.get("TEST_MCP_VAR", "not_set")
response = {
"jsonrpc": "2.0",
"id": req_id,
"result": {
"content": [{"type": "text", "text": test_env}]
},
}
sys.stdout.write(json.dumps(response) + "\\n")
sys.stdout.flush()
else:
response = {
"jsonrpc": "2.0",
"id": req_id,
"result": {},
}
sys.stdout.write(json.dumps(response) + "\\n")
sys.stdout.flush()
""")
transport = StdioTransport(
command=sys.executable,
args=["-c", env_script],
env={"TEST_MCP_VAR": "hello_from_env"},
)
try:
await transport.connect()
result = await transport.send_request(
"tools/call",
params={"name": "check_env", "arguments": {}},
)
assert result["content"][0]["text"] == "hello_from_env"
finally:
await transport.disconnect()