224 lines
8.5 KiB
Python
224 lines
8.5 KiB
Python
"""Unit tests for SimpleRouter — minimal routing layer."""
|
||
|
||
from __future__ import annotations
|
||
|
||
import pytest
|
||
|
||
from agentkit.chat.simple_router import SimpleRouter
|
||
from agentkit.chat.skill_routing import ExecutionMode, SkillRoutingResult
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Helpers
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class MockSkill:
|
||
"""Minimal skill mock for testing."""
|
||
|
||
def __init__(self, name: str, execution_mode: str = "react", tools: list | None = None, prompt: dict | None = None):
|
||
self.name = name
|
||
self.execution_mode = execution_mode
|
||
self.tools = tools or []
|
||
self.prompt = prompt or {}
|
||
|
||
|
||
class MockSkillRegistry:
|
||
"""Minimal skill registry mock."""
|
||
|
||
def __init__(self, skills: dict[str, MockSkill] | None = None):
|
||
self._skills = skills or {}
|
||
|
||
def get(self, name: str) -> MockSkill:
|
||
if name not in self._skills:
|
||
raise ValueError(f"Skill '{name}' not found")
|
||
return self._skills[name]
|
||
|
||
def list_skills(self) -> list[MockSkill]:
|
||
return list(self._skills.values())
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Fixtures
|
||
# ---------------------------------------------------------------------------
|
||
|
||
@pytest.fixture
|
||
def registry() -> MockSkillRegistry:
|
||
return MockSkillRegistry({
|
||
"shell_agent": MockSkill("shell_agent", execution_mode="react", tools=["shell"]),
|
||
"direct_agent": MockSkill("direct_agent", execution_mode="direct", tools=[]),
|
||
"rewoo_agent": MockSkill("rewoo_agent", execution_mode="rewoo", tools=["planner"]),
|
||
})
|
||
|
||
|
||
@pytest.fixture
|
||
def router(registry: MockSkillRegistry) -> SimpleRouter:
|
||
return SimpleRouter(
|
||
skill_registry=registry,
|
||
default_tools=["shell", "search", "file_read"],
|
||
default_system_prompt="You are a helpful assistant.",
|
||
default_model="default",
|
||
default_agent_name="default",
|
||
)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Layer 0: @skill:xxx prefix
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestSkillPrefix:
|
||
@pytest.mark.asyncio
|
||
async def test_skill_prefix_routes_to_skill(self, router: SimpleRouter):
|
||
result = await router.route("@skill:shell_agent 查看当前ip")
|
||
assert result.matched is True
|
||
assert result.skill_name == "shell_agent"
|
||
assert result.match_method == "skill_prefix"
|
||
assert result.match_confidence == 1.0
|
||
assert result.execution_mode == ExecutionMode.SKILL_REACT
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_skill_prefix_direct_mode(self, router: SimpleRouter):
|
||
result = await router.route("@skill:direct_agent 翻译hello")
|
||
assert result.matched is True
|
||
assert result.skill_name == "direct_agent"
|
||
assert result.execution_mode == ExecutionMode.DIRECT_CHAT
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_skill_prefix_rewoo_mode(self, router: SimpleRouter):
|
||
result = await router.route("@skill:rewoo_agent 重构代码")
|
||
assert result.matched is True
|
||
assert result.skill_name == "rewoo_agent"
|
||
assert result.execution_mode == ExecutionMode.REWOO
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_unknown_skill_falls_back_to_react(self, router: SimpleRouter):
|
||
result = await router.route("@skill:nonexistent 查询")
|
||
assert result.matched is False
|
||
assert result.match_method == "skill_not_found_fallback"
|
||
assert result.execution_mode == ExecutionMode.REACT
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Layer 1: Greeting/chitchat/identity regex
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestDirectChat:
|
||
@pytest.mark.asyncio
|
||
async def test_greeting_cn(self, router: SimpleRouter):
|
||
result = await router.route("你好")
|
||
assert result.execution_mode == ExecutionMode.DIRECT_CHAT
|
||
assert result.match_method == "regex_direct"
|
||
assert result.tools == []
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_greeting_en(self, router: SimpleRouter):
|
||
result = await router.route("hello")
|
||
assert result.execution_mode == ExecutionMode.DIRECT_CHAT
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_chitchat(self, router: SimpleRouter):
|
||
result = await router.route("谢谢")
|
||
assert result.execution_mode == ExecutionMode.DIRECT_CHAT
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_identity_question(self, router: SimpleRouter):
|
||
result = await router.route("你是谁")
|
||
assert result.execution_mode == ExecutionMode.DIRECT_CHAT
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_identity_question_en(self, router: SimpleRouter):
|
||
result = await router.route("who are you")
|
||
assert result.execution_mode == ExecutionMode.DIRECT_CHAT
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Default: REACT
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestDefaultReact:
|
||
@pytest.mark.asyncio
|
||
async def test_colloquial_tool_query(self, router: SimpleRouter):
|
||
"""口语化工具查询 — 这是之前路由层误判的核心场景"""
|
||
result = await router.route("查下ip")
|
||
assert result.execution_mode == ExecutionMode.REACT
|
||
assert result.match_method == "default_react"
|
||
assert len(result.tools) > 0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_various_colloquial_expressions(self, router: SimpleRouter):
|
||
"""各种口语化说法都应走 REACT,让 LLM 决定"""
|
||
queries = [
|
||
"查看当前ip",
|
||
"获取ip地址",
|
||
"看下ip",
|
||
"帮我查一下ip",
|
||
"搜索golang教程",
|
||
"执行ls命令",
|
||
"读一下配置文件",
|
||
"检查服务状态",
|
||
]
|
||
for query in queries:
|
||
result = await router.route(query)
|
||
assert result.execution_mode == ExecutionMode.REACT, f"'{query}' should be REACT, got {result.execution_mode}"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_complex_query(self, router: SimpleRouter):
|
||
result = await router.route("帮我分析一下这个数据并生成报告")
|
||
assert result.execution_mode == ExecutionMode.REACT
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_translation_goes_react(self, router: SimpleRouter):
|
||
"""翻译类查询也走 REACT — LLM 在 agent loop 中决定不需要工具"""
|
||
result = await router.route("翻译hello为中文")
|
||
assert result.execution_mode == ExecutionMode.REACT
|
||
# LLM will see tools but decide not to use them
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_default_tools_included(self, router: SimpleRouter):
|
||
result = await router.route("查下ip")
|
||
assert "shell" in result.tools
|
||
assert "search" in result.tools
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_default_system_prompt(self, router: SimpleRouter):
|
||
result = await router.route("查下ip")
|
||
assert result.system_prompt == "You are a helpful assistant."
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Edge cases
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestEdgeCases:
|
||
@pytest.mark.asyncio
|
||
async def test_empty_input(self, router: SimpleRouter):
|
||
result = await router.route("")
|
||
assert result.execution_mode == ExecutionMode.REACT
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_whitespace_only(self, router: SimpleRouter):
|
||
result = await router.route(" ")
|
||
assert result.execution_mode == ExecutionMode.REACT
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_greeting_with_extra_spaces(self, router: SimpleRouter):
|
||
result = await router.route(" 你好 ")
|
||
assert result.execution_mode == ExecutionMode.DIRECT_CHAT
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_no_registry(self):
|
||
"""Router without skill registry should still work for non-skill queries"""
|
||
router = SimpleRouter(default_tools=["shell"])
|
||
result = await router.route("查下ip")
|
||
assert result.execution_mode == ExecutionMode.REACT
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_override_defaults(self, router: SimpleRouter):
|
||
"""Route-time overrides should work"""
|
||
result = await router.route(
|
||
"查下ip",
|
||
default_tools=["shell_only"],
|
||
default_model="gpt-4o",
|
||
)
|
||
assert result.tools == ["shell_only"]
|
||
assert result.model == "gpt-4o"
|