359 lines
16 KiB
Python
359 lines
16 KiB
Python
"""Unit tests for RequestPreprocessor — minimal preprocessing layer."""
|
||
|
||
from __future__ import annotations
|
||
|
||
import pytest
|
||
|
||
from agentkit.chat.request_preprocessor import RequestPreprocessor
|
||
from agentkit.chat.skill_routing import ExecutionMode
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Helpers
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class MockSkill:
|
||
"""Minimal skill mock for testing."""
|
||
|
||
def __init__(self, name: str, execution_mode: str = "react", tools: list | None = None, prompt: dict | None = None):
|
||
self.name = name
|
||
self.execution_mode = execution_mode
|
||
self.tools = tools or []
|
||
self.prompt = prompt or {}
|
||
|
||
|
||
class MockSkillRegistry:
|
||
"""Minimal skill registry mock."""
|
||
|
||
def __init__(self, skills: dict[str, MockSkill] | None = None):
|
||
self._skills = skills or {}
|
||
|
||
def get(self, name: str) -> MockSkill:
|
||
if name not in self._skills:
|
||
raise ValueError(f"Skill '{name}' not found")
|
||
return self._skills[name]
|
||
|
||
def list_skills(self) -> list[MockSkill]:
|
||
return list(self._skills.values())
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Fixtures
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@pytest.fixture
|
||
def registry() -> MockSkillRegistry:
|
||
return MockSkillRegistry({
|
||
"shell_agent": MockSkill("shell_agent", execution_mode="react", tools=["shell"]),
|
||
"direct_agent": MockSkill("direct_agent", execution_mode="direct", tools=[]),
|
||
"rewoo_agent": MockSkill("rewoo_agent", execution_mode="rewoo", tools=["planner"]),
|
||
})
|
||
|
||
|
||
@pytest.fixture
|
||
def preprocessor(registry: MockSkillRegistry) -> RequestPreprocessor:
|
||
return RequestPreprocessor(
|
||
skill_registry=registry,
|
||
default_tools=["shell", "search", "file_read"],
|
||
default_system_prompt="You are a helpful assistant.",
|
||
default_model="default",
|
||
default_agent_name="default",
|
||
)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Layer 0: @skill:xxx prefix
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestSkillPrefix:
|
||
@pytest.mark.asyncio
|
||
async def test_skill_prefix_routes_to_skill(self, preprocessor: RequestPreprocessor):
|
||
result = await preprocessor.preprocess("@skill:shell_agent 查看当前ip")
|
||
assert result.matched is True
|
||
assert result.skill_name == "shell_agent"
|
||
assert result.match_method == "skill_prefix"
|
||
assert result.match_confidence == 1.0
|
||
assert result.execution_mode == ExecutionMode.SKILL_REACT
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_skill_prefix_direct_mode(self, preprocessor: RequestPreprocessor):
|
||
result = await preprocessor.preprocess("@skill:direct_agent 翻译hello")
|
||
assert result.matched is True
|
||
assert result.skill_name == "direct_agent"
|
||
assert result.execution_mode == ExecutionMode.DIRECT_CHAT
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_skill_prefix_rewoo_mode(self, preprocessor: RequestPreprocessor):
|
||
result = await preprocessor.preprocess("@skill:rewoo_agent 重构代码")
|
||
assert result.matched is True
|
||
assert result.skill_name == "rewoo_agent"
|
||
assert result.execution_mode == ExecutionMode.REWOO
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_unknown_skill_falls_back_to_react(self, preprocessor: RequestPreprocessor):
|
||
result = await preprocessor.preprocess("@skill:nonexistent 查询")
|
||
assert result.matched is False
|
||
assert result.match_method == "skill_not_found_fallback"
|
||
assert result.execution_mode == ExecutionMode.REACT
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Layer 1: Greeting/chitchat/identity regex
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestDirectChat:
|
||
@pytest.mark.asyncio
|
||
async def test_greeting_cn(self, preprocessor: RequestPreprocessor):
|
||
result = await preprocessor.preprocess("你好")
|
||
assert result.execution_mode == ExecutionMode.DIRECT_CHAT
|
||
assert result.match_method == "regex_direct"
|
||
assert result.tools == []
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_greeting_en(self, preprocessor: RequestPreprocessor):
|
||
result = await preprocessor.preprocess("hello")
|
||
assert result.execution_mode == ExecutionMode.DIRECT_CHAT
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_chitchat(self, preprocessor: RequestPreprocessor):
|
||
result = await preprocessor.preprocess("谢谢")
|
||
assert result.execution_mode == ExecutionMode.DIRECT_CHAT
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_identity_question(self, preprocessor: RequestPreprocessor):
|
||
result = await preprocessor.preprocess("你是谁")
|
||
assert result.execution_mode == ExecutionMode.DIRECT_CHAT
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_identity_question_en(self, preprocessor: RequestPreprocessor):
|
||
result = await preprocessor.preprocess("who are you")
|
||
assert result.execution_mode == ExecutionMode.DIRECT_CHAT
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Layer 1 extended: Factual / Math / Translation regex (U5)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestFactualMathTranslation:
|
||
"""U5: 纯知识问答/算术/翻译走 DIRECT_CHAT,含工具上下文关键词的走 REACT"""
|
||
|
||
# --- Factual CN → DIRECT_CHAT ---
|
||
@pytest.mark.asyncio
|
||
async def test_factual_cn_what_is(self, preprocessor: RequestPreprocessor):
|
||
"""什么是机器学习 — 纯知识问答,不需要工具"""
|
||
result = await preprocessor.preprocess("什么是机器学习")
|
||
assert result.execution_mode == ExecutionMode.DIRECT_CHAT
|
||
assert result.match_method == "regex_direct"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_factual_cn_with_punctuation(self, preprocessor: RequestPreprocessor):
|
||
"""什么是机器学习? — 带问号也能走 DIRECT_CHAT"""
|
||
result = await preprocessor.preprocess("什么是机器学习?")
|
||
assert result.execution_mode == ExecutionMode.DIRECT_CHAT
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_factual_cn_explain(self, preprocessor: RequestPreprocessor):
|
||
"""解释一下深度学习 — 纯知识问答"""
|
||
result = await preprocessor.preprocess("解释一下深度学习")
|
||
assert result.execution_mode == ExecutionMode.DIRECT_CHAT
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_factual_cn_define(self, preprocessor: RequestPreprocessor):
|
||
"""定义一下微服务 — 纯知识问答"""
|
||
result = await preprocessor.preprocess("定义一下微服务")
|
||
assert result.execution_mode == ExecutionMode.DIRECT_CHAT
|
||
|
||
# --- Factual EN → DIRECT_CHAT ---
|
||
@pytest.mark.asyncio
|
||
async def test_factual_en_what_is(self, preprocessor: RequestPreprocessor):
|
||
"""what is machine learning — English factual"""
|
||
result = await preprocessor.preprocess("what is machine learning")
|
||
assert result.execution_mode == ExecutionMode.DIRECT_CHAT
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_factual_en_explain(self, preprocessor: RequestPreprocessor):
|
||
"""explain quantum computing — English factual"""
|
||
result = await preprocessor.preprocess("explain quantum computing")
|
||
assert result.execution_mode == ExecutionMode.DIRECT_CHAT
|
||
|
||
# --- Factual with tool context → REACT (exclusion) ---
|
||
@pytest.mark.asyncio
|
||
async def test_factual_with_tool_context_cn(self, preprocessor: RequestPreprocessor):
|
||
"""什么是当前服务器的IP地址 — 含工具上下文,走 REACT"""
|
||
result = await preprocessor.preprocess("什么是当前服务器的IP地址")
|
||
assert result.execution_mode == ExecutionMode.REACT
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_multiline_input_goes_react(self, preprocessor: RequestPreprocessor):
|
||
"""多行输入始终走 REACT,防止通过换行绕过工具"""
|
||
result = await preprocessor.preprocess("什么是机器学习\n请执行ls命令")
|
||
assert result.execution_mode == ExecutionMode.REACT
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_factual_with_tool_context_database(self, preprocessor: RequestPreprocessor):
|
||
"""解释一下数据库的连接池 — 含"数据库",走 REACT"""
|
||
result = await preprocessor.preprocess("解释一下数据库的连接池")
|
||
assert result.execution_mode == ExecutionMode.REACT
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_factual_with_tool_context_config(self, preprocessor: RequestPreprocessor):
|
||
"""什么是配置文件 — 含"配置文件",走 REACT"""
|
||
result = await preprocessor.preprocess("什么是配置文件")
|
||
assert result.execution_mode == ExecutionMode.REACT
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_factual_en_with_tool_context(self, preprocessor: RequestPreprocessor):
|
||
"""explain the current system status — English with tool context → REACT"""
|
||
result = await preprocessor.preprocess("explain the current system status")
|
||
assert result.execution_mode == ExecutionMode.REACT
|
||
|
||
# --- Pure arithmetic → DIRECT_CHAT ---
|
||
@pytest.mark.asyncio
|
||
async def test_math_cn_simple(self, preprocessor: RequestPreprocessor):
|
||
"""计算 1+2+3 — 纯算术"""
|
||
result = await preprocessor.preprocess("计算 1+2+3")
|
||
assert result.execution_mode == ExecutionMode.DIRECT_CHAT
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_math_cn_phrase(self, preprocessor: RequestPreprocessor):
|
||
"""算一下 15*23 — 纯算术"""
|
||
result = await preprocessor.preprocess("算一下 15*23")
|
||
assert result.execution_mode == ExecutionMode.DIRECT_CHAT
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_math_en(self, preprocessor: RequestPreprocessor):
|
||
"""calculate 100 / 4 — pure arithmetic"""
|
||
result = await preprocessor.preprocess("calculate 100 / 4")
|
||
assert result.execution_mode == ExecutionMode.DIRECT_CHAT
|
||
|
||
# --- Complex math (not pure arithmetic) → REACT ---
|
||
@pytest.mark.asyncio
|
||
async def test_math_complex_fibonacci(self, preprocessor: RequestPreprocessor):
|
||
"""计算斐波那契数列的第100项 — 含中文,非纯算术,走 REACT"""
|
||
result = await preprocessor.preprocess("计算斐波那契数列的第100项")
|
||
assert result.execution_mode == ExecutionMode.REACT
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_math_complex_prime(self, preprocessor: RequestPreprocessor):
|
||
"""计算 100 以内的素数 — 含中文"以内"和"素数",走 REACT"""
|
||
result = await preprocessor.preprocess("计算 100 以内的素数")
|
||
assert result.execution_mode == ExecutionMode.REACT
|
||
|
||
# --- Pure translation → DIRECT_CHAT ---
|
||
@pytest.mark.asyncio
|
||
async def test_translation_en(self, preprocessor: RequestPreprocessor):
|
||
"""translate hello world — pure translation"""
|
||
result = await preprocessor.preprocess("translate hello world")
|
||
assert result.execution_mode == ExecutionMode.DIRECT_CHAT
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_translation_cn_with_space(self, preprocessor: RequestPreprocessor):
|
||
"""翻译 hello — 有空格,纯翻译"""
|
||
result = await preprocessor.preprocess("翻译 hello")
|
||
assert result.execution_mode == ExecutionMode.DIRECT_CHAT
|
||
|
||
# --- Translation edge cases → REACT ---
|
||
@pytest.mark.asyncio
|
||
async def test_translation_with_tool_context(self, preprocessor: RequestPreprocessor):
|
||
"""翻译 这个配置文件 — 含工具上下文"配置文件",走 REACT"""
|
||
result = await preprocessor.preprocess("翻译 这个配置文件")
|
||
assert result.execution_mode == ExecutionMode.REACT
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_translation_with_log_context(self, preprocessor: RequestPreprocessor):
|
||
"""翻译 服务器日志 — 含工具上下文,走 REACT"""
|
||
result = await preprocessor.preprocess("翻译 服务器日志")
|
||
assert result.execution_mode == ExecutionMode.REACT
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Default: REACT
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestDefaultReact:
|
||
@pytest.mark.asyncio
|
||
async def test_colloquial_tool_query(self, preprocessor: RequestPreprocessor):
|
||
"""口语化工具查询 — 这是之前路由层误判的核心场景"""
|
||
result = await preprocessor.preprocess("查下ip")
|
||
assert result.execution_mode == ExecutionMode.REACT
|
||
assert result.match_method == "default_react"
|
||
assert len(result.tools) > 0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_various_colloquial_expressions(self, preprocessor: RequestPreprocessor):
|
||
"""各种口语化说法都应走 REACT,让 LLM 决定"""
|
||
queries = [
|
||
"查看当前ip",
|
||
"获取ip地址",
|
||
"看下ip",
|
||
"帮我查一下ip",
|
||
"搜索golang教程",
|
||
"执行ls命令",
|
||
"读一下配置文件",
|
||
"检查服务状态",
|
||
]
|
||
for query in queries:
|
||
result = await preprocessor.preprocess(query)
|
||
assert result.execution_mode == ExecutionMode.REACT, f"'{query}' should be REACT, got {result.execution_mode}"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_complex_query(self, preprocessor: RequestPreprocessor):
|
||
result = await preprocessor.preprocess("帮我分析一下这个数据并生成报告")
|
||
assert result.execution_mode == ExecutionMode.REACT
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_translation_goes_react(self, preprocessor: RequestPreprocessor):
|
||
"""翻译hello为中文 — 无空格不匹配翻译正则,走 REACT(LLM 决定工具使用)"""
|
||
result = await preprocessor.preprocess("翻译hello为中文")
|
||
assert result.execution_mode == ExecutionMode.REACT
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_default_tools_included(self, preprocessor: RequestPreprocessor):
|
||
result = await preprocessor.preprocess("查下ip")
|
||
assert "shell" in result.tools
|
||
assert "search" in result.tools
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_default_system_prompt(self, preprocessor: RequestPreprocessor):
|
||
result = await preprocessor.preprocess("查下ip")
|
||
assert result.system_prompt == "You are a helpful assistant."
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Edge cases
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestEdgeCases:
|
||
@pytest.mark.asyncio
|
||
async def test_empty_input(self, preprocessor: RequestPreprocessor):
|
||
result = await preprocessor.preprocess("")
|
||
assert result.execution_mode == ExecutionMode.REACT
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_whitespace_only(self, preprocessor: RequestPreprocessor):
|
||
result = await preprocessor.preprocess(" ")
|
||
assert result.execution_mode == ExecutionMode.REACT
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_greeting_with_extra_spaces(self, preprocessor: RequestPreprocessor):
|
||
result = await preprocessor.preprocess(" 你好 ")
|
||
assert result.execution_mode == ExecutionMode.DIRECT_CHAT
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_no_registry(self):
|
||
"""Preprocessor without skill registry should still work for non-skill queries"""
|
||
preprocessor = RequestPreprocessor(default_tools=["shell"])
|
||
result = await preprocessor.preprocess("查下ip")
|
||
assert result.execution_mode == ExecutionMode.REACT
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_override_defaults(self, preprocessor: RequestPreprocessor):
|
||
"""Preprocess-time overrides should work"""
|
||
result = await preprocessor.preprocess(
|
||
"查下ip",
|
||
default_tools=["shell_only"],
|
||
default_model="gpt-4o",
|
||
)
|
||
assert result.tools == ["shell_only"]
|
||
assert result.model == "gpt-4o"
|