""" AI平台适配器测试 - 验证各平台适配器是否正常工作 测试内容: 1. Kimi适配器返回有效响应结构 2. 适配器限流处理 3. 引用提取器 - 5种提取方式 4. 适配器错误降级 """ import pytest from unittest.mock import Mock, patch, AsyncMock, MagicMock from app.services.ai_engine.kimi import KimiAdapter from app.services.ai_engine.wenxin import WenxinAdapter from app.services.ai_engine.doubao import DoubaoAdapter from app.workers.citation_extractor import ( extract_markdown_links, extract_urls_with_context, extract_footnotes, extract_source_annotations, extract_data_source, analyze_citations, CitationAnalysisResult, ExtractedCitation, ) class TestPlatformAdapters: """AI平台适配器测试""" @pytest.mark.asyncio async def test_kimi_adapter_returns_valid_response(self): """Kimi适配器应返回有效响应结构""" adapter = KimiAdapter() # Mock API响应 mock_response_data = { "choices": [{ "message": { "content": "根据搜索结果,Apple是一家科技公司...来源: https://example.com" } }] } with patch.object(adapter, '_get_client') as mock_get_client: mock_client = AsyncMock() mock_response = Mock() mock_response.status_code = 200 mock_response.json.return_value = mock_response_data mock_client.post.return_value = mock_response mock_get_client.return_value = mock_client result = await adapter.query("Apple公司") # 验证返回结构包含data_source标记或正常文本 assert result is not None assert isinstance(result, str) assert len(result) > 0 @pytest.mark.asyncio async def test_kimi_adapter_handles_rate_limit(self): """Kimi适配器应处理限流(429状态码)""" adapter = KimiAdapter() with patch.object(adapter, '_get_client') as mock_get_client: mock_client = AsyncMock() mock_response = Mock() mock_response.status_code = 429 mock_response.headers = {"Retry-After": "1"} mock_client.post.return_value = mock_response mock_get_client.return_value = mock_client # 应该抛出RuntimeError并触发重试,最终回退到搜索引擎 result = await adapter.query("test") # 验证最终有回退结果 assert result is not None assert "search_engine" in result or "ai_platform" in result @pytest.mark.asyncio async def test_kimi_fallback_to_search_engine(self): """Kimi未配置时应回退到搜索引擎""" adapter = KimiAdapter() # 模拟未配置API Key的情况 - patch api_key属性 with patch.object(adapter, '_api_key', ''): result = await adapter.query("test keyword") assert result is not None assert "search_engine" in result @pytest.mark.asyncio async def test_wenxin_adapter_response_structure(self): """文心适配器应返回有效响应""" adapter = WenxinAdapter() mock_response_data = { "result": "文心一言回答内容,来源: https://example.com" } with patch.object(adapter, '_get_client') as mock_get_client: mock_client = AsyncMock() # Mock token请求 token_response = Mock() token_response.status_code = 200 token_response.json.return_value = {"access_token": "test_token"} # Mock chat请求 chat_response = Mock() chat_response.status_code = 200 chat_response.json.return_value = mock_response_data mock_client.post.side_effect = [token_response, chat_response] mock_get_client.return_value = mock_client result = await adapter.query("测试问题") assert result is not None assert isinstance(result, str) @pytest.mark.asyncio async def test_doubao_adapter_response_structure(self): """豆包适配器应返回有效响应""" adapter = DoubaoAdapter() mock_response_data = { "choices": [{ "message": { "content": "豆包回答内容,参考 https://example.com" } }] } with patch.object(adapter, '_get_client') as mock_get_client: mock_client = AsyncMock() mock_response = Mock() mock_response.status_code = 200 mock_response.json.return_value = mock_response_data mock_client.post.return_value = mock_response mock_get_client.return_value = mock_client result = await adapter.query("测试") assert result is not None assert isinstance(result, str) @pytest.mark.asyncio async def test_adapter_error_returns_fallback(self): """适配器错误时应返回降级结果而非抛出异常""" adapter = KimiAdapter() with patch.object(adapter, '_get_client') as mock_get_client: mock_client = AsyncMock() mock_response = Mock() mock_response.status_code = 500 mock_response.text = "Internal Server Error" mock_client.post.return_value = mock_response mock_get_client.return_value = mock_client # 应该捕获异常并返回降级结果 result = await adapter.query("test") # 验证最终有回退结果而不是抛出异常 assert result is not None assert "search_engine" in result class TestCitationExtractor: """引用提取器测试 - 验证5种提取方式""" def test_citation_extraction_from_markdown_link(self): """1. Markdown链接格式 [text](url)""" text = "Apple是一家伟大的公司 [参考](https://example.com)" result = analyze_citations(text) assert len(result.citations) > 0, "应该提取到至少一个引用" urls = [c.source_url for c in result.citations if c.source_url] assert "https://example.com" in urls, "应包含预期URL" def test_citation_extraction_from_bare_url(self): """2. 裸URL格式""" text = "更多信息请访问 https://example.com 还有 https://test.com" result = analyze_citations(text) assert len(result.citations) > 0, "应该提取到至少一个裸URL" urls = [c.source_url for c in result.citations if c.source_url] assert "https://example.com" in urls assert "https://test.com" in urls def test_citation_extraction_from_footnote(self): """3. 脚注格式 [^n]""" text = """ Apple是一家伟大的公司[1] 微软也是知名公司[2] [1]: https://apple.com [2]: https://microsoft.com """ citations = extract_footnotes(text) assert len(citations) > 0, "应该提取到脚注引用" urls = [c.source_url for c in citations if c.source_url] assert "https://apple.com" in urls assert "https://microsoft.com" in urls def test_citation_extraction_from_source_annotation(self): """4. 来源标注格式 (来源: / 据...报道 / 参考: / 引用: / 出处:)""" text = "Apple是一家伟大的公司。来源: https://source1.com。据新浪报道,..." citations = extract_source_annotations(text) # 来源标注可能没有URL但有标题 assert len(citations) > 0, "应该提取到来源标注" def test_citation_extraction_from_data_source_marker(self): """5. data_source标记""" text = "[data_source: ai_platform]\n这是AI生成的回答" source, clean_text = extract_data_source(text) assert source == "ai_platform", "应正确识别data_source" assert "AI生成" in clean_text, "应正确清理文本" def test_analyze_citations_complete_flow(self): """完整引用分析流程测试""" text = "[data_source: ai_platform]\n\nApple是一家伟大的公司[1]。\n\n更多详情请访问 https://apple.com\n\n[1]: https://reference.com" result = analyze_citations(text) assert isinstance(result, CitationAnalysisResult) assert result.data_source == "ai_platform" assert len(result.citations) > 0, "应提取到引用" assert "Apple是一家伟大的公司" in result.clean_response def test_analyze_citations_empty_text(self): """空文本处理""" result = analyze_citations("") assert isinstance(result, CitationAnalysisResult) assert result.data_source == "unknown" assert len(result.citations) == 0 def test_citation_analyzer_complete_flow(self): """完整引用分析流程测试""" text = "这是一个测试 [链接](https://test.com) 和裸URL https://bare.com" result = analyze_citations(text) assert isinstance(result, CitationAnalysisResult) assert len(result.citations) > 0 def test_duplicate_url_deduplication(self): """重复URL应该去重""" text = """ 访问 https://example.com 再访问 https://example.com [链接](https://example.com) """ citations = extract_urls_with_context(text) urls = [c.source_url for c in citations if c.source_url] # 去重后应该只有一个 assert urls.count("https://example.com") == 1 def test_markdown_link_priority(self): """Markdown链接应优先于裸URL(标题更丰富)""" text = "[Apple官网](https://apple.com) 和裸URL https://microsoft.com" md_links = extract_markdown_links(text) bare_urls = extract_urls_with_context(text) # Markdown链接应该有标题 assert len(md_links) > 0 assert md_links[0].source_title is not None def test_url_cleanup_punctuation(self): """URL末尾标点符号应该被清理""" text = "访问 https://example.com, 和 https://test.com; 结束" citations = extract_urls_with_context(text) urls = [c.source_url for c in citations if c.source_url] # URL不应该以逗号或分号结尾 for url in urls: assert not url.endswith(',') assert not url.endswith(';') class TestAdapterIntegration: """适配器集成测试 - 验证所有平台适配器""" def test_all_adapters_inherit_base(self): """所有适配器应继承BasePlatformAdapter""" adapters = [KimiAdapter, WenxinAdapter, DoubaoAdapter] for adapter_cls in adapters: instance = adapter_cls() assert hasattr(instance, 'query') assert hasattr(instance, 'platform_name') assert hasattr(instance, 'platform_url') assert callable(instance.query) def test_adapter_has_required_properties(self): """适配器应具有必需的属性""" adapter = KimiAdapter() assert adapter.platform_name == "kimi" assert adapter.platform_url == "https://kimi.moonshot.cn" assert hasattr(adapter, 'is_configured') assert hasattr(adapter, 'close') def test_kimi_adapter_properties(self): """Kimi适配器特定属性""" adapter = KimiAdapter() assert adapter.platform_name == "kimi" # is_configured取决于API Key是否设置 assert isinstance(adapter.is_configured, bool) def test_wenxin_adapter_properties(self): """文心适配器特定属性""" adapter = WenxinAdapter() assert adapter.platform_name == "wenxin" assert adapter.platform_url == "https://yiyan.baidu.com" assert hasattr(adapter, 'secret_key') def test_doubao_adapter_properties(self): """豆包适配器特定属性""" adapter = DoubaoAdapter() assert adapter.platform_name == "doubao" assert hasattr(adapter, 'endpoint_id') if __name__ == "__main__": pytest.main([__file__, "-v"])