238 lines
8.4 KiB
Python
238 lines
8.4 KiB
Python
"""MemoryRetriever 单元测试 - 混合检索器
|
||
|
||
使用 InMemoryMemory 实现进行测试,不需要真实 Redis/PG 环境。
|
||
"""
|
||
|
||
from unittest.mock import AsyncMock
|
||
|
||
import pytest
|
||
|
||
from agentkit.memory.base import Memory, MemoryItem
|
||
from agentkit.memory.retriever import MemoryRetriever
|
||
|
||
|
||
# ── In-Memory Memory 实现(用于测试) ────────────────────
|
||
|
||
|
||
class InMemoryMemory(Memory):
|
||
"""基于内存的 Memory 实现,用于测试"""
|
||
|
||
def __init__(self):
|
||
self._store: dict[str, MemoryItem] = {}
|
||
|
||
async def store(self, key: str, value, metadata=None) -> None:
|
||
self._store[key] = MemoryItem(
|
||
key=key, value=value, metadata=metadata or {}, score=1.0
|
||
)
|
||
|
||
async def retrieve(self, key: str) -> MemoryItem | None:
|
||
return self._store.get(key)
|
||
|
||
async def search(self, query: str, top_k: int = 5, filters=None) -> list[MemoryItem]:
|
||
results = []
|
||
for item in self._store.values():
|
||
if query.lower() in str(item.value).lower() or query.lower() in item.key.lower():
|
||
results.append(item)
|
||
return results[:top_k]
|
||
|
||
async def delete(self, key: str) -> bool:
|
||
return self._store.pop(key, None) is not None
|
||
|
||
|
||
# ── MemoryRetriever 测试 ─────────────────────────────────
|
||
|
||
|
||
class TestMemoryRetrieverParallelQuery:
|
||
"""并行查询测试"""
|
||
|
||
async def test_parallel_query_across_layers(self):
|
||
"""并行查询多个记忆层"""
|
||
working = InMemoryMemory()
|
||
episodic = InMemoryMemory()
|
||
semantic = InMemoryMemory()
|
||
|
||
await working.store("w1", "Working memory content about AI")
|
||
await episodic.store("e1", "Episodic memory content about AI")
|
||
await semantic.store("s1", "Semantic memory content about AI")
|
||
|
||
retriever = MemoryRetriever(
|
||
working_memory=working,
|
||
episodic_memory=episodic,
|
||
semantic_memory=semantic,
|
||
)
|
||
|
||
results = await retriever.retrieve("AI")
|
||
assert len(results) >= 3
|
||
|
||
async def test_single_layer_query(self):
|
||
"""仅配置一个记忆层时正常工作"""
|
||
working = InMemoryMemory()
|
||
await working.store("w1", "Only working memory result")
|
||
|
||
retriever = MemoryRetriever(working_memory=working)
|
||
results = await retriever.retrieve("working")
|
||
assert len(results) >= 1
|
||
|
||
|
||
class TestMemoryRetrieverWeightFusion:
|
||
"""权重融合排序测试"""
|
||
|
||
async def test_weight_based_fusion_sorting(self):
|
||
"""权重影响融合排序:高权重层的结果排在前面"""
|
||
working = InMemoryMemory()
|
||
semantic = InMemoryMemory()
|
||
|
||
await working.store("w1", "Working memory result")
|
||
await semantic.store("s1", "Semantic memory result")
|
||
|
||
# Semantic 权重远高于 Working
|
||
retriever = MemoryRetriever(
|
||
working_memory=working,
|
||
semantic_memory=semantic,
|
||
weights={"working": 0.1, "semantic": 0.9},
|
||
)
|
||
|
||
results = await retriever.retrieve("result")
|
||
assert len(results) >= 2
|
||
|
||
# Semantic 权重更高,其结果应排在前面
|
||
semantic_items = [r for r in results if r.key == "s1"]
|
||
working_items = [r for r in results if r.key == "w1"]
|
||
if semantic_items and working_items:
|
||
assert semantic_items[0].score > working_items[0].score
|
||
|
||
async def test_default_weights(self):
|
||
"""默认权重配置"""
|
||
retriever = MemoryRetriever()
|
||
assert retriever._weights == {"working": 0.2, "episodic": 0.4, "semantic": 0.4}
|
||
|
||
async def test_custom_weights(self):
|
||
"""自定义权重"""
|
||
retriever = MemoryRetriever(
|
||
weights={"working": 0.5, "episodic": 0.3, "semantic": 0.2}
|
||
)
|
||
assert retriever._weights["working"] == 0.5
|
||
assert retriever._weights["episodic"] == 0.3
|
||
assert retriever._weights["semantic"] == 0.2
|
||
|
||
|
||
class TestMemoryRetrieverTokenBudget:
|
||
"""Token 预算管理测试"""
|
||
|
||
async def test_token_budget_truncation(self):
|
||
"""Token 超预算时截断结果"""
|
||
working = InMemoryMemory()
|
||
# 存储大量长文本
|
||
for i in range(20):
|
||
await working.store(f"item_{i}", f"Long content item number {i} " * 50)
|
||
|
||
retriever = MemoryRetriever(working_memory=working)
|
||
results = await retriever.retrieve("content", token_budget=200)
|
||
|
||
total_chars = sum(len(str(r.value)) for r in results)
|
||
# 粗略估算 token 数不应远超预算
|
||
assert total_chars // 4 <= 250 # 允许少量溢出
|
||
|
||
async def test_large_budget_returns_more(self):
|
||
"""大预算返回更多结果"""
|
||
working = InMemoryMemory()
|
||
for i in range(10):
|
||
await working.store(f"item_{i}", f"Content item {i}")
|
||
|
||
retriever = MemoryRetriever(working_memory=working)
|
||
small_budget = await retriever.retrieve("Content", token_budget=10)
|
||
large_budget = await retriever.retrieve("Content", token_budget=10000)
|
||
|
||
assert len(large_budget) >= len(small_budget)
|
||
|
||
async def test_zero_budget_returns_empty(self):
|
||
"""零预算返回空结果"""
|
||
working = InMemoryMemory()
|
||
await working.store("w1", "Some content")
|
||
|
||
retriever = MemoryRetriever(working_memory=working)
|
||
results = await retriever.retrieve("content", token_budget=0)
|
||
assert len(results) == 0
|
||
|
||
|
||
class TestMemoryRetrieverMissingLayer:
|
||
"""缺失记忆层测试"""
|
||
|
||
async def test_missing_memory_layer_doesnt_break(self):
|
||
"""缺失某个记忆层不会导致检索失败"""
|
||
working = InMemoryMemory()
|
||
await working.store("w1", "Working memory only")
|
||
|
||
# 只配置 working,episodic 和 semantic 为 None
|
||
retriever = MemoryRetriever(
|
||
working_memory=working,
|
||
episodic_memory=None,
|
||
semantic_memory=None,
|
||
)
|
||
|
||
results = await retriever.retrieve("Working")
|
||
assert len(results) >= 1
|
||
|
||
async def test_no_memory_layers_returns_empty(self):
|
||
"""没有任何记忆层时返回空列表"""
|
||
retriever = MemoryRetriever()
|
||
results = await retriever.retrieve("anything")
|
||
assert results == []
|
||
|
||
async def test_exception_in_layer_doesnt_break(self):
|
||
"""某个记忆层抛出异常不影响其他层"""
|
||
working = InMemoryMemory()
|
||
await working.store("w1", "Working memory result")
|
||
|
||
# 创建一个会抛出异常的 mock memory
|
||
failing_memory = AsyncMock()
|
||
failing_memory.search = AsyncMock(side_effect=Exception("Service unavailable"))
|
||
|
||
retriever = MemoryRetriever(
|
||
working_memory=working,
|
||
episodic_memory=failing_memory,
|
||
)
|
||
|
||
results = await retriever.retrieve("Working")
|
||
# 即使 episodic 失败,working 的结果仍应返回
|
||
assert len(results) >= 1
|
||
|
||
|
||
class TestMemoryRetrieverContextString:
|
||
"""get_context_string 测试"""
|
||
|
||
async def test_get_context_string_returns_formatted_string(self):
|
||
"""get_context_string 返回格式化字符串"""
|
||
working = InMemoryMemory()
|
||
await working.store("ctx1", "Context about Python programming")
|
||
await working.store("ctx2", "Context about AI research")
|
||
|
||
retriever = MemoryRetriever(working_memory=working)
|
||
context = await retriever.get_context_string("Python")
|
||
|
||
assert isinstance(context, str)
|
||
assert "Python" in context
|
||
|
||
async def test_get_context_string_empty_result(self):
|
||
"""无匹配结果时返回空字符串"""
|
||
working = InMemoryMemory()
|
||
await working.store("ctx1", "Unrelated content")
|
||
|
||
retriever = MemoryRetriever(working_memory=working)
|
||
context = await retriever.get_context_string("nonexistent_topic")
|
||
|
||
# InMemoryMemory 的 search 会匹配 key,所以结果取决于 query
|
||
assert isinstance(context, str)
|
||
|
||
async def test_get_context_string_multiple_items(self):
|
||
"""多个结果时用双换行分隔"""
|
||
working = InMemoryMemory()
|
||
await working.store("ctx1", "First context item about testing")
|
||
await working.store("ctx2", "Second context item about testing")
|
||
|
||
retriever = MemoryRetriever(working_memory=working)
|
||
context = await retriever.get_context_string("testing")
|
||
|
||
if "First" in context and "Second" in context:
|
||
assert "\n\n" in context
|