228 lines
8.2 KiB
Python
228 lines
8.2 KiB
Python
import pytest
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
from app.services.ai_engine.base import (
|
|
AIEngineAdapter,
|
|
AIQueryResult,
|
|
EngineType,
|
|
)
|
|
from app.services.ai_engine.yuanbao import YuanbaoAdapter
|
|
|
|
|
|
class TestYuanbaoAdapterInitialization:
|
|
@pytest.mark.asyncio
|
|
async def test_initialization_with_api_key(self):
|
|
adapter = YuanbaoAdapter(api_key="test-hunyuan-key")
|
|
assert adapter.api_key == "test-hunyuan-key"
|
|
assert adapter._model == "hunyuan-lite"
|
|
assert adapter._base_url == "https://api.hunyuan.cloud.tencent.com/v1"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_initialization_with_custom_model(self):
|
|
adapter = YuanbaoAdapter(api_key="test-key", model="hunyuan-pro")
|
|
assert adapter._model == "hunyuan-pro"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_initialization_with_custom_base_url(self):
|
|
adapter = YuanbaoAdapter(
|
|
api_key="test-key",
|
|
base_url="https://custom-api.example.com/v1",
|
|
)
|
|
assert adapter._base_url == "https://custom-api.example.com/v1"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_initialization_with_env_vars(self):
|
|
with patch.dict("os.environ", {
|
|
"HUNYUAN_API_KEY": "env-key",
|
|
"HUNYUAN_MODEL": "hunyuan-turbo",
|
|
"HUNYUAN_BASE_URL": "https://env-url.example.com/v1",
|
|
}):
|
|
adapter = YuanbaoAdapter()
|
|
assert adapter.api_key == "env-key"
|
|
assert adapter._model == "hunyuan-turbo"
|
|
assert adapter._base_url == "https://env-url.example.com/v1"
|
|
|
|
|
|
class TestYuanbaoAdapterEngineType:
|
|
def test_get_engine_type_returns_yuanbao(self):
|
|
adapter = YuanbaoAdapter(api_key="test-key")
|
|
assert adapter.get_engine_type() == EngineType.YUANBAO
|
|
|
|
def test_engine_type_value(self):
|
|
assert EngineType.YUANBAO == "yuanbao"
|
|
|
|
def test_engine_type_in_enum(self):
|
|
assert EngineType.YUANBAO in EngineType
|
|
|
|
|
|
class TestYuanbaoAdapterQuery:
|
|
@pytest.mark.asyncio
|
|
async def test_query_returns_ai_query_result(self):
|
|
adapter = YuanbaoAdapter(api_key="test-key")
|
|
mock_response_data = {
|
|
"choices": [{"message": {"content": "华为是全球领先的ICT基础设施和智能终端提供商"}}],
|
|
"usage": {"total_tokens": 100},
|
|
"model": "hunyuan-lite",
|
|
}
|
|
|
|
with patch.object(adapter, "_request_with_retry", return_value=mock_response_data):
|
|
result = await adapter.query("华为公司", brand_name="华为")
|
|
|
|
assert isinstance(result, AIQueryResult)
|
|
assert result.engine_type == EngineType.YUANBAO
|
|
assert "华为" in result.raw_response
|
|
assert result.has_brand_citation is True
|
|
assert result.metadata.get("model") == "hunyuan-lite"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_query_with_competitor_detection(self):
|
|
adapter = YuanbaoAdapter(api_key="test-key")
|
|
mock_response_data = {
|
|
"choices": [{"message": {"content": "华为和小米都是中国知名手机品牌"}}],
|
|
"model": "hunyuan-lite",
|
|
}
|
|
|
|
with patch.object(adapter, "_request_with_retry", return_value=mock_response_data):
|
|
result = await adapter.query(
|
|
"手机品牌",
|
|
brand_name="华为",
|
|
competitor_names=["小米"],
|
|
)
|
|
|
|
assert result.has_brand_citation is True
|
|
assert result.has_competitor_citation is True
|
|
assert len(result.competitor_contexts) > 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_query_no_brand_found(self):
|
|
adapter = YuanbaoAdapter(api_key="test-key")
|
|
mock_response_data = {
|
|
"choices": [{"message": {"content": "今天天气很好"}}],
|
|
"model": "hunyuan-lite",
|
|
}
|
|
|
|
with patch.object(adapter, "_request_with_retry", return_value=mock_response_data):
|
|
result = await adapter.query("天气", brand_name="华为")
|
|
|
|
assert result.has_brand_citation is False
|
|
assert result.has_competitor_citation is False
|
|
assert result.brand_context is None
|
|
assert result.competitor_contexts == []
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_query_records_response_time(self):
|
|
adapter = YuanbaoAdapter(api_key="test-key")
|
|
mock_response_data = {
|
|
"choices": [{"message": {"content": "测试回复"}}],
|
|
"model": "hunyuan-lite",
|
|
}
|
|
|
|
with patch.object(adapter, "_request_with_retry", return_value=mock_response_data):
|
|
result = await adapter.query("测试", brand_name="华为")
|
|
|
|
assert result.response_time_ms >= 0
|
|
assert result.timestamp is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_query_with_rate_limiter(self):
|
|
mock_limiter = AsyncMock()
|
|
adapter = YuanbaoAdapter(api_key="test-key", rate_limiter=mock_limiter)
|
|
mock_response_data = {
|
|
"choices": [{"message": {"content": "测试回复"}}],
|
|
"model": "hunyuan-lite",
|
|
}
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.status_code = 200
|
|
mock_response.json.return_value = mock_response_data
|
|
|
|
with patch.object(adapter._client, "post", new_callable=AsyncMock, return_value=mock_response):
|
|
await adapter.query("测试", brand_name="华为")
|
|
|
|
mock_limiter.acquire.assert_awaited()
|
|
|
|
|
|
class TestYuanbaoAdapterErrorHandling:
|
|
@pytest.mark.asyncio
|
|
async def test_api_error_handling(self):
|
|
adapter = YuanbaoAdapter(api_key="test-key")
|
|
|
|
with patch.object(
|
|
adapter,
|
|
"_request_with_retry",
|
|
side_effect=Exception("HTTP 500: Internal Server Error"),
|
|
):
|
|
with pytest.raises(Exception, match="HTTP 500"):
|
|
await adapter.query("测试问题", brand_name="华为")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rate_limit_handling(self):
|
|
adapter = YuanbaoAdapter(api_key="test-key")
|
|
|
|
with patch.object(
|
|
adapter,
|
|
"_request_with_retry",
|
|
side_effect=Exception("HTTP 429: rate limited"),
|
|
):
|
|
with pytest.raises(Exception, match="429"):
|
|
await adapter.query("测试问题", brand_name="华为")
|
|
|
|
|
|
class TestYuanbaoAdapterChineseBrandDetection:
|
|
def test_chinese_brand_name_detection(self):
|
|
adapter = YuanbaoAdapter(api_key="test-key")
|
|
has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations(
|
|
"华为是全球领先的ICT基础设施和智能终端提供商",
|
|
brand_name="华为",
|
|
competitor_names=None,
|
|
)
|
|
assert has_brand is True
|
|
assert brand_ctx is not None
|
|
assert "华为" in brand_ctx
|
|
|
|
def test_chinese_competitor_detection(self):
|
|
adapter = YuanbaoAdapter(api_key="test-key")
|
|
has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations(
|
|
"华为和小米都是中国知名手机品牌",
|
|
brand_name="华为",
|
|
competitor_names=["小米"],
|
|
)
|
|
assert has_brand is True
|
|
assert has_comp is True
|
|
assert brand_ctx is not None
|
|
assert len(comp_ctx) > 0
|
|
|
|
def test_no_match_chinese(self):
|
|
adapter = YuanbaoAdapter(api_key="test-key")
|
|
has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations(
|
|
"今天天气很好",
|
|
brand_name="华为",
|
|
competitor_names=["小米"],
|
|
)
|
|
assert has_brand is False
|
|
assert has_comp is False
|
|
assert brand_ctx is None
|
|
assert len(comp_ctx) == 0
|
|
|
|
|
|
class TestYuanbaoAdapterInheritance:
|
|
def test_inherits_from_base(self):
|
|
assert issubclass(YuanbaoAdapter, AIEngineAdapter)
|
|
|
|
def test_has_query_method(self):
|
|
assert hasattr(YuanbaoAdapter, "query")
|
|
assert callable(getattr(YuanbaoAdapter, "query"))
|
|
|
|
def test_has_detect_citations(self):
|
|
assert hasattr(YuanbaoAdapter, "_detect_citations")
|
|
|
|
def test_has_get_engine_type(self):
|
|
adapter = YuanbaoAdapter(api_key="test-key")
|
|
assert adapter.get_engine_type() in EngineType
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_context_manager(self):
|
|
adapter = YuanbaoAdapter(api_key="test-key")
|
|
async with adapter as a:
|
|
assert a is adapter
|