fischer-agentkit/src/agentkit/tools/shell.py

619 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""ShellTool - Shell 命令执行工具
支持无会话模式(向后兼容)和有会话模式(跨命令保持状态)。
危险命令通过确认回调请求人工确认,所有操作记录审计日志。
"""
from __future__ import annotations
import asyncio
import logging
import os
import re
import shlex
import time
import uuid
from collections import deque
from typing import Callable, Awaitable
from agentkit.tools.base import Tool
from agentkit.tools.output_parser import OutputParser, ParsedOutput
from agentkit.tools.terminal_session import TerminalSessionManager
from agentkit.tools.pty_session import PTYSession
logger = logging.getLogger(__name__)
# 安全白名单:这些命令前缀不需要确认
_SAFE_COMMAND_PREFIXES: tuple[str, ...] = (
# 文件浏览
"cd",
"ls",
"pwd",
"find",
"file",
"stat",
"tree",
"du",
"wc",
# 文本查看/处理
"cat",
"head",
"tail",
"less",
"more",
"grep",
"egrep",
"fgrep",
"sort",
"uniq",
"diff",
"comm",
"cut",
"tr",
"awk",
"sed", # sed without -i is read-only; -i is caught by operator check
"tee",
"xargs",
# 输出
"echo",
"printf",
# 系统信息
"whoami",
"id",
"date",
"uname",
"hostname",
"uptime",
"df",
"free",
"top",
"htop",
"ps",
"env",
"printenv",
"type",
"which",
"whereis",
"arch",
"lscpu",
"lsblk",
"mount",
# 网络(只读查询)
"curl", # curl GET is safe; POST/PUT caught by operator check
"wget", # wget download is generally safe
"ping",
"traceroute",
"dig",
"nslookup",
"host",
"ifconfig",
"ip",
"netstat",
"ss",
"lsof",
# 版本/帮助
"python --version",
"python3 --version",
"node --version",
"npm --version",
"npm list",
"npm view",
"npm info",
"npm search",
"npx --version",
"pip list",
"pip show",
"pip search",
"java --version",
"go version",
"rustc --version",
"cargo --version",
"git --version",
"docker --version",
"docker ps",
"docker images",
"docker logs",
"docker inspect",
# Git 只读
"git status",
"git log",
"git diff",
"git branch",
"git remote",
"git show",
"git tag",
"git stash list",
"git config --list",
# 包管理器(只读查询)
"brew list",
"brew info",
"brew search",
"apt list",
"apt search",
"apt show",
"yum list",
"yum search",
"yum info",
# 其他安全命令
"export",
"sleep",
"true",
"false",
"test",
"seq",
"basename",
"dirname",
"realpath",
"readlink",
"md5sum",
"sha256sum",
"shasum",
"openssl", # openssl dgst / rand are safe
# tar 解压查看
"tar -tf",
"tar --list",
"zipinfo",
"unzip -l",
)
# 危险命令检测 — 基于精确 token 匹配,避免子串误判
# 总是危险的二进制命令(无论参数)
_DANGEROUS_BINARIES: frozenset[str] = frozenset(
{
"rm",
"rmdir",
"mkfs",
"dd",
"format",
"shutdown",
"reboot",
"halt",
"killall",
"chown",
"fdisk",
"parted",
}
)
# 需要特定参数才危险的二进制命令binary → 危险 flag/子命令集合
_DANGEROUS_BINARY_FLAGS: dict[str, set[str]] = {
"rm": {"-rf", "-fr", "-r", "-f"},
"kill": {"-9", "-kill"},
"chmod": {"777", "000"},
"git": {"push --force", "push -f", "reset --hard", "clean -f"},
"pip": {"uninstall"},
"npm": {"uninstall"},
"docker": {"rm", "rmi", "system prune"},
}
# 跨 token 的危险模式(编译后的正则)
_DANGEROUS_ARG_PATTERNS: list[re.Pattern[str]] = [
re.compile(r">\s*/dev/", re.IGNORECASE),
re.compile(r">\s*/etc/", re.IGNORECASE),
re.compile(r"drop\s+table", re.IGNORECASE),
re.compile(r"drop\s+database", re.IGNORECASE),
re.compile(r"truncate\s+table", re.IGNORECASE),
# curl/wget data exfiltration: POST/PUT/upload flags
re.compile(
r"\bcurl\b.*(-X\s*(POST|PUT|PATCH|DELETE)|--data|--data-binary|--data-raw|--data-urlencode|-d\s|--post\d)",
re.IGNORECASE,
),
re.compile(r"\bwget\b.*(--post-data|--post-file)", re.IGNORECASE),
]
_SHELL_PIPE_OPERATORS = re.compile(r"\|")
_SHELL_CHAIN_OPERATORS = re.compile(r"[;&]|\|\||&&|\$\(|\$\{|`|\$<|>|<|\n")
class ShellTool(Tool):
"""Shell 命令执行工具
支持两种模式:
1. 无会话模式(默认):每次命令独立执行,不保持状态
2. 有会话模式:通过 session_id 指定会话,跨命令保持 cwd/env/history
安全控制:
- 危险命令通过 confirm_callback 请求人工确认
- 所有操作记录审计日志
Usage:
# 无会话模式
tool = ShellTool()
result = await tool.execute(command="ls -la")
# 有会话模式
result = await tool.execute(command="cd /tmp", session_id="build-01")
result = await tool.execute(command="pwd", session_id="build-01") # 输出 /tmp
"""
def __init__(
self,
name: str = "shell",
description: str = "执行 Shell 命令,支持会话模式保持跨命令状态",
input_schema: dict[str, object] | None = None,
output_schema: dict[str, object] | None = None,
version: str = "1.0.0",
tags: list[str] | None = None,
confirm_callback: Callable[[str], Awaitable[bool]] | None = None,
default_timeout: float = 60.0,
max_output_length: int = 50000,
):
super().__init__(
name=name,
description=description,
input_schema=input_schema or self._default_input_schema(),
output_schema=output_schema or self._default_output_schema(),
version=version,
tags=tags or ["shell", "terminal", "system"],
)
self._session_manager = TerminalSessionManager()
self._output_parser = OutputParser()
self._confirm_callback = confirm_callback
self._default_timeout = default_timeout
self._max_output_length = max_output_length
self._audit_log: deque[dict[str, object]] = deque(maxlen=10000)
@staticmethod
def _default_input_schema() -> dict[str, object]:
return {
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "要执行的 Shell 命令",
},
"timeout": {
"type": "number",
"description": "超时时间(秒),默认 60",
"default": 60,
},
"working_dir": {
"type": "string",
"description": "工作目录(仅无会话模式有效)",
},
"session_id": {
"type": "string",
"description": "会话 ID指定后在会话中执行命令跨命令保持状态",
},
"interactive": {
"type": "boolean",
"description": "是否使用交互式模式PTY用于需要用户输入的命令",
"default": False,
},
},
"required": ["command"],
}
@staticmethod
def _default_output_schema() -> dict[str, object]:
return {
"type": "object",
"properties": {
"output": {"type": "string", "description": "命令输出"},
"exit_code": {"type": "integer", "description": "退出码"},
"is_error": {"type": "boolean", "description": "是否为错误"},
"error_type": {"type": "string", "description": "错误类型"},
"message": {"type": "string", "description": "消息摘要"},
"suggestions": {
"type": "array",
"items": {"type": "string"},
"description": "可操作建议",
},
"session_id": {"type": "string", "description": "会话 ID仅会话模式"},
},
}
async def execute(self, **kwargs) -> dict:
"""执行 Shell 命令
Args:
command: 要执行的命令(必需)
timeout: 超时时间(秒)
working_dir: 工作目录(仅无会话模式)
session_id: 会话 ID启用会话模式
interactive: 是否使用交互式模式
Returns:
包含 output, exit_code, is_error 等字段的字典
"""
command = kwargs.get("command")
if not command:
return {
"output": "",
"exit_code": 1,
"is_error": True,
"error_type": "invalid_argument",
"message": "command 参数是必需的",
"suggestions": ["提供要执行的 Shell 命令"],
}
timeout = kwargs.get("timeout", self._default_timeout)
working_dir = kwargs.get("working_dir")
session_id = kwargs.get("session_id")
interactive = kwargs.get("interactive", False)
# 安全检查:危险命令需要确认(除非已通过 _skip_dangerous_check 跳过)
skip_dangerous = kwargs.get("_skip_dangerous_check", False)
if not skip_dangerous and self._is_dangerous(command):
confirmed = await self._request_confirmation(command)
if not confirmed:
self._log_audit(command, None, blocked=True)
# 返回确认请求由上层ReActEngine/chat处理
confirmation_id = str(uuid.uuid4())
return {
"needs_confirmation": True,
"confirmation_id": confirmation_id,
"command": command[:500],
"reason": "此命令被识别为潜在危险操作,需要用户确认",
"suggestions": [
"确认执行此命令",
"拒绝执行此命令",
],
"output": f"命令被拒绝: {command[:100]}",
"exit_code": 126,
"is_error": True,
}
# 根据模式执行
if session_id:
result = await self._execute_in_session(
command, session_id, timeout, working_dir, interactive
)
else:
result = await self._execute_standalone(command, timeout, working_dir, interactive)
# 审计日志
self._log_audit(command, session_id, exit_code=result.exit_code)
# 截断过长输出
output = result.raw_output
if len(output) > self._max_output_length:
output = output[: self._max_output_length] + "\n... [输出已截断]"
# Ensure non-empty output for successful commands (all execution modes)
if result.exit_code == 0 and not output.strip():
from agentkit.core.fallback import SHELL_NO_OUTPUT
output = SHELL_NO_OUTPUT
return {
"output": output,
"exit_code": result.exit_code,
"is_error": result.is_error,
"error_type": result.error_type.value,
"message": result.message,
"suggestions": result.suggestions,
"session_id": session_id,
}
async def _execute_standalone(
self,
command: str,
timeout: float,
working_dir: str | None,
interactive: bool,
) -> ParsedOutput:
"""无会话模式执行命令(向后兼容)"""
if interactive:
return await self._execute_with_pty(command, timeout, working_dir)
try:
proc = await asyncio.create_subprocess_shell(
command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
cwd=working_dir,
)
try:
stdout, _ = await asyncio.wait_for(
proc.communicate(),
timeout=timeout,
)
except asyncio.TimeoutError:
try:
proc.kill()
except ProcessLookupError:
logger.debug("Process already exited before kill()")
except OSError:
logger.debug("OSError killing process")
await proc.wait()
output = f"命令执行超时({timeout}s"
exit_code = proc.returncode if proc.returncode is not None else -1
else:
output = stdout.decode("utf-8", errors="replace") if stdout else ""
exit_code = proc.returncode if proc.returncode is not None else 0
except Exception as e:
output = str(e)
exit_code = -1
return self._output_parser.parse(output, exit_code)
async def _execute_in_session(
self,
command: str,
session_id: str,
timeout: float,
working_dir: str | None,
interactive: bool,
) -> ParsedOutput:
"""会话模式执行命令"""
session = self._session_manager.get_or_create(
session_id,
cwd=working_dir,
)
if interactive:
return await self._execute_with_pty(command, timeout, session.cwd, session.env)
return await session.execute(command, timeout=timeout)
async def _execute_with_pty(
self,
command: str,
timeout: float,
cwd: str | None = None,
env: dict[str, str] | None = None,
) -> ParsedOutput:
"""使用 PTY 执行交互式命令"""
pty = PTYSession()
try:
await pty.start()
result = await pty.run_command(
command,
timeout=timeout,
cwd=cwd,
env=env,
)
output = result.output
exit_code = result.exit_code
except Exception as e:
output = str(e)
exit_code = -1
finally:
await pty.close()
return self._output_parser.parse(output, exit_code)
@staticmethod
def _is_dangerous(command: str) -> bool:
"""检查命令是否为危险操作
白名单命令直接放行。管道命令(|)在所有子命令都安全时放行。
其他链式操作符(;、&&、||、$()、>、< 等)一律视为危险。
Static so callers without a ShellTool instance (e.g. PhasePolicy) can
reuse the same danger classification. Instance calls still work via
Python's descriptor protocol.
"""
command_stripped = command.strip()
# Check for dangerous chain operators (;, &&, ||, $(), backticks, redirections, newlines)
if _SHELL_CHAIN_OPERATORS.search(command_stripped):
return True
# Handle pipe commands: split and check each sub-command
if _SHELL_PIPE_OPERATORS.search(command_stripped):
parts = command_stripped.split("|")
for part in parts:
part = part.strip()
if not part:
continue
if ShellTool._is_single_command_dangerous(part):
return True
return False # All pipe segments are safe
# Single command
return ShellTool._is_single_command_dangerous(command_stripped)
@staticmethod
def _is_single_command_dangerous(command: str) -> bool:
"""Check if a single command (no pipes/chains) is dangerous."""
command_stripped = command.strip()
if not command_stripped:
return True
# Parse the actual binary being invoked
try:
tokens = shlex.split(command_stripped)
if not tokens:
return True
binary = os.path.basename(tokens[0])
except ValueError:
return True
# Whitelist check: first try full command prefix match, then binary-only match
for prefix in _SAFE_COMMAND_PREFIXES:
prefix_stripped = prefix.lower().strip()
if " " in prefix_stripped:
# Compound prefix like "git status" - match against full command
if command_stripped.lower().startswith(prefix_stripped):
return False
else:
# Simple prefix - match against binary name exactly
if binary.lower() == prefix_stripped:
return False
# Dangerous pattern check — token-based matching
binary_lower = binary.lower()
# 1. Binary is always dangerous regardless of flags
if binary_lower in _DANGEROUS_BINARIES:
return True
# 2. Binary is dangerous with specific flags/subcommands
if binary_lower in _DANGEROUS_BINARY_FLAGS:
cmd_str = " ".join(tokens).lower()
for flag_pattern in _DANGEROUS_BINARY_FLAGS[binary_lower]:
if flag_pattern in cmd_str:
return True
# Binary has dangerous flags but none matched — treat as safe
# (e.g., "git add" is safe even though "git push --force" is not)
return False
# 3. Cross-token dangerous patterns (regex)
command_lower = command_stripped.lower()
for pattern in _DANGEROUS_ARG_PATTERNS:
if pattern.search(command_lower):
return True
# 4. Unknown binary — check if it looks like a path or known safe pattern
# Commands like /usr/bin/python3, ./script.sh, etc. are not in whitelist
# but may be safe. Default to requiring confirmation for truly unknown binaries.
return True
async def _request_confirmation(self, command: str) -> bool:
"""请求人工确认危险命令
Args:
command: 待确认的命令
Returns:
是否确认执行
"""
if self._confirm_callback:
try:
return await self._confirm_callback(command)
except Exception as e:
logger.warning("确认回调执行失败: %s", e)
return False
# 无回调时默认拒绝
logger.warning("危险命令被拒绝(无确认回调): %s", command[:100])
return False
def _log_audit(
self,
command: str,
session_id: str | None,
exit_code: int | None = None,
blocked: bool = False,
) -> None:
"""记录审计日志"""
entry = {
"timestamp": time.time(),
"command": command[:500],
"session_id": session_id,
"exit_code": exit_code,
"blocked": blocked,
}
self._audit_log.append(entry)
logger.info(
"Shell audit: command=%r session=%s exit=%s blocked=%s",
command[:100],
session_id,
exit_code,
blocked,
)
@property
def session_manager(self) -> TerminalSessionManager:
"""获取会话管理器"""
return self._session_manager
@property
def audit_log(self) -> list[dict[str, object]]:
"""获取审计日志(副本)"""
return list(self._audit_log)