340 lines
12 KiB
Python
340 lines
12 KiB
Python
"""
|
||
AI平台适配器测试 - 验证各平台适配器是否正常工作
|
||
|
||
测试内容:
|
||
1. Kimi适配器返回有效响应结构
|
||
2. 适配器限流处理
|
||
3. 引用提取器 - 5种提取方式
|
||
4. 适配器错误降级
|
||
"""
|
||
|
||
import pytest
|
||
from unittest.mock import Mock, patch, AsyncMock, MagicMock
|
||
|
||
from app.services.ai_engine.kimi import KimiAdapter
|
||
from app.services.ai_engine.wenxin import WenxinAdapter
|
||
from app.services.ai_engine.doubao import DoubaoAdapter
|
||
from app.workers.citation_extractor import (
|
||
extract_markdown_links,
|
||
extract_urls_with_context,
|
||
extract_footnotes,
|
||
extract_source_annotations,
|
||
extract_data_source,
|
||
analyze_citations,
|
||
CitationAnalysisResult,
|
||
ExtractedCitation,
|
||
)
|
||
|
||
|
||
class TestPlatformAdapters:
|
||
"""AI平台适配器测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_kimi_adapter_returns_valid_response(self):
|
||
"""Kimi适配器应返回有效响应结构"""
|
||
adapter = KimiAdapter()
|
||
|
||
# Mock API响应
|
||
mock_response_data = {
|
||
"choices": [{
|
||
"message": {
|
||
"content": "根据搜索结果,Apple是一家科技公司...来源: https://example.com"
|
||
}
|
||
}]
|
||
}
|
||
|
||
with patch.object(adapter, '_get_client') as mock_get_client:
|
||
mock_client = AsyncMock()
|
||
mock_response = Mock()
|
||
mock_response.status_code = 200
|
||
mock_response.json.return_value = mock_response_data
|
||
mock_client.post.return_value = mock_response
|
||
mock_get_client.return_value = mock_client
|
||
|
||
result = await adapter.query("Apple公司")
|
||
|
||
# 验证返回结构包含data_source标记或正常文本
|
||
assert result is not None
|
||
assert isinstance(result, str)
|
||
assert len(result) > 0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_kimi_adapter_handles_rate_limit(self):
|
||
"""Kimi适配器应处理限流(429状态码)"""
|
||
adapter = KimiAdapter()
|
||
|
||
with patch.object(adapter, '_get_client') as mock_get_client:
|
||
mock_client = AsyncMock()
|
||
mock_response = Mock()
|
||
mock_response.status_code = 429
|
||
mock_response.headers = {"Retry-After": "1"}
|
||
mock_client.post.return_value = mock_response
|
||
mock_get_client.return_value = mock_client
|
||
|
||
# 应该抛出RuntimeError并触发重试,最终回退到搜索引擎
|
||
result = await adapter.query("test")
|
||
|
||
# 验证最终有回退结果
|
||
assert result is not None
|
||
assert "search_engine" in result or "ai_platform" in result
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_kimi_fallback_to_search_engine(self):
|
||
"""Kimi未配置时应回退到搜索引擎"""
|
||
adapter = KimiAdapter()
|
||
|
||
# 模拟未配置API Key的情况 - patch api_key属性
|
||
with patch.object(adapter, '_api_key', ''):
|
||
result = await adapter.query("test keyword")
|
||
|
||
assert result is not None
|
||
assert "search_engine" in result
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_wenxin_adapter_response_structure(self):
|
||
"""文心适配器应返回有效响应"""
|
||
adapter = WenxinAdapter()
|
||
|
||
mock_response_data = {
|
||
"result": "文心一言回答内容,来源: https://example.com"
|
||
}
|
||
|
||
with patch.object(adapter, '_get_client') as mock_get_client:
|
||
mock_client = AsyncMock()
|
||
|
||
# Mock token请求
|
||
token_response = Mock()
|
||
token_response.status_code = 200
|
||
token_response.json.return_value = {"access_token": "test_token"}
|
||
|
||
# Mock chat请求
|
||
chat_response = Mock()
|
||
chat_response.status_code = 200
|
||
chat_response.json.return_value = mock_response_data
|
||
|
||
mock_client.post.side_effect = [token_response, chat_response]
|
||
mock_get_client.return_value = mock_client
|
||
|
||
result = await adapter.query("测试问题")
|
||
|
||
assert result is not None
|
||
assert isinstance(result, str)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_doubao_adapter_response_structure(self):
|
||
"""豆包适配器应返回有效响应"""
|
||
adapter = DoubaoAdapter()
|
||
|
||
mock_response_data = {
|
||
"choices": [{
|
||
"message": {
|
||
"content": "豆包回答内容,参考 https://example.com"
|
||
}
|
||
}]
|
||
}
|
||
|
||
with patch.object(adapter, '_get_client') as mock_get_client:
|
||
mock_client = AsyncMock()
|
||
mock_response = Mock()
|
||
mock_response.status_code = 200
|
||
mock_response.json.return_value = mock_response_data
|
||
mock_client.post.return_value = mock_response
|
||
mock_get_client.return_value = mock_client
|
||
|
||
result = await adapter.query("测试")
|
||
|
||
assert result is not None
|
||
assert isinstance(result, str)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_adapter_error_returns_fallback(self):
|
||
"""适配器错误时应返回降级结果而非抛出异常"""
|
||
adapter = KimiAdapter()
|
||
|
||
with patch.object(adapter, '_get_client') as mock_get_client:
|
||
mock_client = AsyncMock()
|
||
mock_response = Mock()
|
||
mock_response.status_code = 500
|
||
mock_response.text = "Internal Server Error"
|
||
mock_client.post.return_value = mock_response
|
||
mock_get_client.return_value = mock_client
|
||
|
||
# 应该捕获异常并返回降级结果
|
||
result = await adapter.query("test")
|
||
|
||
# 验证最终有回退结果而不是抛出异常
|
||
assert result is not None
|
||
assert "search_engine" in result
|
||
|
||
|
||
class TestCitationExtractor:
|
||
"""引用提取器测试 - 验证5种提取方式"""
|
||
|
||
def test_citation_extraction_from_markdown_link(self):
|
||
"""1. Markdown链接格式 [text](url)"""
|
||
text = "Apple是一家伟大的公司 [参考](https://example.com)"
|
||
result = analyze_citations(text)
|
||
|
||
assert len(result.citations) > 0, "应该提取到至少一个引用"
|
||
urls = [c.source_url for c in result.citations if c.source_url]
|
||
assert "https://example.com" in urls, "应包含预期URL"
|
||
|
||
def test_citation_extraction_from_bare_url(self):
|
||
"""2. 裸URL格式"""
|
||
text = "更多信息请访问 https://example.com 还有 https://test.com"
|
||
result = analyze_citations(text)
|
||
|
||
assert len(result.citations) > 0, "应该提取到至少一个裸URL"
|
||
urls = [c.source_url for c in result.citations if c.source_url]
|
||
assert "https://example.com" in urls
|
||
assert "https://test.com" in urls
|
||
|
||
def test_citation_extraction_from_footnote(self):
|
||
"""3. 脚注格式 [^n]"""
|
||
text = """
|
||
Apple是一家伟大的公司[1]
|
||
微软也是知名公司[2]
|
||
|
||
[1]: https://apple.com
|
||
[2]: https://microsoft.com
|
||
"""
|
||
citations = extract_footnotes(text)
|
||
|
||
assert len(citations) > 0, "应该提取到脚注引用"
|
||
urls = [c.source_url for c in citations if c.source_url]
|
||
assert "https://apple.com" in urls
|
||
assert "https://microsoft.com" in urls
|
||
|
||
def test_citation_extraction_from_source_annotation(self):
|
||
"""4. 来源标注格式 (来源: / 据...报道 / 参考: / 引用: / 出处:)"""
|
||
text = "Apple是一家伟大的公司。来源: https://source1.com。据新浪报道,..."
|
||
|
||
citations = extract_source_annotations(text)
|
||
|
||
# 来源标注可能没有URL但有标题
|
||
assert len(citations) > 0, "应该提取到来源标注"
|
||
|
||
def test_citation_extraction_from_data_source_marker(self):
|
||
"""5. data_source标记"""
|
||
text = "[data_source: ai_platform]\n这是AI生成的回答"
|
||
|
||
source, clean_text = extract_data_source(text)
|
||
|
||
assert source == "ai_platform", "应正确识别data_source"
|
||
assert "AI生成" in clean_text, "应正确清理文本"
|
||
|
||
def test_analyze_citations_complete_flow(self):
|
||
"""完整引用分析流程测试"""
|
||
text = "[data_source: ai_platform]\n\nApple是一家伟大的公司[1]。\n\n更多详情请访问 https://apple.com\n\n[1]: https://reference.com"
|
||
|
||
result = analyze_citations(text)
|
||
|
||
assert isinstance(result, CitationAnalysisResult)
|
||
assert result.data_source == "ai_platform"
|
||
assert len(result.citations) > 0, "应提取到引用"
|
||
assert "Apple是一家伟大的公司" in result.clean_response
|
||
|
||
def test_analyze_citations_empty_text(self):
|
||
"""空文本处理"""
|
||
result = analyze_citations("")
|
||
assert isinstance(result, CitationAnalysisResult)
|
||
assert result.data_source == "unknown"
|
||
assert len(result.citations) == 0
|
||
|
||
def test_citation_analyzer_complete_flow(self):
|
||
"""完整引用分析流程测试"""
|
||
text = "这是一个测试 [链接](https://test.com) 和裸URL https://bare.com"
|
||
|
||
result = analyze_citations(text)
|
||
|
||
assert isinstance(result, CitationAnalysisResult)
|
||
assert len(result.citations) > 0
|
||
|
||
def test_duplicate_url_deduplication(self):
|
||
"""重复URL应该去重"""
|
||
text = """
|
||
访问 https://example.com
|
||
再访问 https://example.com
|
||
[链接](https://example.com)
|
||
"""
|
||
|
||
citations = extract_urls_with_context(text)
|
||
urls = [c.source_url for c in citations if c.source_url]
|
||
|
||
# 去重后应该只有一个
|
||
assert urls.count("https://example.com") == 1
|
||
|
||
def test_markdown_link_priority(self):
|
||
"""Markdown链接应优先于裸URL(标题更丰富)"""
|
||
text = "[Apple官网](https://apple.com) 和裸URL https://microsoft.com"
|
||
|
||
md_links = extract_markdown_links(text)
|
||
bare_urls = extract_urls_with_context(text)
|
||
|
||
# Markdown链接应该有标题
|
||
assert len(md_links) > 0
|
||
assert md_links[0].source_title is not None
|
||
|
||
def test_url_cleanup_punctuation(self):
|
||
"""URL末尾标点符号应该被清理"""
|
||
text = "访问 https://example.com, 和 https://test.com; 结束"
|
||
|
||
citations = extract_urls_with_context(text)
|
||
urls = [c.source_url for c in citations if c.source_url]
|
||
|
||
# URL不应该以逗号或分号结尾
|
||
for url in urls:
|
||
assert not url.endswith(',')
|
||
assert not url.endswith(';')
|
||
|
||
|
||
class TestAdapterIntegration:
|
||
"""适配器集成测试 - 验证所有平台适配器"""
|
||
|
||
def test_all_adapters_inherit_base(self):
|
||
"""所有适配器应继承BasePlatformAdapter"""
|
||
adapters = [KimiAdapter, WenxinAdapter, DoubaoAdapter]
|
||
|
||
for adapter_cls in adapters:
|
||
instance = adapter_cls()
|
||
assert hasattr(instance, 'query')
|
||
assert hasattr(instance, 'platform_name')
|
||
assert hasattr(instance, 'platform_url')
|
||
assert callable(instance.query)
|
||
|
||
def test_adapter_has_required_properties(self):
|
||
"""适配器应具有必需的属性"""
|
||
adapter = KimiAdapter()
|
||
|
||
assert adapter.platform_name == "kimi"
|
||
assert adapter.platform_url == "https://kimi.moonshot.cn"
|
||
assert hasattr(adapter, 'is_configured')
|
||
assert hasattr(adapter, 'close')
|
||
|
||
def test_kimi_adapter_properties(self):
|
||
"""Kimi适配器特定属性"""
|
||
adapter = KimiAdapter()
|
||
|
||
assert adapter.platform_name == "kimi"
|
||
# is_configured取决于API Key是否设置
|
||
assert isinstance(adapter.is_configured, bool)
|
||
|
||
def test_wenxin_adapter_properties(self):
|
||
"""文心适配器特定属性"""
|
||
adapter = WenxinAdapter()
|
||
|
||
assert adapter.platform_name == "wenxin"
|
||
assert adapter.platform_url == "https://yiyan.baidu.com"
|
||
assert hasattr(adapter, 'secret_key')
|
||
|
||
def test_doubao_adapter_properties(self):
|
||
"""豆包适配器特定属性"""
|
||
adapter = DoubaoAdapter()
|
||
|
||
assert adapter.platform_name == "doubao"
|
||
assert hasattr(adapter, 'endpoint_id')
|
||
|
||
|
||
if __name__ == "__main__":
|
||
pytest.main([__file__, "-v"])
|