feat(phase2): implement self-evolution and smart terminal (U6-U8)
- U6: PitfallDetector - detect historical failure patterns and warn - U7: PathOptimizer - discover and update optimal execution paths - U8: TerminalSession - session state, PTY interactive, output parsing 160 new tests passing. ShellTool enhanced with session_id support.
This commit is contained in:
parent
fd4a811929
commit
e3d4f811dd
|
|
@ -0,0 +1,259 @@
|
|||
"""PathOptimizer - 执行路径优化器
|
||||
|
||||
发现更优执行路径时自动更新经验库中的推荐路径。
|
||||
|
||||
核心逻辑:
|
||||
1. 对比新路径与现有最优路径(综合耗时和成功率)
|
||||
2. 新路径成功率更高 → 更新推荐路径
|
||||
3. 成功率相近但耗时更短 → 更新推荐路径
|
||||
4. 样本量不足 → 不更新,记录待观察
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from agentkit.evolution.experience_store import InMemoryExperienceStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionPath:
|
||||
"""执行路径数据模型
|
||||
|
||||
记录特定任务类型的执行路径信息,用于路径优化比较。
|
||||
|
||||
Attributes:
|
||||
path_id: 路径唯一标识
|
||||
task_type: 任务类型
|
||||
steps: 执行步骤名称列表
|
||||
total_duration: 总耗时(秒)
|
||||
success_rate: 成功率(0.0 ~ 1.0)
|
||||
sample_count: 样本数量
|
||||
is_recommended: 是否为当前推荐路径
|
||||
created_at: 创建时间
|
||||
"""
|
||||
|
||||
path_id: str = ""
|
||||
task_type: str = ""
|
||||
steps: list[str] = field(default_factory=list)
|
||||
total_duration: float = 0.0
|
||||
success_rate: float = 0.0
|
||||
sample_count: int = 0
|
||||
is_recommended: bool = False
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
|
||||
@dataclass
|
||||
class PathUpdateResult:
|
||||
"""路径更新结果
|
||||
|
||||
Attributes:
|
||||
updated: 是否更新了推荐路径
|
||||
old_path: 更新前的推荐路径(未更新时为 None)
|
||||
new_path: 更新后的推荐路径(未更新时为 None)
|
||||
reason: 更新/未更新的原因说明
|
||||
"""
|
||||
|
||||
updated: bool = False
|
||||
old_path: ExecutionPath | None = None
|
||||
new_path: ExecutionPath | None = None
|
||||
reason: str = ""
|
||||
|
||||
|
||||
class PathOptimizer:
|
||||
"""执行路径优化器
|
||||
|
||||
对比新路径与现有最优路径,决定是否更新推荐路径。
|
||||
可独立使用,也可集成到 PlanChecker 的复盘中。
|
||||
|
||||
更新策略:
|
||||
1. 新路径成功率 > 现有成功率 + success_rate_threshold → 更新
|
||||
2. 成功率相近(差值 ≤ threshold)但耗时显著更短
|
||||
(duration 改善比例 > duration_improvement_threshold)→ 更新
|
||||
3. 样本量不足(< min_sample_count)→ 不更新
|
||||
4. 其他情况 → 保留现有推荐路径
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
experience_store: InMemoryExperienceStore | None = None,
|
||||
min_sample_count: int = 3,
|
||||
success_rate_threshold: float = 0.05,
|
||||
duration_improvement_threshold: float = 0.2,
|
||||
):
|
||||
"""初始化 PathOptimizer
|
||||
|
||||
Args:
|
||||
experience_store: 经验存储实例(可选)
|
||||
min_sample_count: 最小样本量,低于此值不做决策
|
||||
success_rate_threshold: 成功率提升阈值,超过此值视为显著提升
|
||||
duration_improvement_threshold: 耗时改善比例阈值,超过此值视为显著改善
|
||||
"""
|
||||
self._experience_store = experience_store
|
||||
self._min_sample_count = min_sample_count
|
||||
self._success_rate_threshold = success_rate_threshold
|
||||
self._duration_improvement_threshold = duration_improvement_threshold
|
||||
self._recommended_paths: dict[str, ExecutionPath] = {}
|
||||
self._pending_paths: dict[str, list[ExecutionPath]] = {}
|
||||
|
||||
def get_recommended_path(self, task_type: str) -> ExecutionPath | None:
|
||||
"""获取指定任务类型的当前推荐路径
|
||||
|
||||
Args:
|
||||
task_type: 任务类型
|
||||
|
||||
Returns:
|
||||
推荐路径,若无则返回 None
|
||||
"""
|
||||
return self._recommended_paths.get(task_type)
|
||||
|
||||
async def evaluate_and_update(
|
||||
self,
|
||||
task_type: str,
|
||||
new_path: ExecutionPath,
|
||||
) -> PathUpdateResult:
|
||||
"""评估新路径并决定是否更新推荐路径
|
||||
|
||||
Args:
|
||||
task_type: 任务类型
|
||||
new_path: 新的执行路径
|
||||
|
||||
Returns:
|
||||
路径更新结果
|
||||
"""
|
||||
# 确保新路径有 path_id
|
||||
if not new_path.path_id:
|
||||
new_path.path_id = str(uuid.uuid4())
|
||||
|
||||
new_path.task_type = task_type
|
||||
|
||||
# 样本量不足 → 不更新,记录待观察
|
||||
if new_path.sample_count < self._min_sample_count:
|
||||
self._pending_paths.setdefault(task_type, []).append(new_path)
|
||||
reason = (
|
||||
f"样本量不足({new_path.sample_count} < {self._min_sample_count}),"
|
||||
f"记录待观察"
|
||||
)
|
||||
logger.info(
|
||||
f"Path not updated for '{task_type}': {reason}"
|
||||
)
|
||||
return PathUpdateResult(
|
||||
updated=False,
|
||||
old_path=None,
|
||||
new_path=new_path,
|
||||
reason=reason,
|
||||
)
|
||||
|
||||
current = self._recommended_paths.get(task_type)
|
||||
|
||||
# 无现有推荐路径 → 直接设为推荐
|
||||
if current is None:
|
||||
new_path.is_recommended = True
|
||||
self._recommended_paths[task_type] = new_path
|
||||
reason = "无现有推荐路径,直接设为推荐"
|
||||
logger.info(f"Path set as recommended for '{task_type}': {reason}")
|
||||
return PathUpdateResult(
|
||||
updated=True,
|
||||
old_path=None,
|
||||
new_path=new_path,
|
||||
reason=reason,
|
||||
)
|
||||
|
||||
# 比较新路径与现有推荐路径
|
||||
return self._compare_and_decide(task_type, current, new_path)
|
||||
|
||||
def _compare_and_decide(
|
||||
self,
|
||||
task_type: str,
|
||||
current: ExecutionPath,
|
||||
new: ExecutionPath,
|
||||
) -> PathUpdateResult:
|
||||
"""比较新旧路径并决策
|
||||
|
||||
比较逻辑:
|
||||
1. 新路径成功率 > 现有成功率 + threshold → 更新
|
||||
2. 成功率相近(差值 ≤ threshold)且新耗时显著更短 → 更新
|
||||
3. 其他 → 保留现有
|
||||
"""
|
||||
sr_diff = new.success_rate - current.success_rate
|
||||
|
||||
# 条件 1:成功率显著提升
|
||||
if sr_diff > self._success_rate_threshold:
|
||||
return self._apply_update(
|
||||
task_type, current, new,
|
||||
f"成功率显著提升({new.success_rate:.2f} > {current.success_rate:.2f},"
|
||||
f"提升 {sr_diff:.2f})",
|
||||
)
|
||||
|
||||
# 条件 2:成功率相近但耗时显著更短
|
||||
if abs(sr_diff) <= self._success_rate_threshold:
|
||||
if current.total_duration > 0:
|
||||
duration_improvement = (
|
||||
(current.total_duration - new.total_duration) / current.total_duration
|
||||
)
|
||||
if (
|
||||
new.total_duration < current.total_duration
|
||||
and duration_improvement > self._duration_improvement_threshold
|
||||
):
|
||||
return self._apply_update(
|
||||
task_type, current, new,
|
||||
f"成功率相近({new.success_rate:.2f} vs {current.success_rate:.2f}),"
|
||||
f"耗时显著更短({new.total_duration:.1f}s vs {current.total_duration:.1f}s,"
|
||||
f"改善 {duration_improvement:.1%})",
|
||||
)
|
||||
elif current.total_duration == 0 and new.total_duration > 0:
|
||||
# 现有路径耗时为 0(不太可能),不更新
|
||||
pass
|
||||
elif current.total_duration == 0 and new.total_duration == 0:
|
||||
# 两者耗时均为 0,不更新
|
||||
pass
|
||||
|
||||
# 条件 3:无明显优势 → 保留现有
|
||||
reason = (
|
||||
f"新路径无明显优势(成功率 {new.success_rate:.2f} vs {current.success_rate:.2f},"
|
||||
f"耗时 {new.total_duration:.1f}s vs {current.total_duration:.1f}s),保留现有推荐路径"
|
||||
)
|
||||
logger.info(f"Path not updated for '{task_type}': {reason}")
|
||||
return PathUpdateResult(
|
||||
updated=False,
|
||||
old_path=current,
|
||||
new_path=new,
|
||||
reason=reason,
|
||||
)
|
||||
|
||||
def _apply_update(
|
||||
self,
|
||||
task_type: str,
|
||||
old: ExecutionPath,
|
||||
new: ExecutionPath,
|
||||
reason: str,
|
||||
) -> PathUpdateResult:
|
||||
"""应用路径更新"""
|
||||
old.is_recommended = False
|
||||
new.is_recommended = True
|
||||
self._recommended_paths[task_type] = new
|
||||
logger.info(f"Path updated for '{task_type}': {reason}")
|
||||
return PathUpdateResult(
|
||||
updated=True,
|
||||
old_path=old,
|
||||
new_path=new,
|
||||
reason=reason,
|
||||
)
|
||||
|
||||
def get_pending_paths(self, task_type: str) -> list[ExecutionPath]:
|
||||
"""获取指定任务类型的待观察路径
|
||||
|
||||
Args:
|
||||
task_type: 任务类型
|
||||
|
||||
Returns:
|
||||
待观察路径列表
|
||||
"""
|
||||
return list(self._pending_paths.get(task_type, []))
|
||||
|
|
@ -0,0 +1,388 @@
|
|||
"""PitfallDetector - 任务避坑预警
|
||||
|
||||
新任务启动时检索历史失败经验,匹配当前计划步骤,自动预警。
|
||||
基于 ExperienceStore 中存储的失败经验,将失败步骤与当前计划步骤
|
||||
进行关键词匹配,计算失败率并按严重程度返回预警列表。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Protocol
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WarningLevel(str, Enum):
|
||||
"""预警级别"""
|
||||
|
||||
HIGH = "high"
|
||||
MEDIUM = "medium"
|
||||
LOW = "low"
|
||||
|
||||
|
||||
@dataclass
|
||||
class PitfallWarning:
|
||||
"""避坑预警
|
||||
|
||||
Attributes:
|
||||
step_name: 计划步骤名称
|
||||
warning_level: 预警级别(HIGH/MEDIUM/LOW)
|
||||
failure_rate: 历史失败率(0.0 ~ 1.0)
|
||||
historical_failures: 历史失败原因列表
|
||||
suggestion: 优化建议
|
||||
"""
|
||||
|
||||
step_name: str
|
||||
warning_level: WarningLevel
|
||||
failure_rate: float
|
||||
historical_failures: list[str] = field(default_factory=list)
|
||||
suggestion: str = ""
|
||||
|
||||
|
||||
class ExperienceStoreProtocol(Protocol):
|
||||
"""ExperienceStore 协议接口,用于类型标注"""
|
||||
|
||||
async def search(
|
||||
self,
|
||||
query: str,
|
||||
top_k: int = 5,
|
||||
task_type: str | None = None,
|
||||
search_multiplier: int = 5,
|
||||
) -> list[Any]:
|
||||
...
|
||||
|
||||
|
||||
# 预警级别阈值
|
||||
_HIGH_THRESHOLD = 0.5
|
||||
_MEDIUM_THRESHOLD = 0.2
|
||||
|
||||
|
||||
class PitfallDetector:
|
||||
"""避坑检测器
|
||||
|
||||
新任务启动时检索历史失败经验,匹配当前计划步骤,自动预警。
|
||||
|
||||
使用方式:
|
||||
detector = PitfallDetector(experience_store)
|
||||
warnings = await detector.check_pitfalls(
|
||||
task_type="code_review",
|
||||
planned_steps=[plan_step1, plan_step2, ...],
|
||||
)
|
||||
|
||||
匹配逻辑:
|
||||
1. 检索同类任务的失败经验
|
||||
2. 从失败经验中提取失败步骤
|
||||
3. 将失败步骤与当前计划步骤进行关键词匹配
|
||||
4. 计算失败率并分配预警级别
|
||||
|
||||
预警级别:
|
||||
- HIGH: failure_rate >= 0.5(历史高失败率步骤)
|
||||
- MEDIUM: failure_rate >= 0.2(有失败记录但频率低)
|
||||
- LOW: 有任何失败记录
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
experience_store: ExperienceStoreProtocol,
|
||||
similarity_threshold: float = 0.3,
|
||||
max_search_results: int = 50,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
experience_store: 经验存储实例(ExperienceStore 或 InMemoryExperienceStore)
|
||||
similarity_threshold: 步骤名称关键词匹配的最小相似度阈值
|
||||
max_search_results: 从经验存储检索的最大结果数
|
||||
"""
|
||||
self._store = experience_store
|
||||
self._similarity_threshold = similarity_threshold
|
||||
self._max_search_results = max_search_results
|
||||
|
||||
async def check_pitfalls(
|
||||
self,
|
||||
task_type: str,
|
||||
planned_steps: list[Any],
|
||||
) -> list[PitfallWarning]:
|
||||
"""检查计划步骤中的潜在陷阱
|
||||
|
||||
Args:
|
||||
task_type: 任务类型
|
||||
planned_steps: 计划步骤列表(PlanStep 对象或具有 name/description 属性的对象)
|
||||
|
||||
Returns:
|
||||
按严重程度排序的预警列表(HIGH → MEDIUM → LOW)
|
||||
"""
|
||||
if not planned_steps:
|
||||
return []
|
||||
|
||||
# 1. 检索同类任务的所有经验(包含成功和失败,用于计算步骤级失败率)
|
||||
all_experiences = await self._search_experiences(task_type)
|
||||
if not all_experiences:
|
||||
logger.debug(f"No experiences found for task_type={task_type}")
|
||||
return []
|
||||
|
||||
# 2. 从经验中提取步骤级别的失败统计
|
||||
step_failure_stats = self._extract_step_failure_stats(all_experiences)
|
||||
|
||||
# 3. 匹配当前计划步骤并生成预警
|
||||
warnings = self._match_and_warn(planned_steps, step_failure_stats)
|
||||
|
||||
# 4. 按严重程度排序(HIGH → MEDIUM → LOW),同级别按失败率降序
|
||||
warnings.sort(key=lambda w: (_warning_level_order(w.warning_level), -w.failure_rate))
|
||||
|
||||
if warnings:
|
||||
logger.info(
|
||||
f"PitfallDetector found {len(warnings)} warnings for task_type={task_type}: "
|
||||
f"{sum(1 for w in warnings if w.warning_level == WarningLevel.HIGH)} HIGH, "
|
||||
f"{sum(1 for w in warnings if w.warning_level == WarningLevel.MEDIUM)} MEDIUM, "
|
||||
f"{sum(1 for w in warnings if w.warning_level == WarningLevel.LOW)} LOW"
|
||||
)
|
||||
|
||||
return warnings
|
||||
|
||||
async def _search_experiences(self, task_type: str) -> list[Any]:
|
||||
"""检索指定任务类型的所有经验(包含成功和失败)"""
|
||||
try:
|
||||
results = await self._store.search(
|
||||
query=task_type,
|
||||
top_k=self._max_search_results,
|
||||
task_type=task_type,
|
||||
)
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to search experiences for pitfall detection: {e}")
|
||||
return []
|
||||
|
||||
def _extract_step_failure_stats(
|
||||
self, failed_experiences: list[Any]
|
||||
) -> dict[str, _StepFailureStats]:
|
||||
"""从失败经验中提取步骤级别的失败统计
|
||||
|
||||
steps_summary 可以是 str 或 list[dict]:
|
||||
- list[dict]: 每个字典包含 step_name, outcome, duration_seconds, error
|
||||
- str: 退化为整体统计
|
||||
|
||||
Returns:
|
||||
以步骤名称为 key 的失败统计字典
|
||||
"""
|
||||
stats: dict[str, _StepFailureStats] = {}
|
||||
|
||||
for exp in failed_experiences:
|
||||
steps_summary = exp.steps_summary
|
||||
|
||||
# 如果 steps_summary 是字符串,无法提取步骤级信息
|
||||
if isinstance(steps_summary, str):
|
||||
continue
|
||||
|
||||
if not isinstance(steps_summary, list):
|
||||
continue
|
||||
|
||||
for step in steps_summary:
|
||||
if not isinstance(step, dict):
|
||||
continue
|
||||
|
||||
step_name = step.get("step_name", "")
|
||||
if not step_name:
|
||||
continue
|
||||
|
||||
outcome = step.get("outcome", "")
|
||||
error = step.get("error", "")
|
||||
|
||||
if step_name not in stats:
|
||||
stats[step_name] = _StepFailureStats(
|
||||
step_name=step_name,
|
||||
total_occurrences=0,
|
||||
failure_occurrences=0,
|
||||
failure_reasons=[],
|
||||
optimization_tips=[],
|
||||
)
|
||||
|
||||
s = stats[step_name]
|
||||
s.total_occurrences += 1
|
||||
|
||||
if outcome in ("failure", "failed", "error"):
|
||||
s.failure_occurrences += 1
|
||||
if error:
|
||||
s.failure_reasons.append(error)
|
||||
|
||||
# 收集优化建议
|
||||
if hasattr(exp, "optimization_tips") and exp.optimization_tips:
|
||||
for step_name, s in stats.items():
|
||||
s.optimization_tips.extend(exp.optimization_tips)
|
||||
|
||||
return stats
|
||||
|
||||
def _match_and_warn(
|
||||
self,
|
||||
planned_steps: list[Any],
|
||||
step_failure_stats: dict[str, _StepFailureStats],
|
||||
) -> list[PitfallWarning]:
|
||||
"""将计划步骤与失败统计进行匹配,生成预警"""
|
||||
warnings: list[PitfallWarning] = []
|
||||
|
||||
for step in planned_steps:
|
||||
step_name = getattr(step, "name", "")
|
||||
step_description = getattr(step, "description", "")
|
||||
|
||||
if not step_name:
|
||||
continue
|
||||
|
||||
# 查找最佳匹配的失败步骤
|
||||
best_match: _StepFailureStats | None = None
|
||||
best_similarity = 0.0
|
||||
|
||||
for stats_step_name, stats in step_failure_stats.items():
|
||||
similarity = _compute_name_similarity(
|
||||
step_name, step_description, stats_step_name
|
||||
)
|
||||
if similarity > best_similarity:
|
||||
best_similarity = similarity
|
||||
best_match = stats
|
||||
|
||||
# 相似度低于阈值,跳过
|
||||
if best_match is None or best_similarity < self._similarity_threshold:
|
||||
continue
|
||||
|
||||
# 计算失败率
|
||||
failure_rate = (
|
||||
best_match.failure_occurrences / best_match.total_occurrences
|
||||
if best_match.total_occurrences > 0
|
||||
else 0.0
|
||||
)
|
||||
|
||||
# 分配预警级别
|
||||
warning_level = _determine_warning_level(failure_rate)
|
||||
|
||||
# 生成建议
|
||||
suggestion = _build_suggestion(best_match, failure_rate)
|
||||
|
||||
warning = PitfallWarning(
|
||||
step_name=step_name,
|
||||
warning_level=warning_level,
|
||||
failure_rate=round(failure_rate, 4),
|
||||
historical_failures=best_match.failure_reasons[:5], # 最多保留 5 条
|
||||
suggestion=suggestion,
|
||||
)
|
||||
warnings.append(warning)
|
||||
|
||||
return warnings
|
||||
|
||||
|
||||
# ── 内部辅助类 ──────────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass
|
||||
class _StepFailureStats:
|
||||
"""步骤级别的失败统计(内部使用)"""
|
||||
|
||||
step_name: str
|
||||
total_occurrences: int
|
||||
failure_occurrences: int
|
||||
failure_reasons: list[str]
|
||||
optimization_tips: list[str]
|
||||
|
||||
|
||||
# ── 辅助函数 ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def _compute_name_similarity(
|
||||
step_name: str, step_description: str, historical_step_name: str
|
||||
) -> float:
|
||||
"""计算步骤名称的关键词重叠相似度
|
||||
|
||||
基于关键词集合的 Jaccard 相似度,同时考虑 step_name 和 step_description。
|
||||
|
||||
Args:
|
||||
step_name: 当前计划步骤名称
|
||||
step_description: 当前计划步骤描述
|
||||
historical_step_name: 历史步骤名称
|
||||
|
||||
Returns:
|
||||
相似度分数(0.0 ~ 1.0)
|
||||
"""
|
||||
# 提取关键词:将名称拆分为词,过滤掉常见停用词
|
||||
current_keywords = _extract_keywords(f"{step_name} {step_description}")
|
||||
historical_keywords = _extract_keywords(historical_step_name)
|
||||
|
||||
if not current_keywords or not historical_keywords:
|
||||
return 0.0
|
||||
|
||||
# Jaccard 相似度
|
||||
intersection = current_keywords & historical_keywords
|
||||
union = current_keywords | historical_keywords
|
||||
|
||||
if not union:
|
||||
return 0.0
|
||||
|
||||
return len(intersection) / len(union)
|
||||
|
||||
|
||||
_STOP_WORDS = frozenset({
|
||||
"a", "an", "the", "and", "or", "but", "in", "on", "at", "to", "for",
|
||||
"of", "with", "by", "from", "is", "are", "was", "were", "be", "been",
|
||||
"being", "have", "has", "had", "do", "does", "did", "will", "would",
|
||||
"could", "should", "may", "might", "can", "shall", "not", "no",
|
||||
})
|
||||
|
||||
|
||||
def _extract_keywords(text: str) -> frozenset[str]:
|
||||
"""从文本中提取关键词集合
|
||||
|
||||
转小写、按空白/下划线/连字符拆分、过滤停用词和单字符词。
|
||||
"""
|
||||
# 统一分隔符
|
||||
normalized = text.lower().replace("_", " ").replace("-", " ")
|
||||
words = normalized.split()
|
||||
return frozenset(
|
||||
w for w in words
|
||||
if len(w) > 1 and w not in _STOP_WORDS
|
||||
)
|
||||
|
||||
|
||||
def _determine_warning_level(failure_rate: float) -> WarningLevel:
|
||||
"""根据失败率确定预警级别
|
||||
|
||||
- HIGH: failure_rate >= 0.5
|
||||
- MEDIUM: failure_rate >= 0.2
|
||||
- LOW: 有任何失败记录
|
||||
"""
|
||||
if failure_rate >= _HIGH_THRESHOLD:
|
||||
return WarningLevel.HIGH
|
||||
if failure_rate >= _MEDIUM_THRESHOLD:
|
||||
return WarningLevel.MEDIUM
|
||||
return WarningLevel.LOW
|
||||
|
||||
|
||||
def _warning_level_order(level: WarningLevel) -> int:
|
||||
"""预警级别排序值(越小越严重)"""
|
||||
return {
|
||||
WarningLevel.HIGH: 0,
|
||||
WarningLevel.MEDIUM: 1,
|
||||
WarningLevel.LOW: 2,
|
||||
}[level]
|
||||
|
||||
|
||||
def _build_suggestion(stats: _StepFailureStats, failure_rate: float) -> str:
|
||||
"""根据失败统计生成优化建议"""
|
||||
parts: list[str] = []
|
||||
|
||||
if failure_rate >= _HIGH_THRESHOLD:
|
||||
parts.append(f"该步骤历史失败率高达 {failure_rate:.0%},建议重点关注")
|
||||
elif failure_rate >= _MEDIUM_THRESHOLD:
|
||||
parts.append(f"该步骤历史失败率为 {failure_rate:.0%},需注意风险")
|
||||
else:
|
||||
parts.append(f"该步骤有少量失败记录(失败率 {failure_rate:.0%})")
|
||||
|
||||
if stats.failure_reasons:
|
||||
unique_reasons = list(dict.fromkeys(stats.failure_reasons))[:3]
|
||||
reasons_str = "、".join(unique_reasons)
|
||||
parts.append(f"常见失败原因:{reasons_str}")
|
||||
|
||||
if stats.optimization_tips:
|
||||
unique_tips = list(dict.fromkeys(stats.optimization_tips))[:2]
|
||||
tips_str = ";".join(unique_tips)
|
||||
parts.append(f"建议:{tips_str}")
|
||||
|
||||
return "。".join(parts)
|
||||
|
|
@ -9,6 +9,10 @@ from agentkit.tools.composition import SequentialChain, ParallelFanOut, DynamicS
|
|||
from agentkit.tools.web_crawl import WebCrawlTool
|
||||
from agentkit.tools.schema_tools import SchemaExtractTool, SchemaGenerateTool
|
||||
from agentkit.tools.baidu_search import BaiduSearchTool
|
||||
from agentkit.tools.shell import ShellTool
|
||||
from agentkit.tools.terminal_session import TerminalSession, TerminalSessionManager
|
||||
from agentkit.tools.pty_session import PTYSession
|
||||
from agentkit.tools.output_parser import OutputParser, ParsedOutput, ErrorType
|
||||
|
||||
# Conditional import: HeadroomRetrieveTool requires HeadroomCompressor
|
||||
try:
|
||||
|
|
@ -30,4 +34,11 @@ __all__ = [
|
|||
"SchemaGenerateTool",
|
||||
"BaiduSearchTool",
|
||||
"HeadroomRetrieveTool",
|
||||
"ShellTool",
|
||||
"TerminalSession",
|
||||
"TerminalSessionManager",
|
||||
"PTYSession",
|
||||
"OutputParser",
|
||||
"ParsedOutput",
|
||||
"ErrorType",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,294 @@
|
|||
"""OutputParser - 结构化解析命令输出
|
||||
|
||||
将命令行输出解析为结构化格式,包含错误类型识别、退出码含义和可操作建议。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ErrorType(Enum):
|
||||
"""命令输出错误类型"""
|
||||
|
||||
NONE = "none"
|
||||
PERMISSION_DENIED = "permission_denied"
|
||||
NOT_FOUND = "not_found"
|
||||
TIMEOUT = "timeout"
|
||||
SYNTAX_ERROR = "syntax_error"
|
||||
CONNECTION_REFUSED = "connection_refused"
|
||||
OUT_OF_MEMORY = "out_of_memory"
|
||||
DISK_FULL = "disk_full"
|
||||
ALREADY_EXISTS = "already_exists"
|
||||
INVALID_ARGUMENT = "invalid_argument"
|
||||
PROCESS_NOT_FOUND = "process_not_found"
|
||||
NETWORK_ERROR = "network_error"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParsedOutput:
|
||||
"""结构化命令输出
|
||||
|
||||
Attributes:
|
||||
exit_code: 命令退出码
|
||||
is_error: 是否为错误输出
|
||||
error_type: 错误类型(仅当 is_error=True 时有值)
|
||||
message: 输出消息摘要
|
||||
raw_output: 原始输出文本
|
||||
suggestions: 可操作建议列表
|
||||
"""
|
||||
|
||||
exit_code: int
|
||||
is_error: bool
|
||||
error_type: ErrorType = ErrorType.NONE
|
||||
message: str = ""
|
||||
raw_output: str = ""
|
||||
suggestions: list[str] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"exit_code": self.exit_code,
|
||||
"is_error": self.is_error,
|
||||
"error_type": self.error_type.value,
|
||||
"message": self.message,
|
||||
"suggestions": self.suggestions,
|
||||
}
|
||||
|
||||
|
||||
# 错误模式匹配规则:(pattern, error_type, message_template, suggestions)
|
||||
_ERROR_PATTERNS: list[tuple[re.Pattern, ErrorType, str, list[str]]] = [
|
||||
(
|
||||
re.compile(r"permission denied|access denied|权限不足|拒绝访问", re.IGNORECASE),
|
||||
ErrorType.PERMISSION_DENIED,
|
||||
"权限不足",
|
||||
[
|
||||
"尝试使用 sudo 执行该命令",
|
||||
"检查文件/目录权限: ls -la <path>",
|
||||
"确认当前用户是否有所需权限",
|
||||
],
|
||||
),
|
||||
(
|
||||
re.compile(
|
||||
r"not found|no such file|no such directory|找不到|不存在|无法找到",
|
||||
re.IGNORECASE,
|
||||
),
|
||||
ErrorType.NOT_FOUND,
|
||||
"文件或目录不存在",
|
||||
[
|
||||
"检查路径拼写是否正确",
|
||||
"使用 ls 确认文件/目录是否存在",
|
||||
"检查是否在正确的工作目录下",
|
||||
],
|
||||
),
|
||||
(
|
||||
re.compile(r"timed?\s*out|timeout|超时|时间超限", re.IGNORECASE),
|
||||
ErrorType.TIMEOUT,
|
||||
"命令执行超时",
|
||||
[
|
||||
"增加超时时间",
|
||||
"检查网络连接是否正常",
|
||||
"检查目标服务是否可达",
|
||||
],
|
||||
),
|
||||
(
|
||||
re.compile(
|
||||
r"syntax error|syntaxerror|parse error|语法错误|解析错误",
|
||||
re.IGNORECASE,
|
||||
),
|
||||
ErrorType.SYNTAX_ERROR,
|
||||
"语法错误",
|
||||
[
|
||||
"检查命令语法是否正确",
|
||||
"使用 --help 查看命令用法",
|
||||
"检查引号和特殊字符是否正确转义",
|
||||
],
|
||||
),
|
||||
(
|
||||
re.compile(
|
||||
r"connection refused|连接被拒绝|无法连接|ECONNREFUSED",
|
||||
re.IGNORECASE,
|
||||
),
|
||||
ErrorType.CONNECTION_REFUSED,
|
||||
"连接被拒绝",
|
||||
[
|
||||
"检查目标服务是否已启动",
|
||||
"确认端口号是否正确",
|
||||
"检查防火墙设置是否阻止了连接",
|
||||
],
|
||||
),
|
||||
(
|
||||
re.compile(
|
||||
r"out of memory|oom|cannot allocate|内存不足|内存溢出",
|
||||
re.IGNORECASE,
|
||||
),
|
||||
ErrorType.OUT_OF_MEMORY,
|
||||
"内存不足",
|
||||
[
|
||||
"释放不必要的内存占用",
|
||||
"增加系统可用内存",
|
||||
"检查是否有内存泄漏",
|
||||
],
|
||||
),
|
||||
(
|
||||
re.compile(
|
||||
r"no space left|disk full|磁盘已满|空间不足|ENOSPC",
|
||||
re.IGNORECASE,
|
||||
),
|
||||
ErrorType.DISK_FULL,
|
||||
"磁盘空间不足",
|
||||
[
|
||||
"清理不必要的文件: du -sh * | sort -rh | head",
|
||||
"检查磁盘使用情况: df -h",
|
||||
"删除临时文件或日志",
|
||||
],
|
||||
),
|
||||
(
|
||||
re.compile(
|
||||
r"already exists|file exists|已存在|重复|EEXIST",
|
||||
re.IGNORECASE,
|
||||
),
|
||||
ErrorType.ALREADY_EXISTS,
|
||||
"资源已存在",
|
||||
[
|
||||
"使用 -f 参数强制覆盖(如适用)",
|
||||
"先删除已有资源再重新创建",
|
||||
"使用不同名称创建",
|
||||
],
|
||||
),
|
||||
(
|
||||
re.compile(
|
||||
r"invalid argument|illegal option|bad option|无效参数|非法选项|invalid option",
|
||||
re.IGNORECASE,
|
||||
),
|
||||
ErrorType.INVALID_ARGUMENT,
|
||||
"无效参数",
|
||||
[
|
||||
"检查命令参数是否正确",
|
||||
"使用 --help 查看支持的参数",
|
||||
"确认参数值类型和范围",
|
||||
],
|
||||
),
|
||||
(
|
||||
re.compile(
|
||||
r"no such process|process not found|进程不存在|进程未找到",
|
||||
re.IGNORECASE,
|
||||
),
|
||||
ErrorType.PROCESS_NOT_FOUND,
|
||||
"进程不存在",
|
||||
[
|
||||
"确认进程 ID 是否正确",
|
||||
"使用 ps aux 查看运行中的进程",
|
||||
"进程可能已经结束",
|
||||
],
|
||||
),
|
||||
(
|
||||
re.compile(
|
||||
r"network is unreachable|no route to host|name resolution|网络不可达|无法解析|ENETUNREACH",
|
||||
re.IGNORECASE,
|
||||
),
|
||||
ErrorType.NETWORK_ERROR,
|
||||
"网络错误",
|
||||
[
|
||||
"检查网络连接是否正常",
|
||||
"确认 DNS 解析是否正常: nslookup <domain>",
|
||||
"检查代理设置",
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class OutputParser:
|
||||
"""命令输出结构化解析器
|
||||
|
||||
将命令行输出(stdout + stderr)和退出码解析为结构化的 ParsedOutput,
|
||||
包含错误类型识别、消息摘要和可操作建议。
|
||||
"""
|
||||
|
||||
def parse(self, output: str, exit_code: int) -> ParsedOutput:
|
||||
"""解析命令输出
|
||||
|
||||
Args:
|
||||
output: 命令的标准输出和错误输出合并文本
|
||||
exit_code: 命令退出码
|
||||
|
||||
Returns:
|
||||
ParsedOutput 结构化解析结果
|
||||
"""
|
||||
is_error = exit_code != 0
|
||||
message = self._extract_message(output)
|
||||
error_type = ErrorType.NONE
|
||||
suggestions: list[str] = []
|
||||
|
||||
if is_error:
|
||||
error_type, suggestions = self._classify_error(output, exit_code)
|
||||
|
||||
return ParsedOutput(
|
||||
exit_code=exit_code,
|
||||
is_error=is_error,
|
||||
error_type=error_type,
|
||||
message=message,
|
||||
raw_output=output,
|
||||
suggestions=suggestions,
|
||||
)
|
||||
|
||||
def _extract_message(self, output: str) -> str:
|
||||
"""从输出中提取关键消息
|
||||
|
||||
取最后几行非空输出中的关键行作为消息摘要。
|
||||
"""
|
||||
if not output:
|
||||
return ""
|
||||
|
||||
lines = [line.strip() for line in output.strip().splitlines() if line.strip()]
|
||||
if not lines:
|
||||
return ""
|
||||
|
||||
# 取最后一行作为摘要,如果太长则截断
|
||||
message = lines[-1]
|
||||
if len(message) > 200:
|
||||
message = message[:200] + "..."
|
||||
return message
|
||||
|
||||
def _classify_error(
|
||||
self, output: str, exit_code: int
|
||||
) -> tuple[ErrorType, list[str]]:
|
||||
"""根据输出内容和退出码分类错误类型
|
||||
|
||||
Args:
|
||||
output: 命令输出
|
||||
exit_code: 退出码
|
||||
|
||||
Returns:
|
||||
(error_type, suggestions) 元组
|
||||
"""
|
||||
# 优先根据输出内容匹配
|
||||
for pattern, error_type, _msg, suggestions in _ERROR_PATTERNS:
|
||||
if pattern.search(output):
|
||||
return error_type, suggestions
|
||||
|
||||
# 退出码兜底分类
|
||||
if exit_code == 126:
|
||||
return ErrorType.PERMISSION_DENIED, [
|
||||
"检查文件是否有执行权限: chmod +x <file>",
|
||||
"确认文件格式是否正确(如行尾符)",
|
||||
]
|
||||
if exit_code == 127:
|
||||
return ErrorType.NOT_FOUND, [
|
||||
"检查命令是否已安装",
|
||||
"确认命令名称拼写是否正确",
|
||||
"检查 PATH 环境变量是否包含命令所在目录",
|
||||
]
|
||||
if exit_code == 130:
|
||||
return ErrorType.TIMEOUT, [
|
||||
"命令被 Ctrl+C 中断",
|
||||
"可能需要增加超时时间",
|
||||
]
|
||||
|
||||
return ErrorType.UNKNOWN, [
|
||||
"检查命令输出中的错误信息",
|
||||
"使用 --verbose 或 --debug 获取更多详情",
|
||||
]
|
||||
|
|
@ -0,0 +1,341 @@
|
|||
"""PTYSession - 伪终端会话,支持交互式命令
|
||||
|
||||
基于 asyncio + os.openpty() 实现伪终端,支持交互式命令和自动应答。
|
||||
不依赖 pexpect,仅使用标准库。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import fcntl
|
||||
import logging
|
||||
import os
|
||||
import struct
|
||||
import termios
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 自动应答规则:(prompt_pattern, response)
|
||||
_AUTO_RESPOND_RULES: list[tuple[str, str]] = [
|
||||
(r"\[y/N\]\s*$", "y"),
|
||||
(r"\[Y/n\]\s*$", "y"),
|
||||
(r"\[yes/no\]\s*$", "yes"),
|
||||
(r"\(yes/no\)\s*$", "yes"),
|
||||
(r"\(yes/no/\[fingerprint\]\)\s*$", "yes"),
|
||||
(r"continue\?\s*$", "y"),
|
||||
(r"are you sure\?\s*$", "y"),
|
||||
(r"password:\s*$", ""), # 密码提示不自动应答,需要人工介入
|
||||
(r"passphrase:\s*$", ""),
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PTYOutput:
|
||||
"""PTY 输出结果
|
||||
|
||||
Attributes:
|
||||
output: 输出文本
|
||||
exit_code: 退出码(-1 表示超时或未结束)
|
||||
timed_out: 是否超时
|
||||
"""
|
||||
|
||||
output: str
|
||||
exit_code: int = -1
|
||||
timed_out: bool = False
|
||||
|
||||
|
||||
class PTYSession:
|
||||
"""伪终端会话 - 支持交互式命令
|
||||
|
||||
使用 os.openpty() 创建伪终端对,通过 asyncio 异步读写。
|
||||
支持自动检测提示并应答(如 yes/no 确认)。
|
||||
|
||||
Usage:
|
||||
pty = PTYSession()
|
||||
await pty.start()
|
||||
output = await pty.run_command("ssh-keygen", timeout=10)
|
||||
await pty.close()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
auto_respond: bool = True,
|
||||
custom_rules: list[tuple[str, str]] | None = None,
|
||||
default_timeout: float = 30.0,
|
||||
buffer_size: int = 4096,
|
||||
):
|
||||
"""初始化 PTY 会话
|
||||
|
||||
Args:
|
||||
auto_respond: 是否自动应答已知提示
|
||||
custom_rules: 自定义应答规则列表 [(prompt_pattern, response)]
|
||||
default_timeout: 默认超时时间(秒)
|
||||
buffer_size: 读取缓冲区大小
|
||||
"""
|
||||
self._auto_respond = auto_respond
|
||||
self._respond_rules = list(_AUTO_RESPOND_RULES)
|
||||
if custom_rules:
|
||||
self._respond_rules.extend(custom_rules)
|
||||
self._default_timeout = default_timeout
|
||||
self._buffer_size = buffer_size
|
||||
|
||||
self._master_fd: int | None = None
|
||||
self._slave_fd: int | None = None
|
||||
self._process: asyncio.subprocess.Process | None = None
|
||||
self._running = False
|
||||
self._output_buffer = ""
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""PTY 会话是否已启动(伪终端已创建)"""
|
||||
return self._running
|
||||
|
||||
async def start(self) -> None:
|
||||
"""启动 PTY 会话(创建伪终端对)
|
||||
|
||||
在执行命令前调用,创建 master/slave 文件描述符。
|
||||
"""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
self._master_fd, self._slave_fd = os.openpty()
|
||||
|
||||
# 设置终端大小
|
||||
try:
|
||||
winsize = struct.pack("HHHH", 24, 80, 0, 0)
|
||||
fcntl.ioctl(self._slave_fd, termios.TIOCSWINSZ, winsize)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 设置 master fd 为非阻塞
|
||||
flags = fcntl.fcntl(self._master_fd, fcntl.F_GETFL)
|
||||
fcntl.fcntl(self._master_fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
|
||||
|
||||
self._running = True
|
||||
logger.debug("PTY session started: master=%d slave=%d", self._master_fd, self._slave_fd)
|
||||
|
||||
async def run_command(
|
||||
self,
|
||||
command: str,
|
||||
timeout: float | None = None,
|
||||
cwd: str | None = None,
|
||||
env: dict[str, str] | None = None,
|
||||
) -> PTYOutput:
|
||||
"""在 PTY 中运行命令,等待完成
|
||||
|
||||
Args:
|
||||
command: 要执行的命令
|
||||
timeout: 超时时间(秒)
|
||||
cwd: 工作目录
|
||||
env: 环境变量
|
||||
|
||||
Returns:
|
||||
PTYOutput 输出结果
|
||||
"""
|
||||
if not self._running:
|
||||
await self.start()
|
||||
|
||||
timeout = timeout or self._default_timeout
|
||||
self._output_buffer = ""
|
||||
|
||||
# 构建环境
|
||||
cmd_env = dict(os.environ)
|
||||
if env:
|
||||
cmd_env.update(env)
|
||||
|
||||
# 启动子进程,使用 slave 作为 stdin/stdout/stderr
|
||||
self._process = await asyncio.create_subprocess_shell(
|
||||
command,
|
||||
stdin=self._slave_fd,
|
||||
stdout=self._slave_fd,
|
||||
stderr=self._slave_fd,
|
||||
cwd=cwd,
|
||||
env=cmd_env,
|
||||
start_new_session=True,
|
||||
)
|
||||
|
||||
# 关闭 slave 端(子进程已继承)
|
||||
os.close(self._slave_fd)
|
||||
self._slave_fd = None
|
||||
|
||||
# 异步读取输出
|
||||
try:
|
||||
exit_code = await self._read_until_exit(timeout)
|
||||
except asyncio.TimeoutError:
|
||||
self._output_buffer += "\n[PTY 命令执行超时]"
|
||||
return PTYOutput(
|
||||
output=self._output_buffer,
|
||||
exit_code=-1,
|
||||
timed_out=True,
|
||||
)
|
||||
|
||||
return PTYOutput(
|
||||
output=self._output_buffer,
|
||||
exit_code=exit_code,
|
||||
timed_out=False,
|
||||
)
|
||||
|
||||
async def send(self, line: str) -> None:
|
||||
"""向 PTY 发送一行输入
|
||||
|
||||
Args:
|
||||
line: 要发送的文本(自动追加换行符)
|
||||
"""
|
||||
if self._master_fd is None:
|
||||
return
|
||||
|
||||
data = (line + "\n").encode("utf-8")
|
||||
try:
|
||||
os.write(self._master_fd, data)
|
||||
except OSError as e:
|
||||
logger.warning("PTY write failed: %s", e)
|
||||
|
||||
async def read_output(self, timeout: float = 1.0) -> str:
|
||||
"""读取当前可用的 PTY 输出
|
||||
|
||||
Args:
|
||||
timeout: 读取超时(秒)
|
||||
|
||||
Returns:
|
||||
读取到的输出文本
|
||||
"""
|
||||
if self._master_fd is None:
|
||||
return ""
|
||||
|
||||
output = ""
|
||||
deadline = time.monotonic() + timeout
|
||||
|
||||
while time.monotonic() < deadline:
|
||||
try:
|
||||
chunk = os.read(self._master_fd, self._buffer_size)
|
||||
if chunk:
|
||||
text = chunk.decode("utf-8", errors="replace")
|
||||
output += text
|
||||
self._output_buffer += text
|
||||
|
||||
# 自动应答
|
||||
if self._auto_respond:
|
||||
await self._try_auto_respond(text)
|
||||
except (BlockingIOError, OSError):
|
||||
# 没有数据可读
|
||||
await asyncio.sleep(0.05)
|
||||
continue
|
||||
|
||||
if output:
|
||||
# 读取到数据后短暂等待看是否还有更多
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
return output
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭 PTY 会话,清理资源"""
|
||||
self._running = False
|
||||
|
||||
if self._process is not None and self._process.returncode is None:
|
||||
try:
|
||||
self._process.terminate()
|
||||
await asyncio.wait_for(self._process.wait(), timeout=5.0)
|
||||
except (asyncio.TimeoutError, ProcessLookupError):
|
||||
try:
|
||||
self._process.kill()
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
|
||||
if self._master_fd is not None:
|
||||
try:
|
||||
os.close(self._master_fd)
|
||||
except OSError:
|
||||
pass
|
||||
self._master_fd = None
|
||||
|
||||
if self._slave_fd is not None:
|
||||
try:
|
||||
os.close(self._slave_fd)
|
||||
except OSError:
|
||||
pass
|
||||
self._slave_fd = None
|
||||
|
||||
self._process = None
|
||||
logger.debug("PTY session closed")
|
||||
|
||||
async def _read_until_exit(self, timeout: float) -> int:
|
||||
"""持续读取输出直到进程退出
|
||||
|
||||
Args:
|
||||
timeout: 超时时间(秒)
|
||||
|
||||
Returns:
|
||||
进程退出码
|
||||
"""
|
||||
deadline = time.monotonic() + timeout
|
||||
|
||||
while True:
|
||||
# 检查超时
|
||||
if time.monotonic() > deadline:
|
||||
raise asyncio.TimeoutError()
|
||||
|
||||
# 检查进程是否已退出
|
||||
if self._process.returncode is not None:
|
||||
# 进程已退出,再读一次剩余输出
|
||||
await self._drain_remaining_output()
|
||||
return self._process.returncode
|
||||
|
||||
# 读取输出
|
||||
try:
|
||||
chunk = os.read(self._master_fd, self._buffer_size)
|
||||
if chunk:
|
||||
text = chunk.decode("utf-8", errors="replace")
|
||||
self._output_buffer += text
|
||||
|
||||
# 自动应答
|
||||
if self._auto_respond:
|
||||
await self._try_auto_respond(text)
|
||||
except (BlockingIOError, OSError):
|
||||
pass
|
||||
|
||||
# 检查进程状态
|
||||
if self._process.returncode is None:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._process.wait(), timeout=0.05
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
else:
|
||||
await self._drain_remaining_output()
|
||||
return self._process.returncode
|
||||
|
||||
await asyncio.sleep(0.02)
|
||||
|
||||
async def _drain_remaining_output(self) -> None:
|
||||
"""排空剩余输出"""
|
||||
for _ in range(10): # 最多尝试 10 次
|
||||
try:
|
||||
chunk = os.read(self._master_fd, self._buffer_size)
|
||||
if chunk:
|
||||
text = chunk.decode("utf-8", errors="replace")
|
||||
self._output_buffer += text
|
||||
else:
|
||||
break
|
||||
except (BlockingIOError, OSError):
|
||||
break
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
async def _try_auto_respond(self, recent_output: str) -> None:
|
||||
"""检测提示并自动应答
|
||||
|
||||
Args:
|
||||
recent_output: 最近的输出文本
|
||||
"""
|
||||
import re
|
||||
|
||||
for pattern, response in self._respond_rules:
|
||||
if not response:
|
||||
# 空响应规则(如密码提示)跳过
|
||||
continue
|
||||
if re.search(pattern, recent_output, re.IGNORECASE | re.MULTILINE):
|
||||
logger.debug("Auto-responding to prompt '%s' with '%s'", pattern, response)
|
||||
await self.send(response)
|
||||
break
|
||||
|
|
@ -0,0 +1,432 @@
|
|||
"""ShellTool - Shell 命令执行工具
|
||||
|
||||
支持无会话模式(向后兼容)和有会话模式(跨命令保持状态)。
|
||||
危险命令通过确认回调请求人工确认,所有操作记录审计日志。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Callable, Awaitable
|
||||
|
||||
from agentkit.tools.base import Tool
|
||||
from agentkit.tools.output_parser import OutputParser, ParsedOutput
|
||||
from agentkit.tools.terminal_session import TerminalSession, TerminalSessionManager
|
||||
from agentkit.tools.pty_session import PTYSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 安全白名单:这些命令前缀不需要确认
|
||||
_SAFE_COMMAND_PREFIXES: tuple[str, ...] = (
|
||||
"ls",
|
||||
"cat",
|
||||
"head",
|
||||
"tail",
|
||||
"grep",
|
||||
"find",
|
||||
"pwd",
|
||||
"echo",
|
||||
"which",
|
||||
"whoami",
|
||||
"id",
|
||||
"date",
|
||||
"uname",
|
||||
"df",
|
||||
"du",
|
||||
"free",
|
||||
"ps",
|
||||
"top",
|
||||
"env",
|
||||
"printenv",
|
||||
"type",
|
||||
"file",
|
||||
"stat",
|
||||
"wc",
|
||||
"sort",
|
||||
"uniq",
|
||||
"diff",
|
||||
"git status",
|
||||
"git log",
|
||||
"git diff",
|
||||
"git branch",
|
||||
"git remote",
|
||||
"pip list",
|
||||
"pip show",
|
||||
"python --version",
|
||||
"python3 --version",
|
||||
"node --version",
|
||||
"npm list",
|
||||
"docker ps",
|
||||
"docker images",
|
||||
"curl",
|
||||
"wget",
|
||||
)
|
||||
|
||||
# 危险命令模式:这些命令需要人工确认
|
||||
_DANGEROUS_PATTERNS: tuple[str, ...] = (
|
||||
"rm ",
|
||||
"rm -",
|
||||
"rmdir",
|
||||
"mkfs",
|
||||
"dd ",
|
||||
"format",
|
||||
"del ",
|
||||
"erase",
|
||||
"> /dev/",
|
||||
"shutdown",
|
||||
"reboot",
|
||||
"init 0",
|
||||
"init 6",
|
||||
"kill -9",
|
||||
"killall",
|
||||
"chmod 777",
|
||||
"chown",
|
||||
"mv /",
|
||||
"pip uninstall",
|
||||
"npm uninstall",
|
||||
"apt remove",
|
||||
"yum remove",
|
||||
"brew uninstall",
|
||||
"docker rm",
|
||||
"docker rmi",
|
||||
"git push --force",
|
||||
"git reset --hard",
|
||||
"git clean -f",
|
||||
"drop table",
|
||||
"drop database",
|
||||
"truncate",
|
||||
)
|
||||
|
||||
|
||||
class ShellTool(Tool):
|
||||
"""Shell 命令执行工具
|
||||
|
||||
支持两种模式:
|
||||
1. 无会话模式(默认):每次命令独立执行,不保持状态
|
||||
2. 有会话模式:通过 session_id 指定会话,跨命令保持 cwd/env/history
|
||||
|
||||
安全控制:
|
||||
- 危险命令通过 confirm_callback 请求人工确认
|
||||
- 所有操作记录审计日志
|
||||
|
||||
Usage:
|
||||
# 无会话模式
|
||||
tool = ShellTool()
|
||||
result = await tool.execute(command="ls -la")
|
||||
|
||||
# 有会话模式
|
||||
result = await tool.execute(command="cd /tmp", session_id="build-01")
|
||||
result = await tool.execute(command="pwd", session_id="build-01") # 输出 /tmp
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str = "shell",
|
||||
description: str = "执行 Shell 命令,支持会话模式保持跨命令状态",
|
||||
input_schema: dict[str, Any] | None = None,
|
||||
output_schema: dict[str, Any] | None = None,
|
||||
version: str = "1.0.0",
|
||||
tags: list[str] | None = None,
|
||||
confirm_callback: Callable[[str], Awaitable[bool]] | None = None,
|
||||
default_timeout: float = 60.0,
|
||||
max_output_length: int = 50000,
|
||||
):
|
||||
super().__init__(
|
||||
name=name,
|
||||
description=description,
|
||||
input_schema=input_schema or self._default_input_schema(),
|
||||
output_schema=output_schema or self._default_output_schema(),
|
||||
version=version,
|
||||
tags=tags or ["shell", "terminal", "system"],
|
||||
)
|
||||
self._session_manager = TerminalSessionManager()
|
||||
self._output_parser = OutputParser()
|
||||
self._confirm_callback = confirm_callback
|
||||
self._default_timeout = default_timeout
|
||||
self._max_output_length = max_output_length
|
||||
self._audit_log: list[dict[str, Any]] = []
|
||||
|
||||
@staticmethod
|
||||
def _default_input_schema() -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "要执行的 Shell 命令",
|
||||
},
|
||||
"timeout": {
|
||||
"type": "number",
|
||||
"description": "超时时间(秒),默认 60",
|
||||
"default": 60,
|
||||
},
|
||||
"working_dir": {
|
||||
"type": "string",
|
||||
"description": "工作目录(仅无会话模式有效)",
|
||||
},
|
||||
"session_id": {
|
||||
"type": "string",
|
||||
"description": "会话 ID,指定后在会话中执行命令,跨命令保持状态",
|
||||
},
|
||||
"interactive": {
|
||||
"type": "boolean",
|
||||
"description": "是否使用交互式模式(PTY),用于需要用户输入的命令",
|
||||
"default": False,
|
||||
},
|
||||
},
|
||||
"required": ["command"],
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _default_output_schema() -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"output": {"type": "string", "description": "命令输出"},
|
||||
"exit_code": {"type": "integer", "description": "退出码"},
|
||||
"is_error": {"type": "boolean", "description": "是否为错误"},
|
||||
"error_type": {"type": "string", "description": "错误类型"},
|
||||
"message": {"type": "string", "description": "消息摘要"},
|
||||
"suggestions": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "可操作建议",
|
||||
},
|
||||
"session_id": {"type": "string", "description": "会话 ID(仅会话模式)"},
|
||||
},
|
||||
}
|
||||
|
||||
async def execute(self, **kwargs) -> dict:
|
||||
"""执行 Shell 命令
|
||||
|
||||
Args:
|
||||
command: 要执行的命令(必需)
|
||||
timeout: 超时时间(秒)
|
||||
working_dir: 工作目录(仅无会话模式)
|
||||
session_id: 会话 ID(启用会话模式)
|
||||
interactive: 是否使用交互式模式
|
||||
|
||||
Returns:
|
||||
包含 output, exit_code, is_error 等字段的字典
|
||||
"""
|
||||
command = kwargs.get("command")
|
||||
if not command:
|
||||
return {
|
||||
"output": "",
|
||||
"exit_code": 1,
|
||||
"is_error": True,
|
||||
"error_type": "invalid_argument",
|
||||
"message": "command 参数是必需的",
|
||||
"suggestions": ["提供要执行的 Shell 命令"],
|
||||
}
|
||||
|
||||
timeout = kwargs.get("timeout", self._default_timeout)
|
||||
working_dir = kwargs.get("working_dir")
|
||||
session_id = kwargs.get("session_id")
|
||||
interactive = kwargs.get("interactive", False)
|
||||
|
||||
# 安全检查:危险命令需要确认
|
||||
if self._is_dangerous(command):
|
||||
confirmed = await self._request_confirmation(command)
|
||||
if not confirmed:
|
||||
self._log_audit(command, None, blocked=True)
|
||||
return {
|
||||
"output": "",
|
||||
"exit_code": 126,
|
||||
"is_error": True,
|
||||
"error_type": "permission_denied",
|
||||
"message": f"危险命令已被拒绝执行: {command[:100]}",
|
||||
"suggestions": [
|
||||
"如需执行此命令,请手动确认",
|
||||
"考虑使用更安全的替代命令",
|
||||
],
|
||||
}
|
||||
|
||||
# 根据模式执行
|
||||
if session_id:
|
||||
result = await self._execute_in_session(
|
||||
command, session_id, timeout, working_dir, interactive
|
||||
)
|
||||
else:
|
||||
result = await self._execute_standalone(command, timeout, working_dir, interactive)
|
||||
|
||||
# 审计日志
|
||||
self._log_audit(command, session_id, exit_code=result.exit_code)
|
||||
|
||||
# 截断过长输出
|
||||
output = result.raw_output
|
||||
if len(output) > self._max_output_length:
|
||||
output = output[: self._max_output_length] + "\n... [输出已截断]"
|
||||
|
||||
return {
|
||||
"output": output,
|
||||
"exit_code": result.exit_code,
|
||||
"is_error": result.is_error,
|
||||
"error_type": result.error_type.value,
|
||||
"message": result.message,
|
||||
"suggestions": result.suggestions,
|
||||
"session_id": session_id,
|
||||
}
|
||||
|
||||
async def _execute_standalone(
|
||||
self,
|
||||
command: str,
|
||||
timeout: float,
|
||||
working_dir: str | None,
|
||||
interactive: bool,
|
||||
) -> ParsedOutput:
|
||||
"""无会话模式执行命令(向后兼容)"""
|
||||
if interactive:
|
||||
return await self._execute_with_pty(command, timeout, working_dir)
|
||||
|
||||
start = time.monotonic()
|
||||
try:
|
||||
proc = await asyncio.create_subprocess_shell(
|
||||
command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.STDOUT,
|
||||
cwd=working_dir,
|
||||
)
|
||||
try:
|
||||
stdout, _ = await asyncio.wait_for(
|
||||
proc.communicate(),
|
||||
timeout=timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
proc.kill()
|
||||
await proc.wait()
|
||||
output = f"命令执行超时({timeout}s)"
|
||||
exit_code = -1
|
||||
else:
|
||||
output = stdout.decode("utf-8", errors="replace") if stdout else ""
|
||||
exit_code = proc.returncode if proc.returncode is not None else 0
|
||||
except Exception as e:
|
||||
output = str(e)
|
||||
exit_code = -1
|
||||
|
||||
return self._output_parser.parse(output, exit_code)
|
||||
|
||||
async def _execute_in_session(
|
||||
self,
|
||||
command: str,
|
||||
session_id: str,
|
||||
timeout: float,
|
||||
working_dir: str | None,
|
||||
interactive: bool,
|
||||
) -> ParsedOutput:
|
||||
"""会话模式执行命令"""
|
||||
session = self._session_manager.get_or_create(
|
||||
session_id,
|
||||
cwd=working_dir,
|
||||
)
|
||||
|
||||
if interactive:
|
||||
return await self._execute_with_pty(
|
||||
command, timeout, session.cwd, session.env
|
||||
)
|
||||
|
||||
return await session.execute(command, timeout=timeout)
|
||||
|
||||
async def _execute_with_pty(
|
||||
self,
|
||||
command: str,
|
||||
timeout: float,
|
||||
cwd: str | None = None,
|
||||
env: dict[str, str] | None = None,
|
||||
) -> ParsedOutput:
|
||||
"""使用 PTY 执行交互式命令"""
|
||||
pty = PTYSession()
|
||||
try:
|
||||
await pty.start()
|
||||
result = await pty.run_command(
|
||||
command,
|
||||
timeout=timeout,
|
||||
cwd=cwd,
|
||||
env=env,
|
||||
)
|
||||
output = result.output
|
||||
exit_code = result.exit_code
|
||||
except Exception as e:
|
||||
output = str(e)
|
||||
exit_code = -1
|
||||
finally:
|
||||
await pty.close()
|
||||
|
||||
return self._output_parser.parse(output, exit_code)
|
||||
|
||||
def _is_dangerous(self, command: str) -> bool:
|
||||
"""检查命令是否为危险操作
|
||||
|
||||
白名单命令直接放行,其他命令检查是否匹配危险模式。
|
||||
"""
|
||||
command_stripped = command.strip()
|
||||
|
||||
# 白名单检查
|
||||
for prefix in _SAFE_COMMAND_PREFIXES:
|
||||
if command_stripped.startswith(prefix):
|
||||
return False
|
||||
|
||||
# 危险模式检查
|
||||
command_lower = command_stripped.lower()
|
||||
for pattern in _DANGEROUS_PATTERNS:
|
||||
if pattern in command_lower:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _request_confirmation(self, command: str) -> bool:
|
||||
"""请求人工确认危险命令
|
||||
|
||||
Args:
|
||||
command: 待确认的命令
|
||||
|
||||
Returns:
|
||||
是否确认执行
|
||||
"""
|
||||
if self._confirm_callback:
|
||||
try:
|
||||
return await self._confirm_callback(command)
|
||||
except Exception as e:
|
||||
logger.warning("确认回调执行失败: %s", e)
|
||||
return False
|
||||
|
||||
# 无回调时默认拒绝
|
||||
logger.warning("危险命令被拒绝(无确认回调): %s", command[:100])
|
||||
return False
|
||||
|
||||
def _log_audit(
|
||||
self,
|
||||
command: str,
|
||||
session_id: str | None,
|
||||
exit_code: int | None = None,
|
||||
blocked: bool = False,
|
||||
) -> None:
|
||||
"""记录审计日志"""
|
||||
entry = {
|
||||
"timestamp": time.time(),
|
||||
"command": command[:500],
|
||||
"session_id": session_id,
|
||||
"exit_code": exit_code,
|
||||
"blocked": blocked,
|
||||
}
|
||||
self._audit_log.append(entry)
|
||||
logger.info(
|
||||
"Shell audit: command=%r session=%s exit=%s blocked=%s",
|
||||
command[:100],
|
||||
session_id,
|
||||
exit_code,
|
||||
blocked,
|
||||
)
|
||||
|
||||
@property
|
||||
def session_manager(self) -> TerminalSessionManager:
|
||||
"""获取会话管理器"""
|
||||
return self._session_manager
|
||||
|
||||
@property
|
||||
def audit_log(self) -> list[dict[str, Any]]:
|
||||
"""获取审计日志(副本)"""
|
||||
return list(self._audit_log)
|
||||
|
|
@ -0,0 +1,352 @@
|
|||
"""TerminalSession - 终端会话状态管理
|
||||
|
||||
维护 cwd、env、history,支持跨命令保持状态。
|
||||
通过在命令前注入 cd 和 export 语句实现跨命令状态持久化。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from agentkit.tools.output_parser import OutputParser, ParsedOutput
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommandRecord:
|
||||
"""命令执行记录
|
||||
|
||||
Attributes:
|
||||
command: 执行的命令
|
||||
exit_code: 退出码
|
||||
output: 标准输出+错误输出
|
||||
cwd: 执行时的工作目录
|
||||
timestamp: 执行时间戳
|
||||
duration_ms: 执行耗时(毫秒)
|
||||
"""
|
||||
|
||||
command: str
|
||||
exit_code: int
|
||||
output: str
|
||||
cwd: str
|
||||
timestamp: float
|
||||
duration_ms: int
|
||||
|
||||
|
||||
class TerminalSession:
|
||||
"""终端会话 - 跨命令保持 cwd/env/history 状态
|
||||
|
||||
通过在命令前注入 `cd {cwd} && ` 和 `export K=V && ` 实现跨命令状态持久化。
|
||||
每次命令执行后自动更新 cwd 和 env 状态。
|
||||
|
||||
Usage:
|
||||
session = TerminalSession(session_id="build-01")
|
||||
result = await session.execute("cd /tmp")
|
||||
result = await session.execute("pwd") # 输出 /tmp
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
cwd: str | None = None,
|
||||
env: dict[str, str] | None = None,
|
||||
max_history: int = 1000,
|
||||
):
|
||||
self.session_id = session_id
|
||||
self._cwd = cwd or os.getcwd()
|
||||
self._env: dict[str, str] = dict(env or os.environ)
|
||||
self._history: list[CommandRecord] = []
|
||||
self._max_history = max_history
|
||||
self._output_parser = OutputParser()
|
||||
self._created_at = time.time()
|
||||
|
||||
@property
|
||||
def cwd(self) -> str:
|
||||
"""当前工作目录"""
|
||||
return self._cwd
|
||||
|
||||
@property
|
||||
def env(self) -> dict[str, str]:
|
||||
"""当前环境变量(副本)"""
|
||||
return dict(self._env)
|
||||
|
||||
@property
|
||||
def history(self) -> list[CommandRecord]:
|
||||
"""命令执行历史(副本)"""
|
||||
return list(self._history)
|
||||
|
||||
@property
|
||||
def created_at(self) -> float:
|
||||
"""会话创建时间戳"""
|
||||
return self._created_at
|
||||
|
||||
def get_cwd(self) -> str:
|
||||
"""获取当前工作目录"""
|
||||
return self._cwd
|
||||
|
||||
def set_cwd(self, cwd: str) -> None:
|
||||
"""手动设置当前工作目录"""
|
||||
self._cwd = cwd
|
||||
|
||||
def get_env(self) -> dict[str, str]:
|
||||
"""获取当前环境变量(副本)"""
|
||||
return dict(self._env)
|
||||
|
||||
def set_env(self, key: str, value: str) -> None:
|
||||
"""设置单个环境变量"""
|
||||
self._env[key] = value
|
||||
|
||||
def update_env(self, env: dict[str, str]) -> None:
|
||||
"""批量更新环境变量"""
|
||||
self._env.update(env)
|
||||
|
||||
def get_history(self) -> list[CommandRecord]:
|
||||
"""获取命令执行历史(副本)"""
|
||||
return list(self._history)
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
command: str,
|
||||
timeout: float | None = None,
|
||||
) -> ParsedOutput:
|
||||
"""在会话上下文中执行命令
|
||||
|
||||
自动在命令前注入 cd 和 export 语句以保持会话状态。
|
||||
执行后自动更新 cwd 和 env。
|
||||
|
||||
Args:
|
||||
command: 要执行的命令
|
||||
timeout: 超时时间(秒),None 表示不超时
|
||||
|
||||
Returns:
|
||||
ParsedOutput 结构化解析结果
|
||||
"""
|
||||
full_command = self._build_command(command)
|
||||
start = time.monotonic()
|
||||
|
||||
try:
|
||||
proc = await asyncio.create_subprocess_shell(
|
||||
full_command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.STDOUT,
|
||||
env=self._env,
|
||||
)
|
||||
try:
|
||||
stdout, _ = await asyncio.wait_for(
|
||||
proc.communicate(),
|
||||
timeout=timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
proc.kill()
|
||||
await proc.wait()
|
||||
output = f"命令执行超时({timeout}s)"
|
||||
exit_code = -1
|
||||
else:
|
||||
output = stdout.decode("utf-8", errors="replace") if stdout else ""
|
||||
exit_code = proc.returncode if proc.returncode is not None else 0
|
||||
except Exception as e:
|
||||
output = str(e)
|
||||
exit_code = -1
|
||||
|
||||
duration_ms = int((time.monotonic() - start) * 1000)
|
||||
|
||||
# 更新会话状态
|
||||
self._update_state_after_execution(command, output, exit_code)
|
||||
|
||||
# 记录历史
|
||||
record = CommandRecord(
|
||||
command=command,
|
||||
exit_code=exit_code,
|
||||
output=output,
|
||||
cwd=self._cwd,
|
||||
timestamp=time.time(),
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
self._add_history(record)
|
||||
|
||||
# 解析输出
|
||||
parsed = self._output_parser.parse(output, exit_code)
|
||||
logger.debug(
|
||||
"Session %s: command=%r exit_code=%d duration=%dms",
|
||||
self.session_id,
|
||||
command,
|
||||
exit_code,
|
||||
duration_ms,
|
||||
)
|
||||
return parsed
|
||||
|
||||
def _build_command(self, command: str) -> str:
|
||||
"""构建带会话状态的完整命令
|
||||
|
||||
在原始命令前注入 cd 和 export 语句。
|
||||
"""
|
||||
parts: list[str] = []
|
||||
|
||||
# 注入 cd
|
||||
if self._cwd:
|
||||
# 使用 shlex.quote 风格的简单转义
|
||||
cwd_escaped = self._cwd.replace("'", "'\\''")
|
||||
parts.append(f"cd '{cwd_escaped}'")
|
||||
|
||||
# 注入环境变量
|
||||
for key, value in self._env.items():
|
||||
# 跳过 os.environ 中已有的且值未变的变量,减少命令长度
|
||||
val_escaped = value.replace("'", "'\\''")
|
||||
parts.append(f"export {key}='{val_escaped}'")
|
||||
|
||||
parts.append(command)
|
||||
return " && ".join(parts)
|
||||
|
||||
def _update_state_after_execution(
|
||||
self, command: str, output: str, exit_code: int
|
||||
) -> None:
|
||||
"""命令执行后更新会话状态
|
||||
|
||||
解析 cd 和 export 命令更新 cwd 和 env。
|
||||
"""
|
||||
if exit_code != 0:
|
||||
return
|
||||
|
||||
# 解析 cd 命令更新 cwd
|
||||
self._parse_cd_commands(command, output)
|
||||
|
||||
# 解析 export 命令更新 env
|
||||
self._parse_export_commands(command)
|
||||
|
||||
def _parse_cd_commands(self, command: str, output: str) -> None:
|
||||
"""从命令中解析 cd 并更新 cwd
|
||||
|
||||
支持:
|
||||
- cd /path
|
||||
- cd dir (相对路径,需要通过 pwd 获取实际路径)
|
||||
- cd - (切换到上一个目录)
|
||||
"""
|
||||
import re
|
||||
|
||||
# 匹配 cd 命令(可能出现在 && 链中)
|
||||
cd_pattern = re.compile(r"(?:^|\s|&&\s*)cd\s+(.+?)(?:\s*&&|\s*$)")
|
||||
matches = cd_pattern.findall(command)
|
||||
|
||||
for target in matches:
|
||||
target = target.strip().strip("'\"")
|
||||
if not target:
|
||||
continue
|
||||
|
||||
if target == "-":
|
||||
# cd - 切换到 OLDPWD
|
||||
old_pwd = self._env.get("OLDPWD")
|
||||
if old_pwd:
|
||||
self._cwd = old_pwd
|
||||
elif os.path.isabs(target):
|
||||
self._cwd = target
|
||||
else:
|
||||
# 相对路径:拼接后规范化
|
||||
new_cwd = os.path.normpath(os.path.join(self._cwd, target))
|
||||
self._cwd = new_cwd
|
||||
|
||||
def _parse_export_commands(self, command: str) -> None:
|
||||
"""从命令中解析 export 并更新 env
|
||||
|
||||
支持:
|
||||
- export KEY=VALUE
|
||||
- export KEY="VALUE WITH SPACES"
|
||||
"""
|
||||
import re
|
||||
|
||||
export_pattern = re.compile(
|
||||
r"(?:^|\s|&&\s*)export\s+(\w+)=(.+?)(?:\s*&&|\s*$)"
|
||||
)
|
||||
matches = export_pattern.findall(command)
|
||||
|
||||
for key, value in matches:
|
||||
value = value.strip().strip("'\"")
|
||||
self._env[key] = value
|
||||
|
||||
def _add_history(self, record: CommandRecord) -> None:
|
||||
"""添加命令记录到历史,超出上限时移除最旧记录"""
|
||||
self._history.append(record)
|
||||
while len(self._history) > self._max_history:
|
||||
self._history.pop(0)
|
||||
|
||||
def close(self) -> None:
|
||||
"""关闭会话,清理资源"""
|
||||
logger.info(
|
||||
"Session %s closed: %d commands executed",
|
||||
self.session_id,
|
||||
len(self._history),
|
||||
)
|
||||
|
||||
|
||||
class TerminalSessionManager:
|
||||
"""终端会话管理器 - 按 ID 管理多个 TerminalSession
|
||||
|
||||
Usage:
|
||||
manager = TerminalSessionManager()
|
||||
session = manager.get_or_create("build-01")
|
||||
session = manager.get("build-01")
|
||||
manager.remove("build-01")
|
||||
"""
|
||||
|
||||
def __init__(self, max_sessions: int = 100):
|
||||
self._sessions: dict[str, TerminalSession] = {}
|
||||
self._max_sessions = max_sessions
|
||||
|
||||
def get_or_create(
|
||||
self,
|
||||
session_id: str,
|
||||
cwd: str | None = None,
|
||||
env: dict[str, str] | None = None,
|
||||
) -> TerminalSession:
|
||||
"""获取或创建会话
|
||||
|
||||
Args:
|
||||
session_id: 会话 ID
|
||||
cwd: 初始工作目录(仅创建时使用)
|
||||
env: 初始环境变量(仅创建时使用)
|
||||
|
||||
Returns:
|
||||
TerminalSession 实例
|
||||
"""
|
||||
if session_id not in self._sessions:
|
||||
if len(self._sessions) >= self._max_sessions:
|
||||
# 移除最旧的会话
|
||||
oldest_id = min(
|
||||
self._sessions, key=lambda k: self._sessions[k].created_at
|
||||
)
|
||||
self.remove(oldest_id)
|
||||
self._sessions[session_id] = TerminalSession(
|
||||
session_id=session_id,
|
||||
cwd=cwd,
|
||||
env=env,
|
||||
)
|
||||
logger.info("Session created: %s", session_id)
|
||||
return self._sessions[session_id]
|
||||
|
||||
def get(self, session_id: str) -> TerminalSession | None:
|
||||
"""获取会话,不存在返回 None"""
|
||||
return self._sessions.get(session_id)
|
||||
|
||||
def remove(self, session_id: str) -> None:
|
||||
"""移除并关闭会话"""
|
||||
session = self._sessions.pop(session_id, None)
|
||||
if session:
|
||||
session.close()
|
||||
|
||||
def list_sessions(self) -> list[str]:
|
||||
"""列出所有会话 ID"""
|
||||
return list(self._sessions.keys())
|
||||
|
||||
def has_session(self, session_id: str) -> bool:
|
||||
"""检查会话是否存在"""
|
||||
return session_id in self._sessions
|
||||
|
||||
def close_all(self) -> None:
|
||||
"""关闭所有会话"""
|
||||
for session_id in list(self._sessions.keys()):
|
||||
self.remove(session_id)
|
||||
|
|
@ -0,0 +1,512 @@
|
|||
"""Tests for PathOptimizer - 执行路径优化器"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.evolution.path_optimizer import ExecutionPath, PathOptimizer, PathUpdateResult
|
||||
|
||||
|
||||
# ── Fixtures ──────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def optimizer():
|
||||
"""默认 PathOptimizer 实例"""
|
||||
return PathOptimizer(min_sample_count=3, success_rate_threshold=0.05, duration_improvement_threshold=0.2)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def optimizer_custom_thresholds():
|
||||
"""自定义阈值的 PathOptimizer"""
|
||||
return PathOptimizer(
|
||||
min_sample_count=5,
|
||||
success_rate_threshold=0.1,
|
||||
duration_improvement_threshold=0.3,
|
||||
)
|
||||
|
||||
|
||||
def _make_path(
|
||||
task_type: str = "code_review",
|
||||
steps: list[str] | None = None,
|
||||
total_duration: float = 10.0,
|
||||
success_rate: float = 0.8,
|
||||
sample_count: int = 5,
|
||||
is_recommended: bool = False,
|
||||
path_id: str = "",
|
||||
created_at: datetime | None = None,
|
||||
) -> ExecutionPath:
|
||||
"""创建测试用 ExecutionPath"""
|
||||
return ExecutionPath(
|
||||
path_id=path_id,
|
||||
task_type=task_type,
|
||||
steps=steps or ["step1", "step2", "step3"],
|
||||
total_duration=total_duration,
|
||||
success_rate=success_rate,
|
||||
sample_count=sample_count,
|
||||
is_recommended=is_recommended,
|
||||
created_at=created_at or datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
# ── ExecutionPath 数据模型测试 ────────────────────────────
|
||||
|
||||
|
||||
class TestExecutionPath:
|
||||
def test_default_values(self):
|
||||
path = ExecutionPath()
|
||||
assert path.path_id == ""
|
||||
assert path.task_type == ""
|
||||
assert path.steps == []
|
||||
assert path.total_duration == 0.0
|
||||
assert path.success_rate == 0.0
|
||||
assert path.sample_count == 0
|
||||
assert path.is_recommended is False
|
||||
assert isinstance(path.created_at, datetime)
|
||||
|
||||
def test_custom_values(self):
|
||||
now = datetime.now(timezone.utc)
|
||||
path = ExecutionPath(
|
||||
path_id="p1",
|
||||
task_type="code_review",
|
||||
steps=["analyze", "review", "report"],
|
||||
total_duration=15.5,
|
||||
success_rate=0.9,
|
||||
sample_count=10,
|
||||
is_recommended=True,
|
||||
created_at=now,
|
||||
)
|
||||
assert path.path_id == "p1"
|
||||
assert path.task_type == "code_review"
|
||||
assert path.steps == ["analyze", "review", "report"]
|
||||
assert path.total_duration == 15.5
|
||||
assert path.success_rate == 0.9
|
||||
assert path.sample_count == 10
|
||||
assert path.is_recommended is True
|
||||
assert path.created_at == now
|
||||
|
||||
|
||||
# ── PathUpdateResult 数据模型测试 ─────────────────────────
|
||||
|
||||
|
||||
class TestPathUpdateResult:
|
||||
def test_default_values(self):
|
||||
result = PathUpdateResult()
|
||||
assert result.updated is False
|
||||
assert result.old_path is None
|
||||
assert result.new_path is None
|
||||
assert result.reason == ""
|
||||
|
||||
def test_updated_result(self):
|
||||
old = _make_path(success_rate=0.7)
|
||||
new = _make_path(success_rate=0.9)
|
||||
result = PathUpdateResult(
|
||||
updated=True,
|
||||
old_path=old,
|
||||
new_path=new,
|
||||
reason="成功率显著提升",
|
||||
)
|
||||
assert result.updated is True
|
||||
assert result.old_path.success_rate == 0.7
|
||||
assert result.new_path.success_rate == 0.9
|
||||
assert "成功率" in result.reason
|
||||
|
||||
|
||||
# ── get_recommended_path 测试 ─────────────────────────────
|
||||
|
||||
|
||||
class TestGetRecommendedPath:
|
||||
async def test_no_recommended_path(self, optimizer):
|
||||
result = optimizer.get_recommended_path("code_review")
|
||||
assert result is None
|
||||
|
||||
async def test_returns_recommended_path(self, optimizer):
|
||||
path = _make_path(task_type="code_review", success_rate=0.8, sample_count=5)
|
||||
await optimizer.evaluate_and_update("code_review", path)
|
||||
result = optimizer.get_recommended_path("code_review")
|
||||
assert result is not None
|
||||
assert result.success_rate == 0.8
|
||||
assert result.is_recommended is True
|
||||
|
||||
async def test_different_task_types_independent(self, optimizer):
|
||||
path_a = _make_path(task_type="code_review", success_rate=0.8, sample_count=5)
|
||||
path_b = _make_path(task_type="data_analysis", success_rate=0.9, sample_count=5)
|
||||
await optimizer.evaluate_and_update("code_review", path_a)
|
||||
await optimizer.evaluate_and_update("data_analysis", path_b)
|
||||
|
||||
result_a = optimizer.get_recommended_path("code_review")
|
||||
result_b = optimizer.get_recommended_path("data_analysis")
|
||||
assert result_a is not None
|
||||
assert result_b is not None
|
||||
assert result_a.success_rate == 0.8
|
||||
assert result_b.success_rate == 0.9
|
||||
|
||||
|
||||
# ── 样本量不足测试 ────────────────────────────────────────
|
||||
|
||||
|
||||
class TestInsufficientSamples:
|
||||
async def test_insufficient_samples_no_update(self, optimizer):
|
||||
"""样本量不足 → 不更新,记录待观察"""
|
||||
path = _make_path(sample_count=2, success_rate=0.9)
|
||||
result = await optimizer.evaluate_and_update("code_review", path)
|
||||
assert result.updated is False
|
||||
assert "样本量不足" in result.reason
|
||||
assert optimizer.get_recommended_path("code_review") is None
|
||||
|
||||
async def test_insufficient_samples_recorded_as_pending(self, optimizer):
|
||||
"""样本量不足的路径被记录到待观察列表"""
|
||||
path = _make_path(sample_count=2, success_rate=0.9)
|
||||
await optimizer.evaluate_and_update("code_review", path)
|
||||
pending = optimizer.get_pending_paths("code_review")
|
||||
assert len(pending) == 1
|
||||
assert pending[0].success_rate == 0.9
|
||||
|
||||
async def test_exact_min_samples_updates(self, optimizer):
|
||||
"""刚好达到最小样本量 → 可以更新"""
|
||||
path = _make_path(sample_count=3, success_rate=0.8)
|
||||
result = await optimizer.evaluate_and_update("code_review", path)
|
||||
assert result.updated is True
|
||||
assert result.reason == "无现有推荐路径,直接设为推荐"
|
||||
|
||||
async def test_custom_min_sample_count(self, optimizer_custom_thresholds):
|
||||
"""自定义最小样本量"""
|
||||
path = _make_path(sample_count=4, success_rate=0.9)
|
||||
result = await optimizer_custom_thresholds.evaluate_and_update("code_review", path)
|
||||
assert result.updated is False
|
||||
assert "样本量不足" in result.reason
|
||||
|
||||
|
||||
# ── 首次设置推荐路径测试 ──────────────────────────────────
|
||||
|
||||
|
||||
class TestFirstRecommendation:
|
||||
async def test_first_path_becomes_recommended(self, optimizer):
|
||||
"""无现有推荐路径时,新路径直接设为推荐"""
|
||||
path = _make_path(success_rate=0.7, sample_count=5)
|
||||
result = await optimizer.evaluate_and_update("code_review", path)
|
||||
assert result.updated is True
|
||||
assert result.old_path is None
|
||||
assert result.new_path is not None
|
||||
assert result.new_path.is_recommended is True
|
||||
assert "无现有推荐路径" in result.reason
|
||||
|
||||
async def test_auto_generates_path_id(self, optimizer):
|
||||
"""未提供 path_id 时自动生成"""
|
||||
path = _make_path(path_id="", sample_count=5)
|
||||
result = await optimizer.evaluate_and_update("code_review", path)
|
||||
assert result.updated is True
|
||||
assert result.new_path is not None
|
||||
assert len(result.new_path.path_id) > 0
|
||||
|
||||
|
||||
# ── 成功率显著提升测试 ────────────────────────────────────
|
||||
|
||||
|
||||
class TestSuccessRateImprovement:
|
||||
async def test_higher_success_rate_updates(self, optimizer):
|
||||
"""新路径成功率更高 → 更新推荐路径"""
|
||||
old_path = _make_path(success_rate=0.7, sample_count=5)
|
||||
await optimizer.evaluate_and_update("code_review", old_path)
|
||||
|
||||
new_path = _make_path(success_rate=0.85, sample_count=5)
|
||||
result = await optimizer.evaluate_and_update("code_review", new_path)
|
||||
assert result.updated is True
|
||||
assert result.old_path.success_rate == 0.7
|
||||
assert result.new_path.success_rate == 0.85
|
||||
assert "成功率显著提升" in result.reason
|
||||
|
||||
async def test_marginal_success_rate_no_update(self, optimizer):
|
||||
"""成功率提升不足阈值 → 不更新"""
|
||||
old_path = _make_path(success_rate=0.8, sample_count=5)
|
||||
await optimizer.evaluate_and_update("code_review", old_path)
|
||||
|
||||
# 提升仅 0.03,低于默认阈值 0.05
|
||||
new_path = _make_path(success_rate=0.83, sample_count=5)
|
||||
result = await optimizer.evaluate_and_update("code_review", new_path)
|
||||
assert result.updated is False
|
||||
assert "无明显优势" in result.reason
|
||||
|
||||
async def test_custom_success_rate_threshold(self, optimizer_custom_thresholds):
|
||||
"""自定义成功率阈值"""
|
||||
old_path = _make_path(success_rate=0.7, sample_count=10)
|
||||
await optimizer_custom_thresholds.evaluate_and_update("code_review", old_path)
|
||||
|
||||
# 提升 0.08,低于自定义阈值 0.1
|
||||
new_path = _make_path(success_rate=0.78, sample_count=10)
|
||||
result = await optimizer_custom_thresholds.evaluate_and_update("code_review", new_path)
|
||||
assert result.updated is False
|
||||
|
||||
async def test_lower_success_rate_no_update(self, optimizer):
|
||||
"""新路径成功率更低 → 不更新"""
|
||||
old_path = _make_path(success_rate=0.9, sample_count=5)
|
||||
await optimizer.evaluate_and_update("code_review", old_path)
|
||||
|
||||
new_path = _make_path(success_rate=0.6, sample_count=5)
|
||||
result = await optimizer.evaluate_and_update("code_review", new_path)
|
||||
assert result.updated is False
|
||||
|
||||
|
||||
# ── 耗时显著更短测试 ──────────────────────────────────────
|
||||
|
||||
|
||||
class TestDurationImprovement:
|
||||
async def test_shorter_duration_with_similar_success_rate_updates(self, optimizer):
|
||||
"""成功率相近但耗时显著更短 → 更新推荐路径"""
|
||||
old_path = _make_path(total_duration=100.0, success_rate=0.8, sample_count=5)
|
||||
await optimizer.evaluate_and_update("code_review", old_path)
|
||||
|
||||
# 耗时减少 30%(> 20% 阈值),成功率相近
|
||||
new_path = _make_path(total_duration=70.0, success_rate=0.82, sample_count=5)
|
||||
result = await optimizer.evaluate_and_update("code_review", new_path)
|
||||
assert result.updated is True
|
||||
assert "耗时显著更短" in result.reason
|
||||
|
||||
async def test_marginal_duration_improvement_no_update(self, optimizer):
|
||||
"""耗时改善不足阈值 → 不更新"""
|
||||
old_path = _make_path(total_duration=100.0, success_rate=0.8, sample_count=5)
|
||||
await optimizer.evaluate_and_update("code_review", old_path)
|
||||
|
||||
# 耗时减少仅 10%(< 20% 阈值)
|
||||
new_path = _make_path(total_duration=90.0, success_rate=0.82, sample_count=5)
|
||||
result = await optimizer.evaluate_and_update("code_review", new_path)
|
||||
assert result.updated is False
|
||||
assert "无明显优势" in result.reason
|
||||
|
||||
async def test_longer_duration_no_update(self, optimizer):
|
||||
"""耗时更长 → 不更新"""
|
||||
old_path = _make_path(total_duration=50.0, success_rate=0.8, sample_count=5)
|
||||
await optimizer.evaluate_and_update("code_review", old_path)
|
||||
|
||||
new_path = _make_path(total_duration=80.0, success_rate=0.82, sample_count=5)
|
||||
result = await optimizer.evaluate_and_update("code_review", new_path)
|
||||
assert result.updated is False
|
||||
|
||||
async def test_custom_duration_improvement_threshold(self, optimizer_custom_thresholds):
|
||||
"""自定义耗时改善阈值"""
|
||||
old_path = _make_path(total_duration=100.0, success_rate=0.8, sample_count=10)
|
||||
await optimizer_custom_thresholds.evaluate_and_update("code_review", old_path)
|
||||
|
||||
# 耗时减少 25%(< 30% 自定义阈值)
|
||||
new_path = _make_path(total_duration=75.0, success_rate=0.82, sample_count=10)
|
||||
result = await optimizer_custom_thresholds.evaluate_and_update("code_review", new_path)
|
||||
assert result.updated is False
|
||||
|
||||
async def test_zero_duration_current_path(self, optimizer):
|
||||
"""现有路径耗时为 0 → 不因耗时更新"""
|
||||
old_path = _make_path(total_duration=0.0, success_rate=0.8, sample_count=5)
|
||||
await optimizer.evaluate_and_update("code_review", old_path)
|
||||
|
||||
new_path = _make_path(total_duration=10.0, success_rate=0.82, sample_count=5)
|
||||
result = await optimizer.evaluate_and_update("code_review", new_path)
|
||||
assert result.updated is False
|
||||
|
||||
async def test_both_zero_duration(self, optimizer):
|
||||
"""两者耗时均为 0 → 不因耗时更新"""
|
||||
old_path = _make_path(total_duration=0.0, success_rate=0.8, sample_count=5)
|
||||
await optimizer.evaluate_and_update("code_review", old_path)
|
||||
|
||||
new_path = _make_path(total_duration=0.0, success_rate=0.82, sample_count=5)
|
||||
result = await optimizer.evaluate_and_update("code_review", new_path)
|
||||
assert result.updated is False
|
||||
|
||||
|
||||
# ── 保留现有推荐路径测试 ──────────────────────────────────
|
||||
|
||||
|
||||
class TestKeepCurrentPath:
|
||||
async def test_no_advantage_keeps_current(self, optimizer):
|
||||
"""新路径无明显优势 → 保留现有推荐路径"""
|
||||
old_path = _make_path(total_duration=50.0, success_rate=0.8, sample_count=5)
|
||||
await optimizer.evaluate_and_update("code_review", old_path)
|
||||
|
||||
new_path = _make_path(total_duration=48.0, success_rate=0.79, sample_count=5)
|
||||
result = await optimizer.evaluate_and_update("code_review", new_path)
|
||||
assert result.updated is False
|
||||
assert result.old_path.success_rate == 0.8
|
||||
# 推荐路径不变
|
||||
recommended = optimizer.get_recommended_path("code_review")
|
||||
assert recommended is not None
|
||||
assert recommended.success_rate == 0.8
|
||||
|
||||
async def test_is_recommended_flag_preserved(self, optimizer):
|
||||
"""未更新时,现有路径的 is_recommended 标志保持为 True"""
|
||||
old_path = _make_path(success_rate=0.8, sample_count=5)
|
||||
await optimizer.evaluate_and_update("code_review", old_path)
|
||||
|
||||
new_path = _make_path(success_rate=0.79, sample_count=5)
|
||||
await optimizer.evaluate_and_update("code_review", new_path)
|
||||
|
||||
recommended = optimizer.get_recommended_path("code_review")
|
||||
assert recommended is not None
|
||||
assert recommended.is_recommended is True
|
||||
|
||||
|
||||
# ── is_recommended 标志管理测试 ────────────────────────────
|
||||
|
||||
|
||||
class TestIsRecommendedFlag:
|
||||
async def test_old_path_loses_recommended_flag(self, optimizer):
|
||||
"""更新后旧路径的 is_recommended 变为 False"""
|
||||
old_path = _make_path(success_rate=0.7, sample_count=5)
|
||||
await optimizer.evaluate_and_update("code_review", old_path)
|
||||
assert old_path.is_recommended is True # 首次设置,is_recommended 为 True
|
||||
|
||||
new_path = _make_path(success_rate=0.9, sample_count=5)
|
||||
result = await optimizer.evaluate_and_update("code_review", new_path)
|
||||
assert result.updated is True
|
||||
assert result.old_path.is_recommended is False # 更新后旧路径失去标志
|
||||
assert result.new_path.is_recommended is True
|
||||
|
||||
|
||||
# ── 多次迭代优化测试 ──────────────────────────────────────
|
||||
|
||||
|
||||
class TestIterativeOptimization:
|
||||
async def test_multiple_updates_converge_to_best(self, optimizer):
|
||||
"""多次迭代后推荐路径收敛到最优"""
|
||||
# 第一次:初始路径
|
||||
path1 = _make_path(success_rate=0.6, total_duration=100.0, sample_count=5)
|
||||
await optimizer.evaluate_and_update("code_review", path1)
|
||||
assert optimizer.get_recommended_path("code_review").success_rate == 0.6
|
||||
|
||||
# 第二次:成功率显著提升
|
||||
path2 = _make_path(success_rate=0.8, total_duration=90.0, sample_count=5)
|
||||
await optimizer.evaluate_and_update("code_review", path2)
|
||||
assert optimizer.get_recommended_path("code_review").success_rate == 0.8
|
||||
|
||||
# 第三次:成功率相近但耗时更短
|
||||
path3 = _make_path(success_rate=0.82, total_duration=50.0, sample_count=5)
|
||||
await optimizer.evaluate_and_update("code_review", path3)
|
||||
assert optimizer.get_recommended_path("code_review").total_duration == 50.0
|
||||
|
||||
# 第四次:无明显优势
|
||||
path4 = _make_path(success_rate=0.81, total_duration=48.0, sample_count=5)
|
||||
result = await optimizer.evaluate_and_update("code_review", path4)
|
||||
assert result.updated is False
|
||||
assert optimizer.get_recommended_path("code_review").total_duration == 50.0
|
||||
|
||||
async def test_different_task_types_evolve_independently(self, optimizer):
|
||||
"""不同任务类型的推荐路径独立进化"""
|
||||
path_a1 = _make_path(task_type="code_review", success_rate=0.7, sample_count=5)
|
||||
path_b1 = _make_path(task_type="data_analysis", success_rate=0.6, sample_count=5)
|
||||
await optimizer.evaluate_and_update("code_review", path_a1)
|
||||
await optimizer.evaluate_and_update("data_analysis", path_b1)
|
||||
|
||||
path_a2 = _make_path(task_type="code_review", success_rate=0.9, sample_count=5)
|
||||
await optimizer.evaluate_and_update("code_review", path_a2)
|
||||
|
||||
# code_review 更新了,data_analysis 不受影响
|
||||
assert optimizer.get_recommended_path("code_review").success_rate == 0.9
|
||||
assert optimizer.get_recommended_path("data_analysis").success_rate == 0.6
|
||||
|
||||
|
||||
# ── 待观察路径管理测试 ────────────────────────────────────
|
||||
|
||||
|
||||
class TestPendingPaths:
|
||||
async def test_pending_paths_empty_initially(self, optimizer):
|
||||
assert optimizer.get_pending_paths("code_review") == []
|
||||
|
||||
async def test_pending_paths_accumulate(self, optimizer):
|
||||
"""多次样本不足的路径会累积"""
|
||||
path1 = _make_path(sample_count=1, success_rate=0.9)
|
||||
path2 = _make_path(sample_count=2, success_rate=0.85)
|
||||
await optimizer.evaluate_and_update("code_review", path1)
|
||||
await optimizer.evaluate_and_update("code_review", path2)
|
||||
|
||||
pending = optimizer.get_pending_paths("code_review")
|
||||
assert len(pending) == 2
|
||||
|
||||
async def test_pending_paths_isolated_by_task_type(self, optimizer):
|
||||
"""不同任务类型的待观察路径相互隔离"""
|
||||
path_a = _make_path(task_type="code_review", sample_count=1, success_rate=0.9)
|
||||
path_b = _make_path(task_type="data_analysis", sample_count=1, success_rate=0.8)
|
||||
await optimizer.evaluate_and_update("code_review", path_a)
|
||||
await optimizer.evaluate_and_update("data_analysis", path_b)
|
||||
|
||||
assert len(optimizer.get_pending_paths("code_review")) == 1
|
||||
assert len(optimizer.get_pending_paths("data_analysis")) == 1
|
||||
|
||||
async def test_sufficient_samples_not_pending(self, optimizer):
|
||||
"""样本量充足的路径不会进入待观察列表"""
|
||||
path = _make_path(sample_count=5, success_rate=0.8)
|
||||
await optimizer.evaluate_and_update("code_review", path)
|
||||
assert optimizer.get_pending_paths("code_review") == []
|
||||
|
||||
|
||||
# ── ExperienceStore 集成测试 ──────────────────────────────
|
||||
|
||||
|
||||
class TestExperienceStoreIntegration:
|
||||
async def test_with_experience_store(self):
|
||||
"""PathOptimizer 可以接受 ExperienceStore 实例"""
|
||||
from agentkit.evolution.experience_store import InMemoryExperienceStore
|
||||
|
||||
store = InMemoryExperienceStore()
|
||||
optimizer = PathOptimizer(experience_store=store, min_sample_count=3)
|
||||
|
||||
path = _make_path(success_rate=0.8, sample_count=5)
|
||||
result = await optimizer.evaluate_and_update("code_review", path)
|
||||
assert result.updated is True
|
||||
|
||||
async def test_without_experience_store(self, optimizer):
|
||||
"""PathOptimizer 可以不依赖 ExperienceStore 独立运行"""
|
||||
path = _make_path(success_rate=0.8, sample_count=5)
|
||||
result = await optimizer.evaluate_and_update("code_review", path)
|
||||
assert result.updated is True
|
||||
|
||||
|
||||
# ── 边界条件测试 ──────────────────────────────────────────
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
async def test_same_path_twice(self, optimizer):
|
||||
"""提交相同路径两次"""
|
||||
path = _make_path(success_rate=0.8, sample_count=5)
|
||||
result1 = await optimizer.evaluate_and_update("code_review", path)
|
||||
assert result1.updated is True
|
||||
|
||||
# 第二次提交相同参数的路径(但不同实例)
|
||||
path2 = _make_path(success_rate=0.8, sample_count=5)
|
||||
result2 = await optimizer.evaluate_and_update("code_review", path2)
|
||||
# 成功率相同,耗时相同 → 无明显优势
|
||||
assert result2.updated is False
|
||||
|
||||
async def test_success_rate_at_boundary(self, optimizer):
|
||||
"""成功率刚好在阈值边界"""
|
||||
old_path = _make_path(success_rate=0.8, sample_count=5)
|
||||
await optimizer.evaluate_and_update("code_review", old_path)
|
||||
|
||||
# 提升恰好等于阈值 0.05,不满足 > threshold
|
||||
new_path = _make_path(success_rate=0.85, sample_count=5)
|
||||
result = await optimizer.evaluate_and_update("code_review", new_path)
|
||||
assert result.updated is False
|
||||
|
||||
async def test_duration_improvement_at_boundary(self, optimizer):
|
||||
"""耗时改善刚好在阈值边界"""
|
||||
old_path = _make_path(total_duration=100.0, success_rate=0.8, sample_count=5)
|
||||
await optimizer.evaluate_and_update("code_review", old_path)
|
||||
|
||||
# 改善恰好等于阈值 20%,不满足 > threshold
|
||||
new_path = _make_path(total_duration=80.0, success_rate=0.82, sample_count=5)
|
||||
result = await optimizer.evaluate_and_update("code_review", new_path)
|
||||
assert result.updated is False
|
||||
|
||||
async def test_zero_sample_count(self, optimizer):
|
||||
"""样本量为 0"""
|
||||
path = _make_path(sample_count=0, success_rate=0.9)
|
||||
result = await optimizer.evaluate_and_update("code_review", path)
|
||||
assert result.updated is False
|
||||
assert "样本量不足" in result.reason
|
||||
|
||||
async def test_path_task_type_override(self, optimizer):
|
||||
"""evaluate_and_update 会用传入的 task_type 覆盖路径的 task_type"""
|
||||
path = _make_path(task_type="wrong_type", success_rate=0.8, sample_count=5)
|
||||
result = await optimizer.evaluate_and_update("code_review", path)
|
||||
assert result.updated is True
|
||||
assert path.task_type == "code_review"
|
||||
recommended = optimizer.get_recommended_path("code_review")
|
||||
assert recommended is not None
|
||||
|
|
@ -0,0 +1,595 @@
|
|||
"""Tests for PitfallDetector - 任务避坑预警检测"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.core.plan_schema import PlanStep, PlanStepStatus
|
||||
from agentkit.evolution.experience_schema import TaskExperience
|
||||
from agentkit.evolution.experience_store import InMemoryExperienceStore
|
||||
from agentkit.evolution.pitfall_detector import (
|
||||
PitfallDetector,
|
||||
PitfallWarning,
|
||||
WarningLevel,
|
||||
_compute_name_similarity,
|
||||
_determine_warning_level,
|
||||
_extract_keywords,
|
||||
)
|
||||
|
||||
|
||||
# ── Fixtures ──────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store():
|
||||
"""无 embedder 的 InMemoryExperienceStore"""
|
||||
return InMemoryExperienceStore(decay_rate=0.01, alpha=0.7)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def detector(store):
|
||||
"""基于 InMemoryExperienceStore 的 PitfallDetector"""
|
||||
return PitfallDetector(experience_store=store, similarity_threshold=0.3)
|
||||
|
||||
|
||||
def _make_experience(
|
||||
task_type: str = "code_review",
|
||||
goal: str = "Review the PR",
|
||||
outcome: str = "success",
|
||||
steps_summary: str | list[dict] = "",
|
||||
failure_reasons: list[str] | None = None,
|
||||
optimization_tips: list[str] | None = None,
|
||||
success_rate: float = 1.0,
|
||||
) -> TaskExperience:
|
||||
"""创建测试用 TaskExperience"""
|
||||
return TaskExperience(
|
||||
experience_id="",
|
||||
task_type=task_type,
|
||||
goal=goal,
|
||||
steps_summary=steps_summary,
|
||||
outcome=outcome,
|
||||
duration_seconds=10.0,
|
||||
success_rate=success_rate,
|
||||
failure_reasons=failure_reasons or [],
|
||||
optimization_tips=optimization_tips or [],
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
def _make_step(
|
||||
name: str = "step",
|
||||
description: str = "do something",
|
||||
step_id: str = "s1",
|
||||
) -> PlanStep:
|
||||
"""创建测试用 PlanStep"""
|
||||
return PlanStep(
|
||||
step_id=step_id,
|
||||
name=name,
|
||||
description=description,
|
||||
status=PlanStepStatus.PENDING,
|
||||
)
|
||||
|
||||
|
||||
# ── 辅助函数测试 ──────────────────────────────────────────
|
||||
|
||||
|
||||
class TestExtractKeywords:
|
||||
def test_basic_extraction(self):
|
||||
keywords = _extract_keywords("Call API Gateway")
|
||||
assert "call" in keywords
|
||||
assert "api" in keywords
|
||||
assert "gateway" in keywords
|
||||
|
||||
def test_stop_words_filtered(self):
|
||||
keywords = _extract_keywords("Call the API and check the result")
|
||||
assert "the" not in keywords
|
||||
assert "and" not in keywords
|
||||
assert "call" in keywords
|
||||
assert "api" in keywords
|
||||
|
||||
def test_underscore_and_hyphen(self):
|
||||
keywords = _extract_keywords("call_api-gateway")
|
||||
assert "call" in keywords
|
||||
assert "api" in keywords
|
||||
assert "gateway" in keywords
|
||||
|
||||
def test_single_char_filtered(self):
|
||||
keywords = _extract_keywords("a b cd")
|
||||
assert "a" not in keywords
|
||||
assert "b" not in keywords
|
||||
assert "cd" in keywords
|
||||
|
||||
def test_empty_string(self):
|
||||
keywords = _extract_keywords("")
|
||||
assert len(keywords) == 0
|
||||
|
||||
|
||||
class TestComputeNameSimilarity:
|
||||
def test_identical_names(self):
|
||||
sim = _compute_name_similarity("Call API Gateway", "", "Call API Gateway")
|
||||
assert sim == pytest.approx(1.0)
|
||||
|
||||
def test_partial_overlap(self):
|
||||
sim = _compute_name_similarity("Call API Gateway", "", "Call External API")
|
||||
# 共享: call, api; 并集: call, api, gateway, external
|
||||
assert 0.0 < sim < 1.0
|
||||
|
||||
def test_no_overlap(self):
|
||||
sim = _compute_name_similarity("Deploy Service", "", "Analyze Data")
|
||||
assert sim == 0.0
|
||||
|
||||
def test_description_contributes(self):
|
||||
sim_no_desc = _compute_name_similarity("Deploy", "", "Deploy Service")
|
||||
sim_with_desc = _compute_name_similarity("Deploy", "Deploy Service", "Deploy Service")
|
||||
# description 中包含匹配关键词,应提高相似度
|
||||
assert sim_with_desc >= sim_no_desc
|
||||
|
||||
def test_empty_inputs(self):
|
||||
sim = _compute_name_similarity("", "", "Call API")
|
||||
assert sim == 0.0
|
||||
|
||||
|
||||
class TestDetermineWarningLevel:
|
||||
def test_high_threshold(self):
|
||||
assert _determine_warning_level(0.6) == WarningLevel.HIGH
|
||||
assert _determine_warning_level(0.5) == WarningLevel.HIGH
|
||||
|
||||
def test_medium_threshold(self):
|
||||
assert _determine_warning_level(0.3) == WarningLevel.MEDIUM
|
||||
assert _determine_warning_level(0.2) == WarningLevel.MEDIUM
|
||||
|
||||
def test_low_threshold(self):
|
||||
assert _determine_warning_level(0.1) == WarningLevel.LOW
|
||||
assert _determine_warning_level(0.01) == WarningLevel.LOW
|
||||
|
||||
|
||||
# ── PitfallDetector.check_pitfalls 测试 ──────────────────
|
||||
|
||||
|
||||
class TestCheckPitfalls:
|
||||
async def test_no_planned_steps_returns_empty(self, detector):
|
||||
warnings = await detector.check_pitfalls(task_type="code_review", planned_steps=[])
|
||||
assert warnings == []
|
||||
|
||||
async def test_no_failed_experiences_returns_empty(self, detector, store):
|
||||
"""无历史失败记录 → 返回空列表"""
|
||||
# 只记录成功经验
|
||||
await store.record_experience(
|
||||
_make_experience(task_type="code_review", outcome="success")
|
||||
)
|
||||
steps = [_make_step(name="Review Code")]
|
||||
warnings = await detector.check_pitfalls(task_type="code_review", planned_steps=steps)
|
||||
assert warnings == []
|
||||
|
||||
async def test_high_failure_rate_returns_high_warning(self, detector, store):
|
||||
"""计划包含历史高失败率步骤 → 返回 HIGH 级别预警"""
|
||||
# 记录多次失败经验,其中 "Call API Gateway" 步骤失败率高
|
||||
for _ in range(6):
|
||||
await store.record_experience(
|
||||
_make_experience(
|
||||
task_type="deployment",
|
||||
outcome="failure",
|
||||
success_rate=0.0,
|
||||
steps_summary=[
|
||||
{"step_name": "Call API Gateway", "outcome": "failure", "error": "Timeout"},
|
||||
{"step_name": "Deploy Container", "outcome": "success"},
|
||||
],
|
||||
failure_reasons=["API Gateway timeout"],
|
||||
)
|
||||
)
|
||||
# 记录少数成功经验
|
||||
for _ in range(4):
|
||||
await store.record_experience(
|
||||
_make_experience(
|
||||
task_type="deployment",
|
||||
outcome="success",
|
||||
success_rate=1.0,
|
||||
steps_summary=[
|
||||
{"step_name": "Call API Gateway", "outcome": "success"},
|
||||
{"step_name": "Deploy Container", "outcome": "success"},
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
steps = [_make_step(name="Call API Gateway", description="Invoke API Gateway endpoint")]
|
||||
warnings = await detector.check_pitfalls(task_type="deployment", planned_steps=steps)
|
||||
|
||||
assert len(warnings) == 1
|
||||
warning = warnings[0]
|
||||
assert warning.step_name == "Call API Gateway"
|
||||
assert warning.warning_level == WarningLevel.HIGH
|
||||
assert warning.failure_rate >= 0.5
|
||||
assert "Timeout" in warning.historical_failures
|
||||
|
||||
async def test_medium_failure_rate(self, detector, store):
|
||||
"""中等失败率 → MEDIUM 级别预警"""
|
||||
# 3 次失败,7 次成功 → 失败率 0.3
|
||||
for _ in range(3):
|
||||
await store.record_experience(
|
||||
_make_experience(
|
||||
task_type="data_analysis",
|
||||
outcome="failure",
|
||||
success_rate=0.0,
|
||||
steps_summary=[
|
||||
{"step_name": "Fetch Data", "outcome": "failure", "error": "Connection refused"},
|
||||
],
|
||||
)
|
||||
)
|
||||
for _ in range(7):
|
||||
await store.record_experience(
|
||||
_make_experience(
|
||||
task_type="data_analysis",
|
||||
outcome="success",
|
||||
success_rate=1.0,
|
||||
steps_summary=[
|
||||
{"step_name": "Fetch Data", "outcome": "success"},
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
steps = [_make_step(name="Fetch Data", description="Fetch data from source")]
|
||||
warnings = await detector.check_pitfalls(task_type="data_analysis", planned_steps=steps)
|
||||
|
||||
assert len(warnings) == 1
|
||||
assert warnings[0].warning_level == WarningLevel.MEDIUM
|
||||
assert 0.2 <= warnings[0].failure_rate < 0.5
|
||||
|
||||
async def test_low_failure_rate(self, detector, store):
|
||||
"""低失败率 → LOW 级别预警"""
|
||||
# 1 次失败,9 次成功 → 失败率 0.1
|
||||
await store.record_experience(
|
||||
_make_experience(
|
||||
task_type="testing",
|
||||
outcome="failure",
|
||||
success_rate=0.0,
|
||||
steps_summary=[
|
||||
{"step_name": "Run Unit Tests", "outcome": "failure", "error": "Flaky test"},
|
||||
],
|
||||
)
|
||||
)
|
||||
for _ in range(9):
|
||||
await store.record_experience(
|
||||
_make_experience(
|
||||
task_type="testing",
|
||||
outcome="success",
|
||||
success_rate=1.0,
|
||||
steps_summary=[
|
||||
{"step_name": "Run Unit Tests", "outcome": "success"},
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
steps = [_make_step(name="Run Unit Tests", description="Execute unit test suite")]
|
||||
warnings = await detector.check_pitfalls(task_type="testing", planned_steps=steps)
|
||||
|
||||
assert len(warnings) == 1
|
||||
assert warnings[0].warning_level == WarningLevel.LOW
|
||||
|
||||
async def test_multiple_steps_with_risks_sorted_by_severity(self, detector, store):
|
||||
"""多个步骤有风险 → 按严重程度排序返回"""
|
||||
# "Call API" 高失败率,"Validate Input" 低失败率
|
||||
for _ in range(6):
|
||||
await store.record_experience(
|
||||
_make_experience(
|
||||
task_type="integration",
|
||||
outcome="failure",
|
||||
success_rate=0.0,
|
||||
steps_summary=[
|
||||
{"step_name": "Call API", "outcome": "failure", "error": "Timeout"},
|
||||
{"step_name": "Validate Input", "outcome": "success"},
|
||||
],
|
||||
)
|
||||
)
|
||||
for _ in range(4):
|
||||
await store.record_experience(
|
||||
_make_experience(
|
||||
task_type="integration",
|
||||
outcome="success",
|
||||
success_rate=1.0,
|
||||
steps_summary=[
|
||||
{"step_name": "Call API", "outcome": "success"},
|
||||
{"step_name": "Validate Input", "outcome": "success"},
|
||||
],
|
||||
)
|
||||
)
|
||||
# 单独给 Validate Input 加一条失败记录
|
||||
await store.record_experience(
|
||||
_make_experience(
|
||||
task_type="integration",
|
||||
outcome="partial",
|
||||
success_rate=0.5,
|
||||
steps_summary=[
|
||||
{"step_name": "Call API", "outcome": "success"},
|
||||
{"step_name": "Validate Input", "outcome": "failure", "error": "Invalid schema"},
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
steps = [
|
||||
_make_step(name="Validate Input", description="Validate input data", step_id="s1"),
|
||||
_make_step(name="Call API", description="Call external API", step_id="s2"),
|
||||
]
|
||||
warnings = await detector.check_pitfalls(task_type="integration", planned_steps=steps)
|
||||
|
||||
assert len(warnings) == 2
|
||||
# HIGH 应排在 MEDIUM/LOW 之前
|
||||
assert warnings[0].warning_level == WarningLevel.HIGH
|
||||
assert warnings[0].step_name == "Call API"
|
||||
|
||||
async def test_no_matching_steps_returns_empty(self, detector, store):
|
||||
"""计划步骤与历史失败步骤无匹配 → 返回空列表"""
|
||||
await store.record_experience(
|
||||
_make_experience(
|
||||
task_type="code_review",
|
||||
outcome="failure",
|
||||
success_rate=0.0,
|
||||
steps_summary=[
|
||||
{"step_name": "Run Linter", "outcome": "failure", "error": "Config error"},
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
# 计划步骤名称与历史步骤完全不同
|
||||
steps = [_make_step(name="Deploy Application", description="Deploy to production")]
|
||||
warnings = await detector.check_pitfalls(task_type="code_review", planned_steps=steps)
|
||||
assert warnings == []
|
||||
|
||||
async def test_different_task_type_no_cross_contamination(self, detector, store):
|
||||
"""不同 task_type 的失败经验不会跨类型预警"""
|
||||
await store.record_experience(
|
||||
_make_experience(
|
||||
task_type="deployment",
|
||||
outcome="failure",
|
||||
success_rate=0.0,
|
||||
steps_summary=[
|
||||
{"step_name": "Deploy Service", "outcome": "failure", "error": "OOM"},
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
# 查询 code_review 类型,不应返回 deployment 的失败经验
|
||||
steps = [_make_step(name="Deploy Service", description="Deploy the service")]
|
||||
warnings = await detector.check_pitfalls(task_type="code_review", planned_steps=steps)
|
||||
assert warnings == []
|
||||
|
||||
async def test_partial_outcome_included(self, detector, store):
|
||||
"""partial 结果的经验也应被检索"""
|
||||
await store.record_experience(
|
||||
_make_experience(
|
||||
task_type="migration",
|
||||
outcome="partial",
|
||||
success_rate=0.5,
|
||||
steps_summary=[
|
||||
{"step_name": "Migrate Database", "outcome": "failure", "error": "Schema mismatch"},
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
steps = [_make_step(name="Migrate Database", description="Migrate DB schema")]
|
||||
warnings = await detector.check_pitfalls(task_type="migration", planned_steps=steps)
|
||||
assert len(warnings) == 1
|
||||
|
||||
async def test_steps_summary_as_string_ignored(self, detector, store):
|
||||
"""steps_summary 为字符串时无法提取步骤级信息,不产生预警"""
|
||||
await store.record_experience(
|
||||
_make_experience(
|
||||
task_type="code_review",
|
||||
outcome="failure",
|
||||
success_rate=0.0,
|
||||
steps_summary="Executed code_review task", # 字符串格式
|
||||
)
|
||||
)
|
||||
|
||||
steps = [_make_step(name="Review Code", description="Review the code")]
|
||||
warnings = await detector.check_pitfalls(task_type="code_review", planned_steps=steps)
|
||||
assert warnings == []
|
||||
|
||||
|
||||
# ── AE3 场景测试 ──────────────────────────────────────────
|
||||
|
||||
|
||||
class TestAE3Scenario:
|
||||
"""AE3: "调用 X 系统 API 在高峰期超时率 60%" → 新任务调用时自动预警"""
|
||||
|
||||
async def test_api_timeout_high_failure_rate_warning(self, detector, store):
|
||||
"""调用 X 系统 API 在高峰期超时率 60% → 新任务调用时自动预警"""
|
||||
# 模拟历史:10 次调用,6 次超时 → 60% 失败率
|
||||
for _ in range(6):
|
||||
await store.record_experience(
|
||||
_make_experience(
|
||||
task_type="order_processing",
|
||||
goal="Process orders via X system",
|
||||
outcome="failure",
|
||||
success_rate=0.0,
|
||||
steps_summary=[
|
||||
{"step_name": "Call X System API", "outcome": "failure", "error": "高峰期超时"},
|
||||
{"step_name": "Process Order", "outcome": "success"},
|
||||
],
|
||||
failure_reasons=["X System API timeout during peak hours"],
|
||||
optimization_tips=["Avoid peak hours", "Add retry logic"],
|
||||
)
|
||||
)
|
||||
for _ in range(4):
|
||||
await store.record_experience(
|
||||
_make_experience(
|
||||
task_type="order_processing",
|
||||
goal="Process orders via X system",
|
||||
outcome="success",
|
||||
success_rate=1.0,
|
||||
steps_summary=[
|
||||
{"step_name": "Call X System API", "outcome": "success"},
|
||||
{"step_name": "Process Order", "outcome": "success"},
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
# 新任务计划包含调用 X 系统 API
|
||||
steps = [
|
||||
_make_step(name="Call X System API", description="Invoke X system API for orders"),
|
||||
]
|
||||
warnings = await detector.check_pitfalls(task_type="order_processing", planned_steps=steps)
|
||||
|
||||
assert len(warnings) == 1
|
||||
warning = warnings[0]
|
||||
assert warning.warning_level == WarningLevel.HIGH
|
||||
assert warning.failure_rate >= 0.5
|
||||
assert any("超时" in reason for reason in warning.historical_failures)
|
||||
assert warning.suggestion # 应有建议
|
||||
|
||||
|
||||
# ── PitfallWarning 数据模型测试 ───────────────────────────
|
||||
|
||||
|
||||
class TestPitfallWarning:
|
||||
def test_creation(self):
|
||||
warning = PitfallWarning(
|
||||
step_name="Call API",
|
||||
warning_level=WarningLevel.HIGH,
|
||||
failure_rate=0.6,
|
||||
historical_failures=["Timeout", "Connection refused"],
|
||||
suggestion="Add retry logic",
|
||||
)
|
||||
assert warning.step_name == "Call API"
|
||||
assert warning.warning_level == WarningLevel.HIGH
|
||||
assert warning.failure_rate == 0.6
|
||||
assert warning.historical_failures == ["Timeout", "Connection refused"]
|
||||
assert warning.suggestion == "Add retry logic"
|
||||
|
||||
def test_default_values(self):
|
||||
warning = PitfallWarning(
|
||||
step_name="Test",
|
||||
warning_level=WarningLevel.LOW,
|
||||
failure_rate=0.1,
|
||||
)
|
||||
assert warning.historical_failures == []
|
||||
assert warning.suggestion == ""
|
||||
|
||||
|
||||
# ── WarningLevel 枚举测试 ─────────────────────────────────
|
||||
|
||||
|
||||
class TestWarningLevel:
|
||||
def test_values(self):
|
||||
assert WarningLevel.HIGH.value == "high"
|
||||
assert WarningLevel.MEDIUM.value == "medium"
|
||||
assert WarningLevel.LOW.value == "low"
|
||||
|
||||
def test_string_comparison(self):
|
||||
assert WarningLevel.HIGH == "high"
|
||||
assert WarningLevel.MEDIUM == "medium"
|
||||
assert WarningLevel.LOW == "low"
|
||||
|
||||
|
||||
# ── 相似度阈值配置测试 ─────────────────────────────────────
|
||||
|
||||
|
||||
class TestSimilarityThreshold:
|
||||
async def test_custom_threshold(self, store):
|
||||
"""自定义相似度阈值"""
|
||||
# 低阈值:更容易匹配
|
||||
detector_low = PitfallDetector(experience_store=store, similarity_threshold=0.1)
|
||||
# 高阈值:更难匹配
|
||||
detector_high = PitfallDetector(experience_store=store, similarity_threshold=0.8)
|
||||
|
||||
await store.record_experience(
|
||||
_make_experience(
|
||||
task_type="testing",
|
||||
outcome="failure",
|
||||
success_rate=0.0,
|
||||
steps_summary=[
|
||||
{"step_name": "Run Integration Tests", "outcome": "failure", "error": "Timeout"},
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
steps = [_make_step(name="Run Unit Tests", description="Execute tests")]
|
||||
# 低阈值可能匹配,高阈值可能不匹配
|
||||
warnings_low = await detector_low.check_pitfalls(task_type="testing", planned_steps=steps)
|
||||
warnings_high = await detector_high.check_pitfalls(task_type="testing", planned_steps=steps)
|
||||
# 低阈值匹配数 >= 高阈值匹配数
|
||||
assert len(warnings_low) >= len(warnings_high)
|
||||
|
||||
|
||||
# ── 端到端流程测试 ─────────────────────────────────────────
|
||||
|
||||
|
||||
class TestEndToEnd:
|
||||
async def test_full_pitfall_detection_flow(self, detector, store):
|
||||
"""完整的避坑检测流程"""
|
||||
# 1. 记录多种失败经验
|
||||
await store.record_experience(
|
||||
_make_experience(
|
||||
task_type="deployment",
|
||||
outcome="failure",
|
||||
success_rate=0.0,
|
||||
steps_summary=[
|
||||
{"step_name": "Build Docker Image", "outcome": "failure", "error": "OOM"},
|
||||
{"step_name": "Push to Registry", "outcome": "success"},
|
||||
],
|
||||
failure_reasons=["Docker build OOM"],
|
||||
optimization_tips=["Increase memory limit"],
|
||||
)
|
||||
)
|
||||
await store.record_experience(
|
||||
_make_experience(
|
||||
task_type="deployment",
|
||||
outcome="failure",
|
||||
success_rate=0.0,
|
||||
steps_summary=[
|
||||
{"step_name": "Build Docker Image", "outcome": "failure", "error": "Dependency conflict"},
|
||||
{"step_name": "Push to Registry", "outcome": "success"},
|
||||
],
|
||||
failure_reasons=["Dependency conflict"],
|
||||
)
|
||||
)
|
||||
await store.record_experience(
|
||||
_make_experience(
|
||||
task_type="deployment",
|
||||
outcome="success",
|
||||
success_rate=1.0,
|
||||
steps_summary=[
|
||||
{"step_name": "Build Docker Image", "outcome": "success"},
|
||||
{"step_name": "Push to Registry", "outcome": "success"},
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
# 2. 新任务计划
|
||||
steps = [
|
||||
_make_step(name="Build Docker Image", description="Build the container image", step_id="s1"),
|
||||
_make_step(name="Push to Registry", description="Push image to container registry", step_id="s2"),
|
||||
]
|
||||
|
||||
# 3. 检测避坑
|
||||
warnings = await detector.check_pitfalls(task_type="deployment", planned_steps=steps)
|
||||
|
||||
# 4. 验证结果
|
||||
assert len(warnings) >= 1
|
||||
# Build Docker Image 失败率 2/3 ≈ 0.667,应为 HIGH
|
||||
build_warning = next((w for w in warnings if w.step_name == "Build Docker Image"), None)
|
||||
assert build_warning is not None
|
||||
assert build_warning.warning_level == WarningLevel.HIGH
|
||||
assert build_warning.failure_rate == pytest.approx(2.0 / 3.0, abs=0.01)
|
||||
|
||||
async def test_suggestion_contains_useful_info(self, detector, store):
|
||||
"""预警建议应包含有用的失败原因和优化建议"""
|
||||
await store.record_experience(
|
||||
_make_experience(
|
||||
task_type="api_integration",
|
||||
outcome="failure",
|
||||
success_rate=0.0,
|
||||
steps_summary=[
|
||||
{"step_name": "Authenticate", "outcome": "failure", "error": "Token expired"},
|
||||
],
|
||||
failure_reasons=["Token expired"],
|
||||
optimization_tips=["Refresh token before expiry"],
|
||||
)
|
||||
)
|
||||
|
||||
steps = [_make_step(name="Authenticate", description="Authenticate with API")]
|
||||
warnings = await detector.check_pitfalls(task_type="api_integration", planned_steps=steps)
|
||||
|
||||
assert len(warnings) == 1
|
||||
assert "Token expired" in warnings[0].suggestion
|
||||
|
|
@ -0,0 +1,217 @@
|
|||
"""PTYSession 单元测试
|
||||
|
||||
测试场景:
|
||||
- PTY 会话启动和关闭
|
||||
- 交互式命令执行
|
||||
- 自动应答提示
|
||||
- 超时处理
|
||||
- 自定义应答规则
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.tools.pty_session import PTYSession, PTYOutput
|
||||
from agentkit.tools.shell import ShellTool
|
||||
|
||||
|
||||
class TestPTYSessionConstruction:
|
||||
"""测试 PTYSession 构造"""
|
||||
|
||||
def test_default_construction(self):
|
||||
pty = PTYSession()
|
||||
assert pty.is_running is False
|
||||
assert pty._auto_respond is True
|
||||
assert pty._default_timeout == 30.0
|
||||
|
||||
def test_custom_construction(self):
|
||||
pty = PTYSession(
|
||||
auto_respond=False,
|
||||
custom_rules=[(r"confirm\?", "yes")],
|
||||
default_timeout=60.0,
|
||||
)
|
||||
assert pty._auto_respond is False
|
||||
assert len(pty._respond_rules) > len(pty._respond_rules) - 1
|
||||
assert pty._default_timeout == 60.0
|
||||
|
||||
|
||||
class TestPTYSessionLifecycle:
|
||||
"""测试 PTYSession 生命周期"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_and_close(self):
|
||||
"""启动和关闭 PTY 会话"""
|
||||
pty = PTYSession()
|
||||
await pty.start()
|
||||
assert pty.is_running is True
|
||||
await pty.close()
|
||||
assert pty.is_running is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_idempotent(self):
|
||||
"""重复启动不报错"""
|
||||
pty = PTYSession()
|
||||
await pty.start()
|
||||
await pty.start() # 不应抛出异常
|
||||
assert pty.is_running is True
|
||||
await pty.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_without_start(self):
|
||||
"""未启动时关闭不报错"""
|
||||
pty = PTYSession()
|
||||
await pty.close() # 不应抛出异常
|
||||
|
||||
|
||||
class TestPTYSessionExecution:
|
||||
"""测试 PTYSession 命令执行"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_simple_command(self):
|
||||
"""执行简单命令"""
|
||||
pty = PTYSession(default_timeout=10.0)
|
||||
try:
|
||||
await pty.start()
|
||||
result = await pty.run_command("echo hello_pty")
|
||||
assert "hello_pty" in result.output
|
||||
assert result.exit_code == 0
|
||||
assert result.timed_out is False
|
||||
finally:
|
||||
await pty.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_command_with_cwd(self):
|
||||
"""指定工作目录执行命令"""
|
||||
pty = PTYSession(default_timeout=10.0)
|
||||
try:
|
||||
await pty.start()
|
||||
result = await pty.run_command("pwd", cwd="/tmp")
|
||||
assert "/tmp" in result.output
|
||||
finally:
|
||||
await pty.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_command_with_env(self):
|
||||
"""指定环境变量执行命令"""
|
||||
pty = PTYSession(default_timeout=10.0)
|
||||
try:
|
||||
await pty.start()
|
||||
result = await pty.run_command(
|
||||
"echo $PTY_TEST_VAR",
|
||||
env={"PTY_TEST_VAR": "pty_value"},
|
||||
)
|
||||
assert "pty_value" in result.output
|
||||
finally:
|
||||
await pty.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_failing_command(self):
|
||||
"""执行失败命令"""
|
||||
pty = PTYSession(default_timeout=10.0)
|
||||
try:
|
||||
await pty.start()
|
||||
result = await pty.run_command("ls /nonexistent_dir_xyz_12345")
|
||||
assert result.exit_code != 0
|
||||
finally:
|
||||
await pty.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_command_timeout(self):
|
||||
"""命令超时"""
|
||||
pty = PTYSession(default_timeout=10.0)
|
||||
try:
|
||||
await pty.start()
|
||||
result = await pty.run_command("sleep 30", timeout=0.5)
|
||||
assert result.timed_out is True
|
||||
assert result.exit_code == -1
|
||||
finally:
|
||||
await pty.close()
|
||||
|
||||
|
||||
class TestPTYSessionAutoRespond:
|
||||
"""测试 PTYSession 自动应答"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_respond_yes_no(self):
|
||||
"""自动应答 [y/N] 提示"""
|
||||
# 使用 echo 模拟包含提示的输出,然后验证自动应答规则存在
|
||||
pty = PTYSession(auto_respond=True)
|
||||
# 验证规则已加载
|
||||
rule_patterns = [r[0] for r in pty._respond_rules]
|
||||
assert any("y/N" in p or "Y/n" in p for p in rule_patterns)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_respond_disabled(self):
|
||||
"""禁用自动应答"""
|
||||
pty = PTYSession(auto_respond=False)
|
||||
assert pty._auto_respond is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_respond_rules(self):
|
||||
"""自定义应答规则"""
|
||||
pty = PTYSession(
|
||||
auto_respond=True,
|
||||
custom_rules=[(r"continue\?\s*$", "yes")],
|
||||
)
|
||||
rule_patterns = [r[0] for r in pty._respond_rules]
|
||||
assert r"continue\?\s*$" in rule_patterns
|
||||
|
||||
|
||||
class TestPTYSessionSendAndRead:
|
||||
"""测试 PTYSession 发送和读取"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_without_start(self):
|
||||
"""未启动时发送不报错"""
|
||||
pty = PTYSession()
|
||||
await pty.send("test") # 不应抛出异常
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_output_without_start(self):
|
||||
"""未启动时读取返回空"""
|
||||
pty = PTYSession()
|
||||
output = await pty.read_output()
|
||||
assert output == ""
|
||||
|
||||
|
||||
class TestPTYOutput:
|
||||
"""测试 PTYOutput 数据类"""
|
||||
|
||||
def test_default_values(self):
|
||||
output = PTYOutput(output="test")
|
||||
assert output.output == "test"
|
||||
assert output.exit_code == -1
|
||||
assert output.timed_out is False
|
||||
|
||||
def test_custom_values(self):
|
||||
output = PTYOutput(output="error", exit_code=1, timed_out=True)
|
||||
assert output.exit_code == 1
|
||||
assert output.timed_out is True
|
||||
|
||||
|
||||
class TestShellToolInteractiveMode:
|
||||
"""测试 ShellTool 交互式模式"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_interactive_mode(self):
|
||||
"""ShellTool interactive 模式执行命令"""
|
||||
tool = ShellTool()
|
||||
result = await tool.execute(command="echo interactive_test", interactive=True)
|
||||
assert result["exit_code"] == 0
|
||||
assert "interactive_test" in result["output"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_interactive_mode_with_session(self):
|
||||
"""ShellTool 会话模式 + 交互式"""
|
||||
tool = ShellTool()
|
||||
result = await tool.execute(
|
||||
command="echo session_interactive",
|
||||
session_id="int-session",
|
||||
interactive=True,
|
||||
)
|
||||
assert result["exit_code"] == 0
|
||||
assert "session_interactive" in result["output"]
|
||||
|
|
@ -0,0 +1,600 @@
|
|||
"""TerminalSession 和 ShellTool 单元测试
|
||||
|
||||
测试场景:
|
||||
- 跨命令保持 cwd → cd 后执行 pwd 返回正确目录
|
||||
- 跨命令保持 env → export 后执行 echo 返回正确值
|
||||
- 危险命令需确认 → rm 命令触发确认回调
|
||||
- 输出解析 → 错误输出结构化为错误类型+建议
|
||||
- 无 session_id 时保持现有行为
|
||||
- 会话管理器功能
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.tools.terminal_session import TerminalSession, TerminalSessionManager, CommandRecord
|
||||
from agentkit.tools.shell import ShellTool
|
||||
from agentkit.tools.output_parser import OutputParser, ParsedOutput, ErrorType
|
||||
|
||||
|
||||
# ============================================================
|
||||
# OutputParser 测试
|
||||
# ============================================================
|
||||
|
||||
|
||||
class TestOutputParser:
|
||||
"""测试 OutputParser 结构化解析"""
|
||||
|
||||
def setup_method(self):
|
||||
self.parser = OutputParser()
|
||||
|
||||
def test_parse_success_output(self):
|
||||
"""成功输出解析"""
|
||||
result = self.parser.parse("hello world", 0)
|
||||
assert result.exit_code == 0
|
||||
assert result.is_error is False
|
||||
assert result.error_type == ErrorType.NONE
|
||||
assert result.message == "hello world"
|
||||
assert result.suggestions == []
|
||||
|
||||
def test_parse_empty_output(self):
|
||||
"""空输出解析"""
|
||||
result = self.parser.parse("", 0)
|
||||
assert result.exit_code == 0
|
||||
assert result.is_error is False
|
||||
assert result.message == ""
|
||||
|
||||
def test_parse_permission_denied(self):
|
||||
"""权限不足错误解析"""
|
||||
result = self.parser.parse("permission denied: /root/secret", 1)
|
||||
assert result.is_error is True
|
||||
assert result.error_type == ErrorType.PERMISSION_DENIED
|
||||
assert len(result.suggestions) > 0
|
||||
assert any("sudo" in s for s in result.suggestions)
|
||||
|
||||
def test_parse_not_found(self):
|
||||
"""文件不存在错误解析"""
|
||||
result = self.parser.parse("No such file or directory: /tmp/missing", 1)
|
||||
assert result.is_error is True
|
||||
assert result.error_type == ErrorType.NOT_FOUND
|
||||
|
||||
def test_parse_timeout(self):
|
||||
"""超时错误解析"""
|
||||
result = self.parser.parse("Connection timed out", 1)
|
||||
assert result.is_error is True
|
||||
assert result.error_type == ErrorType.TIMEOUT
|
||||
|
||||
def test_parse_syntax_error(self):
|
||||
"""语法错误解析"""
|
||||
result = self.parser.parse("syntax error near unexpected token", 2)
|
||||
assert result.is_error is True
|
||||
assert result.error_type == ErrorType.SYNTAX_ERROR
|
||||
|
||||
def test_parse_connection_refused(self):
|
||||
"""连接被拒绝解析"""
|
||||
result = self.parser.parse("Connection refused on port 8080", 1)
|
||||
assert result.is_error is True
|
||||
assert result.error_type == ErrorType.CONNECTION_REFUSED
|
||||
|
||||
def test_parse_out_of_memory(self):
|
||||
"""内存不足解析"""
|
||||
result = self.parser.parse("Out of memory: cannot allocate", 1)
|
||||
assert result.is_error is True
|
||||
assert result.error_type == ErrorType.OUT_OF_MEMORY
|
||||
|
||||
def test_parse_disk_full(self):
|
||||
"""磁盘满解析"""
|
||||
result = self.parser.parse("No space left on device", 1)
|
||||
assert result.is_error is True
|
||||
assert result.error_type == ErrorType.DISK_FULL
|
||||
|
||||
def test_parse_already_exists(self):
|
||||
"""已存在解析"""
|
||||
result = self.parser.parse("File already exists: /tmp/test", 1)
|
||||
assert result.is_error is True
|
||||
assert result.error_type == ErrorType.ALREADY_EXISTS
|
||||
|
||||
def test_parse_invalid_argument(self):
|
||||
"""无效参数解析"""
|
||||
result = self.parser.parse("invalid argument: --unknown-flag", 1)
|
||||
assert result.is_error is True
|
||||
assert result.error_type == ErrorType.INVALID_ARGUMENT
|
||||
|
||||
def test_parse_network_error(self):
|
||||
"""网络错误解析"""
|
||||
result = self.parser.parse("Network is unreachable", 1)
|
||||
assert result.is_error is True
|
||||
assert result.error_type == ErrorType.NETWORK_ERROR
|
||||
|
||||
def test_parse_exit_code_126(self):
|
||||
"""退出码 126 → 权限不足"""
|
||||
result = self.parser.parse("some unknown error", 126)
|
||||
assert result.is_error is True
|
||||
assert result.error_type == ErrorType.PERMISSION_DENIED
|
||||
|
||||
def test_parse_exit_code_127(self):
|
||||
"""退出码 127 → 命令未找到"""
|
||||
result = self.parser.parse("some unknown error", 127)
|
||||
assert result.is_error is True
|
||||
assert result.error_type == ErrorType.NOT_FOUND
|
||||
|
||||
def test_parse_exit_code_130(self):
|
||||
"""退出码 130 → 被中断"""
|
||||
result = self.parser.parse("some unknown error", 130)
|
||||
assert result.is_error is True
|
||||
assert result.error_type == ErrorType.TIMEOUT
|
||||
|
||||
def test_parse_unknown_error(self):
|
||||
"""未知错误"""
|
||||
result = self.parser.parse("something went wrong", 1)
|
||||
assert result.is_error is True
|
||||
assert result.error_type == ErrorType.UNKNOWN
|
||||
|
||||
def test_parse_long_message_truncated(self):
|
||||
"""长消息截断"""
|
||||
long_output = "x" * 300
|
||||
result = self.parser.parse(long_output, 0)
|
||||
assert len(result.message) <= 203 # 200 + "..."
|
||||
|
||||
def test_parsed_output_to_dict(self):
|
||||
"""ParsedOutput.to_dict()"""
|
||||
result = self.parser.parse("permission denied", 1)
|
||||
d = result.to_dict()
|
||||
assert d["exit_code"] == 1
|
||||
assert d["is_error"] is True
|
||||
assert d["error_type"] == "permission_denied"
|
||||
assert isinstance(d["suggestions"], list)
|
||||
|
||||
def test_parse_chinese_error_messages(self):
|
||||
"""中文错误消息解析"""
|
||||
result = self.parser.parse("权限不足: 无法访问", 1)
|
||||
assert result.is_error is True
|
||||
assert result.error_type == ErrorType.PERMISSION_DENIED
|
||||
|
||||
def test_parse_multiline_output_message_is_last_line(self):
|
||||
"""多行输出取最后一行作为消息"""
|
||||
output = "line1\nline2\nline3"
|
||||
result = self.parser.parse(output, 0)
|
||||
assert result.message == "line3"
|
||||
|
||||
|
||||
# ============================================================
|
||||
# TerminalSession 测试
|
||||
# ============================================================
|
||||
|
||||
|
||||
class TestTerminalSession:
|
||||
"""测试 TerminalSession 会话状态管理"""
|
||||
|
||||
def test_construction_default(self):
|
||||
"""默认构造"""
|
||||
session = TerminalSession(session_id="test")
|
||||
assert session.session_id == "test"
|
||||
assert session.cwd == os.getcwd()
|
||||
assert isinstance(session.env, dict)
|
||||
assert session.history == []
|
||||
|
||||
def test_construction_custom_cwd(self):
|
||||
"""自定义工作目录"""
|
||||
session = TerminalSession(session_id="test", cwd="/tmp")
|
||||
assert session.cwd == "/tmp"
|
||||
|
||||
def test_construction_custom_env(self):
|
||||
"""自定义环境变量"""
|
||||
session = TerminalSession(session_id="test", env={"FOO": "bar"})
|
||||
assert session.env.get("FOO") == "bar"
|
||||
|
||||
def test_set_cwd(self):
|
||||
"""手动设置 cwd"""
|
||||
session = TerminalSession(session_id="test")
|
||||
session.set_cwd("/usr/local")
|
||||
assert session.cwd == "/usr/local"
|
||||
|
||||
def test_set_env(self):
|
||||
"""手动设置环境变量"""
|
||||
session = TerminalSession(session_id="test")
|
||||
session.set_env("MY_VAR", "hello")
|
||||
assert session.env.get("MY_VAR") == "hello"
|
||||
|
||||
def test_update_env(self):
|
||||
"""批量更新环境变量"""
|
||||
session = TerminalSession(session_id="test")
|
||||
session.update_env({"A": "1", "B": "2"})
|
||||
assert session.env.get("A") == "1"
|
||||
assert session.env.get("B") == "2"
|
||||
|
||||
def test_get_env_returns_copy(self):
|
||||
"""get_env 返回副本,修改不影响原数据"""
|
||||
session = TerminalSession(session_id="test")
|
||||
env = session.get_env()
|
||||
env["HACKED"] = "yes"
|
||||
assert "HACKED" not in session.env
|
||||
|
||||
def test_get_history_returns_copy(self):
|
||||
"""get_history 返回副本"""
|
||||
session = TerminalSession(session_id="test")
|
||||
history = session.get_history()
|
||||
assert history is not session._history
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_simple_command(self):
|
||||
"""执行简单命令"""
|
||||
session = TerminalSession(session_id="test")
|
||||
result = await session.execute("echo hello")
|
||||
assert result.exit_code == 0
|
||||
assert "hello" in result.raw_output
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_records_history(self):
|
||||
"""执行命令记录历史"""
|
||||
session = TerminalSession(session_id="test")
|
||||
await session.execute("echo first")
|
||||
await session.execute("echo second")
|
||||
assert len(session.history) == 2
|
||||
assert session.history[0].command == "echo first"
|
||||
assert session.history[1].command == "echo second"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cross_command_cwd(self):
|
||||
"""跨命令保持 cwd:cd 后 pwd 返回正确目录"""
|
||||
session = TerminalSession(session_id="test")
|
||||
await session.execute("cd /tmp")
|
||||
assert session.cwd == "/tmp"
|
||||
|
||||
result = await session.execute("pwd")
|
||||
assert "/tmp" in result.raw_output
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cross_command_env(self):
|
||||
"""跨命令保持 env:export 后 echo 返回正确值"""
|
||||
session = TerminalSession(session_id="test")
|
||||
await session.execute("export MY_TEST_VAR=hello123")
|
||||
assert session.env.get("MY_TEST_VAR") == "hello123"
|
||||
|
||||
result = await session.execute("echo $MY_TEST_VAR")
|
||||
assert "hello123" in result.raw_output
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cd_relative_path(self):
|
||||
"""cd 相对路径(目录存在时更新 cwd)"""
|
||||
# 使用 /usr 作为基础目录,cd local(/usr/local 存在)
|
||||
session = TerminalSession(session_id="test", cwd="/usr")
|
||||
await session.execute("cd local")
|
||||
assert session.cwd == "/usr/local"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cd_absolute_path(self):
|
||||
"""cd 绝对路径"""
|
||||
session = TerminalSession(session_id="test")
|
||||
await session.execute("cd /usr")
|
||||
assert session.cwd == "/usr"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failed_command_no_state_update(self):
|
||||
"""失败命令不更新状态"""
|
||||
session = TerminalSession(session_id="test", cwd="/tmp")
|
||||
await session.execute("cd /nonexistent_dir_xyz")
|
||||
# cd 失败,cwd 不应更新
|
||||
assert session.cwd == "/tmp"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout(self):
|
||||
"""命令超时"""
|
||||
session = TerminalSession(session_id="test")
|
||||
result = await session.execute("sleep 10", timeout=0.5)
|
||||
assert result.exit_code == -1
|
||||
assert result.is_error is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_history(self):
|
||||
"""历史记录上限"""
|
||||
session = TerminalSession(session_id="test", max_history=3)
|
||||
for i in range(5):
|
||||
await session.execute(f"echo {i}")
|
||||
assert len(session.history) == 3
|
||||
assert session.history[0].command == "echo 2"
|
||||
|
||||
def test_close(self):
|
||||
"""关闭会话"""
|
||||
session = TerminalSession(session_id="test")
|
||||
session.close() # 不应抛出异常
|
||||
|
||||
|
||||
# ============================================================
|
||||
# TerminalSessionManager 测试
|
||||
# ============================================================
|
||||
|
||||
|
||||
class TestTerminalSessionManager:
|
||||
"""测试 TerminalSessionManager 会话管理"""
|
||||
|
||||
def test_get_or_create_new(self):
|
||||
"""创建新会话"""
|
||||
manager = TerminalSessionManager()
|
||||
session = manager.get_or_create("s1")
|
||||
assert session.session_id == "s1"
|
||||
|
||||
def test_get_or_create_existing(self):
|
||||
"""获取已有会话"""
|
||||
manager = TerminalSessionManager()
|
||||
s1 = manager.get_or_create("s1")
|
||||
s1.set_cwd("/tmp")
|
||||
s2 = manager.get_or_create("s1")
|
||||
assert s2.cwd == "/tmp"
|
||||
|
||||
def test_get_existing(self):
|
||||
"""get 获取已有会话"""
|
||||
manager = TerminalSessionManager()
|
||||
manager.get_or_create("s1")
|
||||
session = manager.get("s1")
|
||||
assert session is not None
|
||||
|
||||
def test_get_nonexistent(self):
|
||||
"""get 不存在的会话返回 None"""
|
||||
manager = TerminalSessionManager()
|
||||
assert manager.get("nonexistent") is None
|
||||
|
||||
def test_remove(self):
|
||||
"""移除会话"""
|
||||
manager = TerminalSessionManager()
|
||||
manager.get_or_create("s1")
|
||||
manager.remove("s1")
|
||||
assert manager.get("s1") is None
|
||||
|
||||
def test_list_sessions(self):
|
||||
"""列出会话"""
|
||||
manager = TerminalSessionManager()
|
||||
manager.get_or_create("s1")
|
||||
manager.get_or_create("s2")
|
||||
assert sorted(manager.list_sessions()) == ["s1", "s2"]
|
||||
|
||||
def test_has_session(self):
|
||||
"""检查会话是否存在"""
|
||||
manager = TerminalSessionManager()
|
||||
manager.get_or_create("s1")
|
||||
assert manager.has_session("s1") is True
|
||||
assert manager.has_session("s2") is False
|
||||
|
||||
def test_max_sessions_eviction(self):
|
||||
"""超过最大会话数时移除最旧会话"""
|
||||
manager = TerminalSessionManager(max_sessions=2)
|
||||
manager.get_or_create("s1")
|
||||
manager.get_or_create("s2")
|
||||
manager.get_or_create("s3") # 应该移除 s1
|
||||
assert not manager.has_session("s1")
|
||||
assert manager.has_session("s2")
|
||||
assert manager.has_session("s3")
|
||||
|
||||
def test_close_all(self):
|
||||
"""关闭所有会话"""
|
||||
manager = TerminalSessionManager()
|
||||
manager.get_or_create("s1")
|
||||
manager.get_or_create("s2")
|
||||
manager.close_all()
|
||||
assert manager.list_sessions() == []
|
||||
|
||||
|
||||
# ============================================================
|
||||
# ShellTool 测试
|
||||
# ============================================================
|
||||
|
||||
|
||||
class TestShellToolConstruction:
|
||||
"""测试 ShellTool 构造"""
|
||||
|
||||
def test_default_construction(self):
|
||||
tool = ShellTool()
|
||||
assert tool.name == "shell"
|
||||
assert tool.input_schema is not None
|
||||
assert "command" in tool.input_schema["properties"]
|
||||
assert "session_id" in tool.input_schema["properties"]
|
||||
assert tool.input_schema["required"] == ["command"]
|
||||
|
||||
def test_custom_construction(self):
|
||||
tool = ShellTool(name="my_shell", version="2.0.0")
|
||||
assert tool.name == "my_shell"
|
||||
assert tool.version == "2.0.0"
|
||||
|
||||
def test_to_dict(self):
|
||||
tool = ShellTool()
|
||||
d = tool.to_dict()
|
||||
assert d["name"] == "shell"
|
||||
assert "input_schema" in d
|
||||
|
||||
def test_repr(self):
|
||||
tool = ShellTool()
|
||||
r = repr(tool)
|
||||
assert "ShellTool" in r
|
||||
assert "shell" in r
|
||||
|
||||
|
||||
class TestShellToolExecution:
|
||||
"""测试 ShellTool 命令执行"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_simple_command(self):
|
||||
"""执行简单命令(无会话模式)"""
|
||||
tool = ShellTool()
|
||||
result = await tool.execute(command="echo hello")
|
||||
assert result["exit_code"] == 0
|
||||
assert "hello" in result["output"]
|
||||
assert result["is_error"] is False
|
||||
assert result["session_id"] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_missing_command(self):
|
||||
"""缺少 command 参数"""
|
||||
tool = ShellTool()
|
||||
result = await tool.execute()
|
||||
assert result["is_error"] is True
|
||||
assert result["exit_code"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_with_working_dir(self):
|
||||
"""指定工作目录"""
|
||||
tool = ShellTool()
|
||||
result = await tool.execute(command="pwd", working_dir="/tmp")
|
||||
assert result["exit_code"] == 0
|
||||
assert "/tmp" in result["output"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_with_session(self):
|
||||
"""会话模式执行命令"""
|
||||
tool = ShellTool()
|
||||
result = await tool.execute(command="echo session_test", session_id="s1")
|
||||
assert result["exit_code"] == 0
|
||||
assert "session_test" in result["output"]
|
||||
assert result["session_id"] == "s1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_preserves_cwd(self):
|
||||
"""会话模式保持 cwd"""
|
||||
tool = ShellTool()
|
||||
await tool.execute(command="cd /tmp", session_id="cwd-test")
|
||||
result = await tool.execute(command="pwd", session_id="cwd-test")
|
||||
assert "/tmp" in result["output"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_preserves_env(self):
|
||||
"""会话模式保持 env"""
|
||||
tool = ShellTool()
|
||||
await tool.execute(
|
||||
command="export SHELL_TEST_VAR=world", session_id="env-test"
|
||||
)
|
||||
result = await tool.execute(
|
||||
command="echo $SHELL_TEST_VAR", session_id="env-test"
|
||||
)
|
||||
assert "world" in result["output"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_session_id_backward_compatible(self):
|
||||
"""无 session_id 时保持现有行为"""
|
||||
tool = ShellTool()
|
||||
result = await tool.execute(command="echo no_session")
|
||||
assert result["exit_code"] == 0
|
||||
assert "no_session" in result["output"]
|
||||
assert result["session_id"] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_sessions_independent(self):
|
||||
"""不同会话互不影响"""
|
||||
tool = ShellTool()
|
||||
await tool.execute(command="cd /tmp", session_id="s1")
|
||||
await tool.execute(command="cd /usr", session_id="s2")
|
||||
|
||||
r1 = await tool.execute(command="pwd", session_id="s1")
|
||||
r2 = await tool.execute(command="pwd", session_id="s2")
|
||||
|
||||
assert "/tmp" in r1["output"]
|
||||
assert "/usr" in r2["output"]
|
||||
|
||||
|
||||
class TestShellToolSecurity:
|
||||
"""测试 ShellTool 安全控制"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_safe_command_allowed(self):
|
||||
"""安全命令直接执行"""
|
||||
tool = ShellTool()
|
||||
result = await tool.execute(command="ls /tmp")
|
||||
assert result["exit_code"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dangerous_command_blocked_without_callback(self):
|
||||
"""危险命令无确认回调时被拒绝"""
|
||||
tool = ShellTool()
|
||||
result = await tool.execute(command="rm -rf /tmp/test")
|
||||
assert result["is_error"] is True
|
||||
assert result["exit_code"] == 126
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dangerous_command_confirmed(self):
|
||||
"""危险命令通过确认回调允许执行"""
|
||||
confirm = AsyncMock(return_value=True)
|
||||
tool = ShellTool(confirm_callback=confirm)
|
||||
result = await tool.execute(command="rm -rf /tmp/nonexistent_test_dir")
|
||||
assert confirm.called
|
||||
# 命令本身可能失败(目录不存在),但不应被安全机制拒绝
|
||||
assert result["exit_code"] != 126 or not result["is_error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dangerous_command_rejected_by_callback(self):
|
||||
"""确认回调拒绝危险命令"""
|
||||
confirm = AsyncMock(return_value=False)
|
||||
tool = ShellTool(confirm_callback=confirm)
|
||||
result = await tool.execute(command="rm -rf /tmp/test")
|
||||
assert result["is_error"] is True
|
||||
assert result["exit_code"] == 126
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audit_log_recorded(self):
|
||||
"""审计日志记录"""
|
||||
tool = ShellTool()
|
||||
await tool.execute(command="echo audit_test")
|
||||
assert len(tool.audit_log) > 0
|
||||
assert tool.audit_log[0]["command"] == "echo audit_test"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_blocked_command_in_audit_log(self):
|
||||
"""被阻止的命令记录在审计日志"""
|
||||
tool = ShellTool()
|
||||
await tool.execute(command="rm -rf /tmp/test")
|
||||
blocked_entries = [e for e in tool.audit_log if e.get("blocked")]
|
||||
assert len(blocked_entries) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_git_push_force_is_dangerous(self):
|
||||
"""git push --force 是危险命令"""
|
||||
tool = ShellTool()
|
||||
result = await tool.execute(command="git push --force origin main")
|
||||
assert result["is_error"] is True
|
||||
assert result["exit_code"] == 126
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_git_status_is_safe(self):
|
||||
"""git status 是安全命令"""
|
||||
tool = ShellTool()
|
||||
result = await tool.execute(command="git status")
|
||||
# git status 可能在非 git 目录失败,但不应被安全机制拒绝
|
||||
assert result["exit_code"] != 126
|
||||
|
||||
|
||||
class TestShellToolOutputParsing:
|
||||
"""测试 ShellTool 输出解析集成"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_output_structured(self):
|
||||
"""错误输出结构化"""
|
||||
tool = ShellTool()
|
||||
result = await tool.execute(command="ls /nonexistent_dir_xyz_12345")
|
||||
assert result["is_error"] is True
|
||||
assert result["error_type"] in ("not_found", "unknown")
|
||||
assert isinstance(result["suggestions"], list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_success_output_not_error(self):
|
||||
"""成功输出不标记为错误"""
|
||||
tool = ShellTool()
|
||||
result = await tool.execute(command="echo success")
|
||||
assert result["is_error"] is False
|
||||
assert result["error_type"] == "none"
|
||||
|
||||
|
||||
class TestShellToolSessionManager:
|
||||
"""测试 ShellTool 会话管理器访问"""
|
||||
|
||||
def test_session_manager_accessible(self):
|
||||
tool = ShellTool()
|
||||
assert tool.session_manager is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_created_on_first_use(self):
|
||||
"""首次使用 session_id 时创建会话"""
|
||||
tool = ShellTool()
|
||||
assert not tool.session_manager.has_session("new-session")
|
||||
await tool.execute(command="echo test", session_id="new-session")
|
||||
assert tool.session_manager.has_session("new-session")
|
||||
Loading…
Reference in New Issue