809 lines
32 KiB
Python
809 lines
32 KiB
Python
"""CostAwareRouter 单元测试 - 三层成本感知路由"""
|
||
|
||
import json
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
import pytest
|
||
|
||
from agentkit.chat.skill_routing import CostAwareRouter, HeuristicClassifier, SkillRoutingResult, _tokenize_content
|
||
from agentkit.llm.protocol import LLMResponse, TokenUsage
|
||
from agentkit.router.intent import IntentRouter, RoutingResult
|
||
from agentkit.skills.base import IntentConfig, Skill, SkillConfig
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Helpers
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def _make_skill(
|
||
name: str,
|
||
keywords: list[str] | None = None,
|
||
description: str = "",
|
||
examples: list[str] | None = None,
|
||
) -> Skill:
|
||
"""快速构造一个带 intent 配置的 Skill"""
|
||
config = SkillConfig(
|
||
name=name,
|
||
agent_type="test",
|
||
task_mode="llm_generate",
|
||
prompt={"system": f"You are a {name} skill."},
|
||
intent={
|
||
"keywords": keywords or [],
|
||
"description": description,
|
||
"examples": examples or [],
|
||
},
|
||
)
|
||
return Skill(config=config)
|
||
|
||
|
||
def _make_llm_gateway(response_content: str) -> MagicMock:
|
||
"""构造一个 mock LLMGateway,chat 返回指定 content"""
|
||
gateway = MagicMock()
|
||
gateway.chat = AsyncMock(
|
||
return_value=LLMResponse(
|
||
content=response_content,
|
||
model="test-model",
|
||
usage=TokenUsage(prompt_tokens=10, completion_tokens=20),
|
||
)
|
||
)
|
||
return gateway
|
||
|
||
|
||
def _make_skill_registry(skills: list[Skill] | None = None) -> MagicMock:
|
||
"""构造一个 mock SkillRegistry"""
|
||
registry = MagicMock()
|
||
_skills = skills or []
|
||
registry.list_skills.return_value = _skills
|
||
|
||
def _get(name: str):
|
||
for s in _skills:
|
||
if s.name == name:
|
||
return s
|
||
raise KeyError(f"Skill '{name}' not found")
|
||
|
||
registry.get = MagicMock(side_effect=_get)
|
||
return registry
|
||
|
||
|
||
def _make_intent_router() -> IntentRouter:
|
||
"""构造一个无 LLM 的 IntentRouter(仅关键词匹配)"""
|
||
return IntentRouter(llm_gateway=None, model="default")
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Layer 0: Rule-based (zero cost)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestLayer0Greeting:
|
||
"""Layer 0: 问候模式匹配"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_chinese_greeting_hits_layer0(self):
|
||
"""'你好' 命中 Layer 0 问候规则,零 token 成本"""
|
||
router = CostAwareRouter()
|
||
result = await router.route(
|
||
content="你好",
|
||
skill_registry=_make_skill_registry(),
|
||
intent_router=_make_intent_router(),
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful.",
|
||
)
|
||
assert result.match_method == "greeting"
|
||
assert result.complexity == 0.0
|
||
assert result.agent_name == "default"
|
||
assert result.matched is False
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_english_greeting_hits_layer0(self):
|
||
"""'hello' 命中 Layer 0 问候规则"""
|
||
router = CostAwareRouter()
|
||
result = await router.route(
|
||
content="hello",
|
||
skill_registry=_make_skill_registry(),
|
||
intent_router=_make_intent_router(),
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful.",
|
||
)
|
||
assert result.match_method == "greeting"
|
||
assert result.complexity == 0.0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_greeting_with_punctuation(self):
|
||
"""'你好!' 带标点也命中 Layer 0"""
|
||
router = CostAwareRouter()
|
||
result = await router.route(
|
||
content="你好!",
|
||
skill_registry=_make_skill_registry(),
|
||
intent_router=_make_intent_router(),
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful.",
|
||
)
|
||
assert result.match_method == "greeting"
|
||
|
||
|
||
class TestLayer0ChatMode:
|
||
"""Layer 0: 简单对话模式"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_thanks_hits_chat_mode(self):
|
||
"""'谢谢' 命中 Layer 0 简单对话模式"""
|
||
router = CostAwareRouter()
|
||
result = await router.route(
|
||
content="谢谢",
|
||
skill_registry=_make_skill_registry(),
|
||
intent_router=_make_intent_router(),
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful.",
|
||
)
|
||
assert result.match_method == "chat_mode"
|
||
assert result.complexity == 0.0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_ok_hits_chat_mode(self):
|
||
"""'好的' 命中 Layer 0 简单对话模式"""
|
||
router = CostAwareRouter()
|
||
result = await router.route(
|
||
content="好的",
|
||
skill_registry=_make_skill_registry(),
|
||
intent_router=_make_intent_router(),
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful.",
|
||
)
|
||
assert result.match_method == "chat_mode"
|
||
|
||
|
||
class TestLayer0ExplicitSkill:
|
||
"""Layer 0: @skill: 显式前缀"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_skill_prefix_hits_layer0(self):
|
||
"""'@skill:search 搜索XX' 命中 Layer 0 显式 Skill 规则,零 token 成本"""
|
||
search_skill = _make_skill("search", keywords=["搜索"], description="搜索信息")
|
||
registry = _make_skill_registry([search_skill])
|
||
# 需要 IntentRouter 支持 LLM fallback
|
||
gateway = _make_llm_gateway(json.dumps({"skill": "search", "confidence": 0.9}))
|
||
intent_router = IntentRouter(llm_gateway=gateway, model="default")
|
||
|
||
router = CostAwareRouter()
|
||
result = await router.route(
|
||
content="@skill:search 搜索XX",
|
||
skill_registry=registry,
|
||
intent_router=intent_router,
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful.",
|
||
)
|
||
assert result.matched is True
|
||
assert result.skill_name == "search"
|
||
assert result.complexity == 0.0
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Layer 1: LLM quick classify (~100 tokens)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestLayer1Classification:
|
||
"""Layer 1: LLM 快速分类"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_medium_complexity_routes_via_intent_router(self):
|
||
"""'分析下这个数据' 经过 Layer 1 LLM 分类,中等复杂度走 IntentRouter"""
|
||
# LLM 返回中等复杂度
|
||
gateway = _make_llm_gateway(json.dumps({"complexity": 0.5}))
|
||
search_skill = _make_skill("search", keywords=["分析"], description="数据分析")
|
||
registry = _make_skill_registry([search_skill])
|
||
|
||
# IntentRouter 也需要 LLM
|
||
intent_gateway = _make_llm_gateway(json.dumps({"skill": "search", "confidence": 0.9}))
|
||
intent_router = IntentRouter(llm_gateway=intent_gateway, model="default")
|
||
|
||
router = CostAwareRouter(llm_gateway=gateway, model="default")
|
||
result = await router.route(
|
||
content="分析下这个数据",
|
||
skill_registry=registry,
|
||
intent_router=intent_router,
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful.",
|
||
)
|
||
assert 0.3 <= result.complexity <= 0.7
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_low_complexity_routes_to_default(self):
|
||
"""低复杂度 (<0.3) 路由到默认 Agent"""
|
||
gateway = _make_llm_gateway(json.dumps({"complexity": 0.1}))
|
||
router = CostAwareRouter(llm_gateway=gateway, model="default")
|
||
result = await router.route(
|
||
content="随便聊聊",
|
||
skill_registry=_make_skill_registry(),
|
||
intent_router=_make_intent_router(),
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful.",
|
||
)
|
||
assert result.complexity < 0.3
|
||
assert result.match_method == "low_complexity"
|
||
assert result.agent_name == "default"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_no_llm_gateway_defaults_to_medium(self):
|
||
"""无 LLM Gateway 时 quick_classify 返回 0.5(中等复杂度)"""
|
||
router = CostAwareRouter(llm_gateway=None)
|
||
complexity = await router.quick_classify("分析下这个数据")
|
||
assert complexity == 0.5
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_llm_malformed_response_defaults_to_medium(self):
|
||
"""LLM 返回非 JSON 时 quick_classify 返回 0.5"""
|
||
gateway = _make_llm_gateway("这不是JSON")
|
||
router = CostAwareRouter(llm_gateway=gateway, model="default")
|
||
complexity = await router.quick_classify("分析下这个数据")
|
||
assert complexity == 0.5
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_complexity_clamped_to_0_1(self):
|
||
"""复杂度值被限制在 [0.0, 1.0] 范围"""
|
||
gateway = _make_llm_gateway(json.dumps({"complexity": 1.5}))
|
||
router = CostAwareRouter(llm_gateway=gateway, model="default")
|
||
complexity = await router.quick_classify("超级复杂任务")
|
||
assert complexity == 1.0
|
||
|
||
gateway2 = _make_llm_gateway(json.dumps({"complexity": -0.5}))
|
||
router2 = CostAwareRouter(llm_gateway=gateway2, model="default")
|
||
complexity2 = await router2.quick_classify("简单任务")
|
||
assert complexity2 == 0.0
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Layer 2: Capability matching / Auction
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestLayer2CapabilityMatching:
|
||
"""Layer 2: 能力匹配 / 拍卖"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_high_complexity_triggers_capability_matching(self):
|
||
"""'做市场调研+竞品分析' 复杂度 > 0.7,触发能力匹配"""
|
||
gateway = _make_llm_gateway(json.dumps({"complexity": 0.85}))
|
||
org_context = MagicMock()
|
||
org_context.find_best_agent = MagicMock(return_value="market-researcher")
|
||
|
||
router = CostAwareRouter(llm_gateway=gateway, model="default", org_context=org_context)
|
||
result = await router.route(
|
||
content="做市场调研+竞品分析",
|
||
skill_registry=_make_skill_registry(),
|
||
intent_router=_make_intent_router(),
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful.",
|
||
)
|
||
assert result.complexity > 0.7
|
||
assert result.match_method == "capability"
|
||
assert result.agent_name == "market-researcher"
|
||
assert result.matched is True
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_layer2_with_org_context_object(self):
|
||
"""org_context.find_best_agent 返回对象时提取 name 属性"""
|
||
gateway = _make_llm_gateway(json.dumps({"complexity": 0.9}))
|
||
agent_obj = MagicMock()
|
||
agent_obj.name = "analyst-agent"
|
||
org_context = MagicMock()
|
||
org_context.find_best_agent = MagicMock(return_value=agent_obj)
|
||
|
||
router = CostAwareRouter(llm_gateway=gateway, model="default", org_context=org_context)
|
||
result = await router.route(
|
||
content="做市场调研+竞品分析",
|
||
skill_registry=_make_skill_registry(),
|
||
intent_router=_make_intent_router(),
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful.",
|
||
)
|
||
assert result.agent_name == "analyst-agent"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_layer2_without_org_context_falls_back_to_intent_router(self):
|
||
"""无 org_context 时 Layer 2 回退到 IntentRouter"""
|
||
gateway = _make_llm_gateway(json.dumps({"complexity": 0.8}))
|
||
search_skill = _make_skill("search", keywords=["调研"], description="市场调研")
|
||
registry = _make_skill_registry([search_skill])
|
||
|
||
intent_gateway = _make_llm_gateway(json.dumps({"skill": "search", "confidence": 0.9}))
|
||
intent_router = IntentRouter(llm_gateway=intent_gateway, model="default")
|
||
|
||
router = CostAwareRouter(llm_gateway=gateway, model="default", org_context=None)
|
||
result = await router.route(
|
||
content="做市场调研+竞品分析",
|
||
skill_registry=registry,
|
||
intent_router=intent_router,
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful.",
|
||
)
|
||
assert result.complexity > 0.7
|
||
# 回退到 IntentRouter,可能匹配到 skill 或走 default
|
||
assert result.match_method in ("capability", "keyword", "llm", "intent_router_fallback", None)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_layer2_org_context_find_best_agent_returns_none(self):
|
||
"""org_context.find_best_agent 返回 None 时回退到 IntentRouter"""
|
||
gateway = _make_llm_gateway(json.dumps({"complexity": 0.8}))
|
||
org_context = MagicMock()
|
||
org_context.find_best_agent = MagicMock(return_value=None)
|
||
|
||
search_skill = _make_skill("search", keywords=["调研"], description="市场调研")
|
||
registry = _make_skill_registry([search_skill])
|
||
intent_gateway = _make_llm_gateway(json.dumps({"skill": "search", "confidence": 0.9}))
|
||
intent_router = IntentRouter(llm_gateway=intent_gateway, model="default")
|
||
|
||
router = CostAwareRouter(llm_gateway=gateway, model="default", org_context=org_context)
|
||
result = await router.route(
|
||
content="做市场调研+竞品分析",
|
||
skill_registry=registry,
|
||
intent_router=intent_router,
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful.",
|
||
)
|
||
assert result.complexity > 0.7
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_auction_disabled_by_default(self):
|
||
"""拍卖模式默认禁用"""
|
||
router = CostAwareRouter()
|
||
assert router._auction_enabled is False
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_auction_can_be_enabled(self):
|
||
"""拍卖模式可手动启用"""
|
||
router = CostAwareRouter(auction_enabled=True)
|
||
assert router._auction_enabled is True
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Transparency
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestTransparency:
|
||
"""透明度级别切换"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_silent_mode_no_trace(self):
|
||
"""SILENT 模式不暴露路由追踪"""
|
||
router = CostAwareRouter()
|
||
result = await router.route(
|
||
content="你好",
|
||
skill_registry=_make_skill_registry(),
|
||
intent_router=_make_intent_router(),
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful.",
|
||
transparency="SILENT",
|
||
)
|
||
assert result.execution_trace == []
|
||
assert result.transparency_level == "SILENT"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_verbose_mode_shows_trace(self):
|
||
"""VERBOSE 模式显示路由追踪"""
|
||
router = CostAwareRouter()
|
||
result = await router.route(
|
||
content="你好",
|
||
skill_registry=_make_skill_registry(),
|
||
intent_router=_make_intent_router(),
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful.",
|
||
transparency="VERBOSE",
|
||
)
|
||
assert len(result.execution_trace) > 0
|
||
assert result.execution_trace[0]["layer"] == 0
|
||
assert result.execution_trace[0]["method"] == "greeting"
|
||
assert result.transparency_level == "VERBOSE"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_trace_mode_shows_full_trace(self):
|
||
"""TRACE 模式显示完整路由追踪"""
|
||
gateway = _make_llm_gateway(json.dumps({"complexity": 0.85}))
|
||
org_context = MagicMock()
|
||
org_context.find_best_agent = MagicMock(return_value="analyst")
|
||
|
||
router = CostAwareRouter(llm_gateway=gateway, model="default", org_context=org_context)
|
||
result = await router.route(
|
||
content="做市场调研+竞品分析",
|
||
skill_registry=_make_skill_registry(),
|
||
intent_router=_make_intent_router(),
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful.",
|
||
transparency="TRACE",
|
||
)
|
||
assert len(result.execution_trace) > 0
|
||
# 应包含 Layer 1 quick_classify 和 Layer 2 的记录
|
||
layers = [t["layer"] for t in result.execution_trace]
|
||
assert 1 in layers # Layer 1 quick_classify
|
||
assert 2 in layers # Layer 2 capability matching
|
||
assert result.transparency_level == "TRACE"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_default_transparency_is_silent(self):
|
||
"""默认透明度为 SILENT"""
|
||
router = CostAwareRouter()
|
||
result = await router.route(
|
||
content="你好",
|
||
skill_registry=_make_skill_registry(),
|
||
intent_router=_make_intent_router(),
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful.",
|
||
)
|
||
assert result.transparency_level == "SILENT"
|
||
assert result.execution_trace == []
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# SkillRoutingResult 新字段
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestSkillRoutingResultNewFields:
|
||
"""SkillRoutingResult 新字段验证"""
|
||
|
||
def test_default_transparency_level(self):
|
||
result = SkillRoutingResult()
|
||
assert result.transparency_level == "SILENT"
|
||
|
||
def test_default_execution_trace(self):
|
||
result = SkillRoutingResult()
|
||
assert result.execution_trace == []
|
||
|
||
def test_default_complexity(self):
|
||
result = SkillRoutingResult()
|
||
assert result.complexity == 0.0
|
||
|
||
def test_new_fields_backward_compatible(self):
|
||
"""新字段不影响旧代码创建 SkillRoutingResult"""
|
||
result = SkillRoutingResult(
|
||
skill_name="test",
|
||
matched=True,
|
||
match_method="keyword",
|
||
)
|
||
assert result.transparency_level == "SILENT"
|
||
assert result.execution_trace == []
|
||
assert result.complexity == 0.0
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# _tokenize_content: 中文分词增强
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestTokenizeContent:
|
||
"""_tokenize_content 中文分词增强测试"""
|
||
|
||
def test_chinese_content(self):
|
||
"""中文内容:'帮我做数据分析' 应包含 '数据分析' 相关 2-gram"""
|
||
tokens = _tokenize_content("帮我做数据分析")
|
||
# 整段无标点分隔,生成 2-gram:帮我、我做、做数、数据、据分、分析
|
||
assert "数据" in tokens or "数据分析" in tokens
|
||
|
||
def test_english_content(self):
|
||
"""英文内容:'help with code generation' 应包含 'code', 'generation' 或 'code generation'"""
|
||
tokens = _tokenize_content("help with code generation")
|
||
assert "code" in tokens or "generation" in tokens or "code generation" in tokens
|
||
|
||
def test_mixed_content(self):
|
||
"""中英混合:'用python做data analysis' 应包含 'python' 相关 token 和 'data analysis'"""
|
||
tokens = _tokenize_content("用python做data analysis")
|
||
# 按空格分割后 "用python做data" 作为一个 segment,生成 2-gram
|
||
# "analysis" 作为独立 segment
|
||
assert "analysis" in tokens
|
||
# "用python做data" 长度 > 4,会生成 2-gram,其中包含 python 相关片段
|
||
has_python_related = any("python" in t for t in tokens)
|
||
assert has_python_related or "data analysis" in tokens
|
||
|
||
def test_stopwords_filtered(self):
|
||
"""停用词过滤:纯停用词短句过滤后应为空或极少 token"""
|
||
tokens = _tokenize_content("我的一个")
|
||
# "我的一个" 长度 4,作为整体保留(不在停用词集合中)
|
||
# 但停用词 "我的" "的一" "一个" 等 2-gram 会被过滤
|
||
assert len(tokens) <= 1
|
||
|
||
def test_bigram_generation(self):
|
||
"""2-gram 生成:'机器学习模型训练' 应包含各 2-gram"""
|
||
tokens = _tokenize_content("机器学习模型训练")
|
||
expected_bigrams = ["机器", "器学", "学习", "习模", "模型", "型训", "训练"]
|
||
for bigram in expected_bigrams:
|
||
assert bigram in tokens, f"缺少 2-gram: {bigram}"
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# HeuristicClassifier
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestHeuristicClassifier:
|
||
"""HeuristicClassifier 本地启发式分类器测试"""
|
||
|
||
def setup_method(self):
|
||
self.classifier = HeuristicClassifier()
|
||
|
||
def test_short_greeting_low_complexity(self):
|
||
"""短问候语 → 低复杂度"""
|
||
score = self.classifier.classify("你好呀")
|
||
assert score < 0.3
|
||
|
||
def test_simple_question_medium_complexity(self):
|
||
"""含'如何'的简单问题 → 中等复杂度"""
|
||
score = self.classifier.classify("如何使用这个功能?")
|
||
assert 0.3 <= score <= 0.7
|
||
|
||
def test_tool_request_high_complexity(self):
|
||
"""含工具关键词的请求 → 高复杂度"""
|
||
score = self.classifier.classify("帮我搜索一下最新的新闻")
|
||
assert score > 0.5
|
||
|
||
def test_code_request_high_complexity(self):
|
||
"""代码相关请求 → 高复杂度"""
|
||
score = self.classifier.classify("写一个Python函数实现快速排序")
|
||
assert score > 0.6
|
||
|
||
def test_multi_step_request_high_complexity(self):
|
||
"""多步分析请求 → 高复杂度"""
|
||
score = self.classifier.classify("分析这个数据,比较不同方案的优缺点,然后给出推荐")
|
||
assert score > 0.7
|
||
|
||
def test_empty_string_zero_complexity(self):
|
||
"""空字符串 → 零复杂度"""
|
||
assert self.classifier.classify("") == 0.0
|
||
assert self.classifier.classify(" ") == 0.0
|
||
|
||
def test_long_message_higher_complexity(self):
|
||
"""长消息 → 更高复杂度"""
|
||
short = "帮我查一下"
|
||
long = "帮我查一下" + "关于机器学习和深度学习的最新进展" * 10
|
||
assert self.classifier.classify(long) > self.classifier.classify(short)
|
||
|
||
def test_code_patterns_boost_complexity(self):
|
||
"""代码模式(反引号/括号)提升复杂度"""
|
||
with_code = "运行这段代码 `print('hello')`"
|
||
without_code = "运行这段代码 print hello"
|
||
assert self.classifier.classify(with_code) > self.classifier.classify(without_code)
|
||
|
||
def test_score_bounded_0_to_1(self):
|
||
"""复杂度值始终在 [0.0, 1.0] 范围"""
|
||
test_inputs = [
|
||
"", "你好", "如何做", "帮我搜索并分析数据,设计一个完整的解决方案,包含代码实现和部署配置",
|
||
]
|
||
for inp in test_inputs:
|
||
score = self.classifier.classify(inp)
|
||
assert 0.0 <= score <= 1.0, f"Score {score} out of range for '{inp}'"
|
||
|
||
|
||
class TestHeuristicClassifierIntegration:
|
||
"""HeuristicClassifier 在 CostAwareRouter 中的集成测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_heuristic_mode_no_llm_call(self):
|
||
"""heuristic 模式 + merged_llm_classify=False 时不调用 LLM"""
|
||
gateway = _make_llm_gateway(json.dumps({"complexity": 0.5}))
|
||
router = CostAwareRouter(llm_gateway=gateway, model="default", classifier="heuristic", merged_llm_classify=False)
|
||
result = await router.route(
|
||
content="帮我分析一下数据",
|
||
skill_registry=_make_skill_registry(),
|
||
intent_router=_make_intent_router(),
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful.",
|
||
)
|
||
# LLM gateway.chat 不应被调用(heuristic + merged disabled)
|
||
gateway.chat.assert_not_called()
|
||
# 复杂度应来自启发式分类器
|
||
assert result.complexity > 0.0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_llm_mode_uses_llm(self):
|
||
"""llm 模式下调用 LLM quick_classify"""
|
||
gateway = _make_llm_gateway(json.dumps({"complexity": 0.5}))
|
||
router = CostAwareRouter(llm_gateway=gateway, model="default", classifier="llm")
|
||
result = await router.route(
|
||
content="帮我分析一下数据",
|
||
skill_registry=_make_skill_registry(),
|
||
intent_router=_make_intent_router(),
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful.",
|
||
)
|
||
# LLM gateway.chat 应被调用
|
||
gateway.chat.assert_called()
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_heuristic_greeting_still_layer0(self):
|
||
"""heuristic 模式下问候仍走 Layer 0"""
|
||
router = CostAwareRouter(classifier="heuristic")
|
||
result = await router.route(
|
||
content="你好",
|
||
skill_registry=_make_skill_registry(),
|
||
intent_router=_make_intent_router(),
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful.",
|
||
)
|
||
assert result.match_method == "greeting"
|
||
assert result.complexity == 0.0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_heuristic_default_classifier_mode(self):
|
||
"""默认分类器模式为 heuristic"""
|
||
router = CostAwareRouter()
|
||
assert router._classifier == "heuristic"
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# U1: Merged LLM Classify
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestMergedLLMClassify:
|
||
"""合并路由 LLM 调用测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_merged_classify_returns_valid_skill(self):
|
||
"""合并调用返回有效 JSON + skill_hint,正确路由到指定 skill"""
|
||
merged_response = json.dumps({
|
||
"complexity": 0.6,
|
||
"intent": "code_generation",
|
||
"skill_hint": "search",
|
||
})
|
||
gateway = _make_llm_gateway(merged_response)
|
||
search_skill = _make_skill("search", keywords=["搜索"], description="搜索信息")
|
||
registry = _make_skill_registry([search_skill])
|
||
|
||
router = CostAwareRouter(llm_gateway=gateway, model="default", merged_llm_classify=True)
|
||
result = await router.route(
|
||
content="帮我搜索一下最新的新闻",
|
||
skill_registry=registry,
|
||
intent_router=_make_intent_router(),
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful.",
|
||
)
|
||
assert result.matched is True
|
||
assert result.skill_name == "search"
|
||
assert result.match_method == "merged_llm"
|
||
assert result.complexity > 0.3
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_merged_classify_malformed_response_fallback(self):
|
||
"""合并调用返回格式异常,fallback 到默认 Agent"""
|
||
gateway = _make_llm_gateway("这不是JSON")
|
||
router = CostAwareRouter(llm_gateway=gateway, model="default", merged_llm_classify=True)
|
||
result = await router.route(
|
||
content="帮我分析一下数据",
|
||
skill_registry=_make_skill_registry(),
|
||
intent_router=_make_intent_router(),
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful.",
|
||
)
|
||
assert result.match_method == "merged_llm_fallback"
|
||
assert result.complexity == 0.5
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_merged_classify_low_complexity(self):
|
||
"""合并调用返回 complexity < 0.3,走低复杂度路由"""
|
||
merged_response = json.dumps({
|
||
"complexity": 0.2,
|
||
"intent": "greeting",
|
||
"skill_hint": None,
|
||
})
|
||
gateway = _make_llm_gateway(merged_response)
|
||
router = CostAwareRouter(llm_gateway=gateway, model="default", merged_llm_classify=True)
|
||
result = await router.route(
|
||
content="如何使用这个功能?", # heuristic returns ~0.45, triggers merged call
|
||
skill_registry=_make_skill_registry(),
|
||
intent_router=_make_intent_router(),
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful.",
|
||
)
|
||
# Merged LLM returned complexity < 0.3, should route to low complexity
|
||
assert result.complexity < 0.3
|
||
assert "low" in result.match_method or "merged_llm_low" in result.match_method
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_merged_classify_high_complexity(self):
|
||
"""合并调用返回 complexity > 0.7,走 Layer 2"""
|
||
merged_response = json.dumps({
|
||
"complexity": 0.85,
|
||
"intent": "research",
|
||
"skill_hint": None,
|
||
})
|
||
gateway = _make_llm_gateway(merged_response)
|
||
org_context = MagicMock()
|
||
org_context.find_best_agent = MagicMock(return_value="researcher")
|
||
|
||
router = CostAwareRouter(
|
||
llm_gateway=gateway, model="default",
|
||
org_context=org_context, merged_llm_classify=True,
|
||
)
|
||
result = await router.route(
|
||
content="做市场调研+竞品分析",
|
||
skill_registry=_make_skill_registry(),
|
||
intent_router=_make_intent_router(),
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful.",
|
||
)
|
||
assert result.complexity > 0.7
|
||
assert result.match_method == "capability"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_merged_classify_disabled_falls_back_to_intent_router(self):
|
||
"""配置 merged_llm_classify=False 时回退到独立 IntentRouter"""
|
||
gateway = _make_llm_gateway(json.dumps({"complexity": 0.5}))
|
||
search_skill = _make_skill("search", keywords=["分析"], description="数据分析")
|
||
registry = _make_skill_registry([search_skill])
|
||
intent_gateway = _make_llm_gateway(json.dumps({"skill": "search", "confidence": 0.9}))
|
||
intent_router = IntentRouter(llm_gateway=intent_gateway, model="default")
|
||
|
||
router = CostAwareRouter(
|
||
llm_gateway=gateway, model="default",
|
||
merged_llm_classify=False,
|
||
)
|
||
result = await router.route(
|
||
content="分析下这个数据",
|
||
skill_registry=registry,
|
||
intent_router=intent_router,
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful.",
|
||
)
|
||
# Should not use merged_llm match_method
|
||
assert result.match_method != "merged_llm"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_merged_classify_no_llm_gateway_falls_back(self):
|
||
"""无 LLM Gateway 时 _classify_merged 回退到 IntentRouter"""
|
||
search_skill = _make_skill("search", keywords=["分析"], description="数据分析")
|
||
registry = _make_skill_registry([search_skill])
|
||
|
||
router = CostAwareRouter(llm_gateway=None, merged_llm_classify=True)
|
||
result = await router.route(
|
||
content="分析下这个数据",
|
||
skill_registry=registry,
|
||
intent_router=_make_intent_router(),
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful.",
|
||
)
|
||
# Should not crash, should use IntentRouter fallback
|
||
assert result is not None
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_merged_classify_skill_hint_not_found_fallback(self):
|
||
"""合并调用返回的 skill_hint 在 registry 中不存在,fallback"""
|
||
merged_response = json.dumps({
|
||
"complexity": 0.5,
|
||
"intent": "unknown",
|
||
"skill_hint": "nonexistent_skill",
|
||
})
|
||
gateway = _make_llm_gateway(merged_response)
|
||
router = CostAwareRouter(llm_gateway=gateway, model="default", merged_llm_classify=True)
|
||
result = await router.route(
|
||
content="帮我分析一下数据",
|
||
skill_registry=_make_skill_registry(),
|
||
intent_router=_make_intent_router(),
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful.",
|
||
)
|
||
# Should fallback to default agent (medium complexity, no skill match)
|
||
assert result.matched is False
|
||
assert result.match_method == "merged_llm_medium"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_merged_classify_only_one_llm_call(self):
|
||
"""合并调用模式下,中等复杂度只产生 1 次 LLM 调用"""
|
||
merged_response = json.dumps({
|
||
"complexity": 0.5,
|
||
"intent": "question",
|
||
"skill_hint": None,
|
||
})
|
||
gateway = _make_llm_gateway(merged_response)
|
||
router = CostAwareRouter(llm_gateway=gateway, model="default", merged_llm_classify=True)
|
||
await router.route(
|
||
content="如何使用这个功能?",
|
||
skill_registry=_make_skill_registry(),
|
||
intent_router=_make_intent_router(),
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful.",
|
||
)
|
||
# Only 1 LLM call should have been made (the merged classify)
|
||
assert gateway.chat.call_count == 1
|