fischer-agentkit/tests/test_routing_chain.py

442 lines
17 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""CostAwareRouter 全链路场景测试。
验证目标:
1. 每个路由路径都正确设置 execution_mode
2. 不存在硬编码的字符串匹配match_method 枚举)
3. 所有输入场景都有明确的路由结果
4. 边界条件(空输入、超长输入、特殊字符)不会导致异常
5. execution_mode 是执行路径选择的唯一依据
运行: python3 tests/test_routing_chain.py
"""
import asyncio
import sys
import os
# Add project root to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from agentkit.chat.skill_routing import (
CostAwareRouter,
ExecutionMode,
SkillRoutingResult,
HeuristicClassifier,
_GREETING_RE,
_CHAT_MODE_RE,
_IDENTITY_RE,
parse_skill_prefix,
)
# ============================================================
# 1. Layer 0 正则硬编码测试
# ============================================================
def test_layer0_regex_hardcoding():
"""Layer 0 的正则是否只覆盖了枚举的词,遗漏了常见变体?"""
print("\n" + "=" * 60)
print("1. Layer 0 正则硬编码测试")
print("=" * 60)
# 问候 — 应该匹配
greeting_should_match = [
"你好", "hi", "hello", "hey", "", "哈喽",
"早上好", "下午好", "晚上好",
"Good morning", "Good afternoon", "Good evening",
"你好!", "hi?", "hello.",
]
# 问候 — 不应该匹配
greeting_should_not_match = [
"你好,帮我写个代码", # 不是纯问候
"hello world", # 不是问候
"hi there, I need help", # 带后续内容
"你好啊,请问一下", # 带后续内容
]
print("\n 问候匹配测试:")
for text in greeting_should_match:
m = _GREETING_RE.match(text.strip())
status = "PASS" if m else "FAIL"
print(f" [{status}] {text!r:30s}{'MATCH' if m else 'MISS'}")
print("\n 问候不应匹配测试:")
for text in greeting_should_not_match:
m = _GREETING_RE.match(text.strip())
status = "PASS" if not m else "FAIL"
print(f" [{status}] {text!r:30s}{'MATCH' if m else 'MISS'}")
# 简单对话 — 应该匹配
chat_should_match = [
"谢谢", "感谢", "thanks", "thank you", "ok",
"好的", "", "", "不是", "没关系", "再见", "bye", "goodbye",
]
# 简单对话 — 不应该匹配
chat_should_not_match = [
"谢谢你的帮助,但我还有个问题", # 带后续内容
"ok let's do it", # 带后续内容
]
print("\n 简单对话匹配测试:")
for text in chat_should_match:
m = _CHAT_MODE_RE.match(text.strip())
status = "PASS" if m else "FAIL"
print(f" [{status}] {text!r:30s}{'MATCH' if m else 'MISS'}")
print("\n 简单对话不应匹配测试:")
for text in chat_should_not_match:
m = _CHAT_MODE_RE.match(text.strip())
status = "PASS" if not m else "FAIL"
print(f" [{status}] {text!r:30s}{'MATCH' if m else 'MISS'}")
# 身份问题 — 应该匹配
identity_should_match = [
"你是谁", "你叫什么", "你是什么", "你是哪个",
"who are you", "what are you", "what's your name",
"介绍一下你自己", "自我介绍", "你叫啥", "你叫什么名字", "你的名字",
"你是谁?", "who are you?",
]
# 身份问题 — 不应该匹配
identity_should_not_match = [
"你是谁的粉丝", # 不是问身份
"你是什么时候发布的", # 不是问身份
"who are you talking to", # 不是问身份
]
print("\n 身份问题匹配测试:")
for text in identity_should_match:
m = _IDENTITY_RE.match(text.strip())
status = "PASS" if m else "FAIL"
print(f" [{status}] {text!r:30s}{'MATCH' if m else 'MISS'}")
print("\n 身份问题不应匹配测试:")
for text in identity_should_not_match:
m = _IDENTITY_RE.match(text.strip())
status = "PASS" if not m else "FAIL"
print(f" [{status}] {text!r:30s}{'MATCH' if m else 'MISS'}")
# ============================================================
# 2. HeuristicClassifier 硬编码测试
# ============================================================
def test_heuristic_classifier():
"""HeuristicClassifier 是否对常见输入给出合理的复杂度评分?"""
print("\n" + "=" * 60)
print("2. HeuristicClassifier 复杂度评估测试")
print("=" * 60)
classifier = HeuristicClassifier()
test_cases = [
# (input, expected_complexity_range, description)
("你是谁", (0.0, 0.3), "短问题,无复杂度暗示"),
("今天天气怎么样", (0.0, 0.3), "简单闲聊"),
("帮我写一个 Python 函数", (0.6, 1.0), "含代码关键词"),
("请搜索一下最近的新闻", (0.6, 1.0), "含搜索关键词"),
("如何学习 Python", (0.3, 0.6), "含中等关键词"),
("为什么天空是蓝色的", (0.3, 0.6), "含中等关键词"),
("执行命令 ls -la 并分析结果,然后搜索相关的配置文件,最后生成一份报告", (0.7, 1.0), "多步复杂任务"),
("", (0.0, 0.0), "空输入"),
("a", (0.0, 0.3), "极短输入"),
("帮我分析这段代码的性能瓶颈,并给出优化建议", (0.6, 1.0), "含分析和优化关键词"),
]
for text, (low, high), desc in test_cases:
score = classifier.classify(text)
in_range = low <= score <= high
status = "PASS" if in_range else "WARN"
print(f" [{status}] {desc:30s}{score:.2f} (期望 {low:.1f}-{high:.1f})")
# ============================================================
# 3. CostAwareRouter 全链路 execution_mode 测试
# ============================================================
async def test_execution_mode_chain():
"""验证每条路由路径都正确设置了 execution_mode。"""
print("\n" + "=" * 60)
print("3. CostAwareRouter execution_mode 全链路测试")
print("=" * 60)
# 创建不依赖 LLM 的路由器heuristic 模式,无 merged LLM
router = CostAwareRouter(
llm_gateway=None, # 不使用 LLM
classifier="heuristic",
merged_llm_classify=False, # 不使用 merged LLM
)
# Mock skill_registry and intent_router (empty)
class MockSkillRegistry:
def list_skills(self):
return []
def get(self, name):
raise KeyError(f"Skill not found: {name}")
class MockIntentRouter:
pass
skill_registry = MockSkillRegistry()
intent_router = MockIntentRouter()
default_tools = [] # 无工具
test_cases = [
# Layer 0 测试
("你好", ExecutionMode.DIRECT_CHAT, "Layer 0: 问候"),
("hello", ExecutionMode.DIRECT_CHAT, "Layer 0: 英文问候"),
("谢谢", ExecutionMode.DIRECT_CHAT, "Layer 0: 简单对话"),
("你是谁", ExecutionMode.DIRECT_CHAT, "Layer 0: 身份问题"),
("who are you", ExecutionMode.DIRECT_CHAT, "Layer 0: 英文身份问题"),
("@skill:coder 帮我写代码", ExecutionMode.DIRECT_CHAT, "Layer 0: 显式技能不存在+无工具→DIRECT_CHAT"),
# Layer 1 测试heuristic, 无工具 → 全部 DIRECT_CHAT
("今天天气怎么样", ExecutionMode.DIRECT_CHAT, "Layer 1: 低复杂度+无工具"),
("帮我写一个 Python 函数排序数组", ExecutionMode.DIRECT_CHAT, "Layer 1: 高复杂度+无工具→DIRECT_CHAT"),
("如何学习编程", ExecutionMode.DIRECT_CHAT, "Layer 1: 中等复杂度+无工具→DIRECT_CHAT"),
# 边界条件
("", ExecutionMode.DIRECT_CHAT, "边界: 空输入"),
("a", ExecutionMode.DIRECT_CHAT, "边界: 极短输入"),
]
for text, expected_mode, desc in test_cases:
try:
result = await router.route(
content=text,
skill_registry=skill_registry,
intent_router=intent_router,
default_tools=default_tools,
default_system_prompt=None,
default_model="default",
default_agent_name="default",
session_id="test",
transparency="TRACE",
)
mode_match = result.execution_mode == expected_mode
status = "PASS" if mode_match else "FAIL"
print(
f" [{status}] {desc:40s} → mode={result.execution_mode.value:12s} "
f"(expected={expected_mode.value}) method={result.match_method}"
)
if not mode_match:
print(f" complexity={result.complexity}, matched={result.matched}, agent={result.agent_name}")
except Exception as e:
print(f" [ERROR] {desc:40s} → exception: {e}")
# ---- 有工具场景测试 ----
print("\n --- 有工具场景 ---")
mock_tool = type('Tool', (), {'name': 'shell', 'description': 'Run shell commands'})()
default_tools_with = [mock_tool]
test_cases_with_tools = [
("帮我写一个 Python 函数排序数组", ExecutionMode.REACT, "Layer 1: 高复杂度+有工具→REACT"),
("如何学习编程", ExecutionMode.REACT, "Layer 1: 中等复杂度+有工具→REACT"),
("今天天气怎么样", ExecutionMode.DIRECT_CHAT, "Layer 1: 低复杂度+有工具→DIRECT_CHAT"),
]
for text, expected_mode, desc in test_cases_with_tools:
try:
result = await router.route(
content=text,
skill_registry=skill_registry,
intent_router=intent_router,
default_tools=default_tools_with,
default_system_prompt=None,
default_model="default",
default_agent_name="default",
session_id="test-tools",
transparency="TRACE",
)
mode_match = result.execution_mode == expected_mode
status = "PASS" if mode_match else "FAIL"
print(
f" [{status}] {desc:40s} → mode={result.execution_mode.value:12s} "
f"(expected={expected_mode.value}) method={result.match_method}"
)
if not mode_match:
print(f" complexity={result.complexity}, matched={result.matched}, agent={result.agent_name}")
except Exception as e:
print(f" [ERROR] {desc:40s} → exception: {e}")
# ============================================================
# 4. execution_mode 覆盖盲区测试
# ============================================================
async def test_execution_mode_gaps():
"""测试 execution_mode 在所有 SkillRoutingResult 构造处是否被设置。"""
print("\n" + "=" * 60)
print("4. execution_mode 覆盖盲区测试")
print("=" * 60)
# 测试默认值
default_result = SkillRoutingResult()
print(f" 默认 execution_mode: {default_result.execution_mode.value}")
assert default_result.execution_mode == ExecutionMode.DIRECT_CHAT, \
f"默认值应为 DIRECT_CHAT实际为 {default_result.execution_mode}"
# 测试所有路由路径都返回非默认值(除非确实应该是 DIRECT_CHAT
router = CostAwareRouter(
llm_gateway=None,
classifier="heuristic",
merged_llm_classify=False,
)
class MockSkillRegistry:
def list_skills(self):
return []
def get(self, name):
raise KeyError(f"Skill not found: {name}")
class MockIntentRouter:
pass
# 测试大量随机输入,确保没有异常
edge_cases = [
" " * 100, # 纯空格
"" * 500, # 超长重复字符
"!@#$%^&*()", # 特殊字符
"12345", # 纯数字
"你好\n\n\n你好", # 多行
"\t\t\t", # 制表符
"mix中English混合input", # 中英混合
"帮我执行 `rm -rf /` 命令", # 含代码和危险命令
"请问一下,你能不能帮我分析一下这个错误日志,然后搜索一下解决方案", # 长句多关键词
]
for text in edge_cases:
try:
result = await router.route(
content=text,
skill_registry=MockSkillRegistry(),
intent_router=MockIntentRouter(),
default_tools=[],
default_system_prompt=None,
session_id="test-edge",
transparency="TRACE",
)
# 只要不抛异常且 execution_mode 是合法枚举值就算通过
assert isinstance(result.execution_mode, ExecutionMode), \
f"execution_mode 不是 ExecutionMode: {result.execution_mode}"
print(f" [PASS] {text[:40]:40s} → mode={result.execution_mode.value}")
except Exception as e:
print(f" [FAIL] {text[:40]:40s} → exception: {e}")
# ============================================================
# 5. portal.py 执行路径选择硬编码检测
# ============================================================
def test_portal_hardcoding():
"""检测 portal.py 中是否还有 match_method 字符串匹配。"""
print("\n" + "=" * 60)
print("5. portal.py 硬编码检测")
print("=" * 60)
portal_path = os.path.join(
os.path.dirname(__file__), "..",
"src", "agentkit", "server", "routes", "portal.py"
)
portal_path = os.path.normpath(portal_path)
if not os.path.exists(portal_path):
print(f" [SKIP] portal.py not found at {portal_path}")
return
with open(portal_path, "r") as f:
content = f.read()
# 检测是否还有 match_method 字符串匹配
import re
# 匹配 "match_method" in (...) 或 match_method == "xxx"
suspicious_patterns = [
(r'match_method\s+in\s*\(', 'match_method in (...) — 应改用 execution_mode'),
(r'match_method\s*==\s*["\']', 'match_method == "xxx" — 应改用 execution_mode'),
(r'agent_name\s*==\s*["\']default["\']', 'agent_name == "default" — 应改用 execution_mode'),
]
issues = []
for pattern, desc in suspicious_patterns:
matches = list(re.finditer(pattern, content))
for m in matches:
# 获取行号
line_no = content[:m.start()].count('\n') + 1
line = content.split('\n')[line_no - 1].strip()
issues.append((line_no, desc, line))
if issues:
for line_no, desc, line in issues:
print(f" [WARN] Line {line_no}: {desc}")
print(f" {line}")
else:
print(" [PASS] 未发现 match_method 字符串匹配硬编码")
# ============================================================
# 6. resolve_skill_routing execution_mode 测试
# ============================================================
async def test_resolve_skill_routing_modes():
"""测试 resolve_skill_routing 的 execution_mode 设置。"""
print("\n" + "=" * 60)
print("6. resolve_skill_routing execution_mode 测试")
print("=" * 60)
from agentkit.chat.skill_routing import resolve_skill_routing
class MockSkillRegistry:
def list_skills(self):
return []
def get(self, name):
raise KeyError(f"Skill not found: {name}")
class MockIntentRouter:
pass
# 无技能匹配 + 无工具 → DIRECT_CHAT
result = await resolve_skill_routing(
content="你好",
skill_registry=MockSkillRegistry(),
intent_router=MockIntentRouter(),
default_tools=[], # 无工具
default_system_prompt=None,
session_id="test",
)
status = "PASS" if result.execution_mode == ExecutionMode.DIRECT_CHAT else "FAIL"
print(f" [{status}] 无技能+无工具 → {result.execution_mode.value} (expected=direct_chat)")
# 无技能匹配 + 有工具 → REACT
result = await resolve_skill_routing(
content="帮我写代码",
skill_registry=MockSkillRegistry(),
intent_router=MockIntentRouter(),
default_tools=[type('Tool', (), {'name': 'shell'})()], # 有工具
default_system_prompt=None,
session_id="test",
)
status = "PASS" if result.execution_mode == ExecutionMode.REACT else "FAIL"
print(f" [{status}] 无技能+有工具 → {result.execution_mode.value} (expected=react)")
# ============================================================
# Main
# ============================================================
async def main():
print("CostAwareRouter 全链路场景验证")
print("=" * 60)
test_layer0_regex_hardcoding()
test_heuristic_classifier()
await test_execution_mode_chain()
await test_execution_mode_gaps()
test_portal_hardcoding()
await test_resolve_skill_routing_modes()
print("\n" + "=" * 60)
print("验证完成")
if __name__ == "__main__":
asyncio.run(main())