"""Tests for PitfallDetector - 任务避坑预警检测""" from __future__ import annotations from datetime import datetime, timezone import pytest from agentkit.core.plan_schema import PlanStep, PlanStepStatus from agentkit.evolution.experience_schema import TaskExperience from agentkit.evolution.experience_store import InMemoryExperienceStore from agentkit.evolution.pitfall_detector import ( PitfallDetector, PitfallWarning, WarningLevel, _compute_name_similarity, _determine_warning_level, _extract_keywords, ) # ── Fixtures ────────────────────────────────────────────── @pytest.fixture def store(): """无 embedder 的 InMemoryExperienceStore""" return InMemoryExperienceStore(decay_rate=0.01, alpha=0.7) @pytest.fixture def detector(store): """基于 InMemoryExperienceStore 的 PitfallDetector""" return PitfallDetector(experience_store=store, similarity_threshold=0.3) def _make_experience( task_type: str = "code_review", goal: str = "Review the PR", outcome: str = "success", steps_summary: str | list[dict] = "", failure_reasons: list[str] | None = None, optimization_tips: list[str] | None = None, success_rate: float = 1.0, ) -> TaskExperience: """创建测试用 TaskExperience""" return TaskExperience( experience_id="", task_type=task_type, goal=goal, steps_summary=steps_summary, outcome=outcome, duration_seconds=10.0, success_rate=success_rate, failure_reasons=failure_reasons or [], optimization_tips=optimization_tips or [], created_at=datetime.now(timezone.utc), ) def _make_step( name: str = "step", description: str = "do something", step_id: str = "s1", ) -> PlanStep: """创建测试用 PlanStep""" return PlanStep( step_id=step_id, name=name, description=description, status=PlanStepStatus.PENDING, ) # ── 辅助函数测试 ────────────────────────────────────────── class TestExtractKeywords: def test_basic_extraction(self): keywords = _extract_keywords("Call API Gateway") assert "call" in keywords assert "api" in keywords assert "gateway" in keywords def test_stop_words_filtered(self): keywords = _extract_keywords("Call the API and check the result") assert "the" not in keywords assert "and" not in keywords assert "call" in keywords assert "api" in keywords def test_underscore_and_hyphen(self): keywords = _extract_keywords("call_api-gateway") assert "call" in keywords assert "api" in keywords assert "gateway" in keywords def test_single_char_filtered(self): keywords = _extract_keywords("a b cd") assert "a" not in keywords assert "b" not in keywords assert "cd" in keywords def test_empty_string(self): keywords = _extract_keywords("") assert len(keywords) == 0 class TestComputeNameSimilarity: def test_identical_names(self): sim = _compute_name_similarity("Call API Gateway", "", "Call API Gateway") assert sim == pytest.approx(1.0) def test_partial_overlap(self): sim = _compute_name_similarity("Call API Gateway", "", "Call External API") # 共享: call, api; 并集: call, api, gateway, external assert 0.0 < sim < 1.0 def test_no_overlap(self): sim = _compute_name_similarity("Deploy Service", "", "Analyze Data") assert sim == 0.0 def test_description_contributes(self): sim_no_desc = _compute_name_similarity("Deploy", "", "Deploy Service") sim_with_desc = _compute_name_similarity("Deploy", "Deploy Service", "Deploy Service") # description 中包含匹配关键词,应提高相似度 assert sim_with_desc >= sim_no_desc def test_empty_inputs(self): sim = _compute_name_similarity("", "", "Call API") assert sim == 0.0 class TestDetermineWarningLevel: def test_high_threshold(self): assert _determine_warning_level(0.6) == WarningLevel.HIGH assert _determine_warning_level(0.5) == WarningLevel.HIGH def test_medium_threshold(self): assert _determine_warning_level(0.3) == WarningLevel.MEDIUM assert _determine_warning_level(0.2) == WarningLevel.MEDIUM def test_low_threshold(self): assert _determine_warning_level(0.1) == WarningLevel.LOW assert _determine_warning_level(0.01) == WarningLevel.LOW # ── PitfallDetector.check_pitfalls 测试 ────────────────── class TestCheckPitfalls: async def test_no_planned_steps_returns_empty(self, detector): warnings = await detector.check_pitfalls(task_type="code_review", planned_steps=[]) assert warnings == [] async def test_no_failed_experiences_returns_empty(self, detector, store): """无历史失败记录 → 返回空列表""" # 只记录成功经验 await store.record_experience( _make_experience(task_type="code_review", outcome="success") ) steps = [_make_step(name="Review Code")] warnings = await detector.check_pitfalls(task_type="code_review", planned_steps=steps) assert warnings == [] async def test_high_failure_rate_returns_high_warning(self, detector, store): """计划包含历史高失败率步骤 → 返回 HIGH 级别预警""" # 记录多次失败经验,其中 "Call API Gateway" 步骤失败率高 for _ in range(6): await store.record_experience( _make_experience( task_type="deployment", outcome="failure", success_rate=0.0, steps_summary=[ {"step_name": "Call API Gateway", "outcome": "failure", "error": "Timeout"}, {"step_name": "Deploy Container", "outcome": "success"}, ], failure_reasons=["API Gateway timeout"], ) ) # 记录少数成功经验 for _ in range(4): await store.record_experience( _make_experience( task_type="deployment", outcome="success", success_rate=1.0, steps_summary=[ {"step_name": "Call API Gateway", "outcome": "success"}, {"step_name": "Deploy Container", "outcome": "success"}, ], ) ) steps = [_make_step(name="Call API Gateway", description="Invoke API Gateway endpoint")] warnings = await detector.check_pitfalls(task_type="deployment", planned_steps=steps) assert len(warnings) == 1 warning = warnings[0] assert warning.step_name == "Call API Gateway" assert warning.warning_level == WarningLevel.HIGH assert warning.failure_rate >= 0.5 assert "Timeout" in warning.historical_failures async def test_medium_failure_rate(self, detector, store): """中等失败率 → MEDIUM 级别预警""" # 3 次失败,7 次成功 → 失败率 0.3 for _ in range(3): await store.record_experience( _make_experience( task_type="data_analysis", outcome="failure", success_rate=0.0, steps_summary=[ {"step_name": "Fetch Data", "outcome": "failure", "error": "Connection refused"}, ], ) ) for _ in range(7): await store.record_experience( _make_experience( task_type="data_analysis", outcome="success", success_rate=1.0, steps_summary=[ {"step_name": "Fetch Data", "outcome": "success"}, ], ) ) steps = [_make_step(name="Fetch Data", description="Fetch data from source")] warnings = await detector.check_pitfalls(task_type="data_analysis", planned_steps=steps) assert len(warnings) == 1 assert warnings[0].warning_level == WarningLevel.MEDIUM assert 0.2 <= warnings[0].failure_rate < 0.5 async def test_low_failure_rate(self, detector, store): """低失败率 → LOW 级别预警""" # 1 次失败,9 次成功 → 失败率 0.1 await store.record_experience( _make_experience( task_type="testing", outcome="failure", success_rate=0.0, steps_summary=[ {"step_name": "Run Unit Tests", "outcome": "failure", "error": "Flaky test"}, ], ) ) for _ in range(9): await store.record_experience( _make_experience( task_type="testing", outcome="success", success_rate=1.0, steps_summary=[ {"step_name": "Run Unit Tests", "outcome": "success"}, ], ) ) steps = [_make_step(name="Run Unit Tests", description="Execute unit test suite")] warnings = await detector.check_pitfalls(task_type="testing", planned_steps=steps) assert len(warnings) == 1 assert warnings[0].warning_level == WarningLevel.LOW async def test_multiple_steps_with_risks_sorted_by_severity(self, detector, store): """多个步骤有风险 → 按严重程度排序返回""" # "Call API" 高失败率,"Validate Input" 低失败率 for _ in range(6): await store.record_experience( _make_experience( task_type="integration", outcome="failure", success_rate=0.0, steps_summary=[ {"step_name": "Call API", "outcome": "failure", "error": "Timeout"}, {"step_name": "Validate Input", "outcome": "success"}, ], ) ) for _ in range(4): await store.record_experience( _make_experience( task_type="integration", outcome="success", success_rate=1.0, steps_summary=[ {"step_name": "Call API", "outcome": "success"}, {"step_name": "Validate Input", "outcome": "success"}, ], ) ) # 单独给 Validate Input 加一条失败记录 await store.record_experience( _make_experience( task_type="integration", outcome="partial", success_rate=0.5, steps_summary=[ {"step_name": "Call API", "outcome": "success"}, {"step_name": "Validate Input", "outcome": "failure", "error": "Invalid schema"}, ], ) ) steps = [ _make_step(name="Validate Input", description="Validate input data", step_id="s1"), _make_step(name="Call API", description="Call external API", step_id="s2"), ] warnings = await detector.check_pitfalls(task_type="integration", planned_steps=steps) assert len(warnings) == 2 # HIGH 应排在 MEDIUM/LOW 之前 assert warnings[0].warning_level == WarningLevel.HIGH assert warnings[0].step_name == "Call API" async def test_no_matching_steps_returns_empty(self, detector, store): """计划步骤与历史失败步骤无匹配 → 返回空列表""" await store.record_experience( _make_experience( task_type="code_review", outcome="failure", success_rate=0.0, steps_summary=[ {"step_name": "Run Linter", "outcome": "failure", "error": "Config error"}, ], ) ) # 计划步骤名称与历史步骤完全不同 steps = [_make_step(name="Deploy Application", description="Deploy to production")] warnings = await detector.check_pitfalls(task_type="code_review", planned_steps=steps) assert warnings == [] async def test_different_task_type_no_cross_contamination(self, detector, store): """不同 task_type 的失败经验不会跨类型预警""" await store.record_experience( _make_experience( task_type="deployment", outcome="failure", success_rate=0.0, steps_summary=[ {"step_name": "Deploy Service", "outcome": "failure", "error": "OOM"}, ], ) ) # 查询 code_review 类型,不应返回 deployment 的失败经验 steps = [_make_step(name="Deploy Service", description="Deploy the service")] warnings = await detector.check_pitfalls(task_type="code_review", planned_steps=steps) assert warnings == [] async def test_partial_outcome_included(self, detector, store): """partial 结果的经验也应被检索""" await store.record_experience( _make_experience( task_type="migration", outcome="partial", success_rate=0.5, steps_summary=[ {"step_name": "Migrate Database", "outcome": "failure", "error": "Schema mismatch"}, ], ) ) steps = [_make_step(name="Migrate Database", description="Migrate DB schema")] warnings = await detector.check_pitfalls(task_type="migration", planned_steps=steps) assert len(warnings) == 1 async def test_steps_summary_as_string_ignored(self, detector, store): """steps_summary 为字符串时无法提取步骤级信息,不产生预警""" await store.record_experience( _make_experience( task_type="code_review", outcome="failure", success_rate=0.0, steps_summary="Executed code_review task", # 字符串格式 ) ) steps = [_make_step(name="Review Code", description="Review the code")] warnings = await detector.check_pitfalls(task_type="code_review", planned_steps=steps) assert warnings == [] # ── AE3 场景测试 ────────────────────────────────────────── class TestAE3Scenario: """AE3: "调用 X 系统 API 在高峰期超时率 60%" → 新任务调用时自动预警""" async def test_api_timeout_high_failure_rate_warning(self, detector, store): """调用 X 系统 API 在高峰期超时率 60% → 新任务调用时自动预警""" # 模拟历史:10 次调用,6 次超时 → 60% 失败率 for _ in range(6): await store.record_experience( _make_experience( task_type="order_processing", goal="Process orders via X system", outcome="failure", success_rate=0.0, steps_summary=[ {"step_name": "Call X System API", "outcome": "failure", "error": "高峰期超时"}, {"step_name": "Process Order", "outcome": "success"}, ], failure_reasons=["X System API timeout during peak hours"], optimization_tips=["Avoid peak hours", "Add retry logic"], ) ) for _ in range(4): await store.record_experience( _make_experience( task_type="order_processing", goal="Process orders via X system", outcome="success", success_rate=1.0, steps_summary=[ {"step_name": "Call X System API", "outcome": "success"}, {"step_name": "Process Order", "outcome": "success"}, ], ) ) # 新任务计划包含调用 X 系统 API steps = [ _make_step(name="Call X System API", description="Invoke X system API for orders"), ] warnings = await detector.check_pitfalls(task_type="order_processing", planned_steps=steps) assert len(warnings) == 1 warning = warnings[0] assert warning.warning_level == WarningLevel.HIGH assert warning.failure_rate >= 0.5 assert any("超时" in reason for reason in warning.historical_failures) assert warning.suggestion # 应有建议 # ── PitfallWarning 数据模型测试 ─────────────────────────── class TestPitfallWarning: def test_creation(self): warning = PitfallWarning( step_name="Call API", warning_level=WarningLevel.HIGH, failure_rate=0.6, historical_failures=["Timeout", "Connection refused"], suggestion="Add retry logic", ) assert warning.step_name == "Call API" assert warning.warning_level == WarningLevel.HIGH assert warning.failure_rate == 0.6 assert warning.historical_failures == ["Timeout", "Connection refused"] assert warning.suggestion == "Add retry logic" def test_default_values(self): warning = PitfallWarning( step_name="Test", warning_level=WarningLevel.LOW, failure_rate=0.1, ) assert warning.historical_failures == [] assert warning.suggestion == "" # ── WarningLevel 枚举测试 ───────────────────────────────── class TestWarningLevel: def test_values(self): assert WarningLevel.HIGH.value == "high" assert WarningLevel.MEDIUM.value == "medium" assert WarningLevel.LOW.value == "low" def test_string_comparison(self): assert WarningLevel.HIGH == "high" assert WarningLevel.MEDIUM == "medium" assert WarningLevel.LOW == "low" # ── 相似度阈值配置测试 ───────────────────────────────────── class TestSimilarityThreshold: async def test_custom_threshold(self, store): """自定义相似度阈值""" # 低阈值:更容易匹配 detector_low = PitfallDetector(experience_store=store, similarity_threshold=0.1) # 高阈值:更难匹配 detector_high = PitfallDetector(experience_store=store, similarity_threshold=0.8) await store.record_experience( _make_experience( task_type="testing", outcome="failure", success_rate=0.0, steps_summary=[ {"step_name": "Run Integration Tests", "outcome": "failure", "error": "Timeout"}, ], ) ) steps = [_make_step(name="Run Unit Tests", description="Execute tests")] # 低阈值可能匹配,高阈值可能不匹配 warnings_low = await detector_low.check_pitfalls(task_type="testing", planned_steps=steps) warnings_high = await detector_high.check_pitfalls(task_type="testing", planned_steps=steps) # 低阈值匹配数 >= 高阈值匹配数 assert len(warnings_low) >= len(warnings_high) # ── 端到端流程测试 ───────────────────────────────────────── class TestEndToEnd: async def test_full_pitfall_detection_flow(self, detector, store): """完整的避坑检测流程""" # 1. 记录多种失败经验 await store.record_experience( _make_experience( task_type="deployment", outcome="failure", success_rate=0.0, steps_summary=[ {"step_name": "Build Docker Image", "outcome": "failure", "error": "OOM"}, {"step_name": "Push to Registry", "outcome": "success"}, ], failure_reasons=["Docker build OOM"], optimization_tips=["Increase memory limit"], ) ) await store.record_experience( _make_experience( task_type="deployment", outcome="failure", success_rate=0.0, steps_summary=[ {"step_name": "Build Docker Image", "outcome": "failure", "error": "Dependency conflict"}, {"step_name": "Push to Registry", "outcome": "success"}, ], failure_reasons=["Dependency conflict"], ) ) await store.record_experience( _make_experience( task_type="deployment", outcome="success", success_rate=1.0, steps_summary=[ {"step_name": "Build Docker Image", "outcome": "success"}, {"step_name": "Push to Registry", "outcome": "success"}, ], ) ) # 2. 新任务计划 steps = [ _make_step(name="Build Docker Image", description="Build the container image", step_id="s1"), _make_step(name="Push to Registry", description="Push image to container registry", step_id="s2"), ] # 3. 检测避坑 warnings = await detector.check_pitfalls(task_type="deployment", planned_steps=steps) # 4. 验证结果 assert len(warnings) >= 1 # Build Docker Image 失败率 2/3 ≈ 0.667,应为 HIGH build_warning = next((w for w in warnings if w.step_name == "Build Docker Image"), None) assert build_warning is not None assert build_warning.warning_level == WarningLevel.HIGH assert build_warning.failure_rate == pytest.approx(2.0 / 3.0, abs=0.01) async def test_suggestion_contains_useful_info(self, detector, store): """预警建议应包含有用的失败原因和优化建议""" await store.record_experience( _make_experience( task_type="api_integration", outcome="failure", success_rate=0.0, steps_summary=[ {"step_name": "Authenticate", "outcome": "failure", "error": "Token expired"}, ], failure_reasons=["Token expired"], optimization_tips=["Refresh token before expiry"], ) ) steps = [_make_step(name="Authenticate", description="Authenticate with API")] warnings = await detector.check_pitfalls(task_type="api_integration", planned_steps=steps) assert len(warnings) == 1 assert "Token expired" in warnings[0].suggestion