355 lines
14 KiB
Python
355 lines
14 KiB
Python
"""Intent Router 单元测试 - 两级意图路由:关键词匹配 → LLM 分类"""
|
||
|
||
import json
|
||
from unittest.mock import AsyncMock, MagicMock
|
||
|
||
import pytest
|
||
|
||
from agentkit.llm.protocol import LLMResponse, TokenUsage
|
||
from agentkit.router import IntentRouter, RoutingResult
|
||
from agentkit.skills.base import IntentConfig, Skill, SkillConfig
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Helpers
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def _make_skill(
|
||
name: str,
|
||
keywords: list[str] | None = None,
|
||
description: str = "",
|
||
examples: list[str] | None = None,
|
||
) -> Skill:
|
||
"""快速构造一个带 intent 配置的 Skill"""
|
||
config = SkillConfig(
|
||
name=name,
|
||
agent_type="test",
|
||
task_mode="llm_generate",
|
||
prompt={"system": f"You are a {name} skill."},
|
||
intent={
|
||
"keywords": keywords or [],
|
||
"description": description,
|
||
"examples": examples or [],
|
||
},
|
||
)
|
||
return Skill(config=config)
|
||
|
||
|
||
def _make_llm_gateway(response_content: str) -> MagicMock:
|
||
"""构造一个 mock LLMGateway,chat 返回指定 content"""
|
||
gateway = MagicMock()
|
||
gateway.chat = AsyncMock(
|
||
return_value=LLMResponse(
|
||
content=response_content,
|
||
model="test-model",
|
||
usage=TokenUsage(prompt_tokens=10, completion_tokens=20),
|
||
)
|
||
)
|
||
return gateway
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# RoutingResult 数据类
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestRoutingResult:
|
||
"""RoutingResult 数据类基本验证"""
|
||
|
||
def test_create_routing_result(self):
|
||
result = RoutingResult(matched_skill="weather", method="keyword", confidence=1.0)
|
||
assert result.matched_skill == "weather"
|
||
assert result.method == "keyword"
|
||
assert result.confidence == 1.0
|
||
|
||
def test_routing_result_contains_method_and_confidence(self):
|
||
result = RoutingResult(matched_skill="search", method="llm", confidence=0.85)
|
||
assert hasattr(result, "method")
|
||
assert hasattr(result, "confidence")
|
||
assert result.method == "llm"
|
||
assert result.confidence == 0.85
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 关键词匹配 (Level 1)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestKeywordMatching:
|
||
"""Level 1: 关键词匹配"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_keyword_match_returns_keyword_method(self):
|
||
"""输入包含 Skill 的 intent.keywords → 返回 method='keyword', confidence=1.0"""
|
||
router = IntentRouter()
|
||
weather = _make_skill("weather", keywords=["天气", "weather", "气温"])
|
||
skills = [weather]
|
||
|
||
result = await router.route({"query": "今天天气怎么样"}, skills)
|
||
|
||
assert result.matched_skill == "weather"
|
||
assert result.method == "keyword"
|
||
assert result.confidence == 1.0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_keyword_no_match_falls_through(self):
|
||
"""输入不包含任何 keyword → 关键词匹配返回 None,走 LLM"""
|
||
gateway = _make_llm_gateway(json.dumps({"skill": "search", "confidence": 0.9}))
|
||
router = IntentRouter(llm_gateway=gateway)
|
||
|
||
weather = _make_skill("weather", keywords=["天气"])
|
||
search = _make_skill("search", keywords=["搜索"], description="搜索信息")
|
||
skills = [weather, search]
|
||
|
||
result = await router.route({"query": "帮我找一下附近的餐厅"}, skills)
|
||
|
||
# 应该走 LLM fallback
|
||
assert result.method == "llm"
|
||
assert result.matched_skill == "search"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_keyword_match_case_insensitive(self):
|
||
"""关键词匹配不区分大小写"""
|
||
router = IntentRouter()
|
||
skill = _make_skill("weather", keywords=["Weather", "TEMPERATURE"])
|
||
skills = [skill]
|
||
|
||
result = await router.route({"query": "what's the weather today"}, skills)
|
||
|
||
assert result.matched_skill == "weather"
|
||
assert result.method == "keyword"
|
||
assert result.confidence == 1.0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_keyword_confidence_always_1(self):
|
||
"""关键词匹配的 confidence 始终为 1.0"""
|
||
router = IntentRouter()
|
||
skill = _make_skill("calc", keywords=["计算", "算数"])
|
||
skills = [skill]
|
||
|
||
result = await router.route({"text": "帮我计算一下"}, skills)
|
||
|
||
assert result.confidence == 1.0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_keyword_match_nested_input(self):
|
||
"""关键词匹配检查 input_data 中的嵌套字符串值"""
|
||
router = IntentRouter()
|
||
skill = _make_skill("translate", keywords=["翻译", "translate"])
|
||
skills = [skill]
|
||
|
||
result = await router.route(
|
||
{"message": {"content": "请翻译这段话", "lang": "en"}},
|
||
skills,
|
||
)
|
||
|
||
assert result.matched_skill == "translate"
|
||
assert result.method == "keyword"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_keyword_match_multiple_hits_returns_first(self):
|
||
"""多个关键词匹配时,返回第一个匹配的 Skill"""
|
||
router = IntentRouter()
|
||
skill_a = _make_skill("weather", keywords=["天气"])
|
||
skill_b = _make_skill("translate", keywords=["翻译"])
|
||
skills = [skill_a, skill_b]
|
||
|
||
# "天气" 先匹配
|
||
result = await router.route({"query": "天气翻译"}, skills)
|
||
assert result.matched_skill == "weather"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_keyword_match_in_list_values(self):
|
||
"""关键词匹配检查 input_data 中列表内的字符串值"""
|
||
router = IntentRouter()
|
||
skill = _make_skill("search", keywords=["搜索"])
|
||
skills = [skill]
|
||
|
||
result = await router.route(
|
||
{"messages": ["你好", "帮我搜索一下"], "type": "chat"},
|
||
skills,
|
||
)
|
||
|
||
assert result.matched_skill == "search"
|
||
assert result.method == "keyword"
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# LLM 分类 (Level 2)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestLLMClassification:
|
||
"""Level 2: LLM 分类"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_llm_classification_returns_llm_method(self):
|
||
"""关键词匹配失败,LLM 正确分类 → 返回 method='llm'"""
|
||
gateway = _make_llm_gateway(json.dumps({"skill": "search", "confidence": 0.92}))
|
||
router = IntentRouter(llm_gateway=gateway)
|
||
|
||
weather = _make_skill("weather", keywords=["天气"], description="查询天气")
|
||
search = _make_skill("search", keywords=["搜索"], description="搜索信息")
|
||
skills = [weather, search]
|
||
|
||
result = await router.route({"query": "附近有什么好吃的"}, skills)
|
||
|
||
assert result.matched_skill == "search"
|
||
assert result.method == "llm"
|
||
assert result.confidence == 0.92
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_llm_confidence_from_response(self):
|
||
"""LLM 分类的 confidence 来自 LLM 响应"""
|
||
gateway = _make_llm_gateway(json.dumps({"skill": "weather", "confidence": 0.75}))
|
||
router = IntentRouter(llm_gateway=gateway)
|
||
|
||
weather = _make_skill("weather", keywords=["天气"], description="查询天气")
|
||
search = _make_skill("search", keywords=["搜索"], description="搜索信息")
|
||
skills = [weather, search]
|
||
|
||
result = await router.route({"query": "外面冷不冷"}, skills)
|
||
|
||
assert result.confidence == 0.75
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_llm_nonexistent_skill_raises_value_error(self):
|
||
"""LLM 返回不存在的 skill name → 抛出 ValueError"""
|
||
gateway = _make_llm_gateway(json.dumps({"skill": "nonexistent", "confidence": 0.5}))
|
||
router = IntentRouter(llm_gateway=gateway)
|
||
|
||
weather = _make_skill("weather", keywords=["天气"], description="查询天气")
|
||
search = _make_skill("search", keywords=["搜索"], description="搜索信息")
|
||
skills = [weather, search]
|
||
|
||
with pytest.raises(ValueError, match="nonexistent"):
|
||
await router.route({"query": "你好"}, skills)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_llm_malformed_json_extracts_skill_name(self):
|
||
"""LLM 返回非标准 JSON → 尝试从文本中提取 skill name"""
|
||
gateway = _make_llm_gateway('我觉得应该匹配 weather 这个技能')
|
||
router = IntentRouter(llm_gateway=gateway)
|
||
|
||
weather = _make_skill("weather", keywords=["天气"], description="查询天气")
|
||
search = _make_skill("search", keywords=["搜索"], description="搜索信息")
|
||
skills = [weather, search]
|
||
|
||
result = await router.route({"query": "外面冷不冷"}, skills)
|
||
|
||
# 应该能从文本中提取到 "weather"
|
||
assert result.matched_skill == "weather"
|
||
assert result.method == "llm"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_llm_no_gateway_raises_error(self):
|
||
"""没有 LLM Gateway 且关键词匹配失败 → 抛出异常"""
|
||
router = IntentRouter(llm_gateway=None)
|
||
|
||
weather = _make_skill("weather", keywords=["天气"])
|
||
search = _make_skill("search", keywords=["搜索"])
|
||
skills = [weather, search]
|
||
|
||
with pytest.raises((ValueError, RuntimeError)):
|
||
await router.route({"query": "你好世界"}, skills)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_llm_classification_uses_skill_description_and_examples(self):
|
||
"""LLM 分类时使用 Skill 的 description 和 examples 构建提示"""
|
||
gateway = _make_llm_gateway(json.dumps({"skill": "search", "confidence": 0.9}))
|
||
router = IntentRouter(llm_gateway=gateway)
|
||
|
||
search = _make_skill(
|
||
"search",
|
||
keywords=["搜索"],
|
||
description="搜索互联网上的信息",
|
||
examples=["帮我搜一下", "查找相关资料"],
|
||
)
|
||
weather = _make_skill("weather", keywords=["天气"], description="查询天气")
|
||
skills = [search, weather]
|
||
|
||
await router.route({"query": "找找看"}, skills)
|
||
|
||
# 验证 LLM 被调用,且 prompt 包含 description 和 examples
|
||
gateway.chat.assert_called_once()
|
||
call_args = gateway.chat.call_args
|
||
messages = call_args[1]["messages"] if "messages" in call_args[1] else call_args[0][0]
|
||
prompt_text = messages[0]["content"] if isinstance(messages, list) else str(messages)
|
||
assert "搜索互联网上的信息" in prompt_text
|
||
assert "帮我搜一下" in prompt_text
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 边界情况
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestEdgeCases:
|
||
"""边界情况"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_single_skill_returns_directly(self):
|
||
"""只有一个 Skill 时直接返回,不做关键词/LLM 检查"""
|
||
router = IntentRouter()
|
||
skill = _make_skill("only_one", keywords=["唯一"])
|
||
skills = [skill]
|
||
|
||
result = await router.route({"query": "随便什么输入"}, skills)
|
||
|
||
assert result.matched_skill == "only_one"
|
||
assert result.method == "keyword"
|
||
assert result.confidence == 1.0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_empty_skill_list_raises_value_error(self):
|
||
"""空 Skill 列表 → 抛出 ValueError"""
|
||
router = IntentRouter()
|
||
|
||
with pytest.raises(ValueError, match="[Ss]kill"):
|
||
await router.route({"query": "hello"}, [])
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_skill_with_empty_keywords(self):
|
||
"""Skill 的 keywords 为空列表时,关键词匹配不会命中"""
|
||
gateway = _make_llm_gateway(json.dumps({"skill": "generic", "confidence": 0.6}))
|
||
router = IntentRouter(llm_gateway=gateway)
|
||
|
||
skill = _make_skill("generic", keywords=[], description="通用技能")
|
||
skills = [skill]
|
||
|
||
result = await router.route({"query": "你好"}, skills)
|
||
|
||
# 只有一个 skill,直接返回
|
||
assert result.matched_skill == "generic"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_input_data_with_no_string_values(self):
|
||
"""input_data 中没有字符串值 → 关键词匹配失败,走 LLM"""
|
||
gateway = _make_llm_gateway(json.dumps({"skill": "weather", "confidence": 0.8}))
|
||
router = IntentRouter(llm_gateway=gateway)
|
||
|
||
weather = _make_skill("weather", keywords=["天气"], description="查询天气")
|
||
search = _make_skill("search", keywords=["搜索"], description="搜索信息")
|
||
skills = [weather, search]
|
||
|
||
result = await router.route({"count": 42, "flag": True}, skills)
|
||
|
||
assert result.method == "llm"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_model_parameter_passed_to_gateway(self):
|
||
"""IntentRouter 的 model 参数传递给 LLM Gateway"""
|
||
gateway = _make_llm_gateway(json.dumps({"skill": "weather", "confidence": 0.9}))
|
||
router = IntentRouter(llm_gateway=gateway, model="gpt-4")
|
||
|
||
weather = _make_skill("weather", keywords=["天气"], description="查询天气")
|
||
search = _make_skill("search", keywords=["搜索"], description="搜索信息")
|
||
skills = [weather, search]
|
||
|
||
await router.route({"query": "你好"}, skills)
|
||
|
||
gateway.chat.assert_called_once()
|
||
call_kwargs = gateway.chat.call_args[1] if gateway.chat.call_args[1] else {}
|
||
assert call_kwargs.get("model") == "gpt-4" or gateway.chat.call_args[0][1] == "gpt-4"
|