fischer-agentkit/tests/unit/test_cost_aware_router.py

809 lines
32 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""CostAwareRouter 单元测试 - 三层成本感知路由"""
import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from agentkit.chat.skill_routing import CostAwareRouter, HeuristicClassifier, SkillRoutingResult, _tokenize_content
from agentkit.llm.protocol import LLMResponse, TokenUsage
from agentkit.router.intent import IntentRouter, RoutingResult
from agentkit.skills.base import IntentConfig, Skill, SkillConfig
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_skill(
name: str,
keywords: list[str] | None = None,
description: str = "",
examples: list[str] | None = None,
) -> Skill:
"""快速构造一个带 intent 配置的 Skill"""
config = SkillConfig(
name=name,
agent_type="test",
task_mode="llm_generate",
prompt={"system": f"You are a {name} skill."},
intent={
"keywords": keywords or [],
"description": description,
"examples": examples or [],
},
)
return Skill(config=config)
def _make_llm_gateway(response_content: str) -> MagicMock:
"""构造一个 mock LLMGatewaychat 返回指定 content"""
gateway = MagicMock()
gateway.chat = AsyncMock(
return_value=LLMResponse(
content=response_content,
model="test-model",
usage=TokenUsage(prompt_tokens=10, completion_tokens=20),
)
)
return gateway
def _make_skill_registry(skills: list[Skill] | None = None) -> MagicMock:
"""构造一个 mock SkillRegistry"""
registry = MagicMock()
_skills = skills or []
registry.list_skills.return_value = _skills
def _get(name: str):
for s in _skills:
if s.name == name:
return s
raise KeyError(f"Skill '{name}' not found")
registry.get = MagicMock(side_effect=_get)
return registry
def _make_intent_router() -> IntentRouter:
"""构造一个无 LLM 的 IntentRouter仅关键词匹配"""
return IntentRouter(llm_gateway=None, model="default")
# ---------------------------------------------------------------------------
# Layer 0: Rule-based (zero cost)
# ---------------------------------------------------------------------------
class TestLayer0Greeting:
"""Layer 0: 问候模式匹配"""
@pytest.mark.asyncio
async def test_chinese_greeting_hits_layer0(self):
"""'你好' 命中 Layer 0 问候规则,零 token 成本"""
router = CostAwareRouter()
result = await router.route(
content="你好",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
assert result.match_method == "greeting"
assert result.complexity == 0.0
assert result.agent_name == "default"
assert result.matched is False
@pytest.mark.asyncio
async def test_english_greeting_hits_layer0(self):
"""'hello' 命中 Layer 0 问候规则"""
router = CostAwareRouter()
result = await router.route(
content="hello",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
assert result.match_method == "greeting"
assert result.complexity == 0.0
@pytest.mark.asyncio
async def test_greeting_with_punctuation(self):
"""'你好!' 带标点也命中 Layer 0"""
router = CostAwareRouter()
result = await router.route(
content="你好!",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
assert result.match_method == "greeting"
class TestLayer0ChatMode:
"""Layer 0: 简单对话模式"""
@pytest.mark.asyncio
async def test_thanks_hits_chat_mode(self):
"""'谢谢' 命中 Layer 0 简单对话模式"""
router = CostAwareRouter()
result = await router.route(
content="谢谢",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
assert result.match_method == "chat_mode"
assert result.complexity == 0.0
@pytest.mark.asyncio
async def test_ok_hits_chat_mode(self):
"""'好的' 命中 Layer 0 简单对话模式"""
router = CostAwareRouter()
result = await router.route(
content="好的",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
assert result.match_method == "chat_mode"
class TestLayer0ExplicitSkill:
"""Layer 0: @skill: 显式前缀"""
@pytest.mark.asyncio
async def test_skill_prefix_hits_layer0(self):
"""'@skill:search 搜索XX' 命中 Layer 0 显式 Skill 规则,零 token 成本"""
search_skill = _make_skill("search", keywords=["搜索"], description="搜索信息")
registry = _make_skill_registry([search_skill])
# 需要 IntentRouter 支持 LLM fallback
gateway = _make_llm_gateway(json.dumps({"skill": "search", "confidence": 0.9}))
intent_router = IntentRouter(llm_gateway=gateway, model="default")
router = CostAwareRouter()
result = await router.route(
content="@skill:search 搜索XX",
skill_registry=registry,
intent_router=intent_router,
default_tools=[],
default_system_prompt="You are helpful.",
)
assert result.matched is True
assert result.skill_name == "search"
assert result.complexity == 0.0
# ---------------------------------------------------------------------------
# Layer 1: LLM quick classify (~100 tokens)
# ---------------------------------------------------------------------------
class TestLayer1Classification:
"""Layer 1: LLM 快速分类"""
@pytest.mark.asyncio
async def test_medium_complexity_routes_via_intent_router(self):
"""'分析下这个数据' 经过 Layer 1 LLM 分类,中等复杂度走 IntentRouter"""
# LLM 返回中等复杂度
gateway = _make_llm_gateway(json.dumps({"complexity": 0.5}))
search_skill = _make_skill("search", keywords=["分析"], description="数据分析")
registry = _make_skill_registry([search_skill])
# IntentRouter 也需要 LLM
intent_gateway = _make_llm_gateway(json.dumps({"skill": "search", "confidence": 0.9}))
intent_router = IntentRouter(llm_gateway=intent_gateway, model="default")
router = CostAwareRouter(llm_gateway=gateway, model="default")
result = await router.route(
content="分析下这个数据",
skill_registry=registry,
intent_router=intent_router,
default_tools=[],
default_system_prompt="You are helpful.",
)
assert 0.3 <= result.complexity <= 0.7
@pytest.mark.asyncio
async def test_low_complexity_routes_to_default(self):
"""低复杂度 (<0.3) 路由到默认 Agent"""
gateway = _make_llm_gateway(json.dumps({"complexity": 0.1}))
router = CostAwareRouter(llm_gateway=gateway, model="default")
result = await router.route(
content="随便聊聊",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
assert result.complexity < 0.3
assert result.match_method == "low_complexity"
assert result.agent_name == "default"
@pytest.mark.asyncio
async def test_no_llm_gateway_defaults_to_medium(self):
"""无 LLM Gateway 时 quick_classify 返回 0.5(中等复杂度)"""
router = CostAwareRouter(llm_gateway=None)
complexity = await router.quick_classify("分析下这个数据")
assert complexity == 0.5
@pytest.mark.asyncio
async def test_llm_malformed_response_defaults_to_medium(self):
"""LLM 返回非 JSON 时 quick_classify 返回 0.5"""
gateway = _make_llm_gateway("这不是JSON")
router = CostAwareRouter(llm_gateway=gateway, model="default")
complexity = await router.quick_classify("分析下这个数据")
assert complexity == 0.5
@pytest.mark.asyncio
async def test_complexity_clamped_to_0_1(self):
"""复杂度值被限制在 [0.0, 1.0] 范围"""
gateway = _make_llm_gateway(json.dumps({"complexity": 1.5}))
router = CostAwareRouter(llm_gateway=gateway, model="default")
complexity = await router.quick_classify("超级复杂任务")
assert complexity == 1.0
gateway2 = _make_llm_gateway(json.dumps({"complexity": -0.5}))
router2 = CostAwareRouter(llm_gateway=gateway2, model="default")
complexity2 = await router2.quick_classify("简单任务")
assert complexity2 == 0.0
# ---------------------------------------------------------------------------
# Layer 2: Capability matching / Auction
# ---------------------------------------------------------------------------
class TestLayer2CapabilityMatching:
"""Layer 2: 能力匹配 / 拍卖"""
@pytest.mark.asyncio
async def test_high_complexity_triggers_capability_matching(self):
"""'做市场调研+竞品分析' 复杂度 > 0.7,触发能力匹配"""
gateway = _make_llm_gateway(json.dumps({"complexity": 0.85}))
org_context = MagicMock()
org_context.find_best_agent = MagicMock(return_value="market-researcher")
router = CostAwareRouter(llm_gateway=gateway, model="default", org_context=org_context)
result = await router.route(
content="做市场调研+竞品分析",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
assert result.complexity > 0.7
assert result.match_method == "capability"
assert result.agent_name == "market-researcher"
assert result.matched is True
@pytest.mark.asyncio
async def test_layer2_with_org_context_object(self):
"""org_context.find_best_agent 返回对象时提取 name 属性"""
gateway = _make_llm_gateway(json.dumps({"complexity": 0.9}))
agent_obj = MagicMock()
agent_obj.name = "analyst-agent"
org_context = MagicMock()
org_context.find_best_agent = MagicMock(return_value=agent_obj)
router = CostAwareRouter(llm_gateway=gateway, model="default", org_context=org_context)
result = await router.route(
content="做市场调研+竞品分析",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
assert result.agent_name == "analyst-agent"
@pytest.mark.asyncio
async def test_layer2_without_org_context_falls_back_to_intent_router(self):
"""无 org_context 时 Layer 2 回退到 IntentRouter"""
gateway = _make_llm_gateway(json.dumps({"complexity": 0.8}))
search_skill = _make_skill("search", keywords=["调研"], description="市场调研")
registry = _make_skill_registry([search_skill])
intent_gateway = _make_llm_gateway(json.dumps({"skill": "search", "confidence": 0.9}))
intent_router = IntentRouter(llm_gateway=intent_gateway, model="default")
router = CostAwareRouter(llm_gateway=gateway, model="default", org_context=None)
result = await router.route(
content="做市场调研+竞品分析",
skill_registry=registry,
intent_router=intent_router,
default_tools=[],
default_system_prompt="You are helpful.",
)
assert result.complexity > 0.7
# 回退到 IntentRouter可能匹配到 skill 或走 default
assert result.match_method in ("capability", "keyword", "llm", "intent_router_fallback", None)
@pytest.mark.asyncio
async def test_layer2_org_context_find_best_agent_returns_none(self):
"""org_context.find_best_agent 返回 None 时回退到 IntentRouter"""
gateway = _make_llm_gateway(json.dumps({"complexity": 0.8}))
org_context = MagicMock()
org_context.find_best_agent = MagicMock(return_value=None)
search_skill = _make_skill("search", keywords=["调研"], description="市场调研")
registry = _make_skill_registry([search_skill])
intent_gateway = _make_llm_gateway(json.dumps({"skill": "search", "confidence": 0.9}))
intent_router = IntentRouter(llm_gateway=intent_gateway, model="default")
router = CostAwareRouter(llm_gateway=gateway, model="default", org_context=org_context)
result = await router.route(
content="做市场调研+竞品分析",
skill_registry=registry,
intent_router=intent_router,
default_tools=[],
default_system_prompt="You are helpful.",
)
assert result.complexity > 0.7
@pytest.mark.asyncio
async def test_auction_disabled_by_default(self):
"""拍卖模式默认禁用"""
router = CostAwareRouter()
assert router._auction_enabled is False
@pytest.mark.asyncio
async def test_auction_can_be_enabled(self):
"""拍卖模式可手动启用"""
router = CostAwareRouter(auction_enabled=True)
assert router._auction_enabled is True
# ---------------------------------------------------------------------------
# Transparency
# ---------------------------------------------------------------------------
class TestTransparency:
"""透明度级别切换"""
@pytest.mark.asyncio
async def test_silent_mode_no_trace(self):
"""SILENT 模式不暴露路由追踪"""
router = CostAwareRouter()
result = await router.route(
content="你好",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
transparency="SILENT",
)
assert result.execution_trace == []
assert result.transparency_level == "SILENT"
@pytest.mark.asyncio
async def test_verbose_mode_shows_trace(self):
"""VERBOSE 模式显示路由追踪"""
router = CostAwareRouter()
result = await router.route(
content="你好",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
transparency="VERBOSE",
)
assert len(result.execution_trace) > 0
assert result.execution_trace[0]["layer"] == 0
assert result.execution_trace[0]["method"] == "greeting"
assert result.transparency_level == "VERBOSE"
@pytest.mark.asyncio
async def test_trace_mode_shows_full_trace(self):
"""TRACE 模式显示完整路由追踪"""
gateway = _make_llm_gateway(json.dumps({"complexity": 0.85}))
org_context = MagicMock()
org_context.find_best_agent = MagicMock(return_value="analyst")
router = CostAwareRouter(llm_gateway=gateway, model="default", org_context=org_context)
result = await router.route(
content="做市场调研+竞品分析",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
transparency="TRACE",
)
assert len(result.execution_trace) > 0
# 应包含 Layer 1 quick_classify 和 Layer 2 的记录
layers = [t["layer"] for t in result.execution_trace]
assert 1 in layers # Layer 1 quick_classify
assert 2 in layers # Layer 2 capability matching
assert result.transparency_level == "TRACE"
@pytest.mark.asyncio
async def test_default_transparency_is_silent(self):
"""默认透明度为 SILENT"""
router = CostAwareRouter()
result = await router.route(
content="你好",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
assert result.transparency_level == "SILENT"
assert result.execution_trace == []
# ---------------------------------------------------------------------------
# SkillRoutingResult 新字段
# ---------------------------------------------------------------------------
class TestSkillRoutingResultNewFields:
"""SkillRoutingResult 新字段验证"""
def test_default_transparency_level(self):
result = SkillRoutingResult()
assert result.transparency_level == "SILENT"
def test_default_execution_trace(self):
result = SkillRoutingResult()
assert result.execution_trace == []
def test_default_complexity(self):
result = SkillRoutingResult()
assert result.complexity == 0.0
def test_new_fields_backward_compatible(self):
"""新字段不影响旧代码创建 SkillRoutingResult"""
result = SkillRoutingResult(
skill_name="test",
matched=True,
match_method="keyword",
)
assert result.transparency_level == "SILENT"
assert result.execution_trace == []
assert result.complexity == 0.0
# ---------------------------------------------------------------------------
# _tokenize_content: 中文分词增强
# ---------------------------------------------------------------------------
class TestTokenizeContent:
"""_tokenize_content 中文分词增强测试"""
def test_chinese_content(self):
"""中文内容:'帮我做数据分析' 应包含 '数据分析' 相关 2-gram"""
tokens = _tokenize_content("帮我做数据分析")
# 整段无标点分隔,生成 2-gram帮我、我做、做数、数据、据分、分析
assert "数据" in tokens or "数据分析" in tokens
def test_english_content(self):
"""英文内容:'help with code generation' 应包含 'code', 'generation''code generation'"""
tokens = _tokenize_content("help with code generation")
assert "code" in tokens or "generation" in tokens or "code generation" in tokens
def test_mixed_content(self):
"""中英混合:'用python做data analysis' 应包含 'python' 相关 token 和 'data analysis'"""
tokens = _tokenize_content("用python做data analysis")
# 按空格分割后 "用python做data" 作为一个 segment生成 2-gram
# "analysis" 作为独立 segment
assert "analysis" in tokens
# "用python做data" 长度 > 4会生成 2-gram其中包含 python 相关片段
has_python_related = any("python" in t for t in tokens)
assert has_python_related or "data analysis" in tokens
def test_stopwords_filtered(self):
"""停用词过滤:纯停用词短句过滤后应为空或极少 token"""
tokens = _tokenize_content("我的一个")
# "我的一个" 长度 4作为整体保留不在停用词集合中
# 但停用词 "我的" "的一" "一个" 等 2-gram 会被过滤
assert len(tokens) <= 1
def test_bigram_generation(self):
"""2-gram 生成:'机器学习模型训练' 应包含各 2-gram"""
tokens = _tokenize_content("机器学习模型训练")
expected_bigrams = ["机器", "器学", "学习", "习模", "模型", "型训", "训练"]
for bigram in expected_bigrams:
assert bigram in tokens, f"缺少 2-gram: {bigram}"
# ---------------------------------------------------------------------------
# HeuristicClassifier
# ---------------------------------------------------------------------------
class TestHeuristicClassifier:
"""HeuristicClassifier 本地启发式分类器测试"""
def setup_method(self):
self.classifier = HeuristicClassifier()
def test_short_greeting_low_complexity(self):
"""短问候语 → 低复杂度"""
score = self.classifier.classify("你好呀")
assert score < 0.3
def test_simple_question_medium_complexity(self):
"""'如何'的简单问题 → 中等复杂度"""
score = self.classifier.classify("如何使用这个功能?")
assert 0.3 <= score <= 0.7
def test_tool_request_high_complexity(self):
"""含工具关键词的请求 → 高复杂度"""
score = self.classifier.classify("帮我搜索一下最新的新闻")
assert score > 0.5
def test_code_request_high_complexity(self):
"""代码相关请求 → 高复杂度"""
score = self.classifier.classify("写一个Python函数实现快速排序")
assert score > 0.6
def test_multi_step_request_high_complexity(self):
"""多步分析请求 → 高复杂度"""
score = self.classifier.classify("分析这个数据,比较不同方案的优缺点,然后给出推荐")
assert score > 0.7
def test_empty_string_zero_complexity(self):
"""空字符串 → 零复杂度"""
assert self.classifier.classify("") == 0.0
assert self.classifier.classify(" ") == 0.0
def test_long_message_higher_complexity(self):
"""长消息 → 更高复杂度"""
short = "帮我查一下"
long = "帮我查一下" + "关于机器学习和深度学习的最新进展" * 10
assert self.classifier.classify(long) > self.classifier.classify(short)
def test_code_patterns_boost_complexity(self):
"""代码模式(反引号/括号)提升复杂度"""
with_code = "运行这段代码 `print('hello')`"
without_code = "运行这段代码 print hello"
assert self.classifier.classify(with_code) > self.classifier.classify(without_code)
def test_score_bounded_0_to_1(self):
"""复杂度值始终在 [0.0, 1.0] 范围"""
test_inputs = [
"", "你好", "如何做", "帮我搜索并分析数据,设计一个完整的解决方案,包含代码实现和部署配置",
]
for inp in test_inputs:
score = self.classifier.classify(inp)
assert 0.0 <= score <= 1.0, f"Score {score} out of range for '{inp}'"
class TestHeuristicClassifierIntegration:
"""HeuristicClassifier 在 CostAwareRouter 中的集成测试"""
@pytest.mark.asyncio
async def test_heuristic_mode_no_llm_call(self):
"""heuristic 模式 + merged_llm_classify=False 时不调用 LLM"""
gateway = _make_llm_gateway(json.dumps({"complexity": 0.5}))
router = CostAwareRouter(llm_gateway=gateway, model="default", classifier="heuristic", merged_llm_classify=False)
result = await router.route(
content="帮我分析一下数据",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
# LLM gateway.chat 不应被调用heuristic + merged disabled
gateway.chat.assert_not_called()
# 复杂度应来自启发式分类器
assert result.complexity > 0.0
@pytest.mark.asyncio
async def test_llm_mode_uses_llm(self):
"""llm 模式下调用 LLM quick_classify"""
gateway = _make_llm_gateway(json.dumps({"complexity": 0.5}))
router = CostAwareRouter(llm_gateway=gateway, model="default", classifier="llm")
result = await router.route(
content="帮我分析一下数据",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
# LLM gateway.chat 应被调用
gateway.chat.assert_called()
@pytest.mark.asyncio
async def test_heuristic_greeting_still_layer0(self):
"""heuristic 模式下问候仍走 Layer 0"""
router = CostAwareRouter(classifier="heuristic")
result = await router.route(
content="你好",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
assert result.match_method == "greeting"
assert result.complexity == 0.0
@pytest.mark.asyncio
async def test_heuristic_default_classifier_mode(self):
"""默认分类器模式为 heuristic"""
router = CostAwareRouter()
assert router._classifier == "heuristic"
# ---------------------------------------------------------------------------
# U1: Merged LLM Classify
# ---------------------------------------------------------------------------
class TestMergedLLMClassify:
"""合并路由 LLM 调用测试"""
@pytest.mark.asyncio
async def test_merged_classify_returns_valid_skill(self):
"""合并调用返回有效 JSON + skill_hint正确路由到指定 skill"""
merged_response = json.dumps({
"complexity": 0.6,
"intent": "code_generation",
"skill_hint": "search",
})
gateway = _make_llm_gateway(merged_response)
search_skill = _make_skill("search", keywords=["搜索"], description="搜索信息")
registry = _make_skill_registry([search_skill])
router = CostAwareRouter(llm_gateway=gateway, model="default", merged_llm_classify=True)
result = await router.route(
content="帮我搜索一下最新的新闻",
skill_registry=registry,
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
assert result.matched is True
assert result.skill_name == "search"
assert result.match_method == "merged_llm"
assert result.complexity > 0.3
@pytest.mark.asyncio
async def test_merged_classify_malformed_response_fallback(self):
"""合并调用返回格式异常fallback 到默认 Agent"""
gateway = _make_llm_gateway("这不是JSON")
router = CostAwareRouter(llm_gateway=gateway, model="default", merged_llm_classify=True)
result = await router.route(
content="帮我分析一下数据",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
assert result.match_method == "merged_llm_fallback"
assert result.complexity == 0.5
@pytest.mark.asyncio
async def test_merged_classify_low_complexity(self):
"""合并调用返回 complexity < 0.3,走低复杂度路由"""
merged_response = json.dumps({
"complexity": 0.2,
"intent": "greeting",
"skill_hint": None,
})
gateway = _make_llm_gateway(merged_response)
router = CostAwareRouter(llm_gateway=gateway, model="default", merged_llm_classify=True)
result = await router.route(
content="如何使用这个功能?", # heuristic returns ~0.45, triggers merged call
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
# Merged LLM returned complexity < 0.3, should route to low complexity
assert result.complexity < 0.3
assert "low" in result.match_method or "merged_llm_low" in result.match_method
@pytest.mark.asyncio
async def test_merged_classify_high_complexity(self):
"""合并调用返回 complexity > 0.7,走 Layer 2"""
merged_response = json.dumps({
"complexity": 0.85,
"intent": "research",
"skill_hint": None,
})
gateway = _make_llm_gateway(merged_response)
org_context = MagicMock()
org_context.find_best_agent = MagicMock(return_value="researcher")
router = CostAwareRouter(
llm_gateway=gateway, model="default",
org_context=org_context, merged_llm_classify=True,
)
result = await router.route(
content="做市场调研+竞品分析",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
assert result.complexity > 0.7
assert result.match_method == "capability"
@pytest.mark.asyncio
async def test_merged_classify_disabled_falls_back_to_intent_router(self):
"""配置 merged_llm_classify=False 时回退到独立 IntentRouter"""
gateway = _make_llm_gateway(json.dumps({"complexity": 0.5}))
search_skill = _make_skill("search", keywords=["分析"], description="数据分析")
registry = _make_skill_registry([search_skill])
intent_gateway = _make_llm_gateway(json.dumps({"skill": "search", "confidence": 0.9}))
intent_router = IntentRouter(llm_gateway=intent_gateway, model="default")
router = CostAwareRouter(
llm_gateway=gateway, model="default",
merged_llm_classify=False,
)
result = await router.route(
content="分析下这个数据",
skill_registry=registry,
intent_router=intent_router,
default_tools=[],
default_system_prompt="You are helpful.",
)
# Should not use merged_llm match_method
assert result.match_method != "merged_llm"
@pytest.mark.asyncio
async def test_merged_classify_no_llm_gateway_falls_back(self):
"""无 LLM Gateway 时 _classify_merged 回退到 IntentRouter"""
search_skill = _make_skill("search", keywords=["分析"], description="数据分析")
registry = _make_skill_registry([search_skill])
router = CostAwareRouter(llm_gateway=None, merged_llm_classify=True)
result = await router.route(
content="分析下这个数据",
skill_registry=registry,
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
# Should not crash, should use IntentRouter fallback
assert result is not None
@pytest.mark.asyncio
async def test_merged_classify_skill_hint_not_found_fallback(self):
"""合并调用返回的 skill_hint 在 registry 中不存在fallback"""
merged_response = json.dumps({
"complexity": 0.5,
"intent": "unknown",
"skill_hint": "nonexistent_skill",
})
gateway = _make_llm_gateway(merged_response)
router = CostAwareRouter(llm_gateway=gateway, model="default", merged_llm_classify=True)
result = await router.route(
content="帮我分析一下数据",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
# Should fallback to default agent (medium complexity, no skill match)
assert result.matched is False
assert result.match_method == "merged_llm_medium"
@pytest.mark.asyncio
async def test_merged_classify_only_one_llm_call(self):
"""合并调用模式下,中等复杂度只产生 1 次 LLM 调用"""
merged_response = json.dumps({
"complexity": 0.5,
"intent": "question",
"skill_hint": None,
})
gateway = _make_llm_gateway(merged_response)
router = CostAwareRouter(llm_gateway=gateway, model="default", merged_llm_classify=True)
await router.route(
content="如何使用这个功能?",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
# Only 1 LLM call should have been made (the merged classify)
assert gateway.chat.call_count == 1