"""U4: 记忆接入 Agent 循环 - 集成测试 测试 MemoryRetriever 注入 ReActEngine 的完整流程: 1. 执行前检索相关上下文注入 system_prompt 2. 执行后写入轨迹摘要到 EpisodicMemory 3. Memory 检索失败不中断任务执行 4. ConfigDrivenAgent 从 config.memory 自动创建 MemoryRetriever 5. BaseAgent.use_memory_retriever() 方法 """ from unittest.mock import AsyncMock, MagicMock, patch import pytest from agentkit.core.react import ReActEngine, ReActResult from agentkit.llm.gateway import LLMGateway from agentkit.llm.protocol import LLMResponse, TokenUsage # ── Test Helpers ────────────────────────────────────────── def make_mock_gateway(responses: list[LLMResponse]) -> LLMGateway: """创建一个 mock LLMGateway,按顺序返回给定响应""" gateway = MagicMock(spec=LLMGateway) gateway.chat = AsyncMock(side_effect=responses) return gateway def make_response( content: str = "", prompt_tokens: int = 10, completion_tokens: int = 20, ) -> LLMResponse: """快速构造 LLMResponse""" return LLMResponse( content=content, model="test-model", usage=TokenUsage( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, ), tool_calls=[], ) def make_mock_memory_retriever(context_string: str = "past experience data"): """创建一个 mock MemoryRetriever""" retriever = MagicMock() retriever.get_context_string = AsyncMock(return_value=context_string) retriever._episodic = None retriever.store_episode = AsyncMock() return retriever def make_mock_episodic_memory(): """创建一个 mock EpisodicMemory""" episodic = MagicMock() episodic.store = AsyncMock() return episodic # ── Test: Memory context injected into system_prompt ────────── class TestMemoryContextInjection: """Memory 上下文注入 system_prompt 测试""" async def test_memory_context_appended_to_existing_system_prompt(self): """当有 system_prompt 时,memory context 追加到末尾""" gateway = make_mock_gateway([make_response(content="final answer")]) engine = ReActEngine(llm_gateway=gateway, max_steps=3) retriever = make_mock_memory_retriever("Previous task result: success") result = await engine.execute( messages=[{"role": "user", "content": "Do something"}], system_prompt="You are a helpful assistant.", memory_retriever=retriever, ) assert isinstance(result, ReActResult) retriever.get_context_string.assert_awaited_once_with( query="Do something", top_k=5, token_budget=2000, ) # Verify system_prompt was augmented with memory context call_args = gateway.chat.call_args messages_sent = call_args.kwargs.get("messages") or call_args[1].get("messages") # The first message should be system with appended context system_msg = messages_sent[0] assert system_msg["role"] == "system" assert "You are a helpful assistant." in system_msg["content"] assert "参考信息" in system_msg["content"] assert "Previous task result: success" in system_msg["content"] async def test_memory_context_used_as_system_prompt_when_none(self): """当没有 system_prompt 时,memory context 作为 system_prompt""" gateway = make_mock_gateway([make_response(content="final answer")]) engine = ReActEngine(llm_gateway=gateway, max_steps=3) retriever = make_mock_memory_retriever("Past context only") result = await engine.execute( messages=[{"role": "user", "content": "Hello"}], memory_retriever=retriever, ) assert isinstance(result, ReActResult) call_args = gateway.chat.call_args messages_sent = call_args.kwargs.get("messages") or call_args[1].get("messages") system_msg = messages_sent[0] assert system_msg["role"] == "system" assert "参考信息" in system_msg["content"] assert "Past context only" in system_msg["content"] async def test_no_memory_context_when_retriever_is_none(self): """当 memory_retriever 为 None 时,不注入 memory context""" gateway = make_mock_gateway([make_response(content="final answer")]) engine = ReActEngine(llm_gateway=gateway, max_steps=3) result = await engine.execute( messages=[{"role": "user", "content": "Hello"}], system_prompt="You are a helper.", memory_retriever=None, ) assert isinstance(result, ReActResult) call_args = gateway.chat.call_args messages_sent = call_args.kwargs.get("messages") or call_args[1].get("messages") system_msg = messages_sent[0] assert system_msg["content"] == "You are a helper." assert "参考信息" not in system_msg["content"] async def test_empty_memory_context_not_injected(self): """当 memory context 为空字符串时,不注入""" gateway = make_mock_gateway([make_response(content="final answer")]) engine = ReActEngine(llm_gateway=gateway, max_steps=3) retriever = make_mock_memory_retriever(context_string="") result = await engine.execute( messages=[{"role": "user", "content": "Hello"}], system_prompt="You are a helper.", memory_retriever=retriever, ) assert isinstance(result, ReActResult) call_args = gateway.chat.call_args messages_sent = call_args.kwargs.get("messages") or call_args[1].get("messages") system_msg = messages_sent[0] assert system_msg["content"] == "You are a helper." assert "参考信息" not in system_msg["content"] # ── Test: Memory retrieval failure doesn't break execution ────────── class TestMemoryRetrievalFailure: """Memory 检索失败不中断任务执行""" async def test_retrieval_failure_continues_without_context(self): """Memory 检索异常时,任务正常执行""" gateway = make_mock_gateway([make_response(content="still works")]) engine = ReActEngine(llm_gateway=gateway, max_steps=3) retriever = make_mock_memory_retriever() retriever.get_context_string = AsyncMock(side_effect=RuntimeError("Redis down")) result = await engine.execute( messages=[{"role": "user", "content": "Hello"}], system_prompt="You are a helper.", memory_retriever=retriever, ) # Task should still complete assert isinstance(result, ReActResult) assert result.output == "still works" # system_prompt should NOT have memory context call_args = gateway.chat.call_args messages_sent = call_args.kwargs.get("messages") or call_args[1].get("messages") system_msg = messages_sent[0] assert "参考信息" not in system_msg["content"] # ── Test: Task result stored in episodic memory ────────── class TestEpisodicMemoryStorage: """执行后写入轨迹摘要到 EpisodicMemory""" async def test_result_stored_in_episodic_memory(self): """任务完成后,结果摘要存储到 EpisodicMemory""" gateway = make_mock_gateway([make_response(content="The answer is 42")]) engine = ReActEngine(llm_gateway=gateway, max_steps=3) episodic = make_mock_episodic_memory() retriever = make_mock_memory_retriever(context_string="") retriever._episodic = episodic result = await engine.execute( messages=[{"role": "user", "content": "What is the answer?"}], memory_retriever=retriever, task_id="task-123", agent_name="test-agent", task_type="qa", ) assert isinstance(result, ReActResult) retriever.store_episode.assert_awaited_once() call_kwargs = retriever.store_episode.call_args assert call_kwargs.kwargs.get("key") == "task:task-123" or call_kwargs[1].get("key") == "task:task-123" # Verify metadata metadata = call_kwargs.kwargs.get("metadata") or call_kwargs[1].get("metadata") assert metadata["task_type"] == "qa" assert metadata["outcome"] == "success" async def test_no_storage_when_no_episodic_memory(self): """没有 EpisodicMemory 时不尝试存储""" gateway = make_mock_gateway([make_response(content="done")]) engine = ReActEngine(llm_gateway=gateway, max_steps=3) retriever = make_mock_memory_retriever(context_string="") retriever._episodic = None result = await engine.execute( messages=[{"role": "user", "content": "Hello"}], memory_retriever=retriever, ) assert isinstance(result, ReActResult) # No exception raised, no store called async def test_storage_failure_doesnt_break_execution(self): """EpisodicMemory 存储失败不中断任务""" gateway = make_mock_gateway([make_response(content="done")]) engine = ReActEngine(llm_gateway=gateway, max_steps=3) retriever = make_mock_memory_retriever(context_string="") retriever.store_episode = AsyncMock(side_effect=RuntimeError("DB down")) result = await engine.execute( messages=[{"role": "user", "content": "Hello"}], memory_retriever=retriever, ) # Task should still complete assert isinstance(result, ReActResult) assert result.output == "done" # ── Test: execute_stream with memory ────────── class TestMemoryInStreamMode: """execute_stream 模式下的 Memory 集成""" async def test_stream_injects_memory_context(self): """execute_stream 也注入 memory context""" gateway = make_mock_gateway([make_response(content="streamed answer")]) engine = ReActEngine(llm_gateway=gateway, max_steps=3) retriever = make_mock_memory_retriever("Stream context") events = [] async for event in engine.execute_stream( messages=[{"role": "user", "content": "Hello"}], system_prompt="You are a helper.", memory_retriever=retriever, ): events.append(event) # Should have events assert len(events) > 0 retriever.get_context_string.assert_awaited_once() async def test_stream_stores_to_episodic(self): """execute_stream 完成后也存储到 EpisodicMemory""" gateway = make_mock_gateway([make_response(content="streamed answer")]) engine = ReActEngine(llm_gateway=gateway, max_steps=3) episodic = make_mock_episodic_memory() retriever = make_mock_memory_retriever(context_string="") retriever._episodic = episodic events = [] async for event in engine.execute_stream( messages=[{"role": "user", "content": "Hello"}], memory_retriever=retriever, task_id="stream-task-1", ): events.append(event) retriever.store_episode.assert_awaited_once() # ── Test: BaseAgent.use_memory_retriever() ────────── class TestBaseAgentMemoryRetriever: """BaseAgent.use_memory_retriever() 方法测试""" def test_use_memory_retriever_sets_field(self): """use_memory_retriever() 正确设置 _memory_retriever""" from agentkit.core.base import BaseAgent # Create a concrete subclass for testing class TestAgent(BaseAgent): async def handle_task(self, task): return {} def get_capabilities(self): from agentkit.core.protocol import AgentCapability return AgentCapability( agent_name=self.name, agent_type=self.agent_type, version=self.version, ) agent = TestAgent(name="test", agent_type="test") mock_retriever = MagicMock() result = agent.use_memory_retriever(mock_retriever) # Should return self for chaining assert result is agent assert agent._memory_retriever is mock_retriever def test_memory_retriever_default_is_none(self): """_memory_retriever 默认为 None""" from agentkit.core.base import BaseAgent class TestAgent(BaseAgent): async def handle_task(self, task): return {} def get_capabilities(self): from agentkit.core.protocol import AgentCapability return AgentCapability( agent_name=self.name, agent_type=self.agent_type, version=self.version, ) agent = TestAgent(name="test", agent_type="test") assert agent._memory_retriever is None # ── Test: ConfigDrivenAgent memory integration ────────── class TestConfigDrivenAgentMemory: """ConfigDrivenAgent 从 config.memory 自动创建 MemoryRetriever""" def test_memory_retriever_created_from_config(self): """config.memory 配置时自动创建 MemoryRetriever""" from agentkit.core.config_driven import ConfigDrivenAgent, AgentConfig config = AgentConfig( name="test-agent", agent_type="test", task_mode="llm_generate", prompt={"identity": "Test agent"}, memory={ "working": {"enabled": False}, "episodic": {"enabled": False}, }, ) with patch("agentkit.core.config_driven.MemoryRetriever", create=True) or \ self._patch_memory_imports(): agent = ConfigDrivenAgent(config=config) # MemoryRetriever should have been created (with no backends since both disabled) assert agent._memory_retriever is not None @staticmethod def _patch_memory_imports(): """Helper to handle import patching""" from unittest.mock import patch return patch("agentkit.memory.retriever.MemoryRetriever") def test_no_memory_retriever_when_no_config(self): """没有 config.memory 时不创建 MemoryRetriever""" from agentkit.core.config_driven import ConfigDrivenAgent, AgentConfig config = AgentConfig( name="test-agent", agent_type="test", task_mode="llm_generate", prompt={"identity": "Test agent"}, ) agent = ConfigDrivenAgent(config=config) assert agent._memory_retriever is None def test_memory_retriever_created_with_empty_memory_dict(self): """config.memory 为空 dict 时创建 MemoryRetriever(无后端)""" from agentkit.core.config_driven import ConfigDrivenAgent, AgentConfig config = AgentConfig( name="test-agent", agent_type="test", task_mode="llm_generate", prompt={"identity": "Test agent"}, memory={}, ) agent = ConfigDrivenAgent(config=config) # Empty dict is falsy, so no retriever assert agent._memory_retriever is None def test_memory_retriever_failure_graceful(self): """Memory 初始化失败时优雅降级""" from agentkit.core.config_driven import ConfigDrivenAgent, AgentConfig config = AgentConfig( name="test-agent", agent_type="test", task_mode="llm_generate", prompt={"identity": "Test agent"}, memory={"working": {"enabled": True, "redis_url": "redis://nonexistent:6379"}}, ) # Should not raise, just log warning and set _memory_retriever to None agent = ConfigDrivenAgent(config=config) # Either retriever was created or gracefully failed # The key is that no exception is raised def test_episodic_memory_created_from_config(self): """config.memory.episodic.enabled=True 时创建 EpisodicMemory""" from agentkit.core.config_driven import ConfigDrivenAgent, AgentConfig config = AgentConfig( name="test-agent", agent_type="test", task_mode="llm_generate", prompt={"identity": "Test agent"}, memory={ "episodic": { "enabled": True, "pgvector_enabled": False, "table_name": "test_memories", "decay_rate": 0.02, "alpha": 0.8, }, }, ) agent = ConfigDrivenAgent(config=config) # MemoryRetriever should be created with episodic memory assert agent._memory_retriever is not None # Episodic memory should be configured assert agent._memory_retriever._episodic is not None assert agent._memory_retriever._episodic._pgvector_enabled is False assert agent._memory_retriever._episodic._table_name == "test_memories" assert agent._memory_retriever._episodic._decay_rate == 0.02 assert agent._memory_retriever._episodic._alpha == 0.8 # ── Test: Structured Context Injection ────────── class TestStructuredContextInjection: """U3: 结构化上下文注入测试""" async def test_structured_format_with_rag_results(self): """结构化格式:RAG 结果包含知识库参考标题""" from agentkit.memory.base import MemoryItem from agentkit.memory.retriever import MemoryRetriever retriever = MemoryRetriever(context_template="structured") # Mock retrieve to return RAG items rag_item = MemoryItem( key="doc-1", value="AI行业在2025年呈现三大趋势...", metadata={"source": "rag", "kb_type": "行业库", "document_title": "AI行业趋势报告"}, score=0.92, ) retriever.retrieve = AsyncMock(return_value=[rag_item]) result = await retriever.get_context_string(query="AI trends", top_k=5, token_budget=3000) assert "### 知识库参考 [来源: 行业库 | 相关度: 0.92 | 文档: AI行业趋势报告]" in result assert "AI行业在2025年呈现三大趋势..." in result async def test_structured_format_with_episodic_results(self): """结构化格式:情景记忆结果包含过往经验标题""" from agentkit.memory.base import MemoryItem from agentkit.memory.retriever import MemoryRetriever retriever = MemoryRetriever(context_template="structured") episodic_item = MemoryItem( key="task:seo-001", value="上次分析竞品SEO策略时发现...", metadata={"source": "episodic", "task_type": "seo_analysis"}, score=0.85, ) retriever.retrieve = AsyncMock(return_value=[episodic_item]) result = await retriever.get_context_string(query="SEO analysis", top_k=5, token_budget=3000) assert "### 过往经验 [来源: 情景记忆 | 任务类型: seo_analysis]" in result assert "上次分析竞品SEO策略时发现..." in result async def test_structured_format_with_mixed_sources(self): """结构化格式:不同来源生成不同标题""" from agentkit.memory.base import MemoryItem from agentkit.memory.retriever import MemoryRetriever retriever = MemoryRetriever(context_template="structured") items = [ MemoryItem( key="doc-1", value="RAG content here", metadata={"source": "rag", "kb_type": "行业库", "document_title": "报告A"}, score=0.90, ), MemoryItem( key="task:ep-1", value="Episodic content here", metadata={"source": "episodic", "task_type": "analysis"}, score=0.80, ), MemoryItem( key="entity-1", value="Graph content here", metadata={"source": "graph"}, score=0.75, ), MemoryItem( key="ctx-1", value="Working memory content", metadata={"source": "working"}, score=0.60, ), MemoryItem( key="other-1", value="Unknown source content", metadata={"source": "custom"}, score=0.50, ), ] retriever.retrieve = AsyncMock(return_value=items) result = await retriever.get_context_string(query="test", top_k=5, token_budget=3000) assert "### 知识库参考" in result assert "### 过往经验" in result assert "### 知识图谱" in result assert "### 工作记忆" in result assert "### 参考 [来源: custom" in result async def test_flat_format_backward_compatible(self): """Flat 格式:纯文本拼接,无标题行""" from agentkit.memory.base import MemoryItem from agentkit.memory.retriever import MemoryRetriever retriever = MemoryRetriever(context_template="flat") items = [ MemoryItem( key="doc-1", value="First result", metadata={"source": "rag"}, score=0.9, ), MemoryItem( key="ep-1", value="Second result", metadata={"source": "episodic"}, score=0.8, ), ] retriever.retrieve = AsyncMock(return_value=items) result = await retriever.get_context_string(query="test", top_k=5, token_budget=3000) # No structured headers assert "### 知识库参考" not in result assert "### 过往经验" not in result # Just plain text values joined by double newline assert "First result" in result assert "Second result" in result assert result == "First result\n\nSecond result" async def test_token_budget_truncation_in_structured_format(self): """结构化格式:超长结果被截断以符合 token 预算""" from agentkit.memory.base import MemoryItem from agentkit.memory.retriever import MemoryRetriever retriever = MemoryRetriever(context_template="structured") # Create a very long content item long_value = "A" * 20000 item = MemoryItem( key="doc-1", value=long_value, metadata={"source": "rag", "kb_type": "知识库", "document_title": "大文档"}, score=0.9, ) retriever.retrieve = AsyncMock(return_value=[item]) # Very small token budget result = await retriever.get_context_string(query="test", top_k=5, token_budget=100) # Result should be truncated (100 tokens * 4 chars = 400 chars max) assert len(result) <= 400 async def test_empty_results_returns_empty_string(self): """空结果:返回空字符串""" from agentkit.memory.retriever import MemoryRetriever retriever = MemoryRetriever(context_template="structured") retriever.retrieve = AsyncMock(return_value=[]) result = await retriever.get_context_string(query="test", top_k=5, token_budget=3000) assert result == "" async def test_context_template_parameter(self): """context_template 参数:flat 模式产生纯文本输出""" from agentkit.memory.base import MemoryItem from agentkit.memory.retriever import MemoryRetriever # Test with flat template retriever_flat = MemoryRetriever(context_template="flat") item = MemoryItem( key="doc-1", value="Flat content", metadata={"source": "rag"}, score=0.9, ) retriever_flat.retrieve = AsyncMock(return_value=[item]) result_flat = await retriever_flat.get_context_string(query="test") assert "### 知识库参考" not in result_flat assert "Flat content" in result_flat # Test with structured template (default) retriever_structured = MemoryRetriever(context_template="structured") retriever_structured.retrieve = AsyncMock(return_value=[item]) result_structured = await retriever_structured.get_context_string(query="test") assert "### 知识库参考" in result_structured async def test_structured_format_default_kb_type(self): """结构化格式:RAG 结果缺少 kb_type 时使用默认值""" from agentkit.memory.base import MemoryItem from agentkit.memory.retriever import MemoryRetriever retriever = MemoryRetriever(context_template="structured") item = MemoryItem( key="doc-1", value="Content without kb_type", metadata={"source": "rag", "document_title": "报告B"}, score=0.88, ) retriever.retrieve = AsyncMock(return_value=[item]) result = await retriever.get_context_string(query="test") assert "### 知识库参考 [来源: 知识库 | 相关度: 0.88 | 文档: 报告B]" in result async def test_structured_format_default_task_type(self): """结构化格式:情景记忆缺少 task_type 时使用默认值""" from agentkit.memory.base import MemoryItem from agentkit.memory.retriever import MemoryRetriever retriever = MemoryRetriever(context_template="structured") item = MemoryItem( key="ep-1", value="Content without task_type", metadata={"source": "episodic"}, score=0.75, ) retriever.retrieve = AsyncMock(return_value=[item]) result = await retriever.get_context_string(query="test") assert "### 过往经验 [来源: 情景记忆 | 任务类型: 未知]" in result # ── Test: ReAct Context Injection Format ────────── class TestReActContextInjectionFormat: """U3: ReActEngine 使用新标题格式""" async def test_react_uses_new_heading(self): """ReActEngine 使用 '## 参考信息' 标题(非旧标题)""" gateway = make_mock_gateway([make_response(content="final answer")]) engine = ReActEngine(llm_gateway=gateway, max_steps=3) retriever = make_mock_memory_retriever("Some context data") result = await engine.execute( messages=[{"role": "user", "content": "Hello"}], system_prompt="You are a helper.", memory_retriever=retriever, ) assert isinstance(result, ReActResult) call_args = gateway.chat.call_args messages_sent = call_args.kwargs.get("messages") or call_args[1].get("messages") system_msg = messages_sent[0] assert "## 参考信息" in system_msg["content"] assert "Relevant Past Experience" not in system_msg["content"] async def test_react_new_heading_when_no_system_prompt(self): """没有 system_prompt 时,新标题作为 system_prompt 开头""" gateway = make_mock_gateway([make_response(content="final answer")]) engine = ReActEngine(llm_gateway=gateway, max_steps=3) retriever = make_mock_memory_retriever("Context only") result = await engine.execute( messages=[{"role": "user", "content": "Hello"}], memory_retriever=retriever, ) assert isinstance(result, ReActResult) call_args = gateway.chat.call_args messages_sent = call_args.kwargs.get("messages") or call_args[1].get("messages") system_msg = messages_sent[0] assert system_msg["content"].startswith("## 参考信息") assert "Relevant Past Experience" not in system_msg["content"]