geo/backend/tests/test_services/test_smart_router_key_integ...

166 lines
7.1 KiB
Python

import pytest
from app.services.api_key_manager import APIKeyManager, KeySource
from app.services.smart_router import ENGINE_COST_PROFILES, CostTier, SmartRouter
from app.services.engine_selector import EngineSelector
class TestSmartRouterWithKeyManager:
"""SmartRouter与APIKeyManager集成测试"""
@pytest.fixture
def key_manager(self):
km = APIKeyManager()
km.add_key("deepseek", "test-deepseek-key", source=KeySource.SYSTEM, priority=1)
km.add_key("qwen", "test-qwen-key", source=KeySource.SYSTEM, priority=1)
km.add_key("gemini", "test-gemini-key", source=KeySource.ENV, priority=0)
return km
@pytest.fixture
def key_manager_with_user_key(self):
km = APIKeyManager()
km.add_key("deepseek", "system-deepseek-key", source=KeySource.SYSTEM, priority=0)
km.add_key("deepseek", "user-deepseek-key", source=KeySource.USER, priority=10, user_id="user_123")
km.add_key("qwen", "system-qwen-key", source=KeySource.SYSTEM, priority=1)
return km
@pytest.fixture
def key_manager_no_keys(self):
return APIKeyManager()
def test_router_accepts_key_manager(self, key_manager):
router = SmartRouter(key_manager=key_manager)
assert router._key_manager is key_manager
def test_set_key_manager_after_init(self):
router = SmartRouter()
assert router._key_manager is None
key_manager = APIKeyManager()
router.set_key_manager(key_manager)
assert router._key_manager is key_manager
def test_filter_by_available_keys(self, key_manager):
router = SmartRouter(key_manager=key_manager)
all_engines = list(ENGINE_COST_PROFILES.keys())
filtered = router._filter_by_available_keys(all_engines)
assert "deepseek" in filtered
assert "qwen" in filtered
assert "gemini" in filtered
assert "chatgpt" not in filtered
assert "perplexity" not in filtered
assert "wenxin" not in filtered
def test_filter_by_available_keys_no_manager(self):
router = SmartRouter()
all_engines = list(ENGINE_COST_PROFILES.keys())
filtered = router._filter_by_available_keys(all_engines)
assert set(filtered) == set(all_engines)
def test_select_engines_filters_no_key(self, key_manager):
router = SmartRouter(key_manager=key_manager)
selected = router.select_engines(max_engines=10)
for engine in selected:
key = key_manager.get_any_available_key(engine)
assert key is not None, f"Selected engine {engine} has no available key"
def test_select_engines_returns_subset(self, key_manager):
router = SmartRouter(key_manager=key_manager)
selected = router.select_engines(max_engines=2)
assert len(selected) <= 2
assert len(selected) > 0
def test_select_engines_user_key_priority(self, key_manager_with_user_key):
router = SmartRouter(key_manager=key_manager_with_user_key)
selected = router.select_engines(max_engines=10)
assert "deepseek" in selected
key = key_manager_with_user_key.get_any_available_key("deepseek")
assert key == "user-deepseek-key"
def test_get_available_engines(self, key_manager):
router = SmartRouter(key_manager=key_manager)
available = router.get_available_engines()
for engine in available:
key = key_manager.get_any_available_key(engine)
assert key is not None
def test_get_engines_by_cost_tier_with_filter(self, key_manager):
router = SmartRouter(key_manager=key_manager)
free_engines = router.get_engines_by_cost_tier(CostTier.FREE)
for engine in free_engines:
profile = ENGINE_COST_PROFILES[engine]
assert profile.cost_tier == CostTier.FREE
key = key_manager.get_any_available_key(engine)
assert key is not None
def test_all_engines_no_keys_returns_empty(self, key_manager_no_keys):
router = SmartRouter(key_manager=key_manager_no_keys)
selected = router.select_engines(max_engines=5)
assert selected == [] or all(
ENGINE_COST_PROFILES[e].cost_tier in [CostTier.FREE, CostTier.LOW_COST]
and not ENGINE_COST_PROFILES[e].requires_own_key
for e in selected
)
class TestEngineSelector:
"""EngineSelector智能引擎选择器测试"""
@pytest.fixture
def key_manager(self):
km = APIKeyManager()
km.add_key("deepseek", "test-deepseek-key", source=KeySource.SYSTEM)
km.add_key("qwen", "test-qwen-key", source=KeySource.SYSTEM)
km.add_key("gemini", "test-gemini-key", source=KeySource.ENV)
km.add_key("wenxin", "test-wenxin-key", source=KeySource.SYSTEM)
return km
@pytest.fixture
def selector(self, key_manager):
return EngineSelector(key_manager=key_manager)
def test_initialization(self, selector, key_manager):
assert selector.key_manager is key_manager
assert selector.router is not None
assert selector.router._key_manager is key_manager
def test_select_engines_returns_valid_engines(self, selector, key_manager):
selected = selector.select_engines(max_engines=5)
for engine in selected:
key = key_manager.get_any_available_key(engine)
assert key is not None, f"Engine {engine} selected but has no key"
def test_select_engines_respects_max_engines(self, selector):
selected = selector.select_engines(max_engines=2)
assert len(selected) <= 2
def test_select_engines_with_min_cost_tier(self, selector):
selected = selector.select_engines(max_engines=10, min_cost_tier="free")
for engine in selected:
profile = ENGINE_COST_PROFILES[engine]
assert profile.cost_tier in [CostTier.FREE, CostTier.LOW_COST, CostTier.MID_COST, CostTier.HIGH_COST]
tier_order = ["free", "low_cost", "mid_cost", "high_cost"]
assert tier_order.index(profile.cost_tier.value) >= tier_order.index("free")
def test_get_best_value_engine(self, selector, key_manager):
best = selector.get_best_value_engine()
if best:
profile = ENGINE_COST_PROFILES[best]
key = key_manager.get_any_available_key(best)
assert key is not None
assert profile.cost_tier in [CostTier.FREE, CostTier.LOW_COST]
def test_get_best_value_returns_none_when_no_keys(self):
km = APIKeyManager()
selector = EngineSelector(key_manager=km)
best = selector.get_best_value_engine()
assert best is None
def test_select_engines_priority_domestic(self, selector):
selected = selector.select_engines(max_engines=10, prefer_domestic=True)
domestic = [e for e in selected if ENGINE_COST_PROFILES[e].domestic]
international = [e for e in selected if not ENGINE_COST_PROFILES[e].domestic]
if domestic and international:
last_domestic_idx = max(selected.index(e) for e in domestic)
first_intl_idx = min(selected.index(e) for e in international)
assert last_domestic_idx < first_intl_idx