"""CostAwareRouter 全链路场景测试。 验证目标: 1. 每个路由路径都正确设置 execution_mode 2. 不存在硬编码的字符串匹配(match_method 枚举) 3. 所有输入场景都有明确的路由结果 4. 边界条件(空输入、超长输入、特殊字符)不会导致异常 5. execution_mode 是执行路径选择的唯一依据 运行: python3 tests/test_routing_chain.py """ import asyncio import sys import os # Add project root to path sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) from agentkit.chat.skill_routing import ( CostAwareRouter, ExecutionMode, SkillRoutingResult, HeuristicClassifier, _GREETING_RE, _CHAT_MODE_RE, _IDENTITY_RE, parse_skill_prefix, ) # ============================================================ # 1. Layer 0 正则硬编码测试 # ============================================================ def test_layer0_regex_hardcoding(): """Layer 0 的正则是否只覆盖了枚举的词,遗漏了常见变体?""" print("\n" + "=" * 60) print("1. Layer 0 正则硬编码测试") print("=" * 60) # 问候 — 应该匹配 greeting_should_match = [ "你好", "hi", "hello", "hey", "嗨", "哈喽", "早上好", "下午好", "晚上好", "Good morning", "Good afternoon", "Good evening", "你好!", "hi?", "hello.", ] # 问候 — 不应该匹配 greeting_should_not_match = [ "你好,帮我写个代码", # 不是纯问候 "hello world", # 不是问候 "hi there, I need help", # 带后续内容 "你好啊,请问一下", # 带后续内容 ] print("\n 问候匹配测试:") for text in greeting_should_match: m = _GREETING_RE.match(text.strip()) status = "PASS" if m else "FAIL" print(f" [{status}] {text!r:30s} → {'MATCH' if m else 'MISS'}") print("\n 问候不应匹配测试:") for text in greeting_should_not_match: m = _GREETING_RE.match(text.strip()) status = "PASS" if not m else "FAIL" print(f" [{status}] {text!r:30s} → {'MATCH' if m else 'MISS'}") # 简单对话 — 应该匹配 chat_should_match = [ "谢谢", "感谢", "thanks", "thank you", "ok", "好的", "嗯", "对", "不是", "没关系", "再见", "bye", "goodbye", ] # 简单对话 — 不应该匹配 chat_should_not_match = [ "谢谢你的帮助,但我还有个问题", # 带后续内容 "ok let's do it", # 带后续内容 ] print("\n 简单对话匹配测试:") for text in chat_should_match: m = _CHAT_MODE_RE.match(text.strip()) status = "PASS" if m else "FAIL" print(f" [{status}] {text!r:30s} → {'MATCH' if m else 'MISS'}") print("\n 简单对话不应匹配测试:") for text in chat_should_not_match: m = _CHAT_MODE_RE.match(text.strip()) status = "PASS" if not m else "FAIL" print(f" [{status}] {text!r:30s} → {'MATCH' if m else 'MISS'}") # 身份问题 — 应该匹配 identity_should_match = [ "你是谁", "你叫什么", "你是什么", "你是哪个", "who are you", "what are you", "what's your name", "介绍一下你自己", "自我介绍", "你叫啥", "你叫什么名字", "你的名字", "你是谁?", "who are you?", ] # 身份问题 — 不应该匹配 identity_should_not_match = [ "你是谁的粉丝", # 不是问身份 "你是什么时候发布的", # 不是问身份 "who are you talking to", # 不是问身份 ] print("\n 身份问题匹配测试:") for text in identity_should_match: m = _IDENTITY_RE.match(text.strip()) status = "PASS" if m else "FAIL" print(f" [{status}] {text!r:30s} → {'MATCH' if m else 'MISS'}") print("\n 身份问题不应匹配测试:") for text in identity_should_not_match: m = _IDENTITY_RE.match(text.strip()) status = "PASS" if not m else "FAIL" print(f" [{status}] {text!r:30s} → {'MATCH' if m else 'MISS'}") # ============================================================ # 2. HeuristicClassifier 硬编码测试 # ============================================================ def test_heuristic_classifier(): """HeuristicClassifier 是否对常见输入给出合理的复杂度评分?""" print("\n" + "=" * 60) print("2. HeuristicClassifier 复杂度评估测试") print("=" * 60) classifier = HeuristicClassifier() test_cases = [ # (input, expected_complexity_range, description) ("你是谁", (0.0, 0.3), "短问题,无复杂度暗示"), ("今天天气怎么样", (0.0, 0.3), "简单闲聊"), ("帮我写一个 Python 函数", (0.6, 1.0), "含代码关键词"), ("请搜索一下最近的新闻", (0.6, 1.0), "含搜索关键词"), ("如何学习 Python?", (0.3, 0.6), "含中等关键词"), ("为什么天空是蓝色的", (0.3, 0.6), "含中等关键词"), ("执行命令 ls -la 并分析结果,然后搜索相关的配置文件,最后生成一份报告", (0.7, 1.0), "多步复杂任务"), ("", (0.0, 0.0), "空输入"), ("a", (0.0, 0.3), "极短输入"), ("帮我分析这段代码的性能瓶颈,并给出优化建议", (0.6, 1.0), "含分析和优化关键词"), ] for text, (low, high), desc in test_cases: score = classifier.classify(text) in_range = low <= score <= high status = "PASS" if in_range else "WARN" print(f" [{status}] {desc:30s} → {score:.2f} (期望 {low:.1f}-{high:.1f})") # ============================================================ # 3. CostAwareRouter 全链路 execution_mode 测试 # ============================================================ async def test_execution_mode_chain(): """验证每条路由路径都正确设置了 execution_mode。""" print("\n" + "=" * 60) print("3. CostAwareRouter execution_mode 全链路测试") print("=" * 60) # 创建不依赖 LLM 的路由器(heuristic 模式,无 merged LLM) router = CostAwareRouter( llm_gateway=None, # 不使用 LLM classifier="heuristic", merged_llm_classify=False, # 不使用 merged LLM ) # Mock skill_registry and intent_router (empty) class MockSkillRegistry: def list_skills(self): return [] def get(self, name): raise KeyError(f"Skill not found: {name}") class MockIntentRouter: pass skill_registry = MockSkillRegistry() intent_router = MockIntentRouter() default_tools = [] # 无工具 test_cases = [ # Layer 0 测试 ("你好", ExecutionMode.DIRECT_CHAT, "Layer 0: 问候"), ("hello", ExecutionMode.DIRECT_CHAT, "Layer 0: 英文问候"), ("谢谢", ExecutionMode.DIRECT_CHAT, "Layer 0: 简单对话"), ("你是谁", ExecutionMode.DIRECT_CHAT, "Layer 0: 身份问题"), ("who are you", ExecutionMode.DIRECT_CHAT, "Layer 0: 英文身份问题"), ("@skill:coder 帮我写代码", ExecutionMode.DIRECT_CHAT, "Layer 0: 显式技能不存在+无工具→DIRECT_CHAT"), # Layer 1 测试(heuristic, 无工具 → 全部 DIRECT_CHAT) ("今天天气怎么样", ExecutionMode.DIRECT_CHAT, "Layer 1: 低复杂度+无工具"), ("帮我写一个 Python 函数排序数组", ExecutionMode.DIRECT_CHAT, "Layer 1: 高复杂度+无工具→DIRECT_CHAT"), ("如何学习编程", ExecutionMode.DIRECT_CHAT, "Layer 1: 中等复杂度+无工具→DIRECT_CHAT"), # 边界条件 ("", ExecutionMode.DIRECT_CHAT, "边界: 空输入"), ("a", ExecutionMode.DIRECT_CHAT, "边界: 极短输入"), ] for text, expected_mode, desc in test_cases: try: result = await router.route( content=text, skill_registry=skill_registry, intent_router=intent_router, default_tools=default_tools, default_system_prompt=None, default_model="default", default_agent_name="default", session_id="test", transparency="TRACE", ) mode_match = result.execution_mode == expected_mode status = "PASS" if mode_match else "FAIL" print( f" [{status}] {desc:40s} → mode={result.execution_mode.value:12s} " f"(expected={expected_mode.value}) method={result.match_method}" ) if not mode_match: print(f" complexity={result.complexity}, matched={result.matched}, agent={result.agent_name}") except Exception as e: print(f" [ERROR] {desc:40s} → exception: {e}") # ---- 有工具场景测试 ---- print("\n --- 有工具场景 ---") mock_tool = type('Tool', (), {'name': 'shell', 'description': 'Run shell commands'})() default_tools_with = [mock_tool] test_cases_with_tools = [ ("帮我写一个 Python 函数排序数组", ExecutionMode.REACT, "Layer 1: 高复杂度+有工具→REACT"), ("如何学习编程", ExecutionMode.REACT, "Layer 1: 中等复杂度+有工具→REACT"), ("今天天气怎么样", ExecutionMode.DIRECT_CHAT, "Layer 1: 低复杂度+有工具→DIRECT_CHAT"), ] for text, expected_mode, desc in test_cases_with_tools: try: result = await router.route( content=text, skill_registry=skill_registry, intent_router=intent_router, default_tools=default_tools_with, default_system_prompt=None, default_model="default", default_agent_name="default", session_id="test-tools", transparency="TRACE", ) mode_match = result.execution_mode == expected_mode status = "PASS" if mode_match else "FAIL" print( f" [{status}] {desc:40s} → mode={result.execution_mode.value:12s} " f"(expected={expected_mode.value}) method={result.match_method}" ) if not mode_match: print(f" complexity={result.complexity}, matched={result.matched}, agent={result.agent_name}") except Exception as e: print(f" [ERROR] {desc:40s} → exception: {e}") # ============================================================ # 4. execution_mode 覆盖盲区测试 # ============================================================ async def test_execution_mode_gaps(): """测试 execution_mode 在所有 SkillRoutingResult 构造处是否被设置。""" print("\n" + "=" * 60) print("4. execution_mode 覆盖盲区测试") print("=" * 60) # 测试默认值 default_result = SkillRoutingResult() print(f" 默认 execution_mode: {default_result.execution_mode.value}") assert default_result.execution_mode == ExecutionMode.DIRECT_CHAT, \ f"默认值应为 DIRECT_CHAT,实际为 {default_result.execution_mode}" # 测试所有路由路径都返回非默认值(除非确实应该是 DIRECT_CHAT) router = CostAwareRouter( llm_gateway=None, classifier="heuristic", merged_llm_classify=False, ) class MockSkillRegistry: def list_skills(self): return [] def get(self, name): raise KeyError(f"Skill not found: {name}") class MockIntentRouter: pass # 测试大量随机输入,确保没有异常 edge_cases = [ " " * 100, # 纯空格 "啊" * 500, # 超长重复字符 "!@#$%^&*()", # 特殊字符 "12345", # 纯数字 "你好\n\n\n你好", # 多行 "\t\t\t", # 制表符 "mix中English混合input", # 中英混合 "帮我执行 `rm -rf /` 命令", # 含代码和危险命令 "请问一下,你能不能帮我分析一下这个错误日志,然后搜索一下解决方案", # 长句多关键词 ] for text in edge_cases: try: result = await router.route( content=text, skill_registry=MockSkillRegistry(), intent_router=MockIntentRouter(), default_tools=[], default_system_prompt=None, session_id="test-edge", transparency="TRACE", ) # 只要不抛异常且 execution_mode 是合法枚举值就算通过 assert isinstance(result.execution_mode, ExecutionMode), \ f"execution_mode 不是 ExecutionMode: {result.execution_mode}" print(f" [PASS] {text[:40]:40s} → mode={result.execution_mode.value}") except Exception as e: print(f" [FAIL] {text[:40]:40s} → exception: {e}") # ============================================================ # 5. portal.py 执行路径选择硬编码检测 # ============================================================ def test_portal_hardcoding(): """检测 portal.py 中是否还有 match_method 字符串匹配。""" print("\n" + "=" * 60) print("5. portal.py 硬编码检测") print("=" * 60) portal_path = os.path.join( os.path.dirname(__file__), "..", "src", "agentkit", "server", "routes", "portal.py" ) portal_path = os.path.normpath(portal_path) if not os.path.exists(portal_path): print(f" [SKIP] portal.py not found at {portal_path}") return with open(portal_path, "r") as f: content = f.read() # 检测是否还有 match_method 字符串匹配 import re # 匹配 "match_method" in (...) 或 match_method == "xxx" suspicious_patterns = [ (r'match_method\s+in\s*\(', 'match_method in (...) — 应改用 execution_mode'), (r'match_method\s*==\s*["\']', 'match_method == "xxx" — 应改用 execution_mode'), (r'agent_name\s*==\s*["\']default["\']', 'agent_name == "default" — 应改用 execution_mode'), ] issues = [] for pattern, desc in suspicious_patterns: matches = list(re.finditer(pattern, content)) for m in matches: # 获取行号 line_no = content[:m.start()].count('\n') + 1 line = content.split('\n')[line_no - 1].strip() issues.append((line_no, desc, line)) if issues: for line_no, desc, line in issues: print(f" [WARN] Line {line_no}: {desc}") print(f" {line}") else: print(" [PASS] 未发现 match_method 字符串匹配硬编码") # ============================================================ # 6. resolve_skill_routing execution_mode 测试 # ============================================================ async def test_resolve_skill_routing_modes(): """测试 resolve_skill_routing 的 execution_mode 设置。""" print("\n" + "=" * 60) print("6. resolve_skill_routing execution_mode 测试") print("=" * 60) from agentkit.chat.skill_routing import resolve_skill_routing class MockSkillRegistry: def list_skills(self): return [] def get(self, name): raise KeyError(f"Skill not found: {name}") class MockIntentRouter: pass # 无技能匹配 + 无工具 → DIRECT_CHAT result = await resolve_skill_routing( content="你好", skill_registry=MockSkillRegistry(), intent_router=MockIntentRouter(), default_tools=[], # 无工具 default_system_prompt=None, session_id="test", ) status = "PASS" if result.execution_mode == ExecutionMode.DIRECT_CHAT else "FAIL" print(f" [{status}] 无技能+无工具 → {result.execution_mode.value} (expected=direct_chat)") # 无技能匹配 + 有工具 → REACT result = await resolve_skill_routing( content="帮我写代码", skill_registry=MockSkillRegistry(), intent_router=MockIntentRouter(), default_tools=[type('Tool', (), {'name': 'shell'})()], # 有工具 default_system_prompt=None, session_id="test", ) status = "PASS" if result.execution_mode == ExecutionMode.REACT else "FAIL" print(f" [{status}] 无技能+有工具 → {result.execution_mode.value} (expected=react)") # ============================================================ # Main # ============================================================ async def main(): print("CostAwareRouter 全链路场景验证") print("=" * 60) test_layer0_regex_hardcoding() test_heuristic_classifier() await test_execution_mode_chain() await test_execution_mode_gaps() test_portal_hardcoding() await test_resolve_skill_routing_modes() print("\n" + "=" * 60) print("验证完成") if __name__ == "__main__": asyncio.run(main())