601 lines
21 KiB
Python
601 lines
21 KiB
Python
"""TerminalSession 和 ShellTool 单元测试
|
||
|
||
测试场景:
|
||
- 跨命令保持 cwd → cd 后执行 pwd 返回正确目录
|
||
- 跨命令保持 env → export 后执行 echo 返回正确值
|
||
- 危险命令需确认 → rm 命令触发确认回调
|
||
- 输出解析 → 错误输出结构化为错误类型+建议
|
||
- 无 session_id 时保持现有行为
|
||
- 会话管理器功能
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import os
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
import pytest
|
||
|
||
from agentkit.tools.terminal_session import TerminalSession, TerminalSessionManager, CommandRecord
|
||
from agentkit.tools.shell import ShellTool
|
||
from agentkit.tools.output_parser import OutputParser, ParsedOutput, ErrorType
|
||
|
||
|
||
# ============================================================
|
||
# OutputParser 测试
|
||
# ============================================================
|
||
|
||
|
||
class TestOutputParser:
|
||
"""测试 OutputParser 结构化解析"""
|
||
|
||
def setup_method(self):
|
||
self.parser = OutputParser()
|
||
|
||
def test_parse_success_output(self):
|
||
"""成功输出解析"""
|
||
result = self.parser.parse("hello world", 0)
|
||
assert result.exit_code == 0
|
||
assert result.is_error is False
|
||
assert result.error_type == ErrorType.NONE
|
||
assert result.message == "hello world"
|
||
assert result.suggestions == []
|
||
|
||
def test_parse_empty_output(self):
|
||
"""空输出解析"""
|
||
result = self.parser.parse("", 0)
|
||
assert result.exit_code == 0
|
||
assert result.is_error is False
|
||
assert result.message == ""
|
||
|
||
def test_parse_permission_denied(self):
|
||
"""权限不足错误解析"""
|
||
result = self.parser.parse("permission denied: /root/secret", 1)
|
||
assert result.is_error is True
|
||
assert result.error_type == ErrorType.PERMISSION_DENIED
|
||
assert len(result.suggestions) > 0
|
||
assert any("sudo" in s for s in result.suggestions)
|
||
|
||
def test_parse_not_found(self):
|
||
"""文件不存在错误解析"""
|
||
result = self.parser.parse("No such file or directory: /tmp/missing", 1)
|
||
assert result.is_error is True
|
||
assert result.error_type == ErrorType.NOT_FOUND
|
||
|
||
def test_parse_timeout(self):
|
||
"""超时错误解析"""
|
||
result = self.parser.parse("Connection timed out", 1)
|
||
assert result.is_error is True
|
||
assert result.error_type == ErrorType.TIMEOUT
|
||
|
||
def test_parse_syntax_error(self):
|
||
"""语法错误解析"""
|
||
result = self.parser.parse("syntax error near unexpected token", 2)
|
||
assert result.is_error is True
|
||
assert result.error_type == ErrorType.SYNTAX_ERROR
|
||
|
||
def test_parse_connection_refused(self):
|
||
"""连接被拒绝解析"""
|
||
result = self.parser.parse("Connection refused on port 8080", 1)
|
||
assert result.is_error is True
|
||
assert result.error_type == ErrorType.CONNECTION_REFUSED
|
||
|
||
def test_parse_out_of_memory(self):
|
||
"""内存不足解析"""
|
||
result = self.parser.parse("Out of memory: cannot allocate", 1)
|
||
assert result.is_error is True
|
||
assert result.error_type == ErrorType.OUT_OF_MEMORY
|
||
|
||
def test_parse_disk_full(self):
|
||
"""磁盘满解析"""
|
||
result = self.parser.parse("No space left on device", 1)
|
||
assert result.is_error is True
|
||
assert result.error_type == ErrorType.DISK_FULL
|
||
|
||
def test_parse_already_exists(self):
|
||
"""已存在解析"""
|
||
result = self.parser.parse("File already exists: /tmp/test", 1)
|
||
assert result.is_error is True
|
||
assert result.error_type == ErrorType.ALREADY_EXISTS
|
||
|
||
def test_parse_invalid_argument(self):
|
||
"""无效参数解析"""
|
||
result = self.parser.parse("invalid argument: --unknown-flag", 1)
|
||
assert result.is_error is True
|
||
assert result.error_type == ErrorType.INVALID_ARGUMENT
|
||
|
||
def test_parse_network_error(self):
|
||
"""网络错误解析"""
|
||
result = self.parser.parse("Network is unreachable", 1)
|
||
assert result.is_error is True
|
||
assert result.error_type == ErrorType.NETWORK_ERROR
|
||
|
||
def test_parse_exit_code_126(self):
|
||
"""退出码 126 → 权限不足"""
|
||
result = self.parser.parse("some unknown error", 126)
|
||
assert result.is_error is True
|
||
assert result.error_type == ErrorType.PERMISSION_DENIED
|
||
|
||
def test_parse_exit_code_127(self):
|
||
"""退出码 127 → 命令未找到"""
|
||
result = self.parser.parse("some unknown error", 127)
|
||
assert result.is_error is True
|
||
assert result.error_type == ErrorType.NOT_FOUND
|
||
|
||
def test_parse_exit_code_130(self):
|
||
"""退出码 130 → 被中断"""
|
||
result = self.parser.parse("some unknown error", 130)
|
||
assert result.is_error is True
|
||
assert result.error_type == ErrorType.TIMEOUT
|
||
|
||
def test_parse_unknown_error(self):
|
||
"""未知错误"""
|
||
result = self.parser.parse("something went wrong", 1)
|
||
assert result.is_error is True
|
||
assert result.error_type == ErrorType.UNKNOWN
|
||
|
||
def test_parse_long_message_truncated(self):
|
||
"""长消息截断"""
|
||
long_output = "x" * 300
|
||
result = self.parser.parse(long_output, 0)
|
||
assert len(result.message) <= 203 # 200 + "..."
|
||
|
||
def test_parsed_output_to_dict(self):
|
||
"""ParsedOutput.to_dict()"""
|
||
result = self.parser.parse("permission denied", 1)
|
||
d = result.to_dict()
|
||
assert d["exit_code"] == 1
|
||
assert d["is_error"] is True
|
||
assert d["error_type"] == "permission_denied"
|
||
assert isinstance(d["suggestions"], list)
|
||
|
||
def test_parse_chinese_error_messages(self):
|
||
"""中文错误消息解析"""
|
||
result = self.parser.parse("权限不足: 无法访问", 1)
|
||
assert result.is_error is True
|
||
assert result.error_type == ErrorType.PERMISSION_DENIED
|
||
|
||
def test_parse_multiline_output_message_is_last_line(self):
|
||
"""多行输出取最后一行作为消息"""
|
||
output = "line1\nline2\nline3"
|
||
result = self.parser.parse(output, 0)
|
||
assert result.message == "line3"
|
||
|
||
|
||
# ============================================================
|
||
# TerminalSession 测试
|
||
# ============================================================
|
||
|
||
|
||
class TestTerminalSession:
|
||
"""测试 TerminalSession 会话状态管理"""
|
||
|
||
def test_construction_default(self):
|
||
"""默认构造"""
|
||
session = TerminalSession(session_id="test")
|
||
assert session.session_id == "test"
|
||
assert session.cwd == os.getcwd()
|
||
assert isinstance(session.env, dict)
|
||
assert session.history == []
|
||
|
||
def test_construction_custom_cwd(self):
|
||
"""自定义工作目录"""
|
||
session = TerminalSession(session_id="test", cwd="/tmp")
|
||
assert session.cwd == "/tmp"
|
||
|
||
def test_construction_custom_env(self):
|
||
"""自定义环境变量"""
|
||
session = TerminalSession(session_id="test", env={"FOO": "bar"})
|
||
assert session.env.get("FOO") == "bar"
|
||
|
||
def test_set_cwd(self):
|
||
"""手动设置 cwd"""
|
||
session = TerminalSession(session_id="test")
|
||
session.set_cwd("/usr/local")
|
||
assert session.cwd == "/usr/local"
|
||
|
||
def test_set_env(self):
|
||
"""手动设置环境变量"""
|
||
session = TerminalSession(session_id="test")
|
||
session.set_env("MY_VAR", "hello")
|
||
assert session.env.get("MY_VAR") == "hello"
|
||
|
||
def test_update_env(self):
|
||
"""批量更新环境变量"""
|
||
session = TerminalSession(session_id="test")
|
||
session.update_env({"A": "1", "B": "2"})
|
||
assert session.env.get("A") == "1"
|
||
assert session.env.get("B") == "2"
|
||
|
||
def test_get_env_returns_copy(self):
|
||
"""get_env 返回副本,修改不影响原数据"""
|
||
session = TerminalSession(session_id="test")
|
||
env = session.get_env()
|
||
env["HACKED"] = "yes"
|
||
assert "HACKED" not in session.env
|
||
|
||
def test_get_history_returns_copy(self):
|
||
"""get_history 返回副本"""
|
||
session = TerminalSession(session_id="test")
|
||
history = session.get_history()
|
||
assert history is not session._history
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_execute_simple_command(self):
|
||
"""执行简单命令"""
|
||
session = TerminalSession(session_id="test")
|
||
result = await session.execute("echo hello")
|
||
assert result.exit_code == 0
|
||
assert "hello" in result.raw_output
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_execute_records_history(self):
|
||
"""执行命令记录历史"""
|
||
session = TerminalSession(session_id="test")
|
||
await session.execute("echo first")
|
||
await session.execute("echo second")
|
||
assert len(session.history) == 2
|
||
assert session.history[0].command == "echo first"
|
||
assert session.history[1].command == "echo second"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_cross_command_cwd(self):
|
||
"""跨命令保持 cwd:cd 后 pwd 返回正确目录"""
|
||
session = TerminalSession(session_id="test")
|
||
await session.execute("cd /tmp")
|
||
assert session.cwd == "/tmp"
|
||
|
||
result = await session.execute("pwd")
|
||
assert "/tmp" in result.raw_output
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_cross_command_env(self):
|
||
"""跨命令保持 env:export 后 echo 返回正确值"""
|
||
session = TerminalSession(session_id="test")
|
||
await session.execute("export MY_TEST_VAR=hello123")
|
||
assert session.env.get("MY_TEST_VAR") == "hello123"
|
||
|
||
result = await session.execute("echo $MY_TEST_VAR")
|
||
assert "hello123" in result.raw_output
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_cd_relative_path(self):
|
||
"""cd 相对路径(目录存在时更新 cwd)"""
|
||
# 使用 /usr 作为基础目录,cd local(/usr/local 存在)
|
||
session = TerminalSession(session_id="test", cwd="/usr")
|
||
await session.execute("cd local")
|
||
assert session.cwd == "/usr/local"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_cd_absolute_path(self):
|
||
"""cd 绝对路径"""
|
||
session = TerminalSession(session_id="test")
|
||
await session.execute("cd /usr")
|
||
assert session.cwd == "/usr"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_failed_command_no_state_update(self):
|
||
"""失败命令不更新状态"""
|
||
session = TerminalSession(session_id="test", cwd="/tmp")
|
||
await session.execute("cd /nonexistent_dir_xyz")
|
||
# cd 失败,cwd 不应更新
|
||
assert session.cwd == "/tmp"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_timeout(self):
|
||
"""命令超时"""
|
||
session = TerminalSession(session_id="test")
|
||
result = await session.execute("sleep 10", timeout=0.5)
|
||
assert result.exit_code == -1
|
||
assert result.is_error is True
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_max_history(self):
|
||
"""历史记录上限"""
|
||
session = TerminalSession(session_id="test", max_history=3)
|
||
for i in range(5):
|
||
await session.execute(f"echo {i}")
|
||
assert len(session.history) == 3
|
||
assert session.history[0].command == "echo 2"
|
||
|
||
def test_close(self):
|
||
"""关闭会话"""
|
||
session = TerminalSession(session_id="test")
|
||
session.close() # 不应抛出异常
|
||
|
||
|
||
# ============================================================
|
||
# TerminalSessionManager 测试
|
||
# ============================================================
|
||
|
||
|
||
class TestTerminalSessionManager:
|
||
"""测试 TerminalSessionManager 会话管理"""
|
||
|
||
def test_get_or_create_new(self):
|
||
"""创建新会话"""
|
||
manager = TerminalSessionManager()
|
||
session = manager.get_or_create("s1")
|
||
assert session.session_id == "s1"
|
||
|
||
def test_get_or_create_existing(self):
|
||
"""获取已有会话"""
|
||
manager = TerminalSessionManager()
|
||
s1 = manager.get_or_create("s1")
|
||
s1.set_cwd("/tmp")
|
||
s2 = manager.get_or_create("s1")
|
||
assert s2.cwd == "/tmp"
|
||
|
||
def test_get_existing(self):
|
||
"""get 获取已有会话"""
|
||
manager = TerminalSessionManager()
|
||
manager.get_or_create("s1")
|
||
session = manager.get("s1")
|
||
assert session is not None
|
||
|
||
def test_get_nonexistent(self):
|
||
"""get 不存在的会话返回 None"""
|
||
manager = TerminalSessionManager()
|
||
assert manager.get("nonexistent") is None
|
||
|
||
def test_remove(self):
|
||
"""移除会话"""
|
||
manager = TerminalSessionManager()
|
||
manager.get_or_create("s1")
|
||
manager.remove("s1")
|
||
assert manager.get("s1") is None
|
||
|
||
def test_list_sessions(self):
|
||
"""列出会话"""
|
||
manager = TerminalSessionManager()
|
||
manager.get_or_create("s1")
|
||
manager.get_or_create("s2")
|
||
assert sorted(manager.list_sessions()) == ["s1", "s2"]
|
||
|
||
def test_has_session(self):
|
||
"""检查会话是否存在"""
|
||
manager = TerminalSessionManager()
|
||
manager.get_or_create("s1")
|
||
assert manager.has_session("s1") is True
|
||
assert manager.has_session("s2") is False
|
||
|
||
def test_max_sessions_eviction(self):
|
||
"""超过最大会话数时移除最旧会话"""
|
||
manager = TerminalSessionManager(max_sessions=2)
|
||
manager.get_or_create("s1")
|
||
manager.get_or_create("s2")
|
||
manager.get_or_create("s3") # 应该移除 s1
|
||
assert not manager.has_session("s1")
|
||
assert manager.has_session("s2")
|
||
assert manager.has_session("s3")
|
||
|
||
def test_close_all(self):
|
||
"""关闭所有会话"""
|
||
manager = TerminalSessionManager()
|
||
manager.get_or_create("s1")
|
||
manager.get_or_create("s2")
|
||
manager.close_all()
|
||
assert manager.list_sessions() == []
|
||
|
||
|
||
# ============================================================
|
||
# ShellTool 测试
|
||
# ============================================================
|
||
|
||
|
||
class TestShellToolConstruction:
|
||
"""测试 ShellTool 构造"""
|
||
|
||
def test_default_construction(self):
|
||
tool = ShellTool()
|
||
assert tool.name == "shell"
|
||
assert tool.input_schema is not None
|
||
assert "command" in tool.input_schema["properties"]
|
||
assert "session_id" in tool.input_schema["properties"]
|
||
assert tool.input_schema["required"] == ["command"]
|
||
|
||
def test_custom_construction(self):
|
||
tool = ShellTool(name="my_shell", version="2.0.0")
|
||
assert tool.name == "my_shell"
|
||
assert tool.version == "2.0.0"
|
||
|
||
def test_to_dict(self):
|
||
tool = ShellTool()
|
||
d = tool.to_dict()
|
||
assert d["name"] == "shell"
|
||
assert "input_schema" in d
|
||
|
||
def test_repr(self):
|
||
tool = ShellTool()
|
||
r = repr(tool)
|
||
assert "ShellTool" in r
|
||
assert "shell" in r
|
||
|
||
|
||
class TestShellToolExecution:
|
||
"""测试 ShellTool 命令执行"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_execute_simple_command(self):
|
||
"""执行简单命令(无会话模式)"""
|
||
tool = ShellTool()
|
||
result = await tool.execute(command="echo hello")
|
||
assert result["exit_code"] == 0
|
||
assert "hello" in result["output"]
|
||
assert result["is_error"] is False
|
||
assert result["session_id"] is None
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_execute_missing_command(self):
|
||
"""缺少 command 参数"""
|
||
tool = ShellTool()
|
||
result = await tool.execute()
|
||
assert result["is_error"] is True
|
||
assert result["exit_code"] == 1
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_execute_with_working_dir(self):
|
||
"""指定工作目录"""
|
||
tool = ShellTool()
|
||
result = await tool.execute(command="pwd", working_dir="/tmp")
|
||
assert result["exit_code"] == 0
|
||
assert "/tmp" in result["output"]
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_execute_with_session(self):
|
||
"""会话模式执行命令"""
|
||
tool = ShellTool()
|
||
result = await tool.execute(command="echo session_test", session_id="s1")
|
||
assert result["exit_code"] == 0
|
||
assert "session_test" in result["output"]
|
||
assert result["session_id"] == "s1"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_session_preserves_cwd(self):
|
||
"""会话模式保持 cwd"""
|
||
tool = ShellTool()
|
||
await tool.execute(command="cd /tmp", session_id="cwd-test")
|
||
result = await tool.execute(command="pwd", session_id="cwd-test")
|
||
assert "/tmp" in result["output"]
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_session_preserves_env(self):
|
||
"""会话模式保持 env"""
|
||
tool = ShellTool()
|
||
await tool.execute(
|
||
command="export SHELL_TEST_VAR=world", session_id="env-test"
|
||
)
|
||
result = await tool.execute(
|
||
command="echo $SHELL_TEST_VAR", session_id="env-test"
|
||
)
|
||
assert "world" in result["output"]
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_no_session_id_backward_compatible(self):
|
||
"""无 session_id 时保持现有行为"""
|
||
tool = ShellTool()
|
||
result = await tool.execute(command="echo no_session")
|
||
assert result["exit_code"] == 0
|
||
assert "no_session" in result["output"]
|
||
assert result["session_id"] is None
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_different_sessions_independent(self):
|
||
"""不同会话互不影响"""
|
||
tool = ShellTool()
|
||
await tool.execute(command="cd /tmp", session_id="s1")
|
||
await tool.execute(command="cd /usr", session_id="s2")
|
||
|
||
r1 = await tool.execute(command="pwd", session_id="s1")
|
||
r2 = await tool.execute(command="pwd", session_id="s2")
|
||
|
||
assert "/tmp" in r1["output"]
|
||
assert "/usr" in r2["output"]
|
||
|
||
|
||
class TestShellToolSecurity:
|
||
"""测试 ShellTool 安全控制"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_safe_command_allowed(self):
|
||
"""安全命令直接执行"""
|
||
tool = ShellTool()
|
||
result = await tool.execute(command="ls /tmp")
|
||
assert result["exit_code"] == 0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_dangerous_command_blocked_without_callback(self):
|
||
"""危险命令无确认回调时被拒绝"""
|
||
tool = ShellTool()
|
||
result = await tool.execute(command="rm -rf /tmp/test")
|
||
assert result["is_error"] is True
|
||
assert result["exit_code"] == 126
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_dangerous_command_confirmed(self):
|
||
"""危险命令通过确认回调允许执行"""
|
||
confirm = AsyncMock(return_value=True)
|
||
tool = ShellTool(confirm_callback=confirm)
|
||
result = await tool.execute(command="rm -rf /tmp/nonexistent_test_dir")
|
||
assert confirm.called
|
||
# 命令本身可能失败(目录不存在),但不应被安全机制拒绝
|
||
assert result["exit_code"] != 126 or not result["is_error"]
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_dangerous_command_rejected_by_callback(self):
|
||
"""确认回调拒绝危险命令"""
|
||
confirm = AsyncMock(return_value=False)
|
||
tool = ShellTool(confirm_callback=confirm)
|
||
result = await tool.execute(command="rm -rf /tmp/test")
|
||
assert result["is_error"] is True
|
||
assert result["exit_code"] == 126
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_audit_log_recorded(self):
|
||
"""审计日志记录"""
|
||
tool = ShellTool()
|
||
await tool.execute(command="echo audit_test")
|
||
assert len(tool.audit_log) > 0
|
||
assert tool.audit_log[0]["command"] == "echo audit_test"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_blocked_command_in_audit_log(self):
|
||
"""被阻止的命令记录在审计日志"""
|
||
tool = ShellTool()
|
||
await tool.execute(command="rm -rf /tmp/test")
|
||
blocked_entries = [e for e in tool.audit_log if e.get("blocked")]
|
||
assert len(blocked_entries) > 0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_git_push_force_is_dangerous(self):
|
||
"""git push --force 是危险命令"""
|
||
tool = ShellTool()
|
||
result = await tool.execute(command="git push --force origin main")
|
||
assert result["is_error"] is True
|
||
assert result["exit_code"] == 126
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_git_status_is_safe(self):
|
||
"""git status 是安全命令"""
|
||
tool = ShellTool()
|
||
result = await tool.execute(command="git status")
|
||
# git status 可能在非 git 目录失败,但不应被安全机制拒绝
|
||
assert result["exit_code"] != 126
|
||
|
||
|
||
class TestShellToolOutputParsing:
|
||
"""测试 ShellTool 输出解析集成"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_error_output_structured(self):
|
||
"""错误输出结构化"""
|
||
tool = ShellTool()
|
||
result = await tool.execute(command="ls /nonexistent_dir_xyz_12345")
|
||
assert result["is_error"] is True
|
||
assert result["error_type"] in ("not_found", "unknown")
|
||
assert isinstance(result["suggestions"], list)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_success_output_not_error(self):
|
||
"""成功输出不标记为错误"""
|
||
tool = ShellTool()
|
||
result = await tool.execute(command="echo success")
|
||
assert result["is_error"] is False
|
||
assert result["error_type"] == "none"
|
||
|
||
|
||
class TestShellToolSessionManager:
|
||
"""测试 ShellTool 会话管理器访问"""
|
||
|
||
def test_session_manager_accessible(self):
|
||
tool = ShellTool()
|
||
assert tool.session_manager is not None
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_session_created_on_first_use(self):
|
||
"""首次使用 session_id 时创建会话"""
|
||
tool = ShellTool()
|
||
assert not tool.session_manager.has_session("new-session")
|
||
await tool.execute(command="echo test", session_id="new-session")
|
||
assert tool.session_manager.has_session("new-session")
|