369 lines
12 KiB
Python
369 lines
12 KiB
Python
"""集成测试 - CostAwareRouter 合并 LLM 分类功能
|
||
|
||
测试 merged_llm_classify 参数控制下单次/双次 LLM 调用的路由行为。
|
||
仅 mock LLMGateway(外部 API),使用真实 CostAwareRouter 实例。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
from unittest.mock import AsyncMock, MagicMock
|
||
|
||
import pytest
|
||
|
||
from agentkit.chat.skill_routing import CostAwareRouter, SkillRoutingResult
|
||
from agentkit.llm.gateway import LLMGateway
|
||
from agentkit.llm.protocol import LLMResponse, TokenUsage
|
||
from agentkit.org.context import AgentProfile, OrganizationContext
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Helpers
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def make_response(
|
||
content: str = "",
|
||
prompt_tokens: int = 10,
|
||
completion_tokens: int = 20,
|
||
) -> LLMResponse:
|
||
return LLMResponse(
|
||
content=content,
|
||
model="test-model",
|
||
usage=TokenUsage(
|
||
prompt_tokens=prompt_tokens,
|
||
completion_tokens=completion_tokens,
|
||
),
|
||
)
|
||
|
||
|
||
def make_mock_gateway(responses: list[LLMResponse]) -> MagicMock:
|
||
gateway = MagicMock(spec=LLMGateway)
|
||
gateway.chat = AsyncMock(side_effect=responses)
|
||
return gateway
|
||
|
||
|
||
def make_mock_skill_registry(skill_name: str | None = None) -> MagicMock:
|
||
"""创建 mock skill_registry,可选地包含一个 skill"""
|
||
registry = MagicMock()
|
||
if skill_name:
|
||
mock_skill = MagicMock()
|
||
mock_skill.config.intent.keywords = ["test"]
|
||
mock_skill.config.llm = None
|
||
mock_skill.config.prompt = None
|
||
mock_skill.tools = []
|
||
registry.list_skills.return_value = [mock_skill]
|
||
mock_skill.name = skill_name
|
||
registry.get.return_value = mock_skill
|
||
else:
|
||
registry.list_skills.return_value = []
|
||
return registry
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Test 1: Merged classify 返回有效 JSON + skill_hint → 路由到指定 skill
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestMergedClassifyValidSkillHint:
|
||
"""合并分类返回有效 JSON 并包含 skill_hint,路由到指定 skill"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_merged_classify_routes_to_skill(self):
|
||
"""merged LLM classify 返回 skill_hint 时路由到对应 skill"""
|
||
merged_response = make_response(
|
||
content=json.dumps({
|
||
"complexity": 0.5,
|
||
"intent": "code_review",
|
||
"skill_hint": "code_reviewer",
|
||
}),
|
||
)
|
||
|
||
gateway = make_mock_gateway([merged_response])
|
||
skill_registry = make_mock_skill_registry("code_reviewer")
|
||
|
||
router = CostAwareRouter(
|
||
llm_gateway=gateway,
|
||
merged_llm_classify=True,
|
||
)
|
||
|
||
# 使用中等复杂度的内容触发 merged classify
|
||
# "如何优化代码" 包含 "如何"(中等复杂度)和 "代码"(高复杂度)
|
||
result = await router.route(
|
||
content="如何优化代码性能",
|
||
skill_registry=skill_registry,
|
||
intent_router=AsyncMock(),
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful",
|
||
default_model="default",
|
||
default_agent_name="default",
|
||
)
|
||
|
||
assert result.matched is True
|
||
assert result.skill_name == "code_reviewer"
|
||
assert result.match_method == "merged_llm"
|
||
assert result.complexity == 0.5
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_merged_classify_skill_hint_not_found_falls_back(self):
|
||
"""merged LLM classify 返回的 skill_hint 不存在时回退到默认"""
|
||
merged_response = make_response(
|
||
content=json.dumps({
|
||
"complexity": 0.4,
|
||
"intent": "unknown",
|
||
"skill_hint": "nonexistent_skill",
|
||
}),
|
||
)
|
||
|
||
gateway = make_mock_gateway([merged_response])
|
||
skill_registry = make_mock_skill_registry() # 空 skill 列表
|
||
|
||
router = CostAwareRouter(
|
||
llm_gateway=gateway,
|
||
merged_llm_classify=True,
|
||
)
|
||
|
||
result = await router.route(
|
||
content="如何优化代码性能",
|
||
skill_registry=skill_registry,
|
||
intent_router=AsyncMock(),
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful",
|
||
default_model="default",
|
||
default_agent_name="default",
|
||
)
|
||
|
||
# skill_hint 指向不存在的 skill,回退到默认 agent
|
||
assert result.matched is False
|
||
assert result.agent_name == "default"
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Test 2: Merged classify 返回格式错误 → 回退到默认 agent
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestMergedClassifyFormatError:
|
||
"""合并分类返回格式错误时回退到默认 agent"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_invalid_json_falls_back_to_default(self):
|
||
"""LLM 返回无效 JSON 时回退到默认 agent"""
|
||
invalid_response = make_response(content="This is not JSON at all")
|
||
|
||
gateway = make_mock_gateway([invalid_response])
|
||
skill_registry = make_mock_skill_registry()
|
||
|
||
router = CostAwareRouter(
|
||
llm_gateway=gateway,
|
||
merged_llm_classify=True,
|
||
)
|
||
|
||
result = await router.route(
|
||
content="如何优化代码性能",
|
||
skill_registry=skill_registry,
|
||
intent_router=AsyncMock(),
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful",
|
||
default_model="default",
|
||
default_agent_name="default",
|
||
)
|
||
|
||
assert result.matched is False
|
||
assert result.agent_name == "default"
|
||
assert result.match_method == "merged_llm_fallback"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_partial_json_falls_back_to_default(self):
|
||
"""LLM 返回部分 JSON 时回退到默认 agent"""
|
||
partial_response = make_response(
|
||
content='{"complexity": 0.5, "intent":', # 不完整 JSON
|
||
)
|
||
|
||
gateway = make_mock_gateway([partial_response])
|
||
skill_registry = make_mock_skill_registry()
|
||
|
||
router = CostAwareRouter(
|
||
llm_gateway=gateway,
|
||
merged_llm_classify=True,
|
||
)
|
||
|
||
result = await router.route(
|
||
content="如何优化代码性能",
|
||
skill_registry=skill_registry,
|
||
intent_router=AsyncMock(),
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful",
|
||
default_model="default",
|
||
default_agent_name="default",
|
||
)
|
||
|
||
assert result.matched is False
|
||
assert result.match_method == "merged_llm_fallback"
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Test 3: merged_llm_classify=False → 使用独立 IntentRouter(2 次 LLM 调用)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestMergedClassifyDisabled:
|
||
"""merged_llm_classify=False 时使用独立的 IntentRouter 路由"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_disabled_uses_separate_intent_router(self):
|
||
"""禁用合并分类时使用 resolve_skill_routing(独立路由)"""
|
||
# 当 merged_llm_classify=False 且 classifier="heuristic" 时,
|
||
# 中等复杂度走 resolve_skill_routing 而非 _classify_merged
|
||
gateway = make_mock_gateway([]) # 不应有 LLM 调用
|
||
skill_registry = make_mock_skill_registry()
|
||
|
||
# mock intent_router
|
||
mock_intent_router = AsyncMock()
|
||
|
||
router = CostAwareRouter(
|
||
llm_gateway=gateway,
|
||
merged_llm_classify=False,
|
||
)
|
||
|
||
result = await router.route(
|
||
content="如何优化代码性能",
|
||
skill_registry=skill_registry,
|
||
intent_router=mock_intent_router,
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful",
|
||
default_model="default",
|
||
default_agent_name="default",
|
||
)
|
||
|
||
# 应该走 resolve_skill_routing 而非 _classify_merged
|
||
# 结果是默认 agent(因为没有匹配的 skill)
|
||
assert result.agent_name == "default"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_disabled_no_merged_llm_call(self):
|
||
"""禁用合并分类时不应调用 _classify_merged"""
|
||
call_count = 0
|
||
|
||
class CountingGateway:
|
||
def __init__(self):
|
||
self.chat = AsyncMock(side_effect=self._count_and_respond)
|
||
|
||
async def _count_and_respond(self, **kwargs):
|
||
nonlocal call_count
|
||
call_count += 1
|
||
return make_response('{"complexity": 0.5}')
|
||
|
||
gateway = CountingGateway()
|
||
skill_registry = make_mock_skill_registry()
|
||
|
||
router = CostAwareRouter(
|
||
llm_gateway=gateway,
|
||
merged_llm_classify=False,
|
||
)
|
||
|
||
await router.route(
|
||
content="如何优化代码性能",
|
||
skill_registry=skill_registry,
|
||
intent_router=AsyncMock(),
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful",
|
||
default_model="default",
|
||
default_agent_name="default",
|
||
)
|
||
|
||
# merged_llm_classify=False + heuristic classifier → 不应有 LLM 调用
|
||
# heuristic classifier 是零成本的,不调用 LLM
|
||
assert call_count == 0, f"Expected 0 LLM calls with heuristic classifier, got {call_count}"
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Test 4: Merged classify 返回高复杂度 → 委派到 Layer 2
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestMergedClassifyHighComplexity:
|
||
"""合并分类返回高复杂度时委派到 Layer 2"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_high_complexity_delegates_to_layer2(self):
|
||
"""merged classify 返回 complexity > 0.7 时委派到 Layer 2"""
|
||
merged_response = make_response(
|
||
content=json.dumps({
|
||
"complexity": 0.85,
|
||
"intent": "deep_analysis",
|
||
"skill_hint": None,
|
||
}),
|
||
)
|
||
|
||
gateway = make_mock_gateway([merged_response])
|
||
|
||
org_context = OrganizationContext()
|
||
org_context.register_agent(AgentProfile(
|
||
name="analyst",
|
||
agent_type="react",
|
||
capabilities=["分析", "市场", "调研"],
|
||
skills=["market_analysis"],
|
||
current_load=0,
|
||
))
|
||
org_context.find_best_agent = MagicMock(
|
||
return_value=org_context.get_agent_profile("analyst")
|
||
)
|
||
|
||
skill_registry = make_mock_skill_registry()
|
||
|
||
router = CostAwareRouter(
|
||
llm_gateway=gateway,
|
||
org_context=org_context,
|
||
merged_llm_classify=True,
|
||
)
|
||
|
||
result = await router.route(
|
||
content="请对市场趋势进行深度分析并给出投资建议",
|
||
skill_registry=skill_registry,
|
||
intent_router=AsyncMock(),
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful",
|
||
default_model="default",
|
||
default_agent_name="default",
|
||
)
|
||
|
||
# 高复杂度应委派到 Layer 2 → 通过 org_context 能力匹配
|
||
assert result.complexity >= 0.7
|
||
assert result.matched is True
|
||
assert result.agent_name == "analyst"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_high_complexity_no_org_context_uses_intent_router(self):
|
||
"""高复杂度但无 org_context 时回退到 IntentRouter"""
|
||
merged_response = make_response(
|
||
content=json.dumps({
|
||
"complexity": 0.8,
|
||
"intent": "deep_analysis",
|
||
"skill_hint": None,
|
||
}),
|
||
)
|
||
|
||
gateway = make_mock_gateway([merged_response])
|
||
skill_registry = make_mock_skill_registry()
|
||
|
||
router = CostAwareRouter(
|
||
llm_gateway=gateway,
|
||
org_context=None,
|
||
merged_llm_classify=True,
|
||
)
|
||
|
||
result = await router.route(
|
||
content="请对市场趋势进行深度分析并给出投资建议",
|
||
skill_registry=skill_registry,
|
||
intent_router=AsyncMock(),
|
||
default_tools=[],
|
||
default_system_prompt="You are helpful",
|
||
default_model="default",
|
||
default_agent_name="default",
|
||
)
|
||
|
||
# 无 org_context,回退到 IntentRouter → 默认 agent
|
||
assert result.complexity >= 0.7
|
||
assert result.agent_name == "default"
|