fischer-agentkit/tests/unit/tools/test_terminal_session.py

601 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""TerminalSession 和 ShellTool 单元测试
测试场景:
- 跨命令保持 cwd → cd 后执行 pwd 返回正确目录
- 跨命令保持 env → export 后执行 echo 返回正确值
- 危险命令需确认 → rm 命令触发确认回调
- 输出解析 → 错误输出结构化为错误类型+建议
- 无 session_id 时保持现有行为
- 会话管理器功能
"""
from __future__ import annotations
import os
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from agentkit.tools.terminal_session import TerminalSession, TerminalSessionManager, CommandRecord
from agentkit.tools.shell import ShellTool
from agentkit.tools.output_parser import OutputParser, ParsedOutput, ErrorType
# ============================================================
# OutputParser 测试
# ============================================================
class TestOutputParser:
"""测试 OutputParser 结构化解析"""
def setup_method(self):
self.parser = OutputParser()
def test_parse_success_output(self):
"""成功输出解析"""
result = self.parser.parse("hello world", 0)
assert result.exit_code == 0
assert result.is_error is False
assert result.error_type == ErrorType.NONE
assert result.message == "hello world"
assert result.suggestions == []
def test_parse_empty_output(self):
"""空输出解析"""
result = self.parser.parse("", 0)
assert result.exit_code == 0
assert result.is_error is False
assert result.message == ""
def test_parse_permission_denied(self):
"""权限不足错误解析"""
result = self.parser.parse("permission denied: /root/secret", 1)
assert result.is_error is True
assert result.error_type == ErrorType.PERMISSION_DENIED
assert len(result.suggestions) > 0
assert any("sudo" in s for s in result.suggestions)
def test_parse_not_found(self):
"""文件不存在错误解析"""
result = self.parser.parse("No such file or directory: /tmp/missing", 1)
assert result.is_error is True
assert result.error_type == ErrorType.NOT_FOUND
def test_parse_timeout(self):
"""超时错误解析"""
result = self.parser.parse("Connection timed out", 1)
assert result.is_error is True
assert result.error_type == ErrorType.TIMEOUT
def test_parse_syntax_error(self):
"""语法错误解析"""
result = self.parser.parse("syntax error near unexpected token", 2)
assert result.is_error is True
assert result.error_type == ErrorType.SYNTAX_ERROR
def test_parse_connection_refused(self):
"""连接被拒绝解析"""
result = self.parser.parse("Connection refused on port 8080", 1)
assert result.is_error is True
assert result.error_type == ErrorType.CONNECTION_REFUSED
def test_parse_out_of_memory(self):
"""内存不足解析"""
result = self.parser.parse("Out of memory: cannot allocate", 1)
assert result.is_error is True
assert result.error_type == ErrorType.OUT_OF_MEMORY
def test_parse_disk_full(self):
"""磁盘满解析"""
result = self.parser.parse("No space left on device", 1)
assert result.is_error is True
assert result.error_type == ErrorType.DISK_FULL
def test_parse_already_exists(self):
"""已存在解析"""
result = self.parser.parse("File already exists: /tmp/test", 1)
assert result.is_error is True
assert result.error_type == ErrorType.ALREADY_EXISTS
def test_parse_invalid_argument(self):
"""无效参数解析"""
result = self.parser.parse("invalid argument: --unknown-flag", 1)
assert result.is_error is True
assert result.error_type == ErrorType.INVALID_ARGUMENT
def test_parse_network_error(self):
"""网络错误解析"""
result = self.parser.parse("Network is unreachable", 1)
assert result.is_error is True
assert result.error_type == ErrorType.NETWORK_ERROR
def test_parse_exit_code_126(self):
"""退出码 126 → 权限不足"""
result = self.parser.parse("some unknown error", 126)
assert result.is_error is True
assert result.error_type == ErrorType.PERMISSION_DENIED
def test_parse_exit_code_127(self):
"""退出码 127 → 命令未找到"""
result = self.parser.parse("some unknown error", 127)
assert result.is_error is True
assert result.error_type == ErrorType.NOT_FOUND
def test_parse_exit_code_130(self):
"""退出码 130 → 被中断"""
result = self.parser.parse("some unknown error", 130)
assert result.is_error is True
assert result.error_type == ErrorType.TIMEOUT
def test_parse_unknown_error(self):
"""未知错误"""
result = self.parser.parse("something went wrong", 1)
assert result.is_error is True
assert result.error_type == ErrorType.UNKNOWN
def test_parse_long_message_truncated(self):
"""长消息截断"""
long_output = "x" * 300
result = self.parser.parse(long_output, 0)
assert len(result.message) <= 203 # 200 + "..."
def test_parsed_output_to_dict(self):
"""ParsedOutput.to_dict()"""
result = self.parser.parse("permission denied", 1)
d = result.to_dict()
assert d["exit_code"] == 1
assert d["is_error"] is True
assert d["error_type"] == "permission_denied"
assert isinstance(d["suggestions"], list)
def test_parse_chinese_error_messages(self):
"""中文错误消息解析"""
result = self.parser.parse("权限不足: 无法访问", 1)
assert result.is_error is True
assert result.error_type == ErrorType.PERMISSION_DENIED
def test_parse_multiline_output_message_is_last_line(self):
"""多行输出取最后一行作为消息"""
output = "line1\nline2\nline3"
result = self.parser.parse(output, 0)
assert result.message == "line3"
# ============================================================
# TerminalSession 测试
# ============================================================
class TestTerminalSession:
"""测试 TerminalSession 会话状态管理"""
def test_construction_default(self):
"""默认构造"""
session = TerminalSession(session_id="test")
assert session.session_id == "test"
assert session.cwd == os.getcwd()
assert isinstance(session.env, dict)
assert session.history == []
def test_construction_custom_cwd(self):
"""自定义工作目录"""
session = TerminalSession(session_id="test", cwd="/tmp")
assert session.cwd == "/tmp"
def test_construction_custom_env(self):
"""自定义环境变量"""
session = TerminalSession(session_id="test", env={"FOO": "bar"})
assert session.env.get("FOO") == "bar"
def test_set_cwd(self):
"""手动设置 cwd"""
session = TerminalSession(session_id="test")
session.set_cwd("/usr/local")
assert session.cwd == "/usr/local"
def test_set_env(self):
"""手动设置环境变量"""
session = TerminalSession(session_id="test")
session.set_env("MY_VAR", "hello")
assert session.env.get("MY_VAR") == "hello"
def test_update_env(self):
"""批量更新环境变量"""
session = TerminalSession(session_id="test")
session.update_env({"A": "1", "B": "2"})
assert session.env.get("A") == "1"
assert session.env.get("B") == "2"
def test_get_env_returns_copy(self):
"""get_env 返回副本,修改不影响原数据"""
session = TerminalSession(session_id="test")
env = session.get_env()
env["HACKED"] = "yes"
assert "HACKED" not in session.env
def test_get_history_returns_copy(self):
"""get_history 返回副本"""
session = TerminalSession(session_id="test")
history = session.get_history()
assert history is not session._history
@pytest.mark.asyncio
async def test_execute_simple_command(self):
"""执行简单命令"""
session = TerminalSession(session_id="test")
result = await session.execute("echo hello")
assert result.exit_code == 0
assert "hello" in result.raw_output
@pytest.mark.asyncio
async def test_execute_records_history(self):
"""执行命令记录历史"""
session = TerminalSession(session_id="test")
await session.execute("echo first")
await session.execute("echo second")
assert len(session.history) == 2
assert session.history[0].command == "echo first"
assert session.history[1].command == "echo second"
@pytest.mark.asyncio
async def test_cross_command_cwd(self):
"""跨命令保持 cwdcd 后 pwd 返回正确目录"""
session = TerminalSession(session_id="test")
await session.execute("cd /tmp")
assert session.cwd == "/tmp"
result = await session.execute("pwd")
assert "/tmp" in result.raw_output
@pytest.mark.asyncio
async def test_cross_command_env(self):
"""跨命令保持 envexport 后 echo 返回正确值"""
session = TerminalSession(session_id="test")
await session.execute("export MY_TEST_VAR=hello123")
assert session.env.get("MY_TEST_VAR") == "hello123"
result = await session.execute("echo $MY_TEST_VAR")
assert "hello123" in result.raw_output
@pytest.mark.asyncio
async def test_cd_relative_path(self):
"""cd 相对路径(目录存在时更新 cwd"""
# 使用 /usr 作为基础目录cd local/usr/local 存在)
session = TerminalSession(session_id="test", cwd="/usr")
await session.execute("cd local")
assert session.cwd == "/usr/local"
@pytest.mark.asyncio
async def test_cd_absolute_path(self):
"""cd 绝对路径"""
session = TerminalSession(session_id="test")
await session.execute("cd /usr")
assert session.cwd == "/usr"
@pytest.mark.asyncio
async def test_failed_command_no_state_update(self):
"""失败命令不更新状态"""
session = TerminalSession(session_id="test", cwd="/tmp")
await session.execute("cd /nonexistent_dir_xyz")
# cd 失败cwd 不应更新
assert session.cwd == "/tmp"
@pytest.mark.asyncio
async def test_timeout(self):
"""命令超时"""
session = TerminalSession(session_id="test")
result = await session.execute("sleep 10", timeout=0.5)
assert result.exit_code == -1
assert result.is_error is True
@pytest.mark.asyncio
async def test_max_history(self):
"""历史记录上限"""
session = TerminalSession(session_id="test", max_history=3)
for i in range(5):
await session.execute(f"echo {i}")
assert len(session.history) == 3
assert session.history[0].command == "echo 2"
def test_close(self):
"""关闭会话"""
session = TerminalSession(session_id="test")
session.close() # 不应抛出异常
# ============================================================
# TerminalSessionManager 测试
# ============================================================
class TestTerminalSessionManager:
"""测试 TerminalSessionManager 会话管理"""
def test_get_or_create_new(self):
"""创建新会话"""
manager = TerminalSessionManager()
session = manager.get_or_create("s1")
assert session.session_id == "s1"
def test_get_or_create_existing(self):
"""获取已有会话"""
manager = TerminalSessionManager()
s1 = manager.get_or_create("s1")
s1.set_cwd("/tmp")
s2 = manager.get_or_create("s1")
assert s2.cwd == "/tmp"
def test_get_existing(self):
"""get 获取已有会话"""
manager = TerminalSessionManager()
manager.get_or_create("s1")
session = manager.get("s1")
assert session is not None
def test_get_nonexistent(self):
"""get 不存在的会话返回 None"""
manager = TerminalSessionManager()
assert manager.get("nonexistent") is None
def test_remove(self):
"""移除会话"""
manager = TerminalSessionManager()
manager.get_or_create("s1")
manager.remove("s1")
assert manager.get("s1") is None
def test_list_sessions(self):
"""列出会话"""
manager = TerminalSessionManager()
manager.get_or_create("s1")
manager.get_or_create("s2")
assert sorted(manager.list_sessions()) == ["s1", "s2"]
def test_has_session(self):
"""检查会话是否存在"""
manager = TerminalSessionManager()
manager.get_or_create("s1")
assert manager.has_session("s1") is True
assert manager.has_session("s2") is False
def test_max_sessions_eviction(self):
"""超过最大会话数时移除最旧会话"""
manager = TerminalSessionManager(max_sessions=2)
manager.get_or_create("s1")
manager.get_or_create("s2")
manager.get_or_create("s3") # 应该移除 s1
assert not manager.has_session("s1")
assert manager.has_session("s2")
assert manager.has_session("s3")
def test_close_all(self):
"""关闭所有会话"""
manager = TerminalSessionManager()
manager.get_or_create("s1")
manager.get_or_create("s2")
manager.close_all()
assert manager.list_sessions() == []
# ============================================================
# ShellTool 测试
# ============================================================
class TestShellToolConstruction:
"""测试 ShellTool 构造"""
def test_default_construction(self):
tool = ShellTool()
assert tool.name == "shell"
assert tool.input_schema is not None
assert "command" in tool.input_schema["properties"]
assert "session_id" in tool.input_schema["properties"]
assert tool.input_schema["required"] == ["command"]
def test_custom_construction(self):
tool = ShellTool(name="my_shell", version="2.0.0")
assert tool.name == "my_shell"
assert tool.version == "2.0.0"
def test_to_dict(self):
tool = ShellTool()
d = tool.to_dict()
assert d["name"] == "shell"
assert "input_schema" in d
def test_repr(self):
tool = ShellTool()
r = repr(tool)
assert "ShellTool" in r
assert "shell" in r
class TestShellToolExecution:
"""测试 ShellTool 命令执行"""
@pytest.mark.asyncio
async def test_execute_simple_command(self):
"""执行简单命令(无会话模式)"""
tool = ShellTool()
result = await tool.execute(command="echo hello")
assert result["exit_code"] == 0
assert "hello" in result["output"]
assert result["is_error"] is False
assert result["session_id"] is None
@pytest.mark.asyncio
async def test_execute_missing_command(self):
"""缺少 command 参数"""
tool = ShellTool()
result = await tool.execute()
assert result["is_error"] is True
assert result["exit_code"] == 1
@pytest.mark.asyncio
async def test_execute_with_working_dir(self):
"""指定工作目录"""
tool = ShellTool()
result = await tool.execute(command="pwd", working_dir="/tmp")
assert result["exit_code"] == 0
assert "/tmp" in result["output"]
@pytest.mark.asyncio
async def test_execute_with_session(self):
"""会话模式执行命令"""
tool = ShellTool()
result = await tool.execute(command="echo session_test", session_id="s1")
assert result["exit_code"] == 0
assert "session_test" in result["output"]
assert result["session_id"] == "s1"
@pytest.mark.asyncio
async def test_session_preserves_cwd(self):
"""会话模式保持 cwd"""
tool = ShellTool()
await tool.execute(command="cd /tmp", session_id="cwd-test")
result = await tool.execute(command="pwd", session_id="cwd-test")
assert "/tmp" in result["output"]
@pytest.mark.asyncio
async def test_session_preserves_env(self):
"""会话模式保持 env"""
tool = ShellTool()
await tool.execute(
command="export SHELL_TEST_VAR=world", session_id="env-test"
)
result = await tool.execute(
command="echo $SHELL_TEST_VAR", session_id="env-test"
)
assert "world" in result["output"]
@pytest.mark.asyncio
async def test_no_session_id_backward_compatible(self):
"""无 session_id 时保持现有行为"""
tool = ShellTool()
result = await tool.execute(command="echo no_session")
assert result["exit_code"] == 0
assert "no_session" in result["output"]
assert result["session_id"] is None
@pytest.mark.asyncio
async def test_different_sessions_independent(self):
"""不同会话互不影响"""
tool = ShellTool()
await tool.execute(command="cd /tmp", session_id="s1")
await tool.execute(command="cd /usr", session_id="s2")
r1 = await tool.execute(command="pwd", session_id="s1")
r2 = await tool.execute(command="pwd", session_id="s2")
assert "/tmp" in r1["output"]
assert "/usr" in r2["output"]
class TestShellToolSecurity:
"""测试 ShellTool 安全控制"""
@pytest.mark.asyncio
async def test_safe_command_allowed(self):
"""安全命令直接执行"""
tool = ShellTool()
result = await tool.execute(command="ls /tmp")
assert result["exit_code"] == 0
@pytest.mark.asyncio
async def test_dangerous_command_blocked_without_callback(self):
"""危险命令无确认回调时被拒绝"""
tool = ShellTool()
result = await tool.execute(command="rm -rf /tmp/test")
assert result["is_error"] is True
assert result["exit_code"] == 126
@pytest.mark.asyncio
async def test_dangerous_command_confirmed(self):
"""危险命令通过确认回调允许执行"""
confirm = AsyncMock(return_value=True)
tool = ShellTool(confirm_callback=confirm)
result = await tool.execute(command="rm -rf /tmp/nonexistent_test_dir")
assert confirm.called
# 命令本身可能失败(目录不存在),但不应被安全机制拒绝
assert result["exit_code"] != 126 or not result["is_error"]
@pytest.mark.asyncio
async def test_dangerous_command_rejected_by_callback(self):
"""确认回调拒绝危险命令"""
confirm = AsyncMock(return_value=False)
tool = ShellTool(confirm_callback=confirm)
result = await tool.execute(command="rm -rf /tmp/test")
assert result["is_error"] is True
assert result["exit_code"] == 126
@pytest.mark.asyncio
async def test_audit_log_recorded(self):
"""审计日志记录"""
tool = ShellTool()
await tool.execute(command="echo audit_test")
assert len(tool.audit_log) > 0
assert tool.audit_log[0]["command"] == "echo audit_test"
@pytest.mark.asyncio
async def test_blocked_command_in_audit_log(self):
"""被阻止的命令记录在审计日志"""
tool = ShellTool()
await tool.execute(command="rm -rf /tmp/test")
blocked_entries = [e for e in tool.audit_log if e.get("blocked")]
assert len(blocked_entries) > 0
@pytest.mark.asyncio
async def test_git_push_force_is_dangerous(self):
"""git push --force 是危险命令"""
tool = ShellTool()
result = await tool.execute(command="git push --force origin main")
assert result["is_error"] is True
assert result["exit_code"] == 126
@pytest.mark.asyncio
async def test_git_status_is_safe(self):
"""git status 是安全命令"""
tool = ShellTool()
result = await tool.execute(command="git status")
# git status 可能在非 git 目录失败,但不应被安全机制拒绝
assert result["exit_code"] != 126
class TestShellToolOutputParsing:
"""测试 ShellTool 输出解析集成"""
@pytest.mark.asyncio
async def test_error_output_structured(self):
"""错误输出结构化"""
tool = ShellTool()
result = await tool.execute(command="ls /nonexistent_dir_xyz_12345")
assert result["is_error"] is True
assert result["error_type"] in ("not_found", "unknown")
assert isinstance(result["suggestions"], list)
@pytest.mark.asyncio
async def test_success_output_not_error(self):
"""成功输出不标记为错误"""
tool = ShellTool()
result = await tool.execute(command="echo success")
assert result["is_error"] is False
assert result["error_type"] == "none"
class TestShellToolSessionManager:
"""测试 ShellTool 会话管理器访问"""
def test_session_manager_accessible(self):
tool = ShellTool()
assert tool.session_manager is not None
@pytest.mark.asyncio
async def test_session_created_on_first_use(self):
"""首次使用 session_id 时创建会话"""
tool = ShellTool()
assert not tool.session_manager.has_session("new-session")
await tool.execute(command="echo test", session_id="new-session")
assert tool.session_manager.has_session("new-session")