363 lines
13 KiB
Python
363 lines
13 KiB
Python
"""U4 测试: RetrieveKnowledgeTool - RAG 管线内置工具
|
|
|
|
测试 retrieve_knowledge 工具的创建、执行、自动注册和集成。
|
|
"""
|
|
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from agentkit.memory.base import Memory, MemoryItem
|
|
from agentkit.memory.retriever import MemoryRetriever, RetrieveKnowledgeTool
|
|
from agentkit.tools.base import Tool
|
|
|
|
|
|
# ── In-Memory Memory 实现(用于测试) ────────────────────
|
|
|
|
|
|
class InMemoryMemory(Memory):
|
|
"""基于内存的 Memory 实现,用于测试"""
|
|
|
|
def __init__(self):
|
|
self._store: dict[str, MemoryItem] = {}
|
|
|
|
async def store(self, key: str, value, metadata=None) -> None:
|
|
self._store[key] = MemoryItem(
|
|
key=key, value=value, metadata=metadata or {}, score=1.0
|
|
)
|
|
|
|
async def retrieve(self, key: str) -> MemoryItem | None:
|
|
return self._store.get(key)
|
|
|
|
async def search(self, query: str, top_k: int = 5, filters=None) -> list[MemoryItem]:
|
|
results = []
|
|
for item in self._store.values():
|
|
if query.lower() in str(item.value).lower() or query.lower() in item.key.lower():
|
|
results.append(item)
|
|
return results[:top_k]
|
|
|
|
async def delete(self, key: str) -> bool:
|
|
return self._store.pop(key, None) is not None
|
|
|
|
|
|
# ── TestRetrieveKnowledgeToolCreation ──────────────────────
|
|
|
|
|
|
class TestRetrieveKnowledgeToolCreation:
|
|
"""RetrieveKnowledgeTool 创建测试"""
|
|
|
|
def test_create_retrieve_tool_returns_tool_when_semantic_configured(self):
|
|
"""有 semantic memory 时 create_retrieve_tool() 返回 Tool"""
|
|
semantic = InMemoryMemory()
|
|
retriever = MemoryRetriever(semantic_memory=semantic)
|
|
|
|
tool = retriever.create_retrieve_tool()
|
|
|
|
assert tool is not None
|
|
assert isinstance(tool, Tool)
|
|
|
|
def test_create_retrieve_tool_returns_none_when_no_semantic(self):
|
|
"""无 semantic memory 时 create_retrieve_tool() 返回 None"""
|
|
retriever = MemoryRetriever()
|
|
|
|
tool = retriever.create_retrieve_tool()
|
|
|
|
assert tool is None
|
|
|
|
def test_create_retrieve_tool_with_working_only_returns_none(self):
|
|
"""仅有 working memory 时返回 None"""
|
|
working = InMemoryMemory()
|
|
retriever = MemoryRetriever(working_memory=working)
|
|
|
|
tool = retriever.create_retrieve_tool()
|
|
|
|
assert tool is None
|
|
|
|
def test_tool_has_correct_name(self):
|
|
"""工具名称为 retrieve_knowledge"""
|
|
semantic = InMemoryMemory()
|
|
retriever = MemoryRetriever(semantic_memory=semantic)
|
|
|
|
tool = retriever.create_retrieve_tool()
|
|
|
|
assert tool.name == "retrieve_knowledge"
|
|
|
|
def test_tool_has_description(self):
|
|
"""工具包含描述"""
|
|
semantic = InMemoryMemory()
|
|
retriever = MemoryRetriever(semantic_memory=semantic)
|
|
|
|
tool = retriever.create_retrieve_tool()
|
|
|
|
assert isinstance(tool.description, str)
|
|
assert len(tool.description) > 0
|
|
|
|
def test_tool_has_input_schema(self):
|
|
"""工具包含 input_schema"""
|
|
semantic = InMemoryMemory()
|
|
retriever = MemoryRetriever(semantic_memory=semantic)
|
|
|
|
tool = retriever.create_retrieve_tool()
|
|
|
|
assert tool.input_schema is not None
|
|
assert tool.input_schema["type"] == "object"
|
|
assert "query" in tool.input_schema["properties"]
|
|
assert "query" in tool.input_schema["required"]
|
|
|
|
def test_tool_is_retrieve_knowledge_tool_instance(self):
|
|
"""工具是 RetrieveKnowledgeTool 实例"""
|
|
semantic = InMemoryMemory()
|
|
retriever = MemoryRetriever(semantic_memory=semantic)
|
|
|
|
tool = retriever.create_retrieve_tool()
|
|
|
|
assert isinstance(tool, RetrieveKnowledgeTool)
|
|
|
|
|
|
# ── TestRetrieveKnowledgeToolExecution ─────────────────────
|
|
|
|
|
|
class TestRetrieveKnowledgeToolExecution:
|
|
"""RetrieveKnowledgeTool 执行测试"""
|
|
|
|
async def test_execute_calls_retriever_retrieve(self):
|
|
"""execute() 调用 MemoryRetriever.retrieve()"""
|
|
semantic = InMemoryMemory()
|
|
await semantic.store("s1", "AI趋势报告", metadata={"source": "report.pdf"})
|
|
retriever = MemoryRetriever(semantic_memory=semantic)
|
|
tool = retriever.create_retrieve_tool()
|
|
|
|
result = await tool.execute(query="AI趋势")
|
|
|
|
assert "results" in result
|
|
assert len(result["results"]) >= 1
|
|
|
|
async def test_execute_results_formatted_correctly(self):
|
|
"""结果包含 content, score, source, document_title"""
|
|
semantic = InMemoryMemory()
|
|
await semantic.store(
|
|
"s1",
|
|
"AI趋势报告内容",
|
|
metadata={"source": "report.pdf", "document_title": "2024 AI Report"},
|
|
)
|
|
retriever = MemoryRetriever(semantic_memory=semantic)
|
|
tool = retriever.create_retrieve_tool()
|
|
|
|
result = await tool.execute(query="AI趋势")
|
|
|
|
assert "results" in result
|
|
for item in result["results"]:
|
|
assert "content" in item
|
|
assert "score" in item
|
|
assert "source" in item
|
|
assert "document_title" in item
|
|
|
|
async def test_execute_empty_query_returns_error(self):
|
|
"""空 query 返回错误"""
|
|
semantic = InMemoryMemory()
|
|
retriever = MemoryRetriever(semantic_memory=semantic)
|
|
tool = retriever.create_retrieve_tool()
|
|
|
|
result = await tool.execute(query="")
|
|
|
|
assert "error" in result
|
|
assert result["results"] == []
|
|
|
|
async def test_execute_max_calls_limit(self):
|
|
"""超过 max_calls 限制后返回错误"""
|
|
semantic = InMemoryMemory()
|
|
await semantic.store("s1", "Some content")
|
|
retriever = MemoryRetriever(semantic_memory=semantic)
|
|
tool = retriever.create_retrieve_tool(max_calls=3)
|
|
|
|
# 前 3 次调用应该成功
|
|
for i in range(3):
|
|
result = await tool.execute(query="content")
|
|
assert "error" not in result or result.get("call_count") == i + 1
|
|
|
|
# 第 4 次调用应该返回错误
|
|
result = await tool.execute(query="content")
|
|
assert "error" in result
|
|
assert "Maximum retrieval calls" in result["error"]
|
|
assert result["results"] == []
|
|
|
|
async def test_execute_call_count_tracking(self):
|
|
"""call_count 在响应中正确跟踪"""
|
|
semantic = InMemoryMemory()
|
|
await semantic.store("s1", "Some content")
|
|
retriever = MemoryRetriever(semantic_memory=semantic)
|
|
tool = retriever.create_retrieve_tool(max_calls=5)
|
|
|
|
for i in range(1, 4):
|
|
result = await tool.execute(query="content")
|
|
assert result["call_count"] == i
|
|
|
|
async def test_execute_exception_handling(self):
|
|
"""retriever 抛出异常时返回错误响应"""
|
|
retriever = MemoryRetriever(semantic_memory=InMemoryMemory())
|
|
tool = retriever.create_retrieve_tool()
|
|
|
|
# Mock retriever.retrieve to raise exception
|
|
tool._retriever.retrieve = AsyncMock(side_effect=Exception("Service unavailable"))
|
|
|
|
result = await tool.execute(query="test")
|
|
|
|
assert "error" in result
|
|
assert "Service unavailable" in result["error"]
|
|
assert result["results"] == []
|
|
|
|
async def test_execute_returns_query_in_response(self):
|
|
"""响应中包含原始查询"""
|
|
semantic = InMemoryMemory()
|
|
await semantic.store("s1", "Some content")
|
|
retriever = MemoryRetriever(semantic_memory=semantic)
|
|
tool = retriever.create_retrieve_tool()
|
|
|
|
result = await tool.execute(query="AI趋势")
|
|
|
|
assert result["query"] == "AI趋势"
|
|
|
|
|
|
# ── TestRetrieveKnowledgeToolAutoRegistration ──────────────
|
|
|
|
|
|
class TestRetrieveKnowledgeToolAutoRegistration:
|
|
"""RetrieveKnowledgeTool 自动注册测试"""
|
|
|
|
def test_agent_with_semantic_memory_has_tool(self):
|
|
"""ConfigDrivenAgent 配置了 semantic memory 时自动注册 retrieve_knowledge"""
|
|
from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent
|
|
|
|
config = AgentConfig.from_dict({
|
|
"name": "test_agent",
|
|
"agent_type": "test",
|
|
"task_mode": "llm_generate",
|
|
"prompt": {
|
|
"identity": "Test agent",
|
|
"instructions": "Test",
|
|
},
|
|
"memory": {
|
|
"semantic": {
|
|
"enabled": True,
|
|
"base_url": "http://localhost:8080",
|
|
"knowledge_base_ids": ["kb1"],
|
|
},
|
|
},
|
|
})
|
|
|
|
# Patch imports inside the try block of ConfigDrivenAgent.__init__
|
|
with patch("agentkit.memory.http_rag.HttpRAGService") as mock_rag, \
|
|
patch("agentkit.memory.semantic.SemanticMemory") as mock_sem:
|
|
mock_sem.return_value = InMemoryMemory()
|
|
agent = ConfigDrivenAgent(config=config)
|
|
|
|
tool_names = [t.name for t in agent._tools]
|
|
assert "retrieve_knowledge" in tool_names
|
|
|
|
def test_agent_without_semantic_memory_does_not_have_tool(self):
|
|
"""ConfigDrivenAgent 未配置 semantic memory 时不注册 retrieve_knowledge"""
|
|
from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent
|
|
|
|
config = AgentConfig.from_dict({
|
|
"name": "test_agent",
|
|
"agent_type": "test",
|
|
"task_mode": "llm_generate",
|
|
"prompt": {
|
|
"identity": "Test agent",
|
|
"instructions": "Test",
|
|
},
|
|
})
|
|
|
|
agent = ConfigDrivenAgent(config=config)
|
|
|
|
tool_names = [t.name for t in agent._tools]
|
|
assert "retrieve_knowledge" not in tool_names
|
|
|
|
def test_auto_registered_tool_is_retrieve_knowledge_instance(self):
|
|
"""自动注册的工具是 RetrieveKnowledgeTool 实例"""
|
|
from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent
|
|
|
|
config = AgentConfig.from_dict({
|
|
"name": "test_agent",
|
|
"agent_type": "test",
|
|
"task_mode": "llm_generate",
|
|
"prompt": {
|
|
"identity": "Test agent",
|
|
"instructions": "Test",
|
|
},
|
|
"memory": {
|
|
"semantic": {
|
|
"enabled": True,
|
|
"base_url": "http://localhost:8080",
|
|
"knowledge_base_ids": ["kb1"],
|
|
},
|
|
},
|
|
})
|
|
|
|
with patch("agentkit.memory.http_rag.HttpRAGService"), \
|
|
patch("agentkit.memory.semantic.SemanticMemory") as mock_sem:
|
|
mock_sem.return_value = InMemoryMemory()
|
|
agent = ConfigDrivenAgent(config=config)
|
|
|
|
retrieve_tools = [t for t in agent._tools if t.name == "retrieve_knowledge"]
|
|
assert len(retrieve_tools) == 1
|
|
assert isinstance(retrieve_tools[0], RetrieveKnowledgeTool)
|
|
|
|
|
|
# ── TestRetrieveKnowledgeToolIntegration ───────────────────
|
|
|
|
|
|
class TestRetrieveKnowledgeToolIntegration:
|
|
"""RetrieveKnowledgeTool 集成测试"""
|
|
|
|
async def test_tool_works_with_query_transformer(self):
|
|
"""工具配合 query transformer 工作"""
|
|
from agentkit.memory.query_transformer import QueryTransformerBase, TransformedQuery
|
|
|
|
class SimpleTransformer(QueryTransformerBase):
|
|
async def transform(self, query: str) -> TransformedQuery:
|
|
return TransformedQuery(
|
|
main_query=f"enhanced: {query}",
|
|
sub_queries=[],
|
|
)
|
|
|
|
semantic = InMemoryMemory()
|
|
await semantic.store("s1", "enhanced: AI trends data")
|
|
retriever = MemoryRetriever(
|
|
semantic_memory=semantic,
|
|
query_transformer=SimpleTransformer(),
|
|
)
|
|
tool = retriever.create_retrieve_tool()
|
|
|
|
result = await tool.execute(query="AI")
|
|
|
|
assert "results" in result
|
|
|
|
async def test_tool_returns_structured_results_for_llm(self):
|
|
"""工具返回 LLM 可用的结构化结果"""
|
|
semantic = InMemoryMemory()
|
|
await semantic.store(
|
|
"s1",
|
|
"GEO optimization improves brand visibility",
|
|
metadata={"source": "guide.md", "document_title": "GEO Guide"},
|
|
)
|
|
await semantic.store(
|
|
"s2",
|
|
"Another relevant document about SEO",
|
|
metadata={"source": "seo.md", "document_title": "SEO Basics"},
|
|
)
|
|
retriever = MemoryRetriever(semantic_memory=semantic)
|
|
tool = retriever.create_retrieve_tool()
|
|
|
|
result = await tool.execute(query="optimization")
|
|
|
|
assert isinstance(result, dict)
|
|
assert "query" in result
|
|
assert "results" in result
|
|
assert "call_count" in result
|
|
assert isinstance(result["results"], list)
|
|
for item in result["results"]:
|
|
assert isinstance(item, dict)
|
|
assert "content" in item
|
|
assert "score" in item
|