geo/backend/tests/test_services/test_qwen_gemini_adapters.py

233 lines
8.9 KiB
Python

from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.services.ai_engine.base import AIEngineAdapter, AIQueryResult, EngineType
from app.services.ai_engine.gemini import GeminiAdapter
from app.services.ai_engine.qwen import QwenAdapter
class TestQwenAdapter:
@pytest.mark.asyncio
async def test_initialization(self):
adapter = QwenAdapter(api_key="test-dashscope-key")
assert adapter.api_key == "test-dashscope-key"
assert adapter._model == "qwen-plus"
assert adapter._base_url == "https://dashscope.aliyuncs.com/compatible-mode/v1"
@pytest.mark.asyncio
async def test_initialization_with_custom_params(self):
adapter = QwenAdapter(
api_key="custom-key",
model="qwen-max",
base_url="https://custom-url.com/v1",
)
assert adapter.api_key == "custom-key"
assert adapter._model == "qwen-max"
assert adapter._base_url == "https://custom-url.com/v1"
@pytest.mark.asyncio
async def test_query_returns_ai_query_result(self):
adapter = QwenAdapter(api_key="test-key")
mock_response_data = {
"choices": [{"message": {"content": "华为是全球领先的ICT基础设施和智能终端提供商"}}],
"model": "qwen-plus",
"usage": {"total_tokens": 100},
}
with patch.object(adapter, "_request_with_retry", return_value=mock_response_data):
result = await adapter.query("华为公司", brand_name="华为")
assert isinstance(result, AIQueryResult)
assert result.engine_type == EngineType.QWEN
assert "华为" in result.raw_response
assert result.has_brand_citation is True
assert result.metadata.get("model") == "qwen-plus"
@pytest.mark.asyncio
async def test_get_engine_type(self):
adapter = QwenAdapter(api_key="test-key")
assert adapter.get_engine_type() == EngineType.QWEN
assert adapter.get_engine_type().value == "qwen"
@pytest.mark.asyncio
async def test_error_handling(self):
adapter = QwenAdapter(api_key="test-key")
with patch.object(
adapter,
"_request_with_retry",
side_effect=Exception("HTTP 500: Internal Server Error"),
):
with pytest.raises(Exception, match="HTTP 500"):
await adapter.query("测试问题", brand_name="华为")
@pytest.mark.asyncio
async def test_chinese_brand_citation_detection(self):
adapter = QwenAdapter(api_key="test-key")
mock_response_data = {
"choices": [{"message": {"content": "华为和小米都是中国知名的科技企业"}}],
"model": "qwen-plus",
}
with patch.object(adapter, "_request_with_retry", return_value=mock_response_data):
result = await adapter.query(
"科技公司", brand_name="华为", competitor_names=["小米"]
)
assert result.has_brand_citation is True
assert result.has_competitor_citation is True
assert result.brand_context is not None
assert "华为" in result.brand_context
assert len(result.competitor_contexts) == 1
@pytest.mark.asyncio
async def test_rate_limiter_called(self):
mock_limiter = AsyncMock()
adapter = QwenAdapter(api_key="test-key", rate_limiter=mock_limiter)
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"choices": [{"message": {"content": "测试回复"}}],
"model": "qwen-plus",
}
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
await adapter.query("测试", brand_name="华为")
mock_limiter.acquire.assert_awaited()
class TestGeminiAdapter:
@pytest.mark.asyncio
async def test_initialization(self):
adapter = GeminiAdapter(api_key="test-google-key")
assert adapter.api_key == "test-google-key"
assert adapter._model == "gemini-pro"
@pytest.mark.asyncio
async def test_initialization_with_custom_params(self):
adapter = GeminiAdapter(
api_key="custom-key",
model="gemini-1.5-pro",
)
assert adapter.api_key == "custom-key"
assert adapter._model == "gemini-1.5-pro"
@pytest.mark.asyncio
async def test_query_returns_ai_query_result(self):
adapter = GeminiAdapter(api_key="test-key")
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"candidates": [{"content": {"parts": [{"text": "Google is a leading technology company."}]}}],
}
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
result = await adapter.query("best tech companies", brand_name="Google")
assert isinstance(result, AIQueryResult)
assert result.engine_type == EngineType.GEMINI
assert "Google" in result.raw_response
assert result.has_brand_citation is True
@pytest.mark.asyncio
async def test_get_engine_type(self):
adapter = GeminiAdapter(api_key="test-key")
assert adapter.get_engine_type() == EngineType.GEMINI
assert adapter.get_engine_type().value == "gemini"
@pytest.mark.asyncio
async def test_error_handling(self):
adapter = GeminiAdapter(api_key="test-key")
mock_response = MagicMock()
mock_response.status_code = 400
mock_response.text = "Bad Request"
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
with pytest.raises(RuntimeError, match="Gemini"):
await adapter.query("test query", brand_name="Google")
@pytest.mark.asyncio
async def test_english_brand_citation_detection(self):
adapter = GeminiAdapter(api_key="test-key")
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"candidates": [{"content": {"parts": [{"text": "Google and Microsoft are major tech companies."}]}}],
}
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
result = await adapter.query(
"tech companies", brand_name="Google", competitor_names=["Microsoft"]
)
assert result.has_brand_citation is True
assert result.has_competitor_citation is True
assert result.brand_context is not None
assert "Google" in result.brand_context
assert len(result.competitor_contexts) == 1
@pytest.mark.asyncio
async def test_rate_limiter_called(self):
mock_limiter = AsyncMock()
adapter = GeminiAdapter(api_key="test-key", rate_limiter=mock_limiter)
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"candidates": [{"content": {"parts": [{"text": "Test response"}]}}],
}
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
await adapter.query("test", brand_name="Google")
mock_limiter.acquire.assert_awaited()
@pytest.mark.asyncio
async def test_api_key_in_url(self):
adapter = GeminiAdapter(api_key="my-secret-key")
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"candidates": [{"content": {"parts": [{"text": "response"}]}}],
}
with patch.object(adapter._client, "post", new_callable=AsyncMock) as mock_post:
mock_post.return_value = mock_response
await adapter.query("test", brand_name="BrandX")
call_args = mock_post.call_args
assert "key=my-secret-key" in str(call_args)
@pytest.mark.asyncio
async def test_proxy_support(self):
adapter = GeminiAdapter(api_key="test-key", proxy="http://proxy:8080")
assert adapter._proxy == "http://proxy:8080"
class TestAdapterInheritance:
def test_qwen_inherits_base(self):
assert issubclass(QwenAdapter, AIEngineAdapter)
def test_gemini_inherits_base(self):
assert issubclass(GeminiAdapter, AIEngineAdapter)
def test_qwen_has_query_method(self):
assert hasattr(QwenAdapter, "query")
assert callable(getattr(QwenAdapter, "query"))
def test_gemini_has_query_method(self):
assert hasattr(GeminiAdapter, "query")
assert callable(getattr(GeminiAdapter, "query"))
def test_qwen_has_detect_citations(self):
assert hasattr(QwenAdapter, "_detect_citations")
def test_gemini_has_detect_citations(self):
assert hasattr(GeminiAdapter, "_detect_citations")