"""Tests for HttpRAGService — HTTP 客户端调用业务系统知识库 API""" import json import pytest from unittest.mock import AsyncMock, MagicMock, patch from agentkit.memory.http_rag import HttpRAGService from agentkit.memory.semantic import SemanticMemory from agentkit.memory.retriever import MemoryRetriever # --------------------------------------------------------------------------- # HttpRAGService unit tests # --------------------------------------------------------------------------- class TestHttpRAGServiceInit: """HttpRAGService 初始化""" def test_basic_init(self): svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge") assert svc._base_url == "http://localhost:8000/api/knowledge" assert svc._api_key is None assert svc._knowledge_base_ids == [] assert svc._timeout == 30 def test_init_with_all_params(self): svc = HttpRAGService( base_url="http://geo:8000/api/knowledge/", api_key="sk-test", knowledge_base_ids=["kb-1", "kb-2"], timeout=60, ) assert svc._base_url == "http://geo:8000/api/knowledge" # trailing slash stripped assert svc._api_key == "sk-test" assert svc._knowledge_base_ids == ["kb-1", "kb-2"] assert svc._timeout == 60 def test_trailing_slash_stripped(self): svc = HttpRAGService(base_url="http://host/api/") assert svc._base_url == "http://host/api" class TestHttpRAGServiceSearch: """HttpRAGService.search — 语义检索""" @pytest.fixture def svc(self): return HttpRAGService( base_url="http://localhost:8000/api/knowledge", api_key="test-key", knowledge_base_ids=["kb-industry", "kb-enterprise"], ) @pytest.mark.asyncio async def test_search_standard_response(self, svc): """标准 SearchResponse 格式: {"results": [...]}""" mock_resp = MagicMock() mock_resp.status_code = 200 mock_resp.raise_for_status = MagicMock() mock_resp.json.return_value = { "results": [ { "chunk_id": "c1", "content": "AI 行业趋势分析", "score": 0.92, "document_id": "d1", "document_title": "行业报告", "metadata": {"page": 1}, }, { "chunk_id": "c2", "content": "企业数字化转型", "score": 0.85, "document_id": "d2", "document_title": "企业案例", "metadata": {}, }, ], "total": 2, "latency_ms": 50, } mock_client = AsyncMock() mock_client.post = AsyncMock(return_value=mock_resp) svc._get_client = MagicMock(return_value=mock_client) results = await svc.search("AI 趋势", top_k=5) assert len(results) == 2 assert results[0]["id"] == "c1" assert results[0]["content"] == "AI 行业趋势分析" assert results[0]["score"] == 0.92 assert results[0]["document_id"] == "d1" assert results[1]["content"] == "企业数字化转型" # Verify payload call_args = mock_client.post.call_args assert call_args[0][0] == "/search" payload = call_args[1]["json"] assert payload["query"] == "AI 趋势" assert payload["knowledge_base_ids"] == ["kb-industry", "kb-enterprise"] assert payload["top_k"] == 5 @pytest.mark.asyncio async def test_search_list_response(self, svc): """直接返回列表格式: [...]""" mock_resp = MagicMock() mock_resp.status_code = 200 mock_resp.raise_for_status = MagicMock() mock_resp.json.return_value = [ {"chunk_id": "c1", "content": "test", "score": 0.8}, ] mock_client = AsyncMock() mock_client.post = AsyncMock(return_value=mock_resp) svc._get_client = MagicMock(return_value=mock_client) results = await svc.search("test") assert len(results) == 1 assert results[0]["content"] == "test" @pytest.mark.asyncio async def test_search_custom_kb_ids(self, svc): """传入自定义 knowledge_base_ids 覆盖默认值""" mock_resp = MagicMock() mock_resp.status_code = 200 mock_resp.raise_for_status = MagicMock() mock_resp.json.return_value = {"results": []} mock_client = AsyncMock() mock_client.post = AsyncMock(return_value=mock_resp) svc._get_client = MagicMock(return_value=mock_client) await svc.search("test", knowledge_base_ids=["custom-kb"]) payload = mock_client.post.call_args[1]["json"] assert payload["knowledge_base_ids"] == ["custom-kb"] @pytest.mark.asyncio async def test_search_http_error_returns_empty(self, svc): """HTTP 错误返回空列表""" import httpx mock_resp = MagicMock() mock_resp.status_code = 500 mock_resp.text = "Internal Server Error" mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError( "500", request=MagicMock(), response=mock_resp ) mock_client = AsyncMock() mock_client.post = AsyncMock(return_value=mock_resp) svc._get_client = MagicMock(return_value=mock_client) results = await svc.search("test") assert results == [] @pytest.mark.asyncio async def test_search_connection_error_returns_empty(self, svc): """连接错误返回空列表""" import httpx mock_client = AsyncMock() mock_client.post = AsyncMock(side_effect=httpx.ConnectError("Connection refused")) svc._get_client = MagicMock(return_value=mock_client) results = await svc.search("test") assert results == [] @pytest.mark.asyncio async def test_search_unexpected_format_returns_empty(self, svc): """非预期响应格式返回空列表""" mock_resp = MagicMock() mock_resp.status_code = 200 mock_resp.raise_for_status = MagicMock() mock_resp.json.return_value = {"error": "something"} mock_client = AsyncMock() mock_client.post = AsyncMock(return_value=mock_resp) svc._get_client = MagicMock(return_value=mock_client) results = await svc.search("test") assert results == [] class TestHttpRAGServiceIngest: """HttpRAGService.ingest — 文档写入""" @pytest.mark.asyncio async def test_ingest_success(self): svc = HttpRAGService( base_url="http://localhost:8000/api/knowledge", knowledge_base_ids=["kb-1"], ) mock_resp = MagicMock() mock_resp.status_code = 201 mock_resp.raise_for_status = MagicMock() mock_resp.json.return_value = {"id": "doc-1", "status": "processing"} mock_client = AsyncMock() mock_client.post = AsyncMock(return_value=mock_resp) svc._get_client = MagicMock(return_value=mock_client) result = await svc.ingest("测试文档", "文档内容") assert result["id"] == "doc-1" # Verify endpoint and payload call_args = mock_client.post.call_args assert call_args[0][0] == "/bases/kb-1/documents" payload = call_args[1]["json"] assert payload["title"] == "测试文档" assert payload["content"] == "文档内容" @pytest.mark.asyncio async def test_ingest_no_kb_ids_returns_none(self): svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge") result = await svc.ingest("test", "content") assert result is None @pytest.mark.asyncio async def test_ingest_http_error_returns_none(self): import httpx svc = HttpRAGService( base_url="http://localhost:8000/api/knowledge", knowledge_base_ids=["kb-1"], ) mock_resp = MagicMock() mock_resp.status_code = 500 mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError( "500", request=MagicMock(), response=mock_resp ) mock_client = AsyncMock() mock_client.post = AsyncMock(return_value=mock_resp) svc._get_client = MagicMock(return_value=mock_client) result = await svc.ingest("test", "content") assert result is None class TestHttpRAGServiceHealthCheck: """HttpRAGService.health_check""" @pytest.mark.asyncio async def test_health_check_ok(self): svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge") mock_resp = MagicMock() mock_resp.status_code = 200 mock_client = AsyncMock() mock_client.get = AsyncMock(return_value=mock_resp) svc._get_client = MagicMock(return_value=mock_client) assert await svc.health_check() is True @pytest.mark.asyncio async def test_health_check_401_still_healthy(self): """401 表示服务在运行,只是需要认证""" svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge") mock_resp = MagicMock() mock_resp.status_code = 401 mock_client = AsyncMock() mock_client.get = AsyncMock(return_value=mock_resp) svc._get_client = MagicMock(return_value=mock_client) assert await svc.health_check() is True @pytest.mark.asyncio async def test_health_check_connection_error(self): svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge") mock_client = AsyncMock() mock_client.get = AsyncMock(side_effect=Exception("Connection refused")) svc._get_client = MagicMock(return_value=mock_client) assert await svc.health_check() is False class TestHttpRAGServiceClient: """HttpRAGService HTTP 客户端管理""" def test_client_lazy_init(self): svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge", api_key="sk-test") assert svc._client is None client = svc._get_client() assert client is not None assert "Bearer sk-test" in str(client.headers.get("Authorization", "")) def test_client_reuse(self): svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge") c1 = svc._get_client() c2 = svc._get_client() assert c1 is c2 @pytest.mark.asyncio async def test_close(self): svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge") svc._get_client() # init client await svc.close() assert svc._client is None @pytest.mark.asyncio async def test_context_manager(self): svc = HttpRAGService(base_url="http://localhost:8000/api/knowledge") async with svc as s: s._get_client() assert s._client is not None assert svc._client is None # --------------------------------------------------------------------------- # SemanticMemory + HttpRAGService integration # --------------------------------------------------------------------------- class TestSemanticMemoryWithHttpRAG: """SemanticMemory 通过 HttpRAGService 检索知识库""" @pytest.mark.asyncio async def test_search_delegates_to_rag_service(self): """SemanticMemory.search 委托给 HttpRAGService.search""" rag = HttpRAGService( base_url="http://localhost:8000/api/knowledge", knowledge_base_ids=["kb-1"], ) # Mock the search mock_resp = MagicMock() mock_resp.status_code = 200 mock_resp.raise_for_status = MagicMock() mock_resp.json.return_value = { "results": [ {"chunk_id": "c1", "content": "行业知识", "score": 0.9, "document_id": "d1"}, ] } mock_client = AsyncMock() mock_client.post = AsyncMock(return_value=mock_resp) rag._get_client = MagicMock(return_value=mock_client) semantic = SemanticMemory(rag_service=rag, knowledge_base_ids=["kb-1"]) items = await semantic.search("行业趋势", top_k=3) assert len(items) == 1 assert items[0].key == "c1" assert items[0].value == "行业知识" assert items[0].score == 0.9 assert items[0].metadata["source"] == "rag" @pytest.mark.asyncio async def test_search_no_rag_service_returns_empty(self): """无 RAG 服务时返回空列表""" semantic = SemanticMemory() items = await semantic.search("test") assert items == [] class TestMemoryRetrieverWithSemantic: """MemoryRetriever 集成 SemanticMemory + HttpRAGService""" @pytest.mark.asyncio async def test_retriever_queries_semantic_layer(self): """MemoryRetriever 查询 Semantic 层并融合结果""" rag = HttpRAGService( base_url="http://localhost:8000/api/knowledge", knowledge_base_ids=["kb-1"], ) mock_resp = MagicMock() mock_resp.status_code = 200 mock_resp.raise_for_status = MagicMock() mock_resp.json.return_value = { "results": [ {"chunk_id": "c1", "content": "知识库内容", "score": 0.95, "document_id": "d1"}, ] } mock_client = AsyncMock() mock_client.post = AsyncMock(return_value=mock_resp) rag._get_client = MagicMock(return_value=mock_client) semantic = SemanticMemory(rag_service=rag, knowledge_base_ids=["kb-1"]) retriever = MemoryRetriever(semantic_memory=semantic) items = await retriever.retrieve("知识查询", top_k=3) assert len(items) >= 1 # Semantic weight is 0.4 by default assert items[0].score == pytest.approx(0.95 * 0.4, abs=0.01) # --------------------------------------------------------------------------- # Config-driven integration tests # --------------------------------------------------------------------------- class TestServerConfigMemorySemantic: """ServerConfig 解析 memory.semantic 配置""" def test_from_dict_with_semantic(self): from agentkit.server.config import ServerConfig data = { "memory": { "semantic": { "enabled": True, "base_url": "http://geo:8000/api/knowledge", "api_key": "sk-test", "knowledge_base_ids": ["kb-1", "kb-2"], "timeout": 60, }, }, } config = ServerConfig.from_dict(data) assert config.memory["semantic"]["enabled"] is True assert config.memory["semantic"]["base_url"] == "http://geo:8000/api/knowledge" assert config.memory["semantic"]["api_key"] == "sk-test" assert config.memory["semantic"]["knowledge_base_ids"] == ["kb-1", "kb-2"] assert config.memory["semantic"]["timeout"] == 60 def test_from_dict_without_memory(self): from agentkit.server.config import ServerConfig config = ServerConfig.from_dict({}) assert config.memory == {} def test_from_yaml_with_env_var_resolution(self): """验证 from_yaml 路径的 ${VAR:-default} 环境变量解析""" from agentkit.server.config import ServerConfig import os import tempfile os.environ["TEST_GEO_API_KEY"] = "sk-from-env" yaml_content = """ memory: semantic: enabled: true base_url: http://geo:8000/api/knowledge api_key: "${TEST_GEO_API_KEY}" """ with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: f.write(yaml_content) f.flush() config = ServerConfig.from_yaml(f.name) assert config.memory["semantic"]["api_key"] == "sk-from-env" del os.environ["TEST_GEO_API_KEY"] def test_from_yaml_with_default_env_var(self): """验证 from_yaml 路径的 ${VAR:-default} 带默认值""" from agentkit.server.config import ServerConfig import tempfile yaml_content = """ memory: semantic: enabled: true base_url: http://geo:8000/api/knowledge api_key: "${NONEXISTENT_KEY:-sk-default}" """ with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: f.write(yaml_content) f.flush() config = ServerConfig.from_yaml(f.name) assert config.memory["semantic"]["api_key"] == "sk-default" # --------------------------------------------------------------------------- # HttpRAGService enhanced_search tests # --------------------------------------------------------------------------- class TestHttpRAGServiceEnhancedSearch: """HttpRAGService.enhanced_search — 增强语义检索""" @pytest.fixture def svc(self): return HttpRAGService( base_url="http://localhost:8000/api/knowledge", api_key="test-key", knowledge_base_ids=["kb-1", "kb-2"], ) @pytest.mark.asyncio async def test_enhanced_search_single_kb(self, svc): """单知识库增强检索,验证 payload 包含 use_rerank 和 use_compression""" svc._knowledge_base_ids = ["kb-1"] mock_resp = MagicMock() mock_resp.status_code = 200 mock_resp.raise_for_status = MagicMock() mock_resp.json.return_value = { "results": [ {"chunk_id": "c1", "content": "AI 趋势", "score": 0.95, "document_id": "d1"}, ] } mock_client = AsyncMock() mock_client.post = AsyncMock(return_value=mock_resp) svc._get_client = MagicMock(return_value=mock_client) results = await svc.enhanced_search("AI 趋势", top_k=5) assert len(results) == 1 assert results[0]["content"] == "AI 趋势" assert results[0]["score"] == 0.95 # Verify endpoint and payload call_args = mock_client.post.call_args assert call_args[0][0] == "/bases/kb-1/retrieve" payload = call_args[1]["json"] assert payload["query"] == "AI 趋势" assert payload["top_k"] == 5 assert payload["use_rerank"] is True assert payload["use_compression"] is False @pytest.mark.asyncio async def test_enhanced_search_multiple_kbs(self, svc): """多知识库增强检索,结果合并并按 score 降序排序""" # First KB returns one result resp1 = MagicMock() resp1.status_code = 200 resp1.raise_for_status = MagicMock() resp1.json.return_value = { "results": [ {"chunk_id": "c1", "content": "KB1 结果", "score": 0.8, "document_id": "d1"}, ] } # Second KB returns one result with higher score resp2 = MagicMock() resp2.status_code = 200 resp2.raise_for_status = MagicMock() resp2.json.return_value = { "results": [ {"chunk_id": "c2", "content": "KB2 结果", "score": 0.95, "document_id": "d2"}, ] } mock_client = AsyncMock() mock_client.post = AsyncMock(side_effect=[resp1, resp2]) svc._get_client = MagicMock(return_value=mock_client) results = await svc.enhanced_search("test query", top_k=5) assert len(results) == 2 # Merged results sorted by score descending assert results[0]["content"] == "KB2 结果" assert results[0]["score"] == 0.95 assert results[1]["content"] == "KB1 结果" assert results[1]["score"] == 0.8 # Verify both KB endpoints were called calls = mock_client.post.call_args_list assert calls[0][0][0] == "/bases/kb-1/retrieve" assert calls[1][0][0] == "/bases/kb-2/retrieve" @pytest.mark.asyncio async def test_enhanced_search_404_fallback(self, svc): """404 响应回退到标准 search 方法""" import httpx mock_resp = MagicMock() mock_resp.status_code = 404 mock_resp.text = "Not Found" mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError( "404", request=MagicMock(), response=mock_resp ) mock_client = AsyncMock() mock_client.post = AsyncMock(return_value=mock_resp) svc._get_client = MagicMock(return_value=mock_client) # Mock the standard search method svc.search = AsyncMock(return_value=[{"id": "fallback", "content": "fallback result", "score": 0.5}]) results = await svc.enhanced_search("test query") # Should have fallen back to search() svc.search.assert_called_once_with("test query", knowledge_base_ids=["kb-1", "kb-2"], top_k=5) assert len(results) == 1 assert results[0]["id"] == "fallback" @pytest.mark.asyncio async def test_enhanced_search_http_error(self, svc): """非 404 HTTP 错误返回空列表""" import httpx mock_resp = MagicMock() mock_resp.status_code = 500 mock_resp.text = "Internal Server Error" mock_resp.raise_for_status.side_effect = httpx.HTTPStatusError( "500", request=MagicMock(), response=mock_resp ) mock_client = AsyncMock() mock_client.post = AsyncMock(return_value=mock_resp) svc._get_client = MagicMock(return_value=mock_client) results = await svc.enhanced_search("test query") assert results == [] @pytest.mark.asyncio async def test_enhanced_search_with_compression(self, svc): """验证 use_compression: true 在 payload 中""" svc._knowledge_base_ids = ["kb-1"] mock_resp = MagicMock() mock_resp.status_code = 200 mock_resp.raise_for_status = MagicMock() mock_resp.json.return_value = {"results": []} mock_client = AsyncMock() mock_client.post = AsyncMock(return_value=mock_resp) svc._get_client = MagicMock(return_value=mock_client) await svc.enhanced_search("test", use_compression=True) payload = mock_client.post.call_args[1]["json"] assert payload["use_compression"] is True @pytest.mark.asyncio async def test_enhanced_search_without_rerank(self, svc): """验证 use_rerank: false 在 payload 中""" svc._knowledge_base_ids = ["kb-1"] mock_resp = MagicMock() mock_resp.status_code = 200 mock_resp.raise_for_status = MagicMock() mock_resp.json.return_value = {"results": []} mock_client = AsyncMock() mock_client.post = AsyncMock(return_value=mock_resp) svc._get_client = MagicMock(return_value=mock_client) await svc.enhanced_search("test", use_rerank=False) payload = mock_client.post.call_args[1]["json"] assert payload["use_rerank"] is False # --------------------------------------------------------------------------- # SemanticMemory enhanced search mode tests # --------------------------------------------------------------------------- class TestSemanticMemoryEnhancedSearch: """SemanticMemory search_mode — 增强检索模式""" @pytest.mark.asyncio async def test_search_mode_enhanced(self): """search_mode="enhanced" 时调用 enhanced_search""" rag = HttpRAGService( base_url="http://localhost:8000/api/knowledge", knowledge_base_ids=["kb-1"], ) # Mock enhanced_search rag.enhanced_search = AsyncMock(return_value=[ {"id": "c1", "content": "enhanced result", "score": 0.9, "source": "rag", "document_id": "d1"}, ]) semantic = SemanticMemory( rag_service=rag, knowledge_base_ids=["kb-1"], search_mode="enhanced", use_rerank=True, use_compression=False, ) items = await semantic.search("test query", top_k=3) rag.enhanced_search.assert_called_once_with( "test query", knowledge_base_ids=["kb-1"], top_k=3, use_rerank=True, use_compression=False, ) assert len(items) == 1 assert items[0].value == "enhanced result" @pytest.mark.asyncio async def test_search_mode_standard(self): """search_mode="standard" 时调用标准 search""" rag = HttpRAGService( base_url="http://localhost:8000/api/knowledge", knowledge_base_ids=["kb-1"], ) # Mock standard search mock_resp = MagicMock() mock_resp.status_code = 200 mock_resp.raise_for_status = MagicMock() mock_resp.json.return_value = { "results": [ {"chunk_id": "c1", "content": "standard result", "score": 0.8, "document_id": "d1"}, ] } mock_client = AsyncMock() mock_client.post = AsyncMock(return_value=mock_resp) rag._get_client = MagicMock(return_value=mock_client) semantic = SemanticMemory( rag_service=rag, knowledge_base_ids=["kb-1"], search_mode="standard", ) items = await semantic.search("test query", top_k=3) assert len(items) == 1 assert items[0].value == "standard result" # Verify standard /search endpoint was called, not /bases/{kb_id}/retrieve call_args = mock_client.post.call_args assert call_args[0][0] == "/search" @pytest.mark.asyncio async def test_search_mode_enhanced_fallback(self): """search_mode="enhanced" 但 rag_service 没有 enhanced_search 时回退到 search""" class SimpleRAGService: """A RAG service without enhanced_search""" async def search(self, query, knowledge_base_ids=None, top_k=5): return [{"id": "c1", "content": "simple result", "score": 0.7, "source": "rag", "document_id": "d1"}] rag = SimpleRAGService() semantic = SemanticMemory( rag_service=rag, knowledge_base_ids=["kb-1"], search_mode="enhanced", ) items = await semantic.search("test query", top_k=3) assert len(items) == 1 assert items[0].value == "simple result" # --------------------------------------------------------------------------- # Config enhanced search tests # --------------------------------------------------------------------------- class TestConfigEnhancedSearch: """ServerConfig 解析 enhanced search 相关配置""" def test_config_search_mode(self): from agentkit.server.config import ServerConfig data = { "memory": { "semantic": { "enabled": True, "base_url": "http://geo:8000/api/knowledge", "api_key": "sk-test", "knowledge_base_ids": ["kb-1"], "search_mode": "enhanced", }, }, } config = ServerConfig.from_dict(data) assert config.memory["semantic"]["search_mode"] == "enhanced" def test_config_use_rerank(self): from agentkit.server.config import ServerConfig data = { "memory": { "semantic": { "enabled": True, "base_url": "http://geo:8000/api/knowledge", "api_key": "sk-test", "knowledge_base_ids": ["kb-1"], "use_rerank": False, "use_compression": True, }, }, } config = ServerConfig.from_dict(data) assert config.memory["semantic"]["use_rerank"] is False assert config.memory["semantic"]["use_compression"] is True