"""MemoryRetriever 单元测试 - 混合检索器 使用 InMemoryMemory 实现进行测试,不需要真实 Redis/PG 环境。 """ from unittest.mock import AsyncMock import pytest from agentkit.memory.base import Memory, MemoryItem from agentkit.memory.retriever import MemoryRetriever # ── In-Memory Memory 实现(用于测试) ──────────────────── class InMemoryMemory(Memory): """基于内存的 Memory 实现,用于测试""" def __init__(self): self._store: dict[str, MemoryItem] = {} async def store(self, key: str, value, metadata=None) -> None: self._store[key] = MemoryItem( key=key, value=value, metadata=metadata or {}, score=1.0 ) async def retrieve(self, key: str) -> MemoryItem | None: return self._store.get(key) async def search(self, query: str, top_k: int = 5, filters=None) -> list[MemoryItem]: results = [] for item in self._store.values(): if query.lower() in str(item.value).lower() or query.lower() in item.key.lower(): results.append(item) return results[:top_k] async def delete(self, key: str) -> bool: return self._store.pop(key, None) is not None # ── MemoryRetriever 测试 ───────────────────────────────── class TestMemoryRetrieverParallelQuery: """并行查询测试""" async def test_parallel_query_across_layers(self): """并行查询多个记忆层""" working = InMemoryMemory() episodic = InMemoryMemory() semantic = InMemoryMemory() await working.store("w1", "Working memory content about AI") await episodic.store("e1", "Episodic memory content about AI") await semantic.store("s1", "Semantic memory content about AI") retriever = MemoryRetriever( working_memory=working, episodic_memory=episodic, semantic_memory=semantic, ) results = await retriever.retrieve("AI") assert len(results) >= 3 async def test_single_layer_query(self): """仅配置一个记忆层时正常工作""" working = InMemoryMemory() await working.store("w1", "Only working memory result") retriever = MemoryRetriever(working_memory=working) results = await retriever.retrieve("working") assert len(results) >= 1 class TestMemoryRetrieverWeightFusion: """权重融合排序测试""" async def test_weight_based_fusion_sorting(self): """权重影响融合排序:高权重层的结果排在前面""" working = InMemoryMemory() semantic = InMemoryMemory() await working.store("w1", "Working memory result") await semantic.store("s1", "Semantic memory result") # Semantic 权重远高于 Working retriever = MemoryRetriever( working_memory=working, semantic_memory=semantic, weights={"working": 0.1, "semantic": 0.9}, ) results = await retriever.retrieve("result") assert len(results) >= 2 # Semantic 权重更高,其结果应排在前面 semantic_items = [r for r in results if r.key == "s1"] working_items = [r for r in results if r.key == "w1"] if semantic_items and working_items: assert semantic_items[0].score > working_items[0].score async def test_default_weights(self): """默认权重配置""" retriever = MemoryRetriever() assert retriever._weights == {"working": 0.2, "episodic": 0.4, "semantic": 0.4} async def test_custom_weights(self): """自定义权重""" retriever = MemoryRetriever( weights={"working": 0.5, "episodic": 0.3, "semantic": 0.2} ) assert retriever._weights["working"] == 0.5 assert retriever._weights["episodic"] == 0.3 assert retriever._weights["semantic"] == 0.2 class TestMemoryRetrieverTokenBudget: """Token 预算管理测试""" async def test_token_budget_truncation(self): """Token 超预算时截断结果""" working = InMemoryMemory() # 存储大量长文本 for i in range(20): await working.store(f"item_{i}", f"Long content item number {i} " * 50) retriever = MemoryRetriever(working_memory=working) results = await retriever.retrieve("content", token_budget=200) total_chars = sum(len(str(r.value)) for r in results) # 粗略估算 token 数不应远超预算 assert total_chars // 4 <= 250 # 允许少量溢出 async def test_large_budget_returns_more(self): """大预算返回更多结果""" working = InMemoryMemory() for i in range(10): await working.store(f"item_{i}", f"Content item {i}") retriever = MemoryRetriever(working_memory=working) small_budget = await retriever.retrieve("Content", token_budget=10) large_budget = await retriever.retrieve("Content", token_budget=10000) assert len(large_budget) >= len(small_budget) async def test_zero_budget_returns_empty(self): """零预算返回空结果""" working = InMemoryMemory() await working.store("w1", "Some content") retriever = MemoryRetriever(working_memory=working) results = await retriever.retrieve("content", token_budget=0) assert len(results) == 0 class TestMemoryRetrieverMissingLayer: """缺失记忆层测试""" async def test_missing_memory_layer_doesnt_break(self): """缺失某个记忆层不会导致检索失败""" working = InMemoryMemory() await working.store("w1", "Working memory only") # 只配置 working,episodic 和 semantic 为 None retriever = MemoryRetriever( working_memory=working, episodic_memory=None, semantic_memory=None, ) results = await retriever.retrieve("Working") assert len(results) >= 1 async def test_no_memory_layers_returns_empty(self): """没有任何记忆层时返回空列表""" retriever = MemoryRetriever() results = await retriever.retrieve("anything") assert results == [] async def test_exception_in_layer_doesnt_break(self): """某个记忆层抛出异常不影响其他层""" working = InMemoryMemory() await working.store("w1", "Working memory result") # 创建一个会抛出异常的 mock memory failing_memory = AsyncMock() failing_memory.search = AsyncMock(side_effect=Exception("Service unavailable")) retriever = MemoryRetriever( working_memory=working, episodic_memory=failing_memory, ) results = await retriever.retrieve("Working") # 即使 episodic 失败,working 的结果仍应返回 assert len(results) >= 1 class TestMemoryRetrieverContextString: """get_context_string 测试""" async def test_get_context_string_returns_formatted_string(self): """get_context_string 返回格式化字符串""" working = InMemoryMemory() await working.store("ctx1", "Context about Python programming") await working.store("ctx2", "Context about AI research") retriever = MemoryRetriever(working_memory=working) context = await retriever.get_context_string("Python") assert isinstance(context, str) assert "Python" in context async def test_get_context_string_empty_result(self): """无匹配结果时返回空字符串""" working = InMemoryMemory() await working.store("ctx1", "Unrelated content") retriever = MemoryRetriever(working_memory=working) context = await retriever.get_context_string("nonexistent_topic") # InMemoryMemory 的 search 会匹配 key,所以结果取决于 query assert isinstance(context, str) async def test_get_context_string_multiple_items(self): """多个结果时用双换行分隔""" working = InMemoryMemory() await working.store("ctx1", "First context item about testing") await working.store("ctx2", "Second context item about testing") retriever = MemoryRetriever(working_memory=working) context = await retriever.get_context_string("testing") if "First" in context and "Second" in context: assert "\n\n" in context