fischer-agentkit/tests/unit/test_episodic_memory.py

419 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""EpisodicMemory 单元测试 - 基于 pgvector + PostgreSQL 的任务经验记忆
使用 mock session_factory 和真实 SQLAlchemy ORM 模型进行单元测试,
不需要真实的 PostgreSQL/pgvector 环境。
"""
import uuid
from contextlib import asynccontextmanager
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, MagicMock
import pytest
from sqlalchemy import Column, DateTime, Float, String, delete as sql_delete, select
from sqlalchemy.orm import DeclarativeBase
from agentkit.memory.episodic import EpisodicMemory
from agentkit.memory.base import MemoryItem
# ── 真实 SQLAlchemy 模型(用于测试) ─────────────────────
class Base(DeclarativeBase):
pass
class MockEpisodicModel(Base):
"""模拟 EpisodicMemory ORM 模型,使用真实 SQLAlchemy 列定义"""
__tablename__ = "test_episodic_memory"
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
agent_name = Column(String, default="")
task_type = Column(String, default="")
input_summary = Column(String, default="")
output_summary = Column(String, default="")
outcome = Column(String, default="success")
quality_score = Column(Float, default=0.5)
reflection = Column(String, default="")
embedding = Column(String, nullable=True)
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
# ── Mock 辅助工具 ────────────────────────────────────────
def make_mock_entry(
id: uuid.UUID | None = None,
agent_name: str = "test_agent",
task_type: str = "analysis",
input_summary: str = "test input",
output_summary: str = "test output",
outcome: str = "success",
quality_score: float = 0.8,
reflection: str = "",
created_at: datetime | None = None,
):
"""创建一个模拟的 ORM entry 对象(使用真实模型实例)"""
entry = MockEpisodicModel(
id=str(id or uuid.uuid4()),
agent_name=agent_name,
task_type=task_type,
input_summary=input_summary,
output_summary=output_summary,
outcome=outcome,
quality_score=quality_score,
reflection=reflection,
created_at=created_at or datetime.now(timezone.utc),
)
return entry
def make_mock_session_factory(entries: list | None = None):
"""创建一个 mock session_factory返回包含指定 entries 的 session
Args:
entries: search 方法返回的 ORM entry 列表
"""
entries = entries or []
mock_session = AsyncMock()
mock_session.add = MagicMock()
mock_session.commit = AsyncMock()
mock_session.rollback = AsyncMock()
# 模拟 execute 返回的 result 对象
mock_result = MagicMock()
mock_scalars = MagicMock()
mock_scalars.all.return_value = entries
mock_result.scalars.return_value = mock_scalars
mock_session.execute = AsyncMock(return_value=mock_result)
@asynccontextmanager
async def factory():
yield mock_session
return factory, mock_session
# ── EpisodicMemory 测试 ──────────────────────────────────
class TestEpisodicMemoryStore:
"""EpisodicMemory.store 测试"""
async def test_store_writes_entry_with_correct_fields(self):
"""store 写入包含正确字段的 entry"""
factory, mock_session = make_mock_session_factory()
mem = EpisodicMemory(
session_factory=factory,
episodic_model=MockEpisodicModel,
)
await mem.store(
key="task:001",
value="Analyzed financial data",
metadata={
"agent_name": "analyst_agent",
"task_type": "financial_analysis",
"output_summary": "Report generated",
"outcome": "success",
"quality_score": 0.9,
"reflection": "Good analysis",
},
)
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
# 验证传入 add 的 entry 参数
entry_arg = mock_session.add.call_args[0][0]
assert isinstance(entry_arg, MockEpisodicModel)
assert entry_arg.agent_name == "analyst_agent"
assert entry_arg.task_type == "financial_analysis"
assert entry_arg.input_summary == "Analyzed financial data"
assert entry_arg.output_summary == "Report generated"
assert entry_arg.outcome == "success"
assert entry_arg.quality_score == 0.9
assert entry_arg.reflection == "Good analysis"
async def test_store_with_embedder_generates_embedding(self):
"""store 时有 embedder 则生成 embedding"""
factory, mock_session = make_mock_session_factory()
mock_embedder = AsyncMock()
mock_embedder.embed = AsyncMock(return_value=[0.1, 0.2, 0.3])
mem = EpisodicMemory(
session_factory=factory,
episodic_model=MockEpisodicModel,
embedder=mock_embedder,
)
await mem.store("key1", "some value", {"agent_name": "test"})
mock_embedder.embed.assert_called_once()
call_args = mock_embedder.embed.call_args[0][0]
assert "some value" in call_args
# 验证 entry 的 embedding 被设置
entry_arg = mock_session.add.call_args[0][0]
assert entry_arg.embedding == [0.1, 0.2, 0.3]
async def test_store_without_embedder_no_embedding(self):
"""store 时无 embedder 则 embedding 为 None"""
factory, mock_session = make_mock_session_factory()
mem = EpisodicMemory(
session_factory=factory,
episodic_model=MockEpisodicModel,
embedder=None,
)
await mem.store("key1", "some value")
entry_arg = mock_session.add.call_args[0][0]
assert entry_arg.embedding is None
async def test_store_rollback_on_error(self):
"""store 失败时执行 rollback"""
factory, mock_session = make_mock_session_factory()
# 让 commit 抛出异常
mock_session.commit = AsyncMock(side_effect=Exception("DB error"))
mem = EpisodicMemory(
session_factory=factory,
episodic_model=MockEpisodicModel,
)
with pytest.raises(Exception, match="DB error"):
await mem.store("key1", "value1")
mock_session.rollback.assert_called_once()
async def test_store_default_metadata_values(self):
"""store 时 metadata 缺失字段使用默认值"""
factory, mock_session = make_mock_session_factory()
mem = EpisodicMemory(
session_factory=factory,
episodic_model=MockEpisodicModel,
)
await mem.store("key1", "value1")
entry_arg = mock_session.add.call_args[0][0]
assert entry_arg.agent_name == ""
assert entry_arg.task_type == ""
assert entry_arg.outcome == "success"
assert entry_arg.quality_score == 0.5
assert entry_arg.reflection == ""
class TestEpisodicMemorySearch:
"""EpisodicMemory.search 测试"""
async def test_search_with_time_decay_recent_scores_higher(self):
"""时间衰减:近期条目得分更高"""
now = datetime.now(timezone.utc)
recent_entry = make_mock_entry(
quality_score=0.8,
created_at=now - timedelta(hours=1),
)
old_entry = make_mock_entry(
quality_score=0.8,
created_at=now - timedelta(hours=100),
)
factory, _ = make_mock_session_factory([recent_entry, old_entry])
mem = EpisodicMemory(
session_factory=factory,
episodic_model=MockEpisodicModel,
decay_rate=0.01,
)
results = await mem.search("test query")
assert len(results) == 2
# 近期条目应排在前面
assert results[0].score > results[1].score
async def test_search_with_quality_score_factor(self):
"""quality_score 影响最终得分"""
now = datetime.now(timezone.utc)
high_quality = make_mock_entry(
quality_score=0.9,
created_at=now - timedelta(hours=1),
)
low_quality = make_mock_entry(
quality_score=0.1,
created_at=now - timedelta(hours=1),
)
factory, _ = make_mock_session_factory([high_quality, low_quality])
mem = EpisodicMemory(
session_factory=factory,
episodic_model=MockEpisodicModel,
)
results = await mem.search("test query")
assert len(results) == 2
# 高质量条目应排在前面
assert results[0].score > results[1].score
async def test_search_empty_store_returns_empty(self):
"""空存储 search 返回空列表"""
factory, _ = make_mock_session_factory([])
mem = EpisodicMemory(
session_factory=factory,
episodic_model=MockEpisodicModel,
)
results = await mem.search("anything")
assert results == []
async def test_search_applies_agent_name_filter(self):
"""search 应用 agent_name 过滤"""
factory, mock_session = make_mock_session_factory([])
mem = EpisodicMemory(
session_factory=factory,
episodic_model=MockEpisodicModel,
)
await mem.search("test", filters={"agent_name": "specific_agent"})
# 验证 execute 被调用(即查询被执行)
mock_session.execute.assert_called_once()
async def test_search_applies_task_type_filter(self):
"""search 应用 task_type 过滤"""
factory, mock_session = make_mock_session_factory([])
mem = EpisodicMemory(
session_factory=factory,
episodic_model=MockEpisodicModel,
)
await mem.search("test", filters={"task_type": "analysis"})
mock_session.execute.assert_called_once()
async def test_search_applies_outcome_filter(self):
"""search 应用 outcome 过滤"""
factory, mock_session = make_mock_session_factory([])
mem = EpisodicMemory(
session_factory=factory,
episodic_model=MockEpisodicModel,
)
await mem.search("test", filters={"outcome": "success"})
mock_session.execute.assert_called_once()
async def test_search_top_k_limits_results(self):
"""search 的 top_k 限制返回数量"""
now = datetime.now(timezone.utc)
entries = [
make_mock_entry(quality_score=0.5 + i * 0.05, created_at=now)
for i in range(10)
]
factory, _ = make_mock_session_factory(entries)
mem = EpisodicMemory(
session_factory=factory,
episodic_model=MockEpisodicModel,
)
results = await mem.search("test", top_k=3)
assert len(results) <= 3
async def test_search_returns_memory_items(self):
"""search 返回 MemoryItem 列表"""
now = datetime.now(timezone.utc)
entry = make_mock_entry(
agent_name="test_agent",
task_type="analysis",
input_summary="test input",
output_summary="test output",
outcome="success",
quality_score=0.9,
reflection="good",
created_at=now,
)
factory, _ = make_mock_session_factory([entry])
mem = EpisodicMemory(
session_factory=factory,
episodic_model=MockEpisodicModel,
)
results = await mem.search("test")
assert len(results) == 1
item = results[0]
assert isinstance(item, MemoryItem)
assert item.value["input_summary"] == "test input"
assert item.value["output_summary"] == "test output"
assert item.value["outcome"] == "success"
assert item.metadata["agent_name"] == "test_agent"
assert item.metadata["task_type"] == "analysis"
class TestEpisodicMemoryDelete:
"""EpisodicMemory.delete 测试"""
async def test_delete_removes_entry_by_id(self):
"""delete 按 ID 删除条目"""
factory, mock_session = make_mock_session_factory()
mem = EpisodicMemory(
session_factory=factory,
episodic_model=MockEpisodicModel,
)
test_id = str(uuid.uuid4())
result = await mem.delete(test_id)
assert result is True
mock_session.execute.assert_called_once()
mock_session.commit.assert_called_once()
async def test_delete_returns_false_on_error(self):
"""delete 失败时返回 False"""
factory, mock_session = make_mock_session_factory()
mock_session.execute = AsyncMock(side_effect=Exception("DB error"))
mem = EpisodicMemory(
session_factory=factory,
episodic_model=MockEpisodicModel,
)
result = await mem.delete(str(uuid.uuid4()))
assert result is False
mock_session.rollback.assert_called_once()
class TestEpisodicMemoryRetrieve:
"""EpisodicMemory.retrieve 测试"""
async def test_retrieve_always_returns_none(self):
"""EpisodicMemory.retrieve 始终返回 None按设计不支持 key 精确检索)"""
factory, _ = make_mock_session_factory()
mem = EpisodicMemory(
session_factory=factory,
episodic_model=MockEpisodicModel,
)
result = await mem.retrieve("any_key")
assert result is None