fischer-agentkit/tests/unit/test_cost_aware_router.py

469 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""CostAwareRouter 单元测试 - 三层成本感知路由"""
import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from agentkit.chat.skill_routing import CostAwareRouter, SkillRoutingResult
from agentkit.llm.protocol import LLMResponse, TokenUsage
from agentkit.router.intent import IntentRouter, RoutingResult
from agentkit.skills.base import IntentConfig, Skill, SkillConfig
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_skill(
name: str,
keywords: list[str] | None = None,
description: str = "",
examples: list[str] | None = None,
) -> Skill:
"""快速构造一个带 intent 配置的 Skill"""
config = SkillConfig(
name=name,
agent_type="test",
task_mode="llm_generate",
prompt={"system": f"You are a {name} skill."},
intent={
"keywords": keywords or [],
"description": description,
"examples": examples or [],
},
)
return Skill(config=config)
def _make_llm_gateway(response_content: str) -> MagicMock:
"""构造一个 mock LLMGatewaychat 返回指定 content"""
gateway = MagicMock()
gateway.chat = AsyncMock(
return_value=LLMResponse(
content=response_content,
model="test-model",
usage=TokenUsage(prompt_tokens=10, completion_tokens=20),
)
)
return gateway
def _make_skill_registry(skills: list[Skill] | None = None) -> MagicMock:
"""构造一个 mock SkillRegistry"""
registry = MagicMock()
_skills = skills or []
registry.list_skills.return_value = _skills
def _get(name: str):
for s in _skills:
if s.name == name:
return s
raise KeyError(f"Skill '{name}' not found")
registry.get = MagicMock(side_effect=_get)
return registry
def _make_intent_router() -> IntentRouter:
"""构造一个无 LLM 的 IntentRouter仅关键词匹配"""
return IntentRouter(llm_gateway=None, model="default")
# ---------------------------------------------------------------------------
# Layer 0: Rule-based (zero cost)
# ---------------------------------------------------------------------------
class TestLayer0Greeting:
"""Layer 0: 问候模式匹配"""
@pytest.mark.asyncio
async def test_chinese_greeting_hits_layer0(self):
"""'你好' 命中 Layer 0 问候规则,零 token 成本"""
router = CostAwareRouter()
result = await router.route(
content="你好",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
assert result.match_method == "greeting"
assert result.complexity == 0.0
assert result.agent_name == "default"
assert result.matched is False
@pytest.mark.asyncio
async def test_english_greeting_hits_layer0(self):
"""'hello' 命中 Layer 0 问候规则"""
router = CostAwareRouter()
result = await router.route(
content="hello",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
assert result.match_method == "greeting"
assert result.complexity == 0.0
@pytest.mark.asyncio
async def test_greeting_with_punctuation(self):
"""'你好!' 带标点也命中 Layer 0"""
router = CostAwareRouter()
result = await router.route(
content="你好!",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
assert result.match_method == "greeting"
class TestLayer0ChatMode:
"""Layer 0: 简单对话模式"""
@pytest.mark.asyncio
async def test_thanks_hits_chat_mode(self):
"""'谢谢' 命中 Layer 0 简单对话模式"""
router = CostAwareRouter()
result = await router.route(
content="谢谢",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
assert result.match_method == "chat_mode"
assert result.complexity == 0.0
@pytest.mark.asyncio
async def test_ok_hits_chat_mode(self):
"""'好的' 命中 Layer 0 简单对话模式"""
router = CostAwareRouter()
result = await router.route(
content="好的",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
assert result.match_method == "chat_mode"
class TestLayer0ExplicitSkill:
"""Layer 0: @skill: 显式前缀"""
@pytest.mark.asyncio
async def test_skill_prefix_hits_layer0(self):
"""'@skill:search 搜索XX' 命中 Layer 0 显式 Skill 规则,零 token 成本"""
search_skill = _make_skill("search", keywords=["搜索"], description="搜索信息")
registry = _make_skill_registry([search_skill])
# 需要 IntentRouter 支持 LLM fallback
gateway = _make_llm_gateway(json.dumps({"skill": "search", "confidence": 0.9}))
intent_router = IntentRouter(llm_gateway=gateway, model="default")
router = CostAwareRouter()
result = await router.route(
content="@skill:search 搜索XX",
skill_registry=registry,
intent_router=intent_router,
default_tools=[],
default_system_prompt="You are helpful.",
)
assert result.matched is True
assert result.skill_name == "search"
assert result.complexity == 0.0
# ---------------------------------------------------------------------------
# Layer 1: LLM quick classify (~100 tokens)
# ---------------------------------------------------------------------------
class TestLayer1Classification:
"""Layer 1: LLM 快速分类"""
@pytest.mark.asyncio
async def test_medium_complexity_routes_via_intent_router(self):
"""'分析下这个数据' 经过 Layer 1 LLM 分类,中等复杂度走 IntentRouter"""
# LLM 返回中等复杂度
gateway = _make_llm_gateway(json.dumps({"complexity": 0.5}))
search_skill = _make_skill("search", keywords=["分析"], description="数据分析")
registry = _make_skill_registry([search_skill])
# IntentRouter 也需要 LLM
intent_gateway = _make_llm_gateway(json.dumps({"skill": "search", "confidence": 0.9}))
intent_router = IntentRouter(llm_gateway=intent_gateway, model="default")
router = CostAwareRouter(llm_gateway=gateway, model="default")
result = await router.route(
content="分析下这个数据",
skill_registry=registry,
intent_router=intent_router,
default_tools=[],
default_system_prompt="You are helpful.",
)
assert 0.3 <= result.complexity <= 0.7
@pytest.mark.asyncio
async def test_low_complexity_routes_to_default(self):
"""低复杂度 (<0.3) 路由到默认 Agent"""
gateway = _make_llm_gateway(json.dumps({"complexity": 0.1}))
router = CostAwareRouter(llm_gateway=gateway, model="default")
result = await router.route(
content="随便聊聊",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
assert result.complexity < 0.3
assert result.match_method == "low_complexity"
assert result.agent_name == "default"
@pytest.mark.asyncio
async def test_no_llm_gateway_defaults_to_medium(self):
"""无 LLM Gateway 时 quick_classify 返回 0.5(中等复杂度)"""
router = CostAwareRouter(llm_gateway=None)
complexity = await router.quick_classify("分析下这个数据")
assert complexity == 0.5
@pytest.mark.asyncio
async def test_llm_malformed_response_defaults_to_medium(self):
"""LLM 返回非 JSON 时 quick_classify 返回 0.5"""
gateway = _make_llm_gateway("这不是JSON")
router = CostAwareRouter(llm_gateway=gateway, model="default")
complexity = await router.quick_classify("分析下这个数据")
assert complexity == 0.5
@pytest.mark.asyncio
async def test_complexity_clamped_to_0_1(self):
"""复杂度值被限制在 [0.0, 1.0] 范围"""
gateway = _make_llm_gateway(json.dumps({"complexity": 1.5}))
router = CostAwareRouter(llm_gateway=gateway, model="default")
complexity = await router.quick_classify("超级复杂任务")
assert complexity == 1.0
gateway2 = _make_llm_gateway(json.dumps({"complexity": -0.5}))
router2 = CostAwareRouter(llm_gateway=gateway2, model="default")
complexity2 = await router2.quick_classify("简单任务")
assert complexity2 == 0.0
# ---------------------------------------------------------------------------
# Layer 2: Capability matching / Auction
# ---------------------------------------------------------------------------
class TestLayer2CapabilityMatching:
"""Layer 2: 能力匹配 / 拍卖"""
@pytest.mark.asyncio
async def test_high_complexity_triggers_capability_matching(self):
"""'做市场调研+竞品分析' 复杂度 > 0.7,触发能力匹配"""
gateway = _make_llm_gateway(json.dumps({"complexity": 0.85}))
org_context = MagicMock()
org_context.find_best_agent = MagicMock(return_value="market-researcher")
router = CostAwareRouter(llm_gateway=gateway, model="default", org_context=org_context)
result = await router.route(
content="做市场调研+竞品分析",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
assert result.complexity > 0.7
assert result.match_method == "capability"
assert result.agent_name == "market-researcher"
assert result.matched is True
@pytest.mark.asyncio
async def test_layer2_with_org_context_object(self):
"""org_context.find_best_agent 返回对象时提取 name 属性"""
gateway = _make_llm_gateway(json.dumps({"complexity": 0.9}))
agent_obj = MagicMock()
agent_obj.name = "analyst-agent"
org_context = MagicMock()
org_context.find_best_agent = MagicMock(return_value=agent_obj)
router = CostAwareRouter(llm_gateway=gateway, model="default", org_context=org_context)
result = await router.route(
content="做市场调研+竞品分析",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
assert result.agent_name == "analyst-agent"
@pytest.mark.asyncio
async def test_layer2_without_org_context_falls_back_to_intent_router(self):
"""无 org_context 时 Layer 2 回退到 IntentRouter"""
gateway = _make_llm_gateway(json.dumps({"complexity": 0.8}))
search_skill = _make_skill("search", keywords=["调研"], description="市场调研")
registry = _make_skill_registry([search_skill])
intent_gateway = _make_llm_gateway(json.dumps({"skill": "search", "confidence": 0.9}))
intent_router = IntentRouter(llm_gateway=intent_gateway, model="default")
router = CostAwareRouter(llm_gateway=gateway, model="default", org_context=None)
result = await router.route(
content="做市场调研+竞品分析",
skill_registry=registry,
intent_router=intent_router,
default_tools=[],
default_system_prompt="You are helpful.",
)
assert result.complexity > 0.7
# 回退到 IntentRouter可能匹配到 skill 或走 default
assert result.match_method in ("capability", "keyword", "llm", "intent_router_fallback", None)
@pytest.mark.asyncio
async def test_layer2_org_context_find_best_agent_returns_none(self):
"""org_context.find_best_agent 返回 None 时回退到 IntentRouter"""
gateway = _make_llm_gateway(json.dumps({"complexity": 0.8}))
org_context = MagicMock()
org_context.find_best_agent = AsyncMock(return_value=None)
search_skill = _make_skill("search", keywords=["调研"], description="市场调研")
registry = _make_skill_registry([search_skill])
intent_gateway = _make_llm_gateway(json.dumps({"skill": "search", "confidence": 0.9}))
intent_router = IntentRouter(llm_gateway=intent_gateway, model="default")
router = CostAwareRouter(llm_gateway=gateway, model="default", org_context=org_context)
result = await router.route(
content="做市场调研+竞品分析",
skill_registry=registry,
intent_router=intent_router,
default_tools=[],
default_system_prompt="You are helpful.",
)
assert result.complexity > 0.7
@pytest.mark.asyncio
async def test_auction_disabled_by_default(self):
"""拍卖模式默认禁用"""
router = CostAwareRouter()
assert router._auction_enabled is False
@pytest.mark.asyncio
async def test_auction_can_be_enabled(self):
"""拍卖模式可手动启用"""
router = CostAwareRouter(auction_enabled=True)
assert router._auction_enabled is True
# ---------------------------------------------------------------------------
# Transparency
# ---------------------------------------------------------------------------
class TestTransparency:
"""透明度级别切换"""
@pytest.mark.asyncio
async def test_silent_mode_no_trace(self):
"""SILENT 模式不暴露路由追踪"""
router = CostAwareRouter()
result = await router.route(
content="你好",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
transparency="SILENT",
)
assert result.execution_trace == []
assert result.transparency_level == "SILENT"
@pytest.mark.asyncio
async def test_verbose_mode_shows_trace(self):
"""VERBOSE 模式显示路由追踪"""
router = CostAwareRouter()
result = await router.route(
content="你好",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
transparency="VERBOSE",
)
assert len(result.execution_trace) > 0
assert result.execution_trace[0]["layer"] == 0
assert result.execution_trace[0]["method"] == "greeting"
assert result.transparency_level == "VERBOSE"
@pytest.mark.asyncio
async def test_trace_mode_shows_full_trace(self):
"""TRACE 模式显示完整路由追踪"""
gateway = _make_llm_gateway(json.dumps({"complexity": 0.85}))
org_context = MagicMock()
org_context.find_best_agent = AsyncMock(return_value="analyst")
router = CostAwareRouter(llm_gateway=gateway, model="default", org_context=org_context)
result = await router.route(
content="做市场调研+竞品分析",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
transparency="TRACE",
)
assert len(result.execution_trace) > 0
# 应包含 Layer 1 quick_classify 和 Layer 2 的记录
layers = [t["layer"] for t in result.execution_trace]
assert 1 in layers # Layer 1 quick_classify
assert 2 in layers # Layer 2 capability matching
assert result.transparency_level == "TRACE"
@pytest.mark.asyncio
async def test_default_transparency_is_silent(self):
"""默认透明度为 SILENT"""
router = CostAwareRouter()
result = await router.route(
content="你好",
skill_registry=_make_skill_registry(),
intent_router=_make_intent_router(),
default_tools=[],
default_system_prompt="You are helpful.",
)
assert result.transparency_level == "SILENT"
assert result.execution_trace == []
# ---------------------------------------------------------------------------
# SkillRoutingResult 新字段
# ---------------------------------------------------------------------------
class TestSkillRoutingResultNewFields:
"""SkillRoutingResult 新字段验证"""
def test_default_transparency_level(self):
result = SkillRoutingResult()
assert result.transparency_level == "SILENT"
def test_default_execution_trace(self):
result = SkillRoutingResult()
assert result.execution_trace == []
def test_default_complexity(self):
result = SkillRoutingResult()
assert result.complexity == 0.0
def test_new_fields_backward_compatible(self):
"""新字段不影响旧代码创建 SkillRoutingResult"""
result = SkillRoutingResult(
skill_name="test",
matched=True,
match_method="keyword",
)
assert result.transparency_level == "SILENT"
assert result.execution_trace == []
assert result.complexity == 0.0