311 lines
11 KiB
Python
311 lines
11 KiB
Python
"""
|
||
AI平台适配器测试 - 验证各平台适配器是否正常工作
|
||
|
||
测试内容:
|
||
1. Kimi适配器返回有效响应结构
|
||
2. 适配器限流处理
|
||
3. 引用提取器 - 5种提取方式
|
||
4. 适配器错误降级
|
||
"""
|
||
|
||
import pytest
|
||
from unittest.mock import Mock, patch, AsyncMock, MagicMock
|
||
|
||
from app.services.ai_engine.kimi import KimiAdapter
|
||
from app.services.ai_engine.wenxin import WenxinAdapter
|
||
from app.services.ai_engine.doubao import DoubaoAdapter
|
||
from app.services.ai_engine.base import AIEngineAdapter, AIQueryResult, EngineType
|
||
from app.workers.citation_extractor import (
|
||
extract_markdown_links,
|
||
extract_urls_with_context,
|
||
extract_footnotes,
|
||
extract_source_annotations,
|
||
extract_data_source,
|
||
analyze_citations,
|
||
CitationAnalysisResult,
|
||
ExtractedCitation,
|
||
)
|
||
|
||
|
||
class TestPlatformAdapters:
|
||
"""AI平台适配器测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_kimi_adapter_returns_valid_response(self):
|
||
"""Kimi适配器应返回有效响应结构"""
|
||
adapter = KimiAdapter()
|
||
|
||
# Mock API响应
|
||
mock_response_data = {
|
||
"choices": [{
|
||
"message": {
|
||
"content": "根据搜索结果,Apple是一家科技公司...来源: https://example.com"
|
||
}
|
||
}],
|
||
"usage": {"prompt_tokens": 10, "completion_tokens": 20},
|
||
}
|
||
|
||
with patch.object(adapter, '_request_with_retry', new_callable=AsyncMock) as mock_retry:
|
||
mock_retry.return_value = mock_response_data
|
||
|
||
result = await adapter.query("Apple公司", "Apple", ["Samsung"])
|
||
|
||
# 验证返回结构
|
||
assert result is not None
|
||
assert isinstance(result, AIQueryResult)
|
||
assert result.engine_type == EngineType.KIMI
|
||
assert isinstance(result.raw_response, str)
|
||
assert len(result.raw_response) > 0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_kimi_adapter_handles_rate_limit(self):
|
||
"""Kimi适配器应处理限流(429状态码)"""
|
||
adapter = KimiAdapter()
|
||
|
||
with patch.object(adapter, '_request_with_retry', new_callable=AsyncMock) as mock_retry:
|
||
mock_retry.side_effect = Exception("HTTP 429: Rate limited")
|
||
|
||
# 应该抛出异常(重试耗尽后)
|
||
with pytest.raises(Exception, match="429|Rate limited"):
|
||
await adapter.query("test", "test_brand")
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_kimi_fallback_to_search_engine(self):
|
||
"""Kimi未配置时应使用空API Key"""
|
||
adapter = KimiAdapter(api_key="")
|
||
|
||
# 验证api_key为空
|
||
assert adapter.api_key == ""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_wenxin_adapter_response_structure(self):
|
||
"""文心适配器应返回有效响应"""
|
||
adapter = WenxinAdapter()
|
||
|
||
mock_token = "test_access_token"
|
||
mock_response_data = {
|
||
"result": "文心一言回答内容,来源: https://example.com",
|
||
"usage": {"prompt_tokens": 10, "completion_tokens": 20},
|
||
}
|
||
|
||
with patch.object(adapter, '_get_access_token', new_callable=AsyncMock) as mock_token_fn:
|
||
mock_token_fn.return_value = mock_token
|
||
|
||
mock_response = Mock()
|
||
mock_response.status_code = 200
|
||
mock_response.json.return_value = mock_response_data
|
||
|
||
with patch.object(adapter._client, 'post', new_callable=AsyncMock) as mock_post:
|
||
mock_post.return_value = mock_response
|
||
|
||
result = await adapter.query("测试问题", "测试品牌")
|
||
|
||
assert result is not None
|
||
assert isinstance(result, AIQueryResult)
|
||
assert result.engine_type == EngineType.WENXIN
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_doubao_adapter_response_structure(self):
|
||
"""豆包适配器应返回有效响应"""
|
||
adapter = DoubaoAdapter()
|
||
|
||
mock_response_data = {
|
||
"choices": [{
|
||
"message": {
|
||
"content": "豆包回答内容,参考 https://example.com"
|
||
}
|
||
}],
|
||
"usage": {"prompt_tokens": 10, "completion_tokens": 20},
|
||
}
|
||
|
||
with patch.object(adapter, '_request_with_retry', new_callable=AsyncMock) as mock_retry:
|
||
mock_retry.return_value = mock_response_data
|
||
|
||
result = await adapter.query("测试", "测试品牌")
|
||
|
||
assert result is not None
|
||
assert isinstance(result, AIQueryResult)
|
||
assert result.engine_type == EngineType.DOUBAO
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_adapter_error_returns_fallback(self):
|
||
"""适配器错误时应抛出异常"""
|
||
adapter = KimiAdapter()
|
||
|
||
with patch.object(adapter, '_request_with_retry', new_callable=AsyncMock) as mock_retry:
|
||
mock_retry.side_effect = Exception("HTTP 500: Internal Server Error")
|
||
|
||
# 重试耗尽后应抛出异常
|
||
with pytest.raises(Exception, match="500|Internal Server Error"):
|
||
await adapter.query("test", "test_brand")
|
||
|
||
|
||
class TestCitationExtractor:
|
||
"""引用提取器测试 - 验证5种提取方式"""
|
||
|
||
def test_citation_extraction_from_markdown_link(self):
|
||
"""1. Markdown链接格式 [text](url)"""
|
||
text = "Apple是一家伟大的公司 [参考](https://example.com)"
|
||
result = analyze_citations(text)
|
||
|
||
assert len(result.citations) > 0, "应该提取到至少一个引用"
|
||
urls = [c.source_url for c in result.citations if c.source_url]
|
||
assert "https://example.com" in urls, "应包含预期URL"
|
||
|
||
def test_citation_extraction_from_bare_url(self):
|
||
"""2. 裸URL格式"""
|
||
text = "更多信息请访问 https://example.com 还有 https://test.com"
|
||
result = analyze_citations(text)
|
||
|
||
assert len(result.citations) > 0, "应该提取到至少一个裸URL"
|
||
urls = [c.source_url for c in result.citations if c.source_url]
|
||
assert "https://example.com" in urls
|
||
assert "https://test.com" in urls
|
||
|
||
def test_citation_extraction_from_footnote(self):
|
||
"""3. 脚注格式 [^n]"""
|
||
text = """
|
||
Apple是一家伟大的公司[1]
|
||
微软也是知名公司[2]
|
||
|
||
[1]: https://apple.com
|
||
[2]: https://microsoft.com
|
||
"""
|
||
citations = extract_footnotes(text)
|
||
|
||
assert len(citations) > 0, "应该提取到脚注引用"
|
||
urls = [c.source_url for c in citations if c.source_url]
|
||
assert "https://apple.com" in urls
|
||
assert "https://microsoft.com" in urls
|
||
|
||
def test_citation_extraction_from_source_annotation(self):
|
||
"""4. 来源标注格式 (来源: / 据...报道 / 参考: / 引用: / 出处:)"""
|
||
text = "Apple是一家伟大的公司。来源: https://source1.com。据新浪报道,..."
|
||
|
||
citations = extract_source_annotations(text)
|
||
|
||
# 来源标注可能没有URL但有标题
|
||
assert len(citations) > 0, "应该提取到来源标注"
|
||
|
||
def test_citation_extraction_from_data_source_marker(self):
|
||
"""5. data_source标记"""
|
||
text = "[data_source: ai_platform]\n这是AI生成的回答"
|
||
|
||
source, clean_text = extract_data_source(text)
|
||
|
||
assert source == "ai_platform", "应正确识别data_source"
|
||
assert "AI生成" in clean_text, "应正确清理文本"
|
||
|
||
def test_analyze_citations_complete_flow(self):
|
||
"""完整引用分析流程测试"""
|
||
text = "[data_source: ai_platform]\n\nApple是一家伟大的公司[1]。\n\n更多详情请访问 https://apple.com\n\n[1]: https://reference.com"
|
||
|
||
result = analyze_citations(text)
|
||
|
||
assert isinstance(result, CitationAnalysisResult)
|
||
assert result.data_source == "ai_platform"
|
||
assert len(result.citations) > 0, "应提取到引用"
|
||
assert "Apple是一家伟大的公司" in result.clean_response
|
||
|
||
def test_analyze_citations_empty_text(self):
|
||
"""空文本处理"""
|
||
result = analyze_citations("")
|
||
assert isinstance(result, CitationAnalysisResult)
|
||
assert result.data_source == "unknown"
|
||
assert len(result.citations) == 0
|
||
|
||
def test_citation_analyzer_complete_flow(self):
|
||
"""完整引用分析流程测试"""
|
||
text = "这是一个测试 [链接](https://test.com) 和裸URL https://bare.com"
|
||
|
||
result = analyze_citations(text)
|
||
|
||
assert isinstance(result, CitationAnalysisResult)
|
||
assert len(result.citations) > 0
|
||
|
||
def test_duplicate_url_deduplication(self):
|
||
"""重复URL应该去重"""
|
||
text = """
|
||
访问 https://example.com
|
||
再访问 https://example.com
|
||
[链接](https://example.com)
|
||
"""
|
||
|
||
citations = extract_urls_with_context(text)
|
||
urls = [c.source_url for c in citations if c.source_url]
|
||
|
||
# 去重后应该只有一个
|
||
assert urls.count("https://example.com") == 1
|
||
|
||
def test_markdown_link_priority(self):
|
||
"""Markdown链接应优先于裸URL(标题更丰富)"""
|
||
text = "[Apple官网](https://apple.com) 和裸URL https://microsoft.com"
|
||
|
||
md_links = extract_markdown_links(text)
|
||
bare_urls = extract_urls_with_context(text)
|
||
|
||
# Markdown链接应该有标题
|
||
assert len(md_links) > 0
|
||
assert md_links[0].source_title is not None
|
||
|
||
def test_url_cleanup_punctuation(self):
|
||
"""URL末尾标点符号应该被清理"""
|
||
text = "访问 https://example.com, 和 https://test.com; 结束"
|
||
|
||
citations = extract_urls_with_context(text)
|
||
urls = [c.source_url for c in citations if c.source_url]
|
||
|
||
# URL不应该以逗号或分号结尾
|
||
for url in urls:
|
||
assert not url.endswith(',')
|
||
assert not url.endswith(';')
|
||
|
||
|
||
class TestAdapterIntegration:
|
||
"""适配器集成测试 - 验证所有平台适配器"""
|
||
|
||
def test_all_adapters_inherit_base(self):
|
||
"""所有适配器应继承AIEngineAdapter"""
|
||
adapters = [KimiAdapter, WenxinAdapter, DoubaoAdapter]
|
||
|
||
for adapter_cls in adapters:
|
||
assert issubclass(adapter_cls, AIEngineAdapter)
|
||
instance = adapter_cls()
|
||
assert hasattr(instance, 'query')
|
||
assert hasattr(instance, 'get_engine_type')
|
||
assert callable(instance.query)
|
||
|
||
def test_adapter_has_required_properties(self):
|
||
"""适配器应具有必需的属性"""
|
||
adapter = KimiAdapter()
|
||
|
||
assert adapter.get_engine_type() == EngineType.KIMI
|
||
assert hasattr(adapter, 'is_configured') or hasattr(adapter, 'api_key')
|
||
assert hasattr(adapter, 'close')
|
||
|
||
def test_kimi_adapter_properties(self):
|
||
"""Kimi适配器特定属性"""
|
||
adapter = KimiAdapter()
|
||
|
||
assert adapter.get_engine_type() == EngineType.KIMI
|
||
# is_configured取决于API Key是否设置
|
||
assert isinstance(adapter.api_key, str)
|
||
|
||
def test_wenxin_adapter_properties(self):
|
||
"""文心适配器特定属性"""
|
||
adapter = WenxinAdapter()
|
||
|
||
assert adapter.get_engine_type() == EngineType.WENXIN
|
||
assert hasattr(adapter, 'secret_key')
|
||
|
||
def test_doubao_adapter_properties(self):
|
||
"""豆包适配器特定属性"""
|
||
adapter = DoubaoAdapter()
|
||
|
||
assert adapter.get_engine_type() == EngineType.DOUBAO
|
||
assert hasattr(adapter, '_endpoint_id')
|
||
|
||
|
||
if __name__ == "__main__":
|
||
pytest.main([__file__, "-v"])
|