feat(phase2): implement self-evolution and smart terminal (U6-U8)

- U6: PitfallDetector - detect historical failure patterns and warn
- U7: PathOptimizer - discover and update optimal execution paths
- U8: TerminalSession - session state, PTY interactive, output parsing

160 new tests passing. ShellTool enhanced with session_id support.
This commit is contained in:
chiguyong 2026-06-10 00:22:36 +08:00
parent fd4a811929
commit e3d4f811dd
11 changed files with 4001 additions and 0 deletions

View File

@ -0,0 +1,259 @@
"""PathOptimizer - 执行路径优化器
发现更优执行路径时自动更新经验库中的推荐路径
核心逻辑
1. 对比新路径与现有最优路径综合耗时和成功率
2. 新路径成功率更高 更新推荐路径
3. 成功率相近但耗时更短 更新推荐路径
4. 样本量不足 不更新记录待观察
"""
from __future__ import annotations
import logging
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from agentkit.evolution.experience_store import InMemoryExperienceStore
logger = logging.getLogger(__name__)
@dataclass
class ExecutionPath:
"""执行路径数据模型
记录特定任务类型的执行路径信息用于路径优化比较
Attributes:
path_id: 路径唯一标识
task_type: 任务类型
steps: 执行步骤名称列表
total_duration: 总耗时
success_rate: 成功率0.0 ~ 1.0
sample_count: 样本数量
is_recommended: 是否为当前推荐路径
created_at: 创建时间
"""
path_id: str = ""
task_type: str = ""
steps: list[str] = field(default_factory=list)
total_duration: float = 0.0
success_rate: float = 0.0
sample_count: int = 0
is_recommended: bool = False
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
@dataclass
class PathUpdateResult:
"""路径更新结果
Attributes:
updated: 是否更新了推荐路径
old_path: 更新前的推荐路径未更新时为 None
new_path: 更新后的推荐路径未更新时为 None
reason: 更新/未更新的原因说明
"""
updated: bool = False
old_path: ExecutionPath | None = None
new_path: ExecutionPath | None = None
reason: str = ""
class PathOptimizer:
"""执行路径优化器
对比新路径与现有最优路径决定是否更新推荐路径
可独立使用也可集成到 PlanChecker 的复盘中
更新策略
1. 新路径成功率 > 现有成功率 + success_rate_threshold 更新
2. 成功率相近差值 threshold但耗时显著更短
duration 改善比例 > duration_improvement_threshold 更新
3. 样本量不足< min_sample_count 不更新
4. 其他情况 保留现有推荐路径
"""
def __init__(
self,
experience_store: InMemoryExperienceStore | None = None,
min_sample_count: int = 3,
success_rate_threshold: float = 0.05,
duration_improvement_threshold: float = 0.2,
):
"""初始化 PathOptimizer
Args:
experience_store: 经验存储实例可选
min_sample_count: 最小样本量低于此值不做决策
success_rate_threshold: 成功率提升阈值超过此值视为显著提升
duration_improvement_threshold: 耗时改善比例阈值超过此值视为显著改善
"""
self._experience_store = experience_store
self._min_sample_count = min_sample_count
self._success_rate_threshold = success_rate_threshold
self._duration_improvement_threshold = duration_improvement_threshold
self._recommended_paths: dict[str, ExecutionPath] = {}
self._pending_paths: dict[str, list[ExecutionPath]] = {}
def get_recommended_path(self, task_type: str) -> ExecutionPath | None:
"""获取指定任务类型的当前推荐路径
Args:
task_type: 任务类型
Returns:
推荐路径若无则返回 None
"""
return self._recommended_paths.get(task_type)
async def evaluate_and_update(
self,
task_type: str,
new_path: ExecutionPath,
) -> PathUpdateResult:
"""评估新路径并决定是否更新推荐路径
Args:
task_type: 任务类型
new_path: 新的执行路径
Returns:
路径更新结果
"""
# 确保新路径有 path_id
if not new_path.path_id:
new_path.path_id = str(uuid.uuid4())
new_path.task_type = task_type
# 样本量不足 → 不更新,记录待观察
if new_path.sample_count < self._min_sample_count:
self._pending_paths.setdefault(task_type, []).append(new_path)
reason = (
f"样本量不足({new_path.sample_count} < {self._min_sample_count}"
f"记录待观察"
)
logger.info(
f"Path not updated for '{task_type}': {reason}"
)
return PathUpdateResult(
updated=False,
old_path=None,
new_path=new_path,
reason=reason,
)
current = self._recommended_paths.get(task_type)
# 无现有推荐路径 → 直接设为推荐
if current is None:
new_path.is_recommended = True
self._recommended_paths[task_type] = new_path
reason = "无现有推荐路径,直接设为推荐"
logger.info(f"Path set as recommended for '{task_type}': {reason}")
return PathUpdateResult(
updated=True,
old_path=None,
new_path=new_path,
reason=reason,
)
# 比较新路径与现有推荐路径
return self._compare_and_decide(task_type, current, new_path)
def _compare_and_decide(
self,
task_type: str,
current: ExecutionPath,
new: ExecutionPath,
) -> PathUpdateResult:
"""比较新旧路径并决策
比较逻辑
1. 新路径成功率 > 现有成功率 + threshold 更新
2. 成功率相近差值 threshold且新耗时显著更短 更新
3. 其他 保留现有
"""
sr_diff = new.success_rate - current.success_rate
# 条件 1成功率显著提升
if sr_diff > self._success_rate_threshold:
return self._apply_update(
task_type, current, new,
f"成功率显著提升({new.success_rate:.2f} > {current.success_rate:.2f}"
f"提升 {sr_diff:.2f}",
)
# 条件 2成功率相近但耗时显著更短
if abs(sr_diff) <= self._success_rate_threshold:
if current.total_duration > 0:
duration_improvement = (
(current.total_duration - new.total_duration) / current.total_duration
)
if (
new.total_duration < current.total_duration
and duration_improvement > self._duration_improvement_threshold
):
return self._apply_update(
task_type, current, new,
f"成功率相近({new.success_rate:.2f} vs {current.success_rate:.2f}"
f"耗时显著更短({new.total_duration:.1f}s vs {current.total_duration:.1f}s"
f"改善 {duration_improvement:.1%}",
)
elif current.total_duration == 0 and new.total_duration > 0:
# 现有路径耗时为 0不太可能不更新
pass
elif current.total_duration == 0 and new.total_duration == 0:
# 两者耗时均为 0不更新
pass
# 条件 3无明显优势 → 保留现有
reason = (
f"新路径无明显优势(成功率 {new.success_rate:.2f} vs {current.success_rate:.2f}"
f"耗时 {new.total_duration:.1f}s vs {current.total_duration:.1f}s保留现有推荐路径"
)
logger.info(f"Path not updated for '{task_type}': {reason}")
return PathUpdateResult(
updated=False,
old_path=current,
new_path=new,
reason=reason,
)
def _apply_update(
self,
task_type: str,
old: ExecutionPath,
new: ExecutionPath,
reason: str,
) -> PathUpdateResult:
"""应用路径更新"""
old.is_recommended = False
new.is_recommended = True
self._recommended_paths[task_type] = new
logger.info(f"Path updated for '{task_type}': {reason}")
return PathUpdateResult(
updated=True,
old_path=old,
new_path=new,
reason=reason,
)
def get_pending_paths(self, task_type: str) -> list[ExecutionPath]:
"""获取指定任务类型的待观察路径
Args:
task_type: 任务类型
Returns:
待观察路径列表
"""
return list(self._pending_paths.get(task_type, []))

View File

@ -0,0 +1,388 @@
"""PitfallDetector - 任务避坑预警
新任务启动时检索历史失败经验匹配当前计划步骤自动预警
基于 ExperienceStore 中存储的失败经验将失败步骤与当前计划步骤
进行关键词匹配计算失败率并按严重程度返回预警列表
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Protocol
logger = logging.getLogger(__name__)
class WarningLevel(str, Enum):
"""预警级别"""
HIGH = "high"
MEDIUM = "medium"
LOW = "low"
@dataclass
class PitfallWarning:
"""避坑预警
Attributes:
step_name: 计划步骤名称
warning_level: 预警级别HIGH/MEDIUM/LOW
failure_rate: 历史失败率0.0 ~ 1.0
historical_failures: 历史失败原因列表
suggestion: 优化建议
"""
step_name: str
warning_level: WarningLevel
failure_rate: float
historical_failures: list[str] = field(default_factory=list)
suggestion: str = ""
class ExperienceStoreProtocol(Protocol):
"""ExperienceStore 协议接口,用于类型标注"""
async def search(
self,
query: str,
top_k: int = 5,
task_type: str | None = None,
search_multiplier: int = 5,
) -> list[Any]:
...
# 预警级别阈值
_HIGH_THRESHOLD = 0.5
_MEDIUM_THRESHOLD = 0.2
class PitfallDetector:
"""避坑检测器
新任务启动时检索历史失败经验匹配当前计划步骤自动预警
使用方式
detector = PitfallDetector(experience_store)
warnings = await detector.check_pitfalls(
task_type="code_review",
planned_steps=[plan_step1, plan_step2, ...],
)
匹配逻辑
1. 检索同类任务的失败经验
2. 从失败经验中提取失败步骤
3. 将失败步骤与当前计划步骤进行关键词匹配
4. 计算失败率并分配预警级别
预警级别
- HIGH: failure_rate >= 0.5历史高失败率步骤
- MEDIUM: failure_rate >= 0.2有失败记录但频率低
- LOW: 有任何失败记录
"""
def __init__(
self,
experience_store: ExperienceStoreProtocol,
similarity_threshold: float = 0.3,
max_search_results: int = 50,
):
"""
Args:
experience_store: 经验存储实例ExperienceStore InMemoryExperienceStore
similarity_threshold: 步骤名称关键词匹配的最小相似度阈值
max_search_results: 从经验存储检索的最大结果数
"""
self._store = experience_store
self._similarity_threshold = similarity_threshold
self._max_search_results = max_search_results
async def check_pitfalls(
self,
task_type: str,
planned_steps: list[Any],
) -> list[PitfallWarning]:
"""检查计划步骤中的潜在陷阱
Args:
task_type: 任务类型
planned_steps: 计划步骤列表PlanStep 对象或具有 name/description 属性的对象
Returns:
按严重程度排序的预警列表HIGH MEDIUM LOW
"""
if not planned_steps:
return []
# 1. 检索同类任务的所有经验(包含成功和失败,用于计算步骤级失败率)
all_experiences = await self._search_experiences(task_type)
if not all_experiences:
logger.debug(f"No experiences found for task_type={task_type}")
return []
# 2. 从经验中提取步骤级别的失败统计
step_failure_stats = self._extract_step_failure_stats(all_experiences)
# 3. 匹配当前计划步骤并生成预警
warnings = self._match_and_warn(planned_steps, step_failure_stats)
# 4. 按严重程度排序HIGH → MEDIUM → LOW同级别按失败率降序
warnings.sort(key=lambda w: (_warning_level_order(w.warning_level), -w.failure_rate))
if warnings:
logger.info(
f"PitfallDetector found {len(warnings)} warnings for task_type={task_type}: "
f"{sum(1 for w in warnings if w.warning_level == WarningLevel.HIGH)} HIGH, "
f"{sum(1 for w in warnings if w.warning_level == WarningLevel.MEDIUM)} MEDIUM, "
f"{sum(1 for w in warnings if w.warning_level == WarningLevel.LOW)} LOW"
)
return warnings
async def _search_experiences(self, task_type: str) -> list[Any]:
"""检索指定任务类型的所有经验(包含成功和失败)"""
try:
results = await self._store.search(
query=task_type,
top_k=self._max_search_results,
task_type=task_type,
)
return results
except Exception as e:
logger.error(f"Failed to search experiences for pitfall detection: {e}")
return []
def _extract_step_failure_stats(
self, failed_experiences: list[Any]
) -> dict[str, _StepFailureStats]:
"""从失败经验中提取步骤级别的失败统计
steps_summary 可以是 str list[dict]
- list[dict]: 每个字典包含 step_name, outcome, duration_seconds, error
- str: 退化为整体统计
Returns:
以步骤名称为 key 的失败统计字典
"""
stats: dict[str, _StepFailureStats] = {}
for exp in failed_experiences:
steps_summary = exp.steps_summary
# 如果 steps_summary 是字符串,无法提取步骤级信息
if isinstance(steps_summary, str):
continue
if not isinstance(steps_summary, list):
continue
for step in steps_summary:
if not isinstance(step, dict):
continue
step_name = step.get("step_name", "")
if not step_name:
continue
outcome = step.get("outcome", "")
error = step.get("error", "")
if step_name not in stats:
stats[step_name] = _StepFailureStats(
step_name=step_name,
total_occurrences=0,
failure_occurrences=0,
failure_reasons=[],
optimization_tips=[],
)
s = stats[step_name]
s.total_occurrences += 1
if outcome in ("failure", "failed", "error"):
s.failure_occurrences += 1
if error:
s.failure_reasons.append(error)
# 收集优化建议
if hasattr(exp, "optimization_tips") and exp.optimization_tips:
for step_name, s in stats.items():
s.optimization_tips.extend(exp.optimization_tips)
return stats
def _match_and_warn(
self,
planned_steps: list[Any],
step_failure_stats: dict[str, _StepFailureStats],
) -> list[PitfallWarning]:
"""将计划步骤与失败统计进行匹配,生成预警"""
warnings: list[PitfallWarning] = []
for step in planned_steps:
step_name = getattr(step, "name", "")
step_description = getattr(step, "description", "")
if not step_name:
continue
# 查找最佳匹配的失败步骤
best_match: _StepFailureStats | None = None
best_similarity = 0.0
for stats_step_name, stats in step_failure_stats.items():
similarity = _compute_name_similarity(
step_name, step_description, stats_step_name
)
if similarity > best_similarity:
best_similarity = similarity
best_match = stats
# 相似度低于阈值,跳过
if best_match is None or best_similarity < self._similarity_threshold:
continue
# 计算失败率
failure_rate = (
best_match.failure_occurrences / best_match.total_occurrences
if best_match.total_occurrences > 0
else 0.0
)
# 分配预警级别
warning_level = _determine_warning_level(failure_rate)
# 生成建议
suggestion = _build_suggestion(best_match, failure_rate)
warning = PitfallWarning(
step_name=step_name,
warning_level=warning_level,
failure_rate=round(failure_rate, 4),
historical_failures=best_match.failure_reasons[:5], # 最多保留 5 条
suggestion=suggestion,
)
warnings.append(warning)
return warnings
# ── 内部辅助类 ──────────────────────────────────────────────
@dataclass
class _StepFailureStats:
"""步骤级别的失败统计(内部使用)"""
step_name: str
total_occurrences: int
failure_occurrences: int
failure_reasons: list[str]
optimization_tips: list[str]
# ── 辅助函数 ──────────────────────────────────────────────
def _compute_name_similarity(
step_name: str, step_description: str, historical_step_name: str
) -> float:
"""计算步骤名称的关键词重叠相似度
基于关键词集合的 Jaccard 相似度同时考虑 step_name step_description
Args:
step_name: 当前计划步骤名称
step_description: 当前计划步骤描述
historical_step_name: 历史步骤名称
Returns:
相似度分数0.0 ~ 1.0
"""
# 提取关键词:将名称拆分为词,过滤掉常见停用词
current_keywords = _extract_keywords(f"{step_name} {step_description}")
historical_keywords = _extract_keywords(historical_step_name)
if not current_keywords or not historical_keywords:
return 0.0
# Jaccard 相似度
intersection = current_keywords & historical_keywords
union = current_keywords | historical_keywords
if not union:
return 0.0
return len(intersection) / len(union)
_STOP_WORDS = frozenset({
"a", "an", "the", "and", "or", "but", "in", "on", "at", "to", "for",
"of", "with", "by", "from", "is", "are", "was", "were", "be", "been",
"being", "have", "has", "had", "do", "does", "did", "will", "would",
"could", "should", "may", "might", "can", "shall", "not", "no",
})
def _extract_keywords(text: str) -> frozenset[str]:
"""从文本中提取关键词集合
转小写按空白/下划线/连字符拆分过滤停用词和单字符词
"""
# 统一分隔符
normalized = text.lower().replace("_", " ").replace("-", " ")
words = normalized.split()
return frozenset(
w for w in words
if len(w) > 1 and w not in _STOP_WORDS
)
def _determine_warning_level(failure_rate: float) -> WarningLevel:
"""根据失败率确定预警级别
- HIGH: failure_rate >= 0.5
- MEDIUM: failure_rate >= 0.2
- LOW: 有任何失败记录
"""
if failure_rate >= _HIGH_THRESHOLD:
return WarningLevel.HIGH
if failure_rate >= _MEDIUM_THRESHOLD:
return WarningLevel.MEDIUM
return WarningLevel.LOW
def _warning_level_order(level: WarningLevel) -> int:
"""预警级别排序值(越小越严重)"""
return {
WarningLevel.HIGH: 0,
WarningLevel.MEDIUM: 1,
WarningLevel.LOW: 2,
}[level]
def _build_suggestion(stats: _StepFailureStats, failure_rate: float) -> str:
"""根据失败统计生成优化建议"""
parts: list[str] = []
if failure_rate >= _HIGH_THRESHOLD:
parts.append(f"该步骤历史失败率高达 {failure_rate:.0%},建议重点关注")
elif failure_rate >= _MEDIUM_THRESHOLD:
parts.append(f"该步骤历史失败率为 {failure_rate:.0%},需注意风险")
else:
parts.append(f"该步骤有少量失败记录(失败率 {failure_rate:.0%}")
if stats.failure_reasons:
unique_reasons = list(dict.fromkeys(stats.failure_reasons))[:3]
reasons_str = "".join(unique_reasons)
parts.append(f"常见失败原因:{reasons_str}")
if stats.optimization_tips:
unique_tips = list(dict.fromkeys(stats.optimization_tips))[:2]
tips_str = "".join(unique_tips)
parts.append(f"建议:{tips_str}")
return "".join(parts)

View File

@ -9,6 +9,10 @@ from agentkit.tools.composition import SequentialChain, ParallelFanOut, DynamicS
from agentkit.tools.web_crawl import WebCrawlTool from agentkit.tools.web_crawl import WebCrawlTool
from agentkit.tools.schema_tools import SchemaExtractTool, SchemaGenerateTool from agentkit.tools.schema_tools import SchemaExtractTool, SchemaGenerateTool
from agentkit.tools.baidu_search import BaiduSearchTool from agentkit.tools.baidu_search import BaiduSearchTool
from agentkit.tools.shell import ShellTool
from agentkit.tools.terminal_session import TerminalSession, TerminalSessionManager
from agentkit.tools.pty_session import PTYSession
from agentkit.tools.output_parser import OutputParser, ParsedOutput, ErrorType
# Conditional import: HeadroomRetrieveTool requires HeadroomCompressor # Conditional import: HeadroomRetrieveTool requires HeadroomCompressor
try: try:
@ -30,4 +34,11 @@ __all__ = [
"SchemaGenerateTool", "SchemaGenerateTool",
"BaiduSearchTool", "BaiduSearchTool",
"HeadroomRetrieveTool", "HeadroomRetrieveTool",
"ShellTool",
"TerminalSession",
"TerminalSessionManager",
"PTYSession",
"OutputParser",
"ParsedOutput",
"ErrorType",
] ]

View File

@ -0,0 +1,294 @@
"""OutputParser - 结构化解析命令输出
将命令行输出解析为结构化格式包含错误类型识别退出码含义和可操作建议
"""
from __future__ import annotations
import re
from dataclasses import dataclass, field
from enum import Enum
from typing import Any
class ErrorType(Enum):
"""命令输出错误类型"""
NONE = "none"
PERMISSION_DENIED = "permission_denied"
NOT_FOUND = "not_found"
TIMEOUT = "timeout"
SYNTAX_ERROR = "syntax_error"
CONNECTION_REFUSED = "connection_refused"
OUT_OF_MEMORY = "out_of_memory"
DISK_FULL = "disk_full"
ALREADY_EXISTS = "already_exists"
INVALID_ARGUMENT = "invalid_argument"
PROCESS_NOT_FOUND = "process_not_found"
NETWORK_ERROR = "network_error"
UNKNOWN = "unknown"
@dataclass
class ParsedOutput:
"""结构化命令输出
Attributes:
exit_code: 命令退出码
is_error: 是否为错误输出
error_type: 错误类型仅当 is_error=True 时有值
message: 输出消息摘要
raw_output: 原始输出文本
suggestions: 可操作建议列表
"""
exit_code: int
is_error: bool
error_type: ErrorType = ErrorType.NONE
message: str = ""
raw_output: str = ""
suggestions: list[str] = field(default_factory=list)
def to_dict(self) -> dict[str, Any]:
return {
"exit_code": self.exit_code,
"is_error": self.is_error,
"error_type": self.error_type.value,
"message": self.message,
"suggestions": self.suggestions,
}
# 错误模式匹配规则:(pattern, error_type, message_template, suggestions)
_ERROR_PATTERNS: list[tuple[re.Pattern, ErrorType, str, list[str]]] = [
(
re.compile(r"permission denied|access denied|权限不足|拒绝访问", re.IGNORECASE),
ErrorType.PERMISSION_DENIED,
"权限不足",
[
"尝试使用 sudo 执行该命令",
"检查文件/目录权限: ls -la <path>",
"确认当前用户是否有所需权限",
],
),
(
re.compile(
r"not found|no such file|no such directory|找不到|不存在|无法找到",
re.IGNORECASE,
),
ErrorType.NOT_FOUND,
"文件或目录不存在",
[
"检查路径拼写是否正确",
"使用 ls 确认文件/目录是否存在",
"检查是否在正确的工作目录下",
],
),
(
re.compile(r"timed?\s*out|timeout|超时|时间超限", re.IGNORECASE),
ErrorType.TIMEOUT,
"命令执行超时",
[
"增加超时时间",
"检查网络连接是否正常",
"检查目标服务是否可达",
],
),
(
re.compile(
r"syntax error|syntaxerror|parse error|语法错误|解析错误",
re.IGNORECASE,
),
ErrorType.SYNTAX_ERROR,
"语法错误",
[
"检查命令语法是否正确",
"使用 --help 查看命令用法",
"检查引号和特殊字符是否正确转义",
],
),
(
re.compile(
r"connection refused|连接被拒绝|无法连接|ECONNREFUSED",
re.IGNORECASE,
),
ErrorType.CONNECTION_REFUSED,
"连接被拒绝",
[
"检查目标服务是否已启动",
"确认端口号是否正确",
"检查防火墙设置是否阻止了连接",
],
),
(
re.compile(
r"out of memory|oom|cannot allocate|内存不足|内存溢出",
re.IGNORECASE,
),
ErrorType.OUT_OF_MEMORY,
"内存不足",
[
"释放不必要的内存占用",
"增加系统可用内存",
"检查是否有内存泄漏",
],
),
(
re.compile(
r"no space left|disk full|磁盘已满|空间不足|ENOSPC",
re.IGNORECASE,
),
ErrorType.DISK_FULL,
"磁盘空间不足",
[
"清理不必要的文件: du -sh * | sort -rh | head",
"检查磁盘使用情况: df -h",
"删除临时文件或日志",
],
),
(
re.compile(
r"already exists|file exists|已存在|重复|EEXIST",
re.IGNORECASE,
),
ErrorType.ALREADY_EXISTS,
"资源已存在",
[
"使用 -f 参数强制覆盖(如适用)",
"先删除已有资源再重新创建",
"使用不同名称创建",
],
),
(
re.compile(
r"invalid argument|illegal option|bad option|无效参数|非法选项|invalid option",
re.IGNORECASE,
),
ErrorType.INVALID_ARGUMENT,
"无效参数",
[
"检查命令参数是否正确",
"使用 --help 查看支持的参数",
"确认参数值类型和范围",
],
),
(
re.compile(
r"no such process|process not found|进程不存在|进程未找到",
re.IGNORECASE,
),
ErrorType.PROCESS_NOT_FOUND,
"进程不存在",
[
"确认进程 ID 是否正确",
"使用 ps aux 查看运行中的进程",
"进程可能已经结束",
],
),
(
re.compile(
r"network is unreachable|no route to host|name resolution|网络不可达|无法解析|ENETUNREACH",
re.IGNORECASE,
),
ErrorType.NETWORK_ERROR,
"网络错误",
[
"检查网络连接是否正常",
"确认 DNS 解析是否正常: nslookup <domain>",
"检查代理设置",
],
),
]
class OutputParser:
"""命令输出结构化解析器
将命令行输出stdout + stderr和退出码解析为结构化的 ParsedOutput
包含错误类型识别消息摘要和可操作建议
"""
def parse(self, output: str, exit_code: int) -> ParsedOutput:
"""解析命令输出
Args:
output: 命令的标准输出和错误输出合并文本
exit_code: 命令退出码
Returns:
ParsedOutput 结构化解析结果
"""
is_error = exit_code != 0
message = self._extract_message(output)
error_type = ErrorType.NONE
suggestions: list[str] = []
if is_error:
error_type, suggestions = self._classify_error(output, exit_code)
return ParsedOutput(
exit_code=exit_code,
is_error=is_error,
error_type=error_type,
message=message,
raw_output=output,
suggestions=suggestions,
)
def _extract_message(self, output: str) -> str:
"""从输出中提取关键消息
取最后几行非空输出中的关键行作为消息摘要
"""
if not output:
return ""
lines = [line.strip() for line in output.strip().splitlines() if line.strip()]
if not lines:
return ""
# 取最后一行作为摘要,如果太长则截断
message = lines[-1]
if len(message) > 200:
message = message[:200] + "..."
return message
def _classify_error(
self, output: str, exit_code: int
) -> tuple[ErrorType, list[str]]:
"""根据输出内容和退出码分类错误类型
Args:
output: 命令输出
exit_code: 退出码
Returns:
(error_type, suggestions) 元组
"""
# 优先根据输出内容匹配
for pattern, error_type, _msg, suggestions in _ERROR_PATTERNS:
if pattern.search(output):
return error_type, suggestions
# 退出码兜底分类
if exit_code == 126:
return ErrorType.PERMISSION_DENIED, [
"检查文件是否有执行权限: chmod +x <file>",
"确认文件格式是否正确(如行尾符)",
]
if exit_code == 127:
return ErrorType.NOT_FOUND, [
"检查命令是否已安装",
"确认命令名称拼写是否正确",
"检查 PATH 环境变量是否包含命令所在目录",
]
if exit_code == 130:
return ErrorType.TIMEOUT, [
"命令被 Ctrl+C 中断",
"可能需要增加超时时间",
]
return ErrorType.UNKNOWN, [
"检查命令输出中的错误信息",
"使用 --verbose 或 --debug 获取更多详情",
]

View File

@ -0,0 +1,341 @@
"""PTYSession - 伪终端会话,支持交互式命令
基于 asyncio + os.openpty() 实现伪终端支持交互式命令和自动应答
不依赖 pexpect仅使用标准库
"""
from __future__ import annotations
import asyncio
import fcntl
import logging
import os
import struct
import termios
import time
from dataclasses import dataclass
logger = logging.getLogger(__name__)
# 自动应答规则:(prompt_pattern, response)
_AUTO_RESPOND_RULES: list[tuple[str, str]] = [
(r"\[y/N\]\s*$", "y"),
(r"\[Y/n\]\s*$", "y"),
(r"\[yes/no\]\s*$", "yes"),
(r"\(yes/no\)\s*$", "yes"),
(r"\(yes/no/\[fingerprint\]\)\s*$", "yes"),
(r"continue\?\s*$", "y"),
(r"are you sure\?\s*$", "y"),
(r"password:\s*$", ""), # 密码提示不自动应答,需要人工介入
(r"passphrase:\s*$", ""),
]
@dataclass
class PTYOutput:
"""PTY 输出结果
Attributes:
output: 输出文本
exit_code: 退出码-1 表示超时或未结束
timed_out: 是否超时
"""
output: str
exit_code: int = -1
timed_out: bool = False
class PTYSession:
"""伪终端会话 - 支持交互式命令
使用 os.openpty() 创建伪终端对通过 asyncio 异步读写
支持自动检测提示并应答 yes/no 确认
Usage:
pty = PTYSession()
await pty.start()
output = await pty.run_command("ssh-keygen", timeout=10)
await pty.close()
"""
def __init__(
self,
auto_respond: bool = True,
custom_rules: list[tuple[str, str]] | None = None,
default_timeout: float = 30.0,
buffer_size: int = 4096,
):
"""初始化 PTY 会话
Args:
auto_respond: 是否自动应答已知提示
custom_rules: 自定义应答规则列表 [(prompt_pattern, response)]
default_timeout: 默认超时时间
buffer_size: 读取缓冲区大小
"""
self._auto_respond = auto_respond
self._respond_rules = list(_AUTO_RESPOND_RULES)
if custom_rules:
self._respond_rules.extend(custom_rules)
self._default_timeout = default_timeout
self._buffer_size = buffer_size
self._master_fd: int | None = None
self._slave_fd: int | None = None
self._process: asyncio.subprocess.Process | None = None
self._running = False
self._output_buffer = ""
@property
def is_running(self) -> bool:
"""PTY 会话是否已启动(伪终端已创建)"""
return self._running
async def start(self) -> None:
"""启动 PTY 会话(创建伪终端对)
在执行命令前调用创建 master/slave 文件描述符
"""
if self._running:
return
self._master_fd, self._slave_fd = os.openpty()
# 设置终端大小
try:
winsize = struct.pack("HHHH", 24, 80, 0, 0)
fcntl.ioctl(self._slave_fd, termios.TIOCSWINSZ, winsize)
except Exception:
pass
# 设置 master fd 为非阻塞
flags = fcntl.fcntl(self._master_fd, fcntl.F_GETFL)
fcntl.fcntl(self._master_fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
self._running = True
logger.debug("PTY session started: master=%d slave=%d", self._master_fd, self._slave_fd)
async def run_command(
self,
command: str,
timeout: float | None = None,
cwd: str | None = None,
env: dict[str, str] | None = None,
) -> PTYOutput:
"""在 PTY 中运行命令,等待完成
Args:
command: 要执行的命令
timeout: 超时时间
cwd: 工作目录
env: 环境变量
Returns:
PTYOutput 输出结果
"""
if not self._running:
await self.start()
timeout = timeout or self._default_timeout
self._output_buffer = ""
# 构建环境
cmd_env = dict(os.environ)
if env:
cmd_env.update(env)
# 启动子进程,使用 slave 作为 stdin/stdout/stderr
self._process = await asyncio.create_subprocess_shell(
command,
stdin=self._slave_fd,
stdout=self._slave_fd,
stderr=self._slave_fd,
cwd=cwd,
env=cmd_env,
start_new_session=True,
)
# 关闭 slave 端(子进程已继承)
os.close(self._slave_fd)
self._slave_fd = None
# 异步读取输出
try:
exit_code = await self._read_until_exit(timeout)
except asyncio.TimeoutError:
self._output_buffer += "\n[PTY 命令执行超时]"
return PTYOutput(
output=self._output_buffer,
exit_code=-1,
timed_out=True,
)
return PTYOutput(
output=self._output_buffer,
exit_code=exit_code,
timed_out=False,
)
async def send(self, line: str) -> None:
"""向 PTY 发送一行输入
Args:
line: 要发送的文本自动追加换行符
"""
if self._master_fd is None:
return
data = (line + "\n").encode("utf-8")
try:
os.write(self._master_fd, data)
except OSError as e:
logger.warning("PTY write failed: %s", e)
async def read_output(self, timeout: float = 1.0) -> str:
"""读取当前可用的 PTY 输出
Args:
timeout: 读取超时
Returns:
读取到的输出文本
"""
if self._master_fd is None:
return ""
output = ""
deadline = time.monotonic() + timeout
while time.monotonic() < deadline:
try:
chunk = os.read(self._master_fd, self._buffer_size)
if chunk:
text = chunk.decode("utf-8", errors="replace")
output += text
self._output_buffer += text
# 自动应答
if self._auto_respond:
await self._try_auto_respond(text)
except (BlockingIOError, OSError):
# 没有数据可读
await asyncio.sleep(0.05)
continue
if output:
# 读取到数据后短暂等待看是否还有更多
await asyncio.sleep(0.05)
return output
async def close(self) -> None:
"""关闭 PTY 会话,清理资源"""
self._running = False
if self._process is not None and self._process.returncode is None:
try:
self._process.terminate()
await asyncio.wait_for(self._process.wait(), timeout=5.0)
except (asyncio.TimeoutError, ProcessLookupError):
try:
self._process.kill()
except ProcessLookupError:
pass
if self._master_fd is not None:
try:
os.close(self._master_fd)
except OSError:
pass
self._master_fd = None
if self._slave_fd is not None:
try:
os.close(self._slave_fd)
except OSError:
pass
self._slave_fd = None
self._process = None
logger.debug("PTY session closed")
async def _read_until_exit(self, timeout: float) -> int:
"""持续读取输出直到进程退出
Args:
timeout: 超时时间
Returns:
进程退出码
"""
deadline = time.monotonic() + timeout
while True:
# 检查超时
if time.monotonic() > deadline:
raise asyncio.TimeoutError()
# 检查进程是否已退出
if self._process.returncode is not None:
# 进程已退出,再读一次剩余输出
await self._drain_remaining_output()
return self._process.returncode
# 读取输出
try:
chunk = os.read(self._master_fd, self._buffer_size)
if chunk:
text = chunk.decode("utf-8", errors="replace")
self._output_buffer += text
# 自动应答
if self._auto_respond:
await self._try_auto_respond(text)
except (BlockingIOError, OSError):
pass
# 检查进程状态
if self._process.returncode is None:
try:
await asyncio.wait_for(
self._process.wait(), timeout=0.05
)
except asyncio.TimeoutError:
pass
else:
await self._drain_remaining_output()
return self._process.returncode
await asyncio.sleep(0.02)
async def _drain_remaining_output(self) -> None:
"""排空剩余输出"""
for _ in range(10): # 最多尝试 10 次
try:
chunk = os.read(self._master_fd, self._buffer_size)
if chunk:
text = chunk.decode("utf-8", errors="replace")
self._output_buffer += text
else:
break
except (BlockingIOError, OSError):
break
await asyncio.sleep(0.01)
async def _try_auto_respond(self, recent_output: str) -> None:
"""检测提示并自动应答
Args:
recent_output: 最近的输出文本
"""
import re
for pattern, response in self._respond_rules:
if not response:
# 空响应规则(如密码提示)跳过
continue
if re.search(pattern, recent_output, re.IGNORECASE | re.MULTILINE):
logger.debug("Auto-responding to prompt '%s' with '%s'", pattern, response)
await self.send(response)
break

432
src/agentkit/tools/shell.py Normal file
View File

@ -0,0 +1,432 @@
"""ShellTool - Shell 命令执行工具
支持无会话模式向后兼容和有会话模式跨命令保持状态
危险命令通过确认回调请求人工确认所有操作记录审计日志
"""
from __future__ import annotations
import asyncio
import logging
import os
import time
from typing import Any, Callable, Awaitable
from agentkit.tools.base import Tool
from agentkit.tools.output_parser import OutputParser, ParsedOutput
from agentkit.tools.terminal_session import TerminalSession, TerminalSessionManager
from agentkit.tools.pty_session import PTYSession
logger = logging.getLogger(__name__)
# 安全白名单:这些命令前缀不需要确认
_SAFE_COMMAND_PREFIXES: tuple[str, ...] = (
"ls",
"cat",
"head",
"tail",
"grep",
"find",
"pwd",
"echo",
"which",
"whoami",
"id",
"date",
"uname",
"df",
"du",
"free",
"ps",
"top",
"env",
"printenv",
"type",
"file",
"stat",
"wc",
"sort",
"uniq",
"diff",
"git status",
"git log",
"git diff",
"git branch",
"git remote",
"pip list",
"pip show",
"python --version",
"python3 --version",
"node --version",
"npm list",
"docker ps",
"docker images",
"curl",
"wget",
)
# 危险命令模式:这些命令需要人工确认
_DANGEROUS_PATTERNS: tuple[str, ...] = (
"rm ",
"rm -",
"rmdir",
"mkfs",
"dd ",
"format",
"del ",
"erase",
"> /dev/",
"shutdown",
"reboot",
"init 0",
"init 6",
"kill -9",
"killall",
"chmod 777",
"chown",
"mv /",
"pip uninstall",
"npm uninstall",
"apt remove",
"yum remove",
"brew uninstall",
"docker rm",
"docker rmi",
"git push --force",
"git reset --hard",
"git clean -f",
"drop table",
"drop database",
"truncate",
)
class ShellTool(Tool):
"""Shell 命令执行工具
支持两种模式
1. 无会话模式默认每次命令独立执行不保持状态
2. 有会话模式通过 session_id 指定会话跨命令保持 cwd/env/history
安全控制
- 危险命令通过 confirm_callback 请求人工确认
- 所有操作记录审计日志
Usage:
# 无会话模式
tool = ShellTool()
result = await tool.execute(command="ls -la")
# 有会话模式
result = await tool.execute(command="cd /tmp", session_id="build-01")
result = await tool.execute(command="pwd", session_id="build-01") # 输出 /tmp
"""
def __init__(
self,
name: str = "shell",
description: str = "执行 Shell 命令,支持会话模式保持跨命令状态",
input_schema: dict[str, Any] | None = None,
output_schema: dict[str, Any] | None = None,
version: str = "1.0.0",
tags: list[str] | None = None,
confirm_callback: Callable[[str], Awaitable[bool]] | None = None,
default_timeout: float = 60.0,
max_output_length: int = 50000,
):
super().__init__(
name=name,
description=description,
input_schema=input_schema or self._default_input_schema(),
output_schema=output_schema or self._default_output_schema(),
version=version,
tags=tags or ["shell", "terminal", "system"],
)
self._session_manager = TerminalSessionManager()
self._output_parser = OutputParser()
self._confirm_callback = confirm_callback
self._default_timeout = default_timeout
self._max_output_length = max_output_length
self._audit_log: list[dict[str, Any]] = []
@staticmethod
def _default_input_schema() -> dict[str, Any]:
return {
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "要执行的 Shell 命令",
},
"timeout": {
"type": "number",
"description": "超时时间(秒),默认 60",
"default": 60,
},
"working_dir": {
"type": "string",
"description": "工作目录(仅无会话模式有效)",
},
"session_id": {
"type": "string",
"description": "会话 ID指定后在会话中执行命令跨命令保持状态",
},
"interactive": {
"type": "boolean",
"description": "是否使用交互式模式PTY用于需要用户输入的命令",
"default": False,
},
},
"required": ["command"],
}
@staticmethod
def _default_output_schema() -> dict[str, Any]:
return {
"type": "object",
"properties": {
"output": {"type": "string", "description": "命令输出"},
"exit_code": {"type": "integer", "description": "退出码"},
"is_error": {"type": "boolean", "description": "是否为错误"},
"error_type": {"type": "string", "description": "错误类型"},
"message": {"type": "string", "description": "消息摘要"},
"suggestions": {
"type": "array",
"items": {"type": "string"},
"description": "可操作建议",
},
"session_id": {"type": "string", "description": "会话 ID仅会话模式"},
},
}
async def execute(self, **kwargs) -> dict:
"""执行 Shell 命令
Args:
command: 要执行的命令必需
timeout: 超时时间
working_dir: 工作目录仅无会话模式
session_id: 会话 ID启用会话模式
interactive: 是否使用交互式模式
Returns:
包含 output, exit_code, is_error 等字段的字典
"""
command = kwargs.get("command")
if not command:
return {
"output": "",
"exit_code": 1,
"is_error": True,
"error_type": "invalid_argument",
"message": "command 参数是必需的",
"suggestions": ["提供要执行的 Shell 命令"],
}
timeout = kwargs.get("timeout", self._default_timeout)
working_dir = kwargs.get("working_dir")
session_id = kwargs.get("session_id")
interactive = kwargs.get("interactive", False)
# 安全检查:危险命令需要确认
if self._is_dangerous(command):
confirmed = await self._request_confirmation(command)
if not confirmed:
self._log_audit(command, None, blocked=True)
return {
"output": "",
"exit_code": 126,
"is_error": True,
"error_type": "permission_denied",
"message": f"危险命令已被拒绝执行: {command[:100]}",
"suggestions": [
"如需执行此命令,请手动确认",
"考虑使用更安全的替代命令",
],
}
# 根据模式执行
if session_id:
result = await self._execute_in_session(
command, session_id, timeout, working_dir, interactive
)
else:
result = await self._execute_standalone(command, timeout, working_dir, interactive)
# 审计日志
self._log_audit(command, session_id, exit_code=result.exit_code)
# 截断过长输出
output = result.raw_output
if len(output) > self._max_output_length:
output = output[: self._max_output_length] + "\n... [输出已截断]"
return {
"output": output,
"exit_code": result.exit_code,
"is_error": result.is_error,
"error_type": result.error_type.value,
"message": result.message,
"suggestions": result.suggestions,
"session_id": session_id,
}
async def _execute_standalone(
self,
command: str,
timeout: float,
working_dir: str | None,
interactive: bool,
) -> ParsedOutput:
"""无会话模式执行命令(向后兼容)"""
if interactive:
return await self._execute_with_pty(command, timeout, working_dir)
start = time.monotonic()
try:
proc = await asyncio.create_subprocess_shell(
command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
cwd=working_dir,
)
try:
stdout, _ = await asyncio.wait_for(
proc.communicate(),
timeout=timeout,
)
except asyncio.TimeoutError:
proc.kill()
await proc.wait()
output = f"命令执行超时({timeout}s"
exit_code = -1
else:
output = stdout.decode("utf-8", errors="replace") if stdout else ""
exit_code = proc.returncode if proc.returncode is not None else 0
except Exception as e:
output = str(e)
exit_code = -1
return self._output_parser.parse(output, exit_code)
async def _execute_in_session(
self,
command: str,
session_id: str,
timeout: float,
working_dir: str | None,
interactive: bool,
) -> ParsedOutput:
"""会话模式执行命令"""
session = self._session_manager.get_or_create(
session_id,
cwd=working_dir,
)
if interactive:
return await self._execute_with_pty(
command, timeout, session.cwd, session.env
)
return await session.execute(command, timeout=timeout)
async def _execute_with_pty(
self,
command: str,
timeout: float,
cwd: str | None = None,
env: dict[str, str] | None = None,
) -> ParsedOutput:
"""使用 PTY 执行交互式命令"""
pty = PTYSession()
try:
await pty.start()
result = await pty.run_command(
command,
timeout=timeout,
cwd=cwd,
env=env,
)
output = result.output
exit_code = result.exit_code
except Exception as e:
output = str(e)
exit_code = -1
finally:
await pty.close()
return self._output_parser.parse(output, exit_code)
def _is_dangerous(self, command: str) -> bool:
"""检查命令是否为危险操作
白名单命令直接放行其他命令检查是否匹配危险模式
"""
command_stripped = command.strip()
# 白名单检查
for prefix in _SAFE_COMMAND_PREFIXES:
if command_stripped.startswith(prefix):
return False
# 危险模式检查
command_lower = command_stripped.lower()
for pattern in _DANGEROUS_PATTERNS:
if pattern in command_lower:
return True
return False
async def _request_confirmation(self, command: str) -> bool:
"""请求人工确认危险命令
Args:
command: 待确认的命令
Returns:
是否确认执行
"""
if self._confirm_callback:
try:
return await self._confirm_callback(command)
except Exception as e:
logger.warning("确认回调执行失败: %s", e)
return False
# 无回调时默认拒绝
logger.warning("危险命令被拒绝(无确认回调): %s", command[:100])
return False
def _log_audit(
self,
command: str,
session_id: str | None,
exit_code: int | None = None,
blocked: bool = False,
) -> None:
"""记录审计日志"""
entry = {
"timestamp": time.time(),
"command": command[:500],
"session_id": session_id,
"exit_code": exit_code,
"blocked": blocked,
}
self._audit_log.append(entry)
logger.info(
"Shell audit: command=%r session=%s exit=%s blocked=%s",
command[:100],
session_id,
exit_code,
blocked,
)
@property
def session_manager(self) -> TerminalSessionManager:
"""获取会话管理器"""
return self._session_manager
@property
def audit_log(self) -> list[dict[str, Any]]:
"""获取审计日志(副本)"""
return list(self._audit_log)

View File

@ -0,0 +1,352 @@
"""TerminalSession - 终端会话状态管理
维护 cwdenvhistory支持跨命令保持状态
通过在命令前注入 cd export 语句实现跨命令状态持久化
"""
from __future__ import annotations
import asyncio
import logging
import os
import time
from dataclasses import dataclass, field
from typing import Any
from agentkit.tools.output_parser import OutputParser, ParsedOutput
logger = logging.getLogger(__name__)
@dataclass
class CommandRecord:
"""命令执行记录
Attributes:
command: 执行的命令
exit_code: 退出码
output: 标准输出+错误输出
cwd: 执行时的工作目录
timestamp: 执行时间戳
duration_ms: 执行耗时毫秒
"""
command: str
exit_code: int
output: str
cwd: str
timestamp: float
duration_ms: int
class TerminalSession:
"""终端会话 - 跨命令保持 cwd/env/history 状态
通过在命令前注入 `cd {cwd} && ` `export K=V && ` 实现跨命令状态持久化
每次命令执行后自动更新 cwd env 状态
Usage:
session = TerminalSession(session_id="build-01")
result = await session.execute("cd /tmp")
result = await session.execute("pwd") # 输出 /tmp
"""
def __init__(
self,
session_id: str,
cwd: str | None = None,
env: dict[str, str] | None = None,
max_history: int = 1000,
):
self.session_id = session_id
self._cwd = cwd or os.getcwd()
self._env: dict[str, str] = dict(env or os.environ)
self._history: list[CommandRecord] = []
self._max_history = max_history
self._output_parser = OutputParser()
self._created_at = time.time()
@property
def cwd(self) -> str:
"""当前工作目录"""
return self._cwd
@property
def env(self) -> dict[str, str]:
"""当前环境变量(副本)"""
return dict(self._env)
@property
def history(self) -> list[CommandRecord]:
"""命令执行历史(副本)"""
return list(self._history)
@property
def created_at(self) -> float:
"""会话创建时间戳"""
return self._created_at
def get_cwd(self) -> str:
"""获取当前工作目录"""
return self._cwd
def set_cwd(self, cwd: str) -> None:
"""手动设置当前工作目录"""
self._cwd = cwd
def get_env(self) -> dict[str, str]:
"""获取当前环境变量(副本)"""
return dict(self._env)
def set_env(self, key: str, value: str) -> None:
"""设置单个环境变量"""
self._env[key] = value
def update_env(self, env: dict[str, str]) -> None:
"""批量更新环境变量"""
self._env.update(env)
def get_history(self) -> list[CommandRecord]:
"""获取命令执行历史(副本)"""
return list(self._history)
async def execute(
self,
command: str,
timeout: float | None = None,
) -> ParsedOutput:
"""在会话上下文中执行命令
自动在命令前注入 cd export 语句以保持会话状态
执行后自动更新 cwd env
Args:
command: 要执行的命令
timeout: 超时时间None 表示不超时
Returns:
ParsedOutput 结构化解析结果
"""
full_command = self._build_command(command)
start = time.monotonic()
try:
proc = await asyncio.create_subprocess_shell(
full_command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
env=self._env,
)
try:
stdout, _ = await asyncio.wait_for(
proc.communicate(),
timeout=timeout,
)
except asyncio.TimeoutError:
proc.kill()
await proc.wait()
output = f"命令执行超时({timeout}s"
exit_code = -1
else:
output = stdout.decode("utf-8", errors="replace") if stdout else ""
exit_code = proc.returncode if proc.returncode is not None else 0
except Exception as e:
output = str(e)
exit_code = -1
duration_ms = int((time.monotonic() - start) * 1000)
# 更新会话状态
self._update_state_after_execution(command, output, exit_code)
# 记录历史
record = CommandRecord(
command=command,
exit_code=exit_code,
output=output,
cwd=self._cwd,
timestamp=time.time(),
duration_ms=duration_ms,
)
self._add_history(record)
# 解析输出
parsed = self._output_parser.parse(output, exit_code)
logger.debug(
"Session %s: command=%r exit_code=%d duration=%dms",
self.session_id,
command,
exit_code,
duration_ms,
)
return parsed
def _build_command(self, command: str) -> str:
"""构建带会话状态的完整命令
在原始命令前注入 cd export 语句
"""
parts: list[str] = []
# 注入 cd
if self._cwd:
# 使用 shlex.quote 风格的简单转义
cwd_escaped = self._cwd.replace("'", "'\\''")
parts.append(f"cd '{cwd_escaped}'")
# 注入环境变量
for key, value in self._env.items():
# 跳过 os.environ 中已有的且值未变的变量,减少命令长度
val_escaped = value.replace("'", "'\\''")
parts.append(f"export {key}='{val_escaped}'")
parts.append(command)
return " && ".join(parts)
def _update_state_after_execution(
self, command: str, output: str, exit_code: int
) -> None:
"""命令执行后更新会话状态
解析 cd export 命令更新 cwd env
"""
if exit_code != 0:
return
# 解析 cd 命令更新 cwd
self._parse_cd_commands(command, output)
# 解析 export 命令更新 env
self._parse_export_commands(command)
def _parse_cd_commands(self, command: str, output: str) -> None:
"""从命令中解析 cd 并更新 cwd
支持:
- cd /path
- cd dir (相对路径需要通过 pwd 获取实际路径)
- cd - (切换到上一个目录)
"""
import re
# 匹配 cd 命令(可能出现在 && 链中)
cd_pattern = re.compile(r"(?:^|\s|&&\s*)cd\s+(.+?)(?:\s*&&|\s*$)")
matches = cd_pattern.findall(command)
for target in matches:
target = target.strip().strip("'\"")
if not target:
continue
if target == "-":
# cd - 切换到 OLDPWD
old_pwd = self._env.get("OLDPWD")
if old_pwd:
self._cwd = old_pwd
elif os.path.isabs(target):
self._cwd = target
else:
# 相对路径:拼接后规范化
new_cwd = os.path.normpath(os.path.join(self._cwd, target))
self._cwd = new_cwd
def _parse_export_commands(self, command: str) -> None:
"""从命令中解析 export 并更新 env
支持:
- export KEY=VALUE
- export KEY="VALUE WITH SPACES"
"""
import re
export_pattern = re.compile(
r"(?:^|\s|&&\s*)export\s+(\w+)=(.+?)(?:\s*&&|\s*$)"
)
matches = export_pattern.findall(command)
for key, value in matches:
value = value.strip().strip("'\"")
self._env[key] = value
def _add_history(self, record: CommandRecord) -> None:
"""添加命令记录到历史,超出上限时移除最旧记录"""
self._history.append(record)
while len(self._history) > self._max_history:
self._history.pop(0)
def close(self) -> None:
"""关闭会话,清理资源"""
logger.info(
"Session %s closed: %d commands executed",
self.session_id,
len(self._history),
)
class TerminalSessionManager:
"""终端会话管理器 - 按 ID 管理多个 TerminalSession
Usage:
manager = TerminalSessionManager()
session = manager.get_or_create("build-01")
session = manager.get("build-01")
manager.remove("build-01")
"""
def __init__(self, max_sessions: int = 100):
self._sessions: dict[str, TerminalSession] = {}
self._max_sessions = max_sessions
def get_or_create(
self,
session_id: str,
cwd: str | None = None,
env: dict[str, str] | None = None,
) -> TerminalSession:
"""获取或创建会话
Args:
session_id: 会话 ID
cwd: 初始工作目录仅创建时使用
env: 初始环境变量仅创建时使用
Returns:
TerminalSession 实例
"""
if session_id not in self._sessions:
if len(self._sessions) >= self._max_sessions:
# 移除最旧的会话
oldest_id = min(
self._sessions, key=lambda k: self._sessions[k].created_at
)
self.remove(oldest_id)
self._sessions[session_id] = TerminalSession(
session_id=session_id,
cwd=cwd,
env=env,
)
logger.info("Session created: %s", session_id)
return self._sessions[session_id]
def get(self, session_id: str) -> TerminalSession | None:
"""获取会话,不存在返回 None"""
return self._sessions.get(session_id)
def remove(self, session_id: str) -> None:
"""移除并关闭会话"""
session = self._sessions.pop(session_id, None)
if session:
session.close()
def list_sessions(self) -> list[str]:
"""列出所有会话 ID"""
return list(self._sessions.keys())
def has_session(self, session_id: str) -> bool:
"""检查会话是否存在"""
return session_id in self._sessions
def close_all(self) -> None:
"""关闭所有会话"""
for session_id in list(self._sessions.keys()):
self.remove(session_id)

View File

@ -0,0 +1,512 @@
"""Tests for PathOptimizer - 执行路径优化器"""
from __future__ import annotations
from datetime import datetime, timezone
import pytest
from agentkit.evolution.path_optimizer import ExecutionPath, PathOptimizer, PathUpdateResult
# ── Fixtures ──────────────────────────────────────────────
@pytest.fixture
def optimizer():
"""默认 PathOptimizer 实例"""
return PathOptimizer(min_sample_count=3, success_rate_threshold=0.05, duration_improvement_threshold=0.2)
@pytest.fixture
def optimizer_custom_thresholds():
"""自定义阈值的 PathOptimizer"""
return PathOptimizer(
min_sample_count=5,
success_rate_threshold=0.1,
duration_improvement_threshold=0.3,
)
def _make_path(
task_type: str = "code_review",
steps: list[str] | None = None,
total_duration: float = 10.0,
success_rate: float = 0.8,
sample_count: int = 5,
is_recommended: bool = False,
path_id: str = "",
created_at: datetime | None = None,
) -> ExecutionPath:
"""创建测试用 ExecutionPath"""
return ExecutionPath(
path_id=path_id,
task_type=task_type,
steps=steps or ["step1", "step2", "step3"],
total_duration=total_duration,
success_rate=success_rate,
sample_count=sample_count,
is_recommended=is_recommended,
created_at=created_at or datetime.now(timezone.utc),
)
# ── ExecutionPath 数据模型测试 ────────────────────────────
class TestExecutionPath:
def test_default_values(self):
path = ExecutionPath()
assert path.path_id == ""
assert path.task_type == ""
assert path.steps == []
assert path.total_duration == 0.0
assert path.success_rate == 0.0
assert path.sample_count == 0
assert path.is_recommended is False
assert isinstance(path.created_at, datetime)
def test_custom_values(self):
now = datetime.now(timezone.utc)
path = ExecutionPath(
path_id="p1",
task_type="code_review",
steps=["analyze", "review", "report"],
total_duration=15.5,
success_rate=0.9,
sample_count=10,
is_recommended=True,
created_at=now,
)
assert path.path_id == "p1"
assert path.task_type == "code_review"
assert path.steps == ["analyze", "review", "report"]
assert path.total_duration == 15.5
assert path.success_rate == 0.9
assert path.sample_count == 10
assert path.is_recommended is True
assert path.created_at == now
# ── PathUpdateResult 数据模型测试 ─────────────────────────
class TestPathUpdateResult:
def test_default_values(self):
result = PathUpdateResult()
assert result.updated is False
assert result.old_path is None
assert result.new_path is None
assert result.reason == ""
def test_updated_result(self):
old = _make_path(success_rate=0.7)
new = _make_path(success_rate=0.9)
result = PathUpdateResult(
updated=True,
old_path=old,
new_path=new,
reason="成功率显著提升",
)
assert result.updated is True
assert result.old_path.success_rate == 0.7
assert result.new_path.success_rate == 0.9
assert "成功率" in result.reason
# ── get_recommended_path 测试 ─────────────────────────────
class TestGetRecommendedPath:
async def test_no_recommended_path(self, optimizer):
result = optimizer.get_recommended_path("code_review")
assert result is None
async def test_returns_recommended_path(self, optimizer):
path = _make_path(task_type="code_review", success_rate=0.8, sample_count=5)
await optimizer.evaluate_and_update("code_review", path)
result = optimizer.get_recommended_path("code_review")
assert result is not None
assert result.success_rate == 0.8
assert result.is_recommended is True
async def test_different_task_types_independent(self, optimizer):
path_a = _make_path(task_type="code_review", success_rate=0.8, sample_count=5)
path_b = _make_path(task_type="data_analysis", success_rate=0.9, sample_count=5)
await optimizer.evaluate_and_update("code_review", path_a)
await optimizer.evaluate_and_update("data_analysis", path_b)
result_a = optimizer.get_recommended_path("code_review")
result_b = optimizer.get_recommended_path("data_analysis")
assert result_a is not None
assert result_b is not None
assert result_a.success_rate == 0.8
assert result_b.success_rate == 0.9
# ── 样本量不足测试 ────────────────────────────────────────
class TestInsufficientSamples:
async def test_insufficient_samples_no_update(self, optimizer):
"""样本量不足 → 不更新,记录待观察"""
path = _make_path(sample_count=2, success_rate=0.9)
result = await optimizer.evaluate_and_update("code_review", path)
assert result.updated is False
assert "样本量不足" in result.reason
assert optimizer.get_recommended_path("code_review") is None
async def test_insufficient_samples_recorded_as_pending(self, optimizer):
"""样本量不足的路径被记录到待观察列表"""
path = _make_path(sample_count=2, success_rate=0.9)
await optimizer.evaluate_and_update("code_review", path)
pending = optimizer.get_pending_paths("code_review")
assert len(pending) == 1
assert pending[0].success_rate == 0.9
async def test_exact_min_samples_updates(self, optimizer):
"""刚好达到最小样本量 → 可以更新"""
path = _make_path(sample_count=3, success_rate=0.8)
result = await optimizer.evaluate_and_update("code_review", path)
assert result.updated is True
assert result.reason == "无现有推荐路径,直接设为推荐"
async def test_custom_min_sample_count(self, optimizer_custom_thresholds):
"""自定义最小样本量"""
path = _make_path(sample_count=4, success_rate=0.9)
result = await optimizer_custom_thresholds.evaluate_and_update("code_review", path)
assert result.updated is False
assert "样本量不足" in result.reason
# ── 首次设置推荐路径测试 ──────────────────────────────────
class TestFirstRecommendation:
async def test_first_path_becomes_recommended(self, optimizer):
"""无现有推荐路径时,新路径直接设为推荐"""
path = _make_path(success_rate=0.7, sample_count=5)
result = await optimizer.evaluate_and_update("code_review", path)
assert result.updated is True
assert result.old_path is None
assert result.new_path is not None
assert result.new_path.is_recommended is True
assert "无现有推荐路径" in result.reason
async def test_auto_generates_path_id(self, optimizer):
"""未提供 path_id 时自动生成"""
path = _make_path(path_id="", sample_count=5)
result = await optimizer.evaluate_and_update("code_review", path)
assert result.updated is True
assert result.new_path is not None
assert len(result.new_path.path_id) > 0
# ── 成功率显著提升测试 ────────────────────────────────────
class TestSuccessRateImprovement:
async def test_higher_success_rate_updates(self, optimizer):
"""新路径成功率更高 → 更新推荐路径"""
old_path = _make_path(success_rate=0.7, sample_count=5)
await optimizer.evaluate_and_update("code_review", old_path)
new_path = _make_path(success_rate=0.85, sample_count=5)
result = await optimizer.evaluate_and_update("code_review", new_path)
assert result.updated is True
assert result.old_path.success_rate == 0.7
assert result.new_path.success_rate == 0.85
assert "成功率显著提升" in result.reason
async def test_marginal_success_rate_no_update(self, optimizer):
"""成功率提升不足阈值 → 不更新"""
old_path = _make_path(success_rate=0.8, sample_count=5)
await optimizer.evaluate_and_update("code_review", old_path)
# 提升仅 0.03,低于默认阈值 0.05
new_path = _make_path(success_rate=0.83, sample_count=5)
result = await optimizer.evaluate_and_update("code_review", new_path)
assert result.updated is False
assert "无明显优势" in result.reason
async def test_custom_success_rate_threshold(self, optimizer_custom_thresholds):
"""自定义成功率阈值"""
old_path = _make_path(success_rate=0.7, sample_count=10)
await optimizer_custom_thresholds.evaluate_and_update("code_review", old_path)
# 提升 0.08,低于自定义阈值 0.1
new_path = _make_path(success_rate=0.78, sample_count=10)
result = await optimizer_custom_thresholds.evaluate_and_update("code_review", new_path)
assert result.updated is False
async def test_lower_success_rate_no_update(self, optimizer):
"""新路径成功率更低 → 不更新"""
old_path = _make_path(success_rate=0.9, sample_count=5)
await optimizer.evaluate_and_update("code_review", old_path)
new_path = _make_path(success_rate=0.6, sample_count=5)
result = await optimizer.evaluate_and_update("code_review", new_path)
assert result.updated is False
# ── 耗时显著更短测试 ──────────────────────────────────────
class TestDurationImprovement:
async def test_shorter_duration_with_similar_success_rate_updates(self, optimizer):
"""成功率相近但耗时显著更短 → 更新推荐路径"""
old_path = _make_path(total_duration=100.0, success_rate=0.8, sample_count=5)
await optimizer.evaluate_and_update("code_review", old_path)
# 耗时减少 30%> 20% 阈值),成功率相近
new_path = _make_path(total_duration=70.0, success_rate=0.82, sample_count=5)
result = await optimizer.evaluate_and_update("code_review", new_path)
assert result.updated is True
assert "耗时显著更短" in result.reason
async def test_marginal_duration_improvement_no_update(self, optimizer):
"""耗时改善不足阈值 → 不更新"""
old_path = _make_path(total_duration=100.0, success_rate=0.8, sample_count=5)
await optimizer.evaluate_and_update("code_review", old_path)
# 耗时减少仅 10%< 20% 阈值)
new_path = _make_path(total_duration=90.0, success_rate=0.82, sample_count=5)
result = await optimizer.evaluate_and_update("code_review", new_path)
assert result.updated is False
assert "无明显优势" in result.reason
async def test_longer_duration_no_update(self, optimizer):
"""耗时更长 → 不更新"""
old_path = _make_path(total_duration=50.0, success_rate=0.8, sample_count=5)
await optimizer.evaluate_and_update("code_review", old_path)
new_path = _make_path(total_duration=80.0, success_rate=0.82, sample_count=5)
result = await optimizer.evaluate_and_update("code_review", new_path)
assert result.updated is False
async def test_custom_duration_improvement_threshold(self, optimizer_custom_thresholds):
"""自定义耗时改善阈值"""
old_path = _make_path(total_duration=100.0, success_rate=0.8, sample_count=10)
await optimizer_custom_thresholds.evaluate_and_update("code_review", old_path)
# 耗时减少 25%< 30% 自定义阈值)
new_path = _make_path(total_duration=75.0, success_rate=0.82, sample_count=10)
result = await optimizer_custom_thresholds.evaluate_and_update("code_review", new_path)
assert result.updated is False
async def test_zero_duration_current_path(self, optimizer):
"""现有路径耗时为 0 → 不因耗时更新"""
old_path = _make_path(total_duration=0.0, success_rate=0.8, sample_count=5)
await optimizer.evaluate_and_update("code_review", old_path)
new_path = _make_path(total_duration=10.0, success_rate=0.82, sample_count=5)
result = await optimizer.evaluate_and_update("code_review", new_path)
assert result.updated is False
async def test_both_zero_duration(self, optimizer):
"""两者耗时均为 0 → 不因耗时更新"""
old_path = _make_path(total_duration=0.0, success_rate=0.8, sample_count=5)
await optimizer.evaluate_and_update("code_review", old_path)
new_path = _make_path(total_duration=0.0, success_rate=0.82, sample_count=5)
result = await optimizer.evaluate_and_update("code_review", new_path)
assert result.updated is False
# ── 保留现有推荐路径测试 ──────────────────────────────────
class TestKeepCurrentPath:
async def test_no_advantage_keeps_current(self, optimizer):
"""新路径无明显优势 → 保留现有推荐路径"""
old_path = _make_path(total_duration=50.0, success_rate=0.8, sample_count=5)
await optimizer.evaluate_and_update("code_review", old_path)
new_path = _make_path(total_duration=48.0, success_rate=0.79, sample_count=5)
result = await optimizer.evaluate_and_update("code_review", new_path)
assert result.updated is False
assert result.old_path.success_rate == 0.8
# 推荐路径不变
recommended = optimizer.get_recommended_path("code_review")
assert recommended is not None
assert recommended.success_rate == 0.8
async def test_is_recommended_flag_preserved(self, optimizer):
"""未更新时,现有路径的 is_recommended 标志保持为 True"""
old_path = _make_path(success_rate=0.8, sample_count=5)
await optimizer.evaluate_and_update("code_review", old_path)
new_path = _make_path(success_rate=0.79, sample_count=5)
await optimizer.evaluate_and_update("code_review", new_path)
recommended = optimizer.get_recommended_path("code_review")
assert recommended is not None
assert recommended.is_recommended is True
# ── is_recommended 标志管理测试 ────────────────────────────
class TestIsRecommendedFlag:
async def test_old_path_loses_recommended_flag(self, optimizer):
"""更新后旧路径的 is_recommended 变为 False"""
old_path = _make_path(success_rate=0.7, sample_count=5)
await optimizer.evaluate_and_update("code_review", old_path)
assert old_path.is_recommended is True # 首次设置is_recommended 为 True
new_path = _make_path(success_rate=0.9, sample_count=5)
result = await optimizer.evaluate_and_update("code_review", new_path)
assert result.updated is True
assert result.old_path.is_recommended is False # 更新后旧路径失去标志
assert result.new_path.is_recommended is True
# ── 多次迭代优化测试 ──────────────────────────────────────
class TestIterativeOptimization:
async def test_multiple_updates_converge_to_best(self, optimizer):
"""多次迭代后推荐路径收敛到最优"""
# 第一次:初始路径
path1 = _make_path(success_rate=0.6, total_duration=100.0, sample_count=5)
await optimizer.evaluate_and_update("code_review", path1)
assert optimizer.get_recommended_path("code_review").success_rate == 0.6
# 第二次:成功率显著提升
path2 = _make_path(success_rate=0.8, total_duration=90.0, sample_count=5)
await optimizer.evaluate_and_update("code_review", path2)
assert optimizer.get_recommended_path("code_review").success_rate == 0.8
# 第三次:成功率相近但耗时更短
path3 = _make_path(success_rate=0.82, total_duration=50.0, sample_count=5)
await optimizer.evaluate_and_update("code_review", path3)
assert optimizer.get_recommended_path("code_review").total_duration == 50.0
# 第四次:无明显优势
path4 = _make_path(success_rate=0.81, total_duration=48.0, sample_count=5)
result = await optimizer.evaluate_and_update("code_review", path4)
assert result.updated is False
assert optimizer.get_recommended_path("code_review").total_duration == 50.0
async def test_different_task_types_evolve_independently(self, optimizer):
"""不同任务类型的推荐路径独立进化"""
path_a1 = _make_path(task_type="code_review", success_rate=0.7, sample_count=5)
path_b1 = _make_path(task_type="data_analysis", success_rate=0.6, sample_count=5)
await optimizer.evaluate_and_update("code_review", path_a1)
await optimizer.evaluate_and_update("data_analysis", path_b1)
path_a2 = _make_path(task_type="code_review", success_rate=0.9, sample_count=5)
await optimizer.evaluate_and_update("code_review", path_a2)
# code_review 更新了data_analysis 不受影响
assert optimizer.get_recommended_path("code_review").success_rate == 0.9
assert optimizer.get_recommended_path("data_analysis").success_rate == 0.6
# ── 待观察路径管理测试 ────────────────────────────────────
class TestPendingPaths:
async def test_pending_paths_empty_initially(self, optimizer):
assert optimizer.get_pending_paths("code_review") == []
async def test_pending_paths_accumulate(self, optimizer):
"""多次样本不足的路径会累积"""
path1 = _make_path(sample_count=1, success_rate=0.9)
path2 = _make_path(sample_count=2, success_rate=0.85)
await optimizer.evaluate_and_update("code_review", path1)
await optimizer.evaluate_and_update("code_review", path2)
pending = optimizer.get_pending_paths("code_review")
assert len(pending) == 2
async def test_pending_paths_isolated_by_task_type(self, optimizer):
"""不同任务类型的待观察路径相互隔离"""
path_a = _make_path(task_type="code_review", sample_count=1, success_rate=0.9)
path_b = _make_path(task_type="data_analysis", sample_count=1, success_rate=0.8)
await optimizer.evaluate_and_update("code_review", path_a)
await optimizer.evaluate_and_update("data_analysis", path_b)
assert len(optimizer.get_pending_paths("code_review")) == 1
assert len(optimizer.get_pending_paths("data_analysis")) == 1
async def test_sufficient_samples_not_pending(self, optimizer):
"""样本量充足的路径不会进入待观察列表"""
path = _make_path(sample_count=5, success_rate=0.8)
await optimizer.evaluate_and_update("code_review", path)
assert optimizer.get_pending_paths("code_review") == []
# ── ExperienceStore 集成测试 ──────────────────────────────
class TestExperienceStoreIntegration:
async def test_with_experience_store(self):
"""PathOptimizer 可以接受 ExperienceStore 实例"""
from agentkit.evolution.experience_store import InMemoryExperienceStore
store = InMemoryExperienceStore()
optimizer = PathOptimizer(experience_store=store, min_sample_count=3)
path = _make_path(success_rate=0.8, sample_count=5)
result = await optimizer.evaluate_and_update("code_review", path)
assert result.updated is True
async def test_without_experience_store(self, optimizer):
"""PathOptimizer 可以不依赖 ExperienceStore 独立运行"""
path = _make_path(success_rate=0.8, sample_count=5)
result = await optimizer.evaluate_and_update("code_review", path)
assert result.updated is True
# ── 边界条件测试 ──────────────────────────────────────────
class TestEdgeCases:
async def test_same_path_twice(self, optimizer):
"""提交相同路径两次"""
path = _make_path(success_rate=0.8, sample_count=5)
result1 = await optimizer.evaluate_and_update("code_review", path)
assert result1.updated is True
# 第二次提交相同参数的路径(但不同实例)
path2 = _make_path(success_rate=0.8, sample_count=5)
result2 = await optimizer.evaluate_and_update("code_review", path2)
# 成功率相同,耗时相同 → 无明显优势
assert result2.updated is False
async def test_success_rate_at_boundary(self, optimizer):
"""成功率刚好在阈值边界"""
old_path = _make_path(success_rate=0.8, sample_count=5)
await optimizer.evaluate_and_update("code_review", old_path)
# 提升恰好等于阈值 0.05,不满足 > threshold
new_path = _make_path(success_rate=0.85, sample_count=5)
result = await optimizer.evaluate_and_update("code_review", new_path)
assert result.updated is False
async def test_duration_improvement_at_boundary(self, optimizer):
"""耗时改善刚好在阈值边界"""
old_path = _make_path(total_duration=100.0, success_rate=0.8, sample_count=5)
await optimizer.evaluate_and_update("code_review", old_path)
# 改善恰好等于阈值 20%,不满足 > threshold
new_path = _make_path(total_duration=80.0, success_rate=0.82, sample_count=5)
result = await optimizer.evaluate_and_update("code_review", new_path)
assert result.updated is False
async def test_zero_sample_count(self, optimizer):
"""样本量为 0"""
path = _make_path(sample_count=0, success_rate=0.9)
result = await optimizer.evaluate_and_update("code_review", path)
assert result.updated is False
assert "样本量不足" in result.reason
async def test_path_task_type_override(self, optimizer):
"""evaluate_and_update 会用传入的 task_type 覆盖路径的 task_type"""
path = _make_path(task_type="wrong_type", success_rate=0.8, sample_count=5)
result = await optimizer.evaluate_and_update("code_review", path)
assert result.updated is True
assert path.task_type == "code_review"
recommended = optimizer.get_recommended_path("code_review")
assert recommended is not None

View File

@ -0,0 +1,595 @@
"""Tests for PitfallDetector - 任务避坑预警检测"""
from __future__ import annotations
from datetime import datetime, timezone
import pytest
from agentkit.core.plan_schema import PlanStep, PlanStepStatus
from agentkit.evolution.experience_schema import TaskExperience
from agentkit.evolution.experience_store import InMemoryExperienceStore
from agentkit.evolution.pitfall_detector import (
PitfallDetector,
PitfallWarning,
WarningLevel,
_compute_name_similarity,
_determine_warning_level,
_extract_keywords,
)
# ── Fixtures ──────────────────────────────────────────────
@pytest.fixture
def store():
"""无 embedder 的 InMemoryExperienceStore"""
return InMemoryExperienceStore(decay_rate=0.01, alpha=0.7)
@pytest.fixture
def detector(store):
"""基于 InMemoryExperienceStore 的 PitfallDetector"""
return PitfallDetector(experience_store=store, similarity_threshold=0.3)
def _make_experience(
task_type: str = "code_review",
goal: str = "Review the PR",
outcome: str = "success",
steps_summary: str | list[dict] = "",
failure_reasons: list[str] | None = None,
optimization_tips: list[str] | None = None,
success_rate: float = 1.0,
) -> TaskExperience:
"""创建测试用 TaskExperience"""
return TaskExperience(
experience_id="",
task_type=task_type,
goal=goal,
steps_summary=steps_summary,
outcome=outcome,
duration_seconds=10.0,
success_rate=success_rate,
failure_reasons=failure_reasons or [],
optimization_tips=optimization_tips or [],
created_at=datetime.now(timezone.utc),
)
def _make_step(
name: str = "step",
description: str = "do something",
step_id: str = "s1",
) -> PlanStep:
"""创建测试用 PlanStep"""
return PlanStep(
step_id=step_id,
name=name,
description=description,
status=PlanStepStatus.PENDING,
)
# ── 辅助函数测试 ──────────────────────────────────────────
class TestExtractKeywords:
def test_basic_extraction(self):
keywords = _extract_keywords("Call API Gateway")
assert "call" in keywords
assert "api" in keywords
assert "gateway" in keywords
def test_stop_words_filtered(self):
keywords = _extract_keywords("Call the API and check the result")
assert "the" not in keywords
assert "and" not in keywords
assert "call" in keywords
assert "api" in keywords
def test_underscore_and_hyphen(self):
keywords = _extract_keywords("call_api-gateway")
assert "call" in keywords
assert "api" in keywords
assert "gateway" in keywords
def test_single_char_filtered(self):
keywords = _extract_keywords("a b cd")
assert "a" not in keywords
assert "b" not in keywords
assert "cd" in keywords
def test_empty_string(self):
keywords = _extract_keywords("")
assert len(keywords) == 0
class TestComputeNameSimilarity:
def test_identical_names(self):
sim = _compute_name_similarity("Call API Gateway", "", "Call API Gateway")
assert sim == pytest.approx(1.0)
def test_partial_overlap(self):
sim = _compute_name_similarity("Call API Gateway", "", "Call External API")
# 共享: call, api; 并集: call, api, gateway, external
assert 0.0 < sim < 1.0
def test_no_overlap(self):
sim = _compute_name_similarity("Deploy Service", "", "Analyze Data")
assert sim == 0.0
def test_description_contributes(self):
sim_no_desc = _compute_name_similarity("Deploy", "", "Deploy Service")
sim_with_desc = _compute_name_similarity("Deploy", "Deploy Service", "Deploy Service")
# description 中包含匹配关键词,应提高相似度
assert sim_with_desc >= sim_no_desc
def test_empty_inputs(self):
sim = _compute_name_similarity("", "", "Call API")
assert sim == 0.0
class TestDetermineWarningLevel:
def test_high_threshold(self):
assert _determine_warning_level(0.6) == WarningLevel.HIGH
assert _determine_warning_level(0.5) == WarningLevel.HIGH
def test_medium_threshold(self):
assert _determine_warning_level(0.3) == WarningLevel.MEDIUM
assert _determine_warning_level(0.2) == WarningLevel.MEDIUM
def test_low_threshold(self):
assert _determine_warning_level(0.1) == WarningLevel.LOW
assert _determine_warning_level(0.01) == WarningLevel.LOW
# ── PitfallDetector.check_pitfalls 测试 ──────────────────
class TestCheckPitfalls:
async def test_no_planned_steps_returns_empty(self, detector):
warnings = await detector.check_pitfalls(task_type="code_review", planned_steps=[])
assert warnings == []
async def test_no_failed_experiences_returns_empty(self, detector, store):
"""无历史失败记录 → 返回空列表"""
# 只记录成功经验
await store.record_experience(
_make_experience(task_type="code_review", outcome="success")
)
steps = [_make_step(name="Review Code")]
warnings = await detector.check_pitfalls(task_type="code_review", planned_steps=steps)
assert warnings == []
async def test_high_failure_rate_returns_high_warning(self, detector, store):
"""计划包含历史高失败率步骤 → 返回 HIGH 级别预警"""
# 记录多次失败经验,其中 "Call API Gateway" 步骤失败率高
for _ in range(6):
await store.record_experience(
_make_experience(
task_type="deployment",
outcome="failure",
success_rate=0.0,
steps_summary=[
{"step_name": "Call API Gateway", "outcome": "failure", "error": "Timeout"},
{"step_name": "Deploy Container", "outcome": "success"},
],
failure_reasons=["API Gateway timeout"],
)
)
# 记录少数成功经验
for _ in range(4):
await store.record_experience(
_make_experience(
task_type="deployment",
outcome="success",
success_rate=1.0,
steps_summary=[
{"step_name": "Call API Gateway", "outcome": "success"},
{"step_name": "Deploy Container", "outcome": "success"},
],
)
)
steps = [_make_step(name="Call API Gateway", description="Invoke API Gateway endpoint")]
warnings = await detector.check_pitfalls(task_type="deployment", planned_steps=steps)
assert len(warnings) == 1
warning = warnings[0]
assert warning.step_name == "Call API Gateway"
assert warning.warning_level == WarningLevel.HIGH
assert warning.failure_rate >= 0.5
assert "Timeout" in warning.historical_failures
async def test_medium_failure_rate(self, detector, store):
"""中等失败率 → MEDIUM 级别预警"""
# 3 次失败7 次成功 → 失败率 0.3
for _ in range(3):
await store.record_experience(
_make_experience(
task_type="data_analysis",
outcome="failure",
success_rate=0.0,
steps_summary=[
{"step_name": "Fetch Data", "outcome": "failure", "error": "Connection refused"},
],
)
)
for _ in range(7):
await store.record_experience(
_make_experience(
task_type="data_analysis",
outcome="success",
success_rate=1.0,
steps_summary=[
{"step_name": "Fetch Data", "outcome": "success"},
],
)
)
steps = [_make_step(name="Fetch Data", description="Fetch data from source")]
warnings = await detector.check_pitfalls(task_type="data_analysis", planned_steps=steps)
assert len(warnings) == 1
assert warnings[0].warning_level == WarningLevel.MEDIUM
assert 0.2 <= warnings[0].failure_rate < 0.5
async def test_low_failure_rate(self, detector, store):
"""低失败率 → LOW 级别预警"""
# 1 次失败9 次成功 → 失败率 0.1
await store.record_experience(
_make_experience(
task_type="testing",
outcome="failure",
success_rate=0.0,
steps_summary=[
{"step_name": "Run Unit Tests", "outcome": "failure", "error": "Flaky test"},
],
)
)
for _ in range(9):
await store.record_experience(
_make_experience(
task_type="testing",
outcome="success",
success_rate=1.0,
steps_summary=[
{"step_name": "Run Unit Tests", "outcome": "success"},
],
)
)
steps = [_make_step(name="Run Unit Tests", description="Execute unit test suite")]
warnings = await detector.check_pitfalls(task_type="testing", planned_steps=steps)
assert len(warnings) == 1
assert warnings[0].warning_level == WarningLevel.LOW
async def test_multiple_steps_with_risks_sorted_by_severity(self, detector, store):
"""多个步骤有风险 → 按严重程度排序返回"""
# "Call API" 高失败率,"Validate Input" 低失败率
for _ in range(6):
await store.record_experience(
_make_experience(
task_type="integration",
outcome="failure",
success_rate=0.0,
steps_summary=[
{"step_name": "Call API", "outcome": "failure", "error": "Timeout"},
{"step_name": "Validate Input", "outcome": "success"},
],
)
)
for _ in range(4):
await store.record_experience(
_make_experience(
task_type="integration",
outcome="success",
success_rate=1.0,
steps_summary=[
{"step_name": "Call API", "outcome": "success"},
{"step_name": "Validate Input", "outcome": "success"},
],
)
)
# 单独给 Validate Input 加一条失败记录
await store.record_experience(
_make_experience(
task_type="integration",
outcome="partial",
success_rate=0.5,
steps_summary=[
{"step_name": "Call API", "outcome": "success"},
{"step_name": "Validate Input", "outcome": "failure", "error": "Invalid schema"},
],
)
)
steps = [
_make_step(name="Validate Input", description="Validate input data", step_id="s1"),
_make_step(name="Call API", description="Call external API", step_id="s2"),
]
warnings = await detector.check_pitfalls(task_type="integration", planned_steps=steps)
assert len(warnings) == 2
# HIGH 应排在 MEDIUM/LOW 之前
assert warnings[0].warning_level == WarningLevel.HIGH
assert warnings[0].step_name == "Call API"
async def test_no_matching_steps_returns_empty(self, detector, store):
"""计划步骤与历史失败步骤无匹配 → 返回空列表"""
await store.record_experience(
_make_experience(
task_type="code_review",
outcome="failure",
success_rate=0.0,
steps_summary=[
{"step_name": "Run Linter", "outcome": "failure", "error": "Config error"},
],
)
)
# 计划步骤名称与历史步骤完全不同
steps = [_make_step(name="Deploy Application", description="Deploy to production")]
warnings = await detector.check_pitfalls(task_type="code_review", planned_steps=steps)
assert warnings == []
async def test_different_task_type_no_cross_contamination(self, detector, store):
"""不同 task_type 的失败经验不会跨类型预警"""
await store.record_experience(
_make_experience(
task_type="deployment",
outcome="failure",
success_rate=0.0,
steps_summary=[
{"step_name": "Deploy Service", "outcome": "failure", "error": "OOM"},
],
)
)
# 查询 code_review 类型,不应返回 deployment 的失败经验
steps = [_make_step(name="Deploy Service", description="Deploy the service")]
warnings = await detector.check_pitfalls(task_type="code_review", planned_steps=steps)
assert warnings == []
async def test_partial_outcome_included(self, detector, store):
"""partial 结果的经验也应被检索"""
await store.record_experience(
_make_experience(
task_type="migration",
outcome="partial",
success_rate=0.5,
steps_summary=[
{"step_name": "Migrate Database", "outcome": "failure", "error": "Schema mismatch"},
],
)
)
steps = [_make_step(name="Migrate Database", description="Migrate DB schema")]
warnings = await detector.check_pitfalls(task_type="migration", planned_steps=steps)
assert len(warnings) == 1
async def test_steps_summary_as_string_ignored(self, detector, store):
"""steps_summary 为字符串时无法提取步骤级信息,不产生预警"""
await store.record_experience(
_make_experience(
task_type="code_review",
outcome="failure",
success_rate=0.0,
steps_summary="Executed code_review task", # 字符串格式
)
)
steps = [_make_step(name="Review Code", description="Review the code")]
warnings = await detector.check_pitfalls(task_type="code_review", planned_steps=steps)
assert warnings == []
# ── AE3 场景测试 ──────────────────────────────────────────
class TestAE3Scenario:
"""AE3: "调用 X 系统 API 在高峰期超时率 60%" → 新任务调用时自动预警"""
async def test_api_timeout_high_failure_rate_warning(self, detector, store):
"""调用 X 系统 API 在高峰期超时率 60% → 新任务调用时自动预警"""
# 模拟历史10 次调用6 次超时 → 60% 失败率
for _ in range(6):
await store.record_experience(
_make_experience(
task_type="order_processing",
goal="Process orders via X system",
outcome="failure",
success_rate=0.0,
steps_summary=[
{"step_name": "Call X System API", "outcome": "failure", "error": "高峰期超时"},
{"step_name": "Process Order", "outcome": "success"},
],
failure_reasons=["X System API timeout during peak hours"],
optimization_tips=["Avoid peak hours", "Add retry logic"],
)
)
for _ in range(4):
await store.record_experience(
_make_experience(
task_type="order_processing",
goal="Process orders via X system",
outcome="success",
success_rate=1.0,
steps_summary=[
{"step_name": "Call X System API", "outcome": "success"},
{"step_name": "Process Order", "outcome": "success"},
],
)
)
# 新任务计划包含调用 X 系统 API
steps = [
_make_step(name="Call X System API", description="Invoke X system API for orders"),
]
warnings = await detector.check_pitfalls(task_type="order_processing", planned_steps=steps)
assert len(warnings) == 1
warning = warnings[0]
assert warning.warning_level == WarningLevel.HIGH
assert warning.failure_rate >= 0.5
assert any("超时" in reason for reason in warning.historical_failures)
assert warning.suggestion # 应有建议
# ── PitfallWarning 数据模型测试 ───────────────────────────
class TestPitfallWarning:
def test_creation(self):
warning = PitfallWarning(
step_name="Call API",
warning_level=WarningLevel.HIGH,
failure_rate=0.6,
historical_failures=["Timeout", "Connection refused"],
suggestion="Add retry logic",
)
assert warning.step_name == "Call API"
assert warning.warning_level == WarningLevel.HIGH
assert warning.failure_rate == 0.6
assert warning.historical_failures == ["Timeout", "Connection refused"]
assert warning.suggestion == "Add retry logic"
def test_default_values(self):
warning = PitfallWarning(
step_name="Test",
warning_level=WarningLevel.LOW,
failure_rate=0.1,
)
assert warning.historical_failures == []
assert warning.suggestion == ""
# ── WarningLevel 枚举测试 ─────────────────────────────────
class TestWarningLevel:
def test_values(self):
assert WarningLevel.HIGH.value == "high"
assert WarningLevel.MEDIUM.value == "medium"
assert WarningLevel.LOW.value == "low"
def test_string_comparison(self):
assert WarningLevel.HIGH == "high"
assert WarningLevel.MEDIUM == "medium"
assert WarningLevel.LOW == "low"
# ── 相似度阈值配置测试 ─────────────────────────────────────
class TestSimilarityThreshold:
async def test_custom_threshold(self, store):
"""自定义相似度阈值"""
# 低阈值:更容易匹配
detector_low = PitfallDetector(experience_store=store, similarity_threshold=0.1)
# 高阈值:更难匹配
detector_high = PitfallDetector(experience_store=store, similarity_threshold=0.8)
await store.record_experience(
_make_experience(
task_type="testing",
outcome="failure",
success_rate=0.0,
steps_summary=[
{"step_name": "Run Integration Tests", "outcome": "failure", "error": "Timeout"},
],
)
)
steps = [_make_step(name="Run Unit Tests", description="Execute tests")]
# 低阈值可能匹配,高阈值可能不匹配
warnings_low = await detector_low.check_pitfalls(task_type="testing", planned_steps=steps)
warnings_high = await detector_high.check_pitfalls(task_type="testing", planned_steps=steps)
# 低阈值匹配数 >= 高阈值匹配数
assert len(warnings_low) >= len(warnings_high)
# ── 端到端流程测试 ─────────────────────────────────────────
class TestEndToEnd:
async def test_full_pitfall_detection_flow(self, detector, store):
"""完整的避坑检测流程"""
# 1. 记录多种失败经验
await store.record_experience(
_make_experience(
task_type="deployment",
outcome="failure",
success_rate=0.0,
steps_summary=[
{"step_name": "Build Docker Image", "outcome": "failure", "error": "OOM"},
{"step_name": "Push to Registry", "outcome": "success"},
],
failure_reasons=["Docker build OOM"],
optimization_tips=["Increase memory limit"],
)
)
await store.record_experience(
_make_experience(
task_type="deployment",
outcome="failure",
success_rate=0.0,
steps_summary=[
{"step_name": "Build Docker Image", "outcome": "failure", "error": "Dependency conflict"},
{"step_name": "Push to Registry", "outcome": "success"},
],
failure_reasons=["Dependency conflict"],
)
)
await store.record_experience(
_make_experience(
task_type="deployment",
outcome="success",
success_rate=1.0,
steps_summary=[
{"step_name": "Build Docker Image", "outcome": "success"},
{"step_name": "Push to Registry", "outcome": "success"},
],
)
)
# 2. 新任务计划
steps = [
_make_step(name="Build Docker Image", description="Build the container image", step_id="s1"),
_make_step(name="Push to Registry", description="Push image to container registry", step_id="s2"),
]
# 3. 检测避坑
warnings = await detector.check_pitfalls(task_type="deployment", planned_steps=steps)
# 4. 验证结果
assert len(warnings) >= 1
# Build Docker Image 失败率 2/3 ≈ 0.667,应为 HIGH
build_warning = next((w for w in warnings if w.step_name == "Build Docker Image"), None)
assert build_warning is not None
assert build_warning.warning_level == WarningLevel.HIGH
assert build_warning.failure_rate == pytest.approx(2.0 / 3.0, abs=0.01)
async def test_suggestion_contains_useful_info(self, detector, store):
"""预警建议应包含有用的失败原因和优化建议"""
await store.record_experience(
_make_experience(
task_type="api_integration",
outcome="failure",
success_rate=0.0,
steps_summary=[
{"step_name": "Authenticate", "outcome": "failure", "error": "Token expired"},
],
failure_reasons=["Token expired"],
optimization_tips=["Refresh token before expiry"],
)
)
steps = [_make_step(name="Authenticate", description="Authenticate with API")]
warnings = await detector.check_pitfalls(task_type="api_integration", planned_steps=steps)
assert len(warnings) == 1
assert "Token expired" in warnings[0].suggestion

View File

@ -0,0 +1,217 @@
"""PTYSession 单元测试
测试场景
- PTY 会话启动和关闭
- 交互式命令执行
- 自动应答提示
- 超时处理
- 自定义应答规则
"""
from __future__ import annotations
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from agentkit.tools.pty_session import PTYSession, PTYOutput
from agentkit.tools.shell import ShellTool
class TestPTYSessionConstruction:
"""测试 PTYSession 构造"""
def test_default_construction(self):
pty = PTYSession()
assert pty.is_running is False
assert pty._auto_respond is True
assert pty._default_timeout == 30.0
def test_custom_construction(self):
pty = PTYSession(
auto_respond=False,
custom_rules=[(r"confirm\?", "yes")],
default_timeout=60.0,
)
assert pty._auto_respond is False
assert len(pty._respond_rules) > len(pty._respond_rules) - 1
assert pty._default_timeout == 60.0
class TestPTYSessionLifecycle:
"""测试 PTYSession 生命周期"""
@pytest.mark.asyncio
async def test_start_and_close(self):
"""启动和关闭 PTY 会话"""
pty = PTYSession()
await pty.start()
assert pty.is_running is True
await pty.close()
assert pty.is_running is False
@pytest.mark.asyncio
async def test_start_idempotent(self):
"""重复启动不报错"""
pty = PTYSession()
await pty.start()
await pty.start() # 不应抛出异常
assert pty.is_running is True
await pty.close()
@pytest.mark.asyncio
async def test_close_without_start(self):
"""未启动时关闭不报错"""
pty = PTYSession()
await pty.close() # 不应抛出异常
class TestPTYSessionExecution:
"""测试 PTYSession 命令执行"""
@pytest.mark.asyncio
async def test_run_simple_command(self):
"""执行简单命令"""
pty = PTYSession(default_timeout=10.0)
try:
await pty.start()
result = await pty.run_command("echo hello_pty")
assert "hello_pty" in result.output
assert result.exit_code == 0
assert result.timed_out is False
finally:
await pty.close()
@pytest.mark.asyncio
async def test_run_command_with_cwd(self):
"""指定工作目录执行命令"""
pty = PTYSession(default_timeout=10.0)
try:
await pty.start()
result = await pty.run_command("pwd", cwd="/tmp")
assert "/tmp" in result.output
finally:
await pty.close()
@pytest.mark.asyncio
async def test_run_command_with_env(self):
"""指定环境变量执行命令"""
pty = PTYSession(default_timeout=10.0)
try:
await pty.start()
result = await pty.run_command(
"echo $PTY_TEST_VAR",
env={"PTY_TEST_VAR": "pty_value"},
)
assert "pty_value" in result.output
finally:
await pty.close()
@pytest.mark.asyncio
async def test_run_failing_command(self):
"""执行失败命令"""
pty = PTYSession(default_timeout=10.0)
try:
await pty.start()
result = await pty.run_command("ls /nonexistent_dir_xyz_12345")
assert result.exit_code != 0
finally:
await pty.close()
@pytest.mark.asyncio
async def test_run_command_timeout(self):
"""命令超时"""
pty = PTYSession(default_timeout=10.0)
try:
await pty.start()
result = await pty.run_command("sleep 30", timeout=0.5)
assert result.timed_out is True
assert result.exit_code == -1
finally:
await pty.close()
class TestPTYSessionAutoRespond:
"""测试 PTYSession 自动应答"""
@pytest.mark.asyncio
async def test_auto_respond_yes_no(self):
"""自动应答 [y/N] 提示"""
# 使用 echo 模拟包含提示的输出,然后验证自动应答规则存在
pty = PTYSession(auto_respond=True)
# 验证规则已加载
rule_patterns = [r[0] for r in pty._respond_rules]
assert any("y/N" in p or "Y/n" in p for p in rule_patterns)
@pytest.mark.asyncio
async def test_auto_respond_disabled(self):
"""禁用自动应答"""
pty = PTYSession(auto_respond=False)
assert pty._auto_respond is False
@pytest.mark.asyncio
async def test_custom_respond_rules(self):
"""自定义应答规则"""
pty = PTYSession(
auto_respond=True,
custom_rules=[(r"continue\?\s*$", "yes")],
)
rule_patterns = [r[0] for r in pty._respond_rules]
assert r"continue\?\s*$" in rule_patterns
class TestPTYSessionSendAndRead:
"""测试 PTYSession 发送和读取"""
@pytest.mark.asyncio
async def test_send_without_start(self):
"""未启动时发送不报错"""
pty = PTYSession()
await pty.send("test") # 不应抛出异常
@pytest.mark.asyncio
async def test_read_output_without_start(self):
"""未启动时读取返回空"""
pty = PTYSession()
output = await pty.read_output()
assert output == ""
class TestPTYOutput:
"""测试 PTYOutput 数据类"""
def test_default_values(self):
output = PTYOutput(output="test")
assert output.output == "test"
assert output.exit_code == -1
assert output.timed_out is False
def test_custom_values(self):
output = PTYOutput(output="error", exit_code=1, timed_out=True)
assert output.exit_code == 1
assert output.timed_out is True
class TestShellToolInteractiveMode:
"""测试 ShellTool 交互式模式"""
@pytest.mark.asyncio
async def test_interactive_mode(self):
"""ShellTool interactive 模式执行命令"""
tool = ShellTool()
result = await tool.execute(command="echo interactive_test", interactive=True)
assert result["exit_code"] == 0
assert "interactive_test" in result["output"]
@pytest.mark.asyncio
async def test_interactive_mode_with_session(self):
"""ShellTool 会话模式 + 交互式"""
tool = ShellTool()
result = await tool.execute(
command="echo session_interactive",
session_id="int-session",
interactive=True,
)
assert result["exit_code"] == 0
assert "session_interactive" in result["output"]

View File

@ -0,0 +1,600 @@
"""TerminalSession 和 ShellTool 单元测试
测试场景
- 跨命令保持 cwd cd 后执行 pwd 返回正确目录
- 跨命令保持 env export 后执行 echo 返回正确值
- 危险命令需确认 rm 命令触发确认回调
- 输出解析 错误输出结构化为错误类型+建议
- session_id 时保持现有行为
- 会话管理器功能
"""
from __future__ import annotations
import os
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from agentkit.tools.terminal_session import TerminalSession, TerminalSessionManager, CommandRecord
from agentkit.tools.shell import ShellTool
from agentkit.tools.output_parser import OutputParser, ParsedOutput, ErrorType
# ============================================================
# OutputParser 测试
# ============================================================
class TestOutputParser:
"""测试 OutputParser 结构化解析"""
def setup_method(self):
self.parser = OutputParser()
def test_parse_success_output(self):
"""成功输出解析"""
result = self.parser.parse("hello world", 0)
assert result.exit_code == 0
assert result.is_error is False
assert result.error_type == ErrorType.NONE
assert result.message == "hello world"
assert result.suggestions == []
def test_parse_empty_output(self):
"""空输出解析"""
result = self.parser.parse("", 0)
assert result.exit_code == 0
assert result.is_error is False
assert result.message == ""
def test_parse_permission_denied(self):
"""权限不足错误解析"""
result = self.parser.parse("permission denied: /root/secret", 1)
assert result.is_error is True
assert result.error_type == ErrorType.PERMISSION_DENIED
assert len(result.suggestions) > 0
assert any("sudo" in s for s in result.suggestions)
def test_parse_not_found(self):
"""文件不存在错误解析"""
result = self.parser.parse("No such file or directory: /tmp/missing", 1)
assert result.is_error is True
assert result.error_type == ErrorType.NOT_FOUND
def test_parse_timeout(self):
"""超时错误解析"""
result = self.parser.parse("Connection timed out", 1)
assert result.is_error is True
assert result.error_type == ErrorType.TIMEOUT
def test_parse_syntax_error(self):
"""语法错误解析"""
result = self.parser.parse("syntax error near unexpected token", 2)
assert result.is_error is True
assert result.error_type == ErrorType.SYNTAX_ERROR
def test_parse_connection_refused(self):
"""连接被拒绝解析"""
result = self.parser.parse("Connection refused on port 8080", 1)
assert result.is_error is True
assert result.error_type == ErrorType.CONNECTION_REFUSED
def test_parse_out_of_memory(self):
"""内存不足解析"""
result = self.parser.parse("Out of memory: cannot allocate", 1)
assert result.is_error is True
assert result.error_type == ErrorType.OUT_OF_MEMORY
def test_parse_disk_full(self):
"""磁盘满解析"""
result = self.parser.parse("No space left on device", 1)
assert result.is_error is True
assert result.error_type == ErrorType.DISK_FULL
def test_parse_already_exists(self):
"""已存在解析"""
result = self.parser.parse("File already exists: /tmp/test", 1)
assert result.is_error is True
assert result.error_type == ErrorType.ALREADY_EXISTS
def test_parse_invalid_argument(self):
"""无效参数解析"""
result = self.parser.parse("invalid argument: --unknown-flag", 1)
assert result.is_error is True
assert result.error_type == ErrorType.INVALID_ARGUMENT
def test_parse_network_error(self):
"""网络错误解析"""
result = self.parser.parse("Network is unreachable", 1)
assert result.is_error is True
assert result.error_type == ErrorType.NETWORK_ERROR
def test_parse_exit_code_126(self):
"""退出码 126 → 权限不足"""
result = self.parser.parse("some unknown error", 126)
assert result.is_error is True
assert result.error_type == ErrorType.PERMISSION_DENIED
def test_parse_exit_code_127(self):
"""退出码 127 → 命令未找到"""
result = self.parser.parse("some unknown error", 127)
assert result.is_error is True
assert result.error_type == ErrorType.NOT_FOUND
def test_parse_exit_code_130(self):
"""退出码 130 → 被中断"""
result = self.parser.parse("some unknown error", 130)
assert result.is_error is True
assert result.error_type == ErrorType.TIMEOUT
def test_parse_unknown_error(self):
"""未知错误"""
result = self.parser.parse("something went wrong", 1)
assert result.is_error is True
assert result.error_type == ErrorType.UNKNOWN
def test_parse_long_message_truncated(self):
"""长消息截断"""
long_output = "x" * 300
result = self.parser.parse(long_output, 0)
assert len(result.message) <= 203 # 200 + "..."
def test_parsed_output_to_dict(self):
"""ParsedOutput.to_dict()"""
result = self.parser.parse("permission denied", 1)
d = result.to_dict()
assert d["exit_code"] == 1
assert d["is_error"] is True
assert d["error_type"] == "permission_denied"
assert isinstance(d["suggestions"], list)
def test_parse_chinese_error_messages(self):
"""中文错误消息解析"""
result = self.parser.parse("权限不足: 无法访问", 1)
assert result.is_error is True
assert result.error_type == ErrorType.PERMISSION_DENIED
def test_parse_multiline_output_message_is_last_line(self):
"""多行输出取最后一行作为消息"""
output = "line1\nline2\nline3"
result = self.parser.parse(output, 0)
assert result.message == "line3"
# ============================================================
# TerminalSession 测试
# ============================================================
class TestTerminalSession:
"""测试 TerminalSession 会话状态管理"""
def test_construction_default(self):
"""默认构造"""
session = TerminalSession(session_id="test")
assert session.session_id == "test"
assert session.cwd == os.getcwd()
assert isinstance(session.env, dict)
assert session.history == []
def test_construction_custom_cwd(self):
"""自定义工作目录"""
session = TerminalSession(session_id="test", cwd="/tmp")
assert session.cwd == "/tmp"
def test_construction_custom_env(self):
"""自定义环境变量"""
session = TerminalSession(session_id="test", env={"FOO": "bar"})
assert session.env.get("FOO") == "bar"
def test_set_cwd(self):
"""手动设置 cwd"""
session = TerminalSession(session_id="test")
session.set_cwd("/usr/local")
assert session.cwd == "/usr/local"
def test_set_env(self):
"""手动设置环境变量"""
session = TerminalSession(session_id="test")
session.set_env("MY_VAR", "hello")
assert session.env.get("MY_VAR") == "hello"
def test_update_env(self):
"""批量更新环境变量"""
session = TerminalSession(session_id="test")
session.update_env({"A": "1", "B": "2"})
assert session.env.get("A") == "1"
assert session.env.get("B") == "2"
def test_get_env_returns_copy(self):
"""get_env 返回副本,修改不影响原数据"""
session = TerminalSession(session_id="test")
env = session.get_env()
env["HACKED"] = "yes"
assert "HACKED" not in session.env
def test_get_history_returns_copy(self):
"""get_history 返回副本"""
session = TerminalSession(session_id="test")
history = session.get_history()
assert history is not session._history
@pytest.mark.asyncio
async def test_execute_simple_command(self):
"""执行简单命令"""
session = TerminalSession(session_id="test")
result = await session.execute("echo hello")
assert result.exit_code == 0
assert "hello" in result.raw_output
@pytest.mark.asyncio
async def test_execute_records_history(self):
"""执行命令记录历史"""
session = TerminalSession(session_id="test")
await session.execute("echo first")
await session.execute("echo second")
assert len(session.history) == 2
assert session.history[0].command == "echo first"
assert session.history[1].command == "echo second"
@pytest.mark.asyncio
async def test_cross_command_cwd(self):
"""跨命令保持 cwdcd 后 pwd 返回正确目录"""
session = TerminalSession(session_id="test")
await session.execute("cd /tmp")
assert session.cwd == "/tmp"
result = await session.execute("pwd")
assert "/tmp" in result.raw_output
@pytest.mark.asyncio
async def test_cross_command_env(self):
"""跨命令保持 envexport 后 echo 返回正确值"""
session = TerminalSession(session_id="test")
await session.execute("export MY_TEST_VAR=hello123")
assert session.env.get("MY_TEST_VAR") == "hello123"
result = await session.execute("echo $MY_TEST_VAR")
assert "hello123" in result.raw_output
@pytest.mark.asyncio
async def test_cd_relative_path(self):
"""cd 相对路径(目录存在时更新 cwd"""
# 使用 /usr 作为基础目录cd local/usr/local 存在)
session = TerminalSession(session_id="test", cwd="/usr")
await session.execute("cd local")
assert session.cwd == "/usr/local"
@pytest.mark.asyncio
async def test_cd_absolute_path(self):
"""cd 绝对路径"""
session = TerminalSession(session_id="test")
await session.execute("cd /usr")
assert session.cwd == "/usr"
@pytest.mark.asyncio
async def test_failed_command_no_state_update(self):
"""失败命令不更新状态"""
session = TerminalSession(session_id="test", cwd="/tmp")
await session.execute("cd /nonexistent_dir_xyz")
# cd 失败cwd 不应更新
assert session.cwd == "/tmp"
@pytest.mark.asyncio
async def test_timeout(self):
"""命令超时"""
session = TerminalSession(session_id="test")
result = await session.execute("sleep 10", timeout=0.5)
assert result.exit_code == -1
assert result.is_error is True
@pytest.mark.asyncio
async def test_max_history(self):
"""历史记录上限"""
session = TerminalSession(session_id="test", max_history=3)
for i in range(5):
await session.execute(f"echo {i}")
assert len(session.history) == 3
assert session.history[0].command == "echo 2"
def test_close(self):
"""关闭会话"""
session = TerminalSession(session_id="test")
session.close() # 不应抛出异常
# ============================================================
# TerminalSessionManager 测试
# ============================================================
class TestTerminalSessionManager:
"""测试 TerminalSessionManager 会话管理"""
def test_get_or_create_new(self):
"""创建新会话"""
manager = TerminalSessionManager()
session = manager.get_or_create("s1")
assert session.session_id == "s1"
def test_get_or_create_existing(self):
"""获取已有会话"""
manager = TerminalSessionManager()
s1 = manager.get_or_create("s1")
s1.set_cwd("/tmp")
s2 = manager.get_or_create("s1")
assert s2.cwd == "/tmp"
def test_get_existing(self):
"""get 获取已有会话"""
manager = TerminalSessionManager()
manager.get_or_create("s1")
session = manager.get("s1")
assert session is not None
def test_get_nonexistent(self):
"""get 不存在的会话返回 None"""
manager = TerminalSessionManager()
assert manager.get("nonexistent") is None
def test_remove(self):
"""移除会话"""
manager = TerminalSessionManager()
manager.get_or_create("s1")
manager.remove("s1")
assert manager.get("s1") is None
def test_list_sessions(self):
"""列出会话"""
manager = TerminalSessionManager()
manager.get_or_create("s1")
manager.get_or_create("s2")
assert sorted(manager.list_sessions()) == ["s1", "s2"]
def test_has_session(self):
"""检查会话是否存在"""
manager = TerminalSessionManager()
manager.get_or_create("s1")
assert manager.has_session("s1") is True
assert manager.has_session("s2") is False
def test_max_sessions_eviction(self):
"""超过最大会话数时移除最旧会话"""
manager = TerminalSessionManager(max_sessions=2)
manager.get_or_create("s1")
manager.get_or_create("s2")
manager.get_or_create("s3") # 应该移除 s1
assert not manager.has_session("s1")
assert manager.has_session("s2")
assert manager.has_session("s3")
def test_close_all(self):
"""关闭所有会话"""
manager = TerminalSessionManager()
manager.get_or_create("s1")
manager.get_or_create("s2")
manager.close_all()
assert manager.list_sessions() == []
# ============================================================
# ShellTool 测试
# ============================================================
class TestShellToolConstruction:
"""测试 ShellTool 构造"""
def test_default_construction(self):
tool = ShellTool()
assert tool.name == "shell"
assert tool.input_schema is not None
assert "command" in tool.input_schema["properties"]
assert "session_id" in tool.input_schema["properties"]
assert tool.input_schema["required"] == ["command"]
def test_custom_construction(self):
tool = ShellTool(name="my_shell", version="2.0.0")
assert tool.name == "my_shell"
assert tool.version == "2.0.0"
def test_to_dict(self):
tool = ShellTool()
d = tool.to_dict()
assert d["name"] == "shell"
assert "input_schema" in d
def test_repr(self):
tool = ShellTool()
r = repr(tool)
assert "ShellTool" in r
assert "shell" in r
class TestShellToolExecution:
"""测试 ShellTool 命令执行"""
@pytest.mark.asyncio
async def test_execute_simple_command(self):
"""执行简单命令(无会话模式)"""
tool = ShellTool()
result = await tool.execute(command="echo hello")
assert result["exit_code"] == 0
assert "hello" in result["output"]
assert result["is_error"] is False
assert result["session_id"] is None
@pytest.mark.asyncio
async def test_execute_missing_command(self):
"""缺少 command 参数"""
tool = ShellTool()
result = await tool.execute()
assert result["is_error"] is True
assert result["exit_code"] == 1
@pytest.mark.asyncio
async def test_execute_with_working_dir(self):
"""指定工作目录"""
tool = ShellTool()
result = await tool.execute(command="pwd", working_dir="/tmp")
assert result["exit_code"] == 0
assert "/tmp" in result["output"]
@pytest.mark.asyncio
async def test_execute_with_session(self):
"""会话模式执行命令"""
tool = ShellTool()
result = await tool.execute(command="echo session_test", session_id="s1")
assert result["exit_code"] == 0
assert "session_test" in result["output"]
assert result["session_id"] == "s1"
@pytest.mark.asyncio
async def test_session_preserves_cwd(self):
"""会话模式保持 cwd"""
tool = ShellTool()
await tool.execute(command="cd /tmp", session_id="cwd-test")
result = await tool.execute(command="pwd", session_id="cwd-test")
assert "/tmp" in result["output"]
@pytest.mark.asyncio
async def test_session_preserves_env(self):
"""会话模式保持 env"""
tool = ShellTool()
await tool.execute(
command="export SHELL_TEST_VAR=world", session_id="env-test"
)
result = await tool.execute(
command="echo $SHELL_TEST_VAR", session_id="env-test"
)
assert "world" in result["output"]
@pytest.mark.asyncio
async def test_no_session_id_backward_compatible(self):
"""无 session_id 时保持现有行为"""
tool = ShellTool()
result = await tool.execute(command="echo no_session")
assert result["exit_code"] == 0
assert "no_session" in result["output"]
assert result["session_id"] is None
@pytest.mark.asyncio
async def test_different_sessions_independent(self):
"""不同会话互不影响"""
tool = ShellTool()
await tool.execute(command="cd /tmp", session_id="s1")
await tool.execute(command="cd /usr", session_id="s2")
r1 = await tool.execute(command="pwd", session_id="s1")
r2 = await tool.execute(command="pwd", session_id="s2")
assert "/tmp" in r1["output"]
assert "/usr" in r2["output"]
class TestShellToolSecurity:
"""测试 ShellTool 安全控制"""
@pytest.mark.asyncio
async def test_safe_command_allowed(self):
"""安全命令直接执行"""
tool = ShellTool()
result = await tool.execute(command="ls /tmp")
assert result["exit_code"] == 0
@pytest.mark.asyncio
async def test_dangerous_command_blocked_without_callback(self):
"""危险命令无确认回调时被拒绝"""
tool = ShellTool()
result = await tool.execute(command="rm -rf /tmp/test")
assert result["is_error"] is True
assert result["exit_code"] == 126
@pytest.mark.asyncio
async def test_dangerous_command_confirmed(self):
"""危险命令通过确认回调允许执行"""
confirm = AsyncMock(return_value=True)
tool = ShellTool(confirm_callback=confirm)
result = await tool.execute(command="rm -rf /tmp/nonexistent_test_dir")
assert confirm.called
# 命令本身可能失败(目录不存在),但不应被安全机制拒绝
assert result["exit_code"] != 126 or not result["is_error"]
@pytest.mark.asyncio
async def test_dangerous_command_rejected_by_callback(self):
"""确认回调拒绝危险命令"""
confirm = AsyncMock(return_value=False)
tool = ShellTool(confirm_callback=confirm)
result = await tool.execute(command="rm -rf /tmp/test")
assert result["is_error"] is True
assert result["exit_code"] == 126
@pytest.mark.asyncio
async def test_audit_log_recorded(self):
"""审计日志记录"""
tool = ShellTool()
await tool.execute(command="echo audit_test")
assert len(tool.audit_log) > 0
assert tool.audit_log[0]["command"] == "echo audit_test"
@pytest.mark.asyncio
async def test_blocked_command_in_audit_log(self):
"""被阻止的命令记录在审计日志"""
tool = ShellTool()
await tool.execute(command="rm -rf /tmp/test")
blocked_entries = [e for e in tool.audit_log if e.get("blocked")]
assert len(blocked_entries) > 0
@pytest.mark.asyncio
async def test_git_push_force_is_dangerous(self):
"""git push --force 是危险命令"""
tool = ShellTool()
result = await tool.execute(command="git push --force origin main")
assert result["is_error"] is True
assert result["exit_code"] == 126
@pytest.mark.asyncio
async def test_git_status_is_safe(self):
"""git status 是安全命令"""
tool = ShellTool()
result = await tool.execute(command="git status")
# git status 可能在非 git 目录失败,但不应被安全机制拒绝
assert result["exit_code"] != 126
class TestShellToolOutputParsing:
"""测试 ShellTool 输出解析集成"""
@pytest.mark.asyncio
async def test_error_output_structured(self):
"""错误输出结构化"""
tool = ShellTool()
result = await tool.execute(command="ls /nonexistent_dir_xyz_12345")
assert result["is_error"] is True
assert result["error_type"] in ("not_found", "unknown")
assert isinstance(result["suggestions"], list)
@pytest.mark.asyncio
async def test_success_output_not_error(self):
"""成功输出不标记为错误"""
tool = ShellTool()
result = await tool.execute(command="echo success")
assert result["is_error"] is False
assert result["error_type"] == "none"
class TestShellToolSessionManager:
"""测试 ShellTool 会话管理器访问"""
def test_session_manager_accessible(self):
tool = ShellTool()
assert tool.session_manager is not None
@pytest.mark.asyncio
async def test_session_created_on_first_use(self):
"""首次使用 session_id 时创建会话"""
tool = ShellTool()
assert not tool.session_manager.has_session("new-session")
await tool.execute(command="echo test", session_id="new-session")
assert tool.session_manager.has_session("new-session")