166 lines
5.6 KiB
Python
166 lines
5.6 KiB
Python
import pytest
|
||
from unittest.mock import AsyncMock, patch, MagicMock
|
||
from app.workers.llm_adapter import LLMAdapter, LLMAdapterError, BRAND_CITATION_PROMPT
|
||
|
||
|
||
class TestLLMAdapter:
|
||
"""LLM适配器测试"""
|
||
|
||
@pytest.fixture
|
||
def llm_adapter(self):
|
||
"""创建LLM适配器实例"""
|
||
return LLMAdapter()
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_llm_adapter_cited_brand(self, llm_adapter):
|
||
"""测试检测到品牌引用"""
|
||
mock_response = {
|
||
"cited": True,
|
||
"position": 1,
|
||
"citation_text": "XXX是一款非常优秀的品牌产品",
|
||
"sentiment": "positive",
|
||
"confidence": 0.95
|
||
}
|
||
|
||
with patch.object(llm_adapter, '_call_deepseek', new_callable=AsyncMock) as mock_call:
|
||
mock_call.return_value = mock_response
|
||
|
||
result = await llm_adapter.query_brand_citation(
|
||
keyword="AI搜索",
|
||
brand_name="XXX",
|
||
brand_aliases=["品牌别名1", "品牌别名2"]
|
||
)
|
||
|
||
assert result.cited is True
|
||
assert result.position == 1
|
||
assert result.citation_text == "XXX是一款非常优秀的品牌产品"
|
||
assert result.sentiment == "positive"
|
||
assert result.confidence == 0.95
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_llm_adapter_not_cited(self, llm_adapter):
|
||
"""测试未检测到品牌引用"""
|
||
mock_response = {
|
||
"cited": False,
|
||
"position": None,
|
||
"citation_text": None,
|
||
"sentiment": "neutral",
|
||
"confidence": 0.90
|
||
}
|
||
|
||
with patch.object(llm_adapter, '_call_deepseek', new_callable=AsyncMock) as mock_call:
|
||
mock_call.return_value = mock_response
|
||
|
||
result = await llm_adapter.query_brand_citation(
|
||
keyword="AI搜索",
|
||
brand_name="YYY",
|
||
brand_aliases=[]
|
||
)
|
||
|
||
assert result.cited is False
|
||
assert result.position is None
|
||
assert result.citation_text is None
|
||
assert result.sentiment == "neutral"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_llm_adapter_sentiment_positive(self, llm_adapter):
|
||
"""测试正面情感"""
|
||
mock_response = {
|
||
"cited": True,
|
||
"position": 2,
|
||
"citation_text": "YYY品牌产品质量非常好,用户口碑极佳",
|
||
"sentiment": "positive",
|
||
"confidence": 0.92
|
||
}
|
||
|
||
with patch.object(llm_adapter, '_call_deepseek', new_callable=AsyncMock) as mock_call:
|
||
mock_call.return_value = mock_response
|
||
|
||
result = await llm_adapter.query_brand_citation(
|
||
keyword="AI搜索",
|
||
brand_name="YYY",
|
||
brand_aliases=[]
|
||
)
|
||
|
||
assert result.sentiment == "positive"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_llm_adapter_sentiment_negative(self, llm_adapter):
|
||
"""测试负面情感"""
|
||
mock_response = {
|
||
"cited": True,
|
||
"position": 3,
|
||
"citation_text": "ZZZ品牌存在质量问题,遭到用户投诉",
|
||
"sentiment": "negative",
|
||
"confidence": 0.88
|
||
}
|
||
|
||
with patch.object(llm_adapter, '_call_deepseek', new_callable=AsyncMock) as mock_call:
|
||
mock_call.return_value = mock_response
|
||
|
||
result = await llm_adapter.query_brand_citation(
|
||
keyword="AI搜索",
|
||
brand_name="ZZZ",
|
||
brand_aliases=[]
|
||
)
|
||
|
||
assert result.sentiment == "negative"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_llm_adapter_api_error_retry(self, llm_adapter):
|
||
"""测试API错误重试"""
|
||
mock_success_response = {
|
||
"cited": True,
|
||
"position": 1,
|
||
"citation_text": "测试文本",
|
||
"sentiment": "neutral",
|
||
"confidence": 0.90
|
||
}
|
||
|
||
with patch.object(llm_adapter, '_call_deepseek', new_callable=AsyncMock) as mock_call:
|
||
# 模拟前两次失败,第三次成功
|
||
mock_call.side_effect = [
|
||
Exception("API调用失败"),
|
||
Exception("API调用失败"),
|
||
mock_success_response
|
||
]
|
||
|
||
result = await llm_adapter.query_brand_citation(
|
||
keyword="AI搜索",
|
||
brand_name="测试品牌",
|
||
brand_aliases=[]
|
||
)
|
||
|
||
assert result.cited is True
|
||
assert mock_call.call_count == 3
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_llm_adapter_parse_error(self, llm_adapter):
|
||
"""测试响应解析错误"""
|
||
with patch.object(llm_adapter, '_call_deepseek', new_callable=AsyncMock) as mock_call:
|
||
mock_call.return_value = {"invalid": "response"}
|
||
|
||
with pytest.raises(LLMAdapterError) as exc_info:
|
||
await llm_adapter.query_brand_citation(
|
||
keyword="AI搜索",
|
||
brand_name="测试品牌",
|
||
brand_aliases=[]
|
||
)
|
||
|
||
# 错误消息应该包含字段缺失或解析失败相关提示
|
||
error_msg = str(exc_info.value)
|
||
assert "响应缺少必需字段" in error_msg or "解析响应失败" in error_msg
|
||
|
||
def test_build_prompt(self, llm_adapter):
|
||
"""测试Prompt构建"""
|
||
prompt = llm_adapter._build_prompt(
|
||
keyword="AI搜索",
|
||
brand_name="测试品牌",
|
||
brand_aliases=["别名1", "别名2"]
|
||
)
|
||
|
||
assert "AI搜索" in prompt
|
||
assert "测试品牌" in prompt
|
||
assert "别名1" in prompt
|
||
assert "别名2" in prompt
|