"""U4 测试: 记忆系统 - 三层记忆 + 混合检索 + BaseAgent 生命周期集成""" import math from datetime import datetime, timedelta, timezone from unittest.mock import AsyncMock import pytest from agentkit.core.base import BaseAgent from agentkit.core.protocol import AgentCapability, TaskMessage, TaskResult, TaskStatus from agentkit.memory.base import Memory, MemoryItem from agentkit.memory.episodic import EpisodicMemory from agentkit.memory.retriever import MemoryRetriever from agentkit.memory.semantic import SemanticMemory from agentkit.memory.working import WorkingMemory # ── In-Memory Memory 实现(用于测试) ──────────────────── class InMemoryMemory(Memory): """基于内存的 Memory 实现,用于测试""" def __init__(self): self._store: dict[str, MemoryItem] = {} async def store(self, key: str, value, metadata=None) -> None: self._store[key] = MemoryItem( key=key, value=value, metadata=metadata or {}, score=1.0 ) async def retrieve(self, key: str) -> MemoryItem | None: return self._store.get(key) async def search(self, query: str, top_k: int = 5, filters=None) -> list[MemoryItem]: results = [] for item in self._store.values(): if query.lower() in str(item.value).lower() or query.lower() in item.key.lower(): results.append(item) return results[:top_k] async def delete(self, key: str) -> bool: return self._store.pop(key, None) is not None # ── Memory 基类测试 ────────────────────────────────────── class TestMemoryBase: async def test_in_memory_store_and_retrieve(self): mem = InMemoryMemory() await mem.store("key1", {"data": "hello"}, {"tag": "test"}) item = await mem.retrieve("key1") assert item is not None assert item.value["data"] == "hello" assert item.metadata["tag"] == "test" async def test_in_memory_search(self): mem = InMemoryMemory() await mem.store("agent:task1", "Generated content about AI") await mem.store("agent:task2", "Analysis of trends") await mem.store("agent:task3", "AI research summary") results = await mem.search("AI") assert len(results) == 2 async def test_in_memory_delete(self): mem = InMemoryMemory() await mem.store("key1", "value1") assert await mem.delete("key1") is True assert await mem.retrieve("key1") is None assert await mem.delete("nonexistent") is False async def test_batch_store(self): mem = InMemoryMemory() items = [("k1", "v1", None), ("k2", "v2", {"tag": "t"})] await mem.store_batch(items) assert await mem.retrieve("k1") is not None assert await mem.retrieve("k2") is not None async def test_get_context(self): mem = InMemoryMemory() await mem.store("k1", "First context item about Python") await mem.store("k2", "Second context item about AI") context = await mem.get_context("Python") assert "Python" in context # ── SemanticMemory 测试 ────────────────────────────────── class TestSemanticMemory: async def test_rag_search(self): """通过 RAG 服务检索知识""" class MockRAGService: async def search(self, query, knowledge_base_ids=None, top_k=5): return [ {"id": "doc1", "content": f"Knowledge about {query}", "score": 0.9, "source": "kb1"}, ] mem = SemanticMemory(rag_service=MockRAGService(), knowledge_base_ids=["kb1"]) results = await mem.search("AI trends") assert len(results) == 1 assert results[0].value == "Knowledge about AI trends" assert results[0].score == 0.9 async def test_graph_search(self): """通过知识图谱检索""" class MockGraphService: async def query(self, query, depth=2): return [ {"id": "node1", "content": f"Entity: {query}", "score": 0.7, "entities": ["AI"]}, ] mem = SemanticMemory(graph_service=MockGraphService()) results = await mem.search("machine learning") assert len(results) == 1 assert results[0].metadata["source"] == "graph" async def test_combined_rag_and_graph(self): """RAG + 图谱联合检索""" class MockRAG: async def search(self, query, **kwargs): return [{"id": "r1", "content": "RAG result", "score": 0.8, "source": "rag"}] class MockGraph: async def query(self, query, **kwargs): return [{"id": "g1", "content": "Graph result", "score": 0.6, "source": "graph"}] mem = SemanticMemory(rag_service=MockRAG(), graph_service=MockGraph()) results = await mem.search("test") assert len(results) == 2 # 按 score 排序 assert results[0].score >= results[1].score async def test_semantic_read_only(self): """Semantic Memory 通常只读""" mem = SemanticMemory() assert await mem.delete("any") is False # ── EpisodicMemory 测试(使用 mock ORM) ───────────────── class TestEpisodicMemory: async def test_time_decay(self): """时间衰减:近期经验权重高于远期""" # 直接测试衰减公式 decay_rate = 0.01 now = datetime.now(timezone.utc) recent_score = 0.8 * math.exp(-decay_rate * 1) # 1 hour ago old_score = 0.8 * math.exp(-decay_rate * 100) # 100 hours ago assert recent_score > old_score assert recent_score > 0.7 assert old_score < 0.5 # ── MemoryRetriever 测试 ───────────────────────────────── class TestMemoryRetriever: async def test_retriever_with_in_memory(self): """混合检索器使用 InMemoryMemory""" working = InMemoryMemory() await working.store("current_task", "Working on AI content generation") retriever = MemoryRetriever(working_memory=working) results = await retriever.retrieve("AI content") assert len(results) >= 1 async def test_retriever_weights(self): """不同层权重影响排序""" working = InMemoryMemory() semantic = InMemoryMemory() await working.store("task1", "Working memory result") await semantic.store("doc1", "Semantic memory result") retriever = MemoryRetriever( working_memory=working, semantic_memory=semantic, weights={"working": 0.2, "semantic": 0.8}, ) results = await retriever.retrieve("result") # Semantic 权重更高,应排前面 if len(results) >= 2: assert results[0].score >= results[1].score async def test_retriever_token_budget(self): """Token 预算管理""" working = InMemoryMemory() for i in range(20): await working.store(f"item_{i}", f"Long content item number {i} " * 50) retriever = MemoryRetriever(working_memory=working) results = await retriever.retrieve("content", token_budget=200) total_chars = sum(len(str(r.value)) for r in results) # 粗略估算 token 数不应远超预算 assert total_chars // 4 <= 250 # 允许少量溢出 async def test_get_context_string(self): """获取格式化上下文字符串""" working = InMemoryMemory() await working.store("ctx1", "Context about Python programming") retriever = MemoryRetriever(working_memory=working) context = await retriever.get_context_string("Python") assert "Python" in context async def test_empty_retriever(self): """无记忆层时不报错""" retriever = MemoryRetriever() results = await retriever.retrieve("anything") assert results == [] # ── BaseAgent + Memory 生命周期集成测试 ─────────────────── class TestAgentMemoryIntegration: async def test_memory_injected_into_agent(self): """Memory 可注入到 Agent""" class TestAgent(BaseAgent): async def handle_task(self, task): return {"result": "done"} def get_capabilities(self): return AgentCapability( agent_name=self.name, agent_type="test", version="1.0", supported_tasks=["test"], max_concurrency=1, description="test", ) agent = TestAgent(name="test", agent_type="test") mem = InMemoryMemory() agent.use_memory(mem) assert agent.memory is mem async def test_on_task_complete_stores_memory(self): """on_task_complete 钩子可存储记忆""" class MemoryAwareAgent(BaseAgent): async def handle_task(self, task): return {"answer": "42"} def get_capabilities(self): return AgentCapability( agent_name=self.name, agent_type="test", version="1.0", supported_tasks=["test"], max_concurrency=1, description="test", ) async def on_task_complete(self, task, output): if self.memory: await self.memory.store( f"task:{task.task_id}", output, {"agent_name": self.name, "task_type": task.task_type}, ) mem = InMemoryMemory() agent = MemoryAwareAgent(name="mem_agent", agent_type="test") agent.use_memory(mem) task = TaskMessage( task_id="t-001", agent_name="mem_agent", task_type="test", priority=1, input_data={}, callback_url=None, created_at=datetime.now(timezone.utc), ) result = await agent.execute(task) assert result.status == TaskStatus.COMPLETED # 验证记忆已存储 stored = await mem.retrieve("task:t-001") assert stored is not None assert stored.value["answer"] == "42" async def test_on_task_start_loads_memory(self): """on_task_start 钩子可加载记忆到任务上下文""" class ContextAwareAgent(BaseAgent): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.loaded_context = None async def handle_task(self, task): return {"context_used": self.loaded_context is not None} def get_capabilities(self): return AgentCapability( agent_name=self.name, agent_type="test", version="1.0", supported_tasks=["test"], max_concurrency=1, description="test", ) async def on_task_start(self, task): if self.memory: context = await self.memory.get_context(task.task_type) self.loaded_context = context mem = InMemoryMemory() await mem.store("test", "Previous experience with similar tasks") agent = ContextAwareAgent(name="ctx_agent", agent_type="test") agent.use_memory(mem) task = TaskMessage( task_id="t-002", agent_name="ctx_agent", task_type="test", priority=1, input_data={}, callback_url=None, created_at=datetime.now(timezone.utc), ) result = await agent.execute(task) assert result.output_data["context_used"] is True async def test_on_task_failed_records_failure(self): """on_task_failed 钩子可记录失败模式""" class ResilientAgent(BaseAgent): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.failure_recorded = False async def handle_task(self, task): raise ValueError("simulated failure") def get_capabilities(self): return AgentCapability( agent_name=self.name, agent_type="test", version="1.0", supported_tasks=["test"], max_concurrency=1, description="test", ) async def on_task_failed(self, task, error): if self.memory: await self.memory.store( f"failure:{task.task_id}", {"error": str(error), "task_type": task.task_type}, {"outcome": "failure"}, ) self.failure_recorded = True mem = InMemoryMemory() agent = ResilientAgent(name="resilient", agent_type="test") agent.use_memory(mem) task = TaskMessage( task_id="t-003", agent_name="resilient", task_type="test", priority=1, input_data={}, callback_url=None, created_at=datetime.now(timezone.utc), ) result = await agent.execute(task) assert result.status == TaskStatus.FAILED assert agent.failure_recorded is True stored = await mem.retrieve("failure:t-003") assert stored is not None assert "simulated failure" in stored.value["error"]