diff --git a/src/agentkit/evolution/path_optimizer.py b/src/agentkit/evolution/path_optimizer.py new file mode 100644 index 0000000..395ac09 --- /dev/null +++ b/src/agentkit/evolution/path_optimizer.py @@ -0,0 +1,259 @@ +"""PathOptimizer - 执行路径优化器 + +发现更优执行路径时自动更新经验库中的推荐路径。 + +核心逻辑: +1. 对比新路径与现有最优路径(综合耗时和成功率) +2. 新路径成功率更高 → 更新推荐路径 +3. 成功率相近但耗时更短 → 更新推荐路径 +4. 样本量不足 → 不更新,记录待观察 +""" + +from __future__ import annotations + +import logging +import uuid +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from agentkit.evolution.experience_store import InMemoryExperienceStore + +logger = logging.getLogger(__name__) + + +@dataclass +class ExecutionPath: + """执行路径数据模型 + + 记录特定任务类型的执行路径信息,用于路径优化比较。 + + Attributes: + path_id: 路径唯一标识 + task_type: 任务类型 + steps: 执行步骤名称列表 + total_duration: 总耗时(秒) + success_rate: 成功率(0.0 ~ 1.0) + sample_count: 样本数量 + is_recommended: 是否为当前推荐路径 + created_at: 创建时间 + """ + + path_id: str = "" + task_type: str = "" + steps: list[str] = field(default_factory=list) + total_duration: float = 0.0 + success_rate: float = 0.0 + sample_count: int = 0 + is_recommended: bool = False + created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + + +@dataclass +class PathUpdateResult: + """路径更新结果 + + Attributes: + updated: 是否更新了推荐路径 + old_path: 更新前的推荐路径(未更新时为 None) + new_path: 更新后的推荐路径(未更新时为 None) + reason: 更新/未更新的原因说明 + """ + + updated: bool = False + old_path: ExecutionPath | None = None + new_path: ExecutionPath | None = None + reason: str = "" + + +class PathOptimizer: + """执行路径优化器 + + 对比新路径与现有最优路径,决定是否更新推荐路径。 + 可独立使用,也可集成到 PlanChecker 的复盘中。 + + 更新策略: + 1. 新路径成功率 > 现有成功率 + success_rate_threshold → 更新 + 2. 成功率相近(差值 ≤ threshold)但耗时显著更短 + (duration 改善比例 > duration_improvement_threshold)→ 更新 + 3. 样本量不足(< min_sample_count)→ 不更新 + 4. 其他情况 → 保留现有推荐路径 + """ + + def __init__( + self, + experience_store: InMemoryExperienceStore | None = None, + min_sample_count: int = 3, + success_rate_threshold: float = 0.05, + duration_improvement_threshold: float = 0.2, + ): + """初始化 PathOptimizer + + Args: + experience_store: 经验存储实例(可选) + min_sample_count: 最小样本量,低于此值不做决策 + success_rate_threshold: 成功率提升阈值,超过此值视为显著提升 + duration_improvement_threshold: 耗时改善比例阈值,超过此值视为显著改善 + """ + self._experience_store = experience_store + self._min_sample_count = min_sample_count + self._success_rate_threshold = success_rate_threshold + self._duration_improvement_threshold = duration_improvement_threshold + self._recommended_paths: dict[str, ExecutionPath] = {} + self._pending_paths: dict[str, list[ExecutionPath]] = {} + + def get_recommended_path(self, task_type: str) -> ExecutionPath | None: + """获取指定任务类型的当前推荐路径 + + Args: + task_type: 任务类型 + + Returns: + 推荐路径,若无则返回 None + """ + return self._recommended_paths.get(task_type) + + async def evaluate_and_update( + self, + task_type: str, + new_path: ExecutionPath, + ) -> PathUpdateResult: + """评估新路径并决定是否更新推荐路径 + + Args: + task_type: 任务类型 + new_path: 新的执行路径 + + Returns: + 路径更新结果 + """ + # 确保新路径有 path_id + if not new_path.path_id: + new_path.path_id = str(uuid.uuid4()) + + new_path.task_type = task_type + + # 样本量不足 → 不更新,记录待观察 + if new_path.sample_count < self._min_sample_count: + self._pending_paths.setdefault(task_type, []).append(new_path) + reason = ( + f"样本量不足({new_path.sample_count} < {self._min_sample_count})," + f"记录待观察" + ) + logger.info( + f"Path not updated for '{task_type}': {reason}" + ) + return PathUpdateResult( + updated=False, + old_path=None, + new_path=new_path, + reason=reason, + ) + + current = self._recommended_paths.get(task_type) + + # 无现有推荐路径 → 直接设为推荐 + if current is None: + new_path.is_recommended = True + self._recommended_paths[task_type] = new_path + reason = "无现有推荐路径,直接设为推荐" + logger.info(f"Path set as recommended for '{task_type}': {reason}") + return PathUpdateResult( + updated=True, + old_path=None, + new_path=new_path, + reason=reason, + ) + + # 比较新路径与现有推荐路径 + return self._compare_and_decide(task_type, current, new_path) + + def _compare_and_decide( + self, + task_type: str, + current: ExecutionPath, + new: ExecutionPath, + ) -> PathUpdateResult: + """比较新旧路径并决策 + + 比较逻辑: + 1. 新路径成功率 > 现有成功率 + threshold → 更新 + 2. 成功率相近(差值 ≤ threshold)且新耗时显著更短 → 更新 + 3. 其他 → 保留现有 + """ + sr_diff = new.success_rate - current.success_rate + + # 条件 1:成功率显著提升 + if sr_diff > self._success_rate_threshold: + return self._apply_update( + task_type, current, new, + f"成功率显著提升({new.success_rate:.2f} > {current.success_rate:.2f}," + f"提升 {sr_diff:.2f})", + ) + + # 条件 2:成功率相近但耗时显著更短 + if abs(sr_diff) <= self._success_rate_threshold: + if current.total_duration > 0: + duration_improvement = ( + (current.total_duration - new.total_duration) / current.total_duration + ) + if ( + new.total_duration < current.total_duration + and duration_improvement > self._duration_improvement_threshold + ): + return self._apply_update( + task_type, current, new, + f"成功率相近({new.success_rate:.2f} vs {current.success_rate:.2f})," + f"耗时显著更短({new.total_duration:.1f}s vs {current.total_duration:.1f}s," + f"改善 {duration_improvement:.1%})", + ) + elif current.total_duration == 0 and new.total_duration > 0: + # 现有路径耗时为 0(不太可能),不更新 + pass + elif current.total_duration == 0 and new.total_duration == 0: + # 两者耗时均为 0,不更新 + pass + + # 条件 3:无明显优势 → 保留现有 + reason = ( + f"新路径无明显优势(成功率 {new.success_rate:.2f} vs {current.success_rate:.2f}," + f"耗时 {new.total_duration:.1f}s vs {current.total_duration:.1f}s),保留现有推荐路径" + ) + logger.info(f"Path not updated for '{task_type}': {reason}") + return PathUpdateResult( + updated=False, + old_path=current, + new_path=new, + reason=reason, + ) + + def _apply_update( + self, + task_type: str, + old: ExecutionPath, + new: ExecutionPath, + reason: str, + ) -> PathUpdateResult: + """应用路径更新""" + old.is_recommended = False + new.is_recommended = True + self._recommended_paths[task_type] = new + logger.info(f"Path updated for '{task_type}': {reason}") + return PathUpdateResult( + updated=True, + old_path=old, + new_path=new, + reason=reason, + ) + + def get_pending_paths(self, task_type: str) -> list[ExecutionPath]: + """获取指定任务类型的待观察路径 + + Args: + task_type: 任务类型 + + Returns: + 待观察路径列表 + """ + return list(self._pending_paths.get(task_type, [])) diff --git a/src/agentkit/evolution/pitfall_detector.py b/src/agentkit/evolution/pitfall_detector.py new file mode 100644 index 0000000..87bdfc5 --- /dev/null +++ b/src/agentkit/evolution/pitfall_detector.py @@ -0,0 +1,388 @@ +"""PitfallDetector - 任务避坑预警 + +新任务启动时检索历史失败经验,匹配当前计划步骤,自动预警。 +基于 ExperienceStore 中存储的失败经验,将失败步骤与当前计划步骤 +进行关键词匹配,计算失败率并按严重程度返回预警列表。 +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Protocol + +logger = logging.getLogger(__name__) + + +class WarningLevel(str, Enum): + """预警级别""" + + HIGH = "high" + MEDIUM = "medium" + LOW = "low" + + +@dataclass +class PitfallWarning: + """避坑预警 + + Attributes: + step_name: 计划步骤名称 + warning_level: 预警级别(HIGH/MEDIUM/LOW) + failure_rate: 历史失败率(0.0 ~ 1.0) + historical_failures: 历史失败原因列表 + suggestion: 优化建议 + """ + + step_name: str + warning_level: WarningLevel + failure_rate: float + historical_failures: list[str] = field(default_factory=list) + suggestion: str = "" + + +class ExperienceStoreProtocol(Protocol): + """ExperienceStore 协议接口,用于类型标注""" + + async def search( + self, + query: str, + top_k: int = 5, + task_type: str | None = None, + search_multiplier: int = 5, + ) -> list[Any]: + ... + + +# 预警级别阈值 +_HIGH_THRESHOLD = 0.5 +_MEDIUM_THRESHOLD = 0.2 + + +class PitfallDetector: + """避坑检测器 + + 新任务启动时检索历史失败经验,匹配当前计划步骤,自动预警。 + + 使用方式: + detector = PitfallDetector(experience_store) + warnings = await detector.check_pitfalls( + task_type="code_review", + planned_steps=[plan_step1, plan_step2, ...], + ) + + 匹配逻辑: + 1. 检索同类任务的失败经验 + 2. 从失败经验中提取失败步骤 + 3. 将失败步骤与当前计划步骤进行关键词匹配 + 4. 计算失败率并分配预警级别 + + 预警级别: + - HIGH: failure_rate >= 0.5(历史高失败率步骤) + - MEDIUM: failure_rate >= 0.2(有失败记录但频率低) + - LOW: 有任何失败记录 + """ + + def __init__( + self, + experience_store: ExperienceStoreProtocol, + similarity_threshold: float = 0.3, + max_search_results: int = 50, + ): + """ + Args: + experience_store: 经验存储实例(ExperienceStore 或 InMemoryExperienceStore) + similarity_threshold: 步骤名称关键词匹配的最小相似度阈值 + max_search_results: 从经验存储检索的最大结果数 + """ + self._store = experience_store + self._similarity_threshold = similarity_threshold + self._max_search_results = max_search_results + + async def check_pitfalls( + self, + task_type: str, + planned_steps: list[Any], + ) -> list[PitfallWarning]: + """检查计划步骤中的潜在陷阱 + + Args: + task_type: 任务类型 + planned_steps: 计划步骤列表(PlanStep 对象或具有 name/description 属性的对象) + + Returns: + 按严重程度排序的预警列表(HIGH → MEDIUM → LOW) + """ + if not planned_steps: + return [] + + # 1. 检索同类任务的所有经验(包含成功和失败,用于计算步骤级失败率) + all_experiences = await self._search_experiences(task_type) + if not all_experiences: + logger.debug(f"No experiences found for task_type={task_type}") + return [] + + # 2. 从经验中提取步骤级别的失败统计 + step_failure_stats = self._extract_step_failure_stats(all_experiences) + + # 3. 匹配当前计划步骤并生成预警 + warnings = self._match_and_warn(planned_steps, step_failure_stats) + + # 4. 按严重程度排序(HIGH → MEDIUM → LOW),同级别按失败率降序 + warnings.sort(key=lambda w: (_warning_level_order(w.warning_level), -w.failure_rate)) + + if warnings: + logger.info( + f"PitfallDetector found {len(warnings)} warnings for task_type={task_type}: " + f"{sum(1 for w in warnings if w.warning_level == WarningLevel.HIGH)} HIGH, " + f"{sum(1 for w in warnings if w.warning_level == WarningLevel.MEDIUM)} MEDIUM, " + f"{sum(1 for w in warnings if w.warning_level == WarningLevel.LOW)} LOW" + ) + + return warnings + + async def _search_experiences(self, task_type: str) -> list[Any]: + """检索指定任务类型的所有经验(包含成功和失败)""" + try: + results = await self._store.search( + query=task_type, + top_k=self._max_search_results, + task_type=task_type, + ) + return results + except Exception as e: + logger.error(f"Failed to search experiences for pitfall detection: {e}") + return [] + + def _extract_step_failure_stats( + self, failed_experiences: list[Any] + ) -> dict[str, _StepFailureStats]: + """从失败经验中提取步骤级别的失败统计 + + steps_summary 可以是 str 或 list[dict]: + - list[dict]: 每个字典包含 step_name, outcome, duration_seconds, error + - str: 退化为整体统计 + + Returns: + 以步骤名称为 key 的失败统计字典 + """ + stats: dict[str, _StepFailureStats] = {} + + for exp in failed_experiences: + steps_summary = exp.steps_summary + + # 如果 steps_summary 是字符串,无法提取步骤级信息 + if isinstance(steps_summary, str): + continue + + if not isinstance(steps_summary, list): + continue + + for step in steps_summary: + if not isinstance(step, dict): + continue + + step_name = step.get("step_name", "") + if not step_name: + continue + + outcome = step.get("outcome", "") + error = step.get("error", "") + + if step_name not in stats: + stats[step_name] = _StepFailureStats( + step_name=step_name, + total_occurrences=0, + failure_occurrences=0, + failure_reasons=[], + optimization_tips=[], + ) + + s = stats[step_name] + s.total_occurrences += 1 + + if outcome in ("failure", "failed", "error"): + s.failure_occurrences += 1 + if error: + s.failure_reasons.append(error) + + # 收集优化建议 + if hasattr(exp, "optimization_tips") and exp.optimization_tips: + for step_name, s in stats.items(): + s.optimization_tips.extend(exp.optimization_tips) + + return stats + + def _match_and_warn( + self, + planned_steps: list[Any], + step_failure_stats: dict[str, _StepFailureStats], + ) -> list[PitfallWarning]: + """将计划步骤与失败统计进行匹配,生成预警""" + warnings: list[PitfallWarning] = [] + + for step in planned_steps: + step_name = getattr(step, "name", "") + step_description = getattr(step, "description", "") + + if not step_name: + continue + + # 查找最佳匹配的失败步骤 + best_match: _StepFailureStats | None = None + best_similarity = 0.0 + + for stats_step_name, stats in step_failure_stats.items(): + similarity = _compute_name_similarity( + step_name, step_description, stats_step_name + ) + if similarity > best_similarity: + best_similarity = similarity + best_match = stats + + # 相似度低于阈值,跳过 + if best_match is None or best_similarity < self._similarity_threshold: + continue + + # 计算失败率 + failure_rate = ( + best_match.failure_occurrences / best_match.total_occurrences + if best_match.total_occurrences > 0 + else 0.0 + ) + + # 分配预警级别 + warning_level = _determine_warning_level(failure_rate) + + # 生成建议 + suggestion = _build_suggestion(best_match, failure_rate) + + warning = PitfallWarning( + step_name=step_name, + warning_level=warning_level, + failure_rate=round(failure_rate, 4), + historical_failures=best_match.failure_reasons[:5], # 最多保留 5 条 + suggestion=suggestion, + ) + warnings.append(warning) + + return warnings + + +# ── 内部辅助类 ────────────────────────────────────────────── + + +@dataclass +class _StepFailureStats: + """步骤级别的失败统计(内部使用)""" + + step_name: str + total_occurrences: int + failure_occurrences: int + failure_reasons: list[str] + optimization_tips: list[str] + + +# ── 辅助函数 ────────────────────────────────────────────── + + +def _compute_name_similarity( + step_name: str, step_description: str, historical_step_name: str +) -> float: + """计算步骤名称的关键词重叠相似度 + + 基于关键词集合的 Jaccard 相似度,同时考虑 step_name 和 step_description。 + + Args: + step_name: 当前计划步骤名称 + step_description: 当前计划步骤描述 + historical_step_name: 历史步骤名称 + + Returns: + 相似度分数(0.0 ~ 1.0) + """ + # 提取关键词:将名称拆分为词,过滤掉常见停用词 + current_keywords = _extract_keywords(f"{step_name} {step_description}") + historical_keywords = _extract_keywords(historical_step_name) + + if not current_keywords or not historical_keywords: + return 0.0 + + # Jaccard 相似度 + intersection = current_keywords & historical_keywords + union = current_keywords | historical_keywords + + if not union: + return 0.0 + + return len(intersection) / len(union) + + +_STOP_WORDS = frozenset({ + "a", "an", "the", "and", "or", "but", "in", "on", "at", "to", "for", + "of", "with", "by", "from", "is", "are", "was", "were", "be", "been", + "being", "have", "has", "had", "do", "does", "did", "will", "would", + "could", "should", "may", "might", "can", "shall", "not", "no", +}) + + +def _extract_keywords(text: str) -> frozenset[str]: + """从文本中提取关键词集合 + + 转小写、按空白/下划线/连字符拆分、过滤停用词和单字符词。 + """ + # 统一分隔符 + normalized = text.lower().replace("_", " ").replace("-", " ") + words = normalized.split() + return frozenset( + w for w in words + if len(w) > 1 and w not in _STOP_WORDS + ) + + +def _determine_warning_level(failure_rate: float) -> WarningLevel: + """根据失败率确定预警级别 + + - HIGH: failure_rate >= 0.5 + - MEDIUM: failure_rate >= 0.2 + - LOW: 有任何失败记录 + """ + if failure_rate >= _HIGH_THRESHOLD: + return WarningLevel.HIGH + if failure_rate >= _MEDIUM_THRESHOLD: + return WarningLevel.MEDIUM + return WarningLevel.LOW + + +def _warning_level_order(level: WarningLevel) -> int: + """预警级别排序值(越小越严重)""" + return { + WarningLevel.HIGH: 0, + WarningLevel.MEDIUM: 1, + WarningLevel.LOW: 2, + }[level] + + +def _build_suggestion(stats: _StepFailureStats, failure_rate: float) -> str: + """根据失败统计生成优化建议""" + parts: list[str] = [] + + if failure_rate >= _HIGH_THRESHOLD: + parts.append(f"该步骤历史失败率高达 {failure_rate:.0%},建议重点关注") + elif failure_rate >= _MEDIUM_THRESHOLD: + parts.append(f"该步骤历史失败率为 {failure_rate:.0%},需注意风险") + else: + parts.append(f"该步骤有少量失败记录(失败率 {failure_rate:.0%})") + + if stats.failure_reasons: + unique_reasons = list(dict.fromkeys(stats.failure_reasons))[:3] + reasons_str = "、".join(unique_reasons) + parts.append(f"常见失败原因:{reasons_str}") + + if stats.optimization_tips: + unique_tips = list(dict.fromkeys(stats.optimization_tips))[:2] + tips_str = ";".join(unique_tips) + parts.append(f"建议:{tips_str}") + + return "。".join(parts) diff --git a/src/agentkit/tools/__init__.py b/src/agentkit/tools/__init__.py index 3aef0be..525298b 100644 --- a/src/agentkit/tools/__init__.py +++ b/src/agentkit/tools/__init__.py @@ -9,6 +9,10 @@ from agentkit.tools.composition import SequentialChain, ParallelFanOut, DynamicS from agentkit.tools.web_crawl import WebCrawlTool from agentkit.tools.schema_tools import SchemaExtractTool, SchemaGenerateTool from agentkit.tools.baidu_search import BaiduSearchTool +from agentkit.tools.shell import ShellTool +from agentkit.tools.terminal_session import TerminalSession, TerminalSessionManager +from agentkit.tools.pty_session import PTYSession +from agentkit.tools.output_parser import OutputParser, ParsedOutput, ErrorType # Conditional import: HeadroomRetrieveTool requires HeadroomCompressor try: @@ -30,4 +34,11 @@ __all__ = [ "SchemaGenerateTool", "BaiduSearchTool", "HeadroomRetrieveTool", + "ShellTool", + "TerminalSession", + "TerminalSessionManager", + "PTYSession", + "OutputParser", + "ParsedOutput", + "ErrorType", ] diff --git a/src/agentkit/tools/output_parser.py b/src/agentkit/tools/output_parser.py new file mode 100644 index 0000000..b371712 --- /dev/null +++ b/src/agentkit/tools/output_parser.py @@ -0,0 +1,294 @@ +"""OutputParser - 结构化解析命令输出 + +将命令行输出解析为结构化格式,包含错误类型识别、退出码含义和可操作建议。 +""" + +from __future__ import annotations + +import re +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + + +class ErrorType(Enum): + """命令输出错误类型""" + + NONE = "none" + PERMISSION_DENIED = "permission_denied" + NOT_FOUND = "not_found" + TIMEOUT = "timeout" + SYNTAX_ERROR = "syntax_error" + CONNECTION_REFUSED = "connection_refused" + OUT_OF_MEMORY = "out_of_memory" + DISK_FULL = "disk_full" + ALREADY_EXISTS = "already_exists" + INVALID_ARGUMENT = "invalid_argument" + PROCESS_NOT_FOUND = "process_not_found" + NETWORK_ERROR = "network_error" + UNKNOWN = "unknown" + + +@dataclass +class ParsedOutput: + """结构化命令输出 + + Attributes: + exit_code: 命令退出码 + is_error: 是否为错误输出 + error_type: 错误类型(仅当 is_error=True 时有值) + message: 输出消息摘要 + raw_output: 原始输出文本 + suggestions: 可操作建议列表 + """ + + exit_code: int + is_error: bool + error_type: ErrorType = ErrorType.NONE + message: str = "" + raw_output: str = "" + suggestions: list[str] = field(default_factory=list) + + def to_dict(self) -> dict[str, Any]: + return { + "exit_code": self.exit_code, + "is_error": self.is_error, + "error_type": self.error_type.value, + "message": self.message, + "suggestions": self.suggestions, + } + + +# 错误模式匹配规则:(pattern, error_type, message_template, suggestions) +_ERROR_PATTERNS: list[tuple[re.Pattern, ErrorType, str, list[str]]] = [ + ( + re.compile(r"permission denied|access denied|权限不足|拒绝访问", re.IGNORECASE), + ErrorType.PERMISSION_DENIED, + "权限不足", + [ + "尝试使用 sudo 执行该命令", + "检查文件/目录权限: ls -la ", + "确认当前用户是否有所需权限", + ], + ), + ( + re.compile( + r"not found|no such file|no such directory|找不到|不存在|无法找到", + re.IGNORECASE, + ), + ErrorType.NOT_FOUND, + "文件或目录不存在", + [ + "检查路径拼写是否正确", + "使用 ls 确认文件/目录是否存在", + "检查是否在正确的工作目录下", + ], + ), + ( + re.compile(r"timed?\s*out|timeout|超时|时间超限", re.IGNORECASE), + ErrorType.TIMEOUT, + "命令执行超时", + [ + "增加超时时间", + "检查网络连接是否正常", + "检查目标服务是否可达", + ], + ), + ( + re.compile( + r"syntax error|syntaxerror|parse error|语法错误|解析错误", + re.IGNORECASE, + ), + ErrorType.SYNTAX_ERROR, + "语法错误", + [ + "检查命令语法是否正确", + "使用 --help 查看命令用法", + "检查引号和特殊字符是否正确转义", + ], + ), + ( + re.compile( + r"connection refused|连接被拒绝|无法连接|ECONNREFUSED", + re.IGNORECASE, + ), + ErrorType.CONNECTION_REFUSED, + "连接被拒绝", + [ + "检查目标服务是否已启动", + "确认端口号是否正确", + "检查防火墙设置是否阻止了连接", + ], + ), + ( + re.compile( + r"out of memory|oom|cannot allocate|内存不足|内存溢出", + re.IGNORECASE, + ), + ErrorType.OUT_OF_MEMORY, + "内存不足", + [ + "释放不必要的内存占用", + "增加系统可用内存", + "检查是否有内存泄漏", + ], + ), + ( + re.compile( + r"no space left|disk full|磁盘已满|空间不足|ENOSPC", + re.IGNORECASE, + ), + ErrorType.DISK_FULL, + "磁盘空间不足", + [ + "清理不必要的文件: du -sh * | sort -rh | head", + "检查磁盘使用情况: df -h", + "删除临时文件或日志", + ], + ), + ( + re.compile( + r"already exists|file exists|已存在|重复|EEXIST", + re.IGNORECASE, + ), + ErrorType.ALREADY_EXISTS, + "资源已存在", + [ + "使用 -f 参数强制覆盖(如适用)", + "先删除已有资源再重新创建", + "使用不同名称创建", + ], + ), + ( + re.compile( + r"invalid argument|illegal option|bad option|无效参数|非法选项|invalid option", + re.IGNORECASE, + ), + ErrorType.INVALID_ARGUMENT, + "无效参数", + [ + "检查命令参数是否正确", + "使用 --help 查看支持的参数", + "确认参数值类型和范围", + ], + ), + ( + re.compile( + r"no such process|process not found|进程不存在|进程未找到", + re.IGNORECASE, + ), + ErrorType.PROCESS_NOT_FOUND, + "进程不存在", + [ + "确认进程 ID 是否正确", + "使用 ps aux 查看运行中的进程", + "进程可能已经结束", + ], + ), + ( + re.compile( + r"network is unreachable|no route to host|name resolution|网络不可达|无法解析|ENETUNREACH", + re.IGNORECASE, + ), + ErrorType.NETWORK_ERROR, + "网络错误", + [ + "检查网络连接是否正常", + "确认 DNS 解析是否正常: nslookup ", + "检查代理设置", + ], + ), +] + + +class OutputParser: + """命令输出结构化解析器 + + 将命令行输出(stdout + stderr)和退出码解析为结构化的 ParsedOutput, + 包含错误类型识别、消息摘要和可操作建议。 + """ + + def parse(self, output: str, exit_code: int) -> ParsedOutput: + """解析命令输出 + + Args: + output: 命令的标准输出和错误输出合并文本 + exit_code: 命令退出码 + + Returns: + ParsedOutput 结构化解析结果 + """ + is_error = exit_code != 0 + message = self._extract_message(output) + error_type = ErrorType.NONE + suggestions: list[str] = [] + + if is_error: + error_type, suggestions = self._classify_error(output, exit_code) + + return ParsedOutput( + exit_code=exit_code, + is_error=is_error, + error_type=error_type, + message=message, + raw_output=output, + suggestions=suggestions, + ) + + def _extract_message(self, output: str) -> str: + """从输出中提取关键消息 + + 取最后几行非空输出中的关键行作为消息摘要。 + """ + if not output: + return "" + + lines = [line.strip() for line in output.strip().splitlines() if line.strip()] + if not lines: + return "" + + # 取最后一行作为摘要,如果太长则截断 + message = lines[-1] + if len(message) > 200: + message = message[:200] + "..." + return message + + def _classify_error( + self, output: str, exit_code: int + ) -> tuple[ErrorType, list[str]]: + """根据输出内容和退出码分类错误类型 + + Args: + output: 命令输出 + exit_code: 退出码 + + Returns: + (error_type, suggestions) 元组 + """ + # 优先根据输出内容匹配 + for pattern, error_type, _msg, suggestions in _ERROR_PATTERNS: + if pattern.search(output): + return error_type, suggestions + + # 退出码兜底分类 + if exit_code == 126: + return ErrorType.PERMISSION_DENIED, [ + "检查文件是否有执行权限: chmod +x ", + "确认文件格式是否正确(如行尾符)", + ] + if exit_code == 127: + return ErrorType.NOT_FOUND, [ + "检查命令是否已安装", + "确认命令名称拼写是否正确", + "检查 PATH 环境变量是否包含命令所在目录", + ] + if exit_code == 130: + return ErrorType.TIMEOUT, [ + "命令被 Ctrl+C 中断", + "可能需要增加超时时间", + ] + + return ErrorType.UNKNOWN, [ + "检查命令输出中的错误信息", + "使用 --verbose 或 --debug 获取更多详情", + ] diff --git a/src/agentkit/tools/pty_session.py b/src/agentkit/tools/pty_session.py new file mode 100644 index 0000000..fdf10e1 --- /dev/null +++ b/src/agentkit/tools/pty_session.py @@ -0,0 +1,341 @@ +"""PTYSession - 伪终端会话,支持交互式命令 + +基于 asyncio + os.openpty() 实现伪终端,支持交互式命令和自动应答。 +不依赖 pexpect,仅使用标准库。 +""" + +from __future__ import annotations + +import asyncio +import fcntl +import logging +import os +import struct +import termios +import time +from dataclasses import dataclass + +logger = logging.getLogger(__name__) + +# 自动应答规则:(prompt_pattern, response) +_AUTO_RESPOND_RULES: list[tuple[str, str]] = [ + (r"\[y/N\]\s*$", "y"), + (r"\[Y/n\]\s*$", "y"), + (r"\[yes/no\]\s*$", "yes"), + (r"\(yes/no\)\s*$", "yes"), + (r"\(yes/no/\[fingerprint\]\)\s*$", "yes"), + (r"continue\?\s*$", "y"), + (r"are you sure\?\s*$", "y"), + (r"password:\s*$", ""), # 密码提示不自动应答,需要人工介入 + (r"passphrase:\s*$", ""), +] + + +@dataclass +class PTYOutput: + """PTY 输出结果 + + Attributes: + output: 输出文本 + exit_code: 退出码(-1 表示超时或未结束) + timed_out: 是否超时 + """ + + output: str + exit_code: int = -1 + timed_out: bool = False + + +class PTYSession: + """伪终端会话 - 支持交互式命令 + + 使用 os.openpty() 创建伪终端对,通过 asyncio 异步读写。 + 支持自动检测提示并应答(如 yes/no 确认)。 + + Usage: + pty = PTYSession() + await pty.start() + output = await pty.run_command("ssh-keygen", timeout=10) + await pty.close() + """ + + def __init__( + self, + auto_respond: bool = True, + custom_rules: list[tuple[str, str]] | None = None, + default_timeout: float = 30.0, + buffer_size: int = 4096, + ): + """初始化 PTY 会话 + + Args: + auto_respond: 是否自动应答已知提示 + custom_rules: 自定义应答规则列表 [(prompt_pattern, response)] + default_timeout: 默认超时时间(秒) + buffer_size: 读取缓冲区大小 + """ + self._auto_respond = auto_respond + self._respond_rules = list(_AUTO_RESPOND_RULES) + if custom_rules: + self._respond_rules.extend(custom_rules) + self._default_timeout = default_timeout + self._buffer_size = buffer_size + + self._master_fd: int | None = None + self._slave_fd: int | None = None + self._process: asyncio.subprocess.Process | None = None + self._running = False + self._output_buffer = "" + + @property + def is_running(self) -> bool: + """PTY 会话是否已启动(伪终端已创建)""" + return self._running + + async def start(self) -> None: + """启动 PTY 会话(创建伪终端对) + + 在执行命令前调用,创建 master/slave 文件描述符。 + """ + if self._running: + return + + self._master_fd, self._slave_fd = os.openpty() + + # 设置终端大小 + try: + winsize = struct.pack("HHHH", 24, 80, 0, 0) + fcntl.ioctl(self._slave_fd, termios.TIOCSWINSZ, winsize) + except Exception: + pass + + # 设置 master fd 为非阻塞 + flags = fcntl.fcntl(self._master_fd, fcntl.F_GETFL) + fcntl.fcntl(self._master_fd, fcntl.F_SETFL, flags | os.O_NONBLOCK) + + self._running = True + logger.debug("PTY session started: master=%d slave=%d", self._master_fd, self._slave_fd) + + async def run_command( + self, + command: str, + timeout: float | None = None, + cwd: str | None = None, + env: dict[str, str] | None = None, + ) -> PTYOutput: + """在 PTY 中运行命令,等待完成 + + Args: + command: 要执行的命令 + timeout: 超时时间(秒) + cwd: 工作目录 + env: 环境变量 + + Returns: + PTYOutput 输出结果 + """ + if not self._running: + await self.start() + + timeout = timeout or self._default_timeout + self._output_buffer = "" + + # 构建环境 + cmd_env = dict(os.environ) + if env: + cmd_env.update(env) + + # 启动子进程,使用 slave 作为 stdin/stdout/stderr + self._process = await asyncio.create_subprocess_shell( + command, + stdin=self._slave_fd, + stdout=self._slave_fd, + stderr=self._slave_fd, + cwd=cwd, + env=cmd_env, + start_new_session=True, + ) + + # 关闭 slave 端(子进程已继承) + os.close(self._slave_fd) + self._slave_fd = None + + # 异步读取输出 + try: + exit_code = await self._read_until_exit(timeout) + except asyncio.TimeoutError: + self._output_buffer += "\n[PTY 命令执行超时]" + return PTYOutput( + output=self._output_buffer, + exit_code=-1, + timed_out=True, + ) + + return PTYOutput( + output=self._output_buffer, + exit_code=exit_code, + timed_out=False, + ) + + async def send(self, line: str) -> None: + """向 PTY 发送一行输入 + + Args: + line: 要发送的文本(自动追加换行符) + """ + if self._master_fd is None: + return + + data = (line + "\n").encode("utf-8") + try: + os.write(self._master_fd, data) + except OSError as e: + logger.warning("PTY write failed: %s", e) + + async def read_output(self, timeout: float = 1.0) -> str: + """读取当前可用的 PTY 输出 + + Args: + timeout: 读取超时(秒) + + Returns: + 读取到的输出文本 + """ + if self._master_fd is None: + return "" + + output = "" + deadline = time.monotonic() + timeout + + while time.monotonic() < deadline: + try: + chunk = os.read(self._master_fd, self._buffer_size) + if chunk: + text = chunk.decode("utf-8", errors="replace") + output += text + self._output_buffer += text + + # 自动应答 + if self._auto_respond: + await self._try_auto_respond(text) + except (BlockingIOError, OSError): + # 没有数据可读 + await asyncio.sleep(0.05) + continue + + if output: + # 读取到数据后短暂等待看是否还有更多 + await asyncio.sleep(0.05) + + return output + + async def close(self) -> None: + """关闭 PTY 会话,清理资源""" + self._running = False + + if self._process is not None and self._process.returncode is None: + try: + self._process.terminate() + await asyncio.wait_for(self._process.wait(), timeout=5.0) + except (asyncio.TimeoutError, ProcessLookupError): + try: + self._process.kill() + except ProcessLookupError: + pass + + if self._master_fd is not None: + try: + os.close(self._master_fd) + except OSError: + pass + self._master_fd = None + + if self._slave_fd is not None: + try: + os.close(self._slave_fd) + except OSError: + pass + self._slave_fd = None + + self._process = None + logger.debug("PTY session closed") + + async def _read_until_exit(self, timeout: float) -> int: + """持续读取输出直到进程退出 + + Args: + timeout: 超时时间(秒) + + Returns: + 进程退出码 + """ + deadline = time.monotonic() + timeout + + while True: + # 检查超时 + if time.monotonic() > deadline: + raise asyncio.TimeoutError() + + # 检查进程是否已退出 + if self._process.returncode is not None: + # 进程已退出,再读一次剩余输出 + await self._drain_remaining_output() + return self._process.returncode + + # 读取输出 + try: + chunk = os.read(self._master_fd, self._buffer_size) + if chunk: + text = chunk.decode("utf-8", errors="replace") + self._output_buffer += text + + # 自动应答 + if self._auto_respond: + await self._try_auto_respond(text) + except (BlockingIOError, OSError): + pass + + # 检查进程状态 + if self._process.returncode is None: + try: + await asyncio.wait_for( + self._process.wait(), timeout=0.05 + ) + except asyncio.TimeoutError: + pass + else: + await self._drain_remaining_output() + return self._process.returncode + + await asyncio.sleep(0.02) + + async def _drain_remaining_output(self) -> None: + """排空剩余输出""" + for _ in range(10): # 最多尝试 10 次 + try: + chunk = os.read(self._master_fd, self._buffer_size) + if chunk: + text = chunk.decode("utf-8", errors="replace") + self._output_buffer += text + else: + break + except (BlockingIOError, OSError): + break + await asyncio.sleep(0.01) + + async def _try_auto_respond(self, recent_output: str) -> None: + """检测提示并自动应答 + + Args: + recent_output: 最近的输出文本 + """ + import re + + for pattern, response in self._respond_rules: + if not response: + # 空响应规则(如密码提示)跳过 + continue + if re.search(pattern, recent_output, re.IGNORECASE | re.MULTILINE): + logger.debug("Auto-responding to prompt '%s' with '%s'", pattern, response) + await self.send(response) + break diff --git a/src/agentkit/tools/shell.py b/src/agentkit/tools/shell.py new file mode 100644 index 0000000..2d9bfdb --- /dev/null +++ b/src/agentkit/tools/shell.py @@ -0,0 +1,432 @@ +"""ShellTool - Shell 命令执行工具 + +支持无会话模式(向后兼容)和有会话模式(跨命令保持状态)。 +危险命令通过确认回调请求人工确认,所有操作记录审计日志。 +""" + +from __future__ import annotations + +import asyncio +import logging +import os +import time +from typing import Any, Callable, Awaitable + +from agentkit.tools.base import Tool +from agentkit.tools.output_parser import OutputParser, ParsedOutput +from agentkit.tools.terminal_session import TerminalSession, TerminalSessionManager +from agentkit.tools.pty_session import PTYSession + +logger = logging.getLogger(__name__) + +# 安全白名单:这些命令前缀不需要确认 +_SAFE_COMMAND_PREFIXES: tuple[str, ...] = ( + "ls", + "cat", + "head", + "tail", + "grep", + "find", + "pwd", + "echo", + "which", + "whoami", + "id", + "date", + "uname", + "df", + "du", + "free", + "ps", + "top", + "env", + "printenv", + "type", + "file", + "stat", + "wc", + "sort", + "uniq", + "diff", + "git status", + "git log", + "git diff", + "git branch", + "git remote", + "pip list", + "pip show", + "python --version", + "python3 --version", + "node --version", + "npm list", + "docker ps", + "docker images", + "curl", + "wget", +) + +# 危险命令模式:这些命令需要人工确认 +_DANGEROUS_PATTERNS: tuple[str, ...] = ( + "rm ", + "rm -", + "rmdir", + "mkfs", + "dd ", + "format", + "del ", + "erase", + "> /dev/", + "shutdown", + "reboot", + "init 0", + "init 6", + "kill -9", + "killall", + "chmod 777", + "chown", + "mv /", + "pip uninstall", + "npm uninstall", + "apt remove", + "yum remove", + "brew uninstall", + "docker rm", + "docker rmi", + "git push --force", + "git reset --hard", + "git clean -f", + "drop table", + "drop database", + "truncate", +) + + +class ShellTool(Tool): + """Shell 命令执行工具 + + 支持两种模式: + 1. 无会话模式(默认):每次命令独立执行,不保持状态 + 2. 有会话模式:通过 session_id 指定会话,跨命令保持 cwd/env/history + + 安全控制: + - 危险命令通过 confirm_callback 请求人工确认 + - 所有操作记录审计日志 + + Usage: + # 无会话模式 + tool = ShellTool() + result = await tool.execute(command="ls -la") + + # 有会话模式 + result = await tool.execute(command="cd /tmp", session_id="build-01") + result = await tool.execute(command="pwd", session_id="build-01") # 输出 /tmp + """ + + def __init__( + self, + name: str = "shell", + description: str = "执行 Shell 命令,支持会话模式保持跨命令状态", + input_schema: dict[str, Any] | None = None, + output_schema: dict[str, Any] | None = None, + version: str = "1.0.0", + tags: list[str] | None = None, + confirm_callback: Callable[[str], Awaitable[bool]] | None = None, + default_timeout: float = 60.0, + max_output_length: int = 50000, + ): + super().__init__( + name=name, + description=description, + input_schema=input_schema or self._default_input_schema(), + output_schema=output_schema or self._default_output_schema(), + version=version, + tags=tags or ["shell", "terminal", "system"], + ) + self._session_manager = TerminalSessionManager() + self._output_parser = OutputParser() + self._confirm_callback = confirm_callback + self._default_timeout = default_timeout + self._max_output_length = max_output_length + self._audit_log: list[dict[str, Any]] = [] + + @staticmethod + def _default_input_schema() -> dict[str, Any]: + return { + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "要执行的 Shell 命令", + }, + "timeout": { + "type": "number", + "description": "超时时间(秒),默认 60", + "default": 60, + }, + "working_dir": { + "type": "string", + "description": "工作目录(仅无会话模式有效)", + }, + "session_id": { + "type": "string", + "description": "会话 ID,指定后在会话中执行命令,跨命令保持状态", + }, + "interactive": { + "type": "boolean", + "description": "是否使用交互式模式(PTY),用于需要用户输入的命令", + "default": False, + }, + }, + "required": ["command"], + } + + @staticmethod + def _default_output_schema() -> dict[str, Any]: + return { + "type": "object", + "properties": { + "output": {"type": "string", "description": "命令输出"}, + "exit_code": {"type": "integer", "description": "退出码"}, + "is_error": {"type": "boolean", "description": "是否为错误"}, + "error_type": {"type": "string", "description": "错误类型"}, + "message": {"type": "string", "description": "消息摘要"}, + "suggestions": { + "type": "array", + "items": {"type": "string"}, + "description": "可操作建议", + }, + "session_id": {"type": "string", "description": "会话 ID(仅会话模式)"}, + }, + } + + async def execute(self, **kwargs) -> dict: + """执行 Shell 命令 + + Args: + command: 要执行的命令(必需) + timeout: 超时时间(秒) + working_dir: 工作目录(仅无会话模式) + session_id: 会话 ID(启用会话模式) + interactive: 是否使用交互式模式 + + Returns: + 包含 output, exit_code, is_error 等字段的字典 + """ + command = kwargs.get("command") + if not command: + return { + "output": "", + "exit_code": 1, + "is_error": True, + "error_type": "invalid_argument", + "message": "command 参数是必需的", + "suggestions": ["提供要执行的 Shell 命令"], + } + + timeout = kwargs.get("timeout", self._default_timeout) + working_dir = kwargs.get("working_dir") + session_id = kwargs.get("session_id") + interactive = kwargs.get("interactive", False) + + # 安全检查:危险命令需要确认 + if self._is_dangerous(command): + confirmed = await self._request_confirmation(command) + if not confirmed: + self._log_audit(command, None, blocked=True) + return { + "output": "", + "exit_code": 126, + "is_error": True, + "error_type": "permission_denied", + "message": f"危险命令已被拒绝执行: {command[:100]}", + "suggestions": [ + "如需执行此命令,请手动确认", + "考虑使用更安全的替代命令", + ], + } + + # 根据模式执行 + if session_id: + result = await self._execute_in_session( + command, session_id, timeout, working_dir, interactive + ) + else: + result = await self._execute_standalone(command, timeout, working_dir, interactive) + + # 审计日志 + self._log_audit(command, session_id, exit_code=result.exit_code) + + # 截断过长输出 + output = result.raw_output + if len(output) > self._max_output_length: + output = output[: self._max_output_length] + "\n... [输出已截断]" + + return { + "output": output, + "exit_code": result.exit_code, + "is_error": result.is_error, + "error_type": result.error_type.value, + "message": result.message, + "suggestions": result.suggestions, + "session_id": session_id, + } + + async def _execute_standalone( + self, + command: str, + timeout: float, + working_dir: str | None, + interactive: bool, + ) -> ParsedOutput: + """无会话模式执行命令(向后兼容)""" + if interactive: + return await self._execute_with_pty(command, timeout, working_dir) + + start = time.monotonic() + try: + proc = await asyncio.create_subprocess_shell( + command, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.STDOUT, + cwd=working_dir, + ) + try: + stdout, _ = await asyncio.wait_for( + proc.communicate(), + timeout=timeout, + ) + except asyncio.TimeoutError: + proc.kill() + await proc.wait() + output = f"命令执行超时({timeout}s)" + exit_code = -1 + else: + output = stdout.decode("utf-8", errors="replace") if stdout else "" + exit_code = proc.returncode if proc.returncode is not None else 0 + except Exception as e: + output = str(e) + exit_code = -1 + + return self._output_parser.parse(output, exit_code) + + async def _execute_in_session( + self, + command: str, + session_id: str, + timeout: float, + working_dir: str | None, + interactive: bool, + ) -> ParsedOutput: + """会话模式执行命令""" + session = self._session_manager.get_or_create( + session_id, + cwd=working_dir, + ) + + if interactive: + return await self._execute_with_pty( + command, timeout, session.cwd, session.env + ) + + return await session.execute(command, timeout=timeout) + + async def _execute_with_pty( + self, + command: str, + timeout: float, + cwd: str | None = None, + env: dict[str, str] | None = None, + ) -> ParsedOutput: + """使用 PTY 执行交互式命令""" + pty = PTYSession() + try: + await pty.start() + result = await pty.run_command( + command, + timeout=timeout, + cwd=cwd, + env=env, + ) + output = result.output + exit_code = result.exit_code + except Exception as e: + output = str(e) + exit_code = -1 + finally: + await pty.close() + + return self._output_parser.parse(output, exit_code) + + def _is_dangerous(self, command: str) -> bool: + """检查命令是否为危险操作 + + 白名单命令直接放行,其他命令检查是否匹配危险模式。 + """ + command_stripped = command.strip() + + # 白名单检查 + for prefix in _SAFE_COMMAND_PREFIXES: + if command_stripped.startswith(prefix): + return False + + # 危险模式检查 + command_lower = command_stripped.lower() + for pattern in _DANGEROUS_PATTERNS: + if pattern in command_lower: + return True + + return False + + async def _request_confirmation(self, command: str) -> bool: + """请求人工确认危险命令 + + Args: + command: 待确认的命令 + + Returns: + 是否确认执行 + """ + if self._confirm_callback: + try: + return await self._confirm_callback(command) + except Exception as e: + logger.warning("确认回调执行失败: %s", e) + return False + + # 无回调时默认拒绝 + logger.warning("危险命令被拒绝(无确认回调): %s", command[:100]) + return False + + def _log_audit( + self, + command: str, + session_id: str | None, + exit_code: int | None = None, + blocked: bool = False, + ) -> None: + """记录审计日志""" + entry = { + "timestamp": time.time(), + "command": command[:500], + "session_id": session_id, + "exit_code": exit_code, + "blocked": blocked, + } + self._audit_log.append(entry) + logger.info( + "Shell audit: command=%r session=%s exit=%s blocked=%s", + command[:100], + session_id, + exit_code, + blocked, + ) + + @property + def session_manager(self) -> TerminalSessionManager: + """获取会话管理器""" + return self._session_manager + + @property + def audit_log(self) -> list[dict[str, Any]]: + """获取审计日志(副本)""" + return list(self._audit_log) diff --git a/src/agentkit/tools/terminal_session.py b/src/agentkit/tools/terminal_session.py new file mode 100644 index 0000000..6ab72db --- /dev/null +++ b/src/agentkit/tools/terminal_session.py @@ -0,0 +1,352 @@ +"""TerminalSession - 终端会话状态管理 + +维护 cwd、env、history,支持跨命令保持状态。 +通过在命令前注入 cd 和 export 语句实现跨命令状态持久化。 +""" + +from __future__ import annotations + +import asyncio +import logging +import os +import time +from dataclasses import dataclass, field +from typing import Any + +from agentkit.tools.output_parser import OutputParser, ParsedOutput + +logger = logging.getLogger(__name__) + + +@dataclass +class CommandRecord: + """命令执行记录 + + Attributes: + command: 执行的命令 + exit_code: 退出码 + output: 标准输出+错误输出 + cwd: 执行时的工作目录 + timestamp: 执行时间戳 + duration_ms: 执行耗时(毫秒) + """ + + command: str + exit_code: int + output: str + cwd: str + timestamp: float + duration_ms: int + + +class TerminalSession: + """终端会话 - 跨命令保持 cwd/env/history 状态 + + 通过在命令前注入 `cd {cwd} && ` 和 `export K=V && ` 实现跨命令状态持久化。 + 每次命令执行后自动更新 cwd 和 env 状态。 + + Usage: + session = TerminalSession(session_id="build-01") + result = await session.execute("cd /tmp") + result = await session.execute("pwd") # 输出 /tmp + """ + + def __init__( + self, + session_id: str, + cwd: str | None = None, + env: dict[str, str] | None = None, + max_history: int = 1000, + ): + self.session_id = session_id + self._cwd = cwd or os.getcwd() + self._env: dict[str, str] = dict(env or os.environ) + self._history: list[CommandRecord] = [] + self._max_history = max_history + self._output_parser = OutputParser() + self._created_at = time.time() + + @property + def cwd(self) -> str: + """当前工作目录""" + return self._cwd + + @property + def env(self) -> dict[str, str]: + """当前环境变量(副本)""" + return dict(self._env) + + @property + def history(self) -> list[CommandRecord]: + """命令执行历史(副本)""" + return list(self._history) + + @property + def created_at(self) -> float: + """会话创建时间戳""" + return self._created_at + + def get_cwd(self) -> str: + """获取当前工作目录""" + return self._cwd + + def set_cwd(self, cwd: str) -> None: + """手动设置当前工作目录""" + self._cwd = cwd + + def get_env(self) -> dict[str, str]: + """获取当前环境变量(副本)""" + return dict(self._env) + + def set_env(self, key: str, value: str) -> None: + """设置单个环境变量""" + self._env[key] = value + + def update_env(self, env: dict[str, str]) -> None: + """批量更新环境变量""" + self._env.update(env) + + def get_history(self) -> list[CommandRecord]: + """获取命令执行历史(副本)""" + return list(self._history) + + async def execute( + self, + command: str, + timeout: float | None = None, + ) -> ParsedOutput: + """在会话上下文中执行命令 + + 自动在命令前注入 cd 和 export 语句以保持会话状态。 + 执行后自动更新 cwd 和 env。 + + Args: + command: 要执行的命令 + timeout: 超时时间(秒),None 表示不超时 + + Returns: + ParsedOutput 结构化解析结果 + """ + full_command = self._build_command(command) + start = time.monotonic() + + try: + proc = await asyncio.create_subprocess_shell( + full_command, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.STDOUT, + env=self._env, + ) + try: + stdout, _ = await asyncio.wait_for( + proc.communicate(), + timeout=timeout, + ) + except asyncio.TimeoutError: + proc.kill() + await proc.wait() + output = f"命令执行超时({timeout}s)" + exit_code = -1 + else: + output = stdout.decode("utf-8", errors="replace") if stdout else "" + exit_code = proc.returncode if proc.returncode is not None else 0 + except Exception as e: + output = str(e) + exit_code = -1 + + duration_ms = int((time.monotonic() - start) * 1000) + + # 更新会话状态 + self._update_state_after_execution(command, output, exit_code) + + # 记录历史 + record = CommandRecord( + command=command, + exit_code=exit_code, + output=output, + cwd=self._cwd, + timestamp=time.time(), + duration_ms=duration_ms, + ) + self._add_history(record) + + # 解析输出 + parsed = self._output_parser.parse(output, exit_code) + logger.debug( + "Session %s: command=%r exit_code=%d duration=%dms", + self.session_id, + command, + exit_code, + duration_ms, + ) + return parsed + + def _build_command(self, command: str) -> str: + """构建带会话状态的完整命令 + + 在原始命令前注入 cd 和 export 语句。 + """ + parts: list[str] = [] + + # 注入 cd + if self._cwd: + # 使用 shlex.quote 风格的简单转义 + cwd_escaped = self._cwd.replace("'", "'\\''") + parts.append(f"cd '{cwd_escaped}'") + + # 注入环境变量 + for key, value in self._env.items(): + # 跳过 os.environ 中已有的且值未变的变量,减少命令长度 + val_escaped = value.replace("'", "'\\''") + parts.append(f"export {key}='{val_escaped}'") + + parts.append(command) + return " && ".join(parts) + + def _update_state_after_execution( + self, command: str, output: str, exit_code: int + ) -> None: + """命令执行后更新会话状态 + + 解析 cd 和 export 命令更新 cwd 和 env。 + """ + if exit_code != 0: + return + + # 解析 cd 命令更新 cwd + self._parse_cd_commands(command, output) + + # 解析 export 命令更新 env + self._parse_export_commands(command) + + def _parse_cd_commands(self, command: str, output: str) -> None: + """从命令中解析 cd 并更新 cwd + + 支持: + - cd /path + - cd dir (相对路径,需要通过 pwd 获取实际路径) + - cd - (切换到上一个目录) + """ + import re + + # 匹配 cd 命令(可能出现在 && 链中) + cd_pattern = re.compile(r"(?:^|\s|&&\s*)cd\s+(.+?)(?:\s*&&|\s*$)") + matches = cd_pattern.findall(command) + + for target in matches: + target = target.strip().strip("'\"") + if not target: + continue + + if target == "-": + # cd - 切换到 OLDPWD + old_pwd = self._env.get("OLDPWD") + if old_pwd: + self._cwd = old_pwd + elif os.path.isabs(target): + self._cwd = target + else: + # 相对路径:拼接后规范化 + new_cwd = os.path.normpath(os.path.join(self._cwd, target)) + self._cwd = new_cwd + + def _parse_export_commands(self, command: str) -> None: + """从命令中解析 export 并更新 env + + 支持: + - export KEY=VALUE + - export KEY="VALUE WITH SPACES" + """ + import re + + export_pattern = re.compile( + r"(?:^|\s|&&\s*)export\s+(\w+)=(.+?)(?:\s*&&|\s*$)" + ) + matches = export_pattern.findall(command) + + for key, value in matches: + value = value.strip().strip("'\"") + self._env[key] = value + + def _add_history(self, record: CommandRecord) -> None: + """添加命令记录到历史,超出上限时移除最旧记录""" + self._history.append(record) + while len(self._history) > self._max_history: + self._history.pop(0) + + def close(self) -> None: + """关闭会话,清理资源""" + logger.info( + "Session %s closed: %d commands executed", + self.session_id, + len(self._history), + ) + + +class TerminalSessionManager: + """终端会话管理器 - 按 ID 管理多个 TerminalSession + + Usage: + manager = TerminalSessionManager() + session = manager.get_or_create("build-01") + session = manager.get("build-01") + manager.remove("build-01") + """ + + def __init__(self, max_sessions: int = 100): + self._sessions: dict[str, TerminalSession] = {} + self._max_sessions = max_sessions + + def get_or_create( + self, + session_id: str, + cwd: str | None = None, + env: dict[str, str] | None = None, + ) -> TerminalSession: + """获取或创建会话 + + Args: + session_id: 会话 ID + cwd: 初始工作目录(仅创建时使用) + env: 初始环境变量(仅创建时使用) + + Returns: + TerminalSession 实例 + """ + if session_id not in self._sessions: + if len(self._sessions) >= self._max_sessions: + # 移除最旧的会话 + oldest_id = min( + self._sessions, key=lambda k: self._sessions[k].created_at + ) + self.remove(oldest_id) + self._sessions[session_id] = TerminalSession( + session_id=session_id, + cwd=cwd, + env=env, + ) + logger.info("Session created: %s", session_id) + return self._sessions[session_id] + + def get(self, session_id: str) -> TerminalSession | None: + """获取会话,不存在返回 None""" + return self._sessions.get(session_id) + + def remove(self, session_id: str) -> None: + """移除并关闭会话""" + session = self._sessions.pop(session_id, None) + if session: + session.close() + + def list_sessions(self) -> list[str]: + """列出所有会话 ID""" + return list(self._sessions.keys()) + + def has_session(self, session_id: str) -> bool: + """检查会话是否存在""" + return session_id in self._sessions + + def close_all(self) -> None: + """关闭所有会话""" + for session_id in list(self._sessions.keys()): + self.remove(session_id) diff --git a/tests/unit/evolution/test_path_optimizer.py b/tests/unit/evolution/test_path_optimizer.py new file mode 100644 index 0000000..61ff91e --- /dev/null +++ b/tests/unit/evolution/test_path_optimizer.py @@ -0,0 +1,512 @@ +"""Tests for PathOptimizer - 执行路径优化器""" + +from __future__ import annotations + +from datetime import datetime, timezone + +import pytest + +from agentkit.evolution.path_optimizer import ExecutionPath, PathOptimizer, PathUpdateResult + + +# ── Fixtures ────────────────────────────────────────────── + + +@pytest.fixture +def optimizer(): + """默认 PathOptimizer 实例""" + return PathOptimizer(min_sample_count=3, success_rate_threshold=0.05, duration_improvement_threshold=0.2) + + +@pytest.fixture +def optimizer_custom_thresholds(): + """自定义阈值的 PathOptimizer""" + return PathOptimizer( + min_sample_count=5, + success_rate_threshold=0.1, + duration_improvement_threshold=0.3, + ) + + +def _make_path( + task_type: str = "code_review", + steps: list[str] | None = None, + total_duration: float = 10.0, + success_rate: float = 0.8, + sample_count: int = 5, + is_recommended: bool = False, + path_id: str = "", + created_at: datetime | None = None, +) -> ExecutionPath: + """创建测试用 ExecutionPath""" + return ExecutionPath( + path_id=path_id, + task_type=task_type, + steps=steps or ["step1", "step2", "step3"], + total_duration=total_duration, + success_rate=success_rate, + sample_count=sample_count, + is_recommended=is_recommended, + created_at=created_at or datetime.now(timezone.utc), + ) + + +# ── ExecutionPath 数据模型测试 ──────────────────────────── + + +class TestExecutionPath: + def test_default_values(self): + path = ExecutionPath() + assert path.path_id == "" + assert path.task_type == "" + assert path.steps == [] + assert path.total_duration == 0.0 + assert path.success_rate == 0.0 + assert path.sample_count == 0 + assert path.is_recommended is False + assert isinstance(path.created_at, datetime) + + def test_custom_values(self): + now = datetime.now(timezone.utc) + path = ExecutionPath( + path_id="p1", + task_type="code_review", + steps=["analyze", "review", "report"], + total_duration=15.5, + success_rate=0.9, + sample_count=10, + is_recommended=True, + created_at=now, + ) + assert path.path_id == "p1" + assert path.task_type == "code_review" + assert path.steps == ["analyze", "review", "report"] + assert path.total_duration == 15.5 + assert path.success_rate == 0.9 + assert path.sample_count == 10 + assert path.is_recommended is True + assert path.created_at == now + + +# ── PathUpdateResult 数据模型测试 ───────────────────────── + + +class TestPathUpdateResult: + def test_default_values(self): + result = PathUpdateResult() + assert result.updated is False + assert result.old_path is None + assert result.new_path is None + assert result.reason == "" + + def test_updated_result(self): + old = _make_path(success_rate=0.7) + new = _make_path(success_rate=0.9) + result = PathUpdateResult( + updated=True, + old_path=old, + new_path=new, + reason="成功率显著提升", + ) + assert result.updated is True + assert result.old_path.success_rate == 0.7 + assert result.new_path.success_rate == 0.9 + assert "成功率" in result.reason + + +# ── get_recommended_path 测试 ───────────────────────────── + + +class TestGetRecommendedPath: + async def test_no_recommended_path(self, optimizer): + result = optimizer.get_recommended_path("code_review") + assert result is None + + async def test_returns_recommended_path(self, optimizer): + path = _make_path(task_type="code_review", success_rate=0.8, sample_count=5) + await optimizer.evaluate_and_update("code_review", path) + result = optimizer.get_recommended_path("code_review") + assert result is not None + assert result.success_rate == 0.8 + assert result.is_recommended is True + + async def test_different_task_types_independent(self, optimizer): + path_a = _make_path(task_type="code_review", success_rate=0.8, sample_count=5) + path_b = _make_path(task_type="data_analysis", success_rate=0.9, sample_count=5) + await optimizer.evaluate_and_update("code_review", path_a) + await optimizer.evaluate_and_update("data_analysis", path_b) + + result_a = optimizer.get_recommended_path("code_review") + result_b = optimizer.get_recommended_path("data_analysis") + assert result_a is not None + assert result_b is not None + assert result_a.success_rate == 0.8 + assert result_b.success_rate == 0.9 + + +# ── 样本量不足测试 ──────────────────────────────────────── + + +class TestInsufficientSamples: + async def test_insufficient_samples_no_update(self, optimizer): + """样本量不足 → 不更新,记录待观察""" + path = _make_path(sample_count=2, success_rate=0.9) + result = await optimizer.evaluate_and_update("code_review", path) + assert result.updated is False + assert "样本量不足" in result.reason + assert optimizer.get_recommended_path("code_review") is None + + async def test_insufficient_samples_recorded_as_pending(self, optimizer): + """样本量不足的路径被记录到待观察列表""" + path = _make_path(sample_count=2, success_rate=0.9) + await optimizer.evaluate_and_update("code_review", path) + pending = optimizer.get_pending_paths("code_review") + assert len(pending) == 1 + assert pending[0].success_rate == 0.9 + + async def test_exact_min_samples_updates(self, optimizer): + """刚好达到最小样本量 → 可以更新""" + path = _make_path(sample_count=3, success_rate=0.8) + result = await optimizer.evaluate_and_update("code_review", path) + assert result.updated is True + assert result.reason == "无现有推荐路径,直接设为推荐" + + async def test_custom_min_sample_count(self, optimizer_custom_thresholds): + """自定义最小样本量""" + path = _make_path(sample_count=4, success_rate=0.9) + result = await optimizer_custom_thresholds.evaluate_and_update("code_review", path) + assert result.updated is False + assert "样本量不足" in result.reason + + +# ── 首次设置推荐路径测试 ────────────────────────────────── + + +class TestFirstRecommendation: + async def test_first_path_becomes_recommended(self, optimizer): + """无现有推荐路径时,新路径直接设为推荐""" + path = _make_path(success_rate=0.7, sample_count=5) + result = await optimizer.evaluate_and_update("code_review", path) + assert result.updated is True + assert result.old_path is None + assert result.new_path is not None + assert result.new_path.is_recommended is True + assert "无现有推荐路径" in result.reason + + async def test_auto_generates_path_id(self, optimizer): + """未提供 path_id 时自动生成""" + path = _make_path(path_id="", sample_count=5) + result = await optimizer.evaluate_and_update("code_review", path) + assert result.updated is True + assert result.new_path is not None + assert len(result.new_path.path_id) > 0 + + +# ── 成功率显著提升测试 ──────────────────────────────────── + + +class TestSuccessRateImprovement: + async def test_higher_success_rate_updates(self, optimizer): + """新路径成功率更高 → 更新推荐路径""" + old_path = _make_path(success_rate=0.7, sample_count=5) + await optimizer.evaluate_and_update("code_review", old_path) + + new_path = _make_path(success_rate=0.85, sample_count=5) + result = await optimizer.evaluate_and_update("code_review", new_path) + assert result.updated is True + assert result.old_path.success_rate == 0.7 + assert result.new_path.success_rate == 0.85 + assert "成功率显著提升" in result.reason + + async def test_marginal_success_rate_no_update(self, optimizer): + """成功率提升不足阈值 → 不更新""" + old_path = _make_path(success_rate=0.8, sample_count=5) + await optimizer.evaluate_and_update("code_review", old_path) + + # 提升仅 0.03,低于默认阈值 0.05 + new_path = _make_path(success_rate=0.83, sample_count=5) + result = await optimizer.evaluate_and_update("code_review", new_path) + assert result.updated is False + assert "无明显优势" in result.reason + + async def test_custom_success_rate_threshold(self, optimizer_custom_thresholds): + """自定义成功率阈值""" + old_path = _make_path(success_rate=0.7, sample_count=10) + await optimizer_custom_thresholds.evaluate_and_update("code_review", old_path) + + # 提升 0.08,低于自定义阈值 0.1 + new_path = _make_path(success_rate=0.78, sample_count=10) + result = await optimizer_custom_thresholds.evaluate_and_update("code_review", new_path) + assert result.updated is False + + async def test_lower_success_rate_no_update(self, optimizer): + """新路径成功率更低 → 不更新""" + old_path = _make_path(success_rate=0.9, sample_count=5) + await optimizer.evaluate_and_update("code_review", old_path) + + new_path = _make_path(success_rate=0.6, sample_count=5) + result = await optimizer.evaluate_and_update("code_review", new_path) + assert result.updated is False + + +# ── 耗时显著更短测试 ────────────────────────────────────── + + +class TestDurationImprovement: + async def test_shorter_duration_with_similar_success_rate_updates(self, optimizer): + """成功率相近但耗时显著更短 → 更新推荐路径""" + old_path = _make_path(total_duration=100.0, success_rate=0.8, sample_count=5) + await optimizer.evaluate_and_update("code_review", old_path) + + # 耗时减少 30%(> 20% 阈值),成功率相近 + new_path = _make_path(total_duration=70.0, success_rate=0.82, sample_count=5) + result = await optimizer.evaluate_and_update("code_review", new_path) + assert result.updated is True + assert "耗时显著更短" in result.reason + + async def test_marginal_duration_improvement_no_update(self, optimizer): + """耗时改善不足阈值 → 不更新""" + old_path = _make_path(total_duration=100.0, success_rate=0.8, sample_count=5) + await optimizer.evaluate_and_update("code_review", old_path) + + # 耗时减少仅 10%(< 20% 阈值) + new_path = _make_path(total_duration=90.0, success_rate=0.82, sample_count=5) + result = await optimizer.evaluate_and_update("code_review", new_path) + assert result.updated is False + assert "无明显优势" in result.reason + + async def test_longer_duration_no_update(self, optimizer): + """耗时更长 → 不更新""" + old_path = _make_path(total_duration=50.0, success_rate=0.8, sample_count=5) + await optimizer.evaluate_and_update("code_review", old_path) + + new_path = _make_path(total_duration=80.0, success_rate=0.82, sample_count=5) + result = await optimizer.evaluate_and_update("code_review", new_path) + assert result.updated is False + + async def test_custom_duration_improvement_threshold(self, optimizer_custom_thresholds): + """自定义耗时改善阈值""" + old_path = _make_path(total_duration=100.0, success_rate=0.8, sample_count=10) + await optimizer_custom_thresholds.evaluate_and_update("code_review", old_path) + + # 耗时减少 25%(< 30% 自定义阈值) + new_path = _make_path(total_duration=75.0, success_rate=0.82, sample_count=10) + result = await optimizer_custom_thresholds.evaluate_and_update("code_review", new_path) + assert result.updated is False + + async def test_zero_duration_current_path(self, optimizer): + """现有路径耗时为 0 → 不因耗时更新""" + old_path = _make_path(total_duration=0.0, success_rate=0.8, sample_count=5) + await optimizer.evaluate_and_update("code_review", old_path) + + new_path = _make_path(total_duration=10.0, success_rate=0.82, sample_count=5) + result = await optimizer.evaluate_and_update("code_review", new_path) + assert result.updated is False + + async def test_both_zero_duration(self, optimizer): + """两者耗时均为 0 → 不因耗时更新""" + old_path = _make_path(total_duration=0.0, success_rate=0.8, sample_count=5) + await optimizer.evaluate_and_update("code_review", old_path) + + new_path = _make_path(total_duration=0.0, success_rate=0.82, sample_count=5) + result = await optimizer.evaluate_and_update("code_review", new_path) + assert result.updated is False + + +# ── 保留现有推荐路径测试 ────────────────────────────────── + + +class TestKeepCurrentPath: + async def test_no_advantage_keeps_current(self, optimizer): + """新路径无明显优势 → 保留现有推荐路径""" + old_path = _make_path(total_duration=50.0, success_rate=0.8, sample_count=5) + await optimizer.evaluate_and_update("code_review", old_path) + + new_path = _make_path(total_duration=48.0, success_rate=0.79, sample_count=5) + result = await optimizer.evaluate_and_update("code_review", new_path) + assert result.updated is False + assert result.old_path.success_rate == 0.8 + # 推荐路径不变 + recommended = optimizer.get_recommended_path("code_review") + assert recommended is not None + assert recommended.success_rate == 0.8 + + async def test_is_recommended_flag_preserved(self, optimizer): + """未更新时,现有路径的 is_recommended 标志保持为 True""" + old_path = _make_path(success_rate=0.8, sample_count=5) + await optimizer.evaluate_and_update("code_review", old_path) + + new_path = _make_path(success_rate=0.79, sample_count=5) + await optimizer.evaluate_and_update("code_review", new_path) + + recommended = optimizer.get_recommended_path("code_review") + assert recommended is not None + assert recommended.is_recommended is True + + +# ── is_recommended 标志管理测试 ──────────────────────────── + + +class TestIsRecommendedFlag: + async def test_old_path_loses_recommended_flag(self, optimizer): + """更新后旧路径的 is_recommended 变为 False""" + old_path = _make_path(success_rate=0.7, sample_count=5) + await optimizer.evaluate_and_update("code_review", old_path) + assert old_path.is_recommended is True # 首次设置,is_recommended 为 True + + new_path = _make_path(success_rate=0.9, sample_count=5) + result = await optimizer.evaluate_and_update("code_review", new_path) + assert result.updated is True + assert result.old_path.is_recommended is False # 更新后旧路径失去标志 + assert result.new_path.is_recommended is True + + +# ── 多次迭代优化测试 ────────────────────────────────────── + + +class TestIterativeOptimization: + async def test_multiple_updates_converge_to_best(self, optimizer): + """多次迭代后推荐路径收敛到最优""" + # 第一次:初始路径 + path1 = _make_path(success_rate=0.6, total_duration=100.0, sample_count=5) + await optimizer.evaluate_and_update("code_review", path1) + assert optimizer.get_recommended_path("code_review").success_rate == 0.6 + + # 第二次:成功率显著提升 + path2 = _make_path(success_rate=0.8, total_duration=90.0, sample_count=5) + await optimizer.evaluate_and_update("code_review", path2) + assert optimizer.get_recommended_path("code_review").success_rate == 0.8 + + # 第三次:成功率相近但耗时更短 + path3 = _make_path(success_rate=0.82, total_duration=50.0, sample_count=5) + await optimizer.evaluate_and_update("code_review", path3) + assert optimizer.get_recommended_path("code_review").total_duration == 50.0 + + # 第四次:无明显优势 + path4 = _make_path(success_rate=0.81, total_duration=48.0, sample_count=5) + result = await optimizer.evaluate_and_update("code_review", path4) + assert result.updated is False + assert optimizer.get_recommended_path("code_review").total_duration == 50.0 + + async def test_different_task_types_evolve_independently(self, optimizer): + """不同任务类型的推荐路径独立进化""" + path_a1 = _make_path(task_type="code_review", success_rate=0.7, sample_count=5) + path_b1 = _make_path(task_type="data_analysis", success_rate=0.6, sample_count=5) + await optimizer.evaluate_and_update("code_review", path_a1) + await optimizer.evaluate_and_update("data_analysis", path_b1) + + path_a2 = _make_path(task_type="code_review", success_rate=0.9, sample_count=5) + await optimizer.evaluate_and_update("code_review", path_a2) + + # code_review 更新了,data_analysis 不受影响 + assert optimizer.get_recommended_path("code_review").success_rate == 0.9 + assert optimizer.get_recommended_path("data_analysis").success_rate == 0.6 + + +# ── 待观察路径管理测试 ──────────────────────────────────── + + +class TestPendingPaths: + async def test_pending_paths_empty_initially(self, optimizer): + assert optimizer.get_pending_paths("code_review") == [] + + async def test_pending_paths_accumulate(self, optimizer): + """多次样本不足的路径会累积""" + path1 = _make_path(sample_count=1, success_rate=0.9) + path2 = _make_path(sample_count=2, success_rate=0.85) + await optimizer.evaluate_and_update("code_review", path1) + await optimizer.evaluate_and_update("code_review", path2) + + pending = optimizer.get_pending_paths("code_review") + assert len(pending) == 2 + + async def test_pending_paths_isolated_by_task_type(self, optimizer): + """不同任务类型的待观察路径相互隔离""" + path_a = _make_path(task_type="code_review", sample_count=1, success_rate=0.9) + path_b = _make_path(task_type="data_analysis", sample_count=1, success_rate=0.8) + await optimizer.evaluate_and_update("code_review", path_a) + await optimizer.evaluate_and_update("data_analysis", path_b) + + assert len(optimizer.get_pending_paths("code_review")) == 1 + assert len(optimizer.get_pending_paths("data_analysis")) == 1 + + async def test_sufficient_samples_not_pending(self, optimizer): + """样本量充足的路径不会进入待观察列表""" + path = _make_path(sample_count=5, success_rate=0.8) + await optimizer.evaluate_and_update("code_review", path) + assert optimizer.get_pending_paths("code_review") == [] + + +# ── ExperienceStore 集成测试 ────────────────────────────── + + +class TestExperienceStoreIntegration: + async def test_with_experience_store(self): + """PathOptimizer 可以接受 ExperienceStore 实例""" + from agentkit.evolution.experience_store import InMemoryExperienceStore + + store = InMemoryExperienceStore() + optimizer = PathOptimizer(experience_store=store, min_sample_count=3) + + path = _make_path(success_rate=0.8, sample_count=5) + result = await optimizer.evaluate_and_update("code_review", path) + assert result.updated is True + + async def test_without_experience_store(self, optimizer): + """PathOptimizer 可以不依赖 ExperienceStore 独立运行""" + path = _make_path(success_rate=0.8, sample_count=5) + result = await optimizer.evaluate_and_update("code_review", path) + assert result.updated is True + + +# ── 边界条件测试 ────────────────────────────────────────── + + +class TestEdgeCases: + async def test_same_path_twice(self, optimizer): + """提交相同路径两次""" + path = _make_path(success_rate=0.8, sample_count=5) + result1 = await optimizer.evaluate_and_update("code_review", path) + assert result1.updated is True + + # 第二次提交相同参数的路径(但不同实例) + path2 = _make_path(success_rate=0.8, sample_count=5) + result2 = await optimizer.evaluate_and_update("code_review", path2) + # 成功率相同,耗时相同 → 无明显优势 + assert result2.updated is False + + async def test_success_rate_at_boundary(self, optimizer): + """成功率刚好在阈值边界""" + old_path = _make_path(success_rate=0.8, sample_count=5) + await optimizer.evaluate_and_update("code_review", old_path) + + # 提升恰好等于阈值 0.05,不满足 > threshold + new_path = _make_path(success_rate=0.85, sample_count=5) + result = await optimizer.evaluate_and_update("code_review", new_path) + assert result.updated is False + + async def test_duration_improvement_at_boundary(self, optimizer): + """耗时改善刚好在阈值边界""" + old_path = _make_path(total_duration=100.0, success_rate=0.8, sample_count=5) + await optimizer.evaluate_and_update("code_review", old_path) + + # 改善恰好等于阈值 20%,不满足 > threshold + new_path = _make_path(total_duration=80.0, success_rate=0.82, sample_count=5) + result = await optimizer.evaluate_and_update("code_review", new_path) + assert result.updated is False + + async def test_zero_sample_count(self, optimizer): + """样本量为 0""" + path = _make_path(sample_count=0, success_rate=0.9) + result = await optimizer.evaluate_and_update("code_review", path) + assert result.updated is False + assert "样本量不足" in result.reason + + async def test_path_task_type_override(self, optimizer): + """evaluate_and_update 会用传入的 task_type 覆盖路径的 task_type""" + path = _make_path(task_type="wrong_type", success_rate=0.8, sample_count=5) + result = await optimizer.evaluate_and_update("code_review", path) + assert result.updated is True + assert path.task_type == "code_review" + recommended = optimizer.get_recommended_path("code_review") + assert recommended is not None diff --git a/tests/unit/evolution/test_pitfall_detector.py b/tests/unit/evolution/test_pitfall_detector.py new file mode 100644 index 0000000..31c9e09 --- /dev/null +++ b/tests/unit/evolution/test_pitfall_detector.py @@ -0,0 +1,595 @@ +"""Tests for PitfallDetector - 任务避坑预警检测""" + +from __future__ import annotations + +from datetime import datetime, timezone + +import pytest + +from agentkit.core.plan_schema import PlanStep, PlanStepStatus +from agentkit.evolution.experience_schema import TaskExperience +from agentkit.evolution.experience_store import InMemoryExperienceStore +from agentkit.evolution.pitfall_detector import ( + PitfallDetector, + PitfallWarning, + WarningLevel, + _compute_name_similarity, + _determine_warning_level, + _extract_keywords, +) + + +# ── Fixtures ────────────────────────────────────────────── + + +@pytest.fixture +def store(): + """无 embedder 的 InMemoryExperienceStore""" + return InMemoryExperienceStore(decay_rate=0.01, alpha=0.7) + + +@pytest.fixture +def detector(store): + """基于 InMemoryExperienceStore 的 PitfallDetector""" + return PitfallDetector(experience_store=store, similarity_threshold=0.3) + + +def _make_experience( + task_type: str = "code_review", + goal: str = "Review the PR", + outcome: str = "success", + steps_summary: str | list[dict] = "", + failure_reasons: list[str] | None = None, + optimization_tips: list[str] | None = None, + success_rate: float = 1.0, +) -> TaskExperience: + """创建测试用 TaskExperience""" + return TaskExperience( + experience_id="", + task_type=task_type, + goal=goal, + steps_summary=steps_summary, + outcome=outcome, + duration_seconds=10.0, + success_rate=success_rate, + failure_reasons=failure_reasons or [], + optimization_tips=optimization_tips or [], + created_at=datetime.now(timezone.utc), + ) + + +def _make_step( + name: str = "step", + description: str = "do something", + step_id: str = "s1", +) -> PlanStep: + """创建测试用 PlanStep""" + return PlanStep( + step_id=step_id, + name=name, + description=description, + status=PlanStepStatus.PENDING, + ) + + +# ── 辅助函数测试 ────────────────────────────────────────── + + +class TestExtractKeywords: + def test_basic_extraction(self): + keywords = _extract_keywords("Call API Gateway") + assert "call" in keywords + assert "api" in keywords + assert "gateway" in keywords + + def test_stop_words_filtered(self): + keywords = _extract_keywords("Call the API and check the result") + assert "the" not in keywords + assert "and" not in keywords + assert "call" in keywords + assert "api" in keywords + + def test_underscore_and_hyphen(self): + keywords = _extract_keywords("call_api-gateway") + assert "call" in keywords + assert "api" in keywords + assert "gateway" in keywords + + def test_single_char_filtered(self): + keywords = _extract_keywords("a b cd") + assert "a" not in keywords + assert "b" not in keywords + assert "cd" in keywords + + def test_empty_string(self): + keywords = _extract_keywords("") + assert len(keywords) == 0 + + +class TestComputeNameSimilarity: + def test_identical_names(self): + sim = _compute_name_similarity("Call API Gateway", "", "Call API Gateway") + assert sim == pytest.approx(1.0) + + def test_partial_overlap(self): + sim = _compute_name_similarity("Call API Gateway", "", "Call External API") + # 共享: call, api; 并集: call, api, gateway, external + assert 0.0 < sim < 1.0 + + def test_no_overlap(self): + sim = _compute_name_similarity("Deploy Service", "", "Analyze Data") + assert sim == 0.0 + + def test_description_contributes(self): + sim_no_desc = _compute_name_similarity("Deploy", "", "Deploy Service") + sim_with_desc = _compute_name_similarity("Deploy", "Deploy Service", "Deploy Service") + # description 中包含匹配关键词,应提高相似度 + assert sim_with_desc >= sim_no_desc + + def test_empty_inputs(self): + sim = _compute_name_similarity("", "", "Call API") + assert sim == 0.0 + + +class TestDetermineWarningLevel: + def test_high_threshold(self): + assert _determine_warning_level(0.6) == WarningLevel.HIGH + assert _determine_warning_level(0.5) == WarningLevel.HIGH + + def test_medium_threshold(self): + assert _determine_warning_level(0.3) == WarningLevel.MEDIUM + assert _determine_warning_level(0.2) == WarningLevel.MEDIUM + + def test_low_threshold(self): + assert _determine_warning_level(0.1) == WarningLevel.LOW + assert _determine_warning_level(0.01) == WarningLevel.LOW + + +# ── PitfallDetector.check_pitfalls 测试 ────────────────── + + +class TestCheckPitfalls: + async def test_no_planned_steps_returns_empty(self, detector): + warnings = await detector.check_pitfalls(task_type="code_review", planned_steps=[]) + assert warnings == [] + + async def test_no_failed_experiences_returns_empty(self, detector, store): + """无历史失败记录 → 返回空列表""" + # 只记录成功经验 + await store.record_experience( + _make_experience(task_type="code_review", outcome="success") + ) + steps = [_make_step(name="Review Code")] + warnings = await detector.check_pitfalls(task_type="code_review", planned_steps=steps) + assert warnings == [] + + async def test_high_failure_rate_returns_high_warning(self, detector, store): + """计划包含历史高失败率步骤 → 返回 HIGH 级别预警""" + # 记录多次失败经验,其中 "Call API Gateway" 步骤失败率高 + for _ in range(6): + await store.record_experience( + _make_experience( + task_type="deployment", + outcome="failure", + success_rate=0.0, + steps_summary=[ + {"step_name": "Call API Gateway", "outcome": "failure", "error": "Timeout"}, + {"step_name": "Deploy Container", "outcome": "success"}, + ], + failure_reasons=["API Gateway timeout"], + ) + ) + # 记录少数成功经验 + for _ in range(4): + await store.record_experience( + _make_experience( + task_type="deployment", + outcome="success", + success_rate=1.0, + steps_summary=[ + {"step_name": "Call API Gateway", "outcome": "success"}, + {"step_name": "Deploy Container", "outcome": "success"}, + ], + ) + ) + + steps = [_make_step(name="Call API Gateway", description="Invoke API Gateway endpoint")] + warnings = await detector.check_pitfalls(task_type="deployment", planned_steps=steps) + + assert len(warnings) == 1 + warning = warnings[0] + assert warning.step_name == "Call API Gateway" + assert warning.warning_level == WarningLevel.HIGH + assert warning.failure_rate >= 0.5 + assert "Timeout" in warning.historical_failures + + async def test_medium_failure_rate(self, detector, store): + """中等失败率 → MEDIUM 级别预警""" + # 3 次失败,7 次成功 → 失败率 0.3 + for _ in range(3): + await store.record_experience( + _make_experience( + task_type="data_analysis", + outcome="failure", + success_rate=0.0, + steps_summary=[ + {"step_name": "Fetch Data", "outcome": "failure", "error": "Connection refused"}, + ], + ) + ) + for _ in range(7): + await store.record_experience( + _make_experience( + task_type="data_analysis", + outcome="success", + success_rate=1.0, + steps_summary=[ + {"step_name": "Fetch Data", "outcome": "success"}, + ], + ) + ) + + steps = [_make_step(name="Fetch Data", description="Fetch data from source")] + warnings = await detector.check_pitfalls(task_type="data_analysis", planned_steps=steps) + + assert len(warnings) == 1 + assert warnings[0].warning_level == WarningLevel.MEDIUM + assert 0.2 <= warnings[0].failure_rate < 0.5 + + async def test_low_failure_rate(self, detector, store): + """低失败率 → LOW 级别预警""" + # 1 次失败,9 次成功 → 失败率 0.1 + await store.record_experience( + _make_experience( + task_type="testing", + outcome="failure", + success_rate=0.0, + steps_summary=[ + {"step_name": "Run Unit Tests", "outcome": "failure", "error": "Flaky test"}, + ], + ) + ) + for _ in range(9): + await store.record_experience( + _make_experience( + task_type="testing", + outcome="success", + success_rate=1.0, + steps_summary=[ + {"step_name": "Run Unit Tests", "outcome": "success"}, + ], + ) + ) + + steps = [_make_step(name="Run Unit Tests", description="Execute unit test suite")] + warnings = await detector.check_pitfalls(task_type="testing", planned_steps=steps) + + assert len(warnings) == 1 + assert warnings[0].warning_level == WarningLevel.LOW + + async def test_multiple_steps_with_risks_sorted_by_severity(self, detector, store): + """多个步骤有风险 → 按严重程度排序返回""" + # "Call API" 高失败率,"Validate Input" 低失败率 + for _ in range(6): + await store.record_experience( + _make_experience( + task_type="integration", + outcome="failure", + success_rate=0.0, + steps_summary=[ + {"step_name": "Call API", "outcome": "failure", "error": "Timeout"}, + {"step_name": "Validate Input", "outcome": "success"}, + ], + ) + ) + for _ in range(4): + await store.record_experience( + _make_experience( + task_type="integration", + outcome="success", + success_rate=1.0, + steps_summary=[ + {"step_name": "Call API", "outcome": "success"}, + {"step_name": "Validate Input", "outcome": "success"}, + ], + ) + ) + # 单独给 Validate Input 加一条失败记录 + await store.record_experience( + _make_experience( + task_type="integration", + outcome="partial", + success_rate=0.5, + steps_summary=[ + {"step_name": "Call API", "outcome": "success"}, + {"step_name": "Validate Input", "outcome": "failure", "error": "Invalid schema"}, + ], + ) + ) + + steps = [ + _make_step(name="Validate Input", description="Validate input data", step_id="s1"), + _make_step(name="Call API", description="Call external API", step_id="s2"), + ] + warnings = await detector.check_pitfalls(task_type="integration", planned_steps=steps) + + assert len(warnings) == 2 + # HIGH 应排在 MEDIUM/LOW 之前 + assert warnings[0].warning_level == WarningLevel.HIGH + assert warnings[0].step_name == "Call API" + + async def test_no_matching_steps_returns_empty(self, detector, store): + """计划步骤与历史失败步骤无匹配 → 返回空列表""" + await store.record_experience( + _make_experience( + task_type="code_review", + outcome="failure", + success_rate=0.0, + steps_summary=[ + {"step_name": "Run Linter", "outcome": "failure", "error": "Config error"}, + ], + ) + ) + + # 计划步骤名称与历史步骤完全不同 + steps = [_make_step(name="Deploy Application", description="Deploy to production")] + warnings = await detector.check_pitfalls(task_type="code_review", planned_steps=steps) + assert warnings == [] + + async def test_different_task_type_no_cross_contamination(self, detector, store): + """不同 task_type 的失败经验不会跨类型预警""" + await store.record_experience( + _make_experience( + task_type="deployment", + outcome="failure", + success_rate=0.0, + steps_summary=[ + {"step_name": "Deploy Service", "outcome": "failure", "error": "OOM"}, + ], + ) + ) + + # 查询 code_review 类型,不应返回 deployment 的失败经验 + steps = [_make_step(name="Deploy Service", description="Deploy the service")] + warnings = await detector.check_pitfalls(task_type="code_review", planned_steps=steps) + assert warnings == [] + + async def test_partial_outcome_included(self, detector, store): + """partial 结果的经验也应被检索""" + await store.record_experience( + _make_experience( + task_type="migration", + outcome="partial", + success_rate=0.5, + steps_summary=[ + {"step_name": "Migrate Database", "outcome": "failure", "error": "Schema mismatch"}, + ], + ) + ) + + steps = [_make_step(name="Migrate Database", description="Migrate DB schema")] + warnings = await detector.check_pitfalls(task_type="migration", planned_steps=steps) + assert len(warnings) == 1 + + async def test_steps_summary_as_string_ignored(self, detector, store): + """steps_summary 为字符串时无法提取步骤级信息,不产生预警""" + await store.record_experience( + _make_experience( + task_type="code_review", + outcome="failure", + success_rate=0.0, + steps_summary="Executed code_review task", # 字符串格式 + ) + ) + + steps = [_make_step(name="Review Code", description="Review the code")] + warnings = await detector.check_pitfalls(task_type="code_review", planned_steps=steps) + assert warnings == [] + + +# ── AE3 场景测试 ────────────────────────────────────────── + + +class TestAE3Scenario: + """AE3: "调用 X 系统 API 在高峰期超时率 60%" → 新任务调用时自动预警""" + + async def test_api_timeout_high_failure_rate_warning(self, detector, store): + """调用 X 系统 API 在高峰期超时率 60% → 新任务调用时自动预警""" + # 模拟历史:10 次调用,6 次超时 → 60% 失败率 + for _ in range(6): + await store.record_experience( + _make_experience( + task_type="order_processing", + goal="Process orders via X system", + outcome="failure", + success_rate=0.0, + steps_summary=[ + {"step_name": "Call X System API", "outcome": "failure", "error": "高峰期超时"}, + {"step_name": "Process Order", "outcome": "success"}, + ], + failure_reasons=["X System API timeout during peak hours"], + optimization_tips=["Avoid peak hours", "Add retry logic"], + ) + ) + for _ in range(4): + await store.record_experience( + _make_experience( + task_type="order_processing", + goal="Process orders via X system", + outcome="success", + success_rate=1.0, + steps_summary=[ + {"step_name": "Call X System API", "outcome": "success"}, + {"step_name": "Process Order", "outcome": "success"}, + ], + ) + ) + + # 新任务计划包含调用 X 系统 API + steps = [ + _make_step(name="Call X System API", description="Invoke X system API for orders"), + ] + warnings = await detector.check_pitfalls(task_type="order_processing", planned_steps=steps) + + assert len(warnings) == 1 + warning = warnings[0] + assert warning.warning_level == WarningLevel.HIGH + assert warning.failure_rate >= 0.5 + assert any("超时" in reason for reason in warning.historical_failures) + assert warning.suggestion # 应有建议 + + +# ── PitfallWarning 数据模型测试 ─────────────────────────── + + +class TestPitfallWarning: + def test_creation(self): + warning = PitfallWarning( + step_name="Call API", + warning_level=WarningLevel.HIGH, + failure_rate=0.6, + historical_failures=["Timeout", "Connection refused"], + suggestion="Add retry logic", + ) + assert warning.step_name == "Call API" + assert warning.warning_level == WarningLevel.HIGH + assert warning.failure_rate == 0.6 + assert warning.historical_failures == ["Timeout", "Connection refused"] + assert warning.suggestion == "Add retry logic" + + def test_default_values(self): + warning = PitfallWarning( + step_name="Test", + warning_level=WarningLevel.LOW, + failure_rate=0.1, + ) + assert warning.historical_failures == [] + assert warning.suggestion == "" + + +# ── WarningLevel 枚举测试 ───────────────────────────────── + + +class TestWarningLevel: + def test_values(self): + assert WarningLevel.HIGH.value == "high" + assert WarningLevel.MEDIUM.value == "medium" + assert WarningLevel.LOW.value == "low" + + def test_string_comparison(self): + assert WarningLevel.HIGH == "high" + assert WarningLevel.MEDIUM == "medium" + assert WarningLevel.LOW == "low" + + +# ── 相似度阈值配置测试 ───────────────────────────────────── + + +class TestSimilarityThreshold: + async def test_custom_threshold(self, store): + """自定义相似度阈值""" + # 低阈值:更容易匹配 + detector_low = PitfallDetector(experience_store=store, similarity_threshold=0.1) + # 高阈值:更难匹配 + detector_high = PitfallDetector(experience_store=store, similarity_threshold=0.8) + + await store.record_experience( + _make_experience( + task_type="testing", + outcome="failure", + success_rate=0.0, + steps_summary=[ + {"step_name": "Run Integration Tests", "outcome": "failure", "error": "Timeout"}, + ], + ) + ) + + steps = [_make_step(name="Run Unit Tests", description="Execute tests")] + # 低阈值可能匹配,高阈值可能不匹配 + warnings_low = await detector_low.check_pitfalls(task_type="testing", planned_steps=steps) + warnings_high = await detector_high.check_pitfalls(task_type="testing", planned_steps=steps) + # 低阈值匹配数 >= 高阈值匹配数 + assert len(warnings_low) >= len(warnings_high) + + +# ── 端到端流程测试 ───────────────────────────────────────── + + +class TestEndToEnd: + async def test_full_pitfall_detection_flow(self, detector, store): + """完整的避坑检测流程""" + # 1. 记录多种失败经验 + await store.record_experience( + _make_experience( + task_type="deployment", + outcome="failure", + success_rate=0.0, + steps_summary=[ + {"step_name": "Build Docker Image", "outcome": "failure", "error": "OOM"}, + {"step_name": "Push to Registry", "outcome": "success"}, + ], + failure_reasons=["Docker build OOM"], + optimization_tips=["Increase memory limit"], + ) + ) + await store.record_experience( + _make_experience( + task_type="deployment", + outcome="failure", + success_rate=0.0, + steps_summary=[ + {"step_name": "Build Docker Image", "outcome": "failure", "error": "Dependency conflict"}, + {"step_name": "Push to Registry", "outcome": "success"}, + ], + failure_reasons=["Dependency conflict"], + ) + ) + await store.record_experience( + _make_experience( + task_type="deployment", + outcome="success", + success_rate=1.0, + steps_summary=[ + {"step_name": "Build Docker Image", "outcome": "success"}, + {"step_name": "Push to Registry", "outcome": "success"}, + ], + ) + ) + + # 2. 新任务计划 + steps = [ + _make_step(name="Build Docker Image", description="Build the container image", step_id="s1"), + _make_step(name="Push to Registry", description="Push image to container registry", step_id="s2"), + ] + + # 3. 检测避坑 + warnings = await detector.check_pitfalls(task_type="deployment", planned_steps=steps) + + # 4. 验证结果 + assert len(warnings) >= 1 + # Build Docker Image 失败率 2/3 ≈ 0.667,应为 HIGH + build_warning = next((w for w in warnings if w.step_name == "Build Docker Image"), None) + assert build_warning is not None + assert build_warning.warning_level == WarningLevel.HIGH + assert build_warning.failure_rate == pytest.approx(2.0 / 3.0, abs=0.01) + + async def test_suggestion_contains_useful_info(self, detector, store): + """预警建议应包含有用的失败原因和优化建议""" + await store.record_experience( + _make_experience( + task_type="api_integration", + outcome="failure", + success_rate=0.0, + steps_summary=[ + {"step_name": "Authenticate", "outcome": "failure", "error": "Token expired"}, + ], + failure_reasons=["Token expired"], + optimization_tips=["Refresh token before expiry"], + ) + ) + + steps = [_make_step(name="Authenticate", description="Authenticate with API")] + warnings = await detector.check_pitfalls(task_type="api_integration", planned_steps=steps) + + assert len(warnings) == 1 + assert "Token expired" in warnings[0].suggestion diff --git a/tests/unit/tools/test_pty_session.py b/tests/unit/tools/test_pty_session.py new file mode 100644 index 0000000..ce76382 --- /dev/null +++ b/tests/unit/tools/test_pty_session.py @@ -0,0 +1,217 @@ +"""PTYSession 单元测试 + +测试场景: +- PTY 会话启动和关闭 +- 交互式命令执行 +- 自动应答提示 +- 超时处理 +- 自定义应答规则 +""" + +from __future__ import annotations + +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from agentkit.tools.pty_session import PTYSession, PTYOutput +from agentkit.tools.shell import ShellTool + + +class TestPTYSessionConstruction: + """测试 PTYSession 构造""" + + def test_default_construction(self): + pty = PTYSession() + assert pty.is_running is False + assert pty._auto_respond is True + assert pty._default_timeout == 30.0 + + def test_custom_construction(self): + pty = PTYSession( + auto_respond=False, + custom_rules=[(r"confirm\?", "yes")], + default_timeout=60.0, + ) + assert pty._auto_respond is False + assert len(pty._respond_rules) > len(pty._respond_rules) - 1 + assert pty._default_timeout == 60.0 + + +class TestPTYSessionLifecycle: + """测试 PTYSession 生命周期""" + + @pytest.mark.asyncio + async def test_start_and_close(self): + """启动和关闭 PTY 会话""" + pty = PTYSession() + await pty.start() + assert pty.is_running is True + await pty.close() + assert pty.is_running is False + + @pytest.mark.asyncio + async def test_start_idempotent(self): + """重复启动不报错""" + pty = PTYSession() + await pty.start() + await pty.start() # 不应抛出异常 + assert pty.is_running is True + await pty.close() + + @pytest.mark.asyncio + async def test_close_without_start(self): + """未启动时关闭不报错""" + pty = PTYSession() + await pty.close() # 不应抛出异常 + + +class TestPTYSessionExecution: + """测试 PTYSession 命令执行""" + + @pytest.mark.asyncio + async def test_run_simple_command(self): + """执行简单命令""" + pty = PTYSession(default_timeout=10.0) + try: + await pty.start() + result = await pty.run_command("echo hello_pty") + assert "hello_pty" in result.output + assert result.exit_code == 0 + assert result.timed_out is False + finally: + await pty.close() + + @pytest.mark.asyncio + async def test_run_command_with_cwd(self): + """指定工作目录执行命令""" + pty = PTYSession(default_timeout=10.0) + try: + await pty.start() + result = await pty.run_command("pwd", cwd="/tmp") + assert "/tmp" in result.output + finally: + await pty.close() + + @pytest.mark.asyncio + async def test_run_command_with_env(self): + """指定环境变量执行命令""" + pty = PTYSession(default_timeout=10.0) + try: + await pty.start() + result = await pty.run_command( + "echo $PTY_TEST_VAR", + env={"PTY_TEST_VAR": "pty_value"}, + ) + assert "pty_value" in result.output + finally: + await pty.close() + + @pytest.mark.asyncio + async def test_run_failing_command(self): + """执行失败命令""" + pty = PTYSession(default_timeout=10.0) + try: + await pty.start() + result = await pty.run_command("ls /nonexistent_dir_xyz_12345") + assert result.exit_code != 0 + finally: + await pty.close() + + @pytest.mark.asyncio + async def test_run_command_timeout(self): + """命令超时""" + pty = PTYSession(default_timeout=10.0) + try: + await pty.start() + result = await pty.run_command("sleep 30", timeout=0.5) + assert result.timed_out is True + assert result.exit_code == -1 + finally: + await pty.close() + + +class TestPTYSessionAutoRespond: + """测试 PTYSession 自动应答""" + + @pytest.mark.asyncio + async def test_auto_respond_yes_no(self): + """自动应答 [y/N] 提示""" + # 使用 echo 模拟包含提示的输出,然后验证自动应答规则存在 + pty = PTYSession(auto_respond=True) + # 验证规则已加载 + rule_patterns = [r[0] for r in pty._respond_rules] + assert any("y/N" in p or "Y/n" in p for p in rule_patterns) + + @pytest.mark.asyncio + async def test_auto_respond_disabled(self): + """禁用自动应答""" + pty = PTYSession(auto_respond=False) + assert pty._auto_respond is False + + @pytest.mark.asyncio + async def test_custom_respond_rules(self): + """自定义应答规则""" + pty = PTYSession( + auto_respond=True, + custom_rules=[(r"continue\?\s*$", "yes")], + ) + rule_patterns = [r[0] for r in pty._respond_rules] + assert r"continue\?\s*$" in rule_patterns + + +class TestPTYSessionSendAndRead: + """测试 PTYSession 发送和读取""" + + @pytest.mark.asyncio + async def test_send_without_start(self): + """未启动时发送不报错""" + pty = PTYSession() + await pty.send("test") # 不应抛出异常 + + @pytest.mark.asyncio + async def test_read_output_without_start(self): + """未启动时读取返回空""" + pty = PTYSession() + output = await pty.read_output() + assert output == "" + + +class TestPTYOutput: + """测试 PTYOutput 数据类""" + + def test_default_values(self): + output = PTYOutput(output="test") + assert output.output == "test" + assert output.exit_code == -1 + assert output.timed_out is False + + def test_custom_values(self): + output = PTYOutput(output="error", exit_code=1, timed_out=True) + assert output.exit_code == 1 + assert output.timed_out is True + + +class TestShellToolInteractiveMode: + """测试 ShellTool 交互式模式""" + + @pytest.mark.asyncio + async def test_interactive_mode(self): + """ShellTool interactive 模式执行命令""" + tool = ShellTool() + result = await tool.execute(command="echo interactive_test", interactive=True) + assert result["exit_code"] == 0 + assert "interactive_test" in result["output"] + + @pytest.mark.asyncio + async def test_interactive_mode_with_session(self): + """ShellTool 会话模式 + 交互式""" + tool = ShellTool() + result = await tool.execute( + command="echo session_interactive", + session_id="int-session", + interactive=True, + ) + assert result["exit_code"] == 0 + assert "session_interactive" in result["output"] diff --git a/tests/unit/tools/test_terminal_session.py b/tests/unit/tools/test_terminal_session.py new file mode 100644 index 0000000..f88d7cf --- /dev/null +++ b/tests/unit/tools/test_terminal_session.py @@ -0,0 +1,600 @@ +"""TerminalSession 和 ShellTool 单元测试 + +测试场景: +- 跨命令保持 cwd → cd 后执行 pwd 返回正确目录 +- 跨命令保持 env → export 后执行 echo 返回正确值 +- 危险命令需确认 → rm 命令触发确认回调 +- 输出解析 → 错误输出结构化为错误类型+建议 +- 无 session_id 时保持现有行为 +- 会话管理器功能 +""" + +from __future__ import annotations + +import os +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from agentkit.tools.terminal_session import TerminalSession, TerminalSessionManager, CommandRecord +from agentkit.tools.shell import ShellTool +from agentkit.tools.output_parser import OutputParser, ParsedOutput, ErrorType + + +# ============================================================ +# OutputParser 测试 +# ============================================================ + + +class TestOutputParser: + """测试 OutputParser 结构化解析""" + + def setup_method(self): + self.parser = OutputParser() + + def test_parse_success_output(self): + """成功输出解析""" + result = self.parser.parse("hello world", 0) + assert result.exit_code == 0 + assert result.is_error is False + assert result.error_type == ErrorType.NONE + assert result.message == "hello world" + assert result.suggestions == [] + + def test_parse_empty_output(self): + """空输出解析""" + result = self.parser.parse("", 0) + assert result.exit_code == 0 + assert result.is_error is False + assert result.message == "" + + def test_parse_permission_denied(self): + """权限不足错误解析""" + result = self.parser.parse("permission denied: /root/secret", 1) + assert result.is_error is True + assert result.error_type == ErrorType.PERMISSION_DENIED + assert len(result.suggestions) > 0 + assert any("sudo" in s for s in result.suggestions) + + def test_parse_not_found(self): + """文件不存在错误解析""" + result = self.parser.parse("No such file or directory: /tmp/missing", 1) + assert result.is_error is True + assert result.error_type == ErrorType.NOT_FOUND + + def test_parse_timeout(self): + """超时错误解析""" + result = self.parser.parse("Connection timed out", 1) + assert result.is_error is True + assert result.error_type == ErrorType.TIMEOUT + + def test_parse_syntax_error(self): + """语法错误解析""" + result = self.parser.parse("syntax error near unexpected token", 2) + assert result.is_error is True + assert result.error_type == ErrorType.SYNTAX_ERROR + + def test_parse_connection_refused(self): + """连接被拒绝解析""" + result = self.parser.parse("Connection refused on port 8080", 1) + assert result.is_error is True + assert result.error_type == ErrorType.CONNECTION_REFUSED + + def test_parse_out_of_memory(self): + """内存不足解析""" + result = self.parser.parse("Out of memory: cannot allocate", 1) + assert result.is_error is True + assert result.error_type == ErrorType.OUT_OF_MEMORY + + def test_parse_disk_full(self): + """磁盘满解析""" + result = self.parser.parse("No space left on device", 1) + assert result.is_error is True + assert result.error_type == ErrorType.DISK_FULL + + def test_parse_already_exists(self): + """已存在解析""" + result = self.parser.parse("File already exists: /tmp/test", 1) + assert result.is_error is True + assert result.error_type == ErrorType.ALREADY_EXISTS + + def test_parse_invalid_argument(self): + """无效参数解析""" + result = self.parser.parse("invalid argument: --unknown-flag", 1) + assert result.is_error is True + assert result.error_type == ErrorType.INVALID_ARGUMENT + + def test_parse_network_error(self): + """网络错误解析""" + result = self.parser.parse("Network is unreachable", 1) + assert result.is_error is True + assert result.error_type == ErrorType.NETWORK_ERROR + + def test_parse_exit_code_126(self): + """退出码 126 → 权限不足""" + result = self.parser.parse("some unknown error", 126) + assert result.is_error is True + assert result.error_type == ErrorType.PERMISSION_DENIED + + def test_parse_exit_code_127(self): + """退出码 127 → 命令未找到""" + result = self.parser.parse("some unknown error", 127) + assert result.is_error is True + assert result.error_type == ErrorType.NOT_FOUND + + def test_parse_exit_code_130(self): + """退出码 130 → 被中断""" + result = self.parser.parse("some unknown error", 130) + assert result.is_error is True + assert result.error_type == ErrorType.TIMEOUT + + def test_parse_unknown_error(self): + """未知错误""" + result = self.parser.parse("something went wrong", 1) + assert result.is_error is True + assert result.error_type == ErrorType.UNKNOWN + + def test_parse_long_message_truncated(self): + """长消息截断""" + long_output = "x" * 300 + result = self.parser.parse(long_output, 0) + assert len(result.message) <= 203 # 200 + "..." + + def test_parsed_output_to_dict(self): + """ParsedOutput.to_dict()""" + result = self.parser.parse("permission denied", 1) + d = result.to_dict() + assert d["exit_code"] == 1 + assert d["is_error"] is True + assert d["error_type"] == "permission_denied" + assert isinstance(d["suggestions"], list) + + def test_parse_chinese_error_messages(self): + """中文错误消息解析""" + result = self.parser.parse("权限不足: 无法访问", 1) + assert result.is_error is True + assert result.error_type == ErrorType.PERMISSION_DENIED + + def test_parse_multiline_output_message_is_last_line(self): + """多行输出取最后一行作为消息""" + output = "line1\nline2\nline3" + result = self.parser.parse(output, 0) + assert result.message == "line3" + + +# ============================================================ +# TerminalSession 测试 +# ============================================================ + + +class TestTerminalSession: + """测试 TerminalSession 会话状态管理""" + + def test_construction_default(self): + """默认构造""" + session = TerminalSession(session_id="test") + assert session.session_id == "test" + assert session.cwd == os.getcwd() + assert isinstance(session.env, dict) + assert session.history == [] + + def test_construction_custom_cwd(self): + """自定义工作目录""" + session = TerminalSession(session_id="test", cwd="/tmp") + assert session.cwd == "/tmp" + + def test_construction_custom_env(self): + """自定义环境变量""" + session = TerminalSession(session_id="test", env={"FOO": "bar"}) + assert session.env.get("FOO") == "bar" + + def test_set_cwd(self): + """手动设置 cwd""" + session = TerminalSession(session_id="test") + session.set_cwd("/usr/local") + assert session.cwd == "/usr/local" + + def test_set_env(self): + """手动设置环境变量""" + session = TerminalSession(session_id="test") + session.set_env("MY_VAR", "hello") + assert session.env.get("MY_VAR") == "hello" + + def test_update_env(self): + """批量更新环境变量""" + session = TerminalSession(session_id="test") + session.update_env({"A": "1", "B": "2"}) + assert session.env.get("A") == "1" + assert session.env.get("B") == "2" + + def test_get_env_returns_copy(self): + """get_env 返回副本,修改不影响原数据""" + session = TerminalSession(session_id="test") + env = session.get_env() + env["HACKED"] = "yes" + assert "HACKED" not in session.env + + def test_get_history_returns_copy(self): + """get_history 返回副本""" + session = TerminalSession(session_id="test") + history = session.get_history() + assert history is not session._history + + @pytest.mark.asyncio + async def test_execute_simple_command(self): + """执行简单命令""" + session = TerminalSession(session_id="test") + result = await session.execute("echo hello") + assert result.exit_code == 0 + assert "hello" in result.raw_output + + @pytest.mark.asyncio + async def test_execute_records_history(self): + """执行命令记录历史""" + session = TerminalSession(session_id="test") + await session.execute("echo first") + await session.execute("echo second") + assert len(session.history) == 2 + assert session.history[0].command == "echo first" + assert session.history[1].command == "echo second" + + @pytest.mark.asyncio + async def test_cross_command_cwd(self): + """跨命令保持 cwd:cd 后 pwd 返回正确目录""" + session = TerminalSession(session_id="test") + await session.execute("cd /tmp") + assert session.cwd == "/tmp" + + result = await session.execute("pwd") + assert "/tmp" in result.raw_output + + @pytest.mark.asyncio + async def test_cross_command_env(self): + """跨命令保持 env:export 后 echo 返回正确值""" + session = TerminalSession(session_id="test") + await session.execute("export MY_TEST_VAR=hello123") + assert session.env.get("MY_TEST_VAR") == "hello123" + + result = await session.execute("echo $MY_TEST_VAR") + assert "hello123" in result.raw_output + + @pytest.mark.asyncio + async def test_cd_relative_path(self): + """cd 相对路径(目录存在时更新 cwd)""" + # 使用 /usr 作为基础目录,cd local(/usr/local 存在) + session = TerminalSession(session_id="test", cwd="/usr") + await session.execute("cd local") + assert session.cwd == "/usr/local" + + @pytest.mark.asyncio + async def test_cd_absolute_path(self): + """cd 绝对路径""" + session = TerminalSession(session_id="test") + await session.execute("cd /usr") + assert session.cwd == "/usr" + + @pytest.mark.asyncio + async def test_failed_command_no_state_update(self): + """失败命令不更新状态""" + session = TerminalSession(session_id="test", cwd="/tmp") + await session.execute("cd /nonexistent_dir_xyz") + # cd 失败,cwd 不应更新 + assert session.cwd == "/tmp" + + @pytest.mark.asyncio + async def test_timeout(self): + """命令超时""" + session = TerminalSession(session_id="test") + result = await session.execute("sleep 10", timeout=0.5) + assert result.exit_code == -1 + assert result.is_error is True + + @pytest.mark.asyncio + async def test_max_history(self): + """历史记录上限""" + session = TerminalSession(session_id="test", max_history=3) + for i in range(5): + await session.execute(f"echo {i}") + assert len(session.history) == 3 + assert session.history[0].command == "echo 2" + + def test_close(self): + """关闭会话""" + session = TerminalSession(session_id="test") + session.close() # 不应抛出异常 + + +# ============================================================ +# TerminalSessionManager 测试 +# ============================================================ + + +class TestTerminalSessionManager: + """测试 TerminalSessionManager 会话管理""" + + def test_get_or_create_new(self): + """创建新会话""" + manager = TerminalSessionManager() + session = manager.get_or_create("s1") + assert session.session_id == "s1" + + def test_get_or_create_existing(self): + """获取已有会话""" + manager = TerminalSessionManager() + s1 = manager.get_or_create("s1") + s1.set_cwd("/tmp") + s2 = manager.get_or_create("s1") + assert s2.cwd == "/tmp" + + def test_get_existing(self): + """get 获取已有会话""" + manager = TerminalSessionManager() + manager.get_or_create("s1") + session = manager.get("s1") + assert session is not None + + def test_get_nonexistent(self): + """get 不存在的会话返回 None""" + manager = TerminalSessionManager() + assert manager.get("nonexistent") is None + + def test_remove(self): + """移除会话""" + manager = TerminalSessionManager() + manager.get_or_create("s1") + manager.remove("s1") + assert manager.get("s1") is None + + def test_list_sessions(self): + """列出会话""" + manager = TerminalSessionManager() + manager.get_or_create("s1") + manager.get_or_create("s2") + assert sorted(manager.list_sessions()) == ["s1", "s2"] + + def test_has_session(self): + """检查会话是否存在""" + manager = TerminalSessionManager() + manager.get_or_create("s1") + assert manager.has_session("s1") is True + assert manager.has_session("s2") is False + + def test_max_sessions_eviction(self): + """超过最大会话数时移除最旧会话""" + manager = TerminalSessionManager(max_sessions=2) + manager.get_or_create("s1") + manager.get_or_create("s2") + manager.get_or_create("s3") # 应该移除 s1 + assert not manager.has_session("s1") + assert manager.has_session("s2") + assert manager.has_session("s3") + + def test_close_all(self): + """关闭所有会话""" + manager = TerminalSessionManager() + manager.get_or_create("s1") + manager.get_or_create("s2") + manager.close_all() + assert manager.list_sessions() == [] + + +# ============================================================ +# ShellTool 测试 +# ============================================================ + + +class TestShellToolConstruction: + """测试 ShellTool 构造""" + + def test_default_construction(self): + tool = ShellTool() + assert tool.name == "shell" + assert tool.input_schema is not None + assert "command" in tool.input_schema["properties"] + assert "session_id" in tool.input_schema["properties"] + assert tool.input_schema["required"] == ["command"] + + def test_custom_construction(self): + tool = ShellTool(name="my_shell", version="2.0.0") + assert tool.name == "my_shell" + assert tool.version == "2.0.0" + + def test_to_dict(self): + tool = ShellTool() + d = tool.to_dict() + assert d["name"] == "shell" + assert "input_schema" in d + + def test_repr(self): + tool = ShellTool() + r = repr(tool) + assert "ShellTool" in r + assert "shell" in r + + +class TestShellToolExecution: + """测试 ShellTool 命令执行""" + + @pytest.mark.asyncio + async def test_execute_simple_command(self): + """执行简单命令(无会话模式)""" + tool = ShellTool() + result = await tool.execute(command="echo hello") + assert result["exit_code"] == 0 + assert "hello" in result["output"] + assert result["is_error"] is False + assert result["session_id"] is None + + @pytest.mark.asyncio + async def test_execute_missing_command(self): + """缺少 command 参数""" + tool = ShellTool() + result = await tool.execute() + assert result["is_error"] is True + assert result["exit_code"] == 1 + + @pytest.mark.asyncio + async def test_execute_with_working_dir(self): + """指定工作目录""" + tool = ShellTool() + result = await tool.execute(command="pwd", working_dir="/tmp") + assert result["exit_code"] == 0 + assert "/tmp" in result["output"] + + @pytest.mark.asyncio + async def test_execute_with_session(self): + """会话模式执行命令""" + tool = ShellTool() + result = await tool.execute(command="echo session_test", session_id="s1") + assert result["exit_code"] == 0 + assert "session_test" in result["output"] + assert result["session_id"] == "s1" + + @pytest.mark.asyncio + async def test_session_preserves_cwd(self): + """会话模式保持 cwd""" + tool = ShellTool() + await tool.execute(command="cd /tmp", session_id="cwd-test") + result = await tool.execute(command="pwd", session_id="cwd-test") + assert "/tmp" in result["output"] + + @pytest.mark.asyncio + async def test_session_preserves_env(self): + """会话模式保持 env""" + tool = ShellTool() + await tool.execute( + command="export SHELL_TEST_VAR=world", session_id="env-test" + ) + result = await tool.execute( + command="echo $SHELL_TEST_VAR", session_id="env-test" + ) + assert "world" in result["output"] + + @pytest.mark.asyncio + async def test_no_session_id_backward_compatible(self): + """无 session_id 时保持现有行为""" + tool = ShellTool() + result = await tool.execute(command="echo no_session") + assert result["exit_code"] == 0 + assert "no_session" in result["output"] + assert result["session_id"] is None + + @pytest.mark.asyncio + async def test_different_sessions_independent(self): + """不同会话互不影响""" + tool = ShellTool() + await tool.execute(command="cd /tmp", session_id="s1") + await tool.execute(command="cd /usr", session_id="s2") + + r1 = await tool.execute(command="pwd", session_id="s1") + r2 = await tool.execute(command="pwd", session_id="s2") + + assert "/tmp" in r1["output"] + assert "/usr" in r2["output"] + + +class TestShellToolSecurity: + """测试 ShellTool 安全控制""" + + @pytest.mark.asyncio + async def test_safe_command_allowed(self): + """安全命令直接执行""" + tool = ShellTool() + result = await tool.execute(command="ls /tmp") + assert result["exit_code"] == 0 + + @pytest.mark.asyncio + async def test_dangerous_command_blocked_without_callback(self): + """危险命令无确认回调时被拒绝""" + tool = ShellTool() + result = await tool.execute(command="rm -rf /tmp/test") + assert result["is_error"] is True + assert result["exit_code"] == 126 + + @pytest.mark.asyncio + async def test_dangerous_command_confirmed(self): + """危险命令通过确认回调允许执行""" + confirm = AsyncMock(return_value=True) + tool = ShellTool(confirm_callback=confirm) + result = await tool.execute(command="rm -rf /tmp/nonexistent_test_dir") + assert confirm.called + # 命令本身可能失败(目录不存在),但不应被安全机制拒绝 + assert result["exit_code"] != 126 or not result["is_error"] + + @pytest.mark.asyncio + async def test_dangerous_command_rejected_by_callback(self): + """确认回调拒绝危险命令""" + confirm = AsyncMock(return_value=False) + tool = ShellTool(confirm_callback=confirm) + result = await tool.execute(command="rm -rf /tmp/test") + assert result["is_error"] is True + assert result["exit_code"] == 126 + + @pytest.mark.asyncio + async def test_audit_log_recorded(self): + """审计日志记录""" + tool = ShellTool() + await tool.execute(command="echo audit_test") + assert len(tool.audit_log) > 0 + assert tool.audit_log[0]["command"] == "echo audit_test" + + @pytest.mark.asyncio + async def test_blocked_command_in_audit_log(self): + """被阻止的命令记录在审计日志""" + tool = ShellTool() + await tool.execute(command="rm -rf /tmp/test") + blocked_entries = [e for e in tool.audit_log if e.get("blocked")] + assert len(blocked_entries) > 0 + + @pytest.mark.asyncio + async def test_git_push_force_is_dangerous(self): + """git push --force 是危险命令""" + tool = ShellTool() + result = await tool.execute(command="git push --force origin main") + assert result["is_error"] is True + assert result["exit_code"] == 126 + + @pytest.mark.asyncio + async def test_git_status_is_safe(self): + """git status 是安全命令""" + tool = ShellTool() + result = await tool.execute(command="git status") + # git status 可能在非 git 目录失败,但不应被安全机制拒绝 + assert result["exit_code"] != 126 + + +class TestShellToolOutputParsing: + """测试 ShellTool 输出解析集成""" + + @pytest.mark.asyncio + async def test_error_output_structured(self): + """错误输出结构化""" + tool = ShellTool() + result = await tool.execute(command="ls /nonexistent_dir_xyz_12345") + assert result["is_error"] is True + assert result["error_type"] in ("not_found", "unknown") + assert isinstance(result["suggestions"], list) + + @pytest.mark.asyncio + async def test_success_output_not_error(self): + """成功输出不标记为错误""" + tool = ShellTool() + result = await tool.execute(command="echo success") + assert result["is_error"] is False + assert result["error_type"] == "none" + + +class TestShellToolSessionManager: + """测试 ShellTool 会话管理器访问""" + + def test_session_manager_accessible(self): + tool = ShellTool() + assert tool.session_manager is not None + + @pytest.mark.asyncio + async def test_session_created_on_first_use(self): + """首次使用 session_id 时创建会话""" + tool = ShellTool() + assert not tool.session_manager.has_session("new-session") + await tool.execute(command="echo test", session_id="new-session") + assert tool.session_manager.has_session("new-session")