442 lines
17 KiB
Python
442 lines
17 KiB
Python
"""CostAwareRouter 全链路场景测试。
|
||
|
||
验证目标:
|
||
1. 每个路由路径都正确设置 execution_mode
|
||
2. 不存在硬编码的字符串匹配(match_method 枚举)
|
||
3. 所有输入场景都有明确的路由结果
|
||
4. 边界条件(空输入、超长输入、特殊字符)不会导致异常
|
||
5. execution_mode 是执行路径选择的唯一依据
|
||
|
||
运行: python3 tests/test_routing_chain.py
|
||
"""
|
||
|
||
import asyncio
|
||
import sys
|
||
import os
|
||
|
||
# Add project root to path
|
||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
||
|
||
from agentkit.chat.skill_routing import (
|
||
CostAwareRouter,
|
||
ExecutionMode,
|
||
SkillRoutingResult,
|
||
HeuristicClassifier,
|
||
_GREETING_RE,
|
||
_CHAT_MODE_RE,
|
||
_IDENTITY_RE,
|
||
parse_skill_prefix,
|
||
)
|
||
|
||
|
||
# ============================================================
|
||
# 1. Layer 0 正则硬编码测试
|
||
# ============================================================
|
||
|
||
def test_layer0_regex_hardcoding():
|
||
"""Layer 0 的正则是否只覆盖了枚举的词,遗漏了常见变体?"""
|
||
print("\n" + "=" * 60)
|
||
print("1. Layer 0 正则硬编码测试")
|
||
print("=" * 60)
|
||
|
||
# 问候 — 应该匹配
|
||
greeting_should_match = [
|
||
"你好", "hi", "hello", "hey", "嗨", "哈喽",
|
||
"早上好", "下午好", "晚上好",
|
||
"Good morning", "Good afternoon", "Good evening",
|
||
"你好!", "hi?", "hello.",
|
||
]
|
||
# 问候 — 不应该匹配
|
||
greeting_should_not_match = [
|
||
"你好,帮我写个代码", # 不是纯问候
|
||
"hello world", # 不是问候
|
||
"hi there, I need help", # 带后续内容
|
||
"你好啊,请问一下", # 带后续内容
|
||
]
|
||
|
||
print("\n 问候匹配测试:")
|
||
for text in greeting_should_match:
|
||
m = _GREETING_RE.match(text.strip())
|
||
status = "PASS" if m else "FAIL"
|
||
print(f" [{status}] {text!r:30s} → {'MATCH' if m else 'MISS'}")
|
||
|
||
print("\n 问候不应匹配测试:")
|
||
for text in greeting_should_not_match:
|
||
m = _GREETING_RE.match(text.strip())
|
||
status = "PASS" if not m else "FAIL"
|
||
print(f" [{status}] {text!r:30s} → {'MATCH' if m else 'MISS'}")
|
||
|
||
# 简单对话 — 应该匹配
|
||
chat_should_match = [
|
||
"谢谢", "感谢", "thanks", "thank you", "ok",
|
||
"好的", "嗯", "对", "不是", "没关系", "再见", "bye", "goodbye",
|
||
]
|
||
# 简单对话 — 不应该匹配
|
||
chat_should_not_match = [
|
||
"谢谢你的帮助,但我还有个问题", # 带后续内容
|
||
"ok let's do it", # 带后续内容
|
||
]
|
||
|
||
print("\n 简单对话匹配测试:")
|
||
for text in chat_should_match:
|
||
m = _CHAT_MODE_RE.match(text.strip())
|
||
status = "PASS" if m else "FAIL"
|
||
print(f" [{status}] {text!r:30s} → {'MATCH' if m else 'MISS'}")
|
||
|
||
print("\n 简单对话不应匹配测试:")
|
||
for text in chat_should_not_match:
|
||
m = _CHAT_MODE_RE.match(text.strip())
|
||
status = "PASS" if not m else "FAIL"
|
||
print(f" [{status}] {text!r:30s} → {'MATCH' if m else 'MISS'}")
|
||
|
||
# 身份问题 — 应该匹配
|
||
identity_should_match = [
|
||
"你是谁", "你叫什么", "你是什么", "你是哪个",
|
||
"who are you", "what are you", "what's your name",
|
||
"介绍一下你自己", "自我介绍", "你叫啥", "你叫什么名字", "你的名字",
|
||
"你是谁?", "who are you?",
|
||
]
|
||
# 身份问题 — 不应该匹配
|
||
identity_should_not_match = [
|
||
"你是谁的粉丝", # 不是问身份
|
||
"你是什么时候发布的", # 不是问身份
|
||
"who are you talking to", # 不是问身份
|
||
]
|
||
|
||
print("\n 身份问题匹配测试:")
|
||
for text in identity_should_match:
|
||
m = _IDENTITY_RE.match(text.strip())
|
||
status = "PASS" if m else "FAIL"
|
||
print(f" [{status}] {text!r:30s} → {'MATCH' if m else 'MISS'}")
|
||
|
||
print("\n 身份问题不应匹配测试:")
|
||
for text in identity_should_not_match:
|
||
m = _IDENTITY_RE.match(text.strip())
|
||
status = "PASS" if not m else "FAIL"
|
||
print(f" [{status}] {text!r:30s} → {'MATCH' if m else 'MISS'}")
|
||
|
||
|
||
# ============================================================
|
||
# 2. HeuristicClassifier 硬编码测试
|
||
# ============================================================
|
||
|
||
def test_heuristic_classifier():
|
||
"""HeuristicClassifier 是否对常见输入给出合理的复杂度评分?"""
|
||
print("\n" + "=" * 60)
|
||
print("2. HeuristicClassifier 复杂度评估测试")
|
||
print("=" * 60)
|
||
|
||
classifier = HeuristicClassifier()
|
||
|
||
test_cases = [
|
||
# (input, expected_complexity_range, description)
|
||
("你是谁", (0.0, 0.3), "短问题,无复杂度暗示"),
|
||
("今天天气怎么样", (0.0, 0.3), "简单闲聊"),
|
||
("帮我写一个 Python 函数", (0.6, 1.0), "含代码关键词"),
|
||
("请搜索一下最近的新闻", (0.6, 1.0), "含搜索关键词"),
|
||
("如何学习 Python?", (0.3, 0.6), "含中等关键词"),
|
||
("为什么天空是蓝色的", (0.3, 0.6), "含中等关键词"),
|
||
("执行命令 ls -la 并分析结果,然后搜索相关的配置文件,最后生成一份报告", (0.7, 1.0), "多步复杂任务"),
|
||
("", (0.0, 0.0), "空输入"),
|
||
("a", (0.0, 0.3), "极短输入"),
|
||
("帮我分析这段代码的性能瓶颈,并给出优化建议", (0.6, 1.0), "含分析和优化关键词"),
|
||
]
|
||
|
||
for text, (low, high), desc in test_cases:
|
||
score = classifier.classify(text)
|
||
in_range = low <= score <= high
|
||
status = "PASS" if in_range else "WARN"
|
||
print(f" [{status}] {desc:30s} → {score:.2f} (期望 {low:.1f}-{high:.1f})")
|
||
|
||
|
||
# ============================================================
|
||
# 3. CostAwareRouter 全链路 execution_mode 测试
|
||
# ============================================================
|
||
|
||
async def test_execution_mode_chain():
|
||
"""验证每条路由路径都正确设置了 execution_mode。"""
|
||
print("\n" + "=" * 60)
|
||
print("3. CostAwareRouter execution_mode 全链路测试")
|
||
print("=" * 60)
|
||
|
||
# 创建不依赖 LLM 的路由器(heuristic 模式,无 merged LLM)
|
||
router = CostAwareRouter(
|
||
llm_gateway=None, # 不使用 LLM
|
||
classifier="heuristic",
|
||
merged_llm_classify=False, # 不使用 merged LLM
|
||
)
|
||
|
||
# Mock skill_registry and intent_router (empty)
|
||
class MockSkillRegistry:
|
||
def list_skills(self):
|
||
return []
|
||
def get(self, name):
|
||
raise KeyError(f"Skill not found: {name}")
|
||
|
||
class MockIntentRouter:
|
||
pass
|
||
|
||
skill_registry = MockSkillRegistry()
|
||
intent_router = MockIntentRouter()
|
||
default_tools = [] # 无工具
|
||
|
||
test_cases = [
|
||
# Layer 0 测试
|
||
("你好", ExecutionMode.DIRECT_CHAT, "Layer 0: 问候"),
|
||
("hello", ExecutionMode.DIRECT_CHAT, "Layer 0: 英文问候"),
|
||
("谢谢", ExecutionMode.DIRECT_CHAT, "Layer 0: 简单对话"),
|
||
("你是谁", ExecutionMode.DIRECT_CHAT, "Layer 0: 身份问题"),
|
||
("who are you", ExecutionMode.DIRECT_CHAT, "Layer 0: 英文身份问题"),
|
||
("@skill:coder 帮我写代码", ExecutionMode.DIRECT_CHAT, "Layer 0: 显式技能不存在+无工具→DIRECT_CHAT"),
|
||
|
||
# Layer 1 测试(heuristic, 无工具 → 全部 DIRECT_CHAT)
|
||
("今天天气怎么样", ExecutionMode.DIRECT_CHAT, "Layer 1: 低复杂度+无工具"),
|
||
("帮我写一个 Python 函数排序数组", ExecutionMode.DIRECT_CHAT, "Layer 1: 高复杂度+无工具→DIRECT_CHAT"),
|
||
("如何学习编程", ExecutionMode.DIRECT_CHAT, "Layer 1: 中等复杂度+无工具→DIRECT_CHAT"),
|
||
|
||
# 边界条件
|
||
("", ExecutionMode.DIRECT_CHAT, "边界: 空输入"),
|
||
("a", ExecutionMode.DIRECT_CHAT, "边界: 极短输入"),
|
||
]
|
||
|
||
for text, expected_mode, desc in test_cases:
|
||
try:
|
||
result = await router.route(
|
||
content=text,
|
||
skill_registry=skill_registry,
|
||
intent_router=intent_router,
|
||
default_tools=default_tools,
|
||
default_system_prompt=None,
|
||
default_model="default",
|
||
default_agent_name="default",
|
||
session_id="test",
|
||
transparency="TRACE",
|
||
)
|
||
mode_match = result.execution_mode == expected_mode
|
||
status = "PASS" if mode_match else "FAIL"
|
||
print(
|
||
f" [{status}] {desc:40s} → mode={result.execution_mode.value:12s} "
|
||
f"(expected={expected_mode.value}) method={result.match_method}"
|
||
)
|
||
if not mode_match:
|
||
print(f" complexity={result.complexity}, matched={result.matched}, agent={result.agent_name}")
|
||
except Exception as e:
|
||
print(f" [ERROR] {desc:40s} → exception: {e}")
|
||
|
||
# ---- 有工具场景测试 ----
|
||
print("\n --- 有工具场景 ---")
|
||
mock_tool = type('Tool', (), {'name': 'shell', 'description': 'Run shell commands'})()
|
||
default_tools_with = [mock_tool]
|
||
|
||
test_cases_with_tools = [
|
||
("帮我写一个 Python 函数排序数组", ExecutionMode.REACT, "Layer 1: 高复杂度+有工具→REACT"),
|
||
("如何学习编程", ExecutionMode.REACT, "Layer 1: 中等复杂度+有工具→REACT"),
|
||
("今天天气怎么样", ExecutionMode.DIRECT_CHAT, "Layer 1: 低复杂度+有工具→DIRECT_CHAT"),
|
||
]
|
||
|
||
for text, expected_mode, desc in test_cases_with_tools:
|
||
try:
|
||
result = await router.route(
|
||
content=text,
|
||
skill_registry=skill_registry,
|
||
intent_router=intent_router,
|
||
default_tools=default_tools_with,
|
||
default_system_prompt=None,
|
||
default_model="default",
|
||
default_agent_name="default",
|
||
session_id="test-tools",
|
||
transparency="TRACE",
|
||
)
|
||
mode_match = result.execution_mode == expected_mode
|
||
status = "PASS" if mode_match else "FAIL"
|
||
print(
|
||
f" [{status}] {desc:40s} → mode={result.execution_mode.value:12s} "
|
||
f"(expected={expected_mode.value}) method={result.match_method}"
|
||
)
|
||
if not mode_match:
|
||
print(f" complexity={result.complexity}, matched={result.matched}, agent={result.agent_name}")
|
||
except Exception as e:
|
||
print(f" [ERROR] {desc:40s} → exception: {e}")
|
||
|
||
|
||
# ============================================================
|
||
# 4. execution_mode 覆盖盲区测试
|
||
# ============================================================
|
||
|
||
async def test_execution_mode_gaps():
|
||
"""测试 execution_mode 在所有 SkillRoutingResult 构造处是否被设置。"""
|
||
print("\n" + "=" * 60)
|
||
print("4. execution_mode 覆盖盲区测试")
|
||
print("=" * 60)
|
||
|
||
# 测试默认值
|
||
default_result = SkillRoutingResult()
|
||
print(f" 默认 execution_mode: {default_result.execution_mode.value}")
|
||
assert default_result.execution_mode == ExecutionMode.DIRECT_CHAT, \
|
||
f"默认值应为 DIRECT_CHAT,实际为 {default_result.execution_mode}"
|
||
|
||
# 测试所有路由路径都返回非默认值(除非确实应该是 DIRECT_CHAT)
|
||
router = CostAwareRouter(
|
||
llm_gateway=None,
|
||
classifier="heuristic",
|
||
merged_llm_classify=False,
|
||
)
|
||
|
||
class MockSkillRegistry:
|
||
def list_skills(self):
|
||
return []
|
||
def get(self, name):
|
||
raise KeyError(f"Skill not found: {name}")
|
||
|
||
class MockIntentRouter:
|
||
pass
|
||
|
||
# 测试大量随机输入,确保没有异常
|
||
edge_cases = [
|
||
" " * 100, # 纯空格
|
||
"啊" * 500, # 超长重复字符
|
||
"!@#$%^&*()", # 特殊字符
|
||
"12345", # 纯数字
|
||
"你好\n\n\n你好", # 多行
|
||
"\t\t\t", # 制表符
|
||
"mix中English混合input", # 中英混合
|
||
"帮我执行 `rm -rf /` 命令", # 含代码和危险命令
|
||
"请问一下,你能不能帮我分析一下这个错误日志,然后搜索一下解决方案", # 长句多关键词
|
||
]
|
||
|
||
for text in edge_cases:
|
||
try:
|
||
result = await router.route(
|
||
content=text,
|
||
skill_registry=MockSkillRegistry(),
|
||
intent_router=MockIntentRouter(),
|
||
default_tools=[],
|
||
default_system_prompt=None,
|
||
session_id="test-edge",
|
||
transparency="TRACE",
|
||
)
|
||
# 只要不抛异常且 execution_mode 是合法枚举值就算通过
|
||
assert isinstance(result.execution_mode, ExecutionMode), \
|
||
f"execution_mode 不是 ExecutionMode: {result.execution_mode}"
|
||
print(f" [PASS] {text[:40]:40s} → mode={result.execution_mode.value}")
|
||
except Exception as e:
|
||
print(f" [FAIL] {text[:40]:40s} → exception: {e}")
|
||
|
||
|
||
# ============================================================
|
||
# 5. portal.py 执行路径选择硬编码检测
|
||
# ============================================================
|
||
|
||
def test_portal_hardcoding():
|
||
"""检测 portal.py 中是否还有 match_method 字符串匹配。"""
|
||
print("\n" + "=" * 60)
|
||
print("5. portal.py 硬编码检测")
|
||
print("=" * 60)
|
||
|
||
portal_path = os.path.join(
|
||
os.path.dirname(__file__), "..",
|
||
"src", "agentkit", "server", "routes", "portal.py"
|
||
)
|
||
portal_path = os.path.normpath(portal_path)
|
||
|
||
if not os.path.exists(portal_path):
|
||
print(f" [SKIP] portal.py not found at {portal_path}")
|
||
return
|
||
|
||
with open(portal_path, "r") as f:
|
||
content = f.read()
|
||
|
||
# 检测是否还有 match_method 字符串匹配
|
||
import re
|
||
# 匹配 "match_method" in (...) 或 match_method == "xxx"
|
||
suspicious_patterns = [
|
||
(r'match_method\s+in\s*\(', 'match_method in (...) — 应改用 execution_mode'),
|
||
(r'match_method\s*==\s*["\']', 'match_method == "xxx" — 应改用 execution_mode'),
|
||
(r'agent_name\s*==\s*["\']default["\']', 'agent_name == "default" — 应改用 execution_mode'),
|
||
]
|
||
|
||
issues = []
|
||
for pattern, desc in suspicious_patterns:
|
||
matches = list(re.finditer(pattern, content))
|
||
for m in matches:
|
||
# 获取行号
|
||
line_no = content[:m.start()].count('\n') + 1
|
||
line = content.split('\n')[line_no - 1].strip()
|
||
issues.append((line_no, desc, line))
|
||
|
||
if issues:
|
||
for line_no, desc, line in issues:
|
||
print(f" [WARN] Line {line_no}: {desc}")
|
||
print(f" {line}")
|
||
else:
|
||
print(" [PASS] 未发现 match_method 字符串匹配硬编码")
|
||
|
||
|
||
# ============================================================
|
||
# 6. resolve_skill_routing execution_mode 测试
|
||
# ============================================================
|
||
|
||
async def test_resolve_skill_routing_modes():
|
||
"""测试 resolve_skill_routing 的 execution_mode 设置。"""
|
||
print("\n" + "=" * 60)
|
||
print("6. resolve_skill_routing execution_mode 测试")
|
||
print("=" * 60)
|
||
|
||
from agentkit.chat.skill_routing import resolve_skill_routing
|
||
|
||
class MockSkillRegistry:
|
||
def list_skills(self):
|
||
return []
|
||
def get(self, name):
|
||
raise KeyError(f"Skill not found: {name}")
|
||
|
||
class MockIntentRouter:
|
||
pass
|
||
|
||
# 无技能匹配 + 无工具 → DIRECT_CHAT
|
||
result = await resolve_skill_routing(
|
||
content="你好",
|
||
skill_registry=MockSkillRegistry(),
|
||
intent_router=MockIntentRouter(),
|
||
default_tools=[], # 无工具
|
||
default_system_prompt=None,
|
||
session_id="test",
|
||
)
|
||
status = "PASS" if result.execution_mode == ExecutionMode.DIRECT_CHAT else "FAIL"
|
||
print(f" [{status}] 无技能+无工具 → {result.execution_mode.value} (expected=direct_chat)")
|
||
|
||
# 无技能匹配 + 有工具 → REACT
|
||
result = await resolve_skill_routing(
|
||
content="帮我写代码",
|
||
skill_registry=MockSkillRegistry(),
|
||
intent_router=MockIntentRouter(),
|
||
default_tools=[type('Tool', (), {'name': 'shell'})()], # 有工具
|
||
default_system_prompt=None,
|
||
session_id="test",
|
||
)
|
||
status = "PASS" if result.execution_mode == ExecutionMode.REACT else "FAIL"
|
||
print(f" [{status}] 无技能+有工具 → {result.execution_mode.value} (expected=react)")
|
||
|
||
|
||
# ============================================================
|
||
# Main
|
||
# ============================================================
|
||
|
||
async def main():
|
||
print("CostAwareRouter 全链路场景验证")
|
||
print("=" * 60)
|
||
|
||
test_layer0_regex_hardcoding()
|
||
test_heuristic_classifier()
|
||
await test_execution_mode_chain()
|
||
await test_execution_mode_gaps()
|
||
test_portal_hardcoding()
|
||
await test_resolve_skill_routing_modes()
|
||
|
||
print("\n" + "=" * 60)
|
||
print("验证完成")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main())
|