596 lines
23 KiB
Python
596 lines
23 KiB
Python
"""Tests for PitfallDetector - 任务避坑预警检测"""
|
||
|
||
from __future__ import annotations
|
||
|
||
from datetime import datetime, timezone
|
||
|
||
import pytest
|
||
|
||
from agentkit.core.plan_schema import PlanStep, PlanStepStatus
|
||
from agentkit.evolution.experience_schema import TaskExperience
|
||
from agentkit.evolution.experience_store import InMemoryExperienceStore
|
||
from agentkit.evolution.pitfall_detector import (
|
||
PitfallDetector,
|
||
PitfallWarning,
|
||
WarningLevel,
|
||
_compute_name_similarity,
|
||
_determine_warning_level,
|
||
_extract_keywords,
|
||
)
|
||
|
||
|
||
# ── Fixtures ──────────────────────────────────────────────
|
||
|
||
|
||
@pytest.fixture
|
||
def store():
|
||
"""无 embedder 的 InMemoryExperienceStore"""
|
||
return InMemoryExperienceStore(decay_rate=0.01, alpha=0.7)
|
||
|
||
|
||
@pytest.fixture
|
||
def detector(store):
|
||
"""基于 InMemoryExperienceStore 的 PitfallDetector"""
|
||
return PitfallDetector(experience_store=store, similarity_threshold=0.3)
|
||
|
||
|
||
def _make_experience(
|
||
task_type: str = "code_review",
|
||
goal: str = "Review the PR",
|
||
outcome: str = "success",
|
||
steps_summary: str | list[dict] = "",
|
||
failure_reasons: list[str] | None = None,
|
||
optimization_tips: list[str] | None = None,
|
||
success_rate: float = 1.0,
|
||
) -> TaskExperience:
|
||
"""创建测试用 TaskExperience"""
|
||
return TaskExperience(
|
||
experience_id="",
|
||
task_type=task_type,
|
||
goal=goal,
|
||
steps_summary=steps_summary,
|
||
outcome=outcome,
|
||
duration_seconds=10.0,
|
||
success_rate=success_rate,
|
||
failure_reasons=failure_reasons or [],
|
||
optimization_tips=optimization_tips or [],
|
||
created_at=datetime.now(timezone.utc),
|
||
)
|
||
|
||
|
||
def _make_step(
|
||
name: str = "step",
|
||
description: str = "do something",
|
||
step_id: str = "s1",
|
||
) -> PlanStep:
|
||
"""创建测试用 PlanStep"""
|
||
return PlanStep(
|
||
step_id=step_id,
|
||
name=name,
|
||
description=description,
|
||
status=PlanStepStatus.PENDING,
|
||
)
|
||
|
||
|
||
# ── 辅助函数测试 ──────────────────────────────────────────
|
||
|
||
|
||
class TestExtractKeywords:
|
||
def test_basic_extraction(self):
|
||
keywords = _extract_keywords("Call API Gateway")
|
||
assert "call" in keywords
|
||
assert "api" in keywords
|
||
assert "gateway" in keywords
|
||
|
||
def test_stop_words_filtered(self):
|
||
keywords = _extract_keywords("Call the API and check the result")
|
||
assert "the" not in keywords
|
||
assert "and" not in keywords
|
||
assert "call" in keywords
|
||
assert "api" in keywords
|
||
|
||
def test_underscore_and_hyphen(self):
|
||
keywords = _extract_keywords("call_api-gateway")
|
||
assert "call" in keywords
|
||
assert "api" in keywords
|
||
assert "gateway" in keywords
|
||
|
||
def test_single_char_filtered(self):
|
||
keywords = _extract_keywords("a b cd")
|
||
assert "a" not in keywords
|
||
assert "b" not in keywords
|
||
assert "cd" in keywords
|
||
|
||
def test_empty_string(self):
|
||
keywords = _extract_keywords("")
|
||
assert len(keywords) == 0
|
||
|
||
|
||
class TestComputeNameSimilarity:
|
||
def test_identical_names(self):
|
||
sim = _compute_name_similarity("Call API Gateway", "", "Call API Gateway")
|
||
assert sim == pytest.approx(1.0)
|
||
|
||
def test_partial_overlap(self):
|
||
sim = _compute_name_similarity("Call API Gateway", "", "Call External API")
|
||
# 共享: call, api; 并集: call, api, gateway, external
|
||
assert 0.0 < sim < 1.0
|
||
|
||
def test_no_overlap(self):
|
||
sim = _compute_name_similarity("Deploy Service", "", "Analyze Data")
|
||
assert sim == 0.0
|
||
|
||
def test_description_contributes(self):
|
||
sim_no_desc = _compute_name_similarity("Deploy", "", "Deploy Service")
|
||
sim_with_desc = _compute_name_similarity("Deploy", "Deploy Service", "Deploy Service")
|
||
# description 中包含匹配关键词,应提高相似度
|
||
assert sim_with_desc >= sim_no_desc
|
||
|
||
def test_empty_inputs(self):
|
||
sim = _compute_name_similarity("", "", "Call API")
|
||
assert sim == 0.0
|
||
|
||
|
||
class TestDetermineWarningLevel:
|
||
def test_high_threshold(self):
|
||
assert _determine_warning_level(0.6) == WarningLevel.HIGH
|
||
assert _determine_warning_level(0.5) == WarningLevel.HIGH
|
||
|
||
def test_medium_threshold(self):
|
||
assert _determine_warning_level(0.3) == WarningLevel.MEDIUM
|
||
assert _determine_warning_level(0.2) == WarningLevel.MEDIUM
|
||
|
||
def test_low_threshold(self):
|
||
assert _determine_warning_level(0.1) == WarningLevel.LOW
|
||
assert _determine_warning_level(0.01) == WarningLevel.LOW
|
||
|
||
|
||
# ── PitfallDetector.check_pitfalls 测试 ──────────────────
|
||
|
||
|
||
class TestCheckPitfalls:
|
||
async def test_no_planned_steps_returns_empty(self, detector):
|
||
warnings = await detector.check_pitfalls(task_type="code_review", planned_steps=[])
|
||
assert warnings == []
|
||
|
||
async def test_no_failed_experiences_returns_empty(self, detector, store):
|
||
"""无历史失败记录 → 返回空列表"""
|
||
# 只记录成功经验
|
||
await store.record_experience(
|
||
_make_experience(task_type="code_review", outcome="success")
|
||
)
|
||
steps = [_make_step(name="Review Code")]
|
||
warnings = await detector.check_pitfalls(task_type="code_review", planned_steps=steps)
|
||
assert warnings == []
|
||
|
||
async def test_high_failure_rate_returns_high_warning(self, detector, store):
|
||
"""计划包含历史高失败率步骤 → 返回 HIGH 级别预警"""
|
||
# 记录多次失败经验,其中 "Call API Gateway" 步骤失败率高
|
||
for _ in range(6):
|
||
await store.record_experience(
|
||
_make_experience(
|
||
task_type="deployment",
|
||
outcome="failure",
|
||
success_rate=0.0,
|
||
steps_summary=[
|
||
{"step_name": "Call API Gateway", "outcome": "failure", "error": "Timeout"},
|
||
{"step_name": "Deploy Container", "outcome": "success"},
|
||
],
|
||
failure_reasons=["API Gateway timeout"],
|
||
)
|
||
)
|
||
# 记录少数成功经验
|
||
for _ in range(4):
|
||
await store.record_experience(
|
||
_make_experience(
|
||
task_type="deployment",
|
||
outcome="success",
|
||
success_rate=1.0,
|
||
steps_summary=[
|
||
{"step_name": "Call API Gateway", "outcome": "success"},
|
||
{"step_name": "Deploy Container", "outcome": "success"},
|
||
],
|
||
)
|
||
)
|
||
|
||
steps = [_make_step(name="Call API Gateway", description="Invoke API Gateway endpoint")]
|
||
warnings = await detector.check_pitfalls(task_type="deployment", planned_steps=steps)
|
||
|
||
assert len(warnings) == 1
|
||
warning = warnings[0]
|
||
assert warning.step_name == "Call API Gateway"
|
||
assert warning.warning_level == WarningLevel.HIGH
|
||
assert warning.failure_rate >= 0.5
|
||
assert "Timeout" in warning.historical_failures
|
||
|
||
async def test_medium_failure_rate(self, detector, store):
|
||
"""中等失败率 → MEDIUM 级别预警"""
|
||
# 3 次失败,7 次成功 → 失败率 0.3
|
||
for _ in range(3):
|
||
await store.record_experience(
|
||
_make_experience(
|
||
task_type="data_analysis",
|
||
outcome="failure",
|
||
success_rate=0.0,
|
||
steps_summary=[
|
||
{"step_name": "Fetch Data", "outcome": "failure", "error": "Connection refused"},
|
||
],
|
||
)
|
||
)
|
||
for _ in range(7):
|
||
await store.record_experience(
|
||
_make_experience(
|
||
task_type="data_analysis",
|
||
outcome="success",
|
||
success_rate=1.0,
|
||
steps_summary=[
|
||
{"step_name": "Fetch Data", "outcome": "success"},
|
||
],
|
||
)
|
||
)
|
||
|
||
steps = [_make_step(name="Fetch Data", description="Fetch data from source")]
|
||
warnings = await detector.check_pitfalls(task_type="data_analysis", planned_steps=steps)
|
||
|
||
assert len(warnings) == 1
|
||
assert warnings[0].warning_level == WarningLevel.MEDIUM
|
||
assert 0.2 <= warnings[0].failure_rate < 0.5
|
||
|
||
async def test_low_failure_rate(self, detector, store):
|
||
"""低失败率 → LOW 级别预警"""
|
||
# 1 次失败,9 次成功 → 失败率 0.1
|
||
await store.record_experience(
|
||
_make_experience(
|
||
task_type="testing",
|
||
outcome="failure",
|
||
success_rate=0.0,
|
||
steps_summary=[
|
||
{"step_name": "Run Unit Tests", "outcome": "failure", "error": "Flaky test"},
|
||
],
|
||
)
|
||
)
|
||
for _ in range(9):
|
||
await store.record_experience(
|
||
_make_experience(
|
||
task_type="testing",
|
||
outcome="success",
|
||
success_rate=1.0,
|
||
steps_summary=[
|
||
{"step_name": "Run Unit Tests", "outcome": "success"},
|
||
],
|
||
)
|
||
)
|
||
|
||
steps = [_make_step(name="Run Unit Tests", description="Execute unit test suite")]
|
||
warnings = await detector.check_pitfalls(task_type="testing", planned_steps=steps)
|
||
|
||
assert len(warnings) == 1
|
||
assert warnings[0].warning_level == WarningLevel.LOW
|
||
|
||
async def test_multiple_steps_with_risks_sorted_by_severity(self, detector, store):
|
||
"""多个步骤有风险 → 按严重程度排序返回"""
|
||
# "Call API" 高失败率,"Validate Input" 低失败率
|
||
for _ in range(6):
|
||
await store.record_experience(
|
||
_make_experience(
|
||
task_type="integration",
|
||
outcome="failure",
|
||
success_rate=0.0,
|
||
steps_summary=[
|
||
{"step_name": "Call API", "outcome": "failure", "error": "Timeout"},
|
||
{"step_name": "Validate Input", "outcome": "success"},
|
||
],
|
||
)
|
||
)
|
||
for _ in range(4):
|
||
await store.record_experience(
|
||
_make_experience(
|
||
task_type="integration",
|
||
outcome="success",
|
||
success_rate=1.0,
|
||
steps_summary=[
|
||
{"step_name": "Call API", "outcome": "success"},
|
||
{"step_name": "Validate Input", "outcome": "success"},
|
||
],
|
||
)
|
||
)
|
||
# 单独给 Validate Input 加一条失败记录
|
||
await store.record_experience(
|
||
_make_experience(
|
||
task_type="integration",
|
||
outcome="partial",
|
||
success_rate=0.5,
|
||
steps_summary=[
|
||
{"step_name": "Call API", "outcome": "success"},
|
||
{"step_name": "Validate Input", "outcome": "failure", "error": "Invalid schema"},
|
||
],
|
||
)
|
||
)
|
||
|
||
steps = [
|
||
_make_step(name="Validate Input", description="Validate input data", step_id="s1"),
|
||
_make_step(name="Call API", description="Call external API", step_id="s2"),
|
||
]
|
||
warnings = await detector.check_pitfalls(task_type="integration", planned_steps=steps)
|
||
|
||
assert len(warnings) == 2
|
||
# HIGH 应排在 MEDIUM/LOW 之前
|
||
assert warnings[0].warning_level == WarningLevel.HIGH
|
||
assert warnings[0].step_name == "Call API"
|
||
|
||
async def test_no_matching_steps_returns_empty(self, detector, store):
|
||
"""计划步骤与历史失败步骤无匹配 → 返回空列表"""
|
||
await store.record_experience(
|
||
_make_experience(
|
||
task_type="code_review",
|
||
outcome="failure",
|
||
success_rate=0.0,
|
||
steps_summary=[
|
||
{"step_name": "Run Linter", "outcome": "failure", "error": "Config error"},
|
||
],
|
||
)
|
||
)
|
||
|
||
# 计划步骤名称与历史步骤完全不同
|
||
steps = [_make_step(name="Deploy Application", description="Deploy to production")]
|
||
warnings = await detector.check_pitfalls(task_type="code_review", planned_steps=steps)
|
||
assert warnings == []
|
||
|
||
async def test_different_task_type_no_cross_contamination(self, detector, store):
|
||
"""不同 task_type 的失败经验不会跨类型预警"""
|
||
await store.record_experience(
|
||
_make_experience(
|
||
task_type="deployment",
|
||
outcome="failure",
|
||
success_rate=0.0,
|
||
steps_summary=[
|
||
{"step_name": "Deploy Service", "outcome": "failure", "error": "OOM"},
|
||
],
|
||
)
|
||
)
|
||
|
||
# 查询 code_review 类型,不应返回 deployment 的失败经验
|
||
steps = [_make_step(name="Deploy Service", description="Deploy the service")]
|
||
warnings = await detector.check_pitfalls(task_type="code_review", planned_steps=steps)
|
||
assert warnings == []
|
||
|
||
async def test_partial_outcome_included(self, detector, store):
|
||
"""partial 结果的经验也应被检索"""
|
||
await store.record_experience(
|
||
_make_experience(
|
||
task_type="migration",
|
||
outcome="partial",
|
||
success_rate=0.5,
|
||
steps_summary=[
|
||
{"step_name": "Migrate Database", "outcome": "failure", "error": "Schema mismatch"},
|
||
],
|
||
)
|
||
)
|
||
|
||
steps = [_make_step(name="Migrate Database", description="Migrate DB schema")]
|
||
warnings = await detector.check_pitfalls(task_type="migration", planned_steps=steps)
|
||
assert len(warnings) == 1
|
||
|
||
async def test_steps_summary_as_string_ignored(self, detector, store):
|
||
"""steps_summary 为字符串时无法提取步骤级信息,不产生预警"""
|
||
await store.record_experience(
|
||
_make_experience(
|
||
task_type="code_review",
|
||
outcome="failure",
|
||
success_rate=0.0,
|
||
steps_summary="Executed code_review task", # 字符串格式
|
||
)
|
||
)
|
||
|
||
steps = [_make_step(name="Review Code", description="Review the code")]
|
||
warnings = await detector.check_pitfalls(task_type="code_review", planned_steps=steps)
|
||
assert warnings == []
|
||
|
||
|
||
# ── AE3 场景测试 ──────────────────────────────────────────
|
||
|
||
|
||
class TestAE3Scenario:
|
||
"""AE3: "调用 X 系统 API 在高峰期超时率 60%" → 新任务调用时自动预警"""
|
||
|
||
async def test_api_timeout_high_failure_rate_warning(self, detector, store):
|
||
"""调用 X 系统 API 在高峰期超时率 60% → 新任务调用时自动预警"""
|
||
# 模拟历史:10 次调用,6 次超时 → 60% 失败率
|
||
for _ in range(6):
|
||
await store.record_experience(
|
||
_make_experience(
|
||
task_type="order_processing",
|
||
goal="Process orders via X system",
|
||
outcome="failure",
|
||
success_rate=0.0,
|
||
steps_summary=[
|
||
{"step_name": "Call X System API", "outcome": "failure", "error": "高峰期超时"},
|
||
{"step_name": "Process Order", "outcome": "success"},
|
||
],
|
||
failure_reasons=["X System API timeout during peak hours"],
|
||
optimization_tips=["Avoid peak hours", "Add retry logic"],
|
||
)
|
||
)
|
||
for _ in range(4):
|
||
await store.record_experience(
|
||
_make_experience(
|
||
task_type="order_processing",
|
||
goal="Process orders via X system",
|
||
outcome="success",
|
||
success_rate=1.0,
|
||
steps_summary=[
|
||
{"step_name": "Call X System API", "outcome": "success"},
|
||
{"step_name": "Process Order", "outcome": "success"},
|
||
],
|
||
)
|
||
)
|
||
|
||
# 新任务计划包含调用 X 系统 API
|
||
steps = [
|
||
_make_step(name="Call X System API", description="Invoke X system API for orders"),
|
||
]
|
||
warnings = await detector.check_pitfalls(task_type="order_processing", planned_steps=steps)
|
||
|
||
assert len(warnings) == 1
|
||
warning = warnings[0]
|
||
assert warning.warning_level == WarningLevel.HIGH
|
||
assert warning.failure_rate >= 0.5
|
||
assert any("超时" in reason for reason in warning.historical_failures)
|
||
assert warning.suggestion # 应有建议
|
||
|
||
|
||
# ── PitfallWarning 数据模型测试 ───────────────────────────
|
||
|
||
|
||
class TestPitfallWarning:
|
||
def test_creation(self):
|
||
warning = PitfallWarning(
|
||
step_name="Call API",
|
||
warning_level=WarningLevel.HIGH,
|
||
failure_rate=0.6,
|
||
historical_failures=["Timeout", "Connection refused"],
|
||
suggestion="Add retry logic",
|
||
)
|
||
assert warning.step_name == "Call API"
|
||
assert warning.warning_level == WarningLevel.HIGH
|
||
assert warning.failure_rate == 0.6
|
||
assert warning.historical_failures == ["Timeout", "Connection refused"]
|
||
assert warning.suggestion == "Add retry logic"
|
||
|
||
def test_default_values(self):
|
||
warning = PitfallWarning(
|
||
step_name="Test",
|
||
warning_level=WarningLevel.LOW,
|
||
failure_rate=0.1,
|
||
)
|
||
assert warning.historical_failures == []
|
||
assert warning.suggestion == ""
|
||
|
||
|
||
# ── WarningLevel 枚举测试 ─────────────────────────────────
|
||
|
||
|
||
class TestWarningLevel:
|
||
def test_values(self):
|
||
assert WarningLevel.HIGH.value == "high"
|
||
assert WarningLevel.MEDIUM.value == "medium"
|
||
assert WarningLevel.LOW.value == "low"
|
||
|
||
def test_string_comparison(self):
|
||
assert WarningLevel.HIGH == "high"
|
||
assert WarningLevel.MEDIUM == "medium"
|
||
assert WarningLevel.LOW == "low"
|
||
|
||
|
||
# ── 相似度阈值配置测试 ─────────────────────────────────────
|
||
|
||
|
||
class TestSimilarityThreshold:
|
||
async def test_custom_threshold(self, store):
|
||
"""自定义相似度阈值"""
|
||
# 低阈值:更容易匹配
|
||
detector_low = PitfallDetector(experience_store=store, similarity_threshold=0.1)
|
||
# 高阈值:更难匹配
|
||
detector_high = PitfallDetector(experience_store=store, similarity_threshold=0.8)
|
||
|
||
await store.record_experience(
|
||
_make_experience(
|
||
task_type="testing",
|
||
outcome="failure",
|
||
success_rate=0.0,
|
||
steps_summary=[
|
||
{"step_name": "Run Integration Tests", "outcome": "failure", "error": "Timeout"},
|
||
],
|
||
)
|
||
)
|
||
|
||
steps = [_make_step(name="Run Unit Tests", description="Execute tests")]
|
||
# 低阈值可能匹配,高阈值可能不匹配
|
||
warnings_low = await detector_low.check_pitfalls(task_type="testing", planned_steps=steps)
|
||
warnings_high = await detector_high.check_pitfalls(task_type="testing", planned_steps=steps)
|
||
# 低阈值匹配数 >= 高阈值匹配数
|
||
assert len(warnings_low) >= len(warnings_high)
|
||
|
||
|
||
# ── 端到端流程测试 ─────────────────────────────────────────
|
||
|
||
|
||
class TestEndToEnd:
|
||
async def test_full_pitfall_detection_flow(self, detector, store):
|
||
"""完整的避坑检测流程"""
|
||
# 1. 记录多种失败经验
|
||
await store.record_experience(
|
||
_make_experience(
|
||
task_type="deployment",
|
||
outcome="failure",
|
||
success_rate=0.0,
|
||
steps_summary=[
|
||
{"step_name": "Build Docker Image", "outcome": "failure", "error": "OOM"},
|
||
{"step_name": "Push to Registry", "outcome": "success"},
|
||
],
|
||
failure_reasons=["Docker build OOM"],
|
||
optimization_tips=["Increase memory limit"],
|
||
)
|
||
)
|
||
await store.record_experience(
|
||
_make_experience(
|
||
task_type="deployment",
|
||
outcome="failure",
|
||
success_rate=0.0,
|
||
steps_summary=[
|
||
{"step_name": "Build Docker Image", "outcome": "failure", "error": "Dependency conflict"},
|
||
{"step_name": "Push to Registry", "outcome": "success"},
|
||
],
|
||
failure_reasons=["Dependency conflict"],
|
||
)
|
||
)
|
||
await store.record_experience(
|
||
_make_experience(
|
||
task_type="deployment",
|
||
outcome="success",
|
||
success_rate=1.0,
|
||
steps_summary=[
|
||
{"step_name": "Build Docker Image", "outcome": "success"},
|
||
{"step_name": "Push to Registry", "outcome": "success"},
|
||
],
|
||
)
|
||
)
|
||
|
||
# 2. 新任务计划
|
||
steps = [
|
||
_make_step(name="Build Docker Image", description="Build the container image", step_id="s1"),
|
||
_make_step(name="Push to Registry", description="Push image to container registry", step_id="s2"),
|
||
]
|
||
|
||
# 3. 检测避坑
|
||
warnings = await detector.check_pitfalls(task_type="deployment", planned_steps=steps)
|
||
|
||
# 4. 验证结果
|
||
assert len(warnings) >= 1
|
||
# Build Docker Image 失败率 2/3 ≈ 0.667,应为 HIGH
|
||
build_warning = next((w for w in warnings if w.step_name == "Build Docker Image"), None)
|
||
assert build_warning is not None
|
||
assert build_warning.warning_level == WarningLevel.HIGH
|
||
assert build_warning.failure_rate == pytest.approx(2.0 / 3.0, abs=0.01)
|
||
|
||
async def test_suggestion_contains_useful_info(self, detector, store):
|
||
"""预警建议应包含有用的失败原因和优化建议"""
|
||
await store.record_experience(
|
||
_make_experience(
|
||
task_type="api_integration",
|
||
outcome="failure",
|
||
success_rate=0.0,
|
||
steps_summary=[
|
||
{"step_name": "Authenticate", "outcome": "failure", "error": "Token expired"},
|
||
],
|
||
failure_reasons=["Token expired"],
|
||
optimization_tips=["Refresh token before expiry"],
|
||
)
|
||
)
|
||
|
||
steps = [_make_step(name="Authenticate", description="Authenticate with API")]
|
||
warnings = await detector.check_pitfalls(task_type="api_integration", planned_steps=steps)
|
||
|
||
assert len(warnings) == 1
|
||
assert "Token expired" in warnings[0].suggestion
|