278 lines
11 KiB
Python
278 lines
11 KiB
Python
import pytest
|
|
from unittest.mock import AsyncMock, Mock, patch, MagicMock
|
|
|
|
from app.services.ai_engine.base import (
|
|
AIEngineAdapter,
|
|
AIQueryResult,
|
|
CitationInfo,
|
|
EngineType,
|
|
)
|
|
from app.services.ai_engine.kimi import KimiAdapter
|
|
from app.services.ai_engine.wenxin import WenxinAdapter
|
|
from app.services.ai_engine.doubao import DoubaoAdapter
|
|
from app.services.ai_engine.yuanbao import YuanbaoAdapter
|
|
|
|
|
|
def _make_mock_response(status_code=200, json_data=None, text="", headers=None):
|
|
mock_resp = Mock()
|
|
mock_resp.status_code = status_code
|
|
mock_resp.json.return_value = json_data or {}
|
|
mock_resp.text = text
|
|
mock_resp.headers = headers or {}
|
|
return mock_resp
|
|
|
|
|
|
class TestKimiAdapter:
|
|
@pytest.mark.asyncio
|
|
async def test_initialization(self):
|
|
adapter = KimiAdapter(api_key="test-key")
|
|
assert adapter.api_key == "test-key"
|
|
assert adapter.get_engine_type() == EngineType.KIMI
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_query_returns_ai_query_result(self):
|
|
adapter = KimiAdapter(api_key="test-key")
|
|
mock_response_data = {
|
|
"choices": [{"message": {"content": "华为是全球领先的ICT公司"}}],
|
|
"usage": {"total_tokens": 100},
|
|
"model": "moonshot-v1-8k",
|
|
}
|
|
|
|
with patch.object(adapter, "_request_with_retry", return_value=mock_response_data):
|
|
result = await adapter.query("华为公司", brand_name="华为")
|
|
|
|
assert isinstance(result, AIQueryResult)
|
|
assert result.engine_type == EngineType.KIMI
|
|
assert "华为" in result.raw_response
|
|
assert result.has_brand_citation is True
|
|
assert result.metadata.get("model") == "moonshot-v1-8k"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_api_error_handling(self):
|
|
adapter = KimiAdapter(api_key="test-key")
|
|
|
|
with patch.object(
|
|
adapter,
|
|
"_request_with_retry",
|
|
side_effect=Exception("HTTP 500: Internal Server Error"),
|
|
):
|
|
with pytest.raises(Exception, match="HTTP 500"):
|
|
await adapter.query("测试问题", brand_name="华为")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rate_limit_handling(self):
|
|
adapter = KimiAdapter(api_key="test-key")
|
|
|
|
with patch.object(
|
|
adapter,
|
|
"_request_with_retry",
|
|
side_effect=Exception("HTTP 429: rate limited"),
|
|
):
|
|
with pytest.raises(Exception, match="429"):
|
|
await adapter.query("测试问题", brand_name="华为")
|
|
|
|
|
|
class TestWenxinAdapter:
|
|
@pytest.mark.asyncio
|
|
async def test_initialization(self):
|
|
adapter = WenxinAdapter(api_key="test-key", secret_key="test-secret")
|
|
assert adapter.api_key == "test-key"
|
|
assert adapter.secret_key == "test-secret"
|
|
assert adapter.get_engine_type() == EngineType.WENXIN
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_query_returns_ai_query_result(self):
|
|
import app.services.ai_engine.wenxin as wenxin_mod
|
|
wenxin_mod._cached_token = None
|
|
wenxin_mod._token_expires_at = 0.0
|
|
|
|
adapter = WenxinAdapter(api_key="test-key", secret_key="test-secret")
|
|
mock_token_data = {"access_token": "test-access-token", "expires_in": 2592000}
|
|
mock_chat_data = {"result": "华为是一家全球领先的科技公司", "usage": {"total_tokens": 100}}
|
|
|
|
mock_client = AsyncMock()
|
|
token_response = _make_mock_response(200, mock_token_data)
|
|
chat_response = _make_mock_response(200, mock_chat_data)
|
|
mock_client.post.side_effect = [token_response, chat_response]
|
|
|
|
with patch.object(adapter, "_client", mock_client):
|
|
result = await adapter.query("华为公司", brand_name="华为")
|
|
|
|
assert isinstance(result, AIQueryResult)
|
|
assert result.engine_type == EngineType.WENXIN
|
|
assert "华为" in result.raw_response
|
|
assert result.has_brand_citation is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_api_error_handling(self):
|
|
import app.services.ai_engine.wenxin as wenxin_mod
|
|
wenxin_mod._cached_token = None
|
|
wenxin_mod._token_expires_at = 0.0
|
|
|
|
adapter = WenxinAdapter(api_key="test-key", secret_key="test-secret")
|
|
|
|
mock_client = AsyncMock()
|
|
token_response = _make_mock_response(200, {"access_token": "test-token", "expires_in": 2592000})
|
|
error_response = _make_mock_response(500, text="Internal Server Error")
|
|
mock_client.post.side_effect = [token_response, error_response]
|
|
|
|
with patch.object(adapter, "_client", mock_client):
|
|
with pytest.raises(RuntimeError, match="文心"):
|
|
await adapter.query("测试问题", brand_name="华为")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rate_limit_handling(self):
|
|
import app.services.ai_engine.wenxin as wenxin_mod
|
|
wenxin_mod._cached_token = None
|
|
wenxin_mod._token_expires_at = 0.0
|
|
|
|
adapter = WenxinAdapter(api_key="test-key", secret_key="test-secret")
|
|
|
|
mock_client = AsyncMock()
|
|
token_response = _make_mock_response(200, {"access_token": "test-token", "expires_in": 2592000})
|
|
rate_limit_response = _make_mock_response(429, headers={"Retry-After": "1"})
|
|
mock_client.post.side_effect = [token_response, rate_limit_response]
|
|
|
|
with patch.object(adapter, "_client", mock_client):
|
|
with pytest.raises(RuntimeError, match="限流"):
|
|
await adapter.query("测试问题", brand_name="华为")
|
|
|
|
|
|
class TestDoubaoAdapter:
|
|
@pytest.mark.asyncio
|
|
async def test_initialization(self):
|
|
adapter = DoubaoAdapter(api_key="test-key", endpoint_id="ep-test")
|
|
assert adapter.api_key == "test-key"
|
|
assert adapter._endpoint_id == "ep-test"
|
|
assert adapter.get_engine_type() == EngineType.DOUBAO
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_query_returns_ai_query_result(self):
|
|
adapter = DoubaoAdapter(api_key="test-key", endpoint_id="ep-test")
|
|
mock_response_data = {
|
|
"choices": [{"message": {"content": "华为是全球知名企业"}}],
|
|
"model": "ep-test",
|
|
}
|
|
|
|
with patch.object(adapter, "_request_with_retry", return_value=mock_response_data):
|
|
result = await adapter.query("华为公司", brand_name="华为")
|
|
|
|
assert isinstance(result, AIQueryResult)
|
|
assert result.engine_type == EngineType.DOUBAO
|
|
assert "华为" in result.raw_response
|
|
assert result.has_brand_citation is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_api_error_handling(self):
|
|
adapter = DoubaoAdapter(api_key="test-key", endpoint_id="ep-test")
|
|
|
|
with patch.object(
|
|
adapter,
|
|
"_request_with_retry",
|
|
side_effect=Exception("HTTP 500: Internal Server Error"),
|
|
):
|
|
with pytest.raises(Exception, match="HTTP 500"):
|
|
await adapter.query("测试问题", brand_name="华为")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rate_limit_handling(self):
|
|
adapter = DoubaoAdapter(api_key="test-key", endpoint_id="ep-test")
|
|
|
|
with patch.object(
|
|
adapter,
|
|
"_request_with_retry",
|
|
side_effect=Exception("HTTP 429: rate limited"),
|
|
):
|
|
with pytest.raises(Exception, match="429"):
|
|
await adapter.query("测试问题", brand_name="华为")
|
|
|
|
|
|
class TestEngineType:
|
|
def test_kimi_engine_type(self):
|
|
adapter = KimiAdapter(api_key="test-key")
|
|
assert adapter.get_engine_type() == EngineType.KIMI
|
|
assert adapter.get_engine_type().value == "kimi"
|
|
|
|
def test_wenxin_engine_type(self):
|
|
adapter = WenxinAdapter(api_key="test-key", secret_key="test-secret")
|
|
assert adapter.get_engine_type() == EngineType.WENXIN
|
|
assert adapter.get_engine_type().value == "wenxin"
|
|
|
|
def test_doubao_engine_type(self):
|
|
adapter = DoubaoAdapter(api_key="test-key", endpoint_id="ep-test")
|
|
assert adapter.get_engine_type() == EngineType.DOUBAO
|
|
assert adapter.get_engine_type().value == "doubao"
|
|
|
|
def test_yuanbao_engine_type(self):
|
|
adapter = YuanbaoAdapter(api_key="test-key")
|
|
assert adapter.get_engine_type() == EngineType.YUANBAO
|
|
assert adapter.get_engine_type().value == "yuanbao"
|
|
|
|
|
|
class TestChineseCitationDetection:
|
|
def test_brand_name_detection_chinese(self):
|
|
adapter = KimiAdapter(api_key="test-key")
|
|
has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations(
|
|
"华为是全球领先的ICT基础设施和智能终端提供商",
|
|
brand_name="华为",
|
|
competitor_names=None,
|
|
)
|
|
assert has_brand is True
|
|
assert brand_ctx is not None
|
|
assert "华为" in brand_ctx
|
|
|
|
def test_brand_name_detection_case_insensitive(self):
|
|
adapter = KimiAdapter(api_key="test-key")
|
|
has_brand, _, _, _ = adapter._detect_citations(
|
|
"Apple is a great company and apple makes phones",
|
|
brand_name="apple",
|
|
competitor_names=None,
|
|
)
|
|
assert has_brand is True
|
|
|
|
def test_competitor_name_detection(self):
|
|
adapter = KimiAdapter(api_key="test-key")
|
|
has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations(
|
|
"华为和小米都是中国知名手机品牌",
|
|
brand_name="华为",
|
|
competitor_names=["小米"],
|
|
)
|
|
assert has_brand is True
|
|
assert has_comp is True
|
|
assert brand_ctx is not None
|
|
assert len(comp_ctx) > 0
|
|
|
|
def test_no_citations_when_no_match(self):
|
|
adapter = KimiAdapter(api_key="test-key")
|
|
has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations(
|
|
"今天天气很好",
|
|
brand_name="华为",
|
|
competitor_names=["小米"],
|
|
)
|
|
assert has_brand is False
|
|
assert has_comp is False
|
|
assert brand_ctx is None
|
|
assert len(comp_ctx) == 0
|
|
|
|
|
|
class TestAdapterInheritance:
|
|
def test_all_adapters_inherit_base(self):
|
|
assert issubclass(KimiAdapter, AIEngineAdapter)
|
|
assert issubclass(WenxinAdapter, AIEngineAdapter)
|
|
assert issubclass(DoubaoAdapter, AIEngineAdapter)
|
|
assert issubclass(YuanbaoAdapter, AIEngineAdapter)
|
|
|
|
def test_all_adapters_have_query_method(self):
|
|
for cls in [KimiAdapter, WenxinAdapter, DoubaoAdapter, YuanbaoAdapter]:
|
|
assert hasattr(cls, "query")
|
|
assert callable(getattr(cls, "query"))
|
|
|
|
def test_all_adapters_have_detect_citations(self):
|
|
for cls in [KimiAdapter, WenxinAdapter, DoubaoAdapter, YuanbaoAdapter]:
|
|
assert hasattr(cls, "_detect_citations")
|
|
|
|
def test_all_adapters_have_get_engine_type(self):
|
|
for cls in [KimiAdapter, WenxinAdapter, DoubaoAdapter, YuanbaoAdapter]:
|
|
instance = cls(api_key="test-key") if cls != WenxinAdapter else cls(api_key="test-key", secret_key="s")
|
|
assert instance.get_engine_type() in EngineType
|