333 lines
12 KiB
Python
333 lines
12 KiB
Python
"""Unit tests for CostAwareRouter team upgrade logic and HeuristicClassifier."""
|
||
|
||
from __future__ import annotations
|
||
|
||
from unittest.mock import MagicMock
|
||
|
||
from agentkit.chat.skill_routing import (
|
||
CostAwareRouter,
|
||
ExecutionMode,
|
||
HeuristicClassifier,
|
||
SkillRoutingResult,
|
||
)
|
||
from agentkit.experts.config import ExpertConfig, ExpertTemplate
|
||
from agentkit.experts.registry import ExpertTemplateRegistry
|
||
from agentkit.experts.router import ExpertTeamRouter
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Helpers
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def _make_router(expert_team_router: ExpertTeamRouter | None = None) -> CostAwareRouter:
|
||
"""Create a CostAwareRouter with mocked dependencies."""
|
||
return CostAwareRouter(
|
||
llm_gateway=None,
|
||
model="test",
|
||
classifier="heuristic",
|
||
expert_team_router=expert_team_router,
|
||
)
|
||
|
||
|
||
def _make_team_router_with_templates() -> ExpertTeamRouter:
|
||
"""Create an ExpertTeamRouter with sample templates."""
|
||
registry = ExpertTemplateRegistry()
|
||
for name in ("analyst", "strategist", "reviewer"):
|
||
config = ExpertConfig(
|
||
name=name,
|
||
agent_type="expert",
|
||
persona=f"Expert in {name}",
|
||
thinking_style="analytical",
|
||
bound_skills=[],
|
||
is_lead=(name == "analyst"),
|
||
task_mode="llm_generate",
|
||
prompt={"identity": f"Expert in {name}"},
|
||
)
|
||
template = ExpertTemplate(
|
||
name=name,
|
||
config=config,
|
||
description=f"Handles {name} tasks",
|
||
)
|
||
registry.register(template)
|
||
return ExpertTeamRouter(template_registry=registry)
|
||
|
||
|
||
def _make_team_router_empty() -> ExpertTeamRouter:
|
||
"""Create an ExpertTeamRouter with no templates."""
|
||
return ExpertTeamRouter(template_registry=ExpertTemplateRegistry())
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Tests: ExpertTeamRouter.can_handle()
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestExpertTeamRouterCanHandle:
|
||
def test_can_handle_with_templates(self) -> None:
|
||
router = _make_team_router_with_templates()
|
||
assert router.can_handle("analyze this data") is True
|
||
|
||
def test_can_handle_no_templates(self) -> None:
|
||
router = _make_team_router_empty()
|
||
assert router.can_handle("analyze this data") is False
|
||
|
||
def test_can_handle_name_match(self) -> None:
|
||
router = _make_team_router_with_templates()
|
||
assert router.can_handle("I need a strategist for this") is True
|
||
|
||
def test_can_handle_description_match(self) -> None:
|
||
router = _make_team_router_with_templates()
|
||
assert router.can_handle("handles review tasks") is True
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Tests: _try_team_upgrade()
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestTryTeamUpgrade:
|
||
def test_upgrade_react_to_team_collab(self) -> None:
|
||
router = _make_router(expert_team_router=_make_team_router_with_templates())
|
||
result = SkillRoutingResult(
|
||
clean_content="complex multi-step analysis task",
|
||
matched=True,
|
||
match_method="capability",
|
||
match_confidence=0.8,
|
||
complexity=0.8,
|
||
execution_mode=ExecutionMode.REACT,
|
||
)
|
||
trace: list[dict] = []
|
||
upgraded = router._try_team_upgrade(result, "complex multi-step analysis task", 0.8, trace)
|
||
assert upgraded.execution_mode == ExecutionMode.TEAM_COLLAB
|
||
assert any(t.get("method") == "team_upgrade" for t in trace)
|
||
|
||
def test_no_upgrade_low_complexity(self) -> None:
|
||
router = _make_router(expert_team_router=_make_team_router_with_templates())
|
||
result = SkillRoutingResult(
|
||
clean_content="simple question",
|
||
matched=True,
|
||
match_method="capability",
|
||
match_confidence=0.8,
|
||
complexity=0.3,
|
||
execution_mode=ExecutionMode.REACT,
|
||
)
|
||
trace: list[dict] = []
|
||
upgraded = router._try_team_upgrade(result, "simple question", 0.3, trace)
|
||
assert upgraded.execution_mode == ExecutionMode.REACT
|
||
assert not any(t.get("method") == "team_upgrade" for t in trace)
|
||
|
||
def test_no_upgrade_no_team_router(self) -> None:
|
||
router = _make_router(expert_team_router=None)
|
||
result = SkillRoutingResult(
|
||
clean_content="complex analysis",
|
||
matched=True,
|
||
match_method="capability",
|
||
match_confidence=0.8,
|
||
complexity=0.9,
|
||
execution_mode=ExecutionMode.REACT,
|
||
)
|
||
trace: list[dict] = []
|
||
upgraded = router._try_team_upgrade(result, "complex analysis", 0.9, trace)
|
||
assert upgraded.execution_mode == ExecutionMode.REACT
|
||
|
||
def test_no_upgrade_empty_templates(self) -> None:
|
||
router = _make_router(expert_team_router=_make_team_router_empty())
|
||
result = SkillRoutingResult(
|
||
clean_content="complex analysis",
|
||
matched=True,
|
||
match_method="capability",
|
||
match_confidence=0.8,
|
||
complexity=0.8,
|
||
execution_mode=ExecutionMode.REACT,
|
||
)
|
||
trace: list[dict] = []
|
||
upgraded = router._try_team_upgrade(result, "complex analysis", 0.8, trace)
|
||
assert upgraded.execution_mode == ExecutionMode.REACT
|
||
|
||
def test_no_upgrade_direct_chat_mode(self) -> None:
|
||
router = _make_router(expert_team_router=_make_team_router_with_templates())
|
||
result = SkillRoutingResult(
|
||
clean_content="hello",
|
||
matched=False,
|
||
match_method="greeting",
|
||
match_confidence=1.0,
|
||
complexity=0.0,
|
||
execution_mode=ExecutionMode.DIRECT_CHAT,
|
||
)
|
||
trace: list[dict] = []
|
||
upgraded = router._try_team_upgrade(result, "hello", 0.0, trace)
|
||
assert upgraded.execution_mode == ExecutionMode.DIRECT_CHAT
|
||
|
||
def test_team_upgrade_exception_handled(self) -> None:
|
||
"""When ExpertTeamRouter raises, the upgrade is silently skipped."""
|
||
broken_router = MagicMock()
|
||
broken_router.can_handle.side_effect = RuntimeError("boom")
|
||
router = _make_router(expert_team_router=broken_router)
|
||
result = SkillRoutingResult(
|
||
clean_content="complex task",
|
||
matched=True,
|
||
match_method="capability",
|
||
match_confidence=0.8,
|
||
complexity=0.8,
|
||
execution_mode=ExecutionMode.REACT,
|
||
)
|
||
trace: list[dict] = []
|
||
upgraded = router._try_team_upgrade(result, "complex task", 0.8, trace)
|
||
assert upgraded.execution_mode == ExecutionMode.REACT
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Tests: ExpertTeamRouter.resolve() with complexity
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestExpertTeamRouterResolve:
|
||
def test_explicit_team_prefix(self) -> None:
|
||
router = _make_team_router_with_templates()
|
||
result = router.resolve("@team:analyst,strategist analyze the market", 0.5)
|
||
assert result.team_mode is True
|
||
assert result.match_method == "explicit_team"
|
||
assert "analyst" in result.specified_experts
|
||
assert "strategist" in result.specified_experts
|
||
|
||
def test_complexity_suggestion(self) -> None:
|
||
router = _make_team_router_with_templates()
|
||
result = router.resolve("complex multi-step analysis", 0.8)
|
||
assert result.team_mode is True
|
||
assert result.match_method == "complexity_suggestion"
|
||
assert result.auto_compose is True
|
||
|
||
def test_no_team_low_complexity(self) -> None:
|
||
router = _make_team_router_with_templates()
|
||
result = router.resolve("simple question", 0.2)
|
||
assert result.team_mode is False
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Tests: HeuristicClassifier complexity calibration
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestHeuristicClassifierLowComplexity:
|
||
"""Low-complexity signals should produce scores < 0.3."""
|
||
|
||
def setup_method(self) -> None:
|
||
self.clf = HeuristicClassifier()
|
||
|
||
def test_chinese_greeting(self) -> None:
|
||
assert self.clf.classify("你好") < 0.3
|
||
|
||
def test_chinese_greeting_hi(self) -> None:
|
||
assert self.clf.classify("嗨") < 0.3
|
||
|
||
def test_english_greeting_hello(self) -> None:
|
||
assert self.clf.classify("Hello") < 0.3
|
||
|
||
def test_english_greeting_hi(self) -> None:
|
||
assert self.clf.classify("hi") < 0.3
|
||
|
||
def test_multiple_low_complexity_words(self) -> None:
|
||
assert self.clf.classify("嗨,早上好") < 0.3
|
||
|
||
def test_greeting_with_high_complexity_word_not_suppressed(self) -> None:
|
||
"""Low-complexity signal should NOT override high-complexity signal."""
|
||
# "你好" is low, but "分析" is high → should score high
|
||
assert self.clf.classify("你好,请帮我分析一下这个数据") > 0.5
|
||
|
||
|
||
class TestHeuristicClassifierIdentity:
|
||
"""Identity queries should produce scores < 0.3."""
|
||
|
||
def setup_method(self) -> None:
|
||
self.clf = HeuristicClassifier()
|
||
|
||
def test_who_are_you_cn(self) -> None:
|
||
assert self.clf.classify("你是谁") < 0.3
|
||
|
||
def test_what_is_your_name_cn(self) -> None:
|
||
assert self.clf.classify("你叫什么") < 0.3
|
||
|
||
|
||
class TestHeuristicClassifierNegation:
|
||
"""Negated high-complexity words should not contribute to score."""
|
||
|
||
def setup_method(self) -> None:
|
||
self.clf = HeuristicClassifier()
|
||
|
||
def test_negate_search_cn(self) -> None:
|
||
assert self.clf.classify("不要搜索") < 0.3
|
||
|
||
def test_negate_analyze_cn(self) -> None:
|
||
assert self.clf.classify("无需分析,直接告诉我答案") < 0.3
|
||
|
||
def test_partial_negation_still_high(self) -> None:
|
||
"""'搜索' negated but '分析' not — should still be high."""
|
||
assert self.clf.classify("分析市场趋势,但不要搜索") > 0.5
|
||
|
||
|
||
class TestHeuristicClassifierThresholds:
|
||
"""Verify adjusted base scores."""
|
||
|
||
def setup_method(self) -> None:
|
||
self.clf = HeuristicClassifier()
|
||
|
||
def test_no_keyword_short_message(self) -> None:
|
||
assert self.clf.classify("好的") <= 0.10
|
||
|
||
def test_medium_complexity_base(self) -> None:
|
||
"""Medium complexity keyword should start at 0.35 (not 0.45)."""
|
||
score = self.clf.classify("如何使用Python?")
|
||
# '如何' is medium → base 0.35, '?' short question → -0.10 = 0.25
|
||
# but 'Python' is not in high/medium lists, so just medium base
|
||
assert 0.25 <= score <= 0.45
|
||
|
||
|
||
class TestHeuristicClassifierShortQuestion:
|
||
"""Short questions ending with ?/? should get deduction."""
|
||
|
||
def setup_method(self) -> None:
|
||
self.clf = HeuristicClassifier()
|
||
|
||
def test_short_question_deduction(self) -> None:
|
||
assert self.clf.classify("怎么用?") < 0.3
|
||
|
||
def test_long_question_no_deduction(self) -> None:
|
||
assert self.clf.classify("如何设计一个高可用的微服务架构?") > 0.5
|
||
|
||
|
||
class TestHeuristicClassifierHighComplexity:
|
||
"""Complex tasks should produce scores > 0.7."""
|
||
|
||
def setup_method(self) -> None:
|
||
self.clf = HeuristicClassifier()
|
||
|
||
def test_two_high_complexity_words(self) -> None:
|
||
# "分析" + "搜索" are both in _HIGH_COMPLEXITY_HINTS_CN → base 0.80
|
||
assert self.clf.classify("分析市场数据并搜索相关信息") > 0.7
|
||
|
||
def test_single_high_complexity_word(self) -> None:
|
||
# "分析" alone → base 0.65
|
||
assert self.clf.classify("分析市场趋势并生成报告") > 0.6
|
||
|
||
def test_execute_and_restart(self) -> None:
|
||
assert self.clf.classify("执行部署脚本并重启服务") > 0.7
|
||
|
||
|
||
class TestHeuristicClassifierEdgeCases:
|
||
"""Boundary conditions."""
|
||
|
||
def setup_method(self) -> None:
|
||
self.clf = HeuristicClassifier()
|
||
|
||
def test_empty_string(self) -> None:
|
||
assert self.clf.classify("") == 0.0
|
||
|
||
def test_whitespace_only(self) -> None:
|
||
assert self.clf.classify(" ") == 0.0
|
||
|
||
def test_long_low_complexity_message(self) -> None:
|
||
"""Even a long greeting should stay low."""
|
||
long_greeting = "你好" * 100 # >200 chars
|
||
assert self.clf.classify(long_greeting) <= 0.15
|