geo/backend/tests/test_services/test_yuanbao_adapter.py

228 lines
8.2 KiB
Python

import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from app.services.ai_engine.base import (
AIEngineAdapter,
AIQueryResult,
EngineType,
)
from app.services.ai_engine.yuanbao import YuanbaoAdapter
class TestYuanbaoAdapterInitialization:
@pytest.mark.asyncio
async def test_initialization_with_api_key(self):
adapter = YuanbaoAdapter(api_key="test-hunyuan-key")
assert adapter.api_key == "test-hunyuan-key"
assert adapter._model == "hunyuan-lite"
assert adapter._base_url == "https://api.hunyuan.cloud.tencent.com/v1"
@pytest.mark.asyncio
async def test_initialization_with_custom_model(self):
adapter = YuanbaoAdapter(api_key="test-key", model="hunyuan-pro")
assert adapter._model == "hunyuan-pro"
@pytest.mark.asyncio
async def test_initialization_with_custom_base_url(self):
adapter = YuanbaoAdapter(
api_key="test-key",
base_url="https://custom-api.example.com/v1",
)
assert adapter._base_url == "https://custom-api.example.com/v1"
@pytest.mark.asyncio
async def test_initialization_with_env_vars(self):
with patch.dict("os.environ", {
"HUNYUAN_API_KEY": "env-key",
"HUNYUAN_MODEL": "hunyuan-turbo",
"HUNYUAN_BASE_URL": "https://env-url.example.com/v1",
}):
adapter = YuanbaoAdapter()
assert adapter.api_key == "env-key"
assert adapter._model == "hunyuan-turbo"
assert adapter._base_url == "https://env-url.example.com/v1"
class TestYuanbaoAdapterEngineType:
def test_get_engine_type_returns_yuanbao(self):
adapter = YuanbaoAdapter(api_key="test-key")
assert adapter.get_engine_type() == EngineType.YUANBAO
def test_engine_type_value(self):
assert EngineType.YUANBAO == "yuanbao"
def test_engine_type_in_enum(self):
assert EngineType.YUANBAO in EngineType
class TestYuanbaoAdapterQuery:
@pytest.mark.asyncio
async def test_query_returns_ai_query_result(self):
adapter = YuanbaoAdapter(api_key="test-key")
mock_response_data = {
"choices": [{"message": {"content": "华为是全球领先的ICT基础设施和智能终端提供商"}}],
"usage": {"total_tokens": 100},
"model": "hunyuan-lite",
}
with patch.object(adapter, "_request_with_retry", return_value=mock_response_data):
result = await adapter.query("华为公司", brand_name="华为")
assert isinstance(result, AIQueryResult)
assert result.engine_type == EngineType.YUANBAO
assert "华为" in result.raw_response
assert result.has_brand_citation is True
assert result.metadata.get("model") == "hunyuan-lite"
@pytest.mark.asyncio
async def test_query_with_competitor_detection(self):
adapter = YuanbaoAdapter(api_key="test-key")
mock_response_data = {
"choices": [{"message": {"content": "华为和小米都是中国知名手机品牌"}}],
"model": "hunyuan-lite",
}
with patch.object(adapter, "_request_with_retry", return_value=mock_response_data):
result = await adapter.query(
"手机品牌",
brand_name="华为",
competitor_names=["小米"],
)
assert result.has_brand_citation is True
assert result.has_competitor_citation is True
assert len(result.competitor_contexts) > 0
@pytest.mark.asyncio
async def test_query_no_brand_found(self):
adapter = YuanbaoAdapter(api_key="test-key")
mock_response_data = {
"choices": [{"message": {"content": "今天天气很好"}}],
"model": "hunyuan-lite",
}
with patch.object(adapter, "_request_with_retry", return_value=mock_response_data):
result = await adapter.query("天气", brand_name="华为")
assert result.has_brand_citation is False
assert result.has_competitor_citation is False
assert result.brand_context is None
assert result.competitor_contexts == []
@pytest.mark.asyncio
async def test_query_records_response_time(self):
adapter = YuanbaoAdapter(api_key="test-key")
mock_response_data = {
"choices": [{"message": {"content": "测试回复"}}],
"model": "hunyuan-lite",
}
with patch.object(adapter, "_request_with_retry", return_value=mock_response_data):
result = await adapter.query("测试", brand_name="华为")
assert result.response_time_ms >= 0
assert result.timestamp is not None
@pytest.mark.asyncio
async def test_query_with_rate_limiter(self):
mock_limiter = AsyncMock()
adapter = YuanbaoAdapter(api_key="test-key", rate_limiter=mock_limiter)
mock_response_data = {
"choices": [{"message": {"content": "测试回复"}}],
"model": "hunyuan-lite",
}
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = mock_response_data
with patch.object(adapter._client, "post", new_callable=AsyncMock, return_value=mock_response):
await adapter.query("测试", brand_name="华为")
mock_limiter.acquire.assert_awaited()
class TestYuanbaoAdapterErrorHandling:
@pytest.mark.asyncio
async def test_api_error_handling(self):
adapter = YuanbaoAdapter(api_key="test-key")
with patch.object(
adapter,
"_request_with_retry",
side_effect=Exception("HTTP 500: Internal Server Error"),
):
with pytest.raises(Exception, match="HTTP 500"):
await adapter.query("测试问题", brand_name="华为")
@pytest.mark.asyncio
async def test_rate_limit_handling(self):
adapter = YuanbaoAdapter(api_key="test-key")
with patch.object(
adapter,
"_request_with_retry",
side_effect=Exception("HTTP 429: rate limited"),
):
with pytest.raises(Exception, match="429"):
await adapter.query("测试问题", brand_name="华为")
class TestYuanbaoAdapterChineseBrandDetection:
def test_chinese_brand_name_detection(self):
adapter = YuanbaoAdapter(api_key="test-key")
has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations(
"华为是全球领先的ICT基础设施和智能终端提供商",
brand_name="华为",
competitor_names=None,
)
assert has_brand is True
assert brand_ctx is not None
assert "华为" in brand_ctx
def test_chinese_competitor_detection(self):
adapter = YuanbaoAdapter(api_key="test-key")
has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations(
"华为和小米都是中国知名手机品牌",
brand_name="华为",
competitor_names=["小米"],
)
assert has_brand is True
assert has_comp is True
assert brand_ctx is not None
assert len(comp_ctx) > 0
def test_no_match_chinese(self):
adapter = YuanbaoAdapter(api_key="test-key")
has_brand, has_comp, brand_ctx, comp_ctx = adapter._detect_citations(
"今天天气很好",
brand_name="华为",
competitor_names=["小米"],
)
assert has_brand is False
assert has_comp is False
assert brand_ctx is None
assert len(comp_ctx) == 0
class TestYuanbaoAdapterInheritance:
def test_inherits_from_base(self):
assert issubclass(YuanbaoAdapter, AIEngineAdapter)
def test_has_query_method(self):
assert hasattr(YuanbaoAdapter, "query")
assert callable(getattr(YuanbaoAdapter, "query"))
def test_has_detect_citations(self):
assert hasattr(YuanbaoAdapter, "_detect_citations")
def test_has_get_engine_type(self):
adapter = YuanbaoAdapter(api_key="test-key")
assert adapter.get_engine_type() in EngineType
@pytest.mark.asyncio
async def test_context_manager(self):
adapter = YuanbaoAdapter(api_key="test-key")
async with adapter as a:
assert a is adapter