fischer-agentkit/tests/unit/test_memory_retriever.py

238 lines
8.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""MemoryRetriever 单元测试 - 混合检索器
使用 InMemoryMemory 实现进行测试,不需要真实 Redis/PG 环境。
"""
from unittest.mock import AsyncMock
import pytest
from agentkit.memory.base import Memory, MemoryItem
from agentkit.memory.retriever import MemoryRetriever
# ── In-Memory Memory 实现(用于测试) ────────────────────
class InMemoryMemory(Memory):
"""基于内存的 Memory 实现,用于测试"""
def __init__(self):
self._store: dict[str, MemoryItem] = {}
async def store(self, key: str, value, metadata=None) -> None:
self._store[key] = MemoryItem(
key=key, value=value, metadata=metadata or {}, score=1.0
)
async def retrieve(self, key: str) -> MemoryItem | None:
return self._store.get(key)
async def search(self, query: str, top_k: int = 5, filters=None) -> list[MemoryItem]:
results = []
for item in self._store.values():
if query.lower() in str(item.value).lower() or query.lower() in item.key.lower():
results.append(item)
return results[:top_k]
async def delete(self, key: str) -> bool:
return self._store.pop(key, None) is not None
# ── MemoryRetriever 测试 ─────────────────────────────────
class TestMemoryRetrieverParallelQuery:
"""并行查询测试"""
async def test_parallel_query_across_layers(self):
"""并行查询多个记忆层"""
working = InMemoryMemory()
episodic = InMemoryMemory()
semantic = InMemoryMemory()
await working.store("w1", "Working memory content about AI")
await episodic.store("e1", "Episodic memory content about AI")
await semantic.store("s1", "Semantic memory content about AI")
retriever = MemoryRetriever(
working_memory=working,
episodic_memory=episodic,
semantic_memory=semantic,
)
results = await retriever.retrieve("AI")
assert len(results) >= 3
async def test_single_layer_query(self):
"""仅配置一个记忆层时正常工作"""
working = InMemoryMemory()
await working.store("w1", "Only working memory result")
retriever = MemoryRetriever(working_memory=working)
results = await retriever.retrieve("working")
assert len(results) >= 1
class TestMemoryRetrieverWeightFusion:
"""权重融合排序测试"""
async def test_weight_based_fusion_sorting(self):
"""权重影响融合排序:高权重层的结果排在前面"""
working = InMemoryMemory()
semantic = InMemoryMemory()
await working.store("w1", "Working memory result")
await semantic.store("s1", "Semantic memory result")
# Semantic 权重远高于 Working
retriever = MemoryRetriever(
working_memory=working,
semantic_memory=semantic,
weights={"working": 0.1, "semantic": 0.9},
)
results = await retriever.retrieve("result")
assert len(results) >= 2
# Semantic 权重更高,其结果应排在前面
semantic_items = [r for r in results if r.key == "s1"]
working_items = [r for r in results if r.key == "w1"]
if semantic_items and working_items:
assert semantic_items[0].score > working_items[0].score
async def test_default_weights(self):
"""默认权重配置"""
retriever = MemoryRetriever()
assert retriever._weights == {"working": 0.2, "episodic": 0.4, "semantic": 0.4}
async def test_custom_weights(self):
"""自定义权重"""
retriever = MemoryRetriever(
weights={"working": 0.5, "episodic": 0.3, "semantic": 0.2}
)
assert retriever._weights["working"] == 0.5
assert retriever._weights["episodic"] == 0.3
assert retriever._weights["semantic"] == 0.2
class TestMemoryRetrieverTokenBudget:
"""Token 预算管理测试"""
async def test_token_budget_truncation(self):
"""Token 超预算时截断结果"""
working = InMemoryMemory()
# 存储大量长文本
for i in range(20):
await working.store(f"item_{i}", f"Long content item number {i} " * 50)
retriever = MemoryRetriever(working_memory=working)
results = await retriever.retrieve("content", token_budget=200)
total_chars = sum(len(str(r.value)) for r in results)
# 粗略估算 token 数不应远超预算
assert total_chars // 4 <= 250 # 允许少量溢出
async def test_large_budget_returns_more(self):
"""大预算返回更多结果"""
working = InMemoryMemory()
for i in range(10):
await working.store(f"item_{i}", f"Content item {i}")
retriever = MemoryRetriever(working_memory=working)
small_budget = await retriever.retrieve("Content", token_budget=10)
large_budget = await retriever.retrieve("Content", token_budget=10000)
assert len(large_budget) >= len(small_budget)
async def test_zero_budget_returns_empty(self):
"""零预算返回空结果"""
working = InMemoryMemory()
await working.store("w1", "Some content")
retriever = MemoryRetriever(working_memory=working)
results = await retriever.retrieve("content", token_budget=0)
assert len(results) == 0
class TestMemoryRetrieverMissingLayer:
"""缺失记忆层测试"""
async def test_missing_memory_layer_doesnt_break(self):
"""缺失某个记忆层不会导致检索失败"""
working = InMemoryMemory()
await working.store("w1", "Working memory only")
# 只配置 workingepisodic 和 semantic 为 None
retriever = MemoryRetriever(
working_memory=working,
episodic_memory=None,
semantic_memory=None,
)
results = await retriever.retrieve("Working")
assert len(results) >= 1
async def test_no_memory_layers_returns_empty(self):
"""没有任何记忆层时返回空列表"""
retriever = MemoryRetriever()
results = await retriever.retrieve("anything")
assert results == []
async def test_exception_in_layer_doesnt_break(self):
"""某个记忆层抛出异常不影响其他层"""
working = InMemoryMemory()
await working.store("w1", "Working memory result")
# 创建一个会抛出异常的 mock memory
failing_memory = AsyncMock()
failing_memory.search = AsyncMock(side_effect=Exception("Service unavailable"))
retriever = MemoryRetriever(
working_memory=working,
episodic_memory=failing_memory,
)
results = await retriever.retrieve("Working")
# 即使 episodic 失败working 的结果仍应返回
assert len(results) >= 1
class TestMemoryRetrieverContextString:
"""get_context_string 测试"""
async def test_get_context_string_returns_formatted_string(self):
"""get_context_string 返回格式化字符串"""
working = InMemoryMemory()
await working.store("ctx1", "Context about Python programming")
await working.store("ctx2", "Context about AI research")
retriever = MemoryRetriever(working_memory=working)
context = await retriever.get_context_string("Python")
assert isinstance(context, str)
assert "Python" in context
async def test_get_context_string_empty_result(self):
"""无匹配结果时返回空字符串"""
working = InMemoryMemory()
await working.store("ctx1", "Unrelated content")
retriever = MemoryRetriever(working_memory=working)
context = await retriever.get_context_string("nonexistent_topic")
# InMemoryMemory 的 search 会匹配 key所以结果取决于 query
assert isinstance(context, str)
async def test_get_context_string_multiple_items(self):
"""多个结果时用双换行分隔"""
working = InMemoryMemory()
await working.store("ctx1", "First context item about testing")
await working.store("ctx2", "Second context item about testing")
retriever = MemoryRetriever(working_memory=working)
context = await retriever.get_context_string("testing")
if "First" in context and "Second" in context:
assert "\n\n" in context