"""U4 测试: RetrieveKnowledgeTool - RAG 管线内置工具 测试 retrieve_knowledge 工具的创建、执行、自动注册和集成。 """ from unittest.mock import AsyncMock, MagicMock, patch import pytest from agentkit.memory.base import Memory, MemoryItem from agentkit.memory.retriever import MemoryRetriever, RetrieveKnowledgeTool from agentkit.tools.base import Tool # ── In-Memory Memory 实现(用于测试) ──────────────────── class InMemoryMemory(Memory): """基于内存的 Memory 实现,用于测试""" def __init__(self): self._store: dict[str, MemoryItem] = {} async def store(self, key: str, value, metadata=None) -> None: self._store[key] = MemoryItem( key=key, value=value, metadata=metadata or {}, score=1.0 ) async def retrieve(self, key: str) -> MemoryItem | None: return self._store.get(key) async def search(self, query: str, top_k: int = 5, filters=None) -> list[MemoryItem]: results = [] for item in self._store.values(): if query.lower() in str(item.value).lower() or query.lower() in item.key.lower(): results.append(item) return results[:top_k] async def delete(self, key: str) -> bool: return self._store.pop(key, None) is not None # ── TestRetrieveKnowledgeToolCreation ────────────────────── class TestRetrieveKnowledgeToolCreation: """RetrieveKnowledgeTool 创建测试""" def test_create_retrieve_tool_returns_tool_when_semantic_configured(self): """有 semantic memory 时 create_retrieve_tool() 返回 Tool""" semantic = InMemoryMemory() retriever = MemoryRetriever(semantic_memory=semantic) tool = retriever.create_retrieve_tool() assert tool is not None assert isinstance(tool, Tool) def test_create_retrieve_tool_returns_none_when_no_semantic(self): """无 semantic memory 时 create_retrieve_tool() 返回 None""" retriever = MemoryRetriever() tool = retriever.create_retrieve_tool() assert tool is None def test_create_retrieve_tool_with_working_only_returns_none(self): """仅有 working memory 时返回 None""" working = InMemoryMemory() retriever = MemoryRetriever(working_memory=working) tool = retriever.create_retrieve_tool() assert tool is None def test_tool_has_correct_name(self): """工具名称为 retrieve_knowledge""" semantic = InMemoryMemory() retriever = MemoryRetriever(semantic_memory=semantic) tool = retriever.create_retrieve_tool() assert tool.name == "retrieve_knowledge" def test_tool_has_description(self): """工具包含描述""" semantic = InMemoryMemory() retriever = MemoryRetriever(semantic_memory=semantic) tool = retriever.create_retrieve_tool() assert isinstance(tool.description, str) assert len(tool.description) > 0 def test_tool_has_input_schema(self): """工具包含 input_schema""" semantic = InMemoryMemory() retriever = MemoryRetriever(semantic_memory=semantic) tool = retriever.create_retrieve_tool() assert tool.input_schema is not None assert tool.input_schema["type"] == "object" assert "query" in tool.input_schema["properties"] assert "query" in tool.input_schema["required"] def test_tool_is_retrieve_knowledge_tool_instance(self): """工具是 RetrieveKnowledgeTool 实例""" semantic = InMemoryMemory() retriever = MemoryRetriever(semantic_memory=semantic) tool = retriever.create_retrieve_tool() assert isinstance(tool, RetrieveKnowledgeTool) # ── TestRetrieveKnowledgeToolExecution ───────────────────── class TestRetrieveKnowledgeToolExecution: """RetrieveKnowledgeTool 执行测试""" async def test_execute_calls_retriever_retrieve(self): """execute() 调用 MemoryRetriever.retrieve()""" semantic = InMemoryMemory() await semantic.store("s1", "AI趋势报告", metadata={"source": "report.pdf"}) retriever = MemoryRetriever(semantic_memory=semantic) tool = retriever.create_retrieve_tool() result = await tool.execute(query="AI趋势") assert "results" in result assert len(result["results"]) >= 1 async def test_execute_results_formatted_correctly(self): """结果包含 content, score, source, document_title""" semantic = InMemoryMemory() await semantic.store( "s1", "AI趋势报告内容", metadata={"source": "report.pdf", "document_title": "2024 AI Report"}, ) retriever = MemoryRetriever(semantic_memory=semantic) tool = retriever.create_retrieve_tool() result = await tool.execute(query="AI趋势") assert "results" in result for item in result["results"]: assert "content" in item assert "score" in item assert "source" in item assert "document_title" in item async def test_execute_empty_query_returns_error(self): """空 query 返回错误""" semantic = InMemoryMemory() retriever = MemoryRetriever(semantic_memory=semantic) tool = retriever.create_retrieve_tool() result = await tool.execute(query="") assert "error" in result assert result["results"] == [] async def test_execute_max_calls_limit(self): """超过 max_calls 限制后返回错误""" semantic = InMemoryMemory() await semantic.store("s1", "Some content") retriever = MemoryRetriever(semantic_memory=semantic) tool = retriever.create_retrieve_tool(max_calls=3) # 前 3 次调用应该成功 for i in range(3): result = await tool.execute(query="content") assert "error" not in result or result.get("call_count") == i + 1 # 第 4 次调用应该返回错误 result = await tool.execute(query="content") assert "error" in result assert "Maximum retrieval calls" in result["error"] assert result["results"] == [] async def test_execute_call_count_tracking(self): """call_count 在响应中正确跟踪""" semantic = InMemoryMemory() await semantic.store("s1", "Some content") retriever = MemoryRetriever(semantic_memory=semantic) tool = retriever.create_retrieve_tool(max_calls=5) for i in range(1, 4): result = await tool.execute(query="content") assert result["call_count"] == i async def test_execute_exception_handling(self): """retriever 抛出异常时返回错误响应""" retriever = MemoryRetriever(semantic_memory=InMemoryMemory()) tool = retriever.create_retrieve_tool() # Mock retriever.retrieve to raise exception tool._retriever.retrieve = AsyncMock(side_effect=Exception("Service unavailable")) result = await tool.execute(query="test") assert "error" in result assert "Service unavailable" in result["error"] assert result["results"] == [] async def test_execute_returns_query_in_response(self): """响应中包含原始查询""" semantic = InMemoryMemory() await semantic.store("s1", "Some content") retriever = MemoryRetriever(semantic_memory=semantic) tool = retriever.create_retrieve_tool() result = await tool.execute(query="AI趋势") assert result["query"] == "AI趋势" # ── TestRetrieveKnowledgeToolAutoRegistration ────────────── class TestRetrieveKnowledgeToolAutoRegistration: """RetrieveKnowledgeTool 自动注册测试""" def test_agent_with_semantic_memory_has_tool(self): """ConfigDrivenAgent 配置了 semantic memory 时自动注册 retrieve_knowledge""" from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent config = AgentConfig.from_dict({ "name": "test_agent", "agent_type": "test", "task_mode": "llm_generate", "prompt": { "identity": "Test agent", "instructions": "Test", }, "memory": { "semantic": { "enabled": True, "base_url": "http://localhost:8080", "knowledge_base_ids": ["kb1"], }, }, }) # Patch imports inside the try block of ConfigDrivenAgent.__init__ with patch("agentkit.memory.http_rag.HttpRAGService") as mock_rag, \ patch("agentkit.memory.semantic.SemanticMemory") as mock_sem: mock_sem.return_value = InMemoryMemory() agent = ConfigDrivenAgent(config=config) tool_names = [t.name for t in agent._tools] assert "retrieve_knowledge" in tool_names def test_agent_without_semantic_memory_does_not_have_tool(self): """ConfigDrivenAgent 未配置 semantic memory 时不注册 retrieve_knowledge""" from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent config = AgentConfig.from_dict({ "name": "test_agent", "agent_type": "test", "task_mode": "llm_generate", "prompt": { "identity": "Test agent", "instructions": "Test", }, }) agent = ConfigDrivenAgent(config=config) tool_names = [t.name for t in agent._tools] assert "retrieve_knowledge" not in tool_names def test_auto_registered_tool_is_retrieve_knowledge_instance(self): """自动注册的工具是 RetrieveKnowledgeTool 实例""" from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent config = AgentConfig.from_dict({ "name": "test_agent", "agent_type": "test", "task_mode": "llm_generate", "prompt": { "identity": "Test agent", "instructions": "Test", }, "memory": { "semantic": { "enabled": True, "base_url": "http://localhost:8080", "knowledge_base_ids": ["kb1"], }, }, }) with patch("agentkit.memory.http_rag.HttpRAGService"), \ patch("agentkit.memory.semantic.SemanticMemory") as mock_sem: mock_sem.return_value = InMemoryMemory() agent = ConfigDrivenAgent(config=config) retrieve_tools = [t for t in agent._tools if t.name == "retrieve_knowledge"] assert len(retrieve_tools) == 1 assert isinstance(retrieve_tools[0], RetrieveKnowledgeTool) # ── TestRetrieveKnowledgeToolIntegration ─────────────────── class TestRetrieveKnowledgeToolIntegration: """RetrieveKnowledgeTool 集成测试""" async def test_tool_works_with_query_transformer(self): """工具配合 query transformer 工作""" from agentkit.memory.query_transformer import QueryTransformerBase, TransformedQuery class SimpleTransformer(QueryTransformerBase): async def transform(self, query: str) -> TransformedQuery: return TransformedQuery( main_query=f"enhanced: {query}", sub_queries=[], ) semantic = InMemoryMemory() await semantic.store("s1", "enhanced: AI trends data") retriever = MemoryRetriever( semantic_memory=semantic, query_transformer=SimpleTransformer(), ) tool = retriever.create_retrieve_tool() result = await tool.execute(query="AI") assert "results" in result async def test_tool_returns_structured_results_for_llm(self): """工具返回 LLM 可用的结构化结果""" semantic = InMemoryMemory() await semantic.store( "s1", "GEO optimization improves brand visibility", metadata={"source": "guide.md", "document_title": "GEO Guide"}, ) await semantic.store( "s2", "Another relevant document about SEO", metadata={"source": "seo.md", "document_title": "SEO Basics"}, ) retriever = MemoryRetriever(semantic_memory=semantic) tool = retriever.create_retrieve_tool() result = await tool.execute(query="optimization") assert isinstance(result, dict) assert "query" in result assert "results" in result assert "call_count" in result assert isinstance(result["results"], list) for item in result["results"]: assert isinstance(item, dict) assert "content" in item assert "score" in item