619 lines
19 KiB
Python
619 lines
19 KiB
Python
"""ShellTool - Shell 命令执行工具
|
||
|
||
支持无会话模式(向后兼容)和有会话模式(跨命令保持状态)。
|
||
危险命令通过确认回调请求人工确认,所有操作记录审计日志。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import logging
|
||
import os
|
||
import re
|
||
import shlex
|
||
import time
|
||
import uuid
|
||
from collections import deque
|
||
from typing import Callable, Awaitable
|
||
|
||
from agentkit.tools.base import Tool
|
||
from agentkit.tools.output_parser import OutputParser, ParsedOutput
|
||
from agentkit.tools.terminal_session import TerminalSessionManager
|
||
from agentkit.tools.pty_session import PTYSession
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 安全白名单:这些命令前缀不需要确认
|
||
_SAFE_COMMAND_PREFIXES: tuple[str, ...] = (
|
||
# 文件浏览
|
||
"cd",
|
||
"ls",
|
||
"pwd",
|
||
"find",
|
||
"file",
|
||
"stat",
|
||
"tree",
|
||
"du",
|
||
"wc",
|
||
# 文本查看/处理
|
||
"cat",
|
||
"head",
|
||
"tail",
|
||
"less",
|
||
"more",
|
||
"grep",
|
||
"egrep",
|
||
"fgrep",
|
||
"sort",
|
||
"uniq",
|
||
"diff",
|
||
"comm",
|
||
"cut",
|
||
"tr",
|
||
"awk",
|
||
"sed", # sed without -i is read-only; -i is caught by operator check
|
||
"tee",
|
||
"xargs",
|
||
# 输出
|
||
"echo",
|
||
"printf",
|
||
# 系统信息
|
||
"whoami",
|
||
"id",
|
||
"date",
|
||
"uname",
|
||
"hostname",
|
||
"uptime",
|
||
"df",
|
||
"free",
|
||
"top",
|
||
"htop",
|
||
"ps",
|
||
"env",
|
||
"printenv",
|
||
"type",
|
||
"which",
|
||
"whereis",
|
||
"arch",
|
||
"lscpu",
|
||
"lsblk",
|
||
"mount",
|
||
# 网络(只读查询)
|
||
"curl", # curl GET is safe; POST/PUT caught by operator check
|
||
"wget", # wget download is generally safe
|
||
"ping",
|
||
"traceroute",
|
||
"dig",
|
||
"nslookup",
|
||
"host",
|
||
"ifconfig",
|
||
"ip",
|
||
"netstat",
|
||
"ss",
|
||
"lsof",
|
||
# 版本/帮助
|
||
"python --version",
|
||
"python3 --version",
|
||
"node --version",
|
||
"npm --version",
|
||
"npm list",
|
||
"npm view",
|
||
"npm info",
|
||
"npm search",
|
||
"npx --version",
|
||
"pip list",
|
||
"pip show",
|
||
"pip search",
|
||
"java --version",
|
||
"go version",
|
||
"rustc --version",
|
||
"cargo --version",
|
||
"git --version",
|
||
"docker --version",
|
||
"docker ps",
|
||
"docker images",
|
||
"docker logs",
|
||
"docker inspect",
|
||
# Git 只读
|
||
"git status",
|
||
"git log",
|
||
"git diff",
|
||
"git branch",
|
||
"git remote",
|
||
"git show",
|
||
"git tag",
|
||
"git stash list",
|
||
"git config --list",
|
||
# 包管理器(只读查询)
|
||
"brew list",
|
||
"brew info",
|
||
"brew search",
|
||
"apt list",
|
||
"apt search",
|
||
"apt show",
|
||
"yum list",
|
||
"yum search",
|
||
"yum info",
|
||
# 其他安全命令
|
||
"export",
|
||
"sleep",
|
||
"true",
|
||
"false",
|
||
"test",
|
||
"seq",
|
||
"basename",
|
||
"dirname",
|
||
"realpath",
|
||
"readlink",
|
||
"md5sum",
|
||
"sha256sum",
|
||
"shasum",
|
||
"openssl", # openssl dgst / rand are safe
|
||
# tar 解压查看
|
||
"tar -tf",
|
||
"tar --list",
|
||
"zipinfo",
|
||
"unzip -l",
|
||
)
|
||
|
||
# 危险命令检测 — 基于精确 token 匹配,避免子串误判
|
||
|
||
# 总是危险的二进制命令(无论参数)
|
||
_DANGEROUS_BINARIES: frozenset[str] = frozenset(
|
||
{
|
||
"rm",
|
||
"rmdir",
|
||
"mkfs",
|
||
"dd",
|
||
"format",
|
||
"shutdown",
|
||
"reboot",
|
||
"halt",
|
||
"killall",
|
||
"chown",
|
||
"fdisk",
|
||
"parted",
|
||
}
|
||
)
|
||
|
||
# 需要特定参数才危险的二进制命令:binary → 危险 flag/子命令集合
|
||
_DANGEROUS_BINARY_FLAGS: dict[str, set[str]] = {
|
||
"rm": {"-rf", "-fr", "-r", "-f"},
|
||
"kill": {"-9", "-kill"},
|
||
"chmod": {"777", "000"},
|
||
"git": {"push --force", "push -f", "reset --hard", "clean -f"},
|
||
"pip": {"uninstall"},
|
||
"npm": {"uninstall"},
|
||
"docker": {"rm", "rmi", "system prune"},
|
||
}
|
||
|
||
# 跨 token 的危险模式(编译后的正则)
|
||
_DANGEROUS_ARG_PATTERNS: list[re.Pattern[str]] = [
|
||
re.compile(r">\s*/dev/", re.IGNORECASE),
|
||
re.compile(r">\s*/etc/", re.IGNORECASE),
|
||
re.compile(r"drop\s+table", re.IGNORECASE),
|
||
re.compile(r"drop\s+database", re.IGNORECASE),
|
||
re.compile(r"truncate\s+table", re.IGNORECASE),
|
||
# curl/wget data exfiltration: POST/PUT/upload flags
|
||
re.compile(
|
||
r"\bcurl\b.*(-X\s*(POST|PUT|PATCH|DELETE)|--data|--data-binary|--data-raw|--data-urlencode|-d\s|--post\d)",
|
||
re.IGNORECASE,
|
||
),
|
||
re.compile(r"\bwget\b.*(--post-data|--post-file)", re.IGNORECASE),
|
||
]
|
||
|
||
|
||
_SHELL_PIPE_OPERATORS = re.compile(r"\|")
|
||
_SHELL_CHAIN_OPERATORS = re.compile(r"[;&]|\|\||&&|\$\(|\$\{|`|\$<|>|<|\n")
|
||
|
||
|
||
class ShellTool(Tool):
|
||
"""Shell 命令执行工具
|
||
|
||
支持两种模式:
|
||
1. 无会话模式(默认):每次命令独立执行,不保持状态
|
||
2. 有会话模式:通过 session_id 指定会话,跨命令保持 cwd/env/history
|
||
|
||
安全控制:
|
||
- 危险命令通过 confirm_callback 请求人工确认
|
||
- 所有操作记录审计日志
|
||
|
||
Usage:
|
||
# 无会话模式
|
||
tool = ShellTool()
|
||
result = await tool.execute(command="ls -la")
|
||
|
||
# 有会话模式
|
||
result = await tool.execute(command="cd /tmp", session_id="build-01")
|
||
result = await tool.execute(command="pwd", session_id="build-01") # 输出 /tmp
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
name: str = "shell",
|
||
description: str = "执行 Shell 命令,支持会话模式保持跨命令状态",
|
||
input_schema: dict[str, object] | None = None,
|
||
output_schema: dict[str, object] | None = None,
|
||
version: str = "1.0.0",
|
||
tags: list[str] | None = None,
|
||
confirm_callback: Callable[[str], Awaitable[bool]] | None = None,
|
||
default_timeout: float = 60.0,
|
||
max_output_length: int = 50000,
|
||
):
|
||
super().__init__(
|
||
name=name,
|
||
description=description,
|
||
input_schema=input_schema or self._default_input_schema(),
|
||
output_schema=output_schema or self._default_output_schema(),
|
||
version=version,
|
||
tags=tags or ["shell", "terminal", "system"],
|
||
)
|
||
self._session_manager = TerminalSessionManager()
|
||
self._output_parser = OutputParser()
|
||
self._confirm_callback = confirm_callback
|
||
self._default_timeout = default_timeout
|
||
self._max_output_length = max_output_length
|
||
self._audit_log: deque[dict[str, object]] = deque(maxlen=10000)
|
||
|
||
@staticmethod
|
||
def _default_input_schema() -> dict[str, object]:
|
||
return {
|
||
"type": "object",
|
||
"properties": {
|
||
"command": {
|
||
"type": "string",
|
||
"description": "要执行的 Shell 命令",
|
||
},
|
||
"timeout": {
|
||
"type": "number",
|
||
"description": "超时时间(秒),默认 60",
|
||
"default": 60,
|
||
},
|
||
"working_dir": {
|
||
"type": "string",
|
||
"description": "工作目录(仅无会话模式有效)",
|
||
},
|
||
"session_id": {
|
||
"type": "string",
|
||
"description": "会话 ID,指定后在会话中执行命令,跨命令保持状态",
|
||
},
|
||
"interactive": {
|
||
"type": "boolean",
|
||
"description": "是否使用交互式模式(PTY),用于需要用户输入的命令",
|
||
"default": False,
|
||
},
|
||
},
|
||
"required": ["command"],
|
||
}
|
||
|
||
@staticmethod
|
||
def _default_output_schema() -> dict[str, object]:
|
||
return {
|
||
"type": "object",
|
||
"properties": {
|
||
"output": {"type": "string", "description": "命令输出"},
|
||
"exit_code": {"type": "integer", "description": "退出码"},
|
||
"is_error": {"type": "boolean", "description": "是否为错误"},
|
||
"error_type": {"type": "string", "description": "错误类型"},
|
||
"message": {"type": "string", "description": "消息摘要"},
|
||
"suggestions": {
|
||
"type": "array",
|
||
"items": {"type": "string"},
|
||
"description": "可操作建议",
|
||
},
|
||
"session_id": {"type": "string", "description": "会话 ID(仅会话模式)"},
|
||
},
|
||
}
|
||
|
||
async def execute(self, **kwargs) -> dict:
|
||
"""执行 Shell 命令
|
||
|
||
Args:
|
||
command: 要执行的命令(必需)
|
||
timeout: 超时时间(秒)
|
||
working_dir: 工作目录(仅无会话模式)
|
||
session_id: 会话 ID(启用会话模式)
|
||
interactive: 是否使用交互式模式
|
||
|
||
Returns:
|
||
包含 output, exit_code, is_error 等字段的字典
|
||
"""
|
||
command = kwargs.get("command")
|
||
if not command:
|
||
return {
|
||
"output": "",
|
||
"exit_code": 1,
|
||
"is_error": True,
|
||
"error_type": "invalid_argument",
|
||
"message": "command 参数是必需的",
|
||
"suggestions": ["提供要执行的 Shell 命令"],
|
||
}
|
||
|
||
timeout = kwargs.get("timeout", self._default_timeout)
|
||
working_dir = kwargs.get("working_dir")
|
||
session_id = kwargs.get("session_id")
|
||
interactive = kwargs.get("interactive", False)
|
||
|
||
# 安全检查:危险命令需要确认(除非已通过 _skip_dangerous_check 跳过)
|
||
skip_dangerous = kwargs.get("_skip_dangerous_check", False)
|
||
if not skip_dangerous and self._is_dangerous(command):
|
||
confirmed = await self._request_confirmation(command)
|
||
if not confirmed:
|
||
self._log_audit(command, None, blocked=True)
|
||
# 返回确认请求,由上层(ReActEngine/chat)处理
|
||
confirmation_id = str(uuid.uuid4())
|
||
return {
|
||
"needs_confirmation": True,
|
||
"confirmation_id": confirmation_id,
|
||
"command": command[:500],
|
||
"reason": "此命令被识别为潜在危险操作,需要用户确认",
|
||
"suggestions": [
|
||
"确认执行此命令",
|
||
"拒绝执行此命令",
|
||
],
|
||
"output": f"命令被拒绝: {command[:100]}",
|
||
"exit_code": 126,
|
||
"is_error": True,
|
||
}
|
||
|
||
# 根据模式执行
|
||
if session_id:
|
||
result = await self._execute_in_session(
|
||
command, session_id, timeout, working_dir, interactive
|
||
)
|
||
else:
|
||
result = await self._execute_standalone(command, timeout, working_dir, interactive)
|
||
|
||
# 审计日志
|
||
self._log_audit(command, session_id, exit_code=result.exit_code)
|
||
|
||
# 截断过长输出
|
||
output = result.raw_output
|
||
if len(output) > self._max_output_length:
|
||
output = output[: self._max_output_length] + "\n... [输出已截断]"
|
||
|
||
# Ensure non-empty output for successful commands (all execution modes)
|
||
if result.exit_code == 0 and not output.strip():
|
||
from agentkit.core.fallback import SHELL_NO_OUTPUT
|
||
|
||
output = SHELL_NO_OUTPUT
|
||
|
||
return {
|
||
"output": output,
|
||
"exit_code": result.exit_code,
|
||
"is_error": result.is_error,
|
||
"error_type": result.error_type.value,
|
||
"message": result.message,
|
||
"suggestions": result.suggestions,
|
||
"session_id": session_id,
|
||
}
|
||
|
||
async def _execute_standalone(
|
||
self,
|
||
command: str,
|
||
timeout: float,
|
||
working_dir: str | None,
|
||
interactive: bool,
|
||
) -> ParsedOutput:
|
||
"""无会话模式执行命令(向后兼容)"""
|
||
if interactive:
|
||
return await self._execute_with_pty(command, timeout, working_dir)
|
||
|
||
try:
|
||
proc = await asyncio.create_subprocess_shell(
|
||
command,
|
||
stdout=asyncio.subprocess.PIPE,
|
||
stderr=asyncio.subprocess.STDOUT,
|
||
cwd=working_dir,
|
||
)
|
||
try:
|
||
stdout, _ = await asyncio.wait_for(
|
||
proc.communicate(),
|
||
timeout=timeout,
|
||
)
|
||
except asyncio.TimeoutError:
|
||
try:
|
||
proc.kill()
|
||
except ProcessLookupError:
|
||
logger.debug("Process already exited before kill()")
|
||
except OSError:
|
||
logger.debug("OSError killing process")
|
||
await proc.wait()
|
||
output = f"命令执行超时({timeout}s)"
|
||
exit_code = proc.returncode if proc.returncode is not None else -1
|
||
else:
|
||
output = stdout.decode("utf-8", errors="replace") if stdout else ""
|
||
exit_code = proc.returncode if proc.returncode is not None else 0
|
||
except Exception as e:
|
||
output = str(e)
|
||
exit_code = -1
|
||
|
||
return self._output_parser.parse(output, exit_code)
|
||
|
||
async def _execute_in_session(
|
||
self,
|
||
command: str,
|
||
session_id: str,
|
||
timeout: float,
|
||
working_dir: str | None,
|
||
interactive: bool,
|
||
) -> ParsedOutput:
|
||
"""会话模式执行命令"""
|
||
session = self._session_manager.get_or_create(
|
||
session_id,
|
||
cwd=working_dir,
|
||
)
|
||
|
||
if interactive:
|
||
return await self._execute_with_pty(command, timeout, session.cwd, session.env)
|
||
|
||
return await session.execute(command, timeout=timeout)
|
||
|
||
async def _execute_with_pty(
|
||
self,
|
||
command: str,
|
||
timeout: float,
|
||
cwd: str | None = None,
|
||
env: dict[str, str] | None = None,
|
||
) -> ParsedOutput:
|
||
"""使用 PTY 执行交互式命令"""
|
||
pty = PTYSession()
|
||
try:
|
||
await pty.start()
|
||
result = await pty.run_command(
|
||
command,
|
||
timeout=timeout,
|
||
cwd=cwd,
|
||
env=env,
|
||
)
|
||
output = result.output
|
||
exit_code = result.exit_code
|
||
except Exception as e:
|
||
output = str(e)
|
||
exit_code = -1
|
||
finally:
|
||
await pty.close()
|
||
|
||
return self._output_parser.parse(output, exit_code)
|
||
|
||
@staticmethod
|
||
def _is_dangerous(command: str) -> bool:
|
||
"""检查命令是否为危险操作
|
||
|
||
白名单命令直接放行。管道命令(|)在所有子命令都安全时放行。
|
||
其他链式操作符(;、&&、||、$()、>、< 等)一律视为危险。
|
||
|
||
Static so callers without a ShellTool instance (e.g. PhasePolicy) can
|
||
reuse the same danger classification. Instance calls still work via
|
||
Python's descriptor protocol.
|
||
"""
|
||
command_stripped = command.strip()
|
||
|
||
# Check for dangerous chain operators (;, &&, ||, $(), backticks, redirections, newlines)
|
||
if _SHELL_CHAIN_OPERATORS.search(command_stripped):
|
||
return True
|
||
|
||
# Handle pipe commands: split and check each sub-command
|
||
if _SHELL_PIPE_OPERATORS.search(command_stripped):
|
||
parts = command_stripped.split("|")
|
||
for part in parts:
|
||
part = part.strip()
|
||
if not part:
|
||
continue
|
||
if ShellTool._is_single_command_dangerous(part):
|
||
return True
|
||
return False # All pipe segments are safe
|
||
|
||
# Single command
|
||
return ShellTool._is_single_command_dangerous(command_stripped)
|
||
|
||
@staticmethod
|
||
def _is_single_command_dangerous(command: str) -> bool:
|
||
"""Check if a single command (no pipes/chains) is dangerous."""
|
||
command_stripped = command.strip()
|
||
if not command_stripped:
|
||
return True
|
||
|
||
# Parse the actual binary being invoked
|
||
try:
|
||
tokens = shlex.split(command_stripped)
|
||
if not tokens:
|
||
return True
|
||
binary = os.path.basename(tokens[0])
|
||
except ValueError:
|
||
return True
|
||
|
||
# Whitelist check: first try full command prefix match, then binary-only match
|
||
for prefix in _SAFE_COMMAND_PREFIXES:
|
||
prefix_stripped = prefix.lower().strip()
|
||
if " " in prefix_stripped:
|
||
# Compound prefix like "git status" - match against full command
|
||
if command_stripped.lower().startswith(prefix_stripped):
|
||
return False
|
||
else:
|
||
# Simple prefix - match against binary name exactly
|
||
if binary.lower() == prefix_stripped:
|
||
return False
|
||
|
||
# Dangerous pattern check — token-based matching
|
||
binary_lower = binary.lower()
|
||
|
||
# 1. Binary is always dangerous regardless of flags
|
||
if binary_lower in _DANGEROUS_BINARIES:
|
||
return True
|
||
|
||
# 2. Binary is dangerous with specific flags/subcommands
|
||
if binary_lower in _DANGEROUS_BINARY_FLAGS:
|
||
cmd_str = " ".join(tokens).lower()
|
||
for flag_pattern in _DANGEROUS_BINARY_FLAGS[binary_lower]:
|
||
if flag_pattern in cmd_str:
|
||
return True
|
||
# Binary has dangerous flags but none matched — treat as safe
|
||
# (e.g., "git add" is safe even though "git push --force" is not)
|
||
return False
|
||
|
||
# 3. Cross-token dangerous patterns (regex)
|
||
command_lower = command_stripped.lower()
|
||
for pattern in _DANGEROUS_ARG_PATTERNS:
|
||
if pattern.search(command_lower):
|
||
return True
|
||
|
||
# 4. Unknown binary — check if it looks like a path or known safe pattern
|
||
# Commands like /usr/bin/python3, ./script.sh, etc. are not in whitelist
|
||
# but may be safe. Default to requiring confirmation for truly unknown binaries.
|
||
return True
|
||
|
||
async def _request_confirmation(self, command: str) -> bool:
|
||
"""请求人工确认危险命令
|
||
|
||
Args:
|
||
command: 待确认的命令
|
||
|
||
Returns:
|
||
是否确认执行
|
||
"""
|
||
if self._confirm_callback:
|
||
try:
|
||
return await self._confirm_callback(command)
|
||
except Exception as e:
|
||
logger.warning("确认回调执行失败: %s", e)
|
||
return False
|
||
|
||
# 无回调时默认拒绝
|
||
logger.warning("危险命令被拒绝(无确认回调): %s", command[:100])
|
||
return False
|
||
|
||
def _log_audit(
|
||
self,
|
||
command: str,
|
||
session_id: str | None,
|
||
exit_code: int | None = None,
|
||
blocked: bool = False,
|
||
) -> None:
|
||
"""记录审计日志"""
|
||
entry = {
|
||
"timestamp": time.time(),
|
||
"command": command[:500],
|
||
"session_id": session_id,
|
||
"exit_code": exit_code,
|
||
"blocked": blocked,
|
||
}
|
||
self._audit_log.append(entry)
|
||
logger.info(
|
||
"Shell audit: command=%r session=%s exit=%s blocked=%s",
|
||
command[:100],
|
||
session_id,
|
||
exit_code,
|
||
blocked,
|
||
)
|
||
|
||
@property
|
||
def session_manager(self) -> TerminalSessionManager:
|
||
"""获取会话管理器"""
|
||
return self._session_manager
|
||
|
||
@property
|
||
def audit_log(self) -> list[dict[str, object]]:
|
||
"""获取审计日志(副本)"""
|
||
return list(self._audit_log)
|