"""Unit tests for CostAwareRouter team upgrade logic and HeuristicClassifier.""" from __future__ import annotations from unittest.mock import MagicMock from agentkit.chat.skill_routing import ( CostAwareRouter, ExecutionMode, HeuristicClassifier, SkillRoutingResult, ) from agentkit.experts.config import ExpertConfig, ExpertTemplate from agentkit.experts.registry import ExpertTemplateRegistry from agentkit.experts.router import ExpertTeamRouter # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _make_router(expert_team_router: ExpertTeamRouter | None = None) -> CostAwareRouter: """Create a CostAwareRouter with mocked dependencies.""" return CostAwareRouter( llm_gateway=None, model="test", classifier="heuristic", expert_team_router=expert_team_router, ) def _make_team_router_with_templates() -> ExpertTeamRouter: """Create an ExpertTeamRouter with sample templates.""" registry = ExpertTemplateRegistry() for name in ("analyst", "strategist", "reviewer"): config = ExpertConfig( name=name, agent_type="expert", persona=f"Expert in {name}", thinking_style="analytical", bound_skills=[], is_lead=(name == "analyst"), task_mode="llm_generate", prompt={"identity": f"Expert in {name}"}, ) template = ExpertTemplate( name=name, config=config, description=f"Handles {name} tasks", ) registry.register(template) return ExpertTeamRouter(template_registry=registry) def _make_team_router_empty() -> ExpertTeamRouter: """Create an ExpertTeamRouter with no templates.""" return ExpertTeamRouter(template_registry=ExpertTemplateRegistry()) # --------------------------------------------------------------------------- # Tests: ExpertTeamRouter.can_handle() # --------------------------------------------------------------------------- class TestExpertTeamRouterCanHandle: def test_can_handle_with_templates(self) -> None: router = _make_team_router_with_templates() assert router.can_handle("analyze this data") is True def test_can_handle_no_templates(self) -> None: router = _make_team_router_empty() assert router.can_handle("analyze this data") is False def test_can_handle_name_match(self) -> None: router = _make_team_router_with_templates() assert router.can_handle("I need a strategist for this") is True def test_can_handle_description_match(self) -> None: router = _make_team_router_with_templates() assert router.can_handle("handles review tasks") is True # --------------------------------------------------------------------------- # Tests: _try_team_upgrade() # --------------------------------------------------------------------------- class TestTryTeamUpgrade: def test_upgrade_react_to_team_collab(self) -> None: router = _make_router(expert_team_router=_make_team_router_with_templates()) result = SkillRoutingResult( clean_content="complex multi-step analysis task", matched=True, match_method="capability", match_confidence=0.8, complexity=0.8, execution_mode=ExecutionMode.REACT, ) trace: list[dict] = [] upgraded = router._try_team_upgrade(result, "complex multi-step analysis task", 0.8, trace) assert upgraded.execution_mode == ExecutionMode.TEAM_COLLAB assert any(t.get("method") == "team_upgrade" for t in trace) def test_no_upgrade_low_complexity(self) -> None: router = _make_router(expert_team_router=_make_team_router_with_templates()) result = SkillRoutingResult( clean_content="simple question", matched=True, match_method="capability", match_confidence=0.8, complexity=0.3, execution_mode=ExecutionMode.REACT, ) trace: list[dict] = [] upgraded = router._try_team_upgrade(result, "simple question", 0.3, trace) assert upgraded.execution_mode == ExecutionMode.REACT assert not any(t.get("method") == "team_upgrade" for t in trace) def test_no_upgrade_no_team_router(self) -> None: router = _make_router(expert_team_router=None) result = SkillRoutingResult( clean_content="complex analysis", matched=True, match_method="capability", match_confidence=0.8, complexity=0.9, execution_mode=ExecutionMode.REACT, ) trace: list[dict] = [] upgraded = router._try_team_upgrade(result, "complex analysis", 0.9, trace) assert upgraded.execution_mode == ExecutionMode.REACT def test_no_upgrade_empty_templates(self) -> None: router = _make_router(expert_team_router=_make_team_router_empty()) result = SkillRoutingResult( clean_content="complex analysis", matched=True, match_method="capability", match_confidence=0.8, complexity=0.8, execution_mode=ExecutionMode.REACT, ) trace: list[dict] = [] upgraded = router._try_team_upgrade(result, "complex analysis", 0.8, trace) assert upgraded.execution_mode == ExecutionMode.REACT def test_no_upgrade_direct_chat_mode(self) -> None: router = _make_router(expert_team_router=_make_team_router_with_templates()) result = SkillRoutingResult( clean_content="hello", matched=False, match_method="greeting", match_confidence=1.0, complexity=0.0, execution_mode=ExecutionMode.DIRECT_CHAT, ) trace: list[dict] = [] upgraded = router._try_team_upgrade(result, "hello", 0.0, trace) assert upgraded.execution_mode == ExecutionMode.DIRECT_CHAT def test_team_upgrade_exception_handled(self) -> None: """When ExpertTeamRouter raises, the upgrade is silently skipped.""" broken_router = MagicMock() broken_router.can_handle.side_effect = RuntimeError("boom") router = _make_router(expert_team_router=broken_router) result = SkillRoutingResult( clean_content="complex task", matched=True, match_method="capability", match_confidence=0.8, complexity=0.8, execution_mode=ExecutionMode.REACT, ) trace: list[dict] = [] upgraded = router._try_team_upgrade(result, "complex task", 0.8, trace) assert upgraded.execution_mode == ExecutionMode.REACT # --------------------------------------------------------------------------- # Tests: ExpertTeamRouter.resolve() with complexity # --------------------------------------------------------------------------- class TestExpertTeamRouterResolve: def test_explicit_team_prefix(self) -> None: router = _make_team_router_with_templates() result = router.resolve("@team:analyst,strategist analyze the market", 0.5) assert result.team_mode is True assert result.match_method == "explicit_team" assert "analyst" in result.specified_experts assert "strategist" in result.specified_experts def test_complexity_suggestion(self) -> None: router = _make_team_router_with_templates() result = router.resolve("complex multi-step analysis", 0.8) assert result.team_mode is True assert result.match_method == "complexity_suggestion" assert result.auto_compose is True def test_no_team_low_complexity(self) -> None: router = _make_team_router_with_templates() result = router.resolve("simple question", 0.2) assert result.team_mode is False # --------------------------------------------------------------------------- # Tests: HeuristicClassifier complexity calibration # --------------------------------------------------------------------------- class TestHeuristicClassifierLowComplexity: """Low-complexity signals should produce scores < 0.3.""" def setup_method(self) -> None: self.clf = HeuristicClassifier() def test_chinese_greeting(self) -> None: assert self.clf.classify("你好") < 0.3 def test_chinese_greeting_hi(self) -> None: assert self.clf.classify("嗨") < 0.3 def test_english_greeting_hello(self) -> None: assert self.clf.classify("Hello") < 0.3 def test_english_greeting_hi(self) -> None: assert self.clf.classify("hi") < 0.3 def test_multiple_low_complexity_words(self) -> None: assert self.clf.classify("嗨,早上好") < 0.3 def test_greeting_with_high_complexity_word_not_suppressed(self) -> None: """Low-complexity signal should NOT override high-complexity signal.""" # "你好" is low, but "分析" is high → should score high assert self.clf.classify("你好,请帮我分析一下这个数据") > 0.5 class TestHeuristicClassifierIdentity: """Identity queries should produce scores < 0.3.""" def setup_method(self) -> None: self.clf = HeuristicClassifier() def test_who_are_you_cn(self) -> None: assert self.clf.classify("你是谁") < 0.3 def test_what_is_your_name_cn(self) -> None: assert self.clf.classify("你叫什么") < 0.3 class TestHeuristicClassifierNegation: """Negated high-complexity words should not contribute to score.""" def setup_method(self) -> None: self.clf = HeuristicClassifier() def test_negate_search_cn(self) -> None: assert self.clf.classify("不要搜索") < 0.3 def test_negate_analyze_cn(self) -> None: assert self.clf.classify("无需分析,直接告诉我答案") < 0.3 def test_partial_negation_still_high(self) -> None: """'搜索' negated but '分析' not — should still be high.""" assert self.clf.classify("分析市场趋势,但不要搜索") > 0.5 class TestHeuristicClassifierThresholds: """Verify adjusted base scores.""" def setup_method(self) -> None: self.clf = HeuristicClassifier() def test_no_keyword_short_message(self) -> None: assert self.clf.classify("好的") <= 0.10 def test_medium_complexity_base(self) -> None: """Medium complexity keyword should start at 0.35 (not 0.45).""" score = self.clf.classify("如何使用Python?") # '如何' is medium → base 0.35, '?' short question → -0.10 = 0.25 # but 'Python' is not in high/medium lists, so just medium base assert 0.25 <= score <= 0.45 class TestHeuristicClassifierShortQuestion: """Short questions ending with ?/? should get deduction.""" def setup_method(self) -> None: self.clf = HeuristicClassifier() def test_short_question_deduction(self) -> None: assert self.clf.classify("怎么用?") < 0.3 def test_long_question_no_deduction(self) -> None: assert self.clf.classify("如何设计一个高可用的微服务架构?") > 0.5 class TestHeuristicClassifierHighComplexity: """Complex tasks should produce scores > 0.7.""" def setup_method(self) -> None: self.clf = HeuristicClassifier() def test_two_high_complexity_words(self) -> None: # "分析" + "搜索" are both in _HIGH_COMPLEXITY_HINTS_CN → base 0.80 assert self.clf.classify("分析市场数据并搜索相关信息") > 0.7 def test_single_high_complexity_word(self) -> None: # "分析" alone → base 0.65 assert self.clf.classify("分析市场趋势并生成报告") > 0.6 def test_execute_and_restart(self) -> None: assert self.clf.classify("执行部署脚本并重启服务") > 0.7 class TestHeuristicClassifierEdgeCases: """Boundary conditions.""" def setup_method(self) -> None: self.clf = HeuristicClassifier() def test_empty_string(self) -> None: assert self.clf.classify("") == 0.0 def test_whitespace_only(self) -> None: assert self.clf.classify(" ") == 0.0 def test_long_low_complexity_message(self) -> None: """Even a long greeting should stay low.""" long_greeting = "你好" * 100 # >200 chars assert self.clf.classify(long_greeting) <= 0.15