feat(phase2): implement self-evolution and smart terminal (U6-U8)
- U6: PitfallDetector - detect historical failure patterns and warn - U7: PathOptimizer - discover and update optimal execution paths - U8: TerminalSession - session state, PTY interactive, output parsing 160 new tests passing. ShellTool enhanced with session_id support.
This commit is contained in:
parent
fd4a811929
commit
e3d4f811dd
|
|
@ -0,0 +1,259 @@
|
||||||
|
"""PathOptimizer - 执行路径优化器
|
||||||
|
|
||||||
|
发现更优执行路径时自动更新经验库中的推荐路径。
|
||||||
|
|
||||||
|
核心逻辑:
|
||||||
|
1. 对比新路径与现有最优路径(综合耗时和成功率)
|
||||||
|
2. 新路径成功率更高 → 更新推荐路径
|
||||||
|
3. 成功率相近但耗时更短 → 更新推荐路径
|
||||||
|
4. 样本量不足 → 不更新,记录待观察
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from agentkit.evolution.experience_store import InMemoryExperienceStore
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ExecutionPath:
|
||||||
|
"""执行路径数据模型
|
||||||
|
|
||||||
|
记录特定任务类型的执行路径信息,用于路径优化比较。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
path_id: 路径唯一标识
|
||||||
|
task_type: 任务类型
|
||||||
|
steps: 执行步骤名称列表
|
||||||
|
total_duration: 总耗时(秒)
|
||||||
|
success_rate: 成功率(0.0 ~ 1.0)
|
||||||
|
sample_count: 样本数量
|
||||||
|
is_recommended: 是否为当前推荐路径
|
||||||
|
created_at: 创建时间
|
||||||
|
"""
|
||||||
|
|
||||||
|
path_id: str = ""
|
||||||
|
task_type: str = ""
|
||||||
|
steps: list[str] = field(default_factory=list)
|
||||||
|
total_duration: float = 0.0
|
||||||
|
success_rate: float = 0.0
|
||||||
|
sample_count: int = 0
|
||||||
|
is_recommended: bool = False
|
||||||
|
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PathUpdateResult:
|
||||||
|
"""路径更新结果
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
updated: 是否更新了推荐路径
|
||||||
|
old_path: 更新前的推荐路径(未更新时为 None)
|
||||||
|
new_path: 更新后的推荐路径(未更新时为 None)
|
||||||
|
reason: 更新/未更新的原因说明
|
||||||
|
"""
|
||||||
|
|
||||||
|
updated: bool = False
|
||||||
|
old_path: ExecutionPath | None = None
|
||||||
|
new_path: ExecutionPath | None = None
|
||||||
|
reason: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class PathOptimizer:
|
||||||
|
"""执行路径优化器
|
||||||
|
|
||||||
|
对比新路径与现有最优路径,决定是否更新推荐路径。
|
||||||
|
可独立使用,也可集成到 PlanChecker 的复盘中。
|
||||||
|
|
||||||
|
更新策略:
|
||||||
|
1. 新路径成功率 > 现有成功率 + success_rate_threshold → 更新
|
||||||
|
2. 成功率相近(差值 ≤ threshold)但耗时显著更短
|
||||||
|
(duration 改善比例 > duration_improvement_threshold)→ 更新
|
||||||
|
3. 样本量不足(< min_sample_count)→ 不更新
|
||||||
|
4. 其他情况 → 保留现有推荐路径
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
experience_store: InMemoryExperienceStore | None = None,
|
||||||
|
min_sample_count: int = 3,
|
||||||
|
success_rate_threshold: float = 0.05,
|
||||||
|
duration_improvement_threshold: float = 0.2,
|
||||||
|
):
|
||||||
|
"""初始化 PathOptimizer
|
||||||
|
|
||||||
|
Args:
|
||||||
|
experience_store: 经验存储实例(可选)
|
||||||
|
min_sample_count: 最小样本量,低于此值不做决策
|
||||||
|
success_rate_threshold: 成功率提升阈值,超过此值视为显著提升
|
||||||
|
duration_improvement_threshold: 耗时改善比例阈值,超过此值视为显著改善
|
||||||
|
"""
|
||||||
|
self._experience_store = experience_store
|
||||||
|
self._min_sample_count = min_sample_count
|
||||||
|
self._success_rate_threshold = success_rate_threshold
|
||||||
|
self._duration_improvement_threshold = duration_improvement_threshold
|
||||||
|
self._recommended_paths: dict[str, ExecutionPath] = {}
|
||||||
|
self._pending_paths: dict[str, list[ExecutionPath]] = {}
|
||||||
|
|
||||||
|
def get_recommended_path(self, task_type: str) -> ExecutionPath | None:
|
||||||
|
"""获取指定任务类型的当前推荐路径
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_type: 任务类型
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
推荐路径,若无则返回 None
|
||||||
|
"""
|
||||||
|
return self._recommended_paths.get(task_type)
|
||||||
|
|
||||||
|
async def evaluate_and_update(
|
||||||
|
self,
|
||||||
|
task_type: str,
|
||||||
|
new_path: ExecutionPath,
|
||||||
|
) -> PathUpdateResult:
|
||||||
|
"""评估新路径并决定是否更新推荐路径
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_type: 任务类型
|
||||||
|
new_path: 新的执行路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
路径更新结果
|
||||||
|
"""
|
||||||
|
# 确保新路径有 path_id
|
||||||
|
if not new_path.path_id:
|
||||||
|
new_path.path_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
new_path.task_type = task_type
|
||||||
|
|
||||||
|
# 样本量不足 → 不更新,记录待观察
|
||||||
|
if new_path.sample_count < self._min_sample_count:
|
||||||
|
self._pending_paths.setdefault(task_type, []).append(new_path)
|
||||||
|
reason = (
|
||||||
|
f"样本量不足({new_path.sample_count} < {self._min_sample_count}),"
|
||||||
|
f"记录待观察"
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Path not updated for '{task_type}': {reason}"
|
||||||
|
)
|
||||||
|
return PathUpdateResult(
|
||||||
|
updated=False,
|
||||||
|
old_path=None,
|
||||||
|
new_path=new_path,
|
||||||
|
reason=reason,
|
||||||
|
)
|
||||||
|
|
||||||
|
current = self._recommended_paths.get(task_type)
|
||||||
|
|
||||||
|
# 无现有推荐路径 → 直接设为推荐
|
||||||
|
if current is None:
|
||||||
|
new_path.is_recommended = True
|
||||||
|
self._recommended_paths[task_type] = new_path
|
||||||
|
reason = "无现有推荐路径,直接设为推荐"
|
||||||
|
logger.info(f"Path set as recommended for '{task_type}': {reason}")
|
||||||
|
return PathUpdateResult(
|
||||||
|
updated=True,
|
||||||
|
old_path=None,
|
||||||
|
new_path=new_path,
|
||||||
|
reason=reason,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 比较新路径与现有推荐路径
|
||||||
|
return self._compare_and_decide(task_type, current, new_path)
|
||||||
|
|
||||||
|
def _compare_and_decide(
|
||||||
|
self,
|
||||||
|
task_type: str,
|
||||||
|
current: ExecutionPath,
|
||||||
|
new: ExecutionPath,
|
||||||
|
) -> PathUpdateResult:
|
||||||
|
"""比较新旧路径并决策
|
||||||
|
|
||||||
|
比较逻辑:
|
||||||
|
1. 新路径成功率 > 现有成功率 + threshold → 更新
|
||||||
|
2. 成功率相近(差值 ≤ threshold)且新耗时显著更短 → 更新
|
||||||
|
3. 其他 → 保留现有
|
||||||
|
"""
|
||||||
|
sr_diff = new.success_rate - current.success_rate
|
||||||
|
|
||||||
|
# 条件 1:成功率显著提升
|
||||||
|
if sr_diff > self._success_rate_threshold:
|
||||||
|
return self._apply_update(
|
||||||
|
task_type, current, new,
|
||||||
|
f"成功率显著提升({new.success_rate:.2f} > {current.success_rate:.2f},"
|
||||||
|
f"提升 {sr_diff:.2f})",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 条件 2:成功率相近但耗时显著更短
|
||||||
|
if abs(sr_diff) <= self._success_rate_threshold:
|
||||||
|
if current.total_duration > 0:
|
||||||
|
duration_improvement = (
|
||||||
|
(current.total_duration - new.total_duration) / current.total_duration
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
new.total_duration < current.total_duration
|
||||||
|
and duration_improvement > self._duration_improvement_threshold
|
||||||
|
):
|
||||||
|
return self._apply_update(
|
||||||
|
task_type, current, new,
|
||||||
|
f"成功率相近({new.success_rate:.2f} vs {current.success_rate:.2f}),"
|
||||||
|
f"耗时显著更短({new.total_duration:.1f}s vs {current.total_duration:.1f}s,"
|
||||||
|
f"改善 {duration_improvement:.1%})",
|
||||||
|
)
|
||||||
|
elif current.total_duration == 0 and new.total_duration > 0:
|
||||||
|
# 现有路径耗时为 0(不太可能),不更新
|
||||||
|
pass
|
||||||
|
elif current.total_duration == 0 and new.total_duration == 0:
|
||||||
|
# 两者耗时均为 0,不更新
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 条件 3:无明显优势 → 保留现有
|
||||||
|
reason = (
|
||||||
|
f"新路径无明显优势(成功率 {new.success_rate:.2f} vs {current.success_rate:.2f},"
|
||||||
|
f"耗时 {new.total_duration:.1f}s vs {current.total_duration:.1f}s),保留现有推荐路径"
|
||||||
|
)
|
||||||
|
logger.info(f"Path not updated for '{task_type}': {reason}")
|
||||||
|
return PathUpdateResult(
|
||||||
|
updated=False,
|
||||||
|
old_path=current,
|
||||||
|
new_path=new,
|
||||||
|
reason=reason,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _apply_update(
|
||||||
|
self,
|
||||||
|
task_type: str,
|
||||||
|
old: ExecutionPath,
|
||||||
|
new: ExecutionPath,
|
||||||
|
reason: str,
|
||||||
|
) -> PathUpdateResult:
|
||||||
|
"""应用路径更新"""
|
||||||
|
old.is_recommended = False
|
||||||
|
new.is_recommended = True
|
||||||
|
self._recommended_paths[task_type] = new
|
||||||
|
logger.info(f"Path updated for '{task_type}': {reason}")
|
||||||
|
return PathUpdateResult(
|
||||||
|
updated=True,
|
||||||
|
old_path=old,
|
||||||
|
new_path=new,
|
||||||
|
reason=reason,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_pending_paths(self, task_type: str) -> list[ExecutionPath]:
|
||||||
|
"""获取指定任务类型的待观察路径
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_type: 任务类型
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
待观察路径列表
|
||||||
|
"""
|
||||||
|
return list(self._pending_paths.get(task_type, []))
|
||||||
|
|
@ -0,0 +1,388 @@
|
||||||
|
"""PitfallDetector - 任务避坑预警
|
||||||
|
|
||||||
|
新任务启动时检索历史失败经验,匹配当前计划步骤,自动预警。
|
||||||
|
基于 ExperienceStore 中存储的失败经验,将失败步骤与当前计划步骤
|
||||||
|
进行关键词匹配,计算失败率并按严重程度返回预警列表。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Protocol
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class WarningLevel(str, Enum):
|
||||||
|
"""预警级别"""
|
||||||
|
|
||||||
|
HIGH = "high"
|
||||||
|
MEDIUM = "medium"
|
||||||
|
LOW = "low"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PitfallWarning:
|
||||||
|
"""避坑预警
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
step_name: 计划步骤名称
|
||||||
|
warning_level: 预警级别(HIGH/MEDIUM/LOW)
|
||||||
|
failure_rate: 历史失败率(0.0 ~ 1.0)
|
||||||
|
historical_failures: 历史失败原因列表
|
||||||
|
suggestion: 优化建议
|
||||||
|
"""
|
||||||
|
|
||||||
|
step_name: str
|
||||||
|
warning_level: WarningLevel
|
||||||
|
failure_rate: float
|
||||||
|
historical_failures: list[str] = field(default_factory=list)
|
||||||
|
suggestion: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class ExperienceStoreProtocol(Protocol):
|
||||||
|
"""ExperienceStore 协议接口,用于类型标注"""
|
||||||
|
|
||||||
|
async def search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
top_k: int = 5,
|
||||||
|
task_type: str | None = None,
|
||||||
|
search_multiplier: int = 5,
|
||||||
|
) -> list[Any]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
# 预警级别阈值
|
||||||
|
_HIGH_THRESHOLD = 0.5
|
||||||
|
_MEDIUM_THRESHOLD = 0.2
|
||||||
|
|
||||||
|
|
||||||
|
class PitfallDetector:
|
||||||
|
"""避坑检测器
|
||||||
|
|
||||||
|
新任务启动时检索历史失败经验,匹配当前计划步骤,自动预警。
|
||||||
|
|
||||||
|
使用方式:
|
||||||
|
detector = PitfallDetector(experience_store)
|
||||||
|
warnings = await detector.check_pitfalls(
|
||||||
|
task_type="code_review",
|
||||||
|
planned_steps=[plan_step1, plan_step2, ...],
|
||||||
|
)
|
||||||
|
|
||||||
|
匹配逻辑:
|
||||||
|
1. 检索同类任务的失败经验
|
||||||
|
2. 从失败经验中提取失败步骤
|
||||||
|
3. 将失败步骤与当前计划步骤进行关键词匹配
|
||||||
|
4. 计算失败率并分配预警级别
|
||||||
|
|
||||||
|
预警级别:
|
||||||
|
- HIGH: failure_rate >= 0.5(历史高失败率步骤)
|
||||||
|
- MEDIUM: failure_rate >= 0.2(有失败记录但频率低)
|
||||||
|
- LOW: 有任何失败记录
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
experience_store: ExperienceStoreProtocol,
|
||||||
|
similarity_threshold: float = 0.3,
|
||||||
|
max_search_results: int = 50,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
experience_store: 经验存储实例(ExperienceStore 或 InMemoryExperienceStore)
|
||||||
|
similarity_threshold: 步骤名称关键词匹配的最小相似度阈值
|
||||||
|
max_search_results: 从经验存储检索的最大结果数
|
||||||
|
"""
|
||||||
|
self._store = experience_store
|
||||||
|
self._similarity_threshold = similarity_threshold
|
||||||
|
self._max_search_results = max_search_results
|
||||||
|
|
||||||
|
async def check_pitfalls(
|
||||||
|
self,
|
||||||
|
task_type: str,
|
||||||
|
planned_steps: list[Any],
|
||||||
|
) -> list[PitfallWarning]:
|
||||||
|
"""检查计划步骤中的潜在陷阱
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task_type: 任务类型
|
||||||
|
planned_steps: 计划步骤列表(PlanStep 对象或具有 name/description 属性的对象)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
按严重程度排序的预警列表(HIGH → MEDIUM → LOW)
|
||||||
|
"""
|
||||||
|
if not planned_steps:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 1. 检索同类任务的所有经验(包含成功和失败,用于计算步骤级失败率)
|
||||||
|
all_experiences = await self._search_experiences(task_type)
|
||||||
|
if not all_experiences:
|
||||||
|
logger.debug(f"No experiences found for task_type={task_type}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 2. 从经验中提取步骤级别的失败统计
|
||||||
|
step_failure_stats = self._extract_step_failure_stats(all_experiences)
|
||||||
|
|
||||||
|
# 3. 匹配当前计划步骤并生成预警
|
||||||
|
warnings = self._match_and_warn(planned_steps, step_failure_stats)
|
||||||
|
|
||||||
|
# 4. 按严重程度排序(HIGH → MEDIUM → LOW),同级别按失败率降序
|
||||||
|
warnings.sort(key=lambda w: (_warning_level_order(w.warning_level), -w.failure_rate))
|
||||||
|
|
||||||
|
if warnings:
|
||||||
|
logger.info(
|
||||||
|
f"PitfallDetector found {len(warnings)} warnings for task_type={task_type}: "
|
||||||
|
f"{sum(1 for w in warnings if w.warning_level == WarningLevel.HIGH)} HIGH, "
|
||||||
|
f"{sum(1 for w in warnings if w.warning_level == WarningLevel.MEDIUM)} MEDIUM, "
|
||||||
|
f"{sum(1 for w in warnings if w.warning_level == WarningLevel.LOW)} LOW"
|
||||||
|
)
|
||||||
|
|
||||||
|
return warnings
|
||||||
|
|
||||||
|
async def _search_experiences(self, task_type: str) -> list[Any]:
|
||||||
|
"""检索指定任务类型的所有经验(包含成功和失败)"""
|
||||||
|
try:
|
||||||
|
results = await self._store.search(
|
||||||
|
query=task_type,
|
||||||
|
top_k=self._max_search_results,
|
||||||
|
task_type=task_type,
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to search experiences for pitfall detection: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _extract_step_failure_stats(
|
||||||
|
self, failed_experiences: list[Any]
|
||||||
|
) -> dict[str, _StepFailureStats]:
|
||||||
|
"""从失败经验中提取步骤级别的失败统计
|
||||||
|
|
||||||
|
steps_summary 可以是 str 或 list[dict]:
|
||||||
|
- list[dict]: 每个字典包含 step_name, outcome, duration_seconds, error
|
||||||
|
- str: 退化为整体统计
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
以步骤名称为 key 的失败统计字典
|
||||||
|
"""
|
||||||
|
stats: dict[str, _StepFailureStats] = {}
|
||||||
|
|
||||||
|
for exp in failed_experiences:
|
||||||
|
steps_summary = exp.steps_summary
|
||||||
|
|
||||||
|
# 如果 steps_summary 是字符串,无法提取步骤级信息
|
||||||
|
if isinstance(steps_summary, str):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not isinstance(steps_summary, list):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for step in steps_summary:
|
||||||
|
if not isinstance(step, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
step_name = step.get("step_name", "")
|
||||||
|
if not step_name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
outcome = step.get("outcome", "")
|
||||||
|
error = step.get("error", "")
|
||||||
|
|
||||||
|
if step_name not in stats:
|
||||||
|
stats[step_name] = _StepFailureStats(
|
||||||
|
step_name=step_name,
|
||||||
|
total_occurrences=0,
|
||||||
|
failure_occurrences=0,
|
||||||
|
failure_reasons=[],
|
||||||
|
optimization_tips=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
s = stats[step_name]
|
||||||
|
s.total_occurrences += 1
|
||||||
|
|
||||||
|
if outcome in ("failure", "failed", "error"):
|
||||||
|
s.failure_occurrences += 1
|
||||||
|
if error:
|
||||||
|
s.failure_reasons.append(error)
|
||||||
|
|
||||||
|
# 收集优化建议
|
||||||
|
if hasattr(exp, "optimization_tips") and exp.optimization_tips:
|
||||||
|
for step_name, s in stats.items():
|
||||||
|
s.optimization_tips.extend(exp.optimization_tips)
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|
||||||
|
def _match_and_warn(
|
||||||
|
self,
|
||||||
|
planned_steps: list[Any],
|
||||||
|
step_failure_stats: dict[str, _StepFailureStats],
|
||||||
|
) -> list[PitfallWarning]:
|
||||||
|
"""将计划步骤与失败统计进行匹配,生成预警"""
|
||||||
|
warnings: list[PitfallWarning] = []
|
||||||
|
|
||||||
|
for step in planned_steps:
|
||||||
|
step_name = getattr(step, "name", "")
|
||||||
|
step_description = getattr(step, "description", "")
|
||||||
|
|
||||||
|
if not step_name:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 查找最佳匹配的失败步骤
|
||||||
|
best_match: _StepFailureStats | None = None
|
||||||
|
best_similarity = 0.0
|
||||||
|
|
||||||
|
for stats_step_name, stats in step_failure_stats.items():
|
||||||
|
similarity = _compute_name_similarity(
|
||||||
|
step_name, step_description, stats_step_name
|
||||||
|
)
|
||||||
|
if similarity > best_similarity:
|
||||||
|
best_similarity = similarity
|
||||||
|
best_match = stats
|
||||||
|
|
||||||
|
# 相似度低于阈值,跳过
|
||||||
|
if best_match is None or best_similarity < self._similarity_threshold:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 计算失败率
|
||||||
|
failure_rate = (
|
||||||
|
best_match.failure_occurrences / best_match.total_occurrences
|
||||||
|
if best_match.total_occurrences > 0
|
||||||
|
else 0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# 分配预警级别
|
||||||
|
warning_level = _determine_warning_level(failure_rate)
|
||||||
|
|
||||||
|
# 生成建议
|
||||||
|
suggestion = _build_suggestion(best_match, failure_rate)
|
||||||
|
|
||||||
|
warning = PitfallWarning(
|
||||||
|
step_name=step_name,
|
||||||
|
warning_level=warning_level,
|
||||||
|
failure_rate=round(failure_rate, 4),
|
||||||
|
historical_failures=best_match.failure_reasons[:5], # 最多保留 5 条
|
||||||
|
suggestion=suggestion,
|
||||||
|
)
|
||||||
|
warnings.append(warning)
|
||||||
|
|
||||||
|
return warnings
|
||||||
|
|
||||||
|
|
||||||
|
# ── 内部辅助类 ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _StepFailureStats:
|
||||||
|
"""步骤级别的失败统计(内部使用)"""
|
||||||
|
|
||||||
|
step_name: str
|
||||||
|
total_occurrences: int
|
||||||
|
failure_occurrences: int
|
||||||
|
failure_reasons: list[str]
|
||||||
|
optimization_tips: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
# ── 辅助函数 ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_name_similarity(
|
||||||
|
step_name: str, step_description: str, historical_step_name: str
|
||||||
|
) -> float:
|
||||||
|
"""计算步骤名称的关键词重叠相似度
|
||||||
|
|
||||||
|
基于关键词集合的 Jaccard 相似度,同时考虑 step_name 和 step_description。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
step_name: 当前计划步骤名称
|
||||||
|
step_description: 当前计划步骤描述
|
||||||
|
historical_step_name: 历史步骤名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
相似度分数(0.0 ~ 1.0)
|
||||||
|
"""
|
||||||
|
# 提取关键词:将名称拆分为词,过滤掉常见停用词
|
||||||
|
current_keywords = _extract_keywords(f"{step_name} {step_description}")
|
||||||
|
historical_keywords = _extract_keywords(historical_step_name)
|
||||||
|
|
||||||
|
if not current_keywords or not historical_keywords:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# Jaccard 相似度
|
||||||
|
intersection = current_keywords & historical_keywords
|
||||||
|
union = current_keywords | historical_keywords
|
||||||
|
|
||||||
|
if not union:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
return len(intersection) / len(union)
|
||||||
|
|
||||||
|
|
||||||
|
_STOP_WORDS = frozenset({
|
||||||
|
"a", "an", "the", "and", "or", "but", "in", "on", "at", "to", "for",
|
||||||
|
"of", "with", "by", "from", "is", "are", "was", "were", "be", "been",
|
||||||
|
"being", "have", "has", "had", "do", "does", "did", "will", "would",
|
||||||
|
"could", "should", "may", "might", "can", "shall", "not", "no",
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_keywords(text: str) -> frozenset[str]:
|
||||||
|
"""从文本中提取关键词集合
|
||||||
|
|
||||||
|
转小写、按空白/下划线/连字符拆分、过滤停用词和单字符词。
|
||||||
|
"""
|
||||||
|
# 统一分隔符
|
||||||
|
normalized = text.lower().replace("_", " ").replace("-", " ")
|
||||||
|
words = normalized.split()
|
||||||
|
return frozenset(
|
||||||
|
w for w in words
|
||||||
|
if len(w) > 1 and w not in _STOP_WORDS
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _determine_warning_level(failure_rate: float) -> WarningLevel:
|
||||||
|
"""根据失败率确定预警级别
|
||||||
|
|
||||||
|
- HIGH: failure_rate >= 0.5
|
||||||
|
- MEDIUM: failure_rate >= 0.2
|
||||||
|
- LOW: 有任何失败记录
|
||||||
|
"""
|
||||||
|
if failure_rate >= _HIGH_THRESHOLD:
|
||||||
|
return WarningLevel.HIGH
|
||||||
|
if failure_rate >= _MEDIUM_THRESHOLD:
|
||||||
|
return WarningLevel.MEDIUM
|
||||||
|
return WarningLevel.LOW
|
||||||
|
|
||||||
|
|
||||||
|
def _warning_level_order(level: WarningLevel) -> int:
|
||||||
|
"""预警级别排序值(越小越严重)"""
|
||||||
|
return {
|
||||||
|
WarningLevel.HIGH: 0,
|
||||||
|
WarningLevel.MEDIUM: 1,
|
||||||
|
WarningLevel.LOW: 2,
|
||||||
|
}[level]
|
||||||
|
|
||||||
|
|
||||||
|
def _build_suggestion(stats: _StepFailureStats, failure_rate: float) -> str:
|
||||||
|
"""根据失败统计生成优化建议"""
|
||||||
|
parts: list[str] = []
|
||||||
|
|
||||||
|
if failure_rate >= _HIGH_THRESHOLD:
|
||||||
|
parts.append(f"该步骤历史失败率高达 {failure_rate:.0%},建议重点关注")
|
||||||
|
elif failure_rate >= _MEDIUM_THRESHOLD:
|
||||||
|
parts.append(f"该步骤历史失败率为 {failure_rate:.0%},需注意风险")
|
||||||
|
else:
|
||||||
|
parts.append(f"该步骤有少量失败记录(失败率 {failure_rate:.0%})")
|
||||||
|
|
||||||
|
if stats.failure_reasons:
|
||||||
|
unique_reasons = list(dict.fromkeys(stats.failure_reasons))[:3]
|
||||||
|
reasons_str = "、".join(unique_reasons)
|
||||||
|
parts.append(f"常见失败原因:{reasons_str}")
|
||||||
|
|
||||||
|
if stats.optimization_tips:
|
||||||
|
unique_tips = list(dict.fromkeys(stats.optimization_tips))[:2]
|
||||||
|
tips_str = ";".join(unique_tips)
|
||||||
|
parts.append(f"建议:{tips_str}")
|
||||||
|
|
||||||
|
return "。".join(parts)
|
||||||
|
|
@ -9,6 +9,10 @@ from agentkit.tools.composition import SequentialChain, ParallelFanOut, DynamicS
|
||||||
from agentkit.tools.web_crawl import WebCrawlTool
|
from agentkit.tools.web_crawl import WebCrawlTool
|
||||||
from agentkit.tools.schema_tools import SchemaExtractTool, SchemaGenerateTool
|
from agentkit.tools.schema_tools import SchemaExtractTool, SchemaGenerateTool
|
||||||
from agentkit.tools.baidu_search import BaiduSearchTool
|
from agentkit.tools.baidu_search import BaiduSearchTool
|
||||||
|
from agentkit.tools.shell import ShellTool
|
||||||
|
from agentkit.tools.terminal_session import TerminalSession, TerminalSessionManager
|
||||||
|
from agentkit.tools.pty_session import PTYSession
|
||||||
|
from agentkit.tools.output_parser import OutputParser, ParsedOutput, ErrorType
|
||||||
|
|
||||||
# Conditional import: HeadroomRetrieveTool requires HeadroomCompressor
|
# Conditional import: HeadroomRetrieveTool requires HeadroomCompressor
|
||||||
try:
|
try:
|
||||||
|
|
@ -30,4 +34,11 @@ __all__ = [
|
||||||
"SchemaGenerateTool",
|
"SchemaGenerateTool",
|
||||||
"BaiduSearchTool",
|
"BaiduSearchTool",
|
||||||
"HeadroomRetrieveTool",
|
"HeadroomRetrieveTool",
|
||||||
|
"ShellTool",
|
||||||
|
"TerminalSession",
|
||||||
|
"TerminalSessionManager",
|
||||||
|
"PTYSession",
|
||||||
|
"OutputParser",
|
||||||
|
"ParsedOutput",
|
||||||
|
"ErrorType",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,294 @@
|
||||||
|
"""OutputParser - 结构化解析命令输出
|
||||||
|
|
||||||
|
将命令行输出解析为结构化格式,包含错误类型识别、退出码含义和可操作建议。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
class ErrorType(Enum):
|
||||||
|
"""命令输出错误类型"""
|
||||||
|
|
||||||
|
NONE = "none"
|
||||||
|
PERMISSION_DENIED = "permission_denied"
|
||||||
|
NOT_FOUND = "not_found"
|
||||||
|
TIMEOUT = "timeout"
|
||||||
|
SYNTAX_ERROR = "syntax_error"
|
||||||
|
CONNECTION_REFUSED = "connection_refused"
|
||||||
|
OUT_OF_MEMORY = "out_of_memory"
|
||||||
|
DISK_FULL = "disk_full"
|
||||||
|
ALREADY_EXISTS = "already_exists"
|
||||||
|
INVALID_ARGUMENT = "invalid_argument"
|
||||||
|
PROCESS_NOT_FOUND = "process_not_found"
|
||||||
|
NETWORK_ERROR = "network_error"
|
||||||
|
UNKNOWN = "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ParsedOutput:
|
||||||
|
"""结构化命令输出
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
exit_code: 命令退出码
|
||||||
|
is_error: 是否为错误输出
|
||||||
|
error_type: 错误类型(仅当 is_error=True 时有值)
|
||||||
|
message: 输出消息摘要
|
||||||
|
raw_output: 原始输出文本
|
||||||
|
suggestions: 可操作建议列表
|
||||||
|
"""
|
||||||
|
|
||||||
|
exit_code: int
|
||||||
|
is_error: bool
|
||||||
|
error_type: ErrorType = ErrorType.NONE
|
||||||
|
message: str = ""
|
||||||
|
raw_output: str = ""
|
||||||
|
suggestions: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"exit_code": self.exit_code,
|
||||||
|
"is_error": self.is_error,
|
||||||
|
"error_type": self.error_type.value,
|
||||||
|
"message": self.message,
|
||||||
|
"suggestions": self.suggestions,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# 错误模式匹配规则:(pattern, error_type, message_template, suggestions)
|
||||||
|
_ERROR_PATTERNS: list[tuple[re.Pattern, ErrorType, str, list[str]]] = [
|
||||||
|
(
|
||||||
|
re.compile(r"permission denied|access denied|权限不足|拒绝访问", re.IGNORECASE),
|
||||||
|
ErrorType.PERMISSION_DENIED,
|
||||||
|
"权限不足",
|
||||||
|
[
|
||||||
|
"尝试使用 sudo 执行该命令",
|
||||||
|
"检查文件/目录权限: ls -la <path>",
|
||||||
|
"确认当前用户是否有所需权限",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
re.compile(
|
||||||
|
r"not found|no such file|no such directory|找不到|不存在|无法找到",
|
||||||
|
re.IGNORECASE,
|
||||||
|
),
|
||||||
|
ErrorType.NOT_FOUND,
|
||||||
|
"文件或目录不存在",
|
||||||
|
[
|
||||||
|
"检查路径拼写是否正确",
|
||||||
|
"使用 ls 确认文件/目录是否存在",
|
||||||
|
"检查是否在正确的工作目录下",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
re.compile(r"timed?\s*out|timeout|超时|时间超限", re.IGNORECASE),
|
||||||
|
ErrorType.TIMEOUT,
|
||||||
|
"命令执行超时",
|
||||||
|
[
|
||||||
|
"增加超时时间",
|
||||||
|
"检查网络连接是否正常",
|
||||||
|
"检查目标服务是否可达",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
re.compile(
|
||||||
|
r"syntax error|syntaxerror|parse error|语法错误|解析错误",
|
||||||
|
re.IGNORECASE,
|
||||||
|
),
|
||||||
|
ErrorType.SYNTAX_ERROR,
|
||||||
|
"语法错误",
|
||||||
|
[
|
||||||
|
"检查命令语法是否正确",
|
||||||
|
"使用 --help 查看命令用法",
|
||||||
|
"检查引号和特殊字符是否正确转义",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
re.compile(
|
||||||
|
r"connection refused|连接被拒绝|无法连接|ECONNREFUSED",
|
||||||
|
re.IGNORECASE,
|
||||||
|
),
|
||||||
|
ErrorType.CONNECTION_REFUSED,
|
||||||
|
"连接被拒绝",
|
||||||
|
[
|
||||||
|
"检查目标服务是否已启动",
|
||||||
|
"确认端口号是否正确",
|
||||||
|
"检查防火墙设置是否阻止了连接",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
re.compile(
|
||||||
|
r"out of memory|oom|cannot allocate|内存不足|内存溢出",
|
||||||
|
re.IGNORECASE,
|
||||||
|
),
|
||||||
|
ErrorType.OUT_OF_MEMORY,
|
||||||
|
"内存不足",
|
||||||
|
[
|
||||||
|
"释放不必要的内存占用",
|
||||||
|
"增加系统可用内存",
|
||||||
|
"检查是否有内存泄漏",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
re.compile(
|
||||||
|
r"no space left|disk full|磁盘已满|空间不足|ENOSPC",
|
||||||
|
re.IGNORECASE,
|
||||||
|
),
|
||||||
|
ErrorType.DISK_FULL,
|
||||||
|
"磁盘空间不足",
|
||||||
|
[
|
||||||
|
"清理不必要的文件: du -sh * | sort -rh | head",
|
||||||
|
"检查磁盘使用情况: df -h",
|
||||||
|
"删除临时文件或日志",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
re.compile(
|
||||||
|
r"already exists|file exists|已存在|重复|EEXIST",
|
||||||
|
re.IGNORECASE,
|
||||||
|
),
|
||||||
|
ErrorType.ALREADY_EXISTS,
|
||||||
|
"资源已存在",
|
||||||
|
[
|
||||||
|
"使用 -f 参数强制覆盖(如适用)",
|
||||||
|
"先删除已有资源再重新创建",
|
||||||
|
"使用不同名称创建",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
re.compile(
|
||||||
|
r"invalid argument|illegal option|bad option|无效参数|非法选项|invalid option",
|
||||||
|
re.IGNORECASE,
|
||||||
|
),
|
||||||
|
ErrorType.INVALID_ARGUMENT,
|
||||||
|
"无效参数",
|
||||||
|
[
|
||||||
|
"检查命令参数是否正确",
|
||||||
|
"使用 --help 查看支持的参数",
|
||||||
|
"确认参数值类型和范围",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
re.compile(
|
||||||
|
r"no such process|process not found|进程不存在|进程未找到",
|
||||||
|
re.IGNORECASE,
|
||||||
|
),
|
||||||
|
ErrorType.PROCESS_NOT_FOUND,
|
||||||
|
"进程不存在",
|
||||||
|
[
|
||||||
|
"确认进程 ID 是否正确",
|
||||||
|
"使用 ps aux 查看运行中的进程",
|
||||||
|
"进程可能已经结束",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
(
|
||||||
|
re.compile(
|
||||||
|
r"network is unreachable|no route to host|name resolution|网络不可达|无法解析|ENETUNREACH",
|
||||||
|
re.IGNORECASE,
|
||||||
|
),
|
||||||
|
ErrorType.NETWORK_ERROR,
|
||||||
|
"网络错误",
|
||||||
|
[
|
||||||
|
"检查网络连接是否正常",
|
||||||
|
"确认 DNS 解析是否正常: nslookup <domain>",
|
||||||
|
"检查代理设置",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class OutputParser:
|
||||||
|
"""命令输出结构化解析器
|
||||||
|
|
||||||
|
将命令行输出(stdout + stderr)和退出码解析为结构化的 ParsedOutput,
|
||||||
|
包含错误类型识别、消息摘要和可操作建议。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def parse(self, output: str, exit_code: int) -> ParsedOutput:
|
||||||
|
"""解析命令输出
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output: 命令的标准输出和错误输出合并文本
|
||||||
|
exit_code: 命令退出码
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ParsedOutput 结构化解析结果
|
||||||
|
"""
|
||||||
|
is_error = exit_code != 0
|
||||||
|
message = self._extract_message(output)
|
||||||
|
error_type = ErrorType.NONE
|
||||||
|
suggestions: list[str] = []
|
||||||
|
|
||||||
|
if is_error:
|
||||||
|
error_type, suggestions = self._classify_error(output, exit_code)
|
||||||
|
|
||||||
|
return ParsedOutput(
|
||||||
|
exit_code=exit_code,
|
||||||
|
is_error=is_error,
|
||||||
|
error_type=error_type,
|
||||||
|
message=message,
|
||||||
|
raw_output=output,
|
||||||
|
suggestions=suggestions,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _extract_message(self, output: str) -> str:
|
||||||
|
"""从输出中提取关键消息
|
||||||
|
|
||||||
|
取最后几行非空输出中的关键行作为消息摘要。
|
||||||
|
"""
|
||||||
|
if not output:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
lines = [line.strip() for line in output.strip().splitlines() if line.strip()]
|
||||||
|
if not lines:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# 取最后一行作为摘要,如果太长则截断
|
||||||
|
message = lines[-1]
|
||||||
|
if len(message) > 200:
|
||||||
|
message = message[:200] + "..."
|
||||||
|
return message
|
||||||
|
|
||||||
|
def _classify_error(
|
||||||
|
self, output: str, exit_code: int
|
||||||
|
) -> tuple[ErrorType, list[str]]:
|
||||||
|
"""根据输出内容和退出码分类错误类型
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output: 命令输出
|
||||||
|
exit_code: 退出码
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(error_type, suggestions) 元组
|
||||||
|
"""
|
||||||
|
# 优先根据输出内容匹配
|
||||||
|
for pattern, error_type, _msg, suggestions in _ERROR_PATTERNS:
|
||||||
|
if pattern.search(output):
|
||||||
|
return error_type, suggestions
|
||||||
|
|
||||||
|
# 退出码兜底分类
|
||||||
|
if exit_code == 126:
|
||||||
|
return ErrorType.PERMISSION_DENIED, [
|
||||||
|
"检查文件是否有执行权限: chmod +x <file>",
|
||||||
|
"确认文件格式是否正确(如行尾符)",
|
||||||
|
]
|
||||||
|
if exit_code == 127:
|
||||||
|
return ErrorType.NOT_FOUND, [
|
||||||
|
"检查命令是否已安装",
|
||||||
|
"确认命令名称拼写是否正确",
|
||||||
|
"检查 PATH 环境变量是否包含命令所在目录",
|
||||||
|
]
|
||||||
|
if exit_code == 130:
|
||||||
|
return ErrorType.TIMEOUT, [
|
||||||
|
"命令被 Ctrl+C 中断",
|
||||||
|
"可能需要增加超时时间",
|
||||||
|
]
|
||||||
|
|
||||||
|
return ErrorType.UNKNOWN, [
|
||||||
|
"检查命令输出中的错误信息",
|
||||||
|
"使用 --verbose 或 --debug 获取更多详情",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,341 @@
|
||||||
|
"""PTYSession - 伪终端会话,支持交互式命令
|
||||||
|
|
||||||
|
基于 asyncio + os.openpty() 实现伪终端,支持交互式命令和自动应答。
|
||||||
|
不依赖 pexpect,仅使用标准库。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import fcntl
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import struct
|
||||||
|
import termios
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 自动应答规则:(prompt_pattern, response)
|
||||||
|
_AUTO_RESPOND_RULES: list[tuple[str, str]] = [
|
||||||
|
(r"\[y/N\]\s*$", "y"),
|
||||||
|
(r"\[Y/n\]\s*$", "y"),
|
||||||
|
(r"\[yes/no\]\s*$", "yes"),
|
||||||
|
(r"\(yes/no\)\s*$", "yes"),
|
||||||
|
(r"\(yes/no/\[fingerprint\]\)\s*$", "yes"),
|
||||||
|
(r"continue\?\s*$", "y"),
|
||||||
|
(r"are you sure\?\s*$", "y"),
|
||||||
|
(r"password:\s*$", ""), # 密码提示不自动应答,需要人工介入
|
||||||
|
(r"passphrase:\s*$", ""),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PTYOutput:
|
||||||
|
"""PTY 输出结果
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
output: 输出文本
|
||||||
|
exit_code: 退出码(-1 表示超时或未结束)
|
||||||
|
timed_out: 是否超时
|
||||||
|
"""
|
||||||
|
|
||||||
|
output: str
|
||||||
|
exit_code: int = -1
|
||||||
|
timed_out: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class PTYSession:
|
||||||
|
"""伪终端会话 - 支持交互式命令
|
||||||
|
|
||||||
|
使用 os.openpty() 创建伪终端对,通过 asyncio 异步读写。
|
||||||
|
支持自动检测提示并应答(如 yes/no 确认)。
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
pty = PTYSession()
|
||||||
|
await pty.start()
|
||||||
|
output = await pty.run_command("ssh-keygen", timeout=10)
|
||||||
|
await pty.close()
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
auto_respond: bool = True,
|
||||||
|
custom_rules: list[tuple[str, str]] | None = None,
|
||||||
|
default_timeout: float = 30.0,
|
||||||
|
buffer_size: int = 4096,
|
||||||
|
):
|
||||||
|
"""初始化 PTY 会话
|
||||||
|
|
||||||
|
Args:
|
||||||
|
auto_respond: 是否自动应答已知提示
|
||||||
|
custom_rules: 自定义应答规则列表 [(prompt_pattern, response)]
|
||||||
|
default_timeout: 默认超时时间(秒)
|
||||||
|
buffer_size: 读取缓冲区大小
|
||||||
|
"""
|
||||||
|
self._auto_respond = auto_respond
|
||||||
|
self._respond_rules = list(_AUTO_RESPOND_RULES)
|
||||||
|
if custom_rules:
|
||||||
|
self._respond_rules.extend(custom_rules)
|
||||||
|
self._default_timeout = default_timeout
|
||||||
|
self._buffer_size = buffer_size
|
||||||
|
|
||||||
|
self._master_fd: int | None = None
|
||||||
|
self._slave_fd: int | None = None
|
||||||
|
self._process: asyncio.subprocess.Process | None = None
|
||||||
|
self._running = False
|
||||||
|
self._output_buffer = ""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_running(self) -> bool:
|
||||||
|
"""PTY 会话是否已启动(伪终端已创建)"""
|
||||||
|
return self._running
|
||||||
|
|
||||||
|
async def start(self) -> None:
|
||||||
|
"""启动 PTY 会话(创建伪终端对)
|
||||||
|
|
||||||
|
在执行命令前调用,创建 master/slave 文件描述符。
|
||||||
|
"""
|
||||||
|
if self._running:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._master_fd, self._slave_fd = os.openpty()
|
||||||
|
|
||||||
|
# 设置终端大小
|
||||||
|
try:
|
||||||
|
winsize = struct.pack("HHHH", 24, 80, 0, 0)
|
||||||
|
fcntl.ioctl(self._slave_fd, termios.TIOCSWINSZ, winsize)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 设置 master fd 为非阻塞
|
||||||
|
flags = fcntl.fcntl(self._master_fd, fcntl.F_GETFL)
|
||||||
|
fcntl.fcntl(self._master_fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
|
||||||
|
|
||||||
|
self._running = True
|
||||||
|
logger.debug("PTY session started: master=%d slave=%d", self._master_fd, self._slave_fd)
|
||||||
|
|
||||||
|
async def run_command(
|
||||||
|
self,
|
||||||
|
command: str,
|
||||||
|
timeout: float | None = None,
|
||||||
|
cwd: str | None = None,
|
||||||
|
env: dict[str, str] | None = None,
|
||||||
|
) -> PTYOutput:
|
||||||
|
"""在 PTY 中运行命令,等待完成
|
||||||
|
|
||||||
|
Args:
|
||||||
|
command: 要执行的命令
|
||||||
|
timeout: 超时时间(秒)
|
||||||
|
cwd: 工作目录
|
||||||
|
env: 环境变量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PTYOutput 输出结果
|
||||||
|
"""
|
||||||
|
if not self._running:
|
||||||
|
await self.start()
|
||||||
|
|
||||||
|
timeout = timeout or self._default_timeout
|
||||||
|
self._output_buffer = ""
|
||||||
|
|
||||||
|
# 构建环境
|
||||||
|
cmd_env = dict(os.environ)
|
||||||
|
if env:
|
||||||
|
cmd_env.update(env)
|
||||||
|
|
||||||
|
# 启动子进程,使用 slave 作为 stdin/stdout/stderr
|
||||||
|
self._process = await asyncio.create_subprocess_shell(
|
||||||
|
command,
|
||||||
|
stdin=self._slave_fd,
|
||||||
|
stdout=self._slave_fd,
|
||||||
|
stderr=self._slave_fd,
|
||||||
|
cwd=cwd,
|
||||||
|
env=cmd_env,
|
||||||
|
start_new_session=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 关闭 slave 端(子进程已继承)
|
||||||
|
os.close(self._slave_fd)
|
||||||
|
self._slave_fd = None
|
||||||
|
|
||||||
|
# 异步读取输出
|
||||||
|
try:
|
||||||
|
exit_code = await self._read_until_exit(timeout)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
self._output_buffer += "\n[PTY 命令执行超时]"
|
||||||
|
return PTYOutput(
|
||||||
|
output=self._output_buffer,
|
||||||
|
exit_code=-1,
|
||||||
|
timed_out=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return PTYOutput(
|
||||||
|
output=self._output_buffer,
|
||||||
|
exit_code=exit_code,
|
||||||
|
timed_out=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def send(self, line: str) -> None:
|
||||||
|
"""向 PTY 发送一行输入
|
||||||
|
|
||||||
|
Args:
|
||||||
|
line: 要发送的文本(自动追加换行符)
|
||||||
|
"""
|
||||||
|
if self._master_fd is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
data = (line + "\n").encode("utf-8")
|
||||||
|
try:
|
||||||
|
os.write(self._master_fd, data)
|
||||||
|
except OSError as e:
|
||||||
|
logger.warning("PTY write failed: %s", e)
|
||||||
|
|
||||||
|
async def read_output(self, timeout: float = 1.0) -> str:
|
||||||
|
"""读取当前可用的 PTY 输出
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeout: 读取超时(秒)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
读取到的输出文本
|
||||||
|
"""
|
||||||
|
if self._master_fd is None:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
output = ""
|
||||||
|
deadline = time.monotonic() + timeout
|
||||||
|
|
||||||
|
while time.monotonic() < deadline:
|
||||||
|
try:
|
||||||
|
chunk = os.read(self._master_fd, self._buffer_size)
|
||||||
|
if chunk:
|
||||||
|
text = chunk.decode("utf-8", errors="replace")
|
||||||
|
output += text
|
||||||
|
self._output_buffer += text
|
||||||
|
|
||||||
|
# 自动应答
|
||||||
|
if self._auto_respond:
|
||||||
|
await self._try_auto_respond(text)
|
||||||
|
except (BlockingIOError, OSError):
|
||||||
|
# 没有数据可读
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if output:
|
||||||
|
# 读取到数据后短暂等待看是否还有更多
|
||||||
|
await asyncio.sleep(0.05)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
"""关闭 PTY 会话,清理资源"""
|
||||||
|
self._running = False
|
||||||
|
|
||||||
|
if self._process is not None and self._process.returncode is None:
|
||||||
|
try:
|
||||||
|
self._process.terminate()
|
||||||
|
await asyncio.wait_for(self._process.wait(), timeout=5.0)
|
||||||
|
except (asyncio.TimeoutError, ProcessLookupError):
|
||||||
|
try:
|
||||||
|
self._process.kill()
|
||||||
|
except ProcessLookupError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if self._master_fd is not None:
|
||||||
|
try:
|
||||||
|
os.close(self._master_fd)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
self._master_fd = None
|
||||||
|
|
||||||
|
if self._slave_fd is not None:
|
||||||
|
try:
|
||||||
|
os.close(self._slave_fd)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
self._slave_fd = None
|
||||||
|
|
||||||
|
self._process = None
|
||||||
|
logger.debug("PTY session closed")
|
||||||
|
|
||||||
|
async def _read_until_exit(self, timeout: float) -> int:
|
||||||
|
"""持续读取输出直到进程退出
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeout: 超时时间(秒)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
进程退出码
|
||||||
|
"""
|
||||||
|
deadline = time.monotonic() + timeout
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# 检查超时
|
||||||
|
if time.monotonic() > deadline:
|
||||||
|
raise asyncio.TimeoutError()
|
||||||
|
|
||||||
|
# 检查进程是否已退出
|
||||||
|
if self._process.returncode is not None:
|
||||||
|
# 进程已退出,再读一次剩余输出
|
||||||
|
await self._drain_remaining_output()
|
||||||
|
return self._process.returncode
|
||||||
|
|
||||||
|
# 读取输出
|
||||||
|
try:
|
||||||
|
chunk = os.read(self._master_fd, self._buffer_size)
|
||||||
|
if chunk:
|
||||||
|
text = chunk.decode("utf-8", errors="replace")
|
||||||
|
self._output_buffer += text
|
||||||
|
|
||||||
|
# 自动应答
|
||||||
|
if self._auto_respond:
|
||||||
|
await self._try_auto_respond(text)
|
||||||
|
except (BlockingIOError, OSError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 检查进程状态
|
||||||
|
if self._process.returncode is None:
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(
|
||||||
|
self._process.wait(), timeout=0.05
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
await self._drain_remaining_output()
|
||||||
|
return self._process.returncode
|
||||||
|
|
||||||
|
await asyncio.sleep(0.02)
|
||||||
|
|
||||||
|
async def _drain_remaining_output(self) -> None:
|
||||||
|
"""排空剩余输出"""
|
||||||
|
for _ in range(10): # 最多尝试 10 次
|
||||||
|
try:
|
||||||
|
chunk = os.read(self._master_fd, self._buffer_size)
|
||||||
|
if chunk:
|
||||||
|
text = chunk.decode("utf-8", errors="replace")
|
||||||
|
self._output_buffer += text
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
except (BlockingIOError, OSError):
|
||||||
|
break
|
||||||
|
await asyncio.sleep(0.01)
|
||||||
|
|
||||||
|
async def _try_auto_respond(self, recent_output: str) -> None:
|
||||||
|
"""检测提示并自动应答
|
||||||
|
|
||||||
|
Args:
|
||||||
|
recent_output: 最近的输出文本
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
|
||||||
|
for pattern, response in self._respond_rules:
|
||||||
|
if not response:
|
||||||
|
# 空响应规则(如密码提示)跳过
|
||||||
|
continue
|
||||||
|
if re.search(pattern, recent_output, re.IGNORECASE | re.MULTILINE):
|
||||||
|
logger.debug("Auto-responding to prompt '%s' with '%s'", pattern, response)
|
||||||
|
await self.send(response)
|
||||||
|
break
|
||||||
|
|
@ -0,0 +1,432 @@
|
||||||
|
"""ShellTool - Shell 命令执行工具
|
||||||
|
|
||||||
|
支持无会话模式(向后兼容)和有会话模式(跨命令保持状态)。
|
||||||
|
危险命令通过确认回调请求人工确认,所有操作记录审计日志。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from typing import Any, Callable, Awaitable
|
||||||
|
|
||||||
|
from agentkit.tools.base import Tool
|
||||||
|
from agentkit.tools.output_parser import OutputParser, ParsedOutput
|
||||||
|
from agentkit.tools.terminal_session import TerminalSession, TerminalSessionManager
|
||||||
|
from agentkit.tools.pty_session import PTYSession
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 安全白名单:这些命令前缀不需要确认
|
||||||
|
_SAFE_COMMAND_PREFIXES: tuple[str, ...] = (
|
||||||
|
"ls",
|
||||||
|
"cat",
|
||||||
|
"head",
|
||||||
|
"tail",
|
||||||
|
"grep",
|
||||||
|
"find",
|
||||||
|
"pwd",
|
||||||
|
"echo",
|
||||||
|
"which",
|
||||||
|
"whoami",
|
||||||
|
"id",
|
||||||
|
"date",
|
||||||
|
"uname",
|
||||||
|
"df",
|
||||||
|
"du",
|
||||||
|
"free",
|
||||||
|
"ps",
|
||||||
|
"top",
|
||||||
|
"env",
|
||||||
|
"printenv",
|
||||||
|
"type",
|
||||||
|
"file",
|
||||||
|
"stat",
|
||||||
|
"wc",
|
||||||
|
"sort",
|
||||||
|
"uniq",
|
||||||
|
"diff",
|
||||||
|
"git status",
|
||||||
|
"git log",
|
||||||
|
"git diff",
|
||||||
|
"git branch",
|
||||||
|
"git remote",
|
||||||
|
"pip list",
|
||||||
|
"pip show",
|
||||||
|
"python --version",
|
||||||
|
"python3 --version",
|
||||||
|
"node --version",
|
||||||
|
"npm list",
|
||||||
|
"docker ps",
|
||||||
|
"docker images",
|
||||||
|
"curl",
|
||||||
|
"wget",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 危险命令模式:这些命令需要人工确认
|
||||||
|
_DANGEROUS_PATTERNS: tuple[str, ...] = (
|
||||||
|
"rm ",
|
||||||
|
"rm -",
|
||||||
|
"rmdir",
|
||||||
|
"mkfs",
|
||||||
|
"dd ",
|
||||||
|
"format",
|
||||||
|
"del ",
|
||||||
|
"erase",
|
||||||
|
"> /dev/",
|
||||||
|
"shutdown",
|
||||||
|
"reboot",
|
||||||
|
"init 0",
|
||||||
|
"init 6",
|
||||||
|
"kill -9",
|
||||||
|
"killall",
|
||||||
|
"chmod 777",
|
||||||
|
"chown",
|
||||||
|
"mv /",
|
||||||
|
"pip uninstall",
|
||||||
|
"npm uninstall",
|
||||||
|
"apt remove",
|
||||||
|
"yum remove",
|
||||||
|
"brew uninstall",
|
||||||
|
"docker rm",
|
||||||
|
"docker rmi",
|
||||||
|
"git push --force",
|
||||||
|
"git reset --hard",
|
||||||
|
"git clean -f",
|
||||||
|
"drop table",
|
||||||
|
"drop database",
|
||||||
|
"truncate",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ShellTool(Tool):
|
||||||
|
"""Shell 命令执行工具
|
||||||
|
|
||||||
|
支持两种模式:
|
||||||
|
1. 无会话模式(默认):每次命令独立执行,不保持状态
|
||||||
|
2. 有会话模式:通过 session_id 指定会话,跨命令保持 cwd/env/history
|
||||||
|
|
||||||
|
安全控制:
|
||||||
|
- 危险命令通过 confirm_callback 请求人工确认
|
||||||
|
- 所有操作记录审计日志
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# 无会话模式
|
||||||
|
tool = ShellTool()
|
||||||
|
result = await tool.execute(command="ls -la")
|
||||||
|
|
||||||
|
# 有会话模式
|
||||||
|
result = await tool.execute(command="cd /tmp", session_id="build-01")
|
||||||
|
result = await tool.execute(command="pwd", session_id="build-01") # 输出 /tmp
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str = "shell",
|
||||||
|
description: str = "执行 Shell 命令,支持会话模式保持跨命令状态",
|
||||||
|
input_schema: dict[str, Any] | None = None,
|
||||||
|
output_schema: dict[str, Any] | None = None,
|
||||||
|
version: str = "1.0.0",
|
||||||
|
tags: list[str] | None = None,
|
||||||
|
confirm_callback: Callable[[str], Awaitable[bool]] | None = None,
|
||||||
|
default_timeout: float = 60.0,
|
||||||
|
max_output_length: int = 50000,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
|
input_schema=input_schema or self._default_input_schema(),
|
||||||
|
output_schema=output_schema or self._default_output_schema(),
|
||||||
|
version=version,
|
||||||
|
tags=tags or ["shell", "terminal", "system"],
|
||||||
|
)
|
||||||
|
self._session_manager = TerminalSessionManager()
|
||||||
|
self._output_parser = OutputParser()
|
||||||
|
self._confirm_callback = confirm_callback
|
||||||
|
self._default_timeout = default_timeout
|
||||||
|
self._max_output_length = max_output_length
|
||||||
|
self._audit_log: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _default_input_schema() -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"command": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "要执行的 Shell 命令",
|
||||||
|
},
|
||||||
|
"timeout": {
|
||||||
|
"type": "number",
|
||||||
|
"description": "超时时间(秒),默认 60",
|
||||||
|
"default": 60,
|
||||||
|
},
|
||||||
|
"working_dir": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "工作目录(仅无会话模式有效)",
|
||||||
|
},
|
||||||
|
"session_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "会话 ID,指定后在会话中执行命令,跨命令保持状态",
|
||||||
|
},
|
||||||
|
"interactive": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "是否使用交互式模式(PTY),用于需要用户输入的命令",
|
||||||
|
"default": False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["command"],
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _default_output_schema() -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"output": {"type": "string", "description": "命令输出"},
|
||||||
|
"exit_code": {"type": "integer", "description": "退出码"},
|
||||||
|
"is_error": {"type": "boolean", "description": "是否为错误"},
|
||||||
|
"error_type": {"type": "string", "description": "错误类型"},
|
||||||
|
"message": {"type": "string", "description": "消息摘要"},
|
||||||
|
"suggestions": {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"description": "可操作建议",
|
||||||
|
},
|
||||||
|
"session_id": {"type": "string", "description": "会话 ID(仅会话模式)"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
async def execute(self, **kwargs) -> dict:
|
||||||
|
"""执行 Shell 命令
|
||||||
|
|
||||||
|
Args:
|
||||||
|
command: 要执行的命令(必需)
|
||||||
|
timeout: 超时时间(秒)
|
||||||
|
working_dir: 工作目录(仅无会话模式)
|
||||||
|
session_id: 会话 ID(启用会话模式)
|
||||||
|
interactive: 是否使用交互式模式
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含 output, exit_code, is_error 等字段的字典
|
||||||
|
"""
|
||||||
|
command = kwargs.get("command")
|
||||||
|
if not command:
|
||||||
|
return {
|
||||||
|
"output": "",
|
||||||
|
"exit_code": 1,
|
||||||
|
"is_error": True,
|
||||||
|
"error_type": "invalid_argument",
|
||||||
|
"message": "command 参数是必需的",
|
||||||
|
"suggestions": ["提供要执行的 Shell 命令"],
|
||||||
|
}
|
||||||
|
|
||||||
|
timeout = kwargs.get("timeout", self._default_timeout)
|
||||||
|
working_dir = kwargs.get("working_dir")
|
||||||
|
session_id = kwargs.get("session_id")
|
||||||
|
interactive = kwargs.get("interactive", False)
|
||||||
|
|
||||||
|
# 安全检查:危险命令需要确认
|
||||||
|
if self._is_dangerous(command):
|
||||||
|
confirmed = await self._request_confirmation(command)
|
||||||
|
if not confirmed:
|
||||||
|
self._log_audit(command, None, blocked=True)
|
||||||
|
return {
|
||||||
|
"output": "",
|
||||||
|
"exit_code": 126,
|
||||||
|
"is_error": True,
|
||||||
|
"error_type": "permission_denied",
|
||||||
|
"message": f"危险命令已被拒绝执行: {command[:100]}",
|
||||||
|
"suggestions": [
|
||||||
|
"如需执行此命令,请手动确认",
|
||||||
|
"考虑使用更安全的替代命令",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
# 根据模式执行
|
||||||
|
if session_id:
|
||||||
|
result = await self._execute_in_session(
|
||||||
|
command, session_id, timeout, working_dir, interactive
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result = await self._execute_standalone(command, timeout, working_dir, interactive)
|
||||||
|
|
||||||
|
# 审计日志
|
||||||
|
self._log_audit(command, session_id, exit_code=result.exit_code)
|
||||||
|
|
||||||
|
# 截断过长输出
|
||||||
|
output = result.raw_output
|
||||||
|
if len(output) > self._max_output_length:
|
||||||
|
output = output[: self._max_output_length] + "\n... [输出已截断]"
|
||||||
|
|
||||||
|
return {
|
||||||
|
"output": output,
|
||||||
|
"exit_code": result.exit_code,
|
||||||
|
"is_error": result.is_error,
|
||||||
|
"error_type": result.error_type.value,
|
||||||
|
"message": result.message,
|
||||||
|
"suggestions": result.suggestions,
|
||||||
|
"session_id": session_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _execute_standalone(
|
||||||
|
self,
|
||||||
|
command: str,
|
||||||
|
timeout: float,
|
||||||
|
working_dir: str | None,
|
||||||
|
interactive: bool,
|
||||||
|
) -> ParsedOutput:
|
||||||
|
"""无会话模式执行命令(向后兼容)"""
|
||||||
|
if interactive:
|
||||||
|
return await self._execute_with_pty(command, timeout, working_dir)
|
||||||
|
|
||||||
|
start = time.monotonic()
|
||||||
|
try:
|
||||||
|
proc = await asyncio.create_subprocess_shell(
|
||||||
|
command,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.STDOUT,
|
||||||
|
cwd=working_dir,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
stdout, _ = await asyncio.wait_for(
|
||||||
|
proc.communicate(),
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
proc.kill()
|
||||||
|
await proc.wait()
|
||||||
|
output = f"命令执行超时({timeout}s)"
|
||||||
|
exit_code = -1
|
||||||
|
else:
|
||||||
|
output = stdout.decode("utf-8", errors="replace") if stdout else ""
|
||||||
|
exit_code = proc.returncode if proc.returncode is not None else 0
|
||||||
|
except Exception as e:
|
||||||
|
output = str(e)
|
||||||
|
exit_code = -1
|
||||||
|
|
||||||
|
return self._output_parser.parse(output, exit_code)
|
||||||
|
|
||||||
|
async def _execute_in_session(
|
||||||
|
self,
|
||||||
|
command: str,
|
||||||
|
session_id: str,
|
||||||
|
timeout: float,
|
||||||
|
working_dir: str | None,
|
||||||
|
interactive: bool,
|
||||||
|
) -> ParsedOutput:
|
||||||
|
"""会话模式执行命令"""
|
||||||
|
session = self._session_manager.get_or_create(
|
||||||
|
session_id,
|
||||||
|
cwd=working_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
if interactive:
|
||||||
|
return await self._execute_with_pty(
|
||||||
|
command, timeout, session.cwd, session.env
|
||||||
|
)
|
||||||
|
|
||||||
|
return await session.execute(command, timeout=timeout)
|
||||||
|
|
||||||
|
async def _execute_with_pty(
|
||||||
|
self,
|
||||||
|
command: str,
|
||||||
|
timeout: float,
|
||||||
|
cwd: str | None = None,
|
||||||
|
env: dict[str, str] | None = None,
|
||||||
|
) -> ParsedOutput:
|
||||||
|
"""使用 PTY 执行交互式命令"""
|
||||||
|
pty = PTYSession()
|
||||||
|
try:
|
||||||
|
await pty.start()
|
||||||
|
result = await pty.run_command(
|
||||||
|
command,
|
||||||
|
timeout=timeout,
|
||||||
|
cwd=cwd,
|
||||||
|
env=env,
|
||||||
|
)
|
||||||
|
output = result.output
|
||||||
|
exit_code = result.exit_code
|
||||||
|
except Exception as e:
|
||||||
|
output = str(e)
|
||||||
|
exit_code = -1
|
||||||
|
finally:
|
||||||
|
await pty.close()
|
||||||
|
|
||||||
|
return self._output_parser.parse(output, exit_code)
|
||||||
|
|
||||||
|
def _is_dangerous(self, command: str) -> bool:
|
||||||
|
"""检查命令是否为危险操作
|
||||||
|
|
||||||
|
白名单命令直接放行,其他命令检查是否匹配危险模式。
|
||||||
|
"""
|
||||||
|
command_stripped = command.strip()
|
||||||
|
|
||||||
|
# 白名单检查
|
||||||
|
for prefix in _SAFE_COMMAND_PREFIXES:
|
||||||
|
if command_stripped.startswith(prefix):
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 危险模式检查
|
||||||
|
command_lower = command_stripped.lower()
|
||||||
|
for pattern in _DANGEROUS_PATTERNS:
|
||||||
|
if pattern in command_lower:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def _request_confirmation(self, command: str) -> bool:
|
||||||
|
"""请求人工确认危险命令
|
||||||
|
|
||||||
|
Args:
|
||||||
|
command: 待确认的命令
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否确认执行
|
||||||
|
"""
|
||||||
|
if self._confirm_callback:
|
||||||
|
try:
|
||||||
|
return await self._confirm_callback(command)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("确认回调执行失败: %s", e)
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 无回调时默认拒绝
|
||||||
|
logger.warning("危险命令被拒绝(无确认回调): %s", command[:100])
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _log_audit(
|
||||||
|
self,
|
||||||
|
command: str,
|
||||||
|
session_id: str | None,
|
||||||
|
exit_code: int | None = None,
|
||||||
|
blocked: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""记录审计日志"""
|
||||||
|
entry = {
|
||||||
|
"timestamp": time.time(),
|
||||||
|
"command": command[:500],
|
||||||
|
"session_id": session_id,
|
||||||
|
"exit_code": exit_code,
|
||||||
|
"blocked": blocked,
|
||||||
|
}
|
||||||
|
self._audit_log.append(entry)
|
||||||
|
logger.info(
|
||||||
|
"Shell audit: command=%r session=%s exit=%s blocked=%s",
|
||||||
|
command[:100],
|
||||||
|
session_id,
|
||||||
|
exit_code,
|
||||||
|
blocked,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def session_manager(self) -> TerminalSessionManager:
|
||||||
|
"""获取会话管理器"""
|
||||||
|
return self._session_manager
|
||||||
|
|
||||||
|
@property
|
||||||
|
def audit_log(self) -> list[dict[str, Any]]:
|
||||||
|
"""获取审计日志(副本)"""
|
||||||
|
return list(self._audit_log)
|
||||||
|
|
@ -0,0 +1,352 @@
|
||||||
|
"""TerminalSession - 终端会话状态管理
|
||||||
|
|
||||||
|
维护 cwd、env、history,支持跨命令保持状态。
|
||||||
|
通过在命令前注入 cd 和 export 语句实现跨命令状态持久化。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from agentkit.tools.output_parser import OutputParser, ParsedOutput
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CommandRecord:
|
||||||
|
"""命令执行记录
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
command: 执行的命令
|
||||||
|
exit_code: 退出码
|
||||||
|
output: 标准输出+错误输出
|
||||||
|
cwd: 执行时的工作目录
|
||||||
|
timestamp: 执行时间戳
|
||||||
|
duration_ms: 执行耗时(毫秒)
|
||||||
|
"""
|
||||||
|
|
||||||
|
command: str
|
||||||
|
exit_code: int
|
||||||
|
output: str
|
||||||
|
cwd: str
|
||||||
|
timestamp: float
|
||||||
|
duration_ms: int
|
||||||
|
|
||||||
|
|
||||||
|
class TerminalSession:
|
||||||
|
"""终端会话 - 跨命令保持 cwd/env/history 状态
|
||||||
|
|
||||||
|
通过在命令前注入 `cd {cwd} && ` 和 `export K=V && ` 实现跨命令状态持久化。
|
||||||
|
每次命令执行后自动更新 cwd 和 env 状态。
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
session = TerminalSession(session_id="build-01")
|
||||||
|
result = await session.execute("cd /tmp")
|
||||||
|
result = await session.execute("pwd") # 输出 /tmp
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
session_id: str,
|
||||||
|
cwd: str | None = None,
|
||||||
|
env: dict[str, str] | None = None,
|
||||||
|
max_history: int = 1000,
|
||||||
|
):
|
||||||
|
self.session_id = session_id
|
||||||
|
self._cwd = cwd or os.getcwd()
|
||||||
|
self._env: dict[str, str] = dict(env or os.environ)
|
||||||
|
self._history: list[CommandRecord] = []
|
||||||
|
self._max_history = max_history
|
||||||
|
self._output_parser = OutputParser()
|
||||||
|
self._created_at = time.time()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def cwd(self) -> str:
|
||||||
|
"""当前工作目录"""
|
||||||
|
return self._cwd
|
||||||
|
|
||||||
|
@property
|
||||||
|
def env(self) -> dict[str, str]:
|
||||||
|
"""当前环境变量(副本)"""
|
||||||
|
return dict(self._env)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def history(self) -> list[CommandRecord]:
|
||||||
|
"""命令执行历史(副本)"""
|
||||||
|
return list(self._history)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def created_at(self) -> float:
|
||||||
|
"""会话创建时间戳"""
|
||||||
|
return self._created_at
|
||||||
|
|
||||||
|
def get_cwd(self) -> str:
|
||||||
|
"""获取当前工作目录"""
|
||||||
|
return self._cwd
|
||||||
|
|
||||||
|
def set_cwd(self, cwd: str) -> None:
|
||||||
|
"""手动设置当前工作目录"""
|
||||||
|
self._cwd = cwd
|
||||||
|
|
||||||
|
def get_env(self) -> dict[str, str]:
|
||||||
|
"""获取当前环境变量(副本)"""
|
||||||
|
return dict(self._env)
|
||||||
|
|
||||||
|
def set_env(self, key: str, value: str) -> None:
|
||||||
|
"""设置单个环境变量"""
|
||||||
|
self._env[key] = value
|
||||||
|
|
||||||
|
def update_env(self, env: dict[str, str]) -> None:
|
||||||
|
"""批量更新环境变量"""
|
||||||
|
self._env.update(env)
|
||||||
|
|
||||||
|
def get_history(self) -> list[CommandRecord]:
|
||||||
|
"""获取命令执行历史(副本)"""
|
||||||
|
return list(self._history)
|
||||||
|
|
||||||
|
async def execute(
|
||||||
|
self,
|
||||||
|
command: str,
|
||||||
|
timeout: float | None = None,
|
||||||
|
) -> ParsedOutput:
|
||||||
|
"""在会话上下文中执行命令
|
||||||
|
|
||||||
|
自动在命令前注入 cd 和 export 语句以保持会话状态。
|
||||||
|
执行后自动更新 cwd 和 env。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
command: 要执行的命令
|
||||||
|
timeout: 超时时间(秒),None 表示不超时
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ParsedOutput 结构化解析结果
|
||||||
|
"""
|
||||||
|
full_command = self._build_command(command)
|
||||||
|
start = time.monotonic()
|
||||||
|
|
||||||
|
try:
|
||||||
|
proc = await asyncio.create_subprocess_shell(
|
||||||
|
full_command,
|
||||||
|
stdout=asyncio.subprocess.PIPE,
|
||||||
|
stderr=asyncio.subprocess.STDOUT,
|
||||||
|
env=self._env,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
stdout, _ = await asyncio.wait_for(
|
||||||
|
proc.communicate(),
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
proc.kill()
|
||||||
|
await proc.wait()
|
||||||
|
output = f"命令执行超时({timeout}s)"
|
||||||
|
exit_code = -1
|
||||||
|
else:
|
||||||
|
output = stdout.decode("utf-8", errors="replace") if stdout else ""
|
||||||
|
exit_code = proc.returncode if proc.returncode is not None else 0
|
||||||
|
except Exception as e:
|
||||||
|
output = str(e)
|
||||||
|
exit_code = -1
|
||||||
|
|
||||||
|
duration_ms = int((time.monotonic() - start) * 1000)
|
||||||
|
|
||||||
|
# 更新会话状态
|
||||||
|
self._update_state_after_execution(command, output, exit_code)
|
||||||
|
|
||||||
|
# 记录历史
|
||||||
|
record = CommandRecord(
|
||||||
|
command=command,
|
||||||
|
exit_code=exit_code,
|
||||||
|
output=output,
|
||||||
|
cwd=self._cwd,
|
||||||
|
timestamp=time.time(),
|
||||||
|
duration_ms=duration_ms,
|
||||||
|
)
|
||||||
|
self._add_history(record)
|
||||||
|
|
||||||
|
# 解析输出
|
||||||
|
parsed = self._output_parser.parse(output, exit_code)
|
||||||
|
logger.debug(
|
||||||
|
"Session %s: command=%r exit_code=%d duration=%dms",
|
||||||
|
self.session_id,
|
||||||
|
command,
|
||||||
|
exit_code,
|
||||||
|
duration_ms,
|
||||||
|
)
|
||||||
|
return parsed
|
||||||
|
|
||||||
|
def _build_command(self, command: str) -> str:
|
||||||
|
"""构建带会话状态的完整命令
|
||||||
|
|
||||||
|
在原始命令前注入 cd 和 export 语句。
|
||||||
|
"""
|
||||||
|
parts: list[str] = []
|
||||||
|
|
||||||
|
# 注入 cd
|
||||||
|
if self._cwd:
|
||||||
|
# 使用 shlex.quote 风格的简单转义
|
||||||
|
cwd_escaped = self._cwd.replace("'", "'\\''")
|
||||||
|
parts.append(f"cd '{cwd_escaped}'")
|
||||||
|
|
||||||
|
# 注入环境变量
|
||||||
|
for key, value in self._env.items():
|
||||||
|
# 跳过 os.environ 中已有的且值未变的变量,减少命令长度
|
||||||
|
val_escaped = value.replace("'", "'\\''")
|
||||||
|
parts.append(f"export {key}='{val_escaped}'")
|
||||||
|
|
||||||
|
parts.append(command)
|
||||||
|
return " && ".join(parts)
|
||||||
|
|
||||||
|
def _update_state_after_execution(
|
||||||
|
self, command: str, output: str, exit_code: int
|
||||||
|
) -> None:
|
||||||
|
"""命令执行后更新会话状态
|
||||||
|
|
||||||
|
解析 cd 和 export 命令更新 cwd 和 env。
|
||||||
|
"""
|
||||||
|
if exit_code != 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 解析 cd 命令更新 cwd
|
||||||
|
self._parse_cd_commands(command, output)
|
||||||
|
|
||||||
|
# 解析 export 命令更新 env
|
||||||
|
self._parse_export_commands(command)
|
||||||
|
|
||||||
|
def _parse_cd_commands(self, command: str, output: str) -> None:
|
||||||
|
"""从命令中解析 cd 并更新 cwd
|
||||||
|
|
||||||
|
支持:
|
||||||
|
- cd /path
|
||||||
|
- cd dir (相对路径,需要通过 pwd 获取实际路径)
|
||||||
|
- cd - (切换到上一个目录)
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
|
||||||
|
# 匹配 cd 命令(可能出现在 && 链中)
|
||||||
|
cd_pattern = re.compile(r"(?:^|\s|&&\s*)cd\s+(.+?)(?:\s*&&|\s*$)")
|
||||||
|
matches = cd_pattern.findall(command)
|
||||||
|
|
||||||
|
for target in matches:
|
||||||
|
target = target.strip().strip("'\"")
|
||||||
|
if not target:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if target == "-":
|
||||||
|
# cd - 切换到 OLDPWD
|
||||||
|
old_pwd = self._env.get("OLDPWD")
|
||||||
|
if old_pwd:
|
||||||
|
self._cwd = old_pwd
|
||||||
|
elif os.path.isabs(target):
|
||||||
|
self._cwd = target
|
||||||
|
else:
|
||||||
|
# 相对路径:拼接后规范化
|
||||||
|
new_cwd = os.path.normpath(os.path.join(self._cwd, target))
|
||||||
|
self._cwd = new_cwd
|
||||||
|
|
||||||
|
def _parse_export_commands(self, command: str) -> None:
|
||||||
|
"""从命令中解析 export 并更新 env
|
||||||
|
|
||||||
|
支持:
|
||||||
|
- export KEY=VALUE
|
||||||
|
- export KEY="VALUE WITH SPACES"
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
|
||||||
|
export_pattern = re.compile(
|
||||||
|
r"(?:^|\s|&&\s*)export\s+(\w+)=(.+?)(?:\s*&&|\s*$)"
|
||||||
|
)
|
||||||
|
matches = export_pattern.findall(command)
|
||||||
|
|
||||||
|
for key, value in matches:
|
||||||
|
value = value.strip().strip("'\"")
|
||||||
|
self._env[key] = value
|
||||||
|
|
||||||
|
def _add_history(self, record: CommandRecord) -> None:
|
||||||
|
"""添加命令记录到历史,超出上限时移除最旧记录"""
|
||||||
|
self._history.append(record)
|
||||||
|
while len(self._history) > self._max_history:
|
||||||
|
self._history.pop(0)
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""关闭会话,清理资源"""
|
||||||
|
logger.info(
|
||||||
|
"Session %s closed: %d commands executed",
|
||||||
|
self.session_id,
|
||||||
|
len(self._history),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TerminalSessionManager:
|
||||||
|
"""终端会话管理器 - 按 ID 管理多个 TerminalSession
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
manager = TerminalSessionManager()
|
||||||
|
session = manager.get_or_create("build-01")
|
||||||
|
session = manager.get("build-01")
|
||||||
|
manager.remove("build-01")
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, max_sessions: int = 100):
|
||||||
|
self._sessions: dict[str, TerminalSession] = {}
|
||||||
|
self._max_sessions = max_sessions
|
||||||
|
|
||||||
|
def get_or_create(
|
||||||
|
self,
|
||||||
|
session_id: str,
|
||||||
|
cwd: str | None = None,
|
||||||
|
env: dict[str, str] | None = None,
|
||||||
|
) -> TerminalSession:
|
||||||
|
"""获取或创建会话
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: 会话 ID
|
||||||
|
cwd: 初始工作目录(仅创建时使用)
|
||||||
|
env: 初始环境变量(仅创建时使用)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TerminalSession 实例
|
||||||
|
"""
|
||||||
|
if session_id not in self._sessions:
|
||||||
|
if len(self._sessions) >= self._max_sessions:
|
||||||
|
# 移除最旧的会话
|
||||||
|
oldest_id = min(
|
||||||
|
self._sessions, key=lambda k: self._sessions[k].created_at
|
||||||
|
)
|
||||||
|
self.remove(oldest_id)
|
||||||
|
self._sessions[session_id] = TerminalSession(
|
||||||
|
session_id=session_id,
|
||||||
|
cwd=cwd,
|
||||||
|
env=env,
|
||||||
|
)
|
||||||
|
logger.info("Session created: %s", session_id)
|
||||||
|
return self._sessions[session_id]
|
||||||
|
|
||||||
|
def get(self, session_id: str) -> TerminalSession | None:
|
||||||
|
"""获取会话,不存在返回 None"""
|
||||||
|
return self._sessions.get(session_id)
|
||||||
|
|
||||||
|
def remove(self, session_id: str) -> None:
|
||||||
|
"""移除并关闭会话"""
|
||||||
|
session = self._sessions.pop(session_id, None)
|
||||||
|
if session:
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
def list_sessions(self) -> list[str]:
|
||||||
|
"""列出所有会话 ID"""
|
||||||
|
return list(self._sessions.keys())
|
||||||
|
|
||||||
|
def has_session(self, session_id: str) -> bool:
|
||||||
|
"""检查会话是否存在"""
|
||||||
|
return session_id in self._sessions
|
||||||
|
|
||||||
|
def close_all(self) -> None:
|
||||||
|
"""关闭所有会话"""
|
||||||
|
for session_id in list(self._sessions.keys()):
|
||||||
|
self.remove(session_id)
|
||||||
|
|
@ -0,0 +1,512 @@
|
||||||
|
"""Tests for PathOptimizer - 执行路径优化器"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.evolution.path_optimizer import ExecutionPath, PathOptimizer, PathUpdateResult
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fixtures ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def optimizer():
|
||||||
|
"""默认 PathOptimizer 实例"""
|
||||||
|
return PathOptimizer(min_sample_count=3, success_rate_threshold=0.05, duration_improvement_threshold=0.2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def optimizer_custom_thresholds():
|
||||||
|
"""自定义阈值的 PathOptimizer"""
|
||||||
|
return PathOptimizer(
|
||||||
|
min_sample_count=5,
|
||||||
|
success_rate_threshold=0.1,
|
||||||
|
duration_improvement_threshold=0.3,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_path(
|
||||||
|
task_type: str = "code_review",
|
||||||
|
steps: list[str] | None = None,
|
||||||
|
total_duration: float = 10.0,
|
||||||
|
success_rate: float = 0.8,
|
||||||
|
sample_count: int = 5,
|
||||||
|
is_recommended: bool = False,
|
||||||
|
path_id: str = "",
|
||||||
|
created_at: datetime | None = None,
|
||||||
|
) -> ExecutionPath:
|
||||||
|
"""创建测试用 ExecutionPath"""
|
||||||
|
return ExecutionPath(
|
||||||
|
path_id=path_id,
|
||||||
|
task_type=task_type,
|
||||||
|
steps=steps or ["step1", "step2", "step3"],
|
||||||
|
total_duration=total_duration,
|
||||||
|
success_rate=success_rate,
|
||||||
|
sample_count=sample_count,
|
||||||
|
is_recommended=is_recommended,
|
||||||
|
created_at=created_at or datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── ExecutionPath 数据模型测试 ────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestExecutionPath:
|
||||||
|
def test_default_values(self):
|
||||||
|
path = ExecutionPath()
|
||||||
|
assert path.path_id == ""
|
||||||
|
assert path.task_type == ""
|
||||||
|
assert path.steps == []
|
||||||
|
assert path.total_duration == 0.0
|
||||||
|
assert path.success_rate == 0.0
|
||||||
|
assert path.sample_count == 0
|
||||||
|
assert path.is_recommended is False
|
||||||
|
assert isinstance(path.created_at, datetime)
|
||||||
|
|
||||||
|
def test_custom_values(self):
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
path = ExecutionPath(
|
||||||
|
path_id="p1",
|
||||||
|
task_type="code_review",
|
||||||
|
steps=["analyze", "review", "report"],
|
||||||
|
total_duration=15.5,
|
||||||
|
success_rate=0.9,
|
||||||
|
sample_count=10,
|
||||||
|
is_recommended=True,
|
||||||
|
created_at=now,
|
||||||
|
)
|
||||||
|
assert path.path_id == "p1"
|
||||||
|
assert path.task_type == "code_review"
|
||||||
|
assert path.steps == ["analyze", "review", "report"]
|
||||||
|
assert path.total_duration == 15.5
|
||||||
|
assert path.success_rate == 0.9
|
||||||
|
assert path.sample_count == 10
|
||||||
|
assert path.is_recommended is True
|
||||||
|
assert path.created_at == now
|
||||||
|
|
||||||
|
|
||||||
|
# ── PathUpdateResult 数据模型测试 ─────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestPathUpdateResult:
|
||||||
|
def test_default_values(self):
|
||||||
|
result = PathUpdateResult()
|
||||||
|
assert result.updated is False
|
||||||
|
assert result.old_path is None
|
||||||
|
assert result.new_path is None
|
||||||
|
assert result.reason == ""
|
||||||
|
|
||||||
|
def test_updated_result(self):
|
||||||
|
old = _make_path(success_rate=0.7)
|
||||||
|
new = _make_path(success_rate=0.9)
|
||||||
|
result = PathUpdateResult(
|
||||||
|
updated=True,
|
||||||
|
old_path=old,
|
||||||
|
new_path=new,
|
||||||
|
reason="成功率显著提升",
|
||||||
|
)
|
||||||
|
assert result.updated is True
|
||||||
|
assert result.old_path.success_rate == 0.7
|
||||||
|
assert result.new_path.success_rate == 0.9
|
||||||
|
assert "成功率" in result.reason
|
||||||
|
|
||||||
|
|
||||||
|
# ── get_recommended_path 测试 ─────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetRecommendedPath:
|
||||||
|
async def test_no_recommended_path(self, optimizer):
|
||||||
|
result = optimizer.get_recommended_path("code_review")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
async def test_returns_recommended_path(self, optimizer):
|
||||||
|
path = _make_path(task_type="code_review", success_rate=0.8, sample_count=5)
|
||||||
|
await optimizer.evaluate_and_update("code_review", path)
|
||||||
|
result = optimizer.get_recommended_path("code_review")
|
||||||
|
assert result is not None
|
||||||
|
assert result.success_rate == 0.8
|
||||||
|
assert result.is_recommended is True
|
||||||
|
|
||||||
|
async def test_different_task_types_independent(self, optimizer):
|
||||||
|
path_a = _make_path(task_type="code_review", success_rate=0.8, sample_count=5)
|
||||||
|
path_b = _make_path(task_type="data_analysis", success_rate=0.9, sample_count=5)
|
||||||
|
await optimizer.evaluate_and_update("code_review", path_a)
|
||||||
|
await optimizer.evaluate_and_update("data_analysis", path_b)
|
||||||
|
|
||||||
|
result_a = optimizer.get_recommended_path("code_review")
|
||||||
|
result_b = optimizer.get_recommended_path("data_analysis")
|
||||||
|
assert result_a is not None
|
||||||
|
assert result_b is not None
|
||||||
|
assert result_a.success_rate == 0.8
|
||||||
|
assert result_b.success_rate == 0.9
|
||||||
|
|
||||||
|
|
||||||
|
# ── 样本量不足测试 ────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestInsufficientSamples:
|
||||||
|
async def test_insufficient_samples_no_update(self, optimizer):
|
||||||
|
"""样本量不足 → 不更新,记录待观察"""
|
||||||
|
path = _make_path(sample_count=2, success_rate=0.9)
|
||||||
|
result = await optimizer.evaluate_and_update("code_review", path)
|
||||||
|
assert result.updated is False
|
||||||
|
assert "样本量不足" in result.reason
|
||||||
|
assert optimizer.get_recommended_path("code_review") is None
|
||||||
|
|
||||||
|
async def test_insufficient_samples_recorded_as_pending(self, optimizer):
|
||||||
|
"""样本量不足的路径被记录到待观察列表"""
|
||||||
|
path = _make_path(sample_count=2, success_rate=0.9)
|
||||||
|
await optimizer.evaluate_and_update("code_review", path)
|
||||||
|
pending = optimizer.get_pending_paths("code_review")
|
||||||
|
assert len(pending) == 1
|
||||||
|
assert pending[0].success_rate == 0.9
|
||||||
|
|
||||||
|
async def test_exact_min_samples_updates(self, optimizer):
|
||||||
|
"""刚好达到最小样本量 → 可以更新"""
|
||||||
|
path = _make_path(sample_count=3, success_rate=0.8)
|
||||||
|
result = await optimizer.evaluate_and_update("code_review", path)
|
||||||
|
assert result.updated is True
|
||||||
|
assert result.reason == "无现有推荐路径,直接设为推荐"
|
||||||
|
|
||||||
|
async def test_custom_min_sample_count(self, optimizer_custom_thresholds):
|
||||||
|
"""自定义最小样本量"""
|
||||||
|
path = _make_path(sample_count=4, success_rate=0.9)
|
||||||
|
result = await optimizer_custom_thresholds.evaluate_and_update("code_review", path)
|
||||||
|
assert result.updated is False
|
||||||
|
assert "样本量不足" in result.reason
|
||||||
|
|
||||||
|
|
||||||
|
# ── 首次设置推荐路径测试 ──────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestFirstRecommendation:
|
||||||
|
async def test_first_path_becomes_recommended(self, optimizer):
|
||||||
|
"""无现有推荐路径时,新路径直接设为推荐"""
|
||||||
|
path = _make_path(success_rate=0.7, sample_count=5)
|
||||||
|
result = await optimizer.evaluate_and_update("code_review", path)
|
||||||
|
assert result.updated is True
|
||||||
|
assert result.old_path is None
|
||||||
|
assert result.new_path is not None
|
||||||
|
assert result.new_path.is_recommended is True
|
||||||
|
assert "无现有推荐路径" in result.reason
|
||||||
|
|
||||||
|
async def test_auto_generates_path_id(self, optimizer):
|
||||||
|
"""未提供 path_id 时自动生成"""
|
||||||
|
path = _make_path(path_id="", sample_count=5)
|
||||||
|
result = await optimizer.evaluate_and_update("code_review", path)
|
||||||
|
assert result.updated is True
|
||||||
|
assert result.new_path is not None
|
||||||
|
assert len(result.new_path.path_id) > 0
|
||||||
|
|
||||||
|
|
||||||
|
# ── 成功率显著提升测试 ────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestSuccessRateImprovement:
|
||||||
|
async def test_higher_success_rate_updates(self, optimizer):
|
||||||
|
"""新路径成功率更高 → 更新推荐路径"""
|
||||||
|
old_path = _make_path(success_rate=0.7, sample_count=5)
|
||||||
|
await optimizer.evaluate_and_update("code_review", old_path)
|
||||||
|
|
||||||
|
new_path = _make_path(success_rate=0.85, sample_count=5)
|
||||||
|
result = await optimizer.evaluate_and_update("code_review", new_path)
|
||||||
|
assert result.updated is True
|
||||||
|
assert result.old_path.success_rate == 0.7
|
||||||
|
assert result.new_path.success_rate == 0.85
|
||||||
|
assert "成功率显著提升" in result.reason
|
||||||
|
|
||||||
|
async def test_marginal_success_rate_no_update(self, optimizer):
|
||||||
|
"""成功率提升不足阈值 → 不更新"""
|
||||||
|
old_path = _make_path(success_rate=0.8, sample_count=5)
|
||||||
|
await optimizer.evaluate_and_update("code_review", old_path)
|
||||||
|
|
||||||
|
# 提升仅 0.03,低于默认阈值 0.05
|
||||||
|
new_path = _make_path(success_rate=0.83, sample_count=5)
|
||||||
|
result = await optimizer.evaluate_and_update("code_review", new_path)
|
||||||
|
assert result.updated is False
|
||||||
|
assert "无明显优势" in result.reason
|
||||||
|
|
||||||
|
async def test_custom_success_rate_threshold(self, optimizer_custom_thresholds):
|
||||||
|
"""自定义成功率阈值"""
|
||||||
|
old_path = _make_path(success_rate=0.7, sample_count=10)
|
||||||
|
await optimizer_custom_thresholds.evaluate_and_update("code_review", old_path)
|
||||||
|
|
||||||
|
# 提升 0.08,低于自定义阈值 0.1
|
||||||
|
new_path = _make_path(success_rate=0.78, sample_count=10)
|
||||||
|
result = await optimizer_custom_thresholds.evaluate_and_update("code_review", new_path)
|
||||||
|
assert result.updated is False
|
||||||
|
|
||||||
|
async def test_lower_success_rate_no_update(self, optimizer):
|
||||||
|
"""新路径成功率更低 → 不更新"""
|
||||||
|
old_path = _make_path(success_rate=0.9, sample_count=5)
|
||||||
|
await optimizer.evaluate_and_update("code_review", old_path)
|
||||||
|
|
||||||
|
new_path = _make_path(success_rate=0.6, sample_count=5)
|
||||||
|
result = await optimizer.evaluate_and_update("code_review", new_path)
|
||||||
|
assert result.updated is False
|
||||||
|
|
||||||
|
|
||||||
|
# ── 耗时显著更短测试 ──────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestDurationImprovement:
|
||||||
|
async def test_shorter_duration_with_similar_success_rate_updates(self, optimizer):
|
||||||
|
"""成功率相近但耗时显著更短 → 更新推荐路径"""
|
||||||
|
old_path = _make_path(total_duration=100.0, success_rate=0.8, sample_count=5)
|
||||||
|
await optimizer.evaluate_and_update("code_review", old_path)
|
||||||
|
|
||||||
|
# 耗时减少 30%(> 20% 阈值),成功率相近
|
||||||
|
new_path = _make_path(total_duration=70.0, success_rate=0.82, sample_count=5)
|
||||||
|
result = await optimizer.evaluate_and_update("code_review", new_path)
|
||||||
|
assert result.updated is True
|
||||||
|
assert "耗时显著更短" in result.reason
|
||||||
|
|
||||||
|
async def test_marginal_duration_improvement_no_update(self, optimizer):
|
||||||
|
"""耗时改善不足阈值 → 不更新"""
|
||||||
|
old_path = _make_path(total_duration=100.0, success_rate=0.8, sample_count=5)
|
||||||
|
await optimizer.evaluate_and_update("code_review", old_path)
|
||||||
|
|
||||||
|
# 耗时减少仅 10%(< 20% 阈值)
|
||||||
|
new_path = _make_path(total_duration=90.0, success_rate=0.82, sample_count=5)
|
||||||
|
result = await optimizer.evaluate_and_update("code_review", new_path)
|
||||||
|
assert result.updated is False
|
||||||
|
assert "无明显优势" in result.reason
|
||||||
|
|
||||||
|
async def test_longer_duration_no_update(self, optimizer):
|
||||||
|
"""耗时更长 → 不更新"""
|
||||||
|
old_path = _make_path(total_duration=50.0, success_rate=0.8, sample_count=5)
|
||||||
|
await optimizer.evaluate_and_update("code_review", old_path)
|
||||||
|
|
||||||
|
new_path = _make_path(total_duration=80.0, success_rate=0.82, sample_count=5)
|
||||||
|
result = await optimizer.evaluate_and_update("code_review", new_path)
|
||||||
|
assert result.updated is False
|
||||||
|
|
||||||
|
async def test_custom_duration_improvement_threshold(self, optimizer_custom_thresholds):
|
||||||
|
"""自定义耗时改善阈值"""
|
||||||
|
old_path = _make_path(total_duration=100.0, success_rate=0.8, sample_count=10)
|
||||||
|
await optimizer_custom_thresholds.evaluate_and_update("code_review", old_path)
|
||||||
|
|
||||||
|
# 耗时减少 25%(< 30% 自定义阈值)
|
||||||
|
new_path = _make_path(total_duration=75.0, success_rate=0.82, sample_count=10)
|
||||||
|
result = await optimizer_custom_thresholds.evaluate_and_update("code_review", new_path)
|
||||||
|
assert result.updated is False
|
||||||
|
|
||||||
|
async def test_zero_duration_current_path(self, optimizer):
|
||||||
|
"""现有路径耗时为 0 → 不因耗时更新"""
|
||||||
|
old_path = _make_path(total_duration=0.0, success_rate=0.8, sample_count=5)
|
||||||
|
await optimizer.evaluate_and_update("code_review", old_path)
|
||||||
|
|
||||||
|
new_path = _make_path(total_duration=10.0, success_rate=0.82, sample_count=5)
|
||||||
|
result = await optimizer.evaluate_and_update("code_review", new_path)
|
||||||
|
assert result.updated is False
|
||||||
|
|
||||||
|
async def test_both_zero_duration(self, optimizer):
|
||||||
|
"""两者耗时均为 0 → 不因耗时更新"""
|
||||||
|
old_path = _make_path(total_duration=0.0, success_rate=0.8, sample_count=5)
|
||||||
|
await optimizer.evaluate_and_update("code_review", old_path)
|
||||||
|
|
||||||
|
new_path = _make_path(total_duration=0.0, success_rate=0.82, sample_count=5)
|
||||||
|
result = await optimizer.evaluate_and_update("code_review", new_path)
|
||||||
|
assert result.updated is False
|
||||||
|
|
||||||
|
|
||||||
|
# ── 保留现有推荐路径测试 ──────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestKeepCurrentPath:
|
||||||
|
async def test_no_advantage_keeps_current(self, optimizer):
|
||||||
|
"""新路径无明显优势 → 保留现有推荐路径"""
|
||||||
|
old_path = _make_path(total_duration=50.0, success_rate=0.8, sample_count=5)
|
||||||
|
await optimizer.evaluate_and_update("code_review", old_path)
|
||||||
|
|
||||||
|
new_path = _make_path(total_duration=48.0, success_rate=0.79, sample_count=5)
|
||||||
|
result = await optimizer.evaluate_and_update("code_review", new_path)
|
||||||
|
assert result.updated is False
|
||||||
|
assert result.old_path.success_rate == 0.8
|
||||||
|
# 推荐路径不变
|
||||||
|
recommended = optimizer.get_recommended_path("code_review")
|
||||||
|
assert recommended is not None
|
||||||
|
assert recommended.success_rate == 0.8
|
||||||
|
|
||||||
|
async def test_is_recommended_flag_preserved(self, optimizer):
|
||||||
|
"""未更新时,现有路径的 is_recommended 标志保持为 True"""
|
||||||
|
old_path = _make_path(success_rate=0.8, sample_count=5)
|
||||||
|
await optimizer.evaluate_and_update("code_review", old_path)
|
||||||
|
|
||||||
|
new_path = _make_path(success_rate=0.79, sample_count=5)
|
||||||
|
await optimizer.evaluate_and_update("code_review", new_path)
|
||||||
|
|
||||||
|
recommended = optimizer.get_recommended_path("code_review")
|
||||||
|
assert recommended is not None
|
||||||
|
assert recommended.is_recommended is True
|
||||||
|
|
||||||
|
|
||||||
|
# ── is_recommended 标志管理测试 ────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestIsRecommendedFlag:
|
||||||
|
async def test_old_path_loses_recommended_flag(self, optimizer):
|
||||||
|
"""更新后旧路径的 is_recommended 变为 False"""
|
||||||
|
old_path = _make_path(success_rate=0.7, sample_count=5)
|
||||||
|
await optimizer.evaluate_and_update("code_review", old_path)
|
||||||
|
assert old_path.is_recommended is True # 首次设置,is_recommended 为 True
|
||||||
|
|
||||||
|
new_path = _make_path(success_rate=0.9, sample_count=5)
|
||||||
|
result = await optimizer.evaluate_and_update("code_review", new_path)
|
||||||
|
assert result.updated is True
|
||||||
|
assert result.old_path.is_recommended is False # 更新后旧路径失去标志
|
||||||
|
assert result.new_path.is_recommended is True
|
||||||
|
|
||||||
|
|
||||||
|
# ── 多次迭代优化测试 ──────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestIterativeOptimization:
|
||||||
|
async def test_multiple_updates_converge_to_best(self, optimizer):
|
||||||
|
"""多次迭代后推荐路径收敛到最优"""
|
||||||
|
# 第一次:初始路径
|
||||||
|
path1 = _make_path(success_rate=0.6, total_duration=100.0, sample_count=5)
|
||||||
|
await optimizer.evaluate_and_update("code_review", path1)
|
||||||
|
assert optimizer.get_recommended_path("code_review").success_rate == 0.6
|
||||||
|
|
||||||
|
# 第二次:成功率显著提升
|
||||||
|
path2 = _make_path(success_rate=0.8, total_duration=90.0, sample_count=5)
|
||||||
|
await optimizer.evaluate_and_update("code_review", path2)
|
||||||
|
assert optimizer.get_recommended_path("code_review").success_rate == 0.8
|
||||||
|
|
||||||
|
# 第三次:成功率相近但耗时更短
|
||||||
|
path3 = _make_path(success_rate=0.82, total_duration=50.0, sample_count=5)
|
||||||
|
await optimizer.evaluate_and_update("code_review", path3)
|
||||||
|
assert optimizer.get_recommended_path("code_review").total_duration == 50.0
|
||||||
|
|
||||||
|
# 第四次:无明显优势
|
||||||
|
path4 = _make_path(success_rate=0.81, total_duration=48.0, sample_count=5)
|
||||||
|
result = await optimizer.evaluate_and_update("code_review", path4)
|
||||||
|
assert result.updated is False
|
||||||
|
assert optimizer.get_recommended_path("code_review").total_duration == 50.0
|
||||||
|
|
||||||
|
async def test_different_task_types_evolve_independently(self, optimizer):
|
||||||
|
"""不同任务类型的推荐路径独立进化"""
|
||||||
|
path_a1 = _make_path(task_type="code_review", success_rate=0.7, sample_count=5)
|
||||||
|
path_b1 = _make_path(task_type="data_analysis", success_rate=0.6, sample_count=5)
|
||||||
|
await optimizer.evaluate_and_update("code_review", path_a1)
|
||||||
|
await optimizer.evaluate_and_update("data_analysis", path_b1)
|
||||||
|
|
||||||
|
path_a2 = _make_path(task_type="code_review", success_rate=0.9, sample_count=5)
|
||||||
|
await optimizer.evaluate_and_update("code_review", path_a2)
|
||||||
|
|
||||||
|
# code_review 更新了,data_analysis 不受影响
|
||||||
|
assert optimizer.get_recommended_path("code_review").success_rate == 0.9
|
||||||
|
assert optimizer.get_recommended_path("data_analysis").success_rate == 0.6
|
||||||
|
|
||||||
|
|
||||||
|
# ── 待观察路径管理测试 ────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestPendingPaths:
|
||||||
|
async def test_pending_paths_empty_initially(self, optimizer):
|
||||||
|
assert optimizer.get_pending_paths("code_review") == []
|
||||||
|
|
||||||
|
async def test_pending_paths_accumulate(self, optimizer):
|
||||||
|
"""多次样本不足的路径会累积"""
|
||||||
|
path1 = _make_path(sample_count=1, success_rate=0.9)
|
||||||
|
path2 = _make_path(sample_count=2, success_rate=0.85)
|
||||||
|
await optimizer.evaluate_and_update("code_review", path1)
|
||||||
|
await optimizer.evaluate_and_update("code_review", path2)
|
||||||
|
|
||||||
|
pending = optimizer.get_pending_paths("code_review")
|
||||||
|
assert len(pending) == 2
|
||||||
|
|
||||||
|
async def test_pending_paths_isolated_by_task_type(self, optimizer):
|
||||||
|
"""不同任务类型的待观察路径相互隔离"""
|
||||||
|
path_a = _make_path(task_type="code_review", sample_count=1, success_rate=0.9)
|
||||||
|
path_b = _make_path(task_type="data_analysis", sample_count=1, success_rate=0.8)
|
||||||
|
await optimizer.evaluate_and_update("code_review", path_a)
|
||||||
|
await optimizer.evaluate_and_update("data_analysis", path_b)
|
||||||
|
|
||||||
|
assert len(optimizer.get_pending_paths("code_review")) == 1
|
||||||
|
assert len(optimizer.get_pending_paths("data_analysis")) == 1
|
||||||
|
|
||||||
|
async def test_sufficient_samples_not_pending(self, optimizer):
|
||||||
|
"""样本量充足的路径不会进入待观察列表"""
|
||||||
|
path = _make_path(sample_count=5, success_rate=0.8)
|
||||||
|
await optimizer.evaluate_and_update("code_review", path)
|
||||||
|
assert optimizer.get_pending_paths("code_review") == []
|
||||||
|
|
||||||
|
|
||||||
|
# ── ExperienceStore 集成测试 ──────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestExperienceStoreIntegration:
|
||||||
|
async def test_with_experience_store(self):
|
||||||
|
"""PathOptimizer 可以接受 ExperienceStore 实例"""
|
||||||
|
from agentkit.evolution.experience_store import InMemoryExperienceStore
|
||||||
|
|
||||||
|
store = InMemoryExperienceStore()
|
||||||
|
optimizer = PathOptimizer(experience_store=store, min_sample_count=3)
|
||||||
|
|
||||||
|
path = _make_path(success_rate=0.8, sample_count=5)
|
||||||
|
result = await optimizer.evaluate_and_update("code_review", path)
|
||||||
|
assert result.updated is True
|
||||||
|
|
||||||
|
async def test_without_experience_store(self, optimizer):
|
||||||
|
"""PathOptimizer 可以不依赖 ExperienceStore 独立运行"""
|
||||||
|
path = _make_path(success_rate=0.8, sample_count=5)
|
||||||
|
result = await optimizer.evaluate_and_update("code_review", path)
|
||||||
|
assert result.updated is True
|
||||||
|
|
||||||
|
|
||||||
|
# ── 边界条件测试 ──────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestEdgeCases:
|
||||||
|
async def test_same_path_twice(self, optimizer):
|
||||||
|
"""提交相同路径两次"""
|
||||||
|
path = _make_path(success_rate=0.8, sample_count=5)
|
||||||
|
result1 = await optimizer.evaluate_and_update("code_review", path)
|
||||||
|
assert result1.updated is True
|
||||||
|
|
||||||
|
# 第二次提交相同参数的路径(但不同实例)
|
||||||
|
path2 = _make_path(success_rate=0.8, sample_count=5)
|
||||||
|
result2 = await optimizer.evaluate_and_update("code_review", path2)
|
||||||
|
# 成功率相同,耗时相同 → 无明显优势
|
||||||
|
assert result2.updated is False
|
||||||
|
|
||||||
|
async def test_success_rate_at_boundary(self, optimizer):
|
||||||
|
"""成功率刚好在阈值边界"""
|
||||||
|
old_path = _make_path(success_rate=0.8, sample_count=5)
|
||||||
|
await optimizer.evaluate_and_update("code_review", old_path)
|
||||||
|
|
||||||
|
# 提升恰好等于阈值 0.05,不满足 > threshold
|
||||||
|
new_path = _make_path(success_rate=0.85, sample_count=5)
|
||||||
|
result = await optimizer.evaluate_and_update("code_review", new_path)
|
||||||
|
assert result.updated is False
|
||||||
|
|
||||||
|
async def test_duration_improvement_at_boundary(self, optimizer):
|
||||||
|
"""耗时改善刚好在阈值边界"""
|
||||||
|
old_path = _make_path(total_duration=100.0, success_rate=0.8, sample_count=5)
|
||||||
|
await optimizer.evaluate_and_update("code_review", old_path)
|
||||||
|
|
||||||
|
# 改善恰好等于阈值 20%,不满足 > threshold
|
||||||
|
new_path = _make_path(total_duration=80.0, success_rate=0.82, sample_count=5)
|
||||||
|
result = await optimizer.evaluate_and_update("code_review", new_path)
|
||||||
|
assert result.updated is False
|
||||||
|
|
||||||
|
async def test_zero_sample_count(self, optimizer):
|
||||||
|
"""样本量为 0"""
|
||||||
|
path = _make_path(sample_count=0, success_rate=0.9)
|
||||||
|
result = await optimizer.evaluate_and_update("code_review", path)
|
||||||
|
assert result.updated is False
|
||||||
|
assert "样本量不足" in result.reason
|
||||||
|
|
||||||
|
async def test_path_task_type_override(self, optimizer):
|
||||||
|
"""evaluate_and_update 会用传入的 task_type 覆盖路径的 task_type"""
|
||||||
|
path = _make_path(task_type="wrong_type", success_rate=0.8, sample_count=5)
|
||||||
|
result = await optimizer.evaluate_and_update("code_review", path)
|
||||||
|
assert result.updated is True
|
||||||
|
assert path.task_type == "code_review"
|
||||||
|
recommended = optimizer.get_recommended_path("code_review")
|
||||||
|
assert recommended is not None
|
||||||
|
|
@ -0,0 +1,595 @@
|
||||||
|
"""Tests for PitfallDetector - 任务避坑预警检测"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.core.plan_schema import PlanStep, PlanStepStatus
|
||||||
|
from agentkit.evolution.experience_schema import TaskExperience
|
||||||
|
from agentkit.evolution.experience_store import InMemoryExperienceStore
|
||||||
|
from agentkit.evolution.pitfall_detector import (
|
||||||
|
PitfallDetector,
|
||||||
|
PitfallWarning,
|
||||||
|
WarningLevel,
|
||||||
|
_compute_name_similarity,
|
||||||
|
_determine_warning_level,
|
||||||
|
_extract_keywords,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fixtures ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def store():
|
||||||
|
"""无 embedder 的 InMemoryExperienceStore"""
|
||||||
|
return InMemoryExperienceStore(decay_rate=0.01, alpha=0.7)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def detector(store):
|
||||||
|
"""基于 InMemoryExperienceStore 的 PitfallDetector"""
|
||||||
|
return PitfallDetector(experience_store=store, similarity_threshold=0.3)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_experience(
|
||||||
|
task_type: str = "code_review",
|
||||||
|
goal: str = "Review the PR",
|
||||||
|
outcome: str = "success",
|
||||||
|
steps_summary: str | list[dict] = "",
|
||||||
|
failure_reasons: list[str] | None = None,
|
||||||
|
optimization_tips: list[str] | None = None,
|
||||||
|
success_rate: float = 1.0,
|
||||||
|
) -> TaskExperience:
|
||||||
|
"""创建测试用 TaskExperience"""
|
||||||
|
return TaskExperience(
|
||||||
|
experience_id="",
|
||||||
|
task_type=task_type,
|
||||||
|
goal=goal,
|
||||||
|
steps_summary=steps_summary,
|
||||||
|
outcome=outcome,
|
||||||
|
duration_seconds=10.0,
|
||||||
|
success_rate=success_rate,
|
||||||
|
failure_reasons=failure_reasons or [],
|
||||||
|
optimization_tips=optimization_tips or [],
|
||||||
|
created_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_step(
|
||||||
|
name: str = "step",
|
||||||
|
description: str = "do something",
|
||||||
|
step_id: str = "s1",
|
||||||
|
) -> PlanStep:
|
||||||
|
"""创建测试用 PlanStep"""
|
||||||
|
return PlanStep(
|
||||||
|
step_id=step_id,
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
|
status=PlanStepStatus.PENDING,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── 辅助函数测试 ──────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtractKeywords:
|
||||||
|
def test_basic_extraction(self):
|
||||||
|
keywords = _extract_keywords("Call API Gateway")
|
||||||
|
assert "call" in keywords
|
||||||
|
assert "api" in keywords
|
||||||
|
assert "gateway" in keywords
|
||||||
|
|
||||||
|
def test_stop_words_filtered(self):
|
||||||
|
keywords = _extract_keywords("Call the API and check the result")
|
||||||
|
assert "the" not in keywords
|
||||||
|
assert "and" not in keywords
|
||||||
|
assert "call" in keywords
|
||||||
|
assert "api" in keywords
|
||||||
|
|
||||||
|
def test_underscore_and_hyphen(self):
|
||||||
|
keywords = _extract_keywords("call_api-gateway")
|
||||||
|
assert "call" in keywords
|
||||||
|
assert "api" in keywords
|
||||||
|
assert "gateway" in keywords
|
||||||
|
|
||||||
|
def test_single_char_filtered(self):
|
||||||
|
keywords = _extract_keywords("a b cd")
|
||||||
|
assert "a" not in keywords
|
||||||
|
assert "b" not in keywords
|
||||||
|
assert "cd" in keywords
|
||||||
|
|
||||||
|
def test_empty_string(self):
|
||||||
|
keywords = _extract_keywords("")
|
||||||
|
assert len(keywords) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestComputeNameSimilarity:
|
||||||
|
def test_identical_names(self):
|
||||||
|
sim = _compute_name_similarity("Call API Gateway", "", "Call API Gateway")
|
||||||
|
assert sim == pytest.approx(1.0)
|
||||||
|
|
||||||
|
def test_partial_overlap(self):
|
||||||
|
sim = _compute_name_similarity("Call API Gateway", "", "Call External API")
|
||||||
|
# 共享: call, api; 并集: call, api, gateway, external
|
||||||
|
assert 0.0 < sim < 1.0
|
||||||
|
|
||||||
|
def test_no_overlap(self):
|
||||||
|
sim = _compute_name_similarity("Deploy Service", "", "Analyze Data")
|
||||||
|
assert sim == 0.0
|
||||||
|
|
||||||
|
def test_description_contributes(self):
|
||||||
|
sim_no_desc = _compute_name_similarity("Deploy", "", "Deploy Service")
|
||||||
|
sim_with_desc = _compute_name_similarity("Deploy", "Deploy Service", "Deploy Service")
|
||||||
|
# description 中包含匹配关键词,应提高相似度
|
||||||
|
assert sim_with_desc >= sim_no_desc
|
||||||
|
|
||||||
|
def test_empty_inputs(self):
|
||||||
|
sim = _compute_name_similarity("", "", "Call API")
|
||||||
|
assert sim == 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class TestDetermineWarningLevel:
|
||||||
|
def test_high_threshold(self):
|
||||||
|
assert _determine_warning_level(0.6) == WarningLevel.HIGH
|
||||||
|
assert _determine_warning_level(0.5) == WarningLevel.HIGH
|
||||||
|
|
||||||
|
def test_medium_threshold(self):
|
||||||
|
assert _determine_warning_level(0.3) == WarningLevel.MEDIUM
|
||||||
|
assert _determine_warning_level(0.2) == WarningLevel.MEDIUM
|
||||||
|
|
||||||
|
def test_low_threshold(self):
|
||||||
|
assert _determine_warning_level(0.1) == WarningLevel.LOW
|
||||||
|
assert _determine_warning_level(0.01) == WarningLevel.LOW
|
||||||
|
|
||||||
|
|
||||||
|
# ── PitfallDetector.check_pitfalls 测试 ──────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestCheckPitfalls:
|
||||||
|
async def test_no_planned_steps_returns_empty(self, detector):
|
||||||
|
warnings = await detector.check_pitfalls(task_type="code_review", planned_steps=[])
|
||||||
|
assert warnings == []
|
||||||
|
|
||||||
|
async def test_no_failed_experiences_returns_empty(self, detector, store):
|
||||||
|
"""无历史失败记录 → 返回空列表"""
|
||||||
|
# 只记录成功经验
|
||||||
|
await store.record_experience(
|
||||||
|
_make_experience(task_type="code_review", outcome="success")
|
||||||
|
)
|
||||||
|
steps = [_make_step(name="Review Code")]
|
||||||
|
warnings = await detector.check_pitfalls(task_type="code_review", planned_steps=steps)
|
||||||
|
assert warnings == []
|
||||||
|
|
||||||
|
async def test_high_failure_rate_returns_high_warning(self, detector, store):
|
||||||
|
"""计划包含历史高失败率步骤 → 返回 HIGH 级别预警"""
|
||||||
|
# 记录多次失败经验,其中 "Call API Gateway" 步骤失败率高
|
||||||
|
for _ in range(6):
|
||||||
|
await store.record_experience(
|
||||||
|
_make_experience(
|
||||||
|
task_type="deployment",
|
||||||
|
outcome="failure",
|
||||||
|
success_rate=0.0,
|
||||||
|
steps_summary=[
|
||||||
|
{"step_name": "Call API Gateway", "outcome": "failure", "error": "Timeout"},
|
||||||
|
{"step_name": "Deploy Container", "outcome": "success"},
|
||||||
|
],
|
||||||
|
failure_reasons=["API Gateway timeout"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# 记录少数成功经验
|
||||||
|
for _ in range(4):
|
||||||
|
await store.record_experience(
|
||||||
|
_make_experience(
|
||||||
|
task_type="deployment",
|
||||||
|
outcome="success",
|
||||||
|
success_rate=1.0,
|
||||||
|
steps_summary=[
|
||||||
|
{"step_name": "Call API Gateway", "outcome": "success"},
|
||||||
|
{"step_name": "Deploy Container", "outcome": "success"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
steps = [_make_step(name="Call API Gateway", description="Invoke API Gateway endpoint")]
|
||||||
|
warnings = await detector.check_pitfalls(task_type="deployment", planned_steps=steps)
|
||||||
|
|
||||||
|
assert len(warnings) == 1
|
||||||
|
warning = warnings[0]
|
||||||
|
assert warning.step_name == "Call API Gateway"
|
||||||
|
assert warning.warning_level == WarningLevel.HIGH
|
||||||
|
assert warning.failure_rate >= 0.5
|
||||||
|
assert "Timeout" in warning.historical_failures
|
||||||
|
|
||||||
|
async def test_medium_failure_rate(self, detector, store):
|
||||||
|
"""中等失败率 → MEDIUM 级别预警"""
|
||||||
|
# 3 次失败,7 次成功 → 失败率 0.3
|
||||||
|
for _ in range(3):
|
||||||
|
await store.record_experience(
|
||||||
|
_make_experience(
|
||||||
|
task_type="data_analysis",
|
||||||
|
outcome="failure",
|
||||||
|
success_rate=0.0,
|
||||||
|
steps_summary=[
|
||||||
|
{"step_name": "Fetch Data", "outcome": "failure", "error": "Connection refused"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for _ in range(7):
|
||||||
|
await store.record_experience(
|
||||||
|
_make_experience(
|
||||||
|
task_type="data_analysis",
|
||||||
|
outcome="success",
|
||||||
|
success_rate=1.0,
|
||||||
|
steps_summary=[
|
||||||
|
{"step_name": "Fetch Data", "outcome": "success"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
steps = [_make_step(name="Fetch Data", description="Fetch data from source")]
|
||||||
|
warnings = await detector.check_pitfalls(task_type="data_analysis", planned_steps=steps)
|
||||||
|
|
||||||
|
assert len(warnings) == 1
|
||||||
|
assert warnings[0].warning_level == WarningLevel.MEDIUM
|
||||||
|
assert 0.2 <= warnings[0].failure_rate < 0.5
|
||||||
|
|
||||||
|
async def test_low_failure_rate(self, detector, store):
|
||||||
|
"""低失败率 → LOW 级别预警"""
|
||||||
|
# 1 次失败,9 次成功 → 失败率 0.1
|
||||||
|
await store.record_experience(
|
||||||
|
_make_experience(
|
||||||
|
task_type="testing",
|
||||||
|
outcome="failure",
|
||||||
|
success_rate=0.0,
|
||||||
|
steps_summary=[
|
||||||
|
{"step_name": "Run Unit Tests", "outcome": "failure", "error": "Flaky test"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for _ in range(9):
|
||||||
|
await store.record_experience(
|
||||||
|
_make_experience(
|
||||||
|
task_type="testing",
|
||||||
|
outcome="success",
|
||||||
|
success_rate=1.0,
|
||||||
|
steps_summary=[
|
||||||
|
{"step_name": "Run Unit Tests", "outcome": "success"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
steps = [_make_step(name="Run Unit Tests", description="Execute unit test suite")]
|
||||||
|
warnings = await detector.check_pitfalls(task_type="testing", planned_steps=steps)
|
||||||
|
|
||||||
|
assert len(warnings) == 1
|
||||||
|
assert warnings[0].warning_level == WarningLevel.LOW
|
||||||
|
|
||||||
|
async def test_multiple_steps_with_risks_sorted_by_severity(self, detector, store):
|
||||||
|
"""多个步骤有风险 → 按严重程度排序返回"""
|
||||||
|
# "Call API" 高失败率,"Validate Input" 低失败率
|
||||||
|
for _ in range(6):
|
||||||
|
await store.record_experience(
|
||||||
|
_make_experience(
|
||||||
|
task_type="integration",
|
||||||
|
outcome="failure",
|
||||||
|
success_rate=0.0,
|
||||||
|
steps_summary=[
|
||||||
|
{"step_name": "Call API", "outcome": "failure", "error": "Timeout"},
|
||||||
|
{"step_name": "Validate Input", "outcome": "success"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for _ in range(4):
|
||||||
|
await store.record_experience(
|
||||||
|
_make_experience(
|
||||||
|
task_type="integration",
|
||||||
|
outcome="success",
|
||||||
|
success_rate=1.0,
|
||||||
|
steps_summary=[
|
||||||
|
{"step_name": "Call API", "outcome": "success"},
|
||||||
|
{"step_name": "Validate Input", "outcome": "success"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# 单独给 Validate Input 加一条失败记录
|
||||||
|
await store.record_experience(
|
||||||
|
_make_experience(
|
||||||
|
task_type="integration",
|
||||||
|
outcome="partial",
|
||||||
|
success_rate=0.5,
|
||||||
|
steps_summary=[
|
||||||
|
{"step_name": "Call API", "outcome": "success"},
|
||||||
|
{"step_name": "Validate Input", "outcome": "failure", "error": "Invalid schema"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
steps = [
|
||||||
|
_make_step(name="Validate Input", description="Validate input data", step_id="s1"),
|
||||||
|
_make_step(name="Call API", description="Call external API", step_id="s2"),
|
||||||
|
]
|
||||||
|
warnings = await detector.check_pitfalls(task_type="integration", planned_steps=steps)
|
||||||
|
|
||||||
|
assert len(warnings) == 2
|
||||||
|
# HIGH 应排在 MEDIUM/LOW 之前
|
||||||
|
assert warnings[0].warning_level == WarningLevel.HIGH
|
||||||
|
assert warnings[0].step_name == "Call API"
|
||||||
|
|
||||||
|
async def test_no_matching_steps_returns_empty(self, detector, store):
|
||||||
|
"""计划步骤与历史失败步骤无匹配 → 返回空列表"""
|
||||||
|
await store.record_experience(
|
||||||
|
_make_experience(
|
||||||
|
task_type="code_review",
|
||||||
|
outcome="failure",
|
||||||
|
success_rate=0.0,
|
||||||
|
steps_summary=[
|
||||||
|
{"step_name": "Run Linter", "outcome": "failure", "error": "Config error"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 计划步骤名称与历史步骤完全不同
|
||||||
|
steps = [_make_step(name="Deploy Application", description="Deploy to production")]
|
||||||
|
warnings = await detector.check_pitfalls(task_type="code_review", planned_steps=steps)
|
||||||
|
assert warnings == []
|
||||||
|
|
||||||
|
async def test_different_task_type_no_cross_contamination(self, detector, store):
|
||||||
|
"""不同 task_type 的失败经验不会跨类型预警"""
|
||||||
|
await store.record_experience(
|
||||||
|
_make_experience(
|
||||||
|
task_type="deployment",
|
||||||
|
outcome="failure",
|
||||||
|
success_rate=0.0,
|
||||||
|
steps_summary=[
|
||||||
|
{"step_name": "Deploy Service", "outcome": "failure", "error": "OOM"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 查询 code_review 类型,不应返回 deployment 的失败经验
|
||||||
|
steps = [_make_step(name="Deploy Service", description="Deploy the service")]
|
||||||
|
warnings = await detector.check_pitfalls(task_type="code_review", planned_steps=steps)
|
||||||
|
assert warnings == []
|
||||||
|
|
||||||
|
async def test_partial_outcome_included(self, detector, store):
|
||||||
|
"""partial 结果的经验也应被检索"""
|
||||||
|
await store.record_experience(
|
||||||
|
_make_experience(
|
||||||
|
task_type="migration",
|
||||||
|
outcome="partial",
|
||||||
|
success_rate=0.5,
|
||||||
|
steps_summary=[
|
||||||
|
{"step_name": "Migrate Database", "outcome": "failure", "error": "Schema mismatch"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
steps = [_make_step(name="Migrate Database", description="Migrate DB schema")]
|
||||||
|
warnings = await detector.check_pitfalls(task_type="migration", planned_steps=steps)
|
||||||
|
assert len(warnings) == 1
|
||||||
|
|
||||||
|
async def test_steps_summary_as_string_ignored(self, detector, store):
|
||||||
|
"""steps_summary 为字符串时无法提取步骤级信息,不产生预警"""
|
||||||
|
await store.record_experience(
|
||||||
|
_make_experience(
|
||||||
|
task_type="code_review",
|
||||||
|
outcome="failure",
|
||||||
|
success_rate=0.0,
|
||||||
|
steps_summary="Executed code_review task", # 字符串格式
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
steps = [_make_step(name="Review Code", description="Review the code")]
|
||||||
|
warnings = await detector.check_pitfalls(task_type="code_review", planned_steps=steps)
|
||||||
|
assert warnings == []
|
||||||
|
|
||||||
|
|
||||||
|
# ── AE3 场景测试 ──────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestAE3Scenario:
|
||||||
|
"""AE3: "调用 X 系统 API 在高峰期超时率 60%" → 新任务调用时自动预警"""
|
||||||
|
|
||||||
|
async def test_api_timeout_high_failure_rate_warning(self, detector, store):
|
||||||
|
"""调用 X 系统 API 在高峰期超时率 60% → 新任务调用时自动预警"""
|
||||||
|
# 模拟历史:10 次调用,6 次超时 → 60% 失败率
|
||||||
|
for _ in range(6):
|
||||||
|
await store.record_experience(
|
||||||
|
_make_experience(
|
||||||
|
task_type="order_processing",
|
||||||
|
goal="Process orders via X system",
|
||||||
|
outcome="failure",
|
||||||
|
success_rate=0.0,
|
||||||
|
steps_summary=[
|
||||||
|
{"step_name": "Call X System API", "outcome": "failure", "error": "高峰期超时"},
|
||||||
|
{"step_name": "Process Order", "outcome": "success"},
|
||||||
|
],
|
||||||
|
failure_reasons=["X System API timeout during peak hours"],
|
||||||
|
optimization_tips=["Avoid peak hours", "Add retry logic"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for _ in range(4):
|
||||||
|
await store.record_experience(
|
||||||
|
_make_experience(
|
||||||
|
task_type="order_processing",
|
||||||
|
goal="Process orders via X system",
|
||||||
|
outcome="success",
|
||||||
|
success_rate=1.0,
|
||||||
|
steps_summary=[
|
||||||
|
{"step_name": "Call X System API", "outcome": "success"},
|
||||||
|
{"step_name": "Process Order", "outcome": "success"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 新任务计划包含调用 X 系统 API
|
||||||
|
steps = [
|
||||||
|
_make_step(name="Call X System API", description="Invoke X system API for orders"),
|
||||||
|
]
|
||||||
|
warnings = await detector.check_pitfalls(task_type="order_processing", planned_steps=steps)
|
||||||
|
|
||||||
|
assert len(warnings) == 1
|
||||||
|
warning = warnings[0]
|
||||||
|
assert warning.warning_level == WarningLevel.HIGH
|
||||||
|
assert warning.failure_rate >= 0.5
|
||||||
|
assert any("超时" in reason for reason in warning.historical_failures)
|
||||||
|
assert warning.suggestion # 应有建议
|
||||||
|
|
||||||
|
|
||||||
|
# ── PitfallWarning 数据模型测试 ───────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestPitfallWarning:
|
||||||
|
def test_creation(self):
|
||||||
|
warning = PitfallWarning(
|
||||||
|
step_name="Call API",
|
||||||
|
warning_level=WarningLevel.HIGH,
|
||||||
|
failure_rate=0.6,
|
||||||
|
historical_failures=["Timeout", "Connection refused"],
|
||||||
|
suggestion="Add retry logic",
|
||||||
|
)
|
||||||
|
assert warning.step_name == "Call API"
|
||||||
|
assert warning.warning_level == WarningLevel.HIGH
|
||||||
|
assert warning.failure_rate == 0.6
|
||||||
|
assert warning.historical_failures == ["Timeout", "Connection refused"]
|
||||||
|
assert warning.suggestion == "Add retry logic"
|
||||||
|
|
||||||
|
def test_default_values(self):
|
||||||
|
warning = PitfallWarning(
|
||||||
|
step_name="Test",
|
||||||
|
warning_level=WarningLevel.LOW,
|
||||||
|
failure_rate=0.1,
|
||||||
|
)
|
||||||
|
assert warning.historical_failures == []
|
||||||
|
assert warning.suggestion == ""
|
||||||
|
|
||||||
|
|
||||||
|
# ── WarningLevel 枚举测试 ─────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestWarningLevel:
|
||||||
|
def test_values(self):
|
||||||
|
assert WarningLevel.HIGH.value == "high"
|
||||||
|
assert WarningLevel.MEDIUM.value == "medium"
|
||||||
|
assert WarningLevel.LOW.value == "low"
|
||||||
|
|
||||||
|
def test_string_comparison(self):
|
||||||
|
assert WarningLevel.HIGH == "high"
|
||||||
|
assert WarningLevel.MEDIUM == "medium"
|
||||||
|
assert WarningLevel.LOW == "low"
|
||||||
|
|
||||||
|
|
||||||
|
# ── 相似度阈值配置测试 ─────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestSimilarityThreshold:
|
||||||
|
async def test_custom_threshold(self, store):
|
||||||
|
"""自定义相似度阈值"""
|
||||||
|
# 低阈值:更容易匹配
|
||||||
|
detector_low = PitfallDetector(experience_store=store, similarity_threshold=0.1)
|
||||||
|
# 高阈值:更难匹配
|
||||||
|
detector_high = PitfallDetector(experience_store=store, similarity_threshold=0.8)
|
||||||
|
|
||||||
|
await store.record_experience(
|
||||||
|
_make_experience(
|
||||||
|
task_type="testing",
|
||||||
|
outcome="failure",
|
||||||
|
success_rate=0.0,
|
||||||
|
steps_summary=[
|
||||||
|
{"step_name": "Run Integration Tests", "outcome": "failure", "error": "Timeout"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
steps = [_make_step(name="Run Unit Tests", description="Execute tests")]
|
||||||
|
# 低阈值可能匹配,高阈值可能不匹配
|
||||||
|
warnings_low = await detector_low.check_pitfalls(task_type="testing", planned_steps=steps)
|
||||||
|
warnings_high = await detector_high.check_pitfalls(task_type="testing", planned_steps=steps)
|
||||||
|
# 低阈值匹配数 >= 高阈值匹配数
|
||||||
|
assert len(warnings_low) >= len(warnings_high)
|
||||||
|
|
||||||
|
|
||||||
|
# ── 端到端流程测试 ─────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestEndToEnd:
|
||||||
|
async def test_full_pitfall_detection_flow(self, detector, store):
|
||||||
|
"""完整的避坑检测流程"""
|
||||||
|
# 1. 记录多种失败经验
|
||||||
|
await store.record_experience(
|
||||||
|
_make_experience(
|
||||||
|
task_type="deployment",
|
||||||
|
outcome="failure",
|
||||||
|
success_rate=0.0,
|
||||||
|
steps_summary=[
|
||||||
|
{"step_name": "Build Docker Image", "outcome": "failure", "error": "OOM"},
|
||||||
|
{"step_name": "Push to Registry", "outcome": "success"},
|
||||||
|
],
|
||||||
|
failure_reasons=["Docker build OOM"],
|
||||||
|
optimization_tips=["Increase memory limit"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await store.record_experience(
|
||||||
|
_make_experience(
|
||||||
|
task_type="deployment",
|
||||||
|
outcome="failure",
|
||||||
|
success_rate=0.0,
|
||||||
|
steps_summary=[
|
||||||
|
{"step_name": "Build Docker Image", "outcome": "failure", "error": "Dependency conflict"},
|
||||||
|
{"step_name": "Push to Registry", "outcome": "success"},
|
||||||
|
],
|
||||||
|
failure_reasons=["Dependency conflict"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await store.record_experience(
|
||||||
|
_make_experience(
|
||||||
|
task_type="deployment",
|
||||||
|
outcome="success",
|
||||||
|
success_rate=1.0,
|
||||||
|
steps_summary=[
|
||||||
|
{"step_name": "Build Docker Image", "outcome": "success"},
|
||||||
|
{"step_name": "Push to Registry", "outcome": "success"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. 新任务计划
|
||||||
|
steps = [
|
||||||
|
_make_step(name="Build Docker Image", description="Build the container image", step_id="s1"),
|
||||||
|
_make_step(name="Push to Registry", description="Push image to container registry", step_id="s2"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# 3. 检测避坑
|
||||||
|
warnings = await detector.check_pitfalls(task_type="deployment", planned_steps=steps)
|
||||||
|
|
||||||
|
# 4. 验证结果
|
||||||
|
assert len(warnings) >= 1
|
||||||
|
# Build Docker Image 失败率 2/3 ≈ 0.667,应为 HIGH
|
||||||
|
build_warning = next((w for w in warnings if w.step_name == "Build Docker Image"), None)
|
||||||
|
assert build_warning is not None
|
||||||
|
assert build_warning.warning_level == WarningLevel.HIGH
|
||||||
|
assert build_warning.failure_rate == pytest.approx(2.0 / 3.0, abs=0.01)
|
||||||
|
|
||||||
|
async def test_suggestion_contains_useful_info(self, detector, store):
|
||||||
|
"""预警建议应包含有用的失败原因和优化建议"""
|
||||||
|
await store.record_experience(
|
||||||
|
_make_experience(
|
||||||
|
task_type="api_integration",
|
||||||
|
outcome="failure",
|
||||||
|
success_rate=0.0,
|
||||||
|
steps_summary=[
|
||||||
|
{"step_name": "Authenticate", "outcome": "failure", "error": "Token expired"},
|
||||||
|
],
|
||||||
|
failure_reasons=["Token expired"],
|
||||||
|
optimization_tips=["Refresh token before expiry"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
steps = [_make_step(name="Authenticate", description="Authenticate with API")]
|
||||||
|
warnings = await detector.check_pitfalls(task_type="api_integration", planned_steps=steps)
|
||||||
|
|
||||||
|
assert len(warnings) == 1
|
||||||
|
assert "Token expired" in warnings[0].suggestion
|
||||||
|
|
@ -0,0 +1,217 @@
|
||||||
|
"""PTYSession 单元测试
|
||||||
|
|
||||||
|
测试场景:
|
||||||
|
- PTY 会话启动和关闭
|
||||||
|
- 交互式命令执行
|
||||||
|
- 自动应答提示
|
||||||
|
- 超时处理
|
||||||
|
- 自定义应答规则
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.tools.pty_session import PTYSession, PTYOutput
|
||||||
|
from agentkit.tools.shell import ShellTool
|
||||||
|
|
||||||
|
|
||||||
|
class TestPTYSessionConstruction:
|
||||||
|
"""测试 PTYSession 构造"""
|
||||||
|
|
||||||
|
def test_default_construction(self):
|
||||||
|
pty = PTYSession()
|
||||||
|
assert pty.is_running is False
|
||||||
|
assert pty._auto_respond is True
|
||||||
|
assert pty._default_timeout == 30.0
|
||||||
|
|
||||||
|
def test_custom_construction(self):
|
||||||
|
pty = PTYSession(
|
||||||
|
auto_respond=False,
|
||||||
|
custom_rules=[(r"confirm\?", "yes")],
|
||||||
|
default_timeout=60.0,
|
||||||
|
)
|
||||||
|
assert pty._auto_respond is False
|
||||||
|
assert len(pty._respond_rules) > len(pty._respond_rules) - 1
|
||||||
|
assert pty._default_timeout == 60.0
|
||||||
|
|
||||||
|
|
||||||
|
class TestPTYSessionLifecycle:
|
||||||
|
"""测试 PTYSession 生命周期"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_start_and_close(self):
|
||||||
|
"""启动和关闭 PTY 会话"""
|
||||||
|
pty = PTYSession()
|
||||||
|
await pty.start()
|
||||||
|
assert pty.is_running is True
|
||||||
|
await pty.close()
|
||||||
|
assert pty.is_running is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_start_idempotent(self):
|
||||||
|
"""重复启动不报错"""
|
||||||
|
pty = PTYSession()
|
||||||
|
await pty.start()
|
||||||
|
await pty.start() # 不应抛出异常
|
||||||
|
assert pty.is_running is True
|
||||||
|
await pty.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_close_without_start(self):
|
||||||
|
"""未启动时关闭不报错"""
|
||||||
|
pty = PTYSession()
|
||||||
|
await pty.close() # 不应抛出异常
|
||||||
|
|
||||||
|
|
||||||
|
class TestPTYSessionExecution:
|
||||||
|
"""测试 PTYSession 命令执行"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_simple_command(self):
|
||||||
|
"""执行简单命令"""
|
||||||
|
pty = PTYSession(default_timeout=10.0)
|
||||||
|
try:
|
||||||
|
await pty.start()
|
||||||
|
result = await pty.run_command("echo hello_pty")
|
||||||
|
assert "hello_pty" in result.output
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert result.timed_out is False
|
||||||
|
finally:
|
||||||
|
await pty.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_command_with_cwd(self):
|
||||||
|
"""指定工作目录执行命令"""
|
||||||
|
pty = PTYSession(default_timeout=10.0)
|
||||||
|
try:
|
||||||
|
await pty.start()
|
||||||
|
result = await pty.run_command("pwd", cwd="/tmp")
|
||||||
|
assert "/tmp" in result.output
|
||||||
|
finally:
|
||||||
|
await pty.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_command_with_env(self):
|
||||||
|
"""指定环境变量执行命令"""
|
||||||
|
pty = PTYSession(default_timeout=10.0)
|
||||||
|
try:
|
||||||
|
await pty.start()
|
||||||
|
result = await pty.run_command(
|
||||||
|
"echo $PTY_TEST_VAR",
|
||||||
|
env={"PTY_TEST_VAR": "pty_value"},
|
||||||
|
)
|
||||||
|
assert "pty_value" in result.output
|
||||||
|
finally:
|
||||||
|
await pty.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_failing_command(self):
|
||||||
|
"""执行失败命令"""
|
||||||
|
pty = PTYSession(default_timeout=10.0)
|
||||||
|
try:
|
||||||
|
await pty.start()
|
||||||
|
result = await pty.run_command("ls /nonexistent_dir_xyz_12345")
|
||||||
|
assert result.exit_code != 0
|
||||||
|
finally:
|
||||||
|
await pty.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_command_timeout(self):
|
||||||
|
"""命令超时"""
|
||||||
|
pty = PTYSession(default_timeout=10.0)
|
||||||
|
try:
|
||||||
|
await pty.start()
|
||||||
|
result = await pty.run_command("sleep 30", timeout=0.5)
|
||||||
|
assert result.timed_out is True
|
||||||
|
assert result.exit_code == -1
|
||||||
|
finally:
|
||||||
|
await pty.close()
|
||||||
|
|
||||||
|
|
||||||
|
class TestPTYSessionAutoRespond:
|
||||||
|
"""测试 PTYSession 自动应答"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auto_respond_yes_no(self):
|
||||||
|
"""自动应答 [y/N] 提示"""
|
||||||
|
# 使用 echo 模拟包含提示的输出,然后验证自动应答规则存在
|
||||||
|
pty = PTYSession(auto_respond=True)
|
||||||
|
# 验证规则已加载
|
||||||
|
rule_patterns = [r[0] for r in pty._respond_rules]
|
||||||
|
assert any("y/N" in p or "Y/n" in p for p in rule_patterns)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auto_respond_disabled(self):
|
||||||
|
"""禁用自动应答"""
|
||||||
|
pty = PTYSession(auto_respond=False)
|
||||||
|
assert pty._auto_respond is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_custom_respond_rules(self):
|
||||||
|
"""自定义应答规则"""
|
||||||
|
pty = PTYSession(
|
||||||
|
auto_respond=True,
|
||||||
|
custom_rules=[(r"continue\?\s*$", "yes")],
|
||||||
|
)
|
||||||
|
rule_patterns = [r[0] for r in pty._respond_rules]
|
||||||
|
assert r"continue\?\s*$" in rule_patterns
|
||||||
|
|
||||||
|
|
||||||
|
class TestPTYSessionSendAndRead:
|
||||||
|
"""测试 PTYSession 发送和读取"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_without_start(self):
|
||||||
|
"""未启动时发送不报错"""
|
||||||
|
pty = PTYSession()
|
||||||
|
await pty.send("test") # 不应抛出异常
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_read_output_without_start(self):
|
||||||
|
"""未启动时读取返回空"""
|
||||||
|
pty = PTYSession()
|
||||||
|
output = await pty.read_output()
|
||||||
|
assert output == ""
|
||||||
|
|
||||||
|
|
||||||
|
class TestPTYOutput:
|
||||||
|
"""测试 PTYOutput 数据类"""
|
||||||
|
|
||||||
|
def test_default_values(self):
|
||||||
|
output = PTYOutput(output="test")
|
||||||
|
assert output.output == "test"
|
||||||
|
assert output.exit_code == -1
|
||||||
|
assert output.timed_out is False
|
||||||
|
|
||||||
|
def test_custom_values(self):
|
||||||
|
output = PTYOutput(output="error", exit_code=1, timed_out=True)
|
||||||
|
assert output.exit_code == 1
|
||||||
|
assert output.timed_out is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestShellToolInteractiveMode:
|
||||||
|
"""测试 ShellTool 交互式模式"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_interactive_mode(self):
|
||||||
|
"""ShellTool interactive 模式执行命令"""
|
||||||
|
tool = ShellTool()
|
||||||
|
result = await tool.execute(command="echo interactive_test", interactive=True)
|
||||||
|
assert result["exit_code"] == 0
|
||||||
|
assert "interactive_test" in result["output"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_interactive_mode_with_session(self):
|
||||||
|
"""ShellTool 会话模式 + 交互式"""
|
||||||
|
tool = ShellTool()
|
||||||
|
result = await tool.execute(
|
||||||
|
command="echo session_interactive",
|
||||||
|
session_id="int-session",
|
||||||
|
interactive=True,
|
||||||
|
)
|
||||||
|
assert result["exit_code"] == 0
|
||||||
|
assert "session_interactive" in result["output"]
|
||||||
|
|
@ -0,0 +1,600 @@
|
||||||
|
"""TerminalSession 和 ShellTool 单元测试
|
||||||
|
|
||||||
|
测试场景:
|
||||||
|
- 跨命令保持 cwd → cd 后执行 pwd 返回正确目录
|
||||||
|
- 跨命令保持 env → export 后执行 echo 返回正确值
|
||||||
|
- 危险命令需确认 → rm 命令触发确认回调
|
||||||
|
- 输出解析 → 错误输出结构化为错误类型+建议
|
||||||
|
- 无 session_id 时保持现有行为
|
||||||
|
- 会话管理器功能
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.tools.terminal_session import TerminalSession, TerminalSessionManager, CommandRecord
|
||||||
|
from agentkit.tools.shell import ShellTool
|
||||||
|
from agentkit.tools.output_parser import OutputParser, ParsedOutput, ErrorType
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# OutputParser 测试
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestOutputParser:
|
||||||
|
"""测试 OutputParser 结构化解析"""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
self.parser = OutputParser()
|
||||||
|
|
||||||
|
def test_parse_success_output(self):
|
||||||
|
"""成功输出解析"""
|
||||||
|
result = self.parser.parse("hello world", 0)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert result.is_error is False
|
||||||
|
assert result.error_type == ErrorType.NONE
|
||||||
|
assert result.message == "hello world"
|
||||||
|
assert result.suggestions == []
|
||||||
|
|
||||||
|
def test_parse_empty_output(self):
|
||||||
|
"""空输出解析"""
|
||||||
|
result = self.parser.parse("", 0)
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert result.is_error is False
|
||||||
|
assert result.message == ""
|
||||||
|
|
||||||
|
def test_parse_permission_denied(self):
|
||||||
|
"""权限不足错误解析"""
|
||||||
|
result = self.parser.parse("permission denied: /root/secret", 1)
|
||||||
|
assert result.is_error is True
|
||||||
|
assert result.error_type == ErrorType.PERMISSION_DENIED
|
||||||
|
assert len(result.suggestions) > 0
|
||||||
|
assert any("sudo" in s for s in result.suggestions)
|
||||||
|
|
||||||
|
def test_parse_not_found(self):
|
||||||
|
"""文件不存在错误解析"""
|
||||||
|
result = self.parser.parse("No such file or directory: /tmp/missing", 1)
|
||||||
|
assert result.is_error is True
|
||||||
|
assert result.error_type == ErrorType.NOT_FOUND
|
||||||
|
|
||||||
|
def test_parse_timeout(self):
|
||||||
|
"""超时错误解析"""
|
||||||
|
result = self.parser.parse("Connection timed out", 1)
|
||||||
|
assert result.is_error is True
|
||||||
|
assert result.error_type == ErrorType.TIMEOUT
|
||||||
|
|
||||||
|
def test_parse_syntax_error(self):
|
||||||
|
"""语法错误解析"""
|
||||||
|
result = self.parser.parse("syntax error near unexpected token", 2)
|
||||||
|
assert result.is_error is True
|
||||||
|
assert result.error_type == ErrorType.SYNTAX_ERROR
|
||||||
|
|
||||||
|
def test_parse_connection_refused(self):
|
||||||
|
"""连接被拒绝解析"""
|
||||||
|
result = self.parser.parse("Connection refused on port 8080", 1)
|
||||||
|
assert result.is_error is True
|
||||||
|
assert result.error_type == ErrorType.CONNECTION_REFUSED
|
||||||
|
|
||||||
|
def test_parse_out_of_memory(self):
|
||||||
|
"""内存不足解析"""
|
||||||
|
result = self.parser.parse("Out of memory: cannot allocate", 1)
|
||||||
|
assert result.is_error is True
|
||||||
|
assert result.error_type == ErrorType.OUT_OF_MEMORY
|
||||||
|
|
||||||
|
def test_parse_disk_full(self):
|
||||||
|
"""磁盘满解析"""
|
||||||
|
result = self.parser.parse("No space left on device", 1)
|
||||||
|
assert result.is_error is True
|
||||||
|
assert result.error_type == ErrorType.DISK_FULL
|
||||||
|
|
||||||
|
def test_parse_already_exists(self):
|
||||||
|
"""已存在解析"""
|
||||||
|
result = self.parser.parse("File already exists: /tmp/test", 1)
|
||||||
|
assert result.is_error is True
|
||||||
|
assert result.error_type == ErrorType.ALREADY_EXISTS
|
||||||
|
|
||||||
|
def test_parse_invalid_argument(self):
|
||||||
|
"""无效参数解析"""
|
||||||
|
result = self.parser.parse("invalid argument: --unknown-flag", 1)
|
||||||
|
assert result.is_error is True
|
||||||
|
assert result.error_type == ErrorType.INVALID_ARGUMENT
|
||||||
|
|
||||||
|
def test_parse_network_error(self):
|
||||||
|
"""网络错误解析"""
|
||||||
|
result = self.parser.parse("Network is unreachable", 1)
|
||||||
|
assert result.is_error is True
|
||||||
|
assert result.error_type == ErrorType.NETWORK_ERROR
|
||||||
|
|
||||||
|
def test_parse_exit_code_126(self):
|
||||||
|
"""退出码 126 → 权限不足"""
|
||||||
|
result = self.parser.parse("some unknown error", 126)
|
||||||
|
assert result.is_error is True
|
||||||
|
assert result.error_type == ErrorType.PERMISSION_DENIED
|
||||||
|
|
||||||
|
def test_parse_exit_code_127(self):
|
||||||
|
"""退出码 127 → 命令未找到"""
|
||||||
|
result = self.parser.parse("some unknown error", 127)
|
||||||
|
assert result.is_error is True
|
||||||
|
assert result.error_type == ErrorType.NOT_FOUND
|
||||||
|
|
||||||
|
def test_parse_exit_code_130(self):
|
||||||
|
"""退出码 130 → 被中断"""
|
||||||
|
result = self.parser.parse("some unknown error", 130)
|
||||||
|
assert result.is_error is True
|
||||||
|
assert result.error_type == ErrorType.TIMEOUT
|
||||||
|
|
||||||
|
def test_parse_unknown_error(self):
|
||||||
|
"""未知错误"""
|
||||||
|
result = self.parser.parse("something went wrong", 1)
|
||||||
|
assert result.is_error is True
|
||||||
|
assert result.error_type == ErrorType.UNKNOWN
|
||||||
|
|
||||||
|
def test_parse_long_message_truncated(self):
|
||||||
|
"""长消息截断"""
|
||||||
|
long_output = "x" * 300
|
||||||
|
result = self.parser.parse(long_output, 0)
|
||||||
|
assert len(result.message) <= 203 # 200 + "..."
|
||||||
|
|
||||||
|
def test_parsed_output_to_dict(self):
|
||||||
|
"""ParsedOutput.to_dict()"""
|
||||||
|
result = self.parser.parse("permission denied", 1)
|
||||||
|
d = result.to_dict()
|
||||||
|
assert d["exit_code"] == 1
|
||||||
|
assert d["is_error"] is True
|
||||||
|
assert d["error_type"] == "permission_denied"
|
||||||
|
assert isinstance(d["suggestions"], list)
|
||||||
|
|
||||||
|
def test_parse_chinese_error_messages(self):
|
||||||
|
"""中文错误消息解析"""
|
||||||
|
result = self.parser.parse("权限不足: 无法访问", 1)
|
||||||
|
assert result.is_error is True
|
||||||
|
assert result.error_type == ErrorType.PERMISSION_DENIED
|
||||||
|
|
||||||
|
def test_parse_multiline_output_message_is_last_line(self):
|
||||||
|
"""多行输出取最后一行作为消息"""
|
||||||
|
output = "line1\nline2\nline3"
|
||||||
|
result = self.parser.parse(output, 0)
|
||||||
|
assert result.message == "line3"
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# TerminalSession 测试
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestTerminalSession:
|
||||||
|
"""测试 TerminalSession 会话状态管理"""
|
||||||
|
|
||||||
|
def test_construction_default(self):
|
||||||
|
"""默认构造"""
|
||||||
|
session = TerminalSession(session_id="test")
|
||||||
|
assert session.session_id == "test"
|
||||||
|
assert session.cwd == os.getcwd()
|
||||||
|
assert isinstance(session.env, dict)
|
||||||
|
assert session.history == []
|
||||||
|
|
||||||
|
def test_construction_custom_cwd(self):
|
||||||
|
"""自定义工作目录"""
|
||||||
|
session = TerminalSession(session_id="test", cwd="/tmp")
|
||||||
|
assert session.cwd == "/tmp"
|
||||||
|
|
||||||
|
def test_construction_custom_env(self):
|
||||||
|
"""自定义环境变量"""
|
||||||
|
session = TerminalSession(session_id="test", env={"FOO": "bar"})
|
||||||
|
assert session.env.get("FOO") == "bar"
|
||||||
|
|
||||||
|
def test_set_cwd(self):
|
||||||
|
"""手动设置 cwd"""
|
||||||
|
session = TerminalSession(session_id="test")
|
||||||
|
session.set_cwd("/usr/local")
|
||||||
|
assert session.cwd == "/usr/local"
|
||||||
|
|
||||||
|
def test_set_env(self):
|
||||||
|
"""手动设置环境变量"""
|
||||||
|
session = TerminalSession(session_id="test")
|
||||||
|
session.set_env("MY_VAR", "hello")
|
||||||
|
assert session.env.get("MY_VAR") == "hello"
|
||||||
|
|
||||||
|
def test_update_env(self):
|
||||||
|
"""批量更新环境变量"""
|
||||||
|
session = TerminalSession(session_id="test")
|
||||||
|
session.update_env({"A": "1", "B": "2"})
|
||||||
|
assert session.env.get("A") == "1"
|
||||||
|
assert session.env.get("B") == "2"
|
||||||
|
|
||||||
|
def test_get_env_returns_copy(self):
|
||||||
|
"""get_env 返回副本,修改不影响原数据"""
|
||||||
|
session = TerminalSession(session_id="test")
|
||||||
|
env = session.get_env()
|
||||||
|
env["HACKED"] = "yes"
|
||||||
|
assert "HACKED" not in session.env
|
||||||
|
|
||||||
|
def test_get_history_returns_copy(self):
|
||||||
|
"""get_history 返回副本"""
|
||||||
|
session = TerminalSession(session_id="test")
|
||||||
|
history = session.get_history()
|
||||||
|
assert history is not session._history
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_simple_command(self):
|
||||||
|
"""执行简单命令"""
|
||||||
|
session = TerminalSession(session_id="test")
|
||||||
|
result = await session.execute("echo hello")
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "hello" in result.raw_output
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_records_history(self):
|
||||||
|
"""执行命令记录历史"""
|
||||||
|
session = TerminalSession(session_id="test")
|
||||||
|
await session.execute("echo first")
|
||||||
|
await session.execute("echo second")
|
||||||
|
assert len(session.history) == 2
|
||||||
|
assert session.history[0].command == "echo first"
|
||||||
|
assert session.history[1].command == "echo second"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cross_command_cwd(self):
|
||||||
|
"""跨命令保持 cwd:cd 后 pwd 返回正确目录"""
|
||||||
|
session = TerminalSession(session_id="test")
|
||||||
|
await session.execute("cd /tmp")
|
||||||
|
assert session.cwd == "/tmp"
|
||||||
|
|
||||||
|
result = await session.execute("pwd")
|
||||||
|
assert "/tmp" in result.raw_output
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cross_command_env(self):
|
||||||
|
"""跨命令保持 env:export 后 echo 返回正确值"""
|
||||||
|
session = TerminalSession(session_id="test")
|
||||||
|
await session.execute("export MY_TEST_VAR=hello123")
|
||||||
|
assert session.env.get("MY_TEST_VAR") == "hello123"
|
||||||
|
|
||||||
|
result = await session.execute("echo $MY_TEST_VAR")
|
||||||
|
assert "hello123" in result.raw_output
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cd_relative_path(self):
|
||||||
|
"""cd 相对路径(目录存在时更新 cwd)"""
|
||||||
|
# 使用 /usr 作为基础目录,cd local(/usr/local 存在)
|
||||||
|
session = TerminalSession(session_id="test", cwd="/usr")
|
||||||
|
await session.execute("cd local")
|
||||||
|
assert session.cwd == "/usr/local"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cd_absolute_path(self):
|
||||||
|
"""cd 绝对路径"""
|
||||||
|
session = TerminalSession(session_id="test")
|
||||||
|
await session.execute("cd /usr")
|
||||||
|
assert session.cwd == "/usr"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_failed_command_no_state_update(self):
|
||||||
|
"""失败命令不更新状态"""
|
||||||
|
session = TerminalSession(session_id="test", cwd="/tmp")
|
||||||
|
await session.execute("cd /nonexistent_dir_xyz")
|
||||||
|
# cd 失败,cwd 不应更新
|
||||||
|
assert session.cwd == "/tmp"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_timeout(self):
|
||||||
|
"""命令超时"""
|
||||||
|
session = TerminalSession(session_id="test")
|
||||||
|
result = await session.execute("sleep 10", timeout=0.5)
|
||||||
|
assert result.exit_code == -1
|
||||||
|
assert result.is_error is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_max_history(self):
|
||||||
|
"""历史记录上限"""
|
||||||
|
session = TerminalSession(session_id="test", max_history=3)
|
||||||
|
for i in range(5):
|
||||||
|
await session.execute(f"echo {i}")
|
||||||
|
assert len(session.history) == 3
|
||||||
|
assert session.history[0].command == "echo 2"
|
||||||
|
|
||||||
|
def test_close(self):
|
||||||
|
"""关闭会话"""
|
||||||
|
session = TerminalSession(session_id="test")
|
||||||
|
session.close() # 不应抛出异常
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# TerminalSessionManager 测试
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestTerminalSessionManager:
|
||||||
|
"""测试 TerminalSessionManager 会话管理"""
|
||||||
|
|
||||||
|
def test_get_or_create_new(self):
|
||||||
|
"""创建新会话"""
|
||||||
|
manager = TerminalSessionManager()
|
||||||
|
session = manager.get_or_create("s1")
|
||||||
|
assert session.session_id == "s1"
|
||||||
|
|
||||||
|
def test_get_or_create_existing(self):
|
||||||
|
"""获取已有会话"""
|
||||||
|
manager = TerminalSessionManager()
|
||||||
|
s1 = manager.get_or_create("s1")
|
||||||
|
s1.set_cwd("/tmp")
|
||||||
|
s2 = manager.get_or_create("s1")
|
||||||
|
assert s2.cwd == "/tmp"
|
||||||
|
|
||||||
|
def test_get_existing(self):
|
||||||
|
"""get 获取已有会话"""
|
||||||
|
manager = TerminalSessionManager()
|
||||||
|
manager.get_or_create("s1")
|
||||||
|
session = manager.get("s1")
|
||||||
|
assert session is not None
|
||||||
|
|
||||||
|
def test_get_nonexistent(self):
|
||||||
|
"""get 不存在的会话返回 None"""
|
||||||
|
manager = TerminalSessionManager()
|
||||||
|
assert manager.get("nonexistent") is None
|
||||||
|
|
||||||
|
def test_remove(self):
|
||||||
|
"""移除会话"""
|
||||||
|
manager = TerminalSessionManager()
|
||||||
|
manager.get_or_create("s1")
|
||||||
|
manager.remove("s1")
|
||||||
|
assert manager.get("s1") is None
|
||||||
|
|
||||||
|
def test_list_sessions(self):
|
||||||
|
"""列出会话"""
|
||||||
|
manager = TerminalSessionManager()
|
||||||
|
manager.get_or_create("s1")
|
||||||
|
manager.get_or_create("s2")
|
||||||
|
assert sorted(manager.list_sessions()) == ["s1", "s2"]
|
||||||
|
|
||||||
|
def test_has_session(self):
|
||||||
|
"""检查会话是否存在"""
|
||||||
|
manager = TerminalSessionManager()
|
||||||
|
manager.get_or_create("s1")
|
||||||
|
assert manager.has_session("s1") is True
|
||||||
|
assert manager.has_session("s2") is False
|
||||||
|
|
||||||
|
def test_max_sessions_eviction(self):
|
||||||
|
"""超过最大会话数时移除最旧会话"""
|
||||||
|
manager = TerminalSessionManager(max_sessions=2)
|
||||||
|
manager.get_or_create("s1")
|
||||||
|
manager.get_or_create("s2")
|
||||||
|
manager.get_or_create("s3") # 应该移除 s1
|
||||||
|
assert not manager.has_session("s1")
|
||||||
|
assert manager.has_session("s2")
|
||||||
|
assert manager.has_session("s3")
|
||||||
|
|
||||||
|
def test_close_all(self):
|
||||||
|
"""关闭所有会话"""
|
||||||
|
manager = TerminalSessionManager()
|
||||||
|
manager.get_or_create("s1")
|
||||||
|
manager.get_or_create("s2")
|
||||||
|
manager.close_all()
|
||||||
|
assert manager.list_sessions() == []
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# ShellTool 测试
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestShellToolConstruction:
|
||||||
|
"""测试 ShellTool 构造"""
|
||||||
|
|
||||||
|
def test_default_construction(self):
|
||||||
|
tool = ShellTool()
|
||||||
|
assert tool.name == "shell"
|
||||||
|
assert tool.input_schema is not None
|
||||||
|
assert "command" in tool.input_schema["properties"]
|
||||||
|
assert "session_id" in tool.input_schema["properties"]
|
||||||
|
assert tool.input_schema["required"] == ["command"]
|
||||||
|
|
||||||
|
def test_custom_construction(self):
|
||||||
|
tool = ShellTool(name="my_shell", version="2.0.0")
|
||||||
|
assert tool.name == "my_shell"
|
||||||
|
assert tool.version == "2.0.0"
|
||||||
|
|
||||||
|
def test_to_dict(self):
|
||||||
|
tool = ShellTool()
|
||||||
|
d = tool.to_dict()
|
||||||
|
assert d["name"] == "shell"
|
||||||
|
assert "input_schema" in d
|
||||||
|
|
||||||
|
def test_repr(self):
|
||||||
|
tool = ShellTool()
|
||||||
|
r = repr(tool)
|
||||||
|
assert "ShellTool" in r
|
||||||
|
assert "shell" in r
|
||||||
|
|
||||||
|
|
||||||
|
class TestShellToolExecution:
|
||||||
|
"""测试 ShellTool 命令执行"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_simple_command(self):
|
||||||
|
"""执行简单命令(无会话模式)"""
|
||||||
|
tool = ShellTool()
|
||||||
|
result = await tool.execute(command="echo hello")
|
||||||
|
assert result["exit_code"] == 0
|
||||||
|
assert "hello" in result["output"]
|
||||||
|
assert result["is_error"] is False
|
||||||
|
assert result["session_id"] is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_missing_command(self):
|
||||||
|
"""缺少 command 参数"""
|
||||||
|
tool = ShellTool()
|
||||||
|
result = await tool.execute()
|
||||||
|
assert result["is_error"] is True
|
||||||
|
assert result["exit_code"] == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_with_working_dir(self):
|
||||||
|
"""指定工作目录"""
|
||||||
|
tool = ShellTool()
|
||||||
|
result = await tool.execute(command="pwd", working_dir="/tmp")
|
||||||
|
assert result["exit_code"] == 0
|
||||||
|
assert "/tmp" in result["output"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_with_session(self):
|
||||||
|
"""会话模式执行命令"""
|
||||||
|
tool = ShellTool()
|
||||||
|
result = await tool.execute(command="echo session_test", session_id="s1")
|
||||||
|
assert result["exit_code"] == 0
|
||||||
|
assert "session_test" in result["output"]
|
||||||
|
assert result["session_id"] == "s1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_session_preserves_cwd(self):
|
||||||
|
"""会话模式保持 cwd"""
|
||||||
|
tool = ShellTool()
|
||||||
|
await tool.execute(command="cd /tmp", session_id="cwd-test")
|
||||||
|
result = await tool.execute(command="pwd", session_id="cwd-test")
|
||||||
|
assert "/tmp" in result["output"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_session_preserves_env(self):
|
||||||
|
"""会话模式保持 env"""
|
||||||
|
tool = ShellTool()
|
||||||
|
await tool.execute(
|
||||||
|
command="export SHELL_TEST_VAR=world", session_id="env-test"
|
||||||
|
)
|
||||||
|
result = await tool.execute(
|
||||||
|
command="echo $SHELL_TEST_VAR", session_id="env-test"
|
||||||
|
)
|
||||||
|
assert "world" in result["output"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_session_id_backward_compatible(self):
|
||||||
|
"""无 session_id 时保持现有行为"""
|
||||||
|
tool = ShellTool()
|
||||||
|
result = await tool.execute(command="echo no_session")
|
||||||
|
assert result["exit_code"] == 0
|
||||||
|
assert "no_session" in result["output"]
|
||||||
|
assert result["session_id"] is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_different_sessions_independent(self):
|
||||||
|
"""不同会话互不影响"""
|
||||||
|
tool = ShellTool()
|
||||||
|
await tool.execute(command="cd /tmp", session_id="s1")
|
||||||
|
await tool.execute(command="cd /usr", session_id="s2")
|
||||||
|
|
||||||
|
r1 = await tool.execute(command="pwd", session_id="s1")
|
||||||
|
r2 = await tool.execute(command="pwd", session_id="s2")
|
||||||
|
|
||||||
|
assert "/tmp" in r1["output"]
|
||||||
|
assert "/usr" in r2["output"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestShellToolSecurity:
|
||||||
|
"""测试 ShellTool 安全控制"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_safe_command_allowed(self):
|
||||||
|
"""安全命令直接执行"""
|
||||||
|
tool = ShellTool()
|
||||||
|
result = await tool.execute(command="ls /tmp")
|
||||||
|
assert result["exit_code"] == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dangerous_command_blocked_without_callback(self):
|
||||||
|
"""危险命令无确认回调时被拒绝"""
|
||||||
|
tool = ShellTool()
|
||||||
|
result = await tool.execute(command="rm -rf /tmp/test")
|
||||||
|
assert result["is_error"] is True
|
||||||
|
assert result["exit_code"] == 126
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dangerous_command_confirmed(self):
|
||||||
|
"""危险命令通过确认回调允许执行"""
|
||||||
|
confirm = AsyncMock(return_value=True)
|
||||||
|
tool = ShellTool(confirm_callback=confirm)
|
||||||
|
result = await tool.execute(command="rm -rf /tmp/nonexistent_test_dir")
|
||||||
|
assert confirm.called
|
||||||
|
# 命令本身可能失败(目录不存在),但不应被安全机制拒绝
|
||||||
|
assert result["exit_code"] != 126 or not result["is_error"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dangerous_command_rejected_by_callback(self):
|
||||||
|
"""确认回调拒绝危险命令"""
|
||||||
|
confirm = AsyncMock(return_value=False)
|
||||||
|
tool = ShellTool(confirm_callback=confirm)
|
||||||
|
result = await tool.execute(command="rm -rf /tmp/test")
|
||||||
|
assert result["is_error"] is True
|
||||||
|
assert result["exit_code"] == 126
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_audit_log_recorded(self):
|
||||||
|
"""审计日志记录"""
|
||||||
|
tool = ShellTool()
|
||||||
|
await tool.execute(command="echo audit_test")
|
||||||
|
assert len(tool.audit_log) > 0
|
||||||
|
assert tool.audit_log[0]["command"] == "echo audit_test"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_blocked_command_in_audit_log(self):
|
||||||
|
"""被阻止的命令记录在审计日志"""
|
||||||
|
tool = ShellTool()
|
||||||
|
await tool.execute(command="rm -rf /tmp/test")
|
||||||
|
blocked_entries = [e for e in tool.audit_log if e.get("blocked")]
|
||||||
|
assert len(blocked_entries) > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_git_push_force_is_dangerous(self):
|
||||||
|
"""git push --force 是危险命令"""
|
||||||
|
tool = ShellTool()
|
||||||
|
result = await tool.execute(command="git push --force origin main")
|
||||||
|
assert result["is_error"] is True
|
||||||
|
assert result["exit_code"] == 126
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_git_status_is_safe(self):
|
||||||
|
"""git status 是安全命令"""
|
||||||
|
tool = ShellTool()
|
||||||
|
result = await tool.execute(command="git status")
|
||||||
|
# git status 可能在非 git 目录失败,但不应被安全机制拒绝
|
||||||
|
assert result["exit_code"] != 126
|
||||||
|
|
||||||
|
|
||||||
|
class TestShellToolOutputParsing:
|
||||||
|
"""测试 ShellTool 输出解析集成"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_error_output_structured(self):
|
||||||
|
"""错误输出结构化"""
|
||||||
|
tool = ShellTool()
|
||||||
|
result = await tool.execute(command="ls /nonexistent_dir_xyz_12345")
|
||||||
|
assert result["is_error"] is True
|
||||||
|
assert result["error_type"] in ("not_found", "unknown")
|
||||||
|
assert isinstance(result["suggestions"], list)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_success_output_not_error(self):
|
||||||
|
"""成功输出不标记为错误"""
|
||||||
|
tool = ShellTool()
|
||||||
|
result = await tool.execute(command="echo success")
|
||||||
|
assert result["is_error"] is False
|
||||||
|
assert result["error_type"] == "none"
|
||||||
|
|
||||||
|
|
||||||
|
class TestShellToolSessionManager:
|
||||||
|
"""测试 ShellTool 会话管理器访问"""
|
||||||
|
|
||||||
|
def test_session_manager_accessible(self):
|
||||||
|
tool = ShellTool()
|
||||||
|
assert tool.session_manager is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_session_created_on_first_use(self):
|
||||||
|
"""首次使用 session_id 时创建会话"""
|
||||||
|
tool = ShellTool()
|
||||||
|
assert not tool.session_manager.has_session("new-session")
|
||||||
|
await tool.execute(command="echo test", session_id="new-session")
|
||||||
|
assert tool.session_manager.has_session("new-session")
|
||||||
Loading…
Reference in New Issue