513 lines
20 KiB
Python
513 lines
20 KiB
Python
"""CostAwareRouter 单元测试 - 三层成本感知路由"""
|
||
|
||
import json
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
import pytest
|
||
|
||
from agentkit.chat.skill_routing import CostAwareRouter, SkillRoutingResult, _tokenize_content
|
||
from agentkit.llm.protocol import LLMResponse, TokenUsage
|
||
from agentkit.router.intent import IntentRouter, RoutingResult
|
||
from agentkit.skills.base import IntentConfig, Skill, SkillConfig
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Helpers
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def _make_skill(
|
||
name: str,
|
||
keywords: list[str] | None = None,
|
||
description: str = "",
|
||
examples: list[str] | None = None,
|
||
) -> Skill:
|
||
"""快速构造一个带 intent 配置的 Skill"""
|
||
config = SkillConfig(
|
||
name=name,
|
||
agent_type="test",
|
||
task_mode="llm_generate",
|
||
prompt={"system": f"You are a {name} skill."},
|
||
intent={
|
||
"keywords": keywords or [],
|
||
"description": description,
|
||
"examples": examples or [],
|
||
},
|
||
)
|
||
return Skill(config=config)
|
||
|
||
|
||
def _make_llm_gateway(response_content: str) -> MagicMock:
|
||
"""构造一个 mock LLMGateway,chat 返回指定 content"""
|
||
gateway = MagicMock()
|
||
gateway.chat = AsyncMock(
|
||
return_value=LLMResponse(
|
||
content=response_content,
|
||
model="test-model",
|
||
usage=TokenUsage(prompt_tokens=10, completion_tokens=20),
|
||
)
|
||
)
|
||
return gateway
|
||
|
||
|
||
def _make_skill_registry(skills: list[Skill] | None = None) -> MagicMock:
|
||
"""构造一个 mock SkillRegistry"""
|
||
registry = MagicMock()
|
||
_skills = skills or []
|
||
registry.list_skills.return_value = _skills
|
||
|
||
def _get(name: str):
|
||
for s in _skills:
|
||
if s.name == name:
|
||
return s
|
||
raise KeyError(f"Skill '{name}' not found")
|
||
|
||
registry.get = MagicMock(side_effect=_get)
|
||
return registry
|
||
|
||
|
||
def _make_intent_router() -> IntentRouter:
|
||
"""构造一个无 LLM 的 IntentRouter(仅关键词匹配)"""
|
||
return IntentRouter(llm_gateway=None, model="default")
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Layer 0: Rule-based (zero cost)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestLayer0Greeting:
|
||
"""Layer 0: 问候模式匹配"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_chinese_greeting_hits_layer0(self):
|
||
"""'你好' 命中 Layer 0 问候规则,零 token 成本"""
|
||
router = CostAwareRouter()
|
||
result = await router.route(
|
||
content="你好",
|
||
skill_registry=_make_skill_registry(),
|
||
intent_router=_make_intent_router(),
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful.",
|
||
)
|
||
assert result.match_method == "greeting"
|
||
assert result.complexity == 0.0
|
||
assert result.agent_name == "default"
|
||
assert result.matched is False
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_english_greeting_hits_layer0(self):
|
||
"""'hello' 命中 Layer 0 问候规则"""
|
||
router = CostAwareRouter()
|
||
result = await router.route(
|
||
content="hello",
|
||
skill_registry=_make_skill_registry(),
|
||
intent_router=_make_intent_router(),
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful.",
|
||
)
|
||
assert result.match_method == "greeting"
|
||
assert result.complexity == 0.0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_greeting_with_punctuation(self):
|
||
"""'你好!' 带标点也命中 Layer 0"""
|
||
router = CostAwareRouter()
|
||
result = await router.route(
|
||
content="你好!",
|
||
skill_registry=_make_skill_registry(),
|
||
intent_router=_make_intent_router(),
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful.",
|
||
)
|
||
assert result.match_method == "greeting"
|
||
|
||
|
||
class TestLayer0ChatMode:
|
||
"""Layer 0: 简单对话模式"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_thanks_hits_chat_mode(self):
|
||
"""'谢谢' 命中 Layer 0 简单对话模式"""
|
||
router = CostAwareRouter()
|
||
result = await router.route(
|
||
content="谢谢",
|
||
skill_registry=_make_skill_registry(),
|
||
intent_router=_make_intent_router(),
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful.",
|
||
)
|
||
assert result.match_method == "chat_mode"
|
||
assert result.complexity == 0.0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_ok_hits_chat_mode(self):
|
||
"""'好的' 命中 Layer 0 简单对话模式"""
|
||
router = CostAwareRouter()
|
||
result = await router.route(
|
||
content="好的",
|
||
skill_registry=_make_skill_registry(),
|
||
intent_router=_make_intent_router(),
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful.",
|
||
)
|
||
assert result.match_method == "chat_mode"
|
||
|
||
|
||
class TestLayer0ExplicitSkill:
|
||
"""Layer 0: @skill: 显式前缀"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_skill_prefix_hits_layer0(self):
|
||
"""'@skill:search 搜索XX' 命中 Layer 0 显式 Skill 规则,零 token 成本"""
|
||
search_skill = _make_skill("search", keywords=["搜索"], description="搜索信息")
|
||
registry = _make_skill_registry([search_skill])
|
||
# 需要 IntentRouter 支持 LLM fallback
|
||
gateway = _make_llm_gateway(json.dumps({"skill": "search", "confidence": 0.9}))
|
||
intent_router = IntentRouter(llm_gateway=gateway, model="default")
|
||
|
||
router = CostAwareRouter()
|
||
result = await router.route(
|
||
content="@skill:search 搜索XX",
|
||
skill_registry=registry,
|
||
intent_router=intent_router,
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful.",
|
||
)
|
||
assert result.matched is True
|
||
assert result.skill_name == "search"
|
||
assert result.complexity == 0.0
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Layer 1: LLM quick classify (~100 tokens)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestLayer1Classification:
|
||
"""Layer 1: LLM 快速分类"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_medium_complexity_routes_via_intent_router(self):
|
||
"""'分析下这个数据' 经过 Layer 1 LLM 分类,中等复杂度走 IntentRouter"""
|
||
# LLM 返回中等复杂度
|
||
gateway = _make_llm_gateway(json.dumps({"complexity": 0.5}))
|
||
search_skill = _make_skill("search", keywords=["分析"], description="数据分析")
|
||
registry = _make_skill_registry([search_skill])
|
||
|
||
# IntentRouter 也需要 LLM
|
||
intent_gateway = _make_llm_gateway(json.dumps({"skill": "search", "confidence": 0.9}))
|
||
intent_router = IntentRouter(llm_gateway=intent_gateway, model="default")
|
||
|
||
router = CostAwareRouter(llm_gateway=gateway, model="default")
|
||
result = await router.route(
|
||
content="分析下这个数据",
|
||
skill_registry=registry,
|
||
intent_router=intent_router,
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful.",
|
||
)
|
||
assert 0.3 <= result.complexity <= 0.7
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_low_complexity_routes_to_default(self):
|
||
"""低复杂度 (<0.3) 路由到默认 Agent"""
|
||
gateway = _make_llm_gateway(json.dumps({"complexity": 0.1}))
|
||
router = CostAwareRouter(llm_gateway=gateway, model="default")
|
||
result = await router.route(
|
||
content="随便聊聊",
|
||
skill_registry=_make_skill_registry(),
|
||
intent_router=_make_intent_router(),
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful.",
|
||
)
|
||
assert result.complexity < 0.3
|
||
assert result.match_method == "low_complexity"
|
||
assert result.agent_name == "default"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_no_llm_gateway_defaults_to_medium(self):
|
||
"""无 LLM Gateway 时 quick_classify 返回 0.5(中等复杂度)"""
|
||
router = CostAwareRouter(llm_gateway=None)
|
||
complexity = await router.quick_classify("分析下这个数据")
|
||
assert complexity == 0.5
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_llm_malformed_response_defaults_to_medium(self):
|
||
"""LLM 返回非 JSON 时 quick_classify 返回 0.5"""
|
||
gateway = _make_llm_gateway("这不是JSON")
|
||
router = CostAwareRouter(llm_gateway=gateway, model="default")
|
||
complexity = await router.quick_classify("分析下这个数据")
|
||
assert complexity == 0.5
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_complexity_clamped_to_0_1(self):
|
||
"""复杂度值被限制在 [0.0, 1.0] 范围"""
|
||
gateway = _make_llm_gateway(json.dumps({"complexity": 1.5}))
|
||
router = CostAwareRouter(llm_gateway=gateway, model="default")
|
||
complexity = await router.quick_classify("超级复杂任务")
|
||
assert complexity == 1.0
|
||
|
||
gateway2 = _make_llm_gateway(json.dumps({"complexity": -0.5}))
|
||
router2 = CostAwareRouter(llm_gateway=gateway2, model="default")
|
||
complexity2 = await router2.quick_classify("简单任务")
|
||
assert complexity2 == 0.0
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Layer 2: Capability matching / Auction
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestLayer2CapabilityMatching:
|
||
"""Layer 2: 能力匹配 / 拍卖"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_high_complexity_triggers_capability_matching(self):
|
||
"""'做市场调研+竞品分析' 复杂度 > 0.7,触发能力匹配"""
|
||
gateway = _make_llm_gateway(json.dumps({"complexity": 0.85}))
|
||
org_context = MagicMock()
|
||
org_context.find_best_agent = MagicMock(return_value="market-researcher")
|
||
|
||
router = CostAwareRouter(llm_gateway=gateway, model="default", org_context=org_context)
|
||
result = await router.route(
|
||
content="做市场调研+竞品分析",
|
||
skill_registry=_make_skill_registry(),
|
||
intent_router=_make_intent_router(),
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful.",
|
||
)
|
||
assert result.complexity > 0.7
|
||
assert result.match_method == "capability"
|
||
assert result.agent_name == "market-researcher"
|
||
assert result.matched is True
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_layer2_with_org_context_object(self):
|
||
"""org_context.find_best_agent 返回对象时提取 name 属性"""
|
||
gateway = _make_llm_gateway(json.dumps({"complexity": 0.9}))
|
||
agent_obj = MagicMock()
|
||
agent_obj.name = "analyst-agent"
|
||
org_context = MagicMock()
|
||
org_context.find_best_agent = MagicMock(return_value=agent_obj)
|
||
|
||
router = CostAwareRouter(llm_gateway=gateway, model="default", org_context=org_context)
|
||
result = await router.route(
|
||
content="做市场调研+竞品分析",
|
||
skill_registry=_make_skill_registry(),
|
||
intent_router=_make_intent_router(),
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful.",
|
||
)
|
||
assert result.agent_name == "analyst-agent"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_layer2_without_org_context_falls_back_to_intent_router(self):
|
||
"""无 org_context 时 Layer 2 回退到 IntentRouter"""
|
||
gateway = _make_llm_gateway(json.dumps({"complexity": 0.8}))
|
||
search_skill = _make_skill("search", keywords=["调研"], description="市场调研")
|
||
registry = _make_skill_registry([search_skill])
|
||
|
||
intent_gateway = _make_llm_gateway(json.dumps({"skill": "search", "confidence": 0.9}))
|
||
intent_router = IntentRouter(llm_gateway=intent_gateway, model="default")
|
||
|
||
router = CostAwareRouter(llm_gateway=gateway, model="default", org_context=None)
|
||
result = await router.route(
|
||
content="做市场调研+竞品分析",
|
||
skill_registry=registry,
|
||
intent_router=intent_router,
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful.",
|
||
)
|
||
assert result.complexity > 0.7
|
||
# 回退到 IntentRouter,可能匹配到 skill 或走 default
|
||
assert result.match_method in ("capability", "keyword", "llm", "intent_router_fallback", None)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_layer2_org_context_find_best_agent_returns_none(self):
|
||
"""org_context.find_best_agent 返回 None 时回退到 IntentRouter"""
|
||
gateway = _make_llm_gateway(json.dumps({"complexity": 0.8}))
|
||
org_context = MagicMock()
|
||
org_context.find_best_agent = MagicMock(return_value=None)
|
||
|
||
search_skill = _make_skill("search", keywords=["调研"], description="市场调研")
|
||
registry = _make_skill_registry([search_skill])
|
||
intent_gateway = _make_llm_gateway(json.dumps({"skill": "search", "confidence": 0.9}))
|
||
intent_router = IntentRouter(llm_gateway=intent_gateway, model="default")
|
||
|
||
router = CostAwareRouter(llm_gateway=gateway, model="default", org_context=org_context)
|
||
result = await router.route(
|
||
content="做市场调研+竞品分析",
|
||
skill_registry=registry,
|
||
intent_router=intent_router,
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful.",
|
||
)
|
||
assert result.complexity > 0.7
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_auction_disabled_by_default(self):
|
||
"""拍卖模式默认禁用"""
|
||
router = CostAwareRouter()
|
||
assert router._auction_enabled is False
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_auction_can_be_enabled(self):
|
||
"""拍卖模式可手动启用"""
|
||
router = CostAwareRouter(auction_enabled=True)
|
||
assert router._auction_enabled is True
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Transparency
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestTransparency:
|
||
"""透明度级别切换"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_silent_mode_no_trace(self):
|
||
"""SILENT 模式不暴露路由追踪"""
|
||
router = CostAwareRouter()
|
||
result = await router.route(
|
||
content="你好",
|
||
skill_registry=_make_skill_registry(),
|
||
intent_router=_make_intent_router(),
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful.",
|
||
transparency="SILENT",
|
||
)
|
||
assert result.execution_trace == []
|
||
assert result.transparency_level == "SILENT"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_verbose_mode_shows_trace(self):
|
||
"""VERBOSE 模式显示路由追踪"""
|
||
router = CostAwareRouter()
|
||
result = await router.route(
|
||
content="你好",
|
||
skill_registry=_make_skill_registry(),
|
||
intent_router=_make_intent_router(),
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful.",
|
||
transparency="VERBOSE",
|
||
)
|
||
assert len(result.execution_trace) > 0
|
||
assert result.execution_trace[0]["layer"] == 0
|
||
assert result.execution_trace[0]["method"] == "greeting"
|
||
assert result.transparency_level == "VERBOSE"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_trace_mode_shows_full_trace(self):
|
||
"""TRACE 模式显示完整路由追踪"""
|
||
gateway = _make_llm_gateway(json.dumps({"complexity": 0.85}))
|
||
org_context = MagicMock()
|
||
org_context.find_best_agent = MagicMock(return_value="analyst")
|
||
|
||
router = CostAwareRouter(llm_gateway=gateway, model="default", org_context=org_context)
|
||
result = await router.route(
|
||
content="做市场调研+竞品分析",
|
||
skill_registry=_make_skill_registry(),
|
||
intent_router=_make_intent_router(),
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful.",
|
||
transparency="TRACE",
|
||
)
|
||
assert len(result.execution_trace) > 0
|
||
# 应包含 Layer 1 quick_classify 和 Layer 2 的记录
|
||
layers = [t["layer"] for t in result.execution_trace]
|
||
assert 1 in layers # Layer 1 quick_classify
|
||
assert 2 in layers # Layer 2 capability matching
|
||
assert result.transparency_level == "TRACE"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_default_transparency_is_silent(self):
|
||
"""默认透明度为 SILENT"""
|
||
router = CostAwareRouter()
|
||
result = await router.route(
|
||
content="你好",
|
||
skill_registry=_make_skill_registry(),
|
||
intent_router=_make_intent_router(),
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful.",
|
||
)
|
||
assert result.transparency_level == "SILENT"
|
||
assert result.execution_trace == []
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# SkillRoutingResult 新字段
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestSkillRoutingResultNewFields:
|
||
"""SkillRoutingResult 新字段验证"""
|
||
|
||
def test_default_transparency_level(self):
|
||
result = SkillRoutingResult()
|
||
assert result.transparency_level == "SILENT"
|
||
|
||
def test_default_execution_trace(self):
|
||
result = SkillRoutingResult()
|
||
assert result.execution_trace == []
|
||
|
||
def test_default_complexity(self):
|
||
result = SkillRoutingResult()
|
||
assert result.complexity == 0.0
|
||
|
||
def test_new_fields_backward_compatible(self):
|
||
"""新字段不影响旧代码创建 SkillRoutingResult"""
|
||
result = SkillRoutingResult(
|
||
skill_name="test",
|
||
matched=True,
|
||
match_method="keyword",
|
||
)
|
||
assert result.transparency_level == "SILENT"
|
||
assert result.execution_trace == []
|
||
assert result.complexity == 0.0
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# _tokenize_content: 中文分词增强
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestTokenizeContent:
|
||
"""_tokenize_content 中文分词增强测试"""
|
||
|
||
def test_chinese_content(self):
|
||
"""中文内容:'帮我做数据分析' 应包含 '数据分析' 相关 2-gram"""
|
||
tokens = _tokenize_content("帮我做数据分析")
|
||
# 整段无标点分隔,生成 2-gram:帮我、我做、做数、数据、据分、分析
|
||
assert "数据" in tokens or "数据分析" in tokens
|
||
|
||
def test_english_content(self):
|
||
"""英文内容:'help with code generation' 应包含 'code', 'generation' 或 'code generation'"""
|
||
tokens = _tokenize_content("help with code generation")
|
||
assert "code" in tokens or "generation" in tokens or "code generation" in tokens
|
||
|
||
def test_mixed_content(self):
|
||
"""中英混合:'用python做data analysis' 应包含 'python' 相关 token 和 'data analysis'"""
|
||
tokens = _tokenize_content("用python做data analysis")
|
||
# 按空格分割后 "用python做data" 作为一个 segment,生成 2-gram
|
||
# "analysis" 作为独立 segment
|
||
assert "analysis" in tokens
|
||
# "用python做data" 长度 > 4,会生成 2-gram,其中包含 python 相关片段
|
||
has_python_related = any("python" in t for t in tokens)
|
||
assert has_python_related or "data analysis" in tokens
|
||
|
||
def test_stopwords_filtered(self):
|
||
"""停用词过滤:纯停用词短句过滤后应为空或极少 token"""
|
||
tokens = _tokenize_content("我的一个")
|
||
# "我的一个" 长度 4,作为整体保留(不在停用词集合中)
|
||
# 但停用词 "我的" "的一" "一个" 等 2-gram 会被过滤
|
||
assert len(tokens) <= 1
|
||
|
||
def test_bigram_generation(self):
|
||
"""2-gram 生成:'机器学习模型训练' 应包含各 2-gram"""
|
||
tokens = _tokenize_content("机器学习模型训练")
|
||
expected_bigrams = ["机器", "器学", "学习", "习模", "模型", "型训", "训练"]
|
||
for bigram in expected_bigrams:
|
||
assert bigram in tokens, f"缺少 2-gram: {bigram}"
|