420 lines
13 KiB
Python
420 lines
13 KiB
Python
"""EpisodicMemory 单元测试 - 基于 pgvector + PostgreSQL 的任务经验记忆
|
||
|
||
使用 mock session_factory 和真实 SQLAlchemy ORM 模型进行单元测试,
|
||
不需要真实的 PostgreSQL/pgvector 环境。
|
||
"""
|
||
|
||
import uuid
|
||
from contextlib import asynccontextmanager
|
||
from datetime import datetime, timedelta, timezone
|
||
from unittest.mock import AsyncMock, MagicMock
|
||
|
||
import pytest
|
||
from sqlalchemy import Column, DateTime, Float, String, delete as sql_delete, select
|
||
from sqlalchemy.orm import DeclarativeBase
|
||
|
||
from agentkit.memory.episodic import EpisodicMemory
|
||
from agentkit.memory.base import MemoryItem
|
||
|
||
|
||
# ── 真实 SQLAlchemy 模型(用于测试) ─────────────────────
|
||
|
||
|
||
class Base(DeclarativeBase):
|
||
pass
|
||
|
||
|
||
class MockEpisodicModel(Base):
|
||
"""模拟 EpisodicMemory ORM 模型,使用真实 SQLAlchemy 列定义"""
|
||
|
||
__tablename__ = "test_episodic_memory"
|
||
|
||
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||
agent_name = Column(String, default="")
|
||
task_type = Column(String, default="")
|
||
input_summary = Column(String, default="")
|
||
output_summary = Column(String, default="")
|
||
outcome = Column(String, default="success")
|
||
quality_score = Column(Float, default=0.5)
|
||
reflection = Column(String, default="")
|
||
embedding = Column(String, nullable=True)
|
||
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||
|
||
|
||
# ── Mock 辅助工具 ────────────────────────────────────────
|
||
|
||
|
||
def make_mock_entry(
|
||
id: uuid.UUID | None = None,
|
||
agent_name: str = "test_agent",
|
||
task_type: str = "analysis",
|
||
input_summary: str = "test input",
|
||
output_summary: str = "test output",
|
||
outcome: str = "success",
|
||
quality_score: float = 0.8,
|
||
reflection: str = "",
|
||
created_at: datetime | None = None,
|
||
):
|
||
"""创建一个模拟的 ORM entry 对象(使用真实模型实例)"""
|
||
entry = MockEpisodicModel(
|
||
id=str(id or uuid.uuid4()),
|
||
agent_name=agent_name,
|
||
task_type=task_type,
|
||
input_summary=input_summary,
|
||
output_summary=output_summary,
|
||
outcome=outcome,
|
||
quality_score=quality_score,
|
||
reflection=reflection,
|
||
created_at=created_at or datetime.now(timezone.utc),
|
||
)
|
||
return entry
|
||
|
||
|
||
def make_mock_session_factory(entries: list | None = None):
|
||
"""创建一个 mock session_factory,返回包含指定 entries 的 session
|
||
|
||
Args:
|
||
entries: search 方法返回的 ORM entry 列表
|
||
"""
|
||
entries = entries or []
|
||
|
||
mock_session = AsyncMock()
|
||
mock_session.add = MagicMock()
|
||
mock_session.commit = AsyncMock()
|
||
mock_session.rollback = AsyncMock()
|
||
|
||
# 模拟 execute 返回的 result 对象
|
||
mock_result = MagicMock()
|
||
mock_scalars = MagicMock()
|
||
mock_scalars.all.return_value = entries
|
||
mock_result.scalars.return_value = mock_scalars
|
||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||
|
||
@asynccontextmanager
|
||
async def factory():
|
||
yield mock_session
|
||
|
||
return factory, mock_session
|
||
|
||
|
||
# ── EpisodicMemory 测试 ──────────────────────────────────
|
||
|
||
|
||
class TestEpisodicMemoryStore:
|
||
"""EpisodicMemory.store 测试"""
|
||
|
||
async def test_store_writes_entry_with_correct_fields(self):
|
||
"""store 写入包含正确字段的 entry"""
|
||
factory, mock_session = make_mock_session_factory()
|
||
|
||
mem = EpisodicMemory(
|
||
session_factory=factory,
|
||
episodic_model=MockEpisodicModel,
|
||
)
|
||
|
||
await mem.store(
|
||
key="task:001",
|
||
value="Analyzed financial data",
|
||
metadata={
|
||
"agent_name": "analyst_agent",
|
||
"task_type": "financial_analysis",
|
||
"output_summary": "Report generated",
|
||
"outcome": "success",
|
||
"quality_score": 0.9,
|
||
"reflection": "Good analysis",
|
||
},
|
||
)
|
||
|
||
mock_session.add.assert_called_once()
|
||
mock_session.commit.assert_called_once()
|
||
|
||
# 验证传入 add 的 entry 参数
|
||
entry_arg = mock_session.add.call_args[0][0]
|
||
assert isinstance(entry_arg, MockEpisodicModel)
|
||
assert entry_arg.agent_name == "analyst_agent"
|
||
assert entry_arg.task_type == "financial_analysis"
|
||
assert entry_arg.input_summary == "Analyzed financial data"
|
||
assert entry_arg.output_summary == "Report generated"
|
||
assert entry_arg.outcome == "success"
|
||
assert entry_arg.quality_score == 0.9
|
||
assert entry_arg.reflection == "Good analysis"
|
||
|
||
async def test_store_with_embedder_generates_embedding(self):
|
||
"""store 时有 embedder 则生成 embedding"""
|
||
factory, mock_session = make_mock_session_factory()
|
||
|
||
mock_embedder = AsyncMock()
|
||
mock_embedder.embed = AsyncMock(return_value=[0.1, 0.2, 0.3])
|
||
|
||
mem = EpisodicMemory(
|
||
session_factory=factory,
|
||
episodic_model=MockEpisodicModel,
|
||
embedder=mock_embedder,
|
||
)
|
||
|
||
await mem.store("key1", "some value", {"agent_name": "test"})
|
||
|
||
mock_embedder.embed.assert_called_once()
|
||
call_args = mock_embedder.embed.call_args[0][0]
|
||
assert "some value" in call_args
|
||
|
||
# 验证 entry 的 embedding 被设置
|
||
entry_arg = mock_session.add.call_args[0][0]
|
||
assert entry_arg.embedding == [0.1, 0.2, 0.3]
|
||
|
||
async def test_store_without_embedder_no_embedding(self):
|
||
"""store 时无 embedder 则 embedding 为 None"""
|
||
factory, mock_session = make_mock_session_factory()
|
||
|
||
mem = EpisodicMemory(
|
||
session_factory=factory,
|
||
episodic_model=MockEpisodicModel,
|
||
embedder=None,
|
||
)
|
||
|
||
await mem.store("key1", "some value")
|
||
|
||
entry_arg = mock_session.add.call_args[0][0]
|
||
assert entry_arg.embedding is None
|
||
|
||
async def test_store_rollback_on_error(self):
|
||
"""store 失败时执行 rollback"""
|
||
factory, mock_session = make_mock_session_factory()
|
||
|
||
# 让 commit 抛出异常
|
||
mock_session.commit = AsyncMock(side_effect=Exception("DB error"))
|
||
|
||
mem = EpisodicMemory(
|
||
session_factory=factory,
|
||
episodic_model=MockEpisodicModel,
|
||
)
|
||
|
||
with pytest.raises(Exception, match="DB error"):
|
||
await mem.store("key1", "value1")
|
||
|
||
mock_session.rollback.assert_called_once()
|
||
|
||
async def test_store_default_metadata_values(self):
|
||
"""store 时 metadata 缺失字段使用默认值"""
|
||
factory, mock_session = make_mock_session_factory()
|
||
|
||
mem = EpisodicMemory(
|
||
session_factory=factory,
|
||
episodic_model=MockEpisodicModel,
|
||
)
|
||
|
||
await mem.store("key1", "value1")
|
||
|
||
entry_arg = mock_session.add.call_args[0][0]
|
||
assert entry_arg.agent_name == ""
|
||
assert entry_arg.task_type == ""
|
||
assert entry_arg.outcome == "success"
|
||
assert entry_arg.quality_score == 0.5
|
||
assert entry_arg.reflection == ""
|
||
|
||
|
||
class TestEpisodicMemorySearch:
|
||
"""EpisodicMemory.search 测试"""
|
||
|
||
async def test_search_with_time_decay_recent_scores_higher(self):
|
||
"""时间衰减:近期条目得分更高"""
|
||
now = datetime.now(timezone.utc)
|
||
recent_entry = make_mock_entry(
|
||
quality_score=0.8,
|
||
created_at=now - timedelta(hours=1),
|
||
)
|
||
old_entry = make_mock_entry(
|
||
quality_score=0.8,
|
||
created_at=now - timedelta(hours=100),
|
||
)
|
||
|
||
factory, _ = make_mock_session_factory([recent_entry, old_entry])
|
||
|
||
mem = EpisodicMemory(
|
||
session_factory=factory,
|
||
episodic_model=MockEpisodicModel,
|
||
decay_rate=0.01,
|
||
)
|
||
|
||
results = await mem.search("test query")
|
||
assert len(results) == 2
|
||
# 近期条目应排在前面
|
||
assert results[0].score > results[1].score
|
||
|
||
async def test_search_with_quality_score_factor(self):
|
||
"""quality_score 影响最终得分"""
|
||
now = datetime.now(timezone.utc)
|
||
high_quality = make_mock_entry(
|
||
quality_score=0.9,
|
||
created_at=now - timedelta(hours=1),
|
||
)
|
||
low_quality = make_mock_entry(
|
||
quality_score=0.1,
|
||
created_at=now - timedelta(hours=1),
|
||
)
|
||
|
||
factory, _ = make_mock_session_factory([high_quality, low_quality])
|
||
|
||
mem = EpisodicMemory(
|
||
session_factory=factory,
|
||
episodic_model=MockEpisodicModel,
|
||
)
|
||
|
||
results = await mem.search("test query")
|
||
assert len(results) == 2
|
||
# 高质量条目应排在前面
|
||
assert results[0].score > results[1].score
|
||
|
||
async def test_search_empty_store_returns_empty(self):
|
||
"""空存储 search 返回空列表"""
|
||
factory, _ = make_mock_session_factory([])
|
||
|
||
mem = EpisodicMemory(
|
||
session_factory=factory,
|
||
episodic_model=MockEpisodicModel,
|
||
)
|
||
|
||
results = await mem.search("anything")
|
||
assert results == []
|
||
|
||
async def test_search_applies_agent_name_filter(self):
|
||
"""search 应用 agent_name 过滤"""
|
||
factory, mock_session = make_mock_session_factory([])
|
||
|
||
mem = EpisodicMemory(
|
||
session_factory=factory,
|
||
episodic_model=MockEpisodicModel,
|
||
)
|
||
|
||
await mem.search("test", filters={"agent_name": "specific_agent"})
|
||
|
||
# 验证 execute 被调用(即查询被执行)
|
||
mock_session.execute.assert_called_once()
|
||
|
||
async def test_search_applies_task_type_filter(self):
|
||
"""search 应用 task_type 过滤"""
|
||
factory, mock_session = make_mock_session_factory([])
|
||
|
||
mem = EpisodicMemory(
|
||
session_factory=factory,
|
||
episodic_model=MockEpisodicModel,
|
||
)
|
||
|
||
await mem.search("test", filters={"task_type": "analysis"})
|
||
|
||
mock_session.execute.assert_called_once()
|
||
|
||
async def test_search_applies_outcome_filter(self):
|
||
"""search 应用 outcome 过滤"""
|
||
factory, mock_session = make_mock_session_factory([])
|
||
|
||
mem = EpisodicMemory(
|
||
session_factory=factory,
|
||
episodic_model=MockEpisodicModel,
|
||
)
|
||
|
||
await mem.search("test", filters={"outcome": "success"})
|
||
|
||
mock_session.execute.assert_called_once()
|
||
|
||
async def test_search_top_k_limits_results(self):
|
||
"""search 的 top_k 限制返回数量"""
|
||
now = datetime.now(timezone.utc)
|
||
entries = [
|
||
make_mock_entry(quality_score=0.5 + i * 0.05, created_at=now)
|
||
for i in range(10)
|
||
]
|
||
|
||
factory, _ = make_mock_session_factory(entries)
|
||
|
||
mem = EpisodicMemory(
|
||
session_factory=factory,
|
||
episodic_model=MockEpisodicModel,
|
||
)
|
||
|
||
results = await mem.search("test", top_k=3)
|
||
assert len(results) <= 3
|
||
|
||
async def test_search_returns_memory_items(self):
|
||
"""search 返回 MemoryItem 列表"""
|
||
now = datetime.now(timezone.utc)
|
||
entry = make_mock_entry(
|
||
agent_name="test_agent",
|
||
task_type="analysis",
|
||
input_summary="test input",
|
||
output_summary="test output",
|
||
outcome="success",
|
||
quality_score=0.9,
|
||
reflection="good",
|
||
created_at=now,
|
||
)
|
||
|
||
factory, _ = make_mock_session_factory([entry])
|
||
|
||
mem = EpisodicMemory(
|
||
session_factory=factory,
|
||
episodic_model=MockEpisodicModel,
|
||
)
|
||
|
||
results = await mem.search("test")
|
||
assert len(results) == 1
|
||
item = results[0]
|
||
assert isinstance(item, MemoryItem)
|
||
assert item.value["input_summary"] == "test input"
|
||
assert item.value["output_summary"] == "test output"
|
||
assert item.value["outcome"] == "success"
|
||
assert item.metadata["agent_name"] == "test_agent"
|
||
assert item.metadata["task_type"] == "analysis"
|
||
|
||
|
||
class TestEpisodicMemoryDelete:
|
||
"""EpisodicMemory.delete 测试"""
|
||
|
||
async def test_delete_removes_entry_by_id(self):
|
||
"""delete 按 ID 删除条目"""
|
||
factory, mock_session = make_mock_session_factory()
|
||
|
||
mem = EpisodicMemory(
|
||
session_factory=factory,
|
||
episodic_model=MockEpisodicModel,
|
||
)
|
||
|
||
test_id = str(uuid.uuid4())
|
||
result = await mem.delete(test_id)
|
||
|
||
assert result is True
|
||
mock_session.execute.assert_called_once()
|
||
mock_session.commit.assert_called_once()
|
||
|
||
async def test_delete_returns_false_on_error(self):
|
||
"""delete 失败时返回 False"""
|
||
factory, mock_session = make_mock_session_factory()
|
||
|
||
mock_session.execute = AsyncMock(side_effect=Exception("DB error"))
|
||
|
||
mem = EpisodicMemory(
|
||
session_factory=factory,
|
||
episodic_model=MockEpisodicModel,
|
||
)
|
||
|
||
result = await mem.delete(str(uuid.uuid4()))
|
||
assert result is False
|
||
mock_session.rollback.assert_called_once()
|
||
|
||
|
||
class TestEpisodicMemoryRetrieve:
|
||
"""EpisodicMemory.retrieve 测试"""
|
||
|
||
async def test_retrieve_always_returns_none(self):
|
||
"""EpisodicMemory.retrieve 始终返回 None(按设计不支持 key 精确检索)"""
|
||
factory, _ = make_mock_session_factory()
|
||
|
||
mem = EpisodicMemory(
|
||
session_factory=factory,
|
||
episodic_model=MockEpisodicModel,
|
||
pgvector_enabled=False,
|
||
)
|
||
|
||
result = await mem.retrieve("any_key")
|
||
assert result is None
|