fischer-agentkit/tests/unit/chat/test_skill_routing.py

333 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Unit tests for CostAwareRouter team upgrade logic and HeuristicClassifier."""
from __future__ import annotations
from unittest.mock import MagicMock
from agentkit.chat.skill_routing import (
CostAwareRouter,
ExecutionMode,
HeuristicClassifier,
SkillRoutingResult,
)
from agentkit.experts.config import ExpertConfig, ExpertTemplate
from agentkit.experts.registry import ExpertTemplateRegistry
from agentkit.experts.router import ExpertTeamRouter
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_router(expert_team_router: ExpertTeamRouter | None = None) -> CostAwareRouter:
"""Create a CostAwareRouter with mocked dependencies."""
return CostAwareRouter(
llm_gateway=None,
model="test",
classifier="heuristic",
expert_team_router=expert_team_router,
)
def _make_team_router_with_templates() -> ExpertTeamRouter:
"""Create an ExpertTeamRouter with sample templates."""
registry = ExpertTemplateRegistry()
for name in ("analyst", "strategist", "reviewer"):
config = ExpertConfig(
name=name,
agent_type="expert",
persona=f"Expert in {name}",
thinking_style="analytical",
bound_skills=[],
is_lead=(name == "analyst"),
task_mode="llm_generate",
prompt={"identity": f"Expert in {name}"},
)
template = ExpertTemplate(
name=name,
config=config,
description=f"Handles {name} tasks",
)
registry.register(template)
return ExpertTeamRouter(template_registry=registry)
def _make_team_router_empty() -> ExpertTeamRouter:
"""Create an ExpertTeamRouter with no templates."""
return ExpertTeamRouter(template_registry=ExpertTemplateRegistry())
# ---------------------------------------------------------------------------
# Tests: ExpertTeamRouter.can_handle()
# ---------------------------------------------------------------------------
class TestExpertTeamRouterCanHandle:
def test_can_handle_with_templates(self) -> None:
router = _make_team_router_with_templates()
assert router.can_handle("analyze this data") is True
def test_can_handle_no_templates(self) -> None:
router = _make_team_router_empty()
assert router.can_handle("analyze this data") is False
def test_can_handle_name_match(self) -> None:
router = _make_team_router_with_templates()
assert router.can_handle("I need a strategist for this") is True
def test_can_handle_description_match(self) -> None:
router = _make_team_router_with_templates()
assert router.can_handle("handles review tasks") is True
# ---------------------------------------------------------------------------
# Tests: _try_team_upgrade()
# ---------------------------------------------------------------------------
class TestTryTeamUpgrade:
def test_upgrade_react_to_team_collab(self) -> None:
router = _make_router(expert_team_router=_make_team_router_with_templates())
result = SkillRoutingResult(
clean_content="complex multi-step analysis task",
matched=True,
match_method="capability",
match_confidence=0.8,
complexity=0.8,
execution_mode=ExecutionMode.REACT,
)
trace: list[dict] = []
upgraded = router._try_team_upgrade(result, "complex multi-step analysis task", 0.8, trace)
assert upgraded.execution_mode == ExecutionMode.TEAM_COLLAB
assert any(t.get("method") == "team_upgrade" for t in trace)
def test_no_upgrade_low_complexity(self) -> None:
router = _make_router(expert_team_router=_make_team_router_with_templates())
result = SkillRoutingResult(
clean_content="simple question",
matched=True,
match_method="capability",
match_confidence=0.8,
complexity=0.3,
execution_mode=ExecutionMode.REACT,
)
trace: list[dict] = []
upgraded = router._try_team_upgrade(result, "simple question", 0.3, trace)
assert upgraded.execution_mode == ExecutionMode.REACT
assert not any(t.get("method") == "team_upgrade" for t in trace)
def test_no_upgrade_no_team_router(self) -> None:
router = _make_router(expert_team_router=None)
result = SkillRoutingResult(
clean_content="complex analysis",
matched=True,
match_method="capability",
match_confidence=0.8,
complexity=0.9,
execution_mode=ExecutionMode.REACT,
)
trace: list[dict] = []
upgraded = router._try_team_upgrade(result, "complex analysis", 0.9, trace)
assert upgraded.execution_mode == ExecutionMode.REACT
def test_no_upgrade_empty_templates(self) -> None:
router = _make_router(expert_team_router=_make_team_router_empty())
result = SkillRoutingResult(
clean_content="complex analysis",
matched=True,
match_method="capability",
match_confidence=0.8,
complexity=0.8,
execution_mode=ExecutionMode.REACT,
)
trace: list[dict] = []
upgraded = router._try_team_upgrade(result, "complex analysis", 0.8, trace)
assert upgraded.execution_mode == ExecutionMode.REACT
def test_no_upgrade_direct_chat_mode(self) -> None:
router = _make_router(expert_team_router=_make_team_router_with_templates())
result = SkillRoutingResult(
clean_content="hello",
matched=False,
match_method="greeting",
match_confidence=1.0,
complexity=0.0,
execution_mode=ExecutionMode.DIRECT_CHAT,
)
trace: list[dict] = []
upgraded = router._try_team_upgrade(result, "hello", 0.0, trace)
assert upgraded.execution_mode == ExecutionMode.DIRECT_CHAT
def test_team_upgrade_exception_handled(self) -> None:
"""When ExpertTeamRouter raises, the upgrade is silently skipped."""
broken_router = MagicMock()
broken_router.can_handle.side_effect = RuntimeError("boom")
router = _make_router(expert_team_router=broken_router)
result = SkillRoutingResult(
clean_content="complex task",
matched=True,
match_method="capability",
match_confidence=0.8,
complexity=0.8,
execution_mode=ExecutionMode.REACT,
)
trace: list[dict] = []
upgraded = router._try_team_upgrade(result, "complex task", 0.8, trace)
assert upgraded.execution_mode == ExecutionMode.REACT
# ---------------------------------------------------------------------------
# Tests: ExpertTeamRouter.resolve() with complexity
# ---------------------------------------------------------------------------
class TestExpertTeamRouterResolve:
def test_explicit_team_prefix(self) -> None:
router = _make_team_router_with_templates()
result = router.resolve("@team:analyst,strategist analyze the market", 0.5)
assert result.team_mode is True
assert result.match_method == "explicit_team"
assert "analyst" in result.specified_experts
assert "strategist" in result.specified_experts
def test_complexity_suggestion(self) -> None:
router = _make_team_router_with_templates()
result = router.resolve("complex multi-step analysis", 0.8)
assert result.team_mode is True
assert result.match_method == "complexity_suggestion"
assert result.auto_compose is True
def test_no_team_low_complexity(self) -> None:
router = _make_team_router_with_templates()
result = router.resolve("simple question", 0.2)
assert result.team_mode is False
# ---------------------------------------------------------------------------
# Tests: HeuristicClassifier complexity calibration
# ---------------------------------------------------------------------------
class TestHeuristicClassifierLowComplexity:
"""Low-complexity signals should produce scores < 0.3."""
def setup_method(self) -> None:
self.clf = HeuristicClassifier()
def test_chinese_greeting(self) -> None:
assert self.clf.classify("你好") < 0.3
def test_chinese_greeting_hi(self) -> None:
assert self.clf.classify("") < 0.3
def test_english_greeting_hello(self) -> None:
assert self.clf.classify("Hello") < 0.3
def test_english_greeting_hi(self) -> None:
assert self.clf.classify("hi") < 0.3
def test_multiple_low_complexity_words(self) -> None:
assert self.clf.classify("嗨,早上好") < 0.3
def test_greeting_with_high_complexity_word_not_suppressed(self) -> None:
"""Low-complexity signal should NOT override high-complexity signal."""
# "你好" is low, but "分析" is high → should score high
assert self.clf.classify("你好,请帮我分析一下这个数据") > 0.5
class TestHeuristicClassifierIdentity:
"""Identity queries should produce scores < 0.3."""
def setup_method(self) -> None:
self.clf = HeuristicClassifier()
def test_who_are_you_cn(self) -> None:
assert self.clf.classify("你是谁") < 0.3
def test_what_is_your_name_cn(self) -> None:
assert self.clf.classify("你叫什么") < 0.3
class TestHeuristicClassifierNegation:
"""Negated high-complexity words should not contribute to score."""
def setup_method(self) -> None:
self.clf = HeuristicClassifier()
def test_negate_search_cn(self) -> None:
assert self.clf.classify("不要搜索") < 0.3
def test_negate_analyze_cn(self) -> None:
assert self.clf.classify("无需分析,直接告诉我答案") < 0.3
def test_partial_negation_still_high(self) -> None:
"""'搜索' negated but '分析' not — should still be high."""
assert self.clf.classify("分析市场趋势,但不要搜索") > 0.5
class TestHeuristicClassifierThresholds:
"""Verify adjusted base scores."""
def setup_method(self) -> None:
self.clf = HeuristicClassifier()
def test_no_keyword_short_message(self) -> None:
assert self.clf.classify("好的") <= 0.10
def test_medium_complexity_base(self) -> None:
"""Medium complexity keyword should start at 0.35 (not 0.45)."""
score = self.clf.classify("如何使用Python")
# '如何' is medium → base 0.35, '' short question → -0.10 = 0.25
# but 'Python' is not in high/medium lists, so just medium base
assert 0.25 <= score <= 0.45
class TestHeuristicClassifierShortQuestion:
"""Short questions ending with /? should get deduction."""
def setup_method(self) -> None:
self.clf = HeuristicClassifier()
def test_short_question_deduction(self) -> None:
assert self.clf.classify("怎么用?") < 0.3
def test_long_question_no_deduction(self) -> None:
assert self.clf.classify("如何设计一个高可用的微服务架构?") > 0.5
class TestHeuristicClassifierHighComplexity:
"""Complex tasks should produce scores > 0.7."""
def setup_method(self) -> None:
self.clf = HeuristicClassifier()
def test_two_high_complexity_words(self) -> None:
# "分析" + "搜索" are both in _HIGH_COMPLEXITY_HINTS_CN → base 0.80
assert self.clf.classify("分析市场数据并搜索相关信息") > 0.7
def test_single_high_complexity_word(self) -> None:
# "分析" alone → base 0.65
assert self.clf.classify("分析市场趋势并生成报告") > 0.6
def test_execute_and_restart(self) -> None:
assert self.clf.classify("执行部署脚本并重启服务") > 0.7
class TestHeuristicClassifierEdgeCases:
"""Boundary conditions."""
def setup_method(self) -> None:
self.clf = HeuristicClassifier()
def test_empty_string(self) -> None:
assert self.clf.classify("") == 0.0
def test_whitespace_only(self) -> None:
assert self.clf.classify(" ") == 0.0
def test_long_low_complexity_message(self) -> None:
"""Even a long greeting should stay low."""
long_greeting = "你好" * 100 # >200 chars
assert self.clf.classify(long_greeting) <= 0.15