"""EpisodicMemory 单元测试 - 基于 pgvector + PostgreSQL 的任务经验记忆 使用 mock session_factory 和真实 SQLAlchemy ORM 模型进行单元测试, 不需要真实的 PostgreSQL/pgvector 环境。 """ import uuid from contextlib import asynccontextmanager from datetime import datetime, timedelta, timezone from unittest.mock import AsyncMock, MagicMock import pytest from sqlalchemy import Column, DateTime, Float, String, delete as sql_delete, select from sqlalchemy.orm import DeclarativeBase from agentkit.memory.episodic import EpisodicMemory from agentkit.memory.base import MemoryItem # ── 真实 SQLAlchemy 模型(用于测试) ───────────────────── class Base(DeclarativeBase): pass class MockEpisodicModel(Base): """模拟 EpisodicMemory ORM 模型,使用真实 SQLAlchemy 列定义""" __tablename__ = "test_episodic_memory" id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4())) agent_name = Column(String, default="") task_type = Column(String, default="") input_summary = Column(String, default="") output_summary = Column(String, default="") outcome = Column(String, default="success") quality_score = Column(Float, default=0.5) reflection = Column(String, default="") embedding = Column(String, nullable=True) created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc)) # ── Mock 辅助工具 ──────────────────────────────────────── def make_mock_entry( id: uuid.UUID | None = None, agent_name: str = "test_agent", task_type: str = "analysis", input_summary: str = "test input", output_summary: str = "test output", outcome: str = "success", quality_score: float = 0.8, reflection: str = "", created_at: datetime | None = None, ): """创建一个模拟的 ORM entry 对象(使用真实模型实例)""" entry = MockEpisodicModel( id=str(id or uuid.uuid4()), agent_name=agent_name, task_type=task_type, input_summary=input_summary, output_summary=output_summary, outcome=outcome, quality_score=quality_score, reflection=reflection, created_at=created_at or datetime.now(timezone.utc), ) return entry def make_mock_session_factory(entries: list | None = None): """创建一个 mock session_factory,返回包含指定 entries 的 session Args: entries: search 方法返回的 ORM entry 列表 """ entries = entries or [] mock_session = AsyncMock() mock_session.add = MagicMock() mock_session.commit = AsyncMock() mock_session.rollback = AsyncMock() # 模拟 execute 返回的 result 对象 mock_result = MagicMock() mock_scalars = MagicMock() mock_scalars.all.return_value = entries mock_result.scalars.return_value = mock_scalars mock_session.execute = AsyncMock(return_value=mock_result) @asynccontextmanager async def factory(): yield mock_session return factory, mock_session # ── EpisodicMemory 测试 ────────────────────────────────── class TestEpisodicMemoryStore: """EpisodicMemory.store 测试""" async def test_store_writes_entry_with_correct_fields(self): """store 写入包含正确字段的 entry""" factory, mock_session = make_mock_session_factory() mem = EpisodicMemory( session_factory=factory, episodic_model=MockEpisodicModel, ) await mem.store( key="task:001", value="Analyzed financial data", metadata={ "agent_name": "analyst_agent", "task_type": "financial_analysis", "output_summary": "Report generated", "outcome": "success", "quality_score": 0.9, "reflection": "Good analysis", }, ) mock_session.add.assert_called_once() mock_session.commit.assert_called_once() # 验证传入 add 的 entry 参数 entry_arg = mock_session.add.call_args[0][0] assert isinstance(entry_arg, MockEpisodicModel) assert entry_arg.agent_name == "analyst_agent" assert entry_arg.task_type == "financial_analysis" assert entry_arg.input_summary == "Analyzed financial data" assert entry_arg.output_summary == "Report generated" assert entry_arg.outcome == "success" assert entry_arg.quality_score == 0.9 assert entry_arg.reflection == "Good analysis" async def test_store_with_embedder_generates_embedding(self): """store 时有 embedder 则生成 embedding""" factory, mock_session = make_mock_session_factory() mock_embedder = AsyncMock() mock_embedder.embed = AsyncMock(return_value=[0.1, 0.2, 0.3]) mem = EpisodicMemory( session_factory=factory, episodic_model=MockEpisodicModel, embedder=mock_embedder, ) await mem.store("key1", "some value", {"agent_name": "test"}) mock_embedder.embed.assert_called_once() call_args = mock_embedder.embed.call_args[0][0] assert "some value" in call_args # 验证 entry 的 embedding 被设置 entry_arg = mock_session.add.call_args[0][0] assert entry_arg.embedding == [0.1, 0.2, 0.3] async def test_store_without_embedder_no_embedding(self): """store 时无 embedder 则 embedding 为 None""" factory, mock_session = make_mock_session_factory() mem = EpisodicMemory( session_factory=factory, episodic_model=MockEpisodicModel, embedder=None, ) await mem.store("key1", "some value") entry_arg = mock_session.add.call_args[0][0] assert entry_arg.embedding is None async def test_store_rollback_on_error(self): """store 失败时执行 rollback""" factory, mock_session = make_mock_session_factory() # 让 commit 抛出异常 mock_session.commit = AsyncMock(side_effect=Exception("DB error")) mem = EpisodicMemory( session_factory=factory, episodic_model=MockEpisodicModel, ) with pytest.raises(Exception, match="DB error"): await mem.store("key1", "value1") mock_session.rollback.assert_called_once() async def test_store_default_metadata_values(self): """store 时 metadata 缺失字段使用默认值""" factory, mock_session = make_mock_session_factory() mem = EpisodicMemory( session_factory=factory, episodic_model=MockEpisodicModel, ) await mem.store("key1", "value1") entry_arg = mock_session.add.call_args[0][0] assert entry_arg.agent_name == "" assert entry_arg.task_type == "" assert entry_arg.outcome == "success" assert entry_arg.quality_score == 0.5 assert entry_arg.reflection == "" class TestEpisodicMemorySearch: """EpisodicMemory.search 测试""" async def test_search_with_time_decay_recent_scores_higher(self): """时间衰减:近期条目得分更高""" now = datetime.now(timezone.utc) recent_entry = make_mock_entry( quality_score=0.8, created_at=now - timedelta(hours=1), ) old_entry = make_mock_entry( quality_score=0.8, created_at=now - timedelta(hours=100), ) factory, _ = make_mock_session_factory([recent_entry, old_entry]) mem = EpisodicMemory( session_factory=factory, episodic_model=MockEpisodicModel, decay_rate=0.01, ) results = await mem.search("test query") assert len(results) == 2 # 近期条目应排在前面 assert results[0].score > results[1].score async def test_search_with_quality_score_factor(self): """quality_score 影响最终得分""" now = datetime.now(timezone.utc) high_quality = make_mock_entry( quality_score=0.9, created_at=now - timedelta(hours=1), ) low_quality = make_mock_entry( quality_score=0.1, created_at=now - timedelta(hours=1), ) factory, _ = make_mock_session_factory([high_quality, low_quality]) mem = EpisodicMemory( session_factory=factory, episodic_model=MockEpisodicModel, ) results = await mem.search("test query") assert len(results) == 2 # 高质量条目应排在前面 assert results[0].score > results[1].score async def test_search_empty_store_returns_empty(self): """空存储 search 返回空列表""" factory, _ = make_mock_session_factory([]) mem = EpisodicMemory( session_factory=factory, episodic_model=MockEpisodicModel, ) results = await mem.search("anything") assert results == [] async def test_search_applies_agent_name_filter(self): """search 应用 agent_name 过滤""" factory, mock_session = make_mock_session_factory([]) mem = EpisodicMemory( session_factory=factory, episodic_model=MockEpisodicModel, ) await mem.search("test", filters={"agent_name": "specific_agent"}) # 验证 execute 被调用(即查询被执行) mock_session.execute.assert_called_once() async def test_search_applies_task_type_filter(self): """search 应用 task_type 过滤""" factory, mock_session = make_mock_session_factory([]) mem = EpisodicMemory( session_factory=factory, episodic_model=MockEpisodicModel, ) await mem.search("test", filters={"task_type": "analysis"}) mock_session.execute.assert_called_once() async def test_search_applies_outcome_filter(self): """search 应用 outcome 过滤""" factory, mock_session = make_mock_session_factory([]) mem = EpisodicMemory( session_factory=factory, episodic_model=MockEpisodicModel, ) await mem.search("test", filters={"outcome": "success"}) mock_session.execute.assert_called_once() async def test_search_top_k_limits_results(self): """search 的 top_k 限制返回数量""" now = datetime.now(timezone.utc) entries = [ make_mock_entry(quality_score=0.5 + i * 0.05, created_at=now) for i in range(10) ] factory, _ = make_mock_session_factory(entries) mem = EpisodicMemory( session_factory=factory, episodic_model=MockEpisodicModel, ) results = await mem.search("test", top_k=3) assert len(results) <= 3 async def test_search_returns_memory_items(self): """search 返回 MemoryItem 列表""" now = datetime.now(timezone.utc) entry = make_mock_entry( agent_name="test_agent", task_type="analysis", input_summary="test input", output_summary="test output", outcome="success", quality_score=0.9, reflection="good", created_at=now, ) factory, _ = make_mock_session_factory([entry]) mem = EpisodicMemory( session_factory=factory, episodic_model=MockEpisodicModel, ) results = await mem.search("test") assert len(results) == 1 item = results[0] assert isinstance(item, MemoryItem) assert item.value["input_summary"] == "test input" assert item.value["output_summary"] == "test output" assert item.value["outcome"] == "success" assert item.metadata["agent_name"] == "test_agent" assert item.metadata["task_type"] == "analysis" class TestEpisodicMemoryDelete: """EpisodicMemory.delete 测试""" async def test_delete_removes_entry_by_id(self): """delete 按 ID 删除条目""" factory, mock_session = make_mock_session_factory() mem = EpisodicMemory( session_factory=factory, episodic_model=MockEpisodicModel, ) test_id = str(uuid.uuid4()) result = await mem.delete(test_id) assert result is True mock_session.execute.assert_called_once() mock_session.commit.assert_called_once() async def test_delete_returns_false_on_error(self): """delete 失败时返回 False""" factory, mock_session = make_mock_session_factory() mock_session.execute = AsyncMock(side_effect=Exception("DB error")) mem = EpisodicMemory( session_factory=factory, episodic_model=MockEpisodicModel, ) result = await mem.delete(str(uuid.uuid4())) assert result is False mock_session.rollback.assert_called_once() class TestEpisodicMemoryRetrieve: """EpisodicMemory.retrieve 测试""" async def test_retrieve_always_returns_none(self): """EpisodicMemory.retrieve 始终返回 None(按设计不支持 key 精确检索)""" factory, _ = make_mock_session_factory() mem = EpisodicMemory( session_factory=factory, episodic_model=MockEpisodicModel, ) result = await mem.retrieve("any_key") assert result is None