geo/backend/tests/test_services/test_ai_engine_chinese.py

278 lines
11 KiB
Python

import pytest
from unittest.mock import AsyncMock, Mock, patch, MagicMock
from app.services.ai_engine.base import (
AIEngineAdapter,
AIQueryResult,
CitationInfo,
EngineType,
)
from app.services.ai_engine.kimi import KimiAdapter
from app.services.ai_engine.wenxin import WenxinAdapter
from app.services.ai_engine.doubao import DoubaoAdapter
from app.services.ai_engine.yuanbao import YuanbaoAdapter
def _make_mock_response(status_code=200, json_data=None, text="", headers=None):
mock_resp = Mock()
mock_resp.status_code = status_code
mock_resp.json.return_value = json_data or {}
mock_resp.text = text
mock_resp.headers = headers or {}
return mock_resp
class TestKimiAdapter:
@pytest.mark.asyncio
async def test_initialization(self):
adapter = KimiAdapter(api_key="test-key")
assert adapter.api_key == "test-key"
assert adapter.get_engine_type() == EngineType.KIMI
@pytest.mark.asyncio
async def test_query_returns_ai_query_result(self):
adapter = KimiAdapter(api_key="test-key")
mock_response_data = {
"choices": [{"message": {"content": "华为是全球领先的ICT公司"}}],
"usage": {"total_tokens": 100},
"model": "moonshot-v1-8k",
}
with patch.object(adapter, "_request_with_retry", return_value=mock_response_data):
result = await adapter.query("华为公司", brand_name="华为")
assert isinstance(result, AIQueryResult)
assert result.engine_type == EngineType.KIMI
assert "华为" in result.raw_response
assert result.has_brand_citation is True
assert result.metadata.get("model") == "moonshot-v1-8k"
@pytest.mark.asyncio
async def test_api_error_handling(self):
adapter = KimiAdapter(api_key="test-key")
with patch.object(
adapter,
"_request_with_retry",
side_effect=Exception("HTTP 500: Internal Server Error"),
):
with pytest.raises(Exception, match="HTTP 500"):
await adapter.query("测试问题", brand_name="华为")
@pytest.mark.asyncio
async def test_rate_limit_handling(self):
adapter = KimiAdapter(api_key="test-key")
with patch.object(
adapter,
"_request_with_retry",
side_effect=Exception("HTTP 429: rate limited"),
):
with pytest.raises(Exception, match="429"):
await adapter.query("测试问题", brand_name="华为")
class TestWenxinAdapter:
@pytest.mark.asyncio
async def test_initialization(self):
adapter = WenxinAdapter(api_key="test-key", secret_key="test-secret")
assert adapter.api_key == "test-key"
assert adapter.secret_key == "test-secret"
assert adapter.get_engine_type() == EngineType.WENXIN
@pytest.mark.asyncio
async def test_query_returns_ai_query_result(self):
import app.services.ai_engine.wenxin as wenxin_mod
wenxin_mod._cached_token = None
wenxin_mod._token_expires_at = 0.0
adapter = WenxinAdapter(api_key="test-key", secret_key="test-secret")
mock_token_data = {"access_token": "test-access-token", "expires_in": 2592000}
mock_chat_data = {"result": "华为是一家全球领先的科技公司", "usage": {"total_tokens": 100}}
mock_client = AsyncMock()
token_response = _make_mock_response(200, mock_token_data)
chat_response = _make_mock_response(200, mock_chat_data)
mock_client.post.side_effect = [token_response, chat_response]
with patch.object(adapter, "_client", mock_client):
result = await adapter.query("华为公司", brand_name="华为")
assert isinstance(result, AIQueryResult)
assert result.engine_type == EngineType.WENXIN
assert "华为" in result.raw_response
assert result.has_brand_citation is True
@pytest.mark.asyncio
async def test_api_error_handling(self):
import app.services.ai_engine.wenxin as wenxin_mod
wenxin_mod._cached_token = None
wenxin_mod._token_expires_at = 0.0
adapter = WenxinAdapter(api_key="test-key", secret_key="test-secret")
mock_client = AsyncMock()
token_response = _make_mock_response(200, {"access_token": "test-token", "expires_in": 2592000})
error_response = _make_mock_response(500, text="Internal Server Error")
mock_client.post.side_effect = [token_response, error_response]
with patch.object(adapter, "_client", mock_client):
with pytest.raises(RuntimeError, match="文心"):
await adapter.query("测试问题", brand_name="华为")
@pytest.mark.asyncio
async def test_rate_limit_handling(self):
import app.services.ai_engine.wenxin as wenxin_mod
wenxin_mod._cached_token = None
wenxin_mod._token_expires_at = 0.0
adapter = WenxinAdapter(api_key="test-key", secret_key="test-secret")
mock_client = AsyncMock()
token_response = _make_mock_response(200, {"access_token": "test-token", "expires_in": 2592000})
rate_limit_response = _make_mock_response(429, headers={"Retry-After": "1"})
mock_client.post.side_effect = [token_response, rate_limit_response]
with patch.object(adapter, "_client", mock_client):
with pytest.raises(RuntimeError, match="限流"):
await adapter.query("测试问题", brand_name="华为")
class TestDoubaoAdapter:
@pytest.mark.asyncio
async def test_initialization(self):
adapter = DoubaoAdapter(api_key="test-key", endpoint_id="ep-test")
assert adapter.api_key == "test-key"
assert adapter._endpoint_id == "ep-test"
assert adapter.get_engine_type() == EngineType.DOUBAO
@pytest.mark.asyncio
async def test_query_returns_ai_query_result(self):
adapter = DoubaoAdapter(api_key="test-key", endpoint_id="ep-test")
mock_response_data = {
"choices": [{"message": {"content": "华为是全球知名企业"}}],
"model": "ep-test",
}
with patch.object(adapter, "_request_with_retry", return_value=mock_response_data):
result = await adapter.query("华为公司", brand_name="华为")
assert isinstance(result, AIQueryResult)
assert result.engine_type == EngineType.DOUBAO
assert "华为" in result.raw_response
assert result.has_brand_citation is True
@pytest.mark.asyncio
async def test_api_error_handling(self):
adapter = DoubaoAdapter(api_key="test-key", endpoint_id="ep-test")
with patch.object(
adapter,
"_request_with_retry",
side_effect=Exception("HTTP 500: Internal Server Error"),
):
with pytest.raises(Exception, match="HTTP 500"):
await adapter.query("测试问题", brand_name="华为")
@pytest.mark.asyncio
async def test_rate_limit_handling(self):
adapter = DoubaoAdapter(api_key="test-key", endpoint_id="ep-test")
with patch.object(
adapter,
"_request_with_retry",
side_effect=Exception("HTTP 429: rate limited"),
):
with pytest.raises(Exception, match="429"):
await adapter.query("测试问题", brand_name="华为")
class TestEngineType:
def test_kimi_engine_type(self):
adapter = KimiAdapter(api_key="test-key")
assert adapter.get_engine_type() == EngineType.KIMI
assert adapter.get_engine_type().value == "kimi"
def test_wenxin_engine_type(self):
adapter = WenxinAdapter(api_key="test-key", secret_key="test-secret")
assert adapter.get_engine_type() == EngineType.WENXIN
assert adapter.get_engine_type().value == "wenxin"
def test_doubao_engine_type(self):
adapter = DoubaoAdapter(api_key="test-key", endpoint_id="ep-test")
assert adapter.get_engine_type() == EngineType.DOUBAO
assert adapter.get_engine_type().value == "doubao"
def test_yuanbao_engine_type(self):
adapter = YuanbaoAdapter(api_key="test-key")
assert adapter.get_engine_type() == EngineType.YUANBAO
assert adapter.get_engine_type().value == "yuanbao"
class TestChineseCitationDetection:
def test_brand_name_detection_chinese(self):
adapter = KimiAdapter(api_key="test-key")
has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations(
"华为是全球领先的ICT基础设施和智能终端提供商",
brand_name="华为",
competitor_names=None,
)
assert has_brand is True
assert brand_ctx is not None
assert "华为" in brand_ctx
def test_brand_name_detection_case_insensitive(self):
adapter = KimiAdapter(api_key="test-key")
has_brand, _, _, _ = adapter._detect_citations(
"Apple is a great company and apple makes phones",
brand_name="apple",
competitor_names=None,
)
assert has_brand is True
def test_competitor_name_detection(self):
adapter = KimiAdapter(api_key="test-key")
has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations(
"华为和小米都是中国知名手机品牌",
brand_name="华为",
competitor_names=["小米"],
)
assert has_brand is True
assert has_comp is True
assert brand_ctx is not None
assert len(comp_ctx) > 0
def test_no_citations_when_no_match(self):
adapter = KimiAdapter(api_key="test-key")
has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations(
"今天天气很好",
brand_name="华为",
competitor_names=["小米"],
)
assert has_brand is False
assert has_comp is False
assert brand_ctx is None
assert len(comp_ctx) == 0
class TestAdapterInheritance:
def test_all_adapters_inherit_base(self):
assert issubclass(KimiAdapter, AIEngineAdapter)
assert issubclass(WenxinAdapter, AIEngineAdapter)
assert issubclass(DoubaoAdapter, AIEngineAdapter)
assert issubclass(YuanbaoAdapter, AIEngineAdapter)
def test_all_adapters_have_query_method(self):
for cls in [KimiAdapter, WenxinAdapter, DoubaoAdapter, YuanbaoAdapter]:
assert hasattr(cls, "query")
assert callable(getattr(cls, "query"))
def test_all_adapters_have_detect_citations(self):
for cls in [KimiAdapter, WenxinAdapter, DoubaoAdapter, YuanbaoAdapter]:
assert hasattr(cls, "_detect_citations")
def test_all_adapters_have_get_engine_type(self):
for cls in [KimiAdapter, WenxinAdapter, DoubaoAdapter, YuanbaoAdapter]:
instance = cls(api_key="test-key") if cls != WenxinAdapter else cls(api_key="test-key", secret_key="s")
assert instance.get_engine_type() in EngineType