fischer-agentkit/tests/unit/test_retrieve_knowledge_too...

363 lines
13 KiB
Python

"""U4 测试: RetrieveKnowledgeTool - RAG 管线内置工具
测试 retrieve_knowledge 工具的创建、执行、自动注册和集成。
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from agentkit.memory.base import Memory, MemoryItem
from agentkit.memory.retriever import MemoryRetriever, RetrieveKnowledgeTool
from agentkit.tools.base import Tool
# ── In-Memory Memory 实现(用于测试) ────────────────────
class InMemoryMemory(Memory):
"""基于内存的 Memory 实现,用于测试"""
def __init__(self):
self._store: dict[str, MemoryItem] = {}
async def store(self, key: str, value, metadata=None) -> None:
self._store[key] = MemoryItem(
key=key, value=value, metadata=metadata or {}, score=1.0
)
async def retrieve(self, key: str) -> MemoryItem | None:
return self._store.get(key)
async def search(self, query: str, top_k: int = 5, filters=None) -> list[MemoryItem]:
results = []
for item in self._store.values():
if query.lower() in str(item.value).lower() or query.lower() in item.key.lower():
results.append(item)
return results[:top_k]
async def delete(self, key: str) -> bool:
return self._store.pop(key, None) is not None
# ── TestRetrieveKnowledgeToolCreation ──────────────────────
class TestRetrieveKnowledgeToolCreation:
"""RetrieveKnowledgeTool 创建测试"""
def test_create_retrieve_tool_returns_tool_when_semantic_configured(self):
"""有 semantic memory 时 create_retrieve_tool() 返回 Tool"""
semantic = InMemoryMemory()
retriever = MemoryRetriever(semantic_memory=semantic)
tool = retriever.create_retrieve_tool()
assert tool is not None
assert isinstance(tool, Tool)
def test_create_retrieve_tool_returns_none_when_no_semantic(self):
"""无 semantic memory 时 create_retrieve_tool() 返回 None"""
retriever = MemoryRetriever()
tool = retriever.create_retrieve_tool()
assert tool is None
def test_create_retrieve_tool_with_working_only_returns_none(self):
"""仅有 working memory 时返回 None"""
working = InMemoryMemory()
retriever = MemoryRetriever(working_memory=working)
tool = retriever.create_retrieve_tool()
assert tool is None
def test_tool_has_correct_name(self):
"""工具名称为 retrieve_knowledge"""
semantic = InMemoryMemory()
retriever = MemoryRetriever(semantic_memory=semantic)
tool = retriever.create_retrieve_tool()
assert tool.name == "retrieve_knowledge"
def test_tool_has_description(self):
"""工具包含描述"""
semantic = InMemoryMemory()
retriever = MemoryRetriever(semantic_memory=semantic)
tool = retriever.create_retrieve_tool()
assert isinstance(tool.description, str)
assert len(tool.description) > 0
def test_tool_has_input_schema(self):
"""工具包含 input_schema"""
semantic = InMemoryMemory()
retriever = MemoryRetriever(semantic_memory=semantic)
tool = retriever.create_retrieve_tool()
assert tool.input_schema is not None
assert tool.input_schema["type"] == "object"
assert "query" in tool.input_schema["properties"]
assert "query" in tool.input_schema["required"]
def test_tool_is_retrieve_knowledge_tool_instance(self):
"""工具是 RetrieveKnowledgeTool 实例"""
semantic = InMemoryMemory()
retriever = MemoryRetriever(semantic_memory=semantic)
tool = retriever.create_retrieve_tool()
assert isinstance(tool, RetrieveKnowledgeTool)
# ── TestRetrieveKnowledgeToolExecution ─────────────────────
class TestRetrieveKnowledgeToolExecution:
"""RetrieveKnowledgeTool 执行测试"""
async def test_execute_calls_retriever_retrieve(self):
"""execute() 调用 MemoryRetriever.retrieve()"""
semantic = InMemoryMemory()
await semantic.store("s1", "AI趋势报告", metadata={"source": "report.pdf"})
retriever = MemoryRetriever(semantic_memory=semantic)
tool = retriever.create_retrieve_tool()
result = await tool.execute(query="AI趋势")
assert "results" in result
assert len(result["results"]) >= 1
async def test_execute_results_formatted_correctly(self):
"""结果包含 content, score, source, document_title"""
semantic = InMemoryMemory()
await semantic.store(
"s1",
"AI趋势报告内容",
metadata={"source": "report.pdf", "document_title": "2024 AI Report"},
)
retriever = MemoryRetriever(semantic_memory=semantic)
tool = retriever.create_retrieve_tool()
result = await tool.execute(query="AI趋势")
assert "results" in result
for item in result["results"]:
assert "content" in item
assert "score" in item
assert "source" in item
assert "document_title" in item
async def test_execute_empty_query_returns_error(self):
"""空 query 返回错误"""
semantic = InMemoryMemory()
retriever = MemoryRetriever(semantic_memory=semantic)
tool = retriever.create_retrieve_tool()
result = await tool.execute(query="")
assert "error" in result
assert result["results"] == []
async def test_execute_max_calls_limit(self):
"""超过 max_calls 限制后返回错误"""
semantic = InMemoryMemory()
await semantic.store("s1", "Some content")
retriever = MemoryRetriever(semantic_memory=semantic)
tool = retriever.create_retrieve_tool(max_calls=3)
# 前 3 次调用应该成功
for i in range(3):
result = await tool.execute(query="content")
assert "error" not in result or result.get("call_count") == i + 1
# 第 4 次调用应该返回错误
result = await tool.execute(query="content")
assert "error" in result
assert "Maximum retrieval calls" in result["error"]
assert result["results"] == []
async def test_execute_call_count_tracking(self):
"""call_count 在响应中正确跟踪"""
semantic = InMemoryMemory()
await semantic.store("s1", "Some content")
retriever = MemoryRetriever(semantic_memory=semantic)
tool = retriever.create_retrieve_tool(max_calls=5)
for i in range(1, 4):
result = await tool.execute(query="content")
assert result["call_count"] == i
async def test_execute_exception_handling(self):
"""retriever 抛出异常时返回错误响应"""
retriever = MemoryRetriever(semantic_memory=InMemoryMemory())
tool = retriever.create_retrieve_tool()
# Mock retriever.retrieve to raise exception
tool._retriever.retrieve = AsyncMock(side_effect=Exception("Service unavailable"))
result = await tool.execute(query="test")
assert "error" in result
assert "Service unavailable" in result["error"]
assert result["results"] == []
async def test_execute_returns_query_in_response(self):
"""响应中包含原始查询"""
semantic = InMemoryMemory()
await semantic.store("s1", "Some content")
retriever = MemoryRetriever(semantic_memory=semantic)
tool = retriever.create_retrieve_tool()
result = await tool.execute(query="AI趋势")
assert result["query"] == "AI趋势"
# ── TestRetrieveKnowledgeToolAutoRegistration ──────────────
class TestRetrieveKnowledgeToolAutoRegistration:
"""RetrieveKnowledgeTool 自动注册测试"""
def test_agent_with_semantic_memory_has_tool(self):
"""ConfigDrivenAgent 配置了 semantic memory 时自动注册 retrieve_knowledge"""
from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent
config = AgentConfig.from_dict({
"name": "test_agent",
"agent_type": "test",
"task_mode": "llm_generate",
"prompt": {
"identity": "Test agent",
"instructions": "Test",
},
"memory": {
"semantic": {
"enabled": True,
"base_url": "http://localhost:8080",
"knowledge_base_ids": ["kb1"],
},
},
})
# Patch imports inside the try block of ConfigDrivenAgent.__init__
with patch("agentkit.memory.http_rag.HttpRAGService") as mock_rag, \
patch("agentkit.memory.semantic.SemanticMemory") as mock_sem:
mock_sem.return_value = InMemoryMemory()
agent = ConfigDrivenAgent(config=config)
tool_names = [t.name for t in agent._tools]
assert "retrieve_knowledge" in tool_names
def test_agent_without_semantic_memory_does_not_have_tool(self):
"""ConfigDrivenAgent 未配置 semantic memory 时不注册 retrieve_knowledge"""
from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent
config = AgentConfig.from_dict({
"name": "test_agent",
"agent_type": "test",
"task_mode": "llm_generate",
"prompt": {
"identity": "Test agent",
"instructions": "Test",
},
})
agent = ConfigDrivenAgent(config=config)
tool_names = [t.name for t in agent._tools]
assert "retrieve_knowledge" not in tool_names
def test_auto_registered_tool_is_retrieve_knowledge_instance(self):
"""自动注册的工具是 RetrieveKnowledgeTool 实例"""
from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent
config = AgentConfig.from_dict({
"name": "test_agent",
"agent_type": "test",
"task_mode": "llm_generate",
"prompt": {
"identity": "Test agent",
"instructions": "Test",
},
"memory": {
"semantic": {
"enabled": True,
"base_url": "http://localhost:8080",
"knowledge_base_ids": ["kb1"],
},
},
})
with patch("agentkit.memory.http_rag.HttpRAGService"), \
patch("agentkit.memory.semantic.SemanticMemory") as mock_sem:
mock_sem.return_value = InMemoryMemory()
agent = ConfigDrivenAgent(config=config)
retrieve_tools = [t for t in agent._tools if t.name == "retrieve_knowledge"]
assert len(retrieve_tools) == 1
assert isinstance(retrieve_tools[0], RetrieveKnowledgeTool)
# ── TestRetrieveKnowledgeToolIntegration ───────────────────
class TestRetrieveKnowledgeToolIntegration:
"""RetrieveKnowledgeTool 集成测试"""
async def test_tool_works_with_query_transformer(self):
"""工具配合 query transformer 工作"""
from agentkit.memory.query_transformer import QueryTransformerBase, TransformedQuery
class SimpleTransformer(QueryTransformerBase):
async def transform(self, query: str) -> TransformedQuery:
return TransformedQuery(
main_query=f"enhanced: {query}",
sub_queries=[],
)
semantic = InMemoryMemory()
await semantic.store("s1", "enhanced: AI trends data")
retriever = MemoryRetriever(
semantic_memory=semantic,
query_transformer=SimpleTransformer(),
)
tool = retriever.create_retrieve_tool()
result = await tool.execute(query="AI")
assert "results" in result
async def test_tool_returns_structured_results_for_llm(self):
"""工具返回 LLM 可用的结构化结果"""
semantic = InMemoryMemory()
await semantic.store(
"s1",
"GEO optimization improves brand visibility",
metadata={"source": "guide.md", "document_title": "GEO Guide"},
)
await semantic.store(
"s2",
"Another relevant document about SEO",
metadata={"source": "seo.md", "document_title": "SEO Basics"},
)
retriever = MemoryRetriever(semantic_memory=semantic)
tool = retriever.create_retrieve_tool()
result = await tool.execute(query="optimization")
assert isinstance(result, dict)
assert "query" in result
assert "results" in result
assert "call_count" in result
assert isinstance(result["results"], list)
for item in result["results"]:
assert isinstance(item, dict)
assert "content" in item
assert "score" in item