fischer-agentkit/tests/unit/test_memory_integration.py

703 lines
26 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""U4: 记忆接入 Agent 循环 - 集成测试
测试 MemoryRetriever 注入 ReActEngine 的完整流程:
1. 执行前检索相关上下文注入 system_prompt
2. 执行后写入轨迹摘要到 EpisodicMemory
3. Memory 检索失败不中断任务执行
4. ConfigDrivenAgent 从 config.memory 自动创建 MemoryRetriever
5. BaseAgent.use_memory_retriever() 方法
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from agentkit.core.react import ReActEngine, ReActResult
from agentkit.llm.gateway import LLMGateway
from agentkit.llm.protocol import LLMResponse, TokenUsage
# ── Test Helpers ──────────────────────────────────────────
def make_mock_gateway(responses: list[LLMResponse]) -> LLMGateway:
"""创建一个 mock LLMGateway按顺序返回给定响应"""
gateway = MagicMock(spec=LLMGateway)
gateway.chat = AsyncMock(side_effect=responses)
return gateway
def make_response(
content: str = "",
prompt_tokens: int = 10,
completion_tokens: int = 20,
) -> LLMResponse:
"""快速构造 LLMResponse"""
return LLMResponse(
content=content,
model="test-model",
usage=TokenUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
),
tool_calls=[],
)
def make_mock_memory_retriever(context_string: str = "past experience data"):
"""创建一个 mock MemoryRetriever"""
retriever = MagicMock()
retriever.get_context_string = AsyncMock(return_value=context_string)
retriever._episodic = None
retriever.store_episode = AsyncMock()
return retriever
def make_mock_episodic_memory():
"""创建一个 mock EpisodicMemory"""
episodic = MagicMock()
episodic.store = AsyncMock()
return episodic
# ── Test: Memory context injected into system_prompt ──────────
class TestMemoryContextInjection:
"""Memory 上下文注入 system_prompt 测试"""
async def test_memory_context_appended_to_existing_system_prompt(self):
"""当有 system_prompt 时memory context 追加到末尾"""
gateway = make_mock_gateway([make_response(content="final answer")])
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
retriever = make_mock_memory_retriever("Previous task result: success")
result = await engine.execute(
messages=[{"role": "user", "content": "Do something"}],
system_prompt="You are a helpful assistant.",
memory_retriever=retriever,
)
assert isinstance(result, ReActResult)
retriever.get_context_string.assert_awaited_once_with(
query="Do something",
top_k=5,
token_budget=2000,
)
# Verify system_prompt was augmented with memory context
call_args = gateway.chat.call_args
messages_sent = call_args.kwargs.get("messages") or call_args[1].get("messages")
# The first message should be system with appended context
system_msg = messages_sent[0]
assert system_msg["role"] == "system"
assert "You are a helpful assistant." in system_msg["content"]
assert "参考信息" in system_msg["content"]
assert "Previous task result: success" in system_msg["content"]
async def test_memory_context_used_as_system_prompt_when_none(self):
"""当没有 system_prompt 时memory context 作为 system_prompt"""
gateway = make_mock_gateway([make_response(content="final answer")])
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
retriever = make_mock_memory_retriever("Past context only")
result = await engine.execute(
messages=[{"role": "user", "content": "Hello"}],
memory_retriever=retriever,
)
assert isinstance(result, ReActResult)
call_args = gateway.chat.call_args
messages_sent = call_args.kwargs.get("messages") or call_args[1].get("messages")
system_msg = messages_sent[0]
assert system_msg["role"] == "system"
assert "参考信息" in system_msg["content"]
assert "Past context only" in system_msg["content"]
async def test_no_memory_context_when_retriever_is_none(self):
"""当 memory_retriever 为 None 时,不注入 memory context"""
gateway = make_mock_gateway([make_response(content="final answer")])
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
result = await engine.execute(
messages=[{"role": "user", "content": "Hello"}],
system_prompt="You are a helper.",
memory_retriever=None,
)
assert isinstance(result, ReActResult)
call_args = gateway.chat.call_args
messages_sent = call_args.kwargs.get("messages") or call_args[1].get("messages")
system_msg = messages_sent[0]
assert system_msg["content"] == "You are a helper."
assert "参考信息" not in system_msg["content"]
async def test_empty_memory_context_not_injected(self):
"""当 memory context 为空字符串时,不注入"""
gateway = make_mock_gateway([make_response(content="final answer")])
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
retriever = make_mock_memory_retriever(context_string="")
result = await engine.execute(
messages=[{"role": "user", "content": "Hello"}],
system_prompt="You are a helper.",
memory_retriever=retriever,
)
assert isinstance(result, ReActResult)
call_args = gateway.chat.call_args
messages_sent = call_args.kwargs.get("messages") or call_args[1].get("messages")
system_msg = messages_sent[0]
assert system_msg["content"] == "You are a helper."
assert "参考信息" not in system_msg["content"]
# ── Test: Memory retrieval failure doesn't break execution ──────────
class TestMemoryRetrievalFailure:
"""Memory 检索失败不中断任务执行"""
async def test_retrieval_failure_continues_without_context(self):
"""Memory 检索异常时,任务正常执行"""
gateway = make_mock_gateway([make_response(content="still works")])
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
retriever = make_mock_memory_retriever()
retriever.get_context_string = AsyncMock(side_effect=RuntimeError("Redis down"))
result = await engine.execute(
messages=[{"role": "user", "content": "Hello"}],
system_prompt="You are a helper.",
memory_retriever=retriever,
)
# Task should still complete
assert isinstance(result, ReActResult)
assert result.output == "still works"
# system_prompt should NOT have memory context
call_args = gateway.chat.call_args
messages_sent = call_args.kwargs.get("messages") or call_args[1].get("messages")
system_msg = messages_sent[0]
assert "参考信息" not in system_msg["content"]
# ── Test: Task result stored in episodic memory ──────────
class TestEpisodicMemoryStorage:
"""执行后写入轨迹摘要到 EpisodicMemory"""
async def test_result_stored_in_episodic_memory(self):
"""任务完成后,结果摘要存储到 EpisodicMemory"""
gateway = make_mock_gateway([make_response(content="The answer is 42")])
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
episodic = make_mock_episodic_memory()
retriever = make_mock_memory_retriever(context_string="")
retriever._episodic = episodic
result = await engine.execute(
messages=[{"role": "user", "content": "What is the answer?"}],
memory_retriever=retriever,
task_id="task-123",
agent_name="test-agent",
task_type="qa",
)
assert isinstance(result, ReActResult)
retriever.store_episode.assert_awaited_once()
call_kwargs = retriever.store_episode.call_args
assert call_kwargs.kwargs.get("key") == "task:task-123" or call_kwargs[1].get("key") == "task:task-123"
# Verify metadata
metadata = call_kwargs.kwargs.get("metadata") or call_kwargs[1].get("metadata")
assert metadata["task_type"] == "qa"
assert metadata["outcome"] == "success"
async def test_no_storage_when_no_episodic_memory(self):
"""没有 EpisodicMemory 时不尝试存储"""
gateway = make_mock_gateway([make_response(content="done")])
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
retriever = make_mock_memory_retriever(context_string="")
retriever._episodic = None
result = await engine.execute(
messages=[{"role": "user", "content": "Hello"}],
memory_retriever=retriever,
)
assert isinstance(result, ReActResult)
# No exception raised, no store called
async def test_storage_failure_doesnt_break_execution(self):
"""EpisodicMemory 存储失败不中断任务"""
gateway = make_mock_gateway([make_response(content="done")])
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
retriever = make_mock_memory_retriever(context_string="")
retriever.store_episode = AsyncMock(side_effect=RuntimeError("DB down"))
result = await engine.execute(
messages=[{"role": "user", "content": "Hello"}],
memory_retriever=retriever,
)
# Task should still complete
assert isinstance(result, ReActResult)
assert result.output == "done"
# ── Test: execute_stream with memory ──────────
class TestMemoryInStreamMode:
"""execute_stream 模式下的 Memory 集成"""
async def test_stream_injects_memory_context(self):
"""execute_stream 也注入 memory context"""
gateway = make_mock_gateway([make_response(content="streamed answer")])
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
retriever = make_mock_memory_retriever("Stream context")
events = []
async for event in engine.execute_stream(
messages=[{"role": "user", "content": "Hello"}],
system_prompt="You are a helper.",
memory_retriever=retriever,
):
events.append(event)
# Should have events
assert len(events) > 0
retriever.get_context_string.assert_awaited_once()
async def test_stream_stores_to_episodic(self):
"""execute_stream 完成后也存储到 EpisodicMemory"""
gateway = make_mock_gateway([make_response(content="streamed answer")])
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
episodic = make_mock_episodic_memory()
retriever = make_mock_memory_retriever(context_string="")
retriever._episodic = episodic
events = []
async for event in engine.execute_stream(
messages=[{"role": "user", "content": "Hello"}],
memory_retriever=retriever,
task_id="stream-task-1",
):
events.append(event)
retriever.store_episode.assert_awaited_once()
# ── Test: BaseAgent.use_memory_retriever() ──────────
class TestBaseAgentMemoryRetriever:
"""BaseAgent.use_memory_retriever() 方法测试"""
def test_use_memory_retriever_sets_field(self):
"""use_memory_retriever() 正确设置 _memory_retriever"""
from agentkit.core.base import BaseAgent
# Create a concrete subclass for testing
class TestAgent(BaseAgent):
async def handle_task(self, task):
return {}
def get_capabilities(self):
from agentkit.core.protocol import AgentCapability
return AgentCapability(
agent_name=self.name,
agent_type=self.agent_type,
version=self.version,
)
agent = TestAgent(name="test", agent_type="test")
mock_retriever = MagicMock()
result = agent.use_memory_retriever(mock_retriever)
# Should return self for chaining
assert result is agent
assert agent._memory_retriever is mock_retriever
def test_memory_retriever_default_is_none(self):
"""_memory_retriever 默认为 None"""
from agentkit.core.base import BaseAgent
class TestAgent(BaseAgent):
async def handle_task(self, task):
return {}
def get_capabilities(self):
from agentkit.core.protocol import AgentCapability
return AgentCapability(
agent_name=self.name,
agent_type=self.agent_type,
version=self.version,
)
agent = TestAgent(name="test", agent_type="test")
assert agent._memory_retriever is None
# ── Test: ConfigDrivenAgent memory integration ──────────
class TestConfigDrivenAgentMemory:
"""ConfigDrivenAgent 从 config.memory 自动创建 MemoryRetriever"""
def test_memory_retriever_created_from_config(self):
"""config.memory 配置时自动创建 MemoryRetriever"""
from agentkit.core.config_driven import ConfigDrivenAgent, AgentConfig
config = AgentConfig(
name="test-agent",
agent_type="test",
task_mode="llm_generate",
prompt={"identity": "Test agent"},
memory={
"working": {"enabled": False},
"episodic": {"enabled": False},
},
)
with patch("agentkit.core.config_driven.MemoryRetriever", create=True) or \
self._patch_memory_imports():
agent = ConfigDrivenAgent(config=config)
# MemoryRetriever should have been created (with no backends since both disabled)
assert agent._memory_retriever is not None
@staticmethod
def _patch_memory_imports():
"""Helper to handle import patching"""
from unittest.mock import patch
return patch("agentkit.memory.retriever.MemoryRetriever")
def test_no_memory_retriever_when_no_config(self):
"""没有 config.memory 时不创建 MemoryRetriever"""
from agentkit.core.config_driven import ConfigDrivenAgent, AgentConfig
config = AgentConfig(
name="test-agent",
agent_type="test",
task_mode="llm_generate",
prompt={"identity": "Test agent"},
)
agent = ConfigDrivenAgent(config=config)
assert agent._memory_retriever is None
def test_memory_retriever_created_with_empty_memory_dict(self):
"""config.memory 为空 dict 时创建 MemoryRetriever无后端"""
from agentkit.core.config_driven import ConfigDrivenAgent, AgentConfig
config = AgentConfig(
name="test-agent",
agent_type="test",
task_mode="llm_generate",
prompt={"identity": "Test agent"},
memory={},
)
agent = ConfigDrivenAgent(config=config)
# Empty dict is falsy, so no retriever
assert agent._memory_retriever is None
def test_memory_retriever_failure_graceful(self):
"""Memory 初始化失败时优雅降级"""
from agentkit.core.config_driven import ConfigDrivenAgent, AgentConfig
config = AgentConfig(
name="test-agent",
agent_type="test",
task_mode="llm_generate",
prompt={"identity": "Test agent"},
memory={"working": {"enabled": True, "redis_url": "redis://nonexistent:6379"}},
)
# Should not raise, just log warning and set _memory_retriever to None
agent = ConfigDrivenAgent(config=config)
# Either retriever was created or gracefully failed
# The key is that no exception is raised
# ── Test: Structured Context Injection ──────────
class TestStructuredContextInjection:
"""U3: 结构化上下文注入测试"""
async def test_structured_format_with_rag_results(self):
"""结构化格式RAG 结果包含知识库参考标题"""
from agentkit.memory.base import MemoryItem
from agentkit.memory.retriever import MemoryRetriever
retriever = MemoryRetriever(context_template="structured")
# Mock retrieve to return RAG items
rag_item = MemoryItem(
key="doc-1",
value="AI行业在2025年呈现三大趋势...",
metadata={"source": "rag", "kb_type": "行业库", "document_title": "AI行业趋势报告"},
score=0.92,
)
retriever.retrieve = AsyncMock(return_value=[rag_item])
result = await retriever.get_context_string(query="AI trends", top_k=5, token_budget=3000)
assert "### 知识库参考 [来源: 行业库 | 相关度: 0.92 | 文档: AI行业趋势报告]" in result
assert "AI行业在2025年呈现三大趋势..." in result
async def test_structured_format_with_episodic_results(self):
"""结构化格式:情景记忆结果包含过往经验标题"""
from agentkit.memory.base import MemoryItem
from agentkit.memory.retriever import MemoryRetriever
retriever = MemoryRetriever(context_template="structured")
episodic_item = MemoryItem(
key="task:seo-001",
value="上次分析竞品SEO策略时发现...",
metadata={"source": "episodic", "task_type": "seo_analysis"},
score=0.85,
)
retriever.retrieve = AsyncMock(return_value=[episodic_item])
result = await retriever.get_context_string(query="SEO analysis", top_k=5, token_budget=3000)
assert "### 过往经验 [来源: 情景记忆 | 任务类型: seo_analysis]" in result
assert "上次分析竞品SEO策略时发现..." in result
async def test_structured_format_with_mixed_sources(self):
"""结构化格式:不同来源生成不同标题"""
from agentkit.memory.base import MemoryItem
from agentkit.memory.retriever import MemoryRetriever
retriever = MemoryRetriever(context_template="structured")
items = [
MemoryItem(
key="doc-1",
value="RAG content here",
metadata={"source": "rag", "kb_type": "行业库", "document_title": "报告A"},
score=0.90,
),
MemoryItem(
key="task:ep-1",
value="Episodic content here",
metadata={"source": "episodic", "task_type": "analysis"},
score=0.80,
),
MemoryItem(
key="entity-1",
value="Graph content here",
metadata={"source": "graph"},
score=0.75,
),
MemoryItem(
key="ctx-1",
value="Working memory content",
metadata={"source": "working"},
score=0.60,
),
MemoryItem(
key="other-1",
value="Unknown source content",
metadata={"source": "custom"},
score=0.50,
),
]
retriever.retrieve = AsyncMock(return_value=items)
result = await retriever.get_context_string(query="test", top_k=5, token_budget=3000)
assert "### 知识库参考" in result
assert "### 过往经验" in result
assert "### 知识图谱" in result
assert "### 工作记忆" in result
assert "### 参考 [来源: custom" in result
async def test_flat_format_backward_compatible(self):
"""Flat 格式:纯文本拼接,无标题行"""
from agentkit.memory.base import MemoryItem
from agentkit.memory.retriever import MemoryRetriever
retriever = MemoryRetriever(context_template="flat")
items = [
MemoryItem(
key="doc-1",
value="First result",
metadata={"source": "rag"},
score=0.9,
),
MemoryItem(
key="ep-1",
value="Second result",
metadata={"source": "episodic"},
score=0.8,
),
]
retriever.retrieve = AsyncMock(return_value=items)
result = await retriever.get_context_string(query="test", top_k=5, token_budget=3000)
# No structured headers
assert "### 知识库参考" not in result
assert "### 过往经验" not in result
# Just plain text values joined by double newline
assert "First result" in result
assert "Second result" in result
assert result == "First result\n\nSecond result"
async def test_token_budget_truncation_in_structured_format(self):
"""结构化格式:超长结果被截断以符合 token 预算"""
from agentkit.memory.base import MemoryItem
from agentkit.memory.retriever import MemoryRetriever
retriever = MemoryRetriever(context_template="structured")
# Create a very long content item
long_value = "A" * 20000
item = MemoryItem(
key="doc-1",
value=long_value,
metadata={"source": "rag", "kb_type": "知识库", "document_title": "大文档"},
score=0.9,
)
retriever.retrieve = AsyncMock(return_value=[item])
# Very small token budget
result = await retriever.get_context_string(query="test", top_k=5, token_budget=100)
# Result should be truncated (100 tokens * 4 chars = 400 chars max)
assert len(result) <= 400
async def test_empty_results_returns_empty_string(self):
"""空结果:返回空字符串"""
from agentkit.memory.retriever import MemoryRetriever
retriever = MemoryRetriever(context_template="structured")
retriever.retrieve = AsyncMock(return_value=[])
result = await retriever.get_context_string(query="test", top_k=5, token_budget=3000)
assert result == ""
async def test_context_template_parameter(self):
"""context_template 参数flat 模式产生纯文本输出"""
from agentkit.memory.base import MemoryItem
from agentkit.memory.retriever import MemoryRetriever
# Test with flat template
retriever_flat = MemoryRetriever(context_template="flat")
item = MemoryItem(
key="doc-1",
value="Flat content",
metadata={"source": "rag"},
score=0.9,
)
retriever_flat.retrieve = AsyncMock(return_value=[item])
result_flat = await retriever_flat.get_context_string(query="test")
assert "### 知识库参考" not in result_flat
assert "Flat content" in result_flat
# Test with structured template (default)
retriever_structured = MemoryRetriever(context_template="structured")
retriever_structured.retrieve = AsyncMock(return_value=[item])
result_structured = await retriever_structured.get_context_string(query="test")
assert "### 知识库参考" in result_structured
async def test_structured_format_default_kb_type(self):
"""结构化格式RAG 结果缺少 kb_type 时使用默认值"""
from agentkit.memory.base import MemoryItem
from agentkit.memory.retriever import MemoryRetriever
retriever = MemoryRetriever(context_template="structured")
item = MemoryItem(
key="doc-1",
value="Content without kb_type",
metadata={"source": "rag", "document_title": "报告B"},
score=0.88,
)
retriever.retrieve = AsyncMock(return_value=[item])
result = await retriever.get_context_string(query="test")
assert "### 知识库参考 [来源: 知识库 | 相关度: 0.88 | 文档: 报告B]" in result
async def test_structured_format_default_task_type(self):
"""结构化格式:情景记忆缺少 task_type 时使用默认值"""
from agentkit.memory.base import MemoryItem
from agentkit.memory.retriever import MemoryRetriever
retriever = MemoryRetriever(context_template="structured")
item = MemoryItem(
key="ep-1",
value="Content without task_type",
metadata={"source": "episodic"},
score=0.75,
)
retriever.retrieve = AsyncMock(return_value=[item])
result = await retriever.get_context_string(query="test")
assert "### 过往经验 [来源: 情景记忆 | 任务类型: 未知]" in result
# ── Test: ReAct Context Injection Format ──────────
class TestReActContextInjectionFormat:
"""U3: ReActEngine 使用新标题格式"""
async def test_react_uses_new_heading(self):
"""ReActEngine 使用 '## 参考信息' 标题(非旧标题)"""
gateway = make_mock_gateway([make_response(content="final answer")])
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
retriever = make_mock_memory_retriever("Some context data")
result = await engine.execute(
messages=[{"role": "user", "content": "Hello"}],
system_prompt="You are a helper.",
memory_retriever=retriever,
)
assert isinstance(result, ReActResult)
call_args = gateway.chat.call_args
messages_sent = call_args.kwargs.get("messages") or call_args[1].get("messages")
system_msg = messages_sent[0]
assert "## 参考信息" in system_msg["content"]
assert "Relevant Past Experience" not in system_msg["content"]
async def test_react_new_heading_when_no_system_prompt(self):
"""没有 system_prompt 时,新标题作为 system_prompt 开头"""
gateway = make_mock_gateway([make_response(content="final answer")])
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
retriever = make_mock_memory_retriever("Context only")
result = await engine.execute(
messages=[{"role": "user", "content": "Hello"}],
memory_retriever=retriever,
)
assert isinstance(result, ReActResult)
call_args = gateway.chat.call_args
messages_sent = call_args.kwargs.get("messages") or call_args[1].get("messages")
system_msg = messages_sent[0]
assert system_msg["content"].startswith("## 参考信息")
assert "Relevant Past Experience" not in system_msg["content"]