433 lines
16 KiB
Python
433 lines
16 KiB
Python
"""U4: 记忆接入 Agent 循环 - 集成测试
|
||
|
||
测试 MemoryRetriever 注入 ReActEngine 的完整流程:
|
||
1. 执行前检索相关上下文注入 system_prompt
|
||
2. 执行后写入轨迹摘要到 EpisodicMemory
|
||
3. Memory 检索失败不中断任务执行
|
||
4. ConfigDrivenAgent 从 config.memory 自动创建 MemoryRetriever
|
||
5. BaseAgent.use_memory_retriever() 方法
|
||
"""
|
||
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
import pytest
|
||
|
||
from agentkit.core.react import ReActEngine, ReActResult
|
||
from agentkit.llm.gateway import LLMGateway
|
||
from agentkit.llm.protocol import LLMResponse, TokenUsage
|
||
|
||
|
||
# ── Test Helpers ──────────────────────────────────────────
|
||
|
||
|
||
def make_mock_gateway(responses: list[LLMResponse]) -> LLMGateway:
|
||
"""创建一个 mock LLMGateway,按顺序返回给定响应"""
|
||
gateway = MagicMock(spec=LLMGateway)
|
||
gateway.chat = AsyncMock(side_effect=responses)
|
||
return gateway
|
||
|
||
|
||
def make_response(
|
||
content: str = "",
|
||
prompt_tokens: int = 10,
|
||
completion_tokens: int = 20,
|
||
) -> LLMResponse:
|
||
"""快速构造 LLMResponse"""
|
||
return LLMResponse(
|
||
content=content,
|
||
model="test-model",
|
||
usage=TokenUsage(
|
||
prompt_tokens=prompt_tokens,
|
||
completion_tokens=completion_tokens,
|
||
),
|
||
tool_calls=[],
|
||
)
|
||
|
||
|
||
def make_mock_memory_retriever(context_string: str = "past experience data"):
|
||
"""创建一个 mock MemoryRetriever"""
|
||
retriever = MagicMock()
|
||
retriever.get_context_string = AsyncMock(return_value=context_string)
|
||
retriever._episodic = None
|
||
return retriever
|
||
|
||
|
||
def make_mock_episodic_memory():
|
||
"""创建一个 mock EpisodicMemory"""
|
||
episodic = MagicMock()
|
||
episodic.store = AsyncMock()
|
||
return episodic
|
||
|
||
|
||
# ── Test: Memory context injected into system_prompt ──────────
|
||
|
||
|
||
class TestMemoryContextInjection:
|
||
"""Memory 上下文注入 system_prompt 测试"""
|
||
|
||
async def test_memory_context_appended_to_existing_system_prompt(self):
|
||
"""当有 system_prompt 时,memory context 追加到末尾"""
|
||
gateway = make_mock_gateway([make_response(content="final answer")])
|
||
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
|
||
|
||
retriever = make_mock_memory_retriever("Previous task result: success")
|
||
|
||
result = await engine.execute(
|
||
messages=[{"role": "user", "content": "Do something"}],
|
||
system_prompt="You are a helpful assistant.",
|
||
memory_retriever=retriever,
|
||
)
|
||
|
||
assert isinstance(result, ReActResult)
|
||
retriever.get_context_string.assert_awaited_once_with(
|
||
query="Do something",
|
||
top_k=5,
|
||
token_budget=2000,
|
||
)
|
||
|
||
# Verify system_prompt was augmented with memory context
|
||
call_args = gateway.chat.call_args
|
||
messages_sent = call_args.kwargs.get("messages") or call_args[1].get("messages")
|
||
# The first message should be system with appended context
|
||
system_msg = messages_sent[0]
|
||
assert system_msg["role"] == "system"
|
||
assert "You are a helpful assistant." in system_msg["content"]
|
||
assert "Relevant Past Experience" in system_msg["content"]
|
||
assert "Previous task result: success" in system_msg["content"]
|
||
|
||
async def test_memory_context_used_as_system_prompt_when_none(self):
|
||
"""当没有 system_prompt 时,memory context 作为 system_prompt"""
|
||
gateway = make_mock_gateway([make_response(content="final answer")])
|
||
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
|
||
|
||
retriever = make_mock_memory_retriever("Past context only")
|
||
|
||
result = await engine.execute(
|
||
messages=[{"role": "user", "content": "Hello"}],
|
||
memory_retriever=retriever,
|
||
)
|
||
|
||
assert isinstance(result, ReActResult)
|
||
call_args = gateway.chat.call_args
|
||
messages_sent = call_args.kwargs.get("messages") or call_args[1].get("messages")
|
||
system_msg = messages_sent[0]
|
||
assert system_msg["role"] == "system"
|
||
assert "Relevant Past Experience" in system_msg["content"]
|
||
assert "Past context only" in system_msg["content"]
|
||
|
||
async def test_no_memory_context_when_retriever_is_none(self):
|
||
"""当 memory_retriever 为 None 时,不注入 memory context"""
|
||
gateway = make_mock_gateway([make_response(content="final answer")])
|
||
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
|
||
|
||
result = await engine.execute(
|
||
messages=[{"role": "user", "content": "Hello"}],
|
||
system_prompt="You are a helper.",
|
||
memory_retriever=None,
|
||
)
|
||
|
||
assert isinstance(result, ReActResult)
|
||
call_args = gateway.chat.call_args
|
||
messages_sent = call_args.kwargs.get("messages") or call_args[1].get("messages")
|
||
system_msg = messages_sent[0]
|
||
assert system_msg["content"] == "You are a helper."
|
||
assert "Relevant Past Experience" not in system_msg["content"]
|
||
|
||
async def test_empty_memory_context_not_injected(self):
|
||
"""当 memory context 为空字符串时,不注入"""
|
||
gateway = make_mock_gateway([make_response(content="final answer")])
|
||
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
|
||
|
||
retriever = make_mock_memory_retriever(context_string="")
|
||
|
||
result = await engine.execute(
|
||
messages=[{"role": "user", "content": "Hello"}],
|
||
system_prompt="You are a helper.",
|
||
memory_retriever=retriever,
|
||
)
|
||
|
||
assert isinstance(result, ReActResult)
|
||
call_args = gateway.chat.call_args
|
||
messages_sent = call_args.kwargs.get("messages") or call_args[1].get("messages")
|
||
system_msg = messages_sent[0]
|
||
assert system_msg["content"] == "You are a helper."
|
||
assert "Relevant Past Experience" not in system_msg["content"]
|
||
|
||
|
||
# ── Test: Memory retrieval failure doesn't break execution ──────────
|
||
|
||
|
||
class TestMemoryRetrievalFailure:
|
||
"""Memory 检索失败不中断任务执行"""
|
||
|
||
async def test_retrieval_failure_continues_without_context(self):
|
||
"""Memory 检索异常时,任务正常执行"""
|
||
gateway = make_mock_gateway([make_response(content="still works")])
|
||
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
|
||
|
||
retriever = make_mock_memory_retriever()
|
||
retriever.get_context_string = AsyncMock(side_effect=RuntimeError("Redis down"))
|
||
|
||
result = await engine.execute(
|
||
messages=[{"role": "user", "content": "Hello"}],
|
||
system_prompt="You are a helper.",
|
||
memory_retriever=retriever,
|
||
)
|
||
|
||
# Task should still complete
|
||
assert isinstance(result, ReActResult)
|
||
assert result.output == "still works"
|
||
|
||
# system_prompt should NOT have memory context
|
||
call_args = gateway.chat.call_args
|
||
messages_sent = call_args.kwargs.get("messages") or call_args[1].get("messages")
|
||
system_msg = messages_sent[0]
|
||
assert "Relevant Past Experience" not in system_msg["content"]
|
||
|
||
|
||
# ── Test: Task result stored in episodic memory ──────────
|
||
|
||
|
||
class TestEpisodicMemoryStorage:
|
||
"""执行后写入轨迹摘要到 EpisodicMemory"""
|
||
|
||
async def test_result_stored_in_episodic_memory(self):
|
||
"""任务完成后,结果摘要存储到 EpisodicMemory"""
|
||
gateway = make_mock_gateway([make_response(content="The answer is 42")])
|
||
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
|
||
|
||
episodic = make_mock_episodic_memory()
|
||
retriever = make_mock_memory_retriever(context_string="")
|
||
retriever._episodic = episodic
|
||
|
||
result = await engine.execute(
|
||
messages=[{"role": "user", "content": "What is the answer?"}],
|
||
memory_retriever=retriever,
|
||
task_id="task-123",
|
||
agent_name="test-agent",
|
||
task_type="qa",
|
||
)
|
||
|
||
assert isinstance(result, ReActResult)
|
||
episodic.store.assert_awaited_once()
|
||
call_kwargs = episodic.store.call_args
|
||
assert call_kwargs.kwargs.get("key") == "task:task-123" or call_kwargs[1].get("key") == "task:task-123"
|
||
# Verify metadata
|
||
metadata = call_kwargs.kwargs.get("metadata") or call_kwargs[1].get("metadata")
|
||
assert metadata["task_type"] == "qa"
|
||
assert metadata["outcome"] == "success"
|
||
|
||
async def test_no_storage_when_no_episodic_memory(self):
|
||
"""没有 EpisodicMemory 时不尝试存储"""
|
||
gateway = make_mock_gateway([make_response(content="done")])
|
||
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
|
||
|
||
retriever = make_mock_memory_retriever(context_string="")
|
||
retriever._episodic = None
|
||
|
||
result = await engine.execute(
|
||
messages=[{"role": "user", "content": "Hello"}],
|
||
memory_retriever=retriever,
|
||
)
|
||
|
||
assert isinstance(result, ReActResult)
|
||
# No exception raised, no store called
|
||
|
||
async def test_storage_failure_doesnt_break_execution(self):
|
||
"""EpisodicMemory 存储失败不中断任务"""
|
||
gateway = make_mock_gateway([make_response(content="done")])
|
||
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
|
||
|
||
episodic = make_mock_episodic_memory()
|
||
episodic.store = AsyncMock(side_effect=RuntimeError("DB down"))
|
||
|
||
retriever = make_mock_memory_retriever(context_string="")
|
||
retriever._episodic = episodic
|
||
|
||
result = await engine.execute(
|
||
messages=[{"role": "user", "content": "Hello"}],
|
||
memory_retriever=retriever,
|
||
)
|
||
|
||
# Task should still complete
|
||
assert isinstance(result, ReActResult)
|
||
assert result.output == "done"
|
||
|
||
|
||
# ── Test: execute_stream with memory ──────────
|
||
|
||
|
||
class TestMemoryInStreamMode:
|
||
"""execute_stream 模式下的 Memory 集成"""
|
||
|
||
async def test_stream_injects_memory_context(self):
|
||
"""execute_stream 也注入 memory context"""
|
||
gateway = make_mock_gateway([make_response(content="streamed answer")])
|
||
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
|
||
|
||
retriever = make_mock_memory_retriever("Stream context")
|
||
|
||
events = []
|
||
async for event in engine.execute_stream(
|
||
messages=[{"role": "user", "content": "Hello"}],
|
||
system_prompt="You are a helper.",
|
||
memory_retriever=retriever,
|
||
):
|
||
events.append(event)
|
||
|
||
# Should have events
|
||
assert len(events) > 0
|
||
retriever.get_context_string.assert_awaited_once()
|
||
|
||
async def test_stream_stores_to_episodic(self):
|
||
"""execute_stream 完成后也存储到 EpisodicMemory"""
|
||
gateway = make_mock_gateway([make_response(content="streamed answer")])
|
||
engine = ReActEngine(llm_gateway=gateway, max_steps=3)
|
||
|
||
episodic = make_mock_episodic_memory()
|
||
retriever = make_mock_memory_retriever(context_string="")
|
||
retriever._episodic = episodic
|
||
|
||
events = []
|
||
async for event in engine.execute_stream(
|
||
messages=[{"role": "user", "content": "Hello"}],
|
||
memory_retriever=retriever,
|
||
task_id="stream-task-1",
|
||
):
|
||
events.append(event)
|
||
|
||
episodic.store.assert_awaited_once()
|
||
|
||
|
||
# ── Test: BaseAgent.use_memory_retriever() ──────────
|
||
|
||
|
||
class TestBaseAgentMemoryRetriever:
|
||
"""BaseAgent.use_memory_retriever() 方法测试"""
|
||
|
||
def test_use_memory_retriever_sets_field(self):
|
||
"""use_memory_retriever() 正确设置 _memory_retriever"""
|
||
from agentkit.core.base import BaseAgent
|
||
|
||
# Create a concrete subclass for testing
|
||
class TestAgent(BaseAgent):
|
||
async def handle_task(self, task):
|
||
return {}
|
||
|
||
def get_capabilities(self):
|
||
from agentkit.core.protocol import AgentCapability
|
||
return AgentCapability(
|
||
agent_name=self.name,
|
||
agent_type=self.agent_type,
|
||
version=self.version,
|
||
)
|
||
|
||
agent = TestAgent(name="test", agent_type="test")
|
||
mock_retriever = MagicMock()
|
||
|
||
result = agent.use_memory_retriever(mock_retriever)
|
||
|
||
# Should return self for chaining
|
||
assert result is agent
|
||
assert agent._memory_retriever is mock_retriever
|
||
|
||
def test_memory_retriever_default_is_none(self):
|
||
"""_memory_retriever 默认为 None"""
|
||
from agentkit.core.base import BaseAgent
|
||
|
||
class TestAgent(BaseAgent):
|
||
async def handle_task(self, task):
|
||
return {}
|
||
|
||
def get_capabilities(self):
|
||
from agentkit.core.protocol import AgentCapability
|
||
return AgentCapability(
|
||
agent_name=self.name,
|
||
agent_type=self.agent_type,
|
||
version=self.version,
|
||
)
|
||
|
||
agent = TestAgent(name="test", agent_type="test")
|
||
assert agent._memory_retriever is None
|
||
|
||
|
||
# ── Test: ConfigDrivenAgent memory integration ──────────
|
||
|
||
|
||
class TestConfigDrivenAgentMemory:
|
||
"""ConfigDrivenAgent 从 config.memory 自动创建 MemoryRetriever"""
|
||
|
||
def test_memory_retriever_created_from_config(self):
|
||
"""config.memory 配置时自动创建 MemoryRetriever"""
|
||
from agentkit.core.config_driven import ConfigDrivenAgent, AgentConfig
|
||
|
||
config = AgentConfig(
|
||
name="test-agent",
|
||
agent_type="test",
|
||
task_mode="llm_generate",
|
||
prompt={"identity": "Test agent"},
|
||
memory={
|
||
"working": {"enabled": False},
|
||
"episodic": {"enabled": False},
|
||
},
|
||
)
|
||
|
||
with patch("agentkit.core.config_driven.MemoryRetriever", create=True) or \
|
||
self._patch_memory_imports():
|
||
agent = ConfigDrivenAgent(config=config)
|
||
# MemoryRetriever should have been created (with no backends since both disabled)
|
||
assert agent._memory_retriever is not None
|
||
|
||
@staticmethod
|
||
def _patch_memory_imports():
|
||
"""Helper to handle import patching"""
|
||
from unittest.mock import patch
|
||
return patch("agentkit.memory.retriever.MemoryRetriever")
|
||
|
||
def test_no_memory_retriever_when_no_config(self):
|
||
"""没有 config.memory 时不创建 MemoryRetriever"""
|
||
from agentkit.core.config_driven import ConfigDrivenAgent, AgentConfig
|
||
|
||
config = AgentConfig(
|
||
name="test-agent",
|
||
agent_type="test",
|
||
task_mode="llm_generate",
|
||
prompt={"identity": "Test agent"},
|
||
)
|
||
|
||
agent = ConfigDrivenAgent(config=config)
|
||
assert agent._memory_retriever is None
|
||
|
||
def test_memory_retriever_created_with_empty_memory_dict(self):
|
||
"""config.memory 为空 dict 时创建 MemoryRetriever(无后端)"""
|
||
from agentkit.core.config_driven import ConfigDrivenAgent, AgentConfig
|
||
|
||
config = AgentConfig(
|
||
name="test-agent",
|
||
agent_type="test",
|
||
task_mode="llm_generate",
|
||
prompt={"identity": "Test agent"},
|
||
memory={},
|
||
)
|
||
|
||
agent = ConfigDrivenAgent(config=config)
|
||
# Empty dict is falsy, so no retriever
|
||
assert agent._memory_retriever is None
|
||
|
||
def test_memory_retriever_failure_graceful(self):
|
||
"""Memory 初始化失败时优雅降级"""
|
||
from agentkit.core.config_driven import ConfigDrivenAgent, AgentConfig
|
||
|
||
config = AgentConfig(
|
||
name="test-agent",
|
||
agent_type="test",
|
||
task_mode="llm_generate",
|
||
prompt={"identity": "Test agent"},
|
||
memory={"working": {"enabled": True, "redis_url": "redis://nonexistent:6379"}},
|
||
)
|
||
|
||
# Should not raise, just log warning and set _memory_retriever to None
|
||
agent = ConfigDrivenAgent(config=config)
|
||
# Either retriever was created or gracefully failed
|
||
# The key is that no exception is raised
|