750 lines
26 KiB
Python
750 lines
26 KiB
Python
"""StdioTransport 单元测试
|
||
|
||
使用内联 mock MCP server 子进程进行测试,无需外部依赖。
|
||
"""
|
||
|
||
import asyncio
|
||
import json
|
||
import sys
|
||
import textwrap
|
||
|
||
import pytest
|
||
|
||
from agentkit.mcp.transport import StdioTransport, TransportError
|
||
|
||
# 内联 mock MCP server 脚本
|
||
# 读取 stdin 的 JSON-RPC 消息,根据 method 返回对应响应
|
||
MOCK_SERVER_SCRIPT = textwrap.dedent("""\
|
||
import sys
|
||
import json
|
||
|
||
def handle_request(data):
|
||
method = data.get("method", "")
|
||
req_id = data.get("id")
|
||
params = data.get("params", {})
|
||
|
||
if method == "initialize":
|
||
return {
|
||
"jsonrpc": "2.0",
|
||
"id": req_id,
|
||
"result": {
|
||
"protocolVersion": "2024-11-05",
|
||
"capabilities": {"tools": {"listChanged": True}},
|
||
"serverInfo": {"name": "mock-mcp-server", "version": "0.1.0"},
|
||
},
|
||
}
|
||
elif method == "tools/list":
|
||
return {
|
||
"jsonrpc": "2.0",
|
||
"id": req_id,
|
||
"result": {
|
||
"tools": [
|
||
{
|
||
"name": "echo",
|
||
"description": "Echo tool",
|
||
"inputSchema": {"type": "object", "properties": {"msg": {"type": "string"}}},
|
||
}
|
||
]
|
||
},
|
||
}
|
||
elif method == "tools/call":
|
||
name = params.get("name", "")
|
||
arguments = params.get("arguments", {})
|
||
if name == "echo":
|
||
return {
|
||
"jsonrpc": "2.0",
|
||
"id": req_id,
|
||
"result": {
|
||
"content": [{"type": "text", "text": arguments.get("msg", "")}]
|
||
},
|
||
}
|
||
else:
|
||
return {
|
||
"jsonrpc": "2.0",
|
||
"id": req_id,
|
||
"error": {"code": -32601, "message": f"Unknown tool: {name}"},
|
||
}
|
||
else:
|
||
return {
|
||
"jsonrpc": "2.0",
|
||
"id": req_id,
|
||
"error": {"code": -32601, "message": f"Method not found: {method}"},
|
||
}
|
||
|
||
def main():
|
||
for line in sys.stdin:
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
try:
|
||
data = json.loads(line)
|
||
except json.JSONDecodeError:
|
||
continue
|
||
# 通知消息(无 id)不回复
|
||
if "id" not in data:
|
||
continue
|
||
response = handle_request(data)
|
||
sys.stdout.write(json.dumps(response) + "\\n")
|
||
sys.stdout.flush()
|
||
|
||
if __name__ == "__main__":
|
||
main()
|
||
""")
|
||
|
||
# 发送通知后立即退出的 mock server(用于测试子进程退出检测)
|
||
EXIT_AFTER_INIT_SCRIPT = textwrap.dedent("""\
|
||
import sys
|
||
import json
|
||
|
||
for line in sys.stdin:
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
try:
|
||
data = json.loads(line)
|
||
except json.JSONDecodeError:
|
||
continue
|
||
if "id" not in data:
|
||
continue
|
||
method = data.get("method", "")
|
||
req_id = data.get("id")
|
||
if method == "initialize":
|
||
response = {
|
||
"jsonrpc": "2.0",
|
||
"id": req_id,
|
||
"result": {"protocolVersion": "2024-11-05", "capabilities": {}, "serverInfo": {"name": "exit-server", "version": "0.1.0"}},
|
||
}
|
||
sys.stdout.write(json.dumps(response) + "\\n")
|
||
sys.stdout.flush()
|
||
# 初始化后立即退出
|
||
sys.exit(0)
|
||
else:
|
||
# 不会到达这里
|
||
pass
|
||
""")
|
||
|
||
# 发送通知的 mock server(用于测试通知接收)
|
||
NOTIFICATION_SERVER_SCRIPT = textwrap.dedent("""\
|
||
import sys
|
||
import json
|
||
|
||
for line in sys.stdin:
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
try:
|
||
data = json.loads(line)
|
||
except json.JSONDecodeError:
|
||
continue
|
||
if "id" not in data:
|
||
continue
|
||
method = data.get("method", "")
|
||
req_id = data.get("id")
|
||
if method == "initialize":
|
||
response = {
|
||
"jsonrpc": "2.0",
|
||
"id": req_id,
|
||
"result": {"protocolVersion": "2024-11-05", "capabilities": {}, "serverInfo": {"name": "notif-server", "version": "0.1.0"}},
|
||
}
|
||
sys.stdout.write(json.dumps(response) + "\\n")
|
||
sys.stdout.flush()
|
||
elif method == "tools/list":
|
||
# 先发送一个通知
|
||
notification = {
|
||
"jsonrpc": "2.0",
|
||
"method": "notifications/tools/list_changed",
|
||
}
|
||
sys.stdout.write(json.dumps(notification) + "\\n")
|
||
sys.stdout.flush()
|
||
# 再发送响应
|
||
response = {
|
||
"jsonrpc": "2.0",
|
||
"id": req_id,
|
||
"result": {"tools": [{"name": "updated_tool"}]},
|
||
}
|
||
sys.stdout.write(json.dumps(response) + "\\n")
|
||
sys.stdout.flush()
|
||
else:
|
||
response = {
|
||
"jsonrpc": "2.0",
|
||
"id": req_id,
|
||
"error": {"code": -32601, "message": "Method not found"},
|
||
}
|
||
sys.stdout.write(json.dumps(response) + "\\n")
|
||
sys.stdout.flush()
|
||
""")
|
||
|
||
# 写入 stderr 的 mock server
|
||
STDERR_SERVER_SCRIPT = textwrap.dedent("""\
|
||
import sys
|
||
import json
|
||
|
||
for line in sys.stdin:
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
try:
|
||
data = json.loads(line)
|
||
except json.JSONDecodeError:
|
||
continue
|
||
if "id" not in data:
|
||
continue
|
||
method = data.get("method", "")
|
||
req_id = data.get("id")
|
||
if method == "initialize":
|
||
# 写入 stderr
|
||
sys.stderr.write("mock server starting\\n")
|
||
sys.stderr.flush()
|
||
response = {
|
||
"jsonrpc": "2.0",
|
||
"id": req_id,
|
||
"result": {"protocolVersion": "2024-11-05", "capabilities": {}, "serverInfo": {"name": "stderr-server", "version": "0.1.0"}},
|
||
}
|
||
sys.stdout.write(json.dumps(response) + "\\n")
|
||
sys.stdout.flush()
|
||
else:
|
||
response = {
|
||
"jsonrpc": "2.0",
|
||
"id": req_id,
|
||
"result": {},
|
||
}
|
||
sys.stdout.write(json.dumps(response) + "\\n")
|
||
sys.stdout.flush()
|
||
""")
|
||
|
||
|
||
def _make_transport(script: str, **kwargs) -> StdioTransport:
|
||
"""创建使用内联 mock server 脚本的 StdioTransport"""
|
||
return StdioTransport(
|
||
command=sys.executable,
|
||
args=["-c", script],
|
||
**kwargs,
|
||
)
|
||
|
||
|
||
# ── 构造测试 ──────────────────────────────────────────
|
||
|
||
|
||
class TestStdioTransportConstruction:
|
||
"""StdioTransport 构造测试"""
|
||
|
||
def test_default_args(self):
|
||
transport = StdioTransport(command="echo")
|
||
assert transport._command == "echo"
|
||
assert transport._args == []
|
||
assert transport._env is None
|
||
assert transport._timeout == 30.0
|
||
assert transport._process is None
|
||
assert transport._request_id == 0
|
||
assert transport._pending == {}
|
||
assert transport._reader_task is None
|
||
assert transport._stderr_task is None
|
||
assert transport._connected is False
|
||
|
||
def test_custom_args(self):
|
||
transport = StdioTransport(
|
||
command="node",
|
||
args=["server.js", "--port", "3000"],
|
||
env={"NODE_ENV": "test"},
|
||
timeout=10.0,
|
||
)
|
||
assert transport._command == "node"
|
||
assert transport._args == ["server.js", "--port", "3000"]
|
||
assert transport._env == {"NODE_ENV": "test"}
|
||
assert transport._timeout == 10.0
|
||
|
||
def test_is_connected_initially_false(self):
|
||
transport = StdioTransport(command="echo")
|
||
assert not transport.is_connected
|
||
|
||
|
||
# ── 连接/断开测试 ──────────────────────────────────────────
|
||
|
||
|
||
class TestStdioTransportConnect:
|
||
"""StdioTransport 连接测试"""
|
||
|
||
async def test_connect_starts_subprocess(self):
|
||
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||
try:
|
||
await transport.connect()
|
||
assert transport.is_connected
|
||
assert transport._process is not None
|
||
assert transport._process.returncode is None
|
||
assert transport._reader_task is not None
|
||
assert transport._stderr_task is not None
|
||
finally:
|
||
await transport.disconnect()
|
||
|
||
async def test_connect_completes_initialize_handshake(self):
|
||
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||
try:
|
||
await transport.connect()
|
||
# connect 成功说明 initialize 握手完成
|
||
assert transport._connected is True
|
||
# request_id 应该至少递增到 1(initialize 请求)
|
||
assert transport._request_id >= 1
|
||
finally:
|
||
await transport.disconnect()
|
||
|
||
async def test_connect_is_idempotent(self):
|
||
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||
try:
|
||
await transport.connect()
|
||
await transport.connect() # 不应报错
|
||
assert transport.is_connected
|
||
finally:
|
||
await transport.disconnect()
|
||
|
||
async def test_connect_with_invalid_command_raises(self):
|
||
transport = StdioTransport(command="/nonexistent/command")
|
||
with pytest.raises(TransportError, match="Failed to start process"):
|
||
await transport.connect()
|
||
|
||
async def test_connect_timeout(self):
|
||
"""使用不响应 initialize 的子进程测试超时"""
|
||
# 使用一个只读取 stdin 但不输出任何内容的脚本
|
||
silent_script = "import sys; sys.stdin.read()"
|
||
transport = StdioTransport(
|
||
command=sys.executable,
|
||
args=["-c", silent_script],
|
||
timeout=0.5,
|
||
)
|
||
with pytest.raises(TransportError, match="Timeout waiting for initialize"):
|
||
await transport.connect()
|
||
assert not transport.is_connected
|
||
|
||
|
||
class TestStdioTransportDisconnect:
|
||
"""StdioTransport 断开测试"""
|
||
|
||
async def test_disconnect_closes_subprocess(self):
|
||
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||
await transport.connect()
|
||
assert transport.is_connected
|
||
|
||
await transport.disconnect()
|
||
assert not transport.is_connected
|
||
assert transport._process is None
|
||
assert transport._reader_task is None
|
||
assert transport._stderr_task is None
|
||
|
||
async def test_disconnect_is_idempotent(self):
|
||
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||
await transport.connect()
|
||
await transport.disconnect()
|
||
await transport.disconnect() # 不应报错
|
||
|
||
async def test_disconnect_cancels_pending_futures(self):
|
||
"""断开时所有 pending future 应收到 TransportError"""
|
||
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||
await transport.connect()
|
||
|
||
# 手动添加一个 pending future
|
||
loop = asyncio.get_running_loop()
|
||
future = loop.create_future()
|
||
transport._pending[999] = future
|
||
|
||
await transport.disconnect()
|
||
|
||
assert future.done()
|
||
with pytest.raises(TransportError, match="Transport disconnected"):
|
||
future.result()
|
||
|
||
async def test_disconnect_clears_pending(self):
|
||
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||
await transport.connect()
|
||
transport._pending[1] = asyncio.get_running_loop().create_future()
|
||
await transport.disconnect()
|
||
assert transport._pending == {}
|
||
|
||
|
||
# ── 请求发送测试 ──────────────────────────────────────────
|
||
|
||
|
||
class TestStdioTransportSendRequest:
|
||
"""StdioTransport 请求发送测试"""
|
||
|
||
async def test_send_request_not_connected_raises(self):
|
||
transport = StdioTransport(command="echo")
|
||
with pytest.raises(TransportError, match="not connected"):
|
||
await transport.send_request("tools/list")
|
||
|
||
async def test_send_request_tools_list(self):
|
||
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||
try:
|
||
await transport.connect()
|
||
result = await transport.send_request("tools/list")
|
||
assert "tools" in result
|
||
assert len(result["tools"]) == 1
|
||
assert result["tools"][0]["name"] == "echo"
|
||
finally:
|
||
await transport.disconnect()
|
||
|
||
async def test_send_request_tools_call(self):
|
||
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||
try:
|
||
await transport.connect()
|
||
result = await transport.send_request(
|
||
"tools/call",
|
||
params={"name": "echo", "arguments": {"msg": "hello world"}},
|
||
)
|
||
assert "content" in result
|
||
assert result["content"][0]["text"] == "hello world"
|
||
finally:
|
||
await transport.disconnect()
|
||
|
||
async def test_send_request_json_rpc_error(self):
|
||
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||
try:
|
||
await transport.connect()
|
||
with pytest.raises(TransportError, match="JSON-RPC error"):
|
||
await transport.send_request("unknown/method")
|
||
finally:
|
||
await transport.disconnect()
|
||
|
||
async def test_send_request_unknown_tool_error(self):
|
||
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||
try:
|
||
await transport.connect()
|
||
with pytest.raises(TransportError, match="Unknown tool"):
|
||
await transport.send_request(
|
||
"tools/call",
|
||
params={"name": "nonexistent", "arguments": {}},
|
||
)
|
||
finally:
|
||
await transport.disconnect()
|
||
|
||
async def test_request_id_increments(self):
|
||
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||
try:
|
||
await transport.connect()
|
||
# connect 时已经用了 id=1 (initialize)
|
||
id_before = transport._request_id
|
||
await transport.send_request("tools/list")
|
||
id_after_1 = transport._request_id
|
||
await transport.send_request("tools/list")
|
||
id_after_2 = transport._request_id
|
||
assert id_after_1 == id_before + 1
|
||
assert id_after_2 == id_before + 2
|
||
finally:
|
||
await transport.disconnect()
|
||
|
||
async def test_send_request_timeout(self):
|
||
"""请求超时测试"""
|
||
# 使用一个不响应的脚本
|
||
silent_script = "import sys; sys.stdin.read()"
|
||
transport = StdioTransport(
|
||
command=sys.executable,
|
||
args=["-c", silent_script],
|
||
timeout=0.5,
|
||
)
|
||
# 手动设置连接状态以绕过 initialize
|
||
transport._connected = True
|
||
transport._process = await asyncio.create_subprocess_exec(
|
||
sys.executable, "-c", silent_script,
|
||
stdin=asyncio.subprocess.PIPE,
|
||
stdout=asyncio.subprocess.PIPE,
|
||
stderr=asyncio.subprocess.PIPE,
|
||
)
|
||
transport._reader_task = asyncio.create_task(transport._read_stdout())
|
||
transport._stderr_task = asyncio.create_task(transport._read_stderr())
|
||
|
||
try:
|
||
with pytest.raises(TransportError, match="Timeout"):
|
||
await transport.send_request("tools/list")
|
||
finally:
|
||
transport._connected = False
|
||
await transport._cleanup()
|
||
|
||
|
||
# ── 并发请求测试 ──────────────────────────────────────────
|
||
|
||
|
||
class TestStdioTransportConcurrentRequests:
|
||
"""StdioTransport 并发请求测试"""
|
||
|
||
async def test_concurrent_requests_correct_id_matching(self):
|
||
"""并发请求的响应应正确匹配到对应的 Future"""
|
||
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||
try:
|
||
await transport.connect()
|
||
|
||
# 并发发送多个请求
|
||
results = await asyncio.gather(
|
||
transport.send_request("tools/list"),
|
||
transport.send_request(
|
||
"tools/call",
|
||
params={"name": "echo", "arguments": {"msg": "msg1"}},
|
||
),
|
||
transport.send_request(
|
||
"tools/call",
|
||
params={"name": "echo", "arguments": {"msg": "msg2"}},
|
||
),
|
||
)
|
||
|
||
# 验证每个请求都得到了正确类型的响应
|
||
assert "tools" in results[0]
|
||
assert results[1]["content"][0]["text"] == "msg1"
|
||
assert results[2]["content"][0]["text"] == "msg2"
|
||
finally:
|
||
await transport.disconnect()
|
||
|
||
|
||
# ── 通知接收测试 ──────────────────────────────────────────
|
||
|
||
|
||
class TestStdioTransportNotifications:
|
||
"""StdioTransport 通知接收测试"""
|
||
|
||
async def test_receive_notification(self):
|
||
transport = _make_transport(NOTIFICATION_SERVER_SCRIPT)
|
||
try:
|
||
await transport.connect()
|
||
|
||
# tools/list 会先发送一个通知
|
||
result = await transport.send_request("tools/list")
|
||
assert "tools" in result
|
||
|
||
# 等待通知到达
|
||
await asyncio.sleep(0.1)
|
||
|
||
# 应该能收到通知
|
||
notification = await transport.receive_response()
|
||
assert notification["method"] == "notifications/tools/list_changed"
|
||
finally:
|
||
await transport.disconnect()
|
||
|
||
async def test_receive_response_no_notification_raises(self):
|
||
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||
try:
|
||
await transport.connect()
|
||
with pytest.raises(TransportError, match="No notification"):
|
||
await transport.receive_response()
|
||
finally:
|
||
await transport.disconnect()
|
||
|
||
async def test_receive_response_not_connected_raises(self):
|
||
transport = StdioTransport(command="echo")
|
||
with pytest.raises(TransportError, match="not connected"):
|
||
await transport.receive_response()
|
||
|
||
|
||
# ── 子进程退出检测测试 ──────────────────────────────────────────
|
||
|
||
|
||
class TestStdioTransportProcessExit:
|
||
"""StdioTransport 子进程退出检测测试"""
|
||
|
||
async def test_subprocess_exit_detection(self):
|
||
"""子进程退出后 is_connected 应返回 False"""
|
||
transport = _make_transport(EXIT_AFTER_INIT_SCRIPT)
|
||
try:
|
||
await transport.connect()
|
||
assert transport.is_connected
|
||
|
||
# 等待子进程退出
|
||
await asyncio.sleep(0.5)
|
||
|
||
# 子进程已退出,is_connected 应为 False
|
||
assert not transport.is_connected
|
||
finally:
|
||
if transport._process is not None:
|
||
await transport.disconnect()
|
||
|
||
async def test_send_request_after_process_exit_raises(self):
|
||
"""子进程退出后发送请求应抛出 TransportError"""
|
||
transport = _make_transport(EXIT_AFTER_INIT_SCRIPT)
|
||
try:
|
||
await transport.connect()
|
||
# 等待子进程退出
|
||
await asyncio.sleep(0.5)
|
||
|
||
if not transport.is_connected:
|
||
with pytest.raises(TransportError, match="not connected"):
|
||
await transport.send_request("tools/list")
|
||
finally:
|
||
if transport._process is not None:
|
||
await transport.disconnect()
|
||
|
||
|
||
# ── stderr 转发测试 ──────────────────────────────────────────
|
||
|
||
|
||
class TestStdioTransportStderr:
|
||
"""StdioTransport stderr 转发测试"""
|
||
|
||
async def test_stderr_forwarded_to_logger(self, caplog):
|
||
"""stderr 输出应转发到 logger"""
|
||
import logging
|
||
|
||
transport = _make_transport(STDERR_SERVER_SCRIPT)
|
||
try:
|
||
with caplog.at_level(logging.DEBUG, logger="agentkit.mcp.transport"):
|
||
await transport.connect()
|
||
# 发送一个请求触发 stderr 输出
|
||
await transport.send_request("tools/list")
|
||
# 等待 stderr 被读取
|
||
await asyncio.sleep(0.2)
|
||
|
||
# 检查日志中包含 stderr 输出
|
||
stderr_logs = [
|
||
r for r in caplog.records
|
||
if "mock server starting" in r.message
|
||
]
|
||
assert len(stderr_logs) > 0
|
||
finally:
|
||
await transport.disconnect()
|
||
|
||
|
||
# ── is_connected 属性测试 ──────────────────────────────────────────
|
||
|
||
|
||
class TestStdioTransportIsConnected:
|
||
"""StdioTransport is_connected 属性测试"""
|
||
|
||
async def test_is_connected_before_connect(self):
|
||
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||
assert not transport.is_connected
|
||
|
||
async def test_is_connected_after_connect(self):
|
||
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||
try:
|
||
await transport.connect()
|
||
assert transport.is_connected
|
||
finally:
|
||
await transport.disconnect()
|
||
|
||
async def test_is_connected_after_disconnect(self):
|
||
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||
await transport.connect()
|
||
await transport.disconnect()
|
||
assert not transport.is_connected
|
||
|
||
async def test_is_connected_checks_process_returncode(self):
|
||
"""is_connected 应检查 process.returncode"""
|
||
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||
try:
|
||
await transport.connect()
|
||
assert transport._process is not None
|
||
# 模拟进程退出但 _connected 仍为 True
|
||
transport._connected = True
|
||
# 终止子进程
|
||
transport._process.kill()
|
||
await transport._process.wait()
|
||
# is_connected 应为 False 因为 returncode 不为 None
|
||
assert not transport.is_connected
|
||
finally:
|
||
transport._connected = False
|
||
await transport._cleanup()
|
||
|
||
|
||
# ── 完整生命周期测试 ──────────────────────────────────────────
|
||
|
||
|
||
class TestStdioTransportLifecycle:
|
||
"""StdioTransport 完整生命周期测试"""
|
||
|
||
async def test_full_lifecycle(self):
|
||
"""测试完整的 connect → send_request → disconnect 生命周期"""
|
||
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||
|
||
# 1. 连接
|
||
await transport.connect()
|
||
assert transport.is_connected
|
||
|
||
# 2. 发送请求
|
||
result = await transport.send_request("tools/list")
|
||
assert "tools" in result
|
||
|
||
# 3. 发送带参数的请求
|
||
result = await transport.send_request(
|
||
"tools/call",
|
||
params={"name": "echo", "arguments": {"msg": "test"}},
|
||
)
|
||
assert result["content"][0]["text"] == "test"
|
||
|
||
# 4. 断开
|
||
await transport.disconnect()
|
||
assert not transport.is_connected
|
||
|
||
async def test_reconnect_after_disconnect(self):
|
||
"""测试断开后重新连接"""
|
||
transport = _make_transport(MOCK_SERVER_SCRIPT)
|
||
|
||
# 第一次连接
|
||
await transport.connect()
|
||
result1 = await transport.send_request("tools/list")
|
||
assert "tools" in result1
|
||
await transport.disconnect()
|
||
|
||
# 重新连接
|
||
await transport.connect()
|
||
result2 = await transport.send_request("tools/list")
|
||
assert "tools" in result2
|
||
await transport.disconnect()
|
||
|
||
async def test_env_variables_passed_to_subprocess(self):
|
||
"""测试环境变量传递给子进程"""
|
||
# 使用打印环境变量的脚本
|
||
env_script = textwrap.dedent("""\
|
||
import sys
|
||
import json
|
||
import os
|
||
|
||
for line in sys.stdin:
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
try:
|
||
data = json.loads(line)
|
||
except json.JSONDecodeError:
|
||
continue
|
||
if "id" not in data:
|
||
continue
|
||
method = data.get("method", "")
|
||
req_id = data.get("id")
|
||
if method == "initialize":
|
||
response = {
|
||
"jsonrpc": "2.0",
|
||
"id": req_id,
|
||
"result": {"protocolVersion": "2024-11-05", "capabilities": {}, "serverInfo": {"name": "env-server", "version": "0.1.0"}},
|
||
}
|
||
sys.stdout.write(json.dumps(response) + "\\n")
|
||
sys.stdout.flush()
|
||
elif method == "tools/call":
|
||
test_env = os.environ.get("TEST_MCP_VAR", "not_set")
|
||
response = {
|
||
"jsonrpc": "2.0",
|
||
"id": req_id,
|
||
"result": {
|
||
"content": [{"type": "text", "text": test_env}]
|
||
},
|
||
}
|
||
sys.stdout.write(json.dumps(response) + "\\n")
|
||
sys.stdout.flush()
|
||
else:
|
||
response = {
|
||
"jsonrpc": "2.0",
|
||
"id": req_id,
|
||
"result": {},
|
||
}
|
||
sys.stdout.write(json.dumps(response) + "\\n")
|
||
sys.stdout.flush()
|
||
""")
|
||
|
||
transport = StdioTransport(
|
||
command=sys.executable,
|
||
args=["-c", env_script],
|
||
env={"TEST_MCP_VAR": "hello_from_env"},
|
||
)
|
||
try:
|
||
await transport.connect()
|
||
result = await transport.send_request(
|
||
"tools/call",
|
||
params={"name": "check_env", "arguments": {}},
|
||
)
|
||
assert result["content"][0]["text"] == "hello_from_env"
|
||
finally:
|
||
await transport.disconnect()
|