179 lines
7.0 KiB
Python
179 lines
7.0 KiB
Python
import pytest
|
|
from unittest.mock import patch, MagicMock, call
|
|
import httpx
|
|
import asyncio
|
|
|
|
from app.services.ai_engine.base import AIEngineAdapter, EngineType
|
|
|
|
|
|
class TestBatchQueryExceptionHandling:
|
|
"""测试 batch_query.py 中的异常处理行为"""
|
|
|
|
def test_build_adapters_handles_generic_exception(self, caplog):
|
|
"""测试 _build_adapters 对通用异常的处理(应记录 error 日志)"""
|
|
from app.services.ai_engine.batch_query import _build_adapters, _ADAPTER_CLASSES
|
|
|
|
class FailingAdapter(AIEngineAdapter):
|
|
def __init__(self):
|
|
super().__init__(api_key="test")
|
|
raise RuntimeError("Simulated initialization failure")
|
|
|
|
def get_engine_type(self):
|
|
return EngineType.CHATGPT
|
|
|
|
def _get_env_key(self):
|
|
return "TEST_KEY"
|
|
|
|
async def query(self, query, brand_name, competitor_names=None):
|
|
pass
|
|
|
|
with patch.dict(_ADAPTER_CLASSES, {EngineType.CHATGPT: FailingAdapter}):
|
|
_build_adapters.cache_clear()
|
|
result = _build_adapters()
|
|
assert EngineType.CHATGPT.value not in result
|
|
|
|
def test_build_adapters_handles_httpx_http_error(self, caplog):
|
|
"""测试 _build_adapters 对 httpx.HTTPError 的处理(应记录 warning 日志)"""
|
|
from app.services.ai_engine.batch_query import _build_adapters, _ADAPTER_CLASSES
|
|
|
|
class HttpErrorAdapter(AIEngineAdapter):
|
|
def __init__(self):
|
|
super().__init__(api_key="test")
|
|
raise httpx.HTTPError("HTTP connection failed")
|
|
|
|
def get_engine_type(self):
|
|
return EngineType.PERPLEXITY
|
|
|
|
def _get_env_key(self):
|
|
return "TEST_KEY"
|
|
|
|
async def query(self, query, brand_name, competitor_names=None):
|
|
pass
|
|
|
|
with patch.dict(_ADAPTER_CLASSES, {EngineType.PERPLEXITY: HttpErrorAdapter}):
|
|
_build_adapters.cache_clear()
|
|
result = _build_adapters()
|
|
assert EngineType.PERPLEXITY.value not in result
|
|
assert any("HTTP error" in record.message for record in caplog.records)
|
|
|
|
def test_build_adapters_handles_timeout_error(self, caplog):
|
|
"""测试 _build_adapters 对 asyncio.TimeoutError 的处理(应记录 warning 日志)"""
|
|
from app.services.ai_engine.batch_query import _build_adapters, _ADAPTER_CLASSES
|
|
|
|
class TimeoutAdapter(AIEngineAdapter):
|
|
def __init__(self):
|
|
super().__init__(api_key="test")
|
|
raise asyncio.TimeoutError("Connection timeout")
|
|
|
|
def get_engine_type(self):
|
|
return EngineType.KIMI
|
|
|
|
def _get_env_key(self):
|
|
return "TEST_KEY"
|
|
|
|
async def query(self, query, brand_name, competitor_names=None):
|
|
pass
|
|
|
|
with patch.dict(_ADAPTER_CLASSES, {EngineType.KIMI: TimeoutAdapter}):
|
|
_build_adapters.cache_clear()
|
|
result = _build_adapters()
|
|
assert EngineType.KIMI.value not in result
|
|
assert any("Timeout" in record.message for record in caplog.records)
|
|
|
|
def test_build_adapters_with_key_manager_handles_httpx_http_error(self, caplog):
|
|
"""测试 _build_adapters_with_key_manager 对 httpx.HTTPError 的处理"""
|
|
from app.services.ai_engine.batch_query import _build_adapters_with_key_manager, _ADAPTER_CLASSES
|
|
|
|
class HttpErrorAdapter(AIEngineAdapter):
|
|
def __init__(self, key_manager=None, user_id=None):
|
|
super().__init__(api_key="test")
|
|
raise httpx.HTTPError("HTTP connection failed")
|
|
|
|
def get_engine_type(self):
|
|
return EngineType.WENXIN
|
|
|
|
def _get_env_key(self):
|
|
return "TEST_KEY"
|
|
|
|
async def query(self, query, brand_name, competitor_names=None):
|
|
pass
|
|
|
|
with patch.dict(_ADAPTER_CLASSES, {EngineType.WENXIN: HttpErrorAdapter}):
|
|
result = _build_adapters_with_key_manager(
|
|
key_manager=MagicMock(),
|
|
user_id="test_user"
|
|
)
|
|
assert EngineType.WENXIN.value not in result
|
|
assert any("HTTP error" in record.message for record in caplog.records)
|
|
|
|
def test_build_adapters_with_key_manager_handles_timeout_error(self, caplog):
|
|
"""测试 _build_adapters_with_key_manager 对 asyncio.TimeoutError 的处理"""
|
|
from app.services.ai_engine.batch_query import _build_adapters_with_key_manager, _ADAPTER_CLASSES
|
|
|
|
class TimeoutAdapter(AIEngineAdapter):
|
|
def __init__(self, key_manager=None, user_id=None):
|
|
super().__init__(api_key="test")
|
|
raise asyncio.TimeoutError("Connection timeout")
|
|
|
|
def get_engine_type(self):
|
|
return EngineType.DEEPSEEK
|
|
|
|
def _get_env_key(self):
|
|
return "TEST_KEY"
|
|
|
|
async def query(self, query, brand_name, competitor_names=None):
|
|
pass
|
|
|
|
with patch.dict(_ADAPTER_CLASSES, {EngineType.DEEPSEEK: TimeoutAdapter}):
|
|
result = _build_adapters_with_key_manager(
|
|
key_manager=MagicMock(),
|
|
user_id="test_user"
|
|
)
|
|
assert EngineType.DEEPSEEK.value not in result
|
|
assert any("Timeout" in record.message for record in caplog.records)
|
|
|
|
def test_build_adapters_with_key_manager_handles_generic_exception(self, caplog):
|
|
"""测试 _build_adapters_with_key_manager 对通用异常的处理"""
|
|
from app.services.ai_engine.batch_query import _build_adapters_with_key_manager, _ADAPTER_CLASSES
|
|
|
|
class FailingAdapter(AIEngineAdapter):
|
|
def __init__(self, key_manager=None, user_id=None):
|
|
super().__init__(api_key="test")
|
|
raise RuntimeError("Simulated key manager failure")
|
|
|
|
def get_engine_type(self):
|
|
return EngineType.QWEN
|
|
|
|
def _get_env_key(self):
|
|
return "TEST_KEY"
|
|
|
|
async def query(self, query, brand_name, competitor_names=None):
|
|
pass
|
|
|
|
with patch.dict(_ADAPTER_CLASSES, {EngineType.QWEN: FailingAdapter}):
|
|
result = _build_adapters_with_key_manager(
|
|
key_manager=MagicMock(),
|
|
user_id="test_user"
|
|
)
|
|
assert EngineType.QWEN.value not in result
|
|
|
|
def test_register_adapter_handles_exception(self):
|
|
"""测试 register_adapter 函数对异常的处理"""
|
|
from app.services.ai_engine.batch_query import register_adapter
|
|
|
|
class BadAdapter(AIEngineAdapter):
|
|
def __init__(self):
|
|
super().__init__(api_key="test")
|
|
raise ValueError("Adapter registration failed")
|
|
|
|
def get_engine_type(self):
|
|
return EngineType.GEMINI
|
|
|
|
def _get_env_key(self):
|
|
return "TEST_KEY"
|
|
|
|
async def query(self, query, brand_name, competitor_names=None):
|
|
pass
|
|
|
|
register_adapter(BadAdapter)
|