"""ShellTool - Shell 命令执行工具 支持无会话模式(向后兼容)和有会话模式(跨命令保持状态)。 危险命令通过确认回调请求人工确认,所有操作记录审计日志。 """ from __future__ import annotations import asyncio import logging import os import re import shlex import time import uuid from collections import deque from typing import Callable, Awaitable from agentkit.tools.base import Tool from agentkit.tools.output_parser import OutputParser, ParsedOutput from agentkit.tools.terminal_session import TerminalSessionManager from agentkit.tools.pty_session import PTYSession logger = logging.getLogger(__name__) # 安全白名单:这些命令前缀不需要确认 _SAFE_COMMAND_PREFIXES: tuple[str, ...] = ( # 文件浏览 "cd", "ls", "pwd", "find", "file", "stat", "tree", "du", "wc", # 文本查看/处理 "cat", "head", "tail", "less", "more", "grep", "egrep", "fgrep", "sort", "uniq", "diff", "comm", "cut", "tr", "awk", "sed", # sed without -i is read-only; -i is caught by operator check "tee", "xargs", # 输出 "echo", "printf", # 系统信息 "whoami", "id", "date", "uname", "hostname", "uptime", "df", "free", "top", "htop", "ps", "env", "printenv", "type", "which", "whereis", "arch", "lscpu", "lsblk", "mount", # 网络(只读查询) "curl", # curl GET is safe; POST/PUT caught by operator check "wget", # wget download is generally safe "ping", "traceroute", "dig", "nslookup", "host", "ifconfig", "ip", "netstat", "ss", "lsof", # 版本/帮助 "python --version", "python3 --version", "node --version", "npm --version", "npm list", "npm view", "npm info", "npm search", "npx --version", "pip list", "pip show", "pip search", "java --version", "go version", "rustc --version", "cargo --version", "git --version", "docker --version", "docker ps", "docker images", "docker logs", "docker inspect", # Git 只读 "git status", "git log", "git diff", "git branch", "git remote", "git show", "git tag", "git stash list", "git config --list", # 包管理器(只读查询) "brew list", "brew info", "brew search", "apt list", "apt search", "apt show", "yum list", "yum search", "yum info", # 其他安全命令 "export", "sleep", "true", "false", "test", "seq", "basename", "dirname", "realpath", "readlink", "md5sum", "sha256sum", "shasum", "openssl", # openssl dgst / rand are safe # tar 解压查看 "tar -tf", "tar --list", "zipinfo", "unzip -l", ) # 危险命令检测 — 基于精确 token 匹配,避免子串误判 # 总是危险的二进制命令(无论参数) _DANGEROUS_BINARIES: frozenset[str] = frozenset( { "rm", "rmdir", "mkfs", "dd", "format", "shutdown", "reboot", "halt", "killall", "chown", "fdisk", "parted", } ) # 需要特定参数才危险的二进制命令:binary → 危险 flag/子命令集合 _DANGEROUS_BINARY_FLAGS: dict[str, set[str]] = { "rm": {"-rf", "-fr", "-r", "-f"}, "kill": {"-9", "-kill"}, "chmod": {"777", "000"}, "git": {"push --force", "push -f", "reset --hard", "clean -f"}, "pip": {"uninstall"}, "npm": {"uninstall"}, "docker": {"rm", "rmi", "system prune"}, } # 跨 token 的危险模式(编译后的正则) _DANGEROUS_ARG_PATTERNS: list[re.Pattern[str]] = [ re.compile(r">\s*/dev/", re.IGNORECASE), re.compile(r">\s*/etc/", re.IGNORECASE), re.compile(r"drop\s+table", re.IGNORECASE), re.compile(r"drop\s+database", re.IGNORECASE), re.compile(r"truncate\s+table", re.IGNORECASE), # curl/wget data exfiltration: POST/PUT/upload flags re.compile( r"\bcurl\b.*(-X\s*(POST|PUT|PATCH|DELETE)|--data|--data-binary|--data-raw|--data-urlencode|-d\s|--post\d)", re.IGNORECASE, ), re.compile(r"\bwget\b.*(--post-data|--post-file)", re.IGNORECASE), ] _SHELL_PIPE_OPERATORS = re.compile(r"\|") _SHELL_CHAIN_OPERATORS = re.compile(r"[;&]|\|\||&&|\$\(|\$\{|`|\$<|>|<|\n") class ShellTool(Tool): """Shell 命令执行工具 支持两种模式: 1. 无会话模式(默认):每次命令独立执行,不保持状态 2. 有会话模式:通过 session_id 指定会话,跨命令保持 cwd/env/history 安全控制: - 危险命令通过 confirm_callback 请求人工确认 - 所有操作记录审计日志 Usage: # 无会话模式 tool = ShellTool() result = await tool.execute(command="ls -la") # 有会话模式 result = await tool.execute(command="cd /tmp", session_id="build-01") result = await tool.execute(command="pwd", session_id="build-01") # 输出 /tmp """ def __init__( self, name: str = "shell", description: str = "执行 Shell 命令,支持会话模式保持跨命令状态", input_schema: dict[str, object] | None = None, output_schema: dict[str, object] | None = None, version: str = "1.0.0", tags: list[str] | None = None, confirm_callback: Callable[[str], Awaitable[bool]] | None = None, default_timeout: float = 60.0, max_output_length: int = 50000, ): super().__init__( name=name, description=description, input_schema=input_schema or self._default_input_schema(), output_schema=output_schema or self._default_output_schema(), version=version, tags=tags or ["shell", "terminal", "system"], ) self._session_manager = TerminalSessionManager() self._output_parser = OutputParser() self._confirm_callback = confirm_callback self._default_timeout = default_timeout self._max_output_length = max_output_length self._audit_log: deque[dict[str, object]] = deque(maxlen=10000) @staticmethod def _default_input_schema() -> dict[str, object]: return { "type": "object", "properties": { "command": { "type": "string", "description": "要执行的 Shell 命令", }, "timeout": { "type": "number", "description": "超时时间(秒),默认 60", "default": 60, }, "working_dir": { "type": "string", "description": "工作目录(仅无会话模式有效)", }, "session_id": { "type": "string", "description": "会话 ID,指定后在会话中执行命令,跨命令保持状态", }, "interactive": { "type": "boolean", "description": "是否使用交互式模式(PTY),用于需要用户输入的命令", "default": False, }, }, "required": ["command"], } @staticmethod def _default_output_schema() -> dict[str, object]: return { "type": "object", "properties": { "output": {"type": "string", "description": "命令输出"}, "exit_code": {"type": "integer", "description": "退出码"}, "is_error": {"type": "boolean", "description": "是否为错误"}, "error_type": {"type": "string", "description": "错误类型"}, "message": {"type": "string", "description": "消息摘要"}, "suggestions": { "type": "array", "items": {"type": "string"}, "description": "可操作建议", }, "session_id": {"type": "string", "description": "会话 ID(仅会话模式)"}, }, } async def execute(self, **kwargs) -> dict: """执行 Shell 命令 Args: command: 要执行的命令(必需) timeout: 超时时间(秒) working_dir: 工作目录(仅无会话模式) session_id: 会话 ID(启用会话模式) interactive: 是否使用交互式模式 Returns: 包含 output, exit_code, is_error 等字段的字典 """ command = kwargs.get("command") if not command: return { "output": "", "exit_code": 1, "is_error": True, "error_type": "invalid_argument", "message": "command 参数是必需的", "suggestions": ["提供要执行的 Shell 命令"], } timeout = kwargs.get("timeout", self._default_timeout) working_dir = kwargs.get("working_dir") session_id = kwargs.get("session_id") interactive = kwargs.get("interactive", False) # 安全检查:危险命令需要确认(除非已通过 _skip_dangerous_check 跳过) skip_dangerous = kwargs.get("_skip_dangerous_check", False) if not skip_dangerous and self._is_dangerous(command): confirmed = await self._request_confirmation(command) if not confirmed: self._log_audit(command, None, blocked=True) # 返回确认请求,由上层(ReActEngine/chat)处理 confirmation_id = str(uuid.uuid4()) return { "needs_confirmation": True, "confirmation_id": confirmation_id, "command": command[:500], "reason": "此命令被识别为潜在危险操作,需要用户确认", "suggestions": [ "确认执行此命令", "拒绝执行此命令", ], "output": f"命令被拒绝: {command[:100]}", "exit_code": 126, "is_error": True, } # 根据模式执行 if session_id: result = await self._execute_in_session( command, session_id, timeout, working_dir, interactive ) else: result = await self._execute_standalone(command, timeout, working_dir, interactive) # 审计日志 self._log_audit(command, session_id, exit_code=result.exit_code) # 截断过长输出 output = result.raw_output if len(output) > self._max_output_length: output = output[: self._max_output_length] + "\n... [输出已截断]" # Ensure non-empty output for successful commands (all execution modes) if result.exit_code == 0 and not output.strip(): from agentkit.core.fallback import SHELL_NO_OUTPUT output = SHELL_NO_OUTPUT return { "output": output, "exit_code": result.exit_code, "is_error": result.is_error, "error_type": result.error_type.value, "message": result.message, "suggestions": result.suggestions, "session_id": session_id, } async def _execute_standalone( self, command: str, timeout: float, working_dir: str | None, interactive: bool, ) -> ParsedOutput: """无会话模式执行命令(向后兼容)""" if interactive: return await self._execute_with_pty(command, timeout, working_dir) try: proc = await asyncio.create_subprocess_shell( command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.STDOUT, cwd=working_dir, ) try: stdout, _ = await asyncio.wait_for( proc.communicate(), timeout=timeout, ) except asyncio.TimeoutError: try: proc.kill() except ProcessLookupError: logger.debug("Process already exited before kill()") except OSError: logger.debug("OSError killing process") await proc.wait() output = f"命令执行超时({timeout}s)" exit_code = proc.returncode if proc.returncode is not None else -1 else: output = stdout.decode("utf-8", errors="replace") if stdout else "" exit_code = proc.returncode if proc.returncode is not None else 0 except Exception as e: output = str(e) exit_code = -1 return self._output_parser.parse(output, exit_code) async def _execute_in_session( self, command: str, session_id: str, timeout: float, working_dir: str | None, interactive: bool, ) -> ParsedOutput: """会话模式执行命令""" session = self._session_manager.get_or_create( session_id, cwd=working_dir, ) if interactive: return await self._execute_with_pty(command, timeout, session.cwd, session.env) return await session.execute(command, timeout=timeout) async def _execute_with_pty( self, command: str, timeout: float, cwd: str | None = None, env: dict[str, str] | None = None, ) -> ParsedOutput: """使用 PTY 执行交互式命令""" pty = PTYSession() try: await pty.start() result = await pty.run_command( command, timeout=timeout, cwd=cwd, env=env, ) output = result.output exit_code = result.exit_code except Exception as e: output = str(e) exit_code = -1 finally: await pty.close() return self._output_parser.parse(output, exit_code) @staticmethod def _is_dangerous(command: str) -> bool: """检查命令是否为危险操作 白名单命令直接放行。管道命令(|)在所有子命令都安全时放行。 其他链式操作符(;、&&、||、$()、>、< 等)一律视为危险。 Static so callers without a ShellTool instance (e.g. PhasePolicy) can reuse the same danger classification. Instance calls still work via Python's descriptor protocol. """ command_stripped = command.strip() # Check for dangerous chain operators (;, &&, ||, $(), backticks, redirections, newlines) if _SHELL_CHAIN_OPERATORS.search(command_stripped): return True # Handle pipe commands: split and check each sub-command if _SHELL_PIPE_OPERATORS.search(command_stripped): parts = command_stripped.split("|") for part in parts: part = part.strip() if not part: continue if ShellTool._is_single_command_dangerous(part): return True return False # All pipe segments are safe # Single command return ShellTool._is_single_command_dangerous(command_stripped) @staticmethod def _is_single_command_dangerous(command: str) -> bool: """Check if a single command (no pipes/chains) is dangerous.""" command_stripped = command.strip() if not command_stripped: return True # Parse the actual binary being invoked try: tokens = shlex.split(command_stripped) if not tokens: return True binary = os.path.basename(tokens[0]) except ValueError: return True # Whitelist check: first try full command prefix match, then binary-only match for prefix in _SAFE_COMMAND_PREFIXES: prefix_stripped = prefix.lower().strip() if " " in prefix_stripped: # Compound prefix like "git status" - match against full command if command_stripped.lower().startswith(prefix_stripped): return False else: # Simple prefix - match against binary name exactly if binary.lower() == prefix_stripped: return False # Dangerous pattern check — token-based matching binary_lower = binary.lower() # 1. Binary is always dangerous regardless of flags if binary_lower in _DANGEROUS_BINARIES: return True # 2. Binary is dangerous with specific flags/subcommands if binary_lower in _DANGEROUS_BINARY_FLAGS: cmd_str = " ".join(tokens).lower() for flag_pattern in _DANGEROUS_BINARY_FLAGS[binary_lower]: if flag_pattern in cmd_str: return True # Binary has dangerous flags but none matched — treat as safe # (e.g., "git add" is safe even though "git push --force" is not) return False # 3. Cross-token dangerous patterns (regex) command_lower = command_stripped.lower() for pattern in _DANGEROUS_ARG_PATTERNS: if pattern.search(command_lower): return True # 4. Unknown binary — check if it looks like a path or known safe pattern # Commands like /usr/bin/python3, ./script.sh, etc. are not in whitelist # but may be safe. Default to requiring confirmation for truly unknown binaries. return True async def _request_confirmation(self, command: str) -> bool: """请求人工确认危险命令 Args: command: 待确认的命令 Returns: 是否确认执行 """ if self._confirm_callback: try: return await self._confirm_callback(command) except Exception as e: logger.warning("确认回调执行失败: %s", e) return False # 无回调时默认拒绝 logger.warning("危险命令被拒绝(无确认回调): %s", command[:100]) return False def _log_audit( self, command: str, session_id: str | None, exit_code: int | None = None, blocked: bool = False, ) -> None: """记录审计日志""" entry = { "timestamp": time.time(), "command": command[:500], "session_id": session_id, "exit_code": exit_code, "blocked": blocked, } self._audit_log.append(entry) logger.info( "Shell audit: command=%r session=%s exit=%s blocked=%s", command[:100], session_id, exit_code, blocked, ) @property def session_manager(self) -> TerminalSessionManager: """获取会话管理器""" return self._session_manager @property def audit_log(self) -> list[dict[str, object]]: """获取审计日志(副本)""" return list(self._audit_log)