233 lines
8.9 KiB
Python
233 lines
8.9 KiB
Python
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from app.services.ai_engine.base import AIEngineAdapter, AIQueryResult, EngineType
|
|
from app.services.ai_engine.gemini import GeminiAdapter
|
|
from app.services.ai_engine.qwen import QwenAdapter
|
|
|
|
|
|
class TestQwenAdapter:
|
|
@pytest.mark.asyncio
|
|
async def test_initialization(self):
|
|
adapter = QwenAdapter(api_key="test-dashscope-key")
|
|
assert adapter.api_key == "test-dashscope-key"
|
|
assert adapter._model == "qwen-plus"
|
|
assert adapter._base_url == "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_initialization_with_custom_params(self):
|
|
adapter = QwenAdapter(
|
|
api_key="custom-key",
|
|
model="qwen-max",
|
|
base_url="https://custom-url.com/v1",
|
|
)
|
|
assert adapter.api_key == "custom-key"
|
|
assert adapter._model == "qwen-max"
|
|
assert adapter._base_url == "https://custom-url.com/v1"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_query_returns_ai_query_result(self):
|
|
adapter = QwenAdapter(api_key="test-key")
|
|
mock_response_data = {
|
|
"choices": [{"message": {"content": "华为是全球领先的ICT基础设施和智能终端提供商"}}],
|
|
"model": "qwen-plus",
|
|
"usage": {"total_tokens": 100},
|
|
}
|
|
|
|
with patch.object(adapter, "_request_with_retry", return_value=mock_response_data):
|
|
result = await adapter.query("华为公司", brand_name="华为")
|
|
|
|
assert isinstance(result, AIQueryResult)
|
|
assert result.engine_type == EngineType.QWEN
|
|
assert "华为" in result.raw_response
|
|
assert result.has_brand_citation is True
|
|
assert result.metadata.get("model") == "qwen-plus"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_engine_type(self):
|
|
adapter = QwenAdapter(api_key="test-key")
|
|
assert adapter.get_engine_type() == EngineType.QWEN
|
|
assert adapter.get_engine_type().value == "qwen"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_error_handling(self):
|
|
adapter = QwenAdapter(api_key="test-key")
|
|
|
|
with patch.object(
|
|
adapter,
|
|
"_request_with_retry",
|
|
side_effect=Exception("HTTP 500: Internal Server Error"),
|
|
):
|
|
with pytest.raises(Exception, match="HTTP 500"):
|
|
await adapter.query("测试问题", brand_name="华为")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chinese_brand_citation_detection(self):
|
|
adapter = QwenAdapter(api_key="test-key")
|
|
mock_response_data = {
|
|
"choices": [{"message": {"content": "华为和小米都是中国知名的科技企业"}}],
|
|
"model": "qwen-plus",
|
|
}
|
|
|
|
with patch.object(adapter, "_request_with_retry", return_value=mock_response_data):
|
|
result = await adapter.query(
|
|
"科技公司", brand_name="华为", competitor_names=["小米"]
|
|
)
|
|
|
|
assert result.has_brand_citation is True
|
|
assert result.has_competitor_citation is True
|
|
assert result.brand_context is not None
|
|
assert "华为" in result.brand_context
|
|
assert len(result.competitor_contexts) == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rate_limiter_called(self):
|
|
mock_limiter = AsyncMock()
|
|
adapter = QwenAdapter(api_key="test-key", rate_limiter=mock_limiter)
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {
|
|
"choices": [{"message": {"content": "测试回复"}}],
|
|
"model": "qwen-plus",
|
|
}
|
|
|
|
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
|
mock_post.return_value = mock_response
|
|
await adapter.query("测试", brand_name="华为")
|
|
|
|
mock_limiter.acquire.assert_awaited()
|
|
|
|
|
|
class TestGeminiAdapter:
|
|
@pytest.mark.asyncio
|
|
async def test_initialization(self):
|
|
adapter = GeminiAdapter(api_key="test-google-key")
|
|
assert adapter.api_key == "test-google-key"
|
|
assert adapter._model == "gemini-pro"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_initialization_with_custom_params(self):
|
|
adapter = GeminiAdapter(
|
|
api_key="custom-key",
|
|
model="gemini-1.5-pro",
|
|
)
|
|
assert adapter.api_key == "custom-key"
|
|
assert adapter._model == "gemini-1.5-pro"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_query_returns_ai_query_result(self):
|
|
adapter = GeminiAdapter(api_key="test-key")
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {
|
|
"candidates": [{"content": {"parts": [{"text": "Google is a leading technology company."}]}}],
|
|
}
|
|
|
|
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
|
mock_post.return_value = mock_response
|
|
result = await adapter.query("best tech companies", brand_name="Google")
|
|
|
|
assert isinstance(result, AIQueryResult)
|
|
assert result.engine_type == EngineType.GEMINI
|
|
assert "Google" in result.raw_response
|
|
assert result.has_brand_citation is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_engine_type(self):
|
|
adapter = GeminiAdapter(api_key="test-key")
|
|
assert adapter.get_engine_type() == EngineType.GEMINI
|
|
assert adapter.get_engine_type().value == "gemini"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_error_handling(self):
|
|
adapter = GeminiAdapter(api_key="test-key")
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 400
|
|
mock_response.text = "Bad Request"
|
|
|
|
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
|
mock_post.return_value = mock_response
|
|
with pytest.raises(RuntimeError, match="Gemini"):
|
|
await adapter.query("test query", brand_name="Google")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_english_brand_citation_detection(self):
|
|
adapter = GeminiAdapter(api_key="test-key")
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {
|
|
"candidates": [{"content": {"parts": [{"text": "Google and Microsoft are major tech companies."}]}}],
|
|
}
|
|
|
|
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
|
mock_post.return_value = mock_response
|
|
result = await adapter.query(
|
|
"tech companies", brand_name="Google", competitor_names=["Microsoft"]
|
|
)
|
|
|
|
assert result.has_brand_citation is True
|
|
assert result.has_competitor_citation is True
|
|
assert result.brand_context is not None
|
|
assert "Google" in result.brand_context
|
|
assert len(result.competitor_contexts) == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rate_limiter_called(self):
|
|
mock_limiter = AsyncMock()
|
|
adapter = GeminiAdapter(api_key="test-key", rate_limiter=mock_limiter)
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {
|
|
"candidates": [{"content": {"parts": [{"text": "Test response"}]}}],
|
|
}
|
|
|
|
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
|
mock_post.return_value = mock_response
|
|
await adapter.query("test", brand_name="Google")
|
|
|
|
mock_limiter.acquire.assert_awaited()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_api_key_in_url(self):
|
|
adapter = GeminiAdapter(api_key="my-secret-key")
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = {
|
|
"candidates": [{"content": {"parts": [{"text": "response"}]}}],
|
|
}
|
|
|
|
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
|
|
mock_post.return_value = mock_response
|
|
await adapter.query("test", brand_name="BrandX")
|
|
|
|
call_args = mock_post.call_args
|
|
assert "key=my-secret-key" in str(call_args)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_proxy_support(self):
|
|
adapter = GeminiAdapter(api_key="test-key", proxy="http://proxy:8080")
|
|
assert adapter._proxy == "http://proxy:8080"
|
|
|
|
|
|
class TestAdapterInheritance:
|
|
def test_qwen_inherits_base(self):
|
|
assert issubclass(QwenAdapter, AIEngineAdapter)
|
|
|
|
def test_gemini_inherits_base(self):
|
|
assert issubclass(GeminiAdapter, AIEngineAdapter)
|
|
|
|
def test_qwen_has_query_method(self):
|
|
assert hasattr(QwenAdapter, "query")
|
|
assert callable(getattr(QwenAdapter, "query"))
|
|
|
|
def test_gemini_has_query_method(self):
|
|
assert hasattr(GeminiAdapter, "query")
|
|
assert callable(getattr(GeminiAdapter, "query"))
|
|
|
|
def test_qwen_has_detect_citations(self):
|
|
assert hasattr(QwenAdapter, "_detect_citations")
|
|
|
|
def test_gemini_has_detect_citations(self):
|
|
assert hasattr(GeminiAdapter, "_detect_citations")
|