geo/backend/tests/test_platform_adapters.py

340 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
AI平台适配器测试 - 验证各平台适配器是否正常工作
测试内容:
1. Kimi适配器返回有效响应结构
2. 适配器限流处理
3. 引用提取器 - 5种提取方式
4. 适配器错误降级
"""
import pytest
from unittest.mock import Mock, patch, AsyncMock, MagicMock
from app.services.ai_engine.kimi import KimiAdapter
from app.services.ai_engine.wenxin import WenxinAdapter
from app.services.ai_engine.doubao import DoubaoAdapter
from app.workers.citation_extractor import (
extract_markdown_links,
extract_urls_with_context,
extract_footnotes,
extract_source_annotations,
extract_data_source,
analyze_citations,
CitationAnalysisResult,
ExtractedCitation,
)
class TestPlatformAdapters:
"""AI平台适配器测试"""
@pytest.mark.asyncio
async def test_kimi_adapter_returns_valid_response(self):
"""Kimi适配器应返回有效响应结构"""
adapter = KimiAdapter()
# Mock API响应
mock_response_data = {
"choices": [{
"message": {
"content": "根据搜索结果Apple是一家科技公司...来源: https://example.com"
}
}]
}
with patch.object(adapter, '_get_client') as mock_get_client:
mock_client = AsyncMock()
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = mock_response_data
mock_client.post.return_value = mock_response
mock_get_client.return_value = mock_client
result = await adapter.query("Apple公司")
# 验证返回结构包含data_source标记或正常文本
assert result is not None
assert isinstance(result, str)
assert len(result) > 0
@pytest.mark.asyncio
async def test_kimi_adapter_handles_rate_limit(self):
"""Kimi适配器应处理限流429状态码"""
adapter = KimiAdapter()
with patch.object(adapter, '_get_client') as mock_get_client:
mock_client = AsyncMock()
mock_response = Mock()
mock_response.status_code = 429
mock_response.headers = {"Retry-After": "1"}
mock_client.post.return_value = mock_response
mock_get_client.return_value = mock_client
# 应该抛出RuntimeError并触发重试最终回退到搜索引擎
result = await adapter.query("test")
# 验证最终有回退结果
assert result is not None
assert "search_engine" in result or "ai_platform" in result
@pytest.mark.asyncio
async def test_kimi_fallback_to_search_engine(self):
"""Kimi未配置时应回退到搜索引擎"""
adapter = KimiAdapter()
# 模拟未配置API Key的情况 - patch api_key属性
with patch.object(adapter, '_api_key', ''):
result = await adapter.query("test keyword")
assert result is not None
assert "search_engine" in result
@pytest.mark.asyncio
async def test_wenxin_adapter_response_structure(self):
"""文心适配器应返回有效响应"""
adapter = WenxinAdapter()
mock_response_data = {
"result": "文心一言回答内容,来源: https://example.com"
}
with patch.object(adapter, '_get_client') as mock_get_client:
mock_client = AsyncMock()
# Mock token请求
token_response = Mock()
token_response.status_code = 200
token_response.json.return_value = {"access_token": "test_token"}
# Mock chat请求
chat_response = Mock()
chat_response.status_code = 200
chat_response.json.return_value = mock_response_data
mock_client.post.side_effect = [token_response, chat_response]
mock_get_client.return_value = mock_client
result = await adapter.query("测试问题")
assert result is not None
assert isinstance(result, str)
@pytest.mark.asyncio
async def test_doubao_adapter_response_structure(self):
"""豆包适配器应返回有效响应"""
adapter = DoubaoAdapter()
mock_response_data = {
"choices": [{
"message": {
"content": "豆包回答内容,参考 https://example.com"
}
}]
}
with patch.object(adapter, '_get_client') as mock_get_client:
mock_client = AsyncMock()
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = mock_response_data
mock_client.post.return_value = mock_response
mock_get_client.return_value = mock_client
result = await adapter.query("测试")
assert result is not None
assert isinstance(result, str)
@pytest.mark.asyncio
async def test_adapter_error_returns_fallback(self):
"""适配器错误时应返回降级结果而非抛出异常"""
adapter = KimiAdapter()
with patch.object(adapter, '_get_client') as mock_get_client:
mock_client = AsyncMock()
mock_response = Mock()
mock_response.status_code = 500
mock_response.text = "Internal Server Error"
mock_client.post.return_value = mock_response
mock_get_client.return_value = mock_client
# 应该捕获异常并返回降级结果
result = await adapter.query("test")
# 验证最终有回退结果而不是抛出异常
assert result is not None
assert "search_engine" in result
class TestCitationExtractor:
"""引用提取器测试 - 验证5种提取方式"""
def test_citation_extraction_from_markdown_link(self):
"""1. Markdown链接格式 [text](url)"""
text = "Apple是一家伟大的公司 [参考](https://example.com)"
result = analyze_citations(text)
assert len(result.citations) > 0, "应该提取到至少一个引用"
urls = [c.source_url for c in result.citations if c.source_url]
assert "https://example.com" in urls, "应包含预期URL"
def test_citation_extraction_from_bare_url(self):
"""2. 裸URL格式"""
text = "更多信息请访问 https://example.com 还有 https://test.com"
result = analyze_citations(text)
assert len(result.citations) > 0, "应该提取到至少一个裸URL"
urls = [c.source_url for c in result.citations if c.source_url]
assert "https://example.com" in urls
assert "https://test.com" in urls
def test_citation_extraction_from_footnote(self):
"""3. 脚注格式 [^n]"""
text = """
Apple是一家伟大的公司[1]
微软也是知名公司[2]
[1]: https://apple.com
[2]: https://microsoft.com
"""
citations = extract_footnotes(text)
assert len(citations) > 0, "应该提取到脚注引用"
urls = [c.source_url for c in citations if c.source_url]
assert "https://apple.com" in urls
assert "https://microsoft.com" in urls
def test_citation_extraction_from_source_annotation(self):
"""4. 来源标注格式 (来源: / 据...报道 / 参考: / 引用: / 出处:)"""
text = "Apple是一家伟大的公司。来源: https://source1.com。据新浪报道..."
citations = extract_source_annotations(text)
# 来源标注可能没有URL但有标题
assert len(citations) > 0, "应该提取到来源标注"
def test_citation_extraction_from_data_source_marker(self):
"""5. data_source标记"""
text = "[data_source: ai_platform]\n这是AI生成的回答"
source, clean_text = extract_data_source(text)
assert source == "ai_platform", "应正确识别data_source"
assert "AI生成" in clean_text, "应正确清理文本"
def test_analyze_citations_complete_flow(self):
"""完整引用分析流程测试"""
text = "[data_source: ai_platform]\n\nApple是一家伟大的公司[1]。\n\n更多详情请访问 https://apple.com\n\n[1]: https://reference.com"
result = analyze_citations(text)
assert isinstance(result, CitationAnalysisResult)
assert result.data_source == "ai_platform"
assert len(result.citations) > 0, "应提取到引用"
assert "Apple是一家伟大的公司" in result.clean_response
def test_analyze_citations_empty_text(self):
"""空文本处理"""
result = analyze_citations("")
assert isinstance(result, CitationAnalysisResult)
assert result.data_source == "unknown"
assert len(result.citations) == 0
def test_citation_analyzer_complete_flow(self):
"""完整引用分析流程测试"""
text = "这是一个测试 [链接](https://test.com) 和裸URL https://bare.com"
result = analyze_citations(text)
assert isinstance(result, CitationAnalysisResult)
assert len(result.citations) > 0
def test_duplicate_url_deduplication(self):
"""重复URL应该去重"""
text = """
访问 https://example.com
再访问 https://example.com
[链接](https://example.com)
"""
citations = extract_urls_with_context(text)
urls = [c.source_url for c in citations if c.source_url]
# 去重后应该只有一个
assert urls.count("https://example.com") == 1
def test_markdown_link_priority(self):
"""Markdown链接应优先于裸URL标题更丰富"""
text = "[Apple官网](https://apple.com) 和裸URL https://microsoft.com"
md_links = extract_markdown_links(text)
bare_urls = extract_urls_with_context(text)
# Markdown链接应该有标题
assert len(md_links) > 0
assert md_links[0].source_title is not None
def test_url_cleanup_punctuation(self):
"""URL末尾标点符号应该被清理"""
text = "访问 https://example.com, 和 https://test.com; 结束"
citations = extract_urls_with_context(text)
urls = [c.source_url for c in citations if c.source_url]
# URL不应该以逗号或分号结尾
for url in urls:
assert not url.endswith(',')
assert not url.endswith(';')
class TestAdapterIntegration:
"""适配器集成测试 - 验证所有平台适配器"""
def test_all_adapters_inherit_base(self):
"""所有适配器应继承BasePlatformAdapter"""
adapters = [KimiAdapter, WenxinAdapter, DoubaoAdapter]
for adapter_cls in adapters:
instance = adapter_cls()
assert hasattr(instance, 'query')
assert hasattr(instance, 'platform_name')
assert hasattr(instance, 'platform_url')
assert callable(instance.query)
def test_adapter_has_required_properties(self):
"""适配器应具有必需的属性"""
adapter = KimiAdapter()
assert adapter.platform_name == "kimi"
assert adapter.platform_url == "https://kimi.moonshot.cn"
assert hasattr(adapter, 'is_configured')
assert hasattr(adapter, 'close')
def test_kimi_adapter_properties(self):
"""Kimi适配器特定属性"""
adapter = KimiAdapter()
assert adapter.platform_name == "kimi"
# is_configured取决于API Key是否设置
assert isinstance(adapter.is_configured, bool)
def test_wenxin_adapter_properties(self):
"""文心适配器特定属性"""
adapter = WenxinAdapter()
assert adapter.platform_name == "wenxin"
assert adapter.platform_url == "https://yiyan.baidu.com"
assert hasattr(adapter, 'secret_key')
def test_doubao_adapter_properties(self):
"""豆包适配器特定属性"""
adapter = DoubaoAdapter()
assert adapter.platform_name == "doubao"
assert hasattr(adapter, 'endpoint_id')
if __name__ == "__main__":
pytest.main([__file__, "-v"])