1022 lines
33 KiB
Python
1022 lines
33 KiB
Python
"""EpisodicMemory 向量检索单元测试 - cosine similarity + hybrid scoring + pgvector"""
|
||
|
||
import uuid
|
||
from contextlib import asynccontextmanager
|
||
from datetime import datetime, timedelta, timezone
|
||
from unittest.mock import AsyncMock, MagicMock
|
||
|
||
import pytest
|
||
from sqlalchemy import Column, DateTime, Float, String
|
||
from sqlalchemy.orm import DeclarativeBase
|
||
|
||
from agentkit.memory.episodic import EpisodicMemory
|
||
from agentkit.memory.base import MemoryItem
|
||
from agentkit.memory.embedder import MockEmbedder
|
||
from agentkit.utils.vector_math import compute_cosine_similarity
|
||
|
||
|
||
# ── 真实 SQLAlchemy 模型(用于测试) ─────────────────────
|
||
|
||
|
||
class Base(DeclarativeBase):
|
||
pass
|
||
|
||
|
||
class MockEpisodicModel(Base):
|
||
"""模拟 EpisodicMemory ORM 模型"""
|
||
|
||
__tablename__ = "test_episodic_vector_search"
|
||
|
||
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||
agent_name = Column(String, default="")
|
||
task_type = Column(String, default="")
|
||
input_summary = Column(String, default="")
|
||
output_summary = Column(String, default="")
|
||
outcome = Column(String, default="success")
|
||
quality_score = Column(Float, default=0.5)
|
||
reflection = Column(String, default="")
|
||
embedding = Column(String, nullable=True)
|
||
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
||
|
||
|
||
# ── Mock 辅助工具 ────────────────────────────────────────
|
||
|
||
|
||
def make_mock_entry(
|
||
id: uuid.UUID | None = None,
|
||
agent_name: str = "test_agent",
|
||
task_type: str = "analysis",
|
||
input_summary: str = "test input",
|
||
output_summary: str = "test output",
|
||
outcome: str = "success",
|
||
quality_score: float = 0.8,
|
||
reflection: str = "",
|
||
embedding: list[float] | None = None,
|
||
created_at: datetime | None = None,
|
||
):
|
||
"""创建一个模拟的 ORM entry 对象"""
|
||
entry = MockEpisodicModel(
|
||
id=str(id or uuid.uuid4()),
|
||
agent_name=agent_name,
|
||
task_type=task_type,
|
||
input_summary=input_summary,
|
||
output_summary=output_summary,
|
||
outcome=outcome,
|
||
quality_score=quality_score,
|
||
reflection=reflection,
|
||
created_at=created_at or datetime.now(timezone.utc),
|
||
)
|
||
# 直接设置 embedding 属性(绕过 Column 限制)
|
||
entry.embedding = embedding
|
||
return entry
|
||
|
||
|
||
def make_mock_session_factory(entries: list | None = None):
|
||
"""创建一个 mock session_factory"""
|
||
entries = entries or []
|
||
|
||
mock_session = AsyncMock()
|
||
mock_session.add = MagicMock()
|
||
mock_session.commit = AsyncMock()
|
||
mock_session.rollback = AsyncMock()
|
||
|
||
mock_result = MagicMock()
|
||
mock_scalars = MagicMock()
|
||
mock_scalars.all.return_value = entries
|
||
mock_result.scalars.return_value = mock_scalars
|
||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||
|
||
@asynccontextmanager
|
||
async def factory():
|
||
yield mock_session
|
||
|
||
return factory, mock_session
|
||
|
||
|
||
class _RowMapping(dict):
|
||
"""A dict subclass that supports both ``row["key"]`` and ``row.get("key")``
|
||
access patterns, mimicking SQLAlchemy's MappingResult rows."""
|
||
|
||
def __getattr__(self, name: str):
|
||
try:
|
||
return self[name]
|
||
except KeyError:
|
||
raise AttributeError(name)
|
||
|
||
|
||
def _make_row_mapping(data: dict) -> _RowMapping:
|
||
"""Create a _RowMapping from a dict, for use in pgvector mock tests."""
|
||
return _RowMapping(data)
|
||
|
||
|
||
# ── Cosine Similarity 测试 ──────────────────────────────
|
||
|
||
|
||
class TestCosineSimilarity:
|
||
"""compute_cosine_similarity 测试"""
|
||
|
||
def test_identical_vectors_return_one(self):
|
||
"""相同向量余弦相似度为 1"""
|
||
vec = [1.0, 0.0, 0.0]
|
||
assert compute_cosine_similarity(vec, vec) == pytest.approx(1.0)
|
||
|
||
def test_orthogonal_vectors_return_zero(self):
|
||
"""正交向量余弦相似度为 0"""
|
||
vec_a = [1.0, 0.0]
|
||
vec_b = [0.0, 1.0]
|
||
assert compute_cosine_similarity(vec_a, vec_b) == pytest.approx(0.0)
|
||
|
||
def test_opposite_vectors_return_minus_one(self):
|
||
"""相反向量余弦相似度为 -1"""
|
||
vec_a = [1.0, 0.0]
|
||
vec_b = [-1.0, 0.0]
|
||
assert compute_cosine_similarity(vec_a, vec_b) == pytest.approx(-1.0)
|
||
|
||
def test_dimension_mismatch_returns_zero(self):
|
||
"""维度不匹配返回 0"""
|
||
vec_a = [1.0, 2.0]
|
||
vec_b = [1.0]
|
||
assert compute_cosine_similarity(vec_a, vec_b) == 0.0
|
||
|
||
def test_empty_vectors_return_zero(self):
|
||
"""空向量返回 0"""
|
||
assert compute_cosine_similarity([], []) == 0.0
|
||
|
||
def test_zero_vector_returns_zero(self):
|
||
"""零向量返回 0"""
|
||
vec_a = [0.0, 0.0]
|
||
vec_b = [1.0, 2.0]
|
||
assert compute_cosine_similarity(vec_a, vec_b) == 0.0
|
||
|
||
|
||
# ── MockEmbedder 测试 ───────────────────────────────────
|
||
|
||
|
||
class TestMockEmbedder:
|
||
"""MockEmbedder 测试"""
|
||
|
||
async def test_embed_returns_correct_dimension(self):
|
||
"""embed 返回指定维度的向量"""
|
||
embedder = MockEmbedder(dimension=64)
|
||
vec = await embedder.embed("test text")
|
||
assert len(vec) == 64
|
||
|
||
async def test_embed_is_deterministic(self):
|
||
"""相同文本生成相同向量"""
|
||
embedder = MockEmbedder(dimension=32)
|
||
vec1 = await embedder.embed("hello world")
|
||
vec2 = await embedder.embed("hello world")
|
||
assert vec1 == vec2
|
||
|
||
async def test_embed_different_text_different_vector(self):
|
||
"""不同文本生成不同向量"""
|
||
embedder = MockEmbedder(dimension=32)
|
||
vec1 = await embedder.embed("hello")
|
||
vec2 = await embedder.embed("world")
|
||
assert vec1 != vec2
|
||
|
||
async def test_embed_produces_unit_vector(self):
|
||
"""embed 生成单位向量"""
|
||
embedder = MockEmbedder(dimension=32)
|
||
vec = await embedder.embed("test")
|
||
magnitude = sum(x**2 for x in vec) ** 0.5
|
||
assert magnitude == pytest.approx(1.0, abs=1e-6)
|
||
|
||
def test_get_dimension(self):
|
||
"""get_dimension 返回正确维度"""
|
||
embedder = MockEmbedder(dimension=256)
|
||
assert embedder.get_dimension() == 256
|
||
|
||
|
||
# ── Store 测试 ──────────────────────────────────────────
|
||
|
||
|
||
class TestStoreWithEmbedder:
|
||
"""store() 带 embedder 的测试"""
|
||
|
||
async def test_store_generates_embedding_when_embedder_provided(self):
|
||
"""有 embedder 时 store 生成 embedding"""
|
||
factory, mock_session = make_mock_session_factory()
|
||
embedder = MockEmbedder(dimension=32)
|
||
|
||
mem = EpisodicMemory(
|
||
session_factory=factory,
|
||
episodic_model=MockEpisodicModel,
|
||
embedder=embedder,
|
||
)
|
||
|
||
await mem.store("key1", "some value", {"agent_name": "test"})
|
||
|
||
entry_arg = mock_session.add.call_args[0][0]
|
||
assert entry_arg.embedding is not None
|
||
assert len(entry_arg.embedding) == 32
|
||
|
||
async def test_store_no_embedding_without_embedder(self):
|
||
"""无 embedder 时 store 不生成 embedding"""
|
||
factory, mock_session = make_mock_session_factory()
|
||
|
||
mem = EpisodicMemory(
|
||
session_factory=factory,
|
||
episodic_model=MockEpisodicModel,
|
||
)
|
||
|
||
await mem.store("key1", "some value")
|
||
|
||
entry_arg = mock_session.add.call_args[0][0]
|
||
assert entry_arg.embedding is None
|
||
|
||
|
||
# ── Search 向量检索测试 ─────────────────────────────────
|
||
|
||
|
||
class TestSearchVectorSearch:
|
||
"""search() 向量检索测试"""
|
||
|
||
async def test_search_with_embedder_uses_cosine_similarity(self):
|
||
"""有 embedder 时 search 使用 cosine similarity 排序"""
|
||
embedder = MockEmbedder(dimension=32)
|
||
|
||
# 生成 embedding
|
||
vec_similar = await embedder.embed("financial analysis")
|
||
vec_different = await embedder.embed("completely unrelated topic xyz")
|
||
|
||
now = datetime.now(timezone.utc)
|
||
similar_entry = make_mock_entry(
|
||
input_summary="financial analysis report",
|
||
quality_score=0.5,
|
||
embedding=vec_similar,
|
||
created_at=now,
|
||
)
|
||
different_entry = make_mock_entry(
|
||
input_summary="unrelated task",
|
||
quality_score=0.5,
|
||
embedding=vec_different,
|
||
created_at=now,
|
||
)
|
||
|
||
factory, _ = make_mock_session_factory([similar_entry, different_entry])
|
||
|
||
mem = EpisodicMemory(
|
||
session_factory=factory,
|
||
episodic_model=MockEpisodicModel,
|
||
embedder=embedder,
|
||
alpha=1.0, # 纯 cosine 排序
|
||
pgvector_enabled=False, # 使用客户端 cosine
|
||
)
|
||
|
||
results = await mem.search("financial analysis")
|
||
assert len(results) == 2
|
||
# 相似条目应排在前面
|
||
assert results[0].value["input_summary"] == "financial analysis report"
|
||
|
||
async def test_search_fallback_to_time_decay_without_embedder(self):
|
||
"""无 embedder 时 search 回退到时间衰减排序"""
|
||
now = datetime.now(timezone.utc)
|
||
recent_entry = make_mock_entry(
|
||
quality_score=0.8,
|
||
created_at=now - timedelta(hours=1),
|
||
)
|
||
old_entry = make_mock_entry(
|
||
quality_score=0.8,
|
||
created_at=now - timedelta(hours=100),
|
||
)
|
||
|
||
factory, _ = make_mock_session_factory([recent_entry, old_entry])
|
||
|
||
mem = EpisodicMemory(
|
||
session_factory=factory,
|
||
episodic_model=MockEpisodicModel,
|
||
)
|
||
|
||
results = await mem.search("test query")
|
||
assert len(results) == 2
|
||
# 近期条目应排在前面(纯时间衰减)
|
||
assert results[0].score > results[1].score
|
||
|
||
async def test_search_hybrid_scoring_formula(self):
|
||
"""混合评分公式:alpha * cosine + (1-alpha) * time_decay"""
|
||
embedder = MockEmbedder(dimension=32)
|
||
|
||
vec_similar = await embedder.embed("query text")
|
||
vec_different = await embedder.embed("something else entirely")
|
||
|
||
now = datetime.now(timezone.utc)
|
||
# 相似条目但质量低
|
||
similar_entry = make_mock_entry(
|
||
quality_score=0.5,
|
||
embedding=vec_similar,
|
||
created_at=now,
|
||
)
|
||
# 不相似条目但质量高
|
||
different_entry = make_mock_entry(
|
||
quality_score=0.9,
|
||
embedding=vec_different,
|
||
created_at=now,
|
||
)
|
||
|
||
factory, _ = make_mock_session_factory([similar_entry, different_entry])
|
||
|
||
# alpha=1.0 → 纯 cosine 排序
|
||
mem = EpisodicMemory(
|
||
session_factory=factory,
|
||
episodic_model=MockEpisodicModel,
|
||
embedder=embedder,
|
||
alpha=1.0,
|
||
pgvector_enabled=False,
|
||
)
|
||
|
||
results = await mem.search("query text")
|
||
# alpha=1.0 时,cosine 主导,相似条目排前面
|
||
assert results[0].value["input_summary"] == similar_entry.input_summary
|
||
|
||
async def test_search_alpha_zero_pure_time_decay(self):
|
||
"""alpha=0 时完全使用时间衰减排序"""
|
||
embedder = MockEmbedder(dimension=32)
|
||
|
||
vec_similar = await embedder.embed("query text")
|
||
vec_different = await embedder.embed("something else")
|
||
|
||
now = datetime.now(timezone.utc)
|
||
# 相似但质量低
|
||
similar_entry = make_mock_entry(
|
||
quality_score=0.3,
|
||
embedding=vec_similar,
|
||
created_at=now,
|
||
)
|
||
# 不相似但质量高
|
||
different_entry = make_mock_entry(
|
||
quality_score=0.9,
|
||
embedding=vec_different,
|
||
created_at=now,
|
||
)
|
||
|
||
factory, _ = make_mock_session_factory([similar_entry, different_entry])
|
||
|
||
mem = EpisodicMemory(
|
||
session_factory=factory,
|
||
episodic_model=MockEpisodicModel,
|
||
embedder=embedder,
|
||
alpha=0.0, # 纯时间衰减
|
||
pgvector_enabled=False,
|
||
)
|
||
|
||
results = await mem.search("query text")
|
||
# alpha=0 时,time_decay 主导,高质量条目排前面
|
||
assert results[0].value["quality_score"] == 0.9
|
||
|
||
async def test_search_entry_without_embedding_uses_time_decay(self):
|
||
"""有 embedder 但 entry 没有 embedding 时使用时间衰减"""
|
||
embedder = MockEmbedder(dimension=32)
|
||
|
||
now = datetime.now(timezone.utc)
|
||
entry_with_embedding = make_mock_entry(
|
||
quality_score=0.5,
|
||
embedding=await embedder.embed("test"),
|
||
created_at=now - timedelta(hours=10),
|
||
)
|
||
entry_without_embedding = make_mock_entry(
|
||
quality_score=0.9,
|
||
embedding=None,
|
||
created_at=now,
|
||
)
|
||
|
||
factory, _ = make_mock_session_factory([entry_with_embedding, entry_without_embedding])
|
||
|
||
mem = EpisodicMemory(
|
||
session_factory=factory,
|
||
episodic_model=MockEpisodicModel,
|
||
embedder=embedder,
|
||
alpha=0.7,
|
||
pgvector_enabled=False,
|
||
)
|
||
|
||
results = await mem.search("test query")
|
||
assert len(results) == 2
|
||
|
||
async def test_search_empty_store_returns_empty(self):
|
||
"""空存储 search 返回空列表"""
|
||
factory, _ = make_mock_session_factory([])
|
||
embedder = MockEmbedder(dimension=32)
|
||
|
||
mem = EpisodicMemory(
|
||
session_factory=factory,
|
||
episodic_model=MockEpisodicModel,
|
||
embedder=embedder,
|
||
)
|
||
|
||
results = await mem.search("anything")
|
||
assert results == []
|
||
|
||
|
||
# ── Retrieve 向量检索测试 ───────────────────────────────
|
||
|
||
|
||
class TestRetrieveVectorSearch:
|
||
"""retrieve() 向量检索测试"""
|
||
|
||
async def test_retrieve_with_embedder_returns_best_match(self):
|
||
"""有 embedder 时 retrieve 返回最相似条目"""
|
||
embedder = MockEmbedder(dimension=32)
|
||
|
||
vec_similar = await embedder.embed("financial report")
|
||
vec_different = await embedder.embed("weather forecast")
|
||
|
||
now = datetime.now(timezone.utc)
|
||
similar_entry = make_mock_entry(
|
||
input_summary="financial report Q4",
|
||
embedding=vec_similar,
|
||
created_at=now,
|
||
)
|
||
different_entry = make_mock_entry(
|
||
input_summary="weather forecast today",
|
||
embedding=vec_different,
|
||
created_at=now,
|
||
)
|
||
|
||
factory, _ = make_mock_session_factory([similar_entry, different_entry])
|
||
|
||
mem = EpisodicMemory(
|
||
session_factory=factory,
|
||
episodic_model=MockEpisodicModel,
|
||
embedder=embedder,
|
||
pgvector_enabled=False,
|
||
)
|
||
|
||
result = await mem.retrieve("financial report")
|
||
assert result is not None
|
||
assert result.value["input_summary"] == "financial report Q4"
|
||
assert result.metadata["cosine_similarity"] > 0.0
|
||
|
||
async def test_retrieve_without_embedder_returns_none(self):
|
||
"""无 embedder 时 retrieve 返回 None"""
|
||
factory, _ = make_mock_session_factory([])
|
||
|
||
mem = EpisodicMemory(
|
||
session_factory=factory,
|
||
episodic_model=MockEpisodicModel,
|
||
)
|
||
|
||
result = await mem.retrieve("any key")
|
||
assert result is None
|
||
|
||
async def test_retrieve_empty_store_returns_none(self):
|
||
"""空存储 retrieve 返回 None"""
|
||
factory, _ = make_mock_session_factory([])
|
||
embedder = MockEmbedder(dimension=32)
|
||
|
||
mem = EpisodicMemory(
|
||
session_factory=factory,
|
||
episodic_model=MockEpisodicModel,
|
||
embedder=embedder,
|
||
)
|
||
|
||
result = await mem.retrieve("any key")
|
||
assert result is None
|
||
|
||
async def test_retrieve_no_entries_with_embedding_returns_none(self):
|
||
"""所有 entry 都没有 embedding 时 retrieve 返回 None"""
|
||
embedder = MockEmbedder(dimension=32)
|
||
|
||
now = datetime.now(timezone.utc)
|
||
entry = make_mock_entry(
|
||
embedding=None,
|
||
created_at=now,
|
||
)
|
||
|
||
factory, _ = make_mock_session_factory([entry])
|
||
|
||
mem = EpisodicMemory(
|
||
session_factory=factory,
|
||
episodic_model=MockEpisodicModel,
|
||
embedder=embedder,
|
||
pgvector_enabled=False,
|
||
)
|
||
|
||
result = await mem.retrieve("any key")
|
||
assert result is None
|
||
|
||
async def test_retrieve_returns_memory_item(self):
|
||
"""retrieve 返回 MemoryItem 实例"""
|
||
embedder = MockEmbedder(dimension=32)
|
||
|
||
vec = await embedder.embed("test query")
|
||
now = datetime.now(timezone.utc)
|
||
entry = make_mock_entry(
|
||
input_summary="test input",
|
||
output_summary="test output",
|
||
outcome="success",
|
||
quality_score=0.9,
|
||
embedding=vec,
|
||
created_at=now,
|
||
)
|
||
|
||
factory, _ = make_mock_session_factory([entry])
|
||
|
||
mem = EpisodicMemory(
|
||
session_factory=factory,
|
||
episodic_model=MockEpisodicModel,
|
||
embedder=embedder,
|
||
pgvector_enabled=False,
|
||
)
|
||
|
||
result = await mem.retrieve("test query")
|
||
assert isinstance(result, MemoryItem)
|
||
assert result.value["input_summary"] == "test input"
|
||
assert result.value["output_summary"] == "test output"
|
||
assert result.value["outcome"] == "success"
|
||
assert result.score > 0.0
|
||
|
||
|
||
# ── Alpha 参数测试 ──────────────────────────────────────
|
||
|
||
|
||
class TestAlphaParameter:
|
||
"""alpha 参数控制混合评分平衡"""
|
||
|
||
async def test_alpha_controls_hybrid_balance(self):
|
||
"""alpha 控制语义相似度和时间衰减的平衡"""
|
||
embedder = MockEmbedder(dimension=32)
|
||
|
||
vec_similar = await embedder.embed("machine learning")
|
||
vec_different = await embedder.embed("cooking recipes")
|
||
|
||
now = datetime.now(timezone.utc)
|
||
similar_entry = make_mock_entry(
|
||
quality_score=0.3,
|
||
embedding=vec_similar,
|
||
created_at=now,
|
||
)
|
||
different_entry = make_mock_entry(
|
||
quality_score=0.9,
|
||
embedding=vec_different,
|
||
created_at=now,
|
||
)
|
||
|
||
# alpha=1.0: 纯 cosine → 相似条目排前面
|
||
factory1, _ = make_mock_session_factory([similar_entry, different_entry])
|
||
mem_high_alpha = EpisodicMemory(
|
||
session_factory=factory1,
|
||
episodic_model=MockEpisodicModel,
|
||
embedder=embedder,
|
||
alpha=1.0,
|
||
pgvector_enabled=False,
|
||
)
|
||
results_high = await mem_high_alpha.search("machine learning")
|
||
assert results_high[0].value["quality_score"] == 0.3 # 相似条目
|
||
|
||
# alpha=0.0: 纯 time_decay → 高质量条目排前面
|
||
factory2, _ = make_mock_session_factory([similar_entry, different_entry])
|
||
mem_low_alpha = EpisodicMemory(
|
||
session_factory=factory2,
|
||
episodic_model=MockEpisodicModel,
|
||
embedder=embedder,
|
||
alpha=0.0,
|
||
pgvector_enabled=False,
|
||
)
|
||
results_low = await mem_low_alpha.search("machine learning")
|
||
assert results_low[0].value["quality_score"] == 0.9 # 高质量条目
|
||
|
||
async def test_default_alpha_is_0_7(self):
|
||
"""默认 alpha 值为 0.7"""
|
||
factory, _ = make_mock_session_factory([])
|
||
|
||
mem = EpisodicMemory(
|
||
session_factory=factory,
|
||
episodic_model=MockEpisodicModel,
|
||
)
|
||
|
||
assert mem._alpha == 0.7
|
||
|
||
|
||
# ── pgvector 参数测试 ───────────────────────────────────
|
||
|
||
|
||
class TestPgvectorParameters:
|
||
"""pgvector_enabled 和 table_name 参数测试"""
|
||
|
||
def test_default_pgvector_enabled_is_true(self):
|
||
"""默认 pgvector_enabled 为 True"""
|
||
factory, _ = make_mock_session_factory()
|
||
|
||
mem = EpisodicMemory(
|
||
session_factory=factory,
|
||
episodic_model=MockEpisodicModel,
|
||
)
|
||
|
||
assert mem._pgvector_enabled is True
|
||
|
||
def test_pgvector_enabled_can_be_disabled(self):
|
||
"""可以禁用 pgvector"""
|
||
factory, _ = make_mock_session_factory()
|
||
|
||
mem = EpisodicMemory(
|
||
session_factory=factory,
|
||
episodic_model=MockEpisodicModel,
|
||
pgvector_enabled=False,
|
||
)
|
||
|
||
assert mem._pgvector_enabled is False
|
||
|
||
def test_default_table_name(self):
|
||
"""默认 table_name 为 episodic_memories"""
|
||
factory, _ = make_mock_session_factory()
|
||
|
||
mem = EpisodicMemory(
|
||
session_factory=factory,
|
||
episodic_model=MockEpisodicModel,
|
||
)
|
||
|
||
assert mem._table_name == "episodic_memories"
|
||
|
||
def test_custom_table_name(self):
|
||
"""可以自定义 table_name"""
|
||
factory, _ = make_mock_session_factory()
|
||
|
||
mem = EpisodicMemory(
|
||
session_factory=factory,
|
||
episodic_model=MockEpisodicModel,
|
||
table_name="custom_memories",
|
||
)
|
||
|
||
assert mem._table_name == "custom_memories"
|
||
|
||
async def test_search_uses_client_side_when_pgvector_disabled(self):
|
||
"""pgvector_enabled=False 时使用客户端 cosine similarity"""
|
||
embedder = MockEmbedder(dimension=32)
|
||
|
||
vec_similar = await embedder.embed("test query")
|
||
vec_different = await embedder.embed("unrelated")
|
||
|
||
now = datetime.now(timezone.utc)
|
||
similar_entry = make_mock_entry(
|
||
input_summary="similar task",
|
||
quality_score=0.5,
|
||
embedding=vec_similar,
|
||
created_at=now,
|
||
)
|
||
different_entry = make_mock_entry(
|
||
input_summary="different task",
|
||
quality_score=0.5,
|
||
embedding=vec_different,
|
||
created_at=now,
|
||
)
|
||
|
||
factory, mock_session = make_mock_session_factory([similar_entry, different_entry])
|
||
|
||
mem = EpisodicMemory(
|
||
session_factory=factory,
|
||
episodic_model=MockEpisodicModel,
|
||
embedder=embedder,
|
||
alpha=1.0,
|
||
pgvector_enabled=False,
|
||
)
|
||
|
||
results = await mem.search("test query")
|
||
assert len(results) == 2
|
||
# Client-side should still rank similar entry first
|
||
assert results[0].value["input_summary"] == "similar task"
|
||
|
||
async def test_search_uses_client_side_when_no_embedder(self):
|
||
"""没有 embedder 时即使 pgvector_enabled=True 也使用客户端路径"""
|
||
now = datetime.now(timezone.utc)
|
||
recent_entry = make_mock_entry(
|
||
quality_score=0.8,
|
||
created_at=now - timedelta(hours=1),
|
||
)
|
||
old_entry = make_mock_entry(
|
||
quality_score=0.8,
|
||
created_at=now - timedelta(hours=100),
|
||
)
|
||
|
||
factory, _ = make_mock_session_factory([recent_entry, old_entry])
|
||
|
||
mem = EpisodicMemory(
|
||
session_factory=factory,
|
||
episodic_model=MockEpisodicModel,
|
||
pgvector_enabled=True, # Enabled but no embedder → falls back
|
||
)
|
||
|
||
results = await mem.search("test query")
|
||
assert len(results) == 2
|
||
assert results[0].score > results[1].score
|
||
|
||
async def test_retrieve_uses_client_side_when_pgvector_disabled(self):
|
||
"""pgvector_enabled=False 时 retrieve 使用客户端 cosine similarity"""
|
||
embedder = MockEmbedder(dimension=32)
|
||
|
||
vec = await embedder.embed("test query")
|
||
now = datetime.now(timezone.utc)
|
||
entry = make_mock_entry(
|
||
input_summary="test input",
|
||
embedding=vec,
|
||
created_at=now,
|
||
)
|
||
|
||
factory, _ = make_mock_session_factory([entry])
|
||
|
||
mem = EpisodicMemory(
|
||
session_factory=factory,
|
||
episodic_model=MockEpisodicModel,
|
||
embedder=embedder,
|
||
pgvector_enabled=False,
|
||
)
|
||
|
||
result = await mem.retrieve("test query")
|
||
assert result is not None
|
||
assert result.value["input_summary"] == "test input"
|
||
|
||
|
||
# ── pgvector 原生查询 Mock 测试 ─────────────────────────
|
||
|
||
|
||
class TestPgvectorNativeSearch:
|
||
"""pgvector 原生 ``<=>`` 算符检索测试(使用 mock session)"""
|
||
|
||
async def test_search_pgvector_uses_text_query(self):
|
||
"""pgvector search 使用 SQLAlchemy text() 查询"""
|
||
embedder = MockEmbedder(dimension=32)
|
||
vec = await embedder.embed("test query")
|
||
|
||
now = datetime.now(timezone.utc)
|
||
|
||
# Mock the pgvector raw query result as a dict-like MappingRow
|
||
mock_row = _make_row_mapping({
|
||
"id": str(uuid.uuid4()),
|
||
"agent_name": "test_agent",
|
||
"task_type": "analysis",
|
||
"input_summary": "test input",
|
||
"output_summary": "test output",
|
||
"outcome": "success",
|
||
"quality_score": 0.8,
|
||
"reflection": "",
|
||
"embedding": vec,
|
||
"created_at": now,
|
||
"distance": 0.1,
|
||
})
|
||
|
||
mock_result = MagicMock()
|
||
mock_result.mappings.return_value.all.return_value = [mock_row]
|
||
|
||
mock_session = AsyncMock()
|
||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||
|
||
@asynccontextmanager
|
||
async def factory():
|
||
yield mock_session
|
||
|
||
mem = EpisodicMemory(
|
||
session_factory=factory,
|
||
episodic_model=MockEpisodicModel,
|
||
embedder=embedder,
|
||
pgvector_enabled=True,
|
||
table_name="episodic_memories",
|
||
)
|
||
|
||
results = await mem.search("test query")
|
||
assert len(results) == 1
|
||
assert results[0].value["input_summary"] == "test input"
|
||
|
||
# Verify that execute was called with a text() query
|
||
mock_session.execute.assert_called_once()
|
||
call_args = mock_session.execute.call_args
|
||
sql_obj = call_args[0][0]
|
||
# The SQL should contain the <=> operator
|
||
assert "<=>" in str(sql_obj)
|
||
|
||
async def test_retrieve_pgvector_uses_text_query(self):
|
||
"""pgvector retrieve 使用 SQLAlchemy text() 查询"""
|
||
embedder = MockEmbedder(dimension=32)
|
||
vec = await embedder.embed("test query")
|
||
|
||
now = datetime.now(timezone.utc)
|
||
|
||
mock_row = _make_row_mapping({
|
||
"id": str(uuid.uuid4()),
|
||
"agent_name": "test_agent",
|
||
"task_type": "analysis",
|
||
"input_summary": "test input",
|
||
"output_summary": "test output",
|
||
"outcome": "success",
|
||
"quality_score": 0.8,
|
||
"reflection": "",
|
||
"embedding": vec,
|
||
"created_at": now,
|
||
})
|
||
|
||
mock_result = MagicMock()
|
||
mock_result.mappings.return_value.first.return_value = mock_row
|
||
|
||
mock_session = AsyncMock()
|
||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||
|
||
@asynccontextmanager
|
||
async def factory():
|
||
yield mock_session
|
||
|
||
mem = EpisodicMemory(
|
||
session_factory=factory,
|
||
episodic_model=MockEpisodicModel,
|
||
embedder=embedder,
|
||
pgvector_enabled=True,
|
||
)
|
||
|
||
result = await mem.retrieve("test query")
|
||
assert result is not None
|
||
assert result.value["input_summary"] == "test input"
|
||
|
||
# Verify that execute was called with a text() query
|
||
mock_session.execute.assert_called_once()
|
||
call_args = mock_session.execute.call_args
|
||
sql_obj = call_args[0][0]
|
||
assert "<=>" in str(sql_obj)
|
||
|
||
async def test_search_pgvector_with_filters(self):
|
||
"""pgvector search 应用过滤条件"""
|
||
embedder = MockEmbedder(dimension=32)
|
||
vec = await embedder.embed("test query")
|
||
|
||
now = datetime.now(timezone.utc)
|
||
|
||
mock_row = _make_row_mapping({
|
||
"id": str(uuid.uuid4()),
|
||
"agent_name": "specific_agent",
|
||
"task_type": "analysis",
|
||
"input_summary": "filtered result",
|
||
"output_summary": "output",
|
||
"outcome": "success",
|
||
"quality_score": 0.8,
|
||
"reflection": "",
|
||
"embedding": vec,
|
||
"created_at": now,
|
||
"distance": 0.1,
|
||
})
|
||
|
||
mock_result = MagicMock()
|
||
mock_result.mappings.return_value.all.return_value = [mock_row]
|
||
|
||
mock_session = AsyncMock()
|
||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||
|
||
@asynccontextmanager
|
||
async def factory():
|
||
yield mock_session
|
||
|
||
mem = EpisodicMemory(
|
||
session_factory=factory,
|
||
episodic_model=MockEpisodicModel,
|
||
embedder=embedder,
|
||
pgvector_enabled=True,
|
||
)
|
||
|
||
results = await mem.search("test query", filters={"agent_name": "specific_agent"})
|
||
assert len(results) == 1
|
||
|
||
# Verify the SQL query contains WHERE clause
|
||
call_args = mock_session.execute.call_args
|
||
sql_obj = call_args[0][0]
|
||
sql_text = str(sql_obj)
|
||
assert "WHERE" in sql_text
|
||
assert "agent_name" in sql_text
|
||
|
||
async def test_search_pgvector_empty_result(self):
|
||
"""pgvector search 无结果时返回空列表"""
|
||
embedder = MockEmbedder(dimension=32)
|
||
|
||
mock_result = MagicMock()
|
||
mock_result.mappings.return_value.all.return_value = []
|
||
|
||
mock_session = AsyncMock()
|
||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||
|
||
@asynccontextmanager
|
||
async def factory():
|
||
yield mock_session
|
||
|
||
mem = EpisodicMemory(
|
||
session_factory=factory,
|
||
episodic_model=MockEpisodicModel,
|
||
embedder=embedder,
|
||
pgvector_enabled=True,
|
||
)
|
||
|
||
results = await mem.search("nonexistent")
|
||
assert results == []
|
||
|
||
async def test_retrieve_pgvector_no_embedding_in_row(self):
|
||
"""pgvector retrieve 返回行没有 embedding 时返回 None"""
|
||
embedder = MockEmbedder(dimension=32)
|
||
|
||
mock_row = _make_row_mapping({
|
||
"id": str(uuid.uuid4()),
|
||
"embedding": None,
|
||
})
|
||
|
||
mock_result = MagicMock()
|
||
mock_result.mappings.return_value.first.return_value = mock_row
|
||
|
||
mock_session = AsyncMock()
|
||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||
|
||
@asynccontextmanager
|
||
async def factory():
|
||
yield mock_session
|
||
|
||
mem = EpisodicMemory(
|
||
session_factory=factory,
|
||
episodic_model=MockEpisodicModel,
|
||
embedder=embedder,
|
||
pgvector_enabled=True,
|
||
)
|
||
|
||
result = await mem.retrieve("test query")
|
||
assert result is None
|
||
|
||
async def test_retrieve_pgvector_no_rows(self):
|
||
"""pgvector retrieve 无匹配行时返回 None"""
|
||
embedder = MockEmbedder(dimension=32)
|
||
|
||
mock_result = MagicMock()
|
||
mock_result.mappings.return_value.first.return_value = None
|
||
|
||
mock_session = AsyncMock()
|
||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||
|
||
@asynccontextmanager
|
||
async def factory():
|
||
yield mock_session
|
||
|
||
mem = EpisodicMemory(
|
||
session_factory=factory,
|
||
episodic_model=MockEpisodicModel,
|
||
embedder=embedder,
|
||
pgvector_enabled=True,
|
||
)
|
||
|
||
result = await mem.retrieve("nonexistent")
|
||
assert result is None
|
||
|
||
async def test_search_pgvector_time_decay_reranking(self):
|
||
"""pgvector search 对返回结果做 time_decay 重排"""
|
||
embedder = MockEmbedder(dimension=32)
|
||
vec_similar = await embedder.embed("test query")
|
||
vec_different = await embedder.embed("unrelated")
|
||
|
||
now = datetime.now(timezone.utc)
|
||
|
||
# Row with high cosine but low quality
|
||
row_high_cosine = _make_row_mapping({
|
||
"id": str(uuid.uuid4()),
|
||
"agent_name": "",
|
||
"task_type": "",
|
||
"input_summary": "similar but low quality",
|
||
"output_summary": "",
|
||
"outcome": "success",
|
||
"quality_score": 0.3,
|
||
"reflection": "",
|
||
"embedding": vec_similar,
|
||
"created_at": now,
|
||
"distance": 0.1,
|
||
})
|
||
|
||
# Row with lower cosine but high quality
|
||
row_low_cosine = _make_row_mapping({
|
||
"id": str(uuid.uuid4()),
|
||
"agent_name": "",
|
||
"task_type": "",
|
||
"input_summary": "different but high quality",
|
||
"output_summary": "",
|
||
"outcome": "success",
|
||
"quality_score": 0.9,
|
||
"reflection": "",
|
||
"embedding": vec_different,
|
||
"created_at": now,
|
||
"distance": 0.5,
|
||
})
|
||
|
||
mock_result = MagicMock()
|
||
mock_result.mappings.return_value.all.return_value = [
|
||
row_high_cosine,
|
||
row_low_cosine,
|
||
]
|
||
|
||
mock_session = AsyncMock()
|
||
mock_session.execute = AsyncMock(return_value=mock_result)
|
||
|
||
@asynccontextmanager
|
||
async def factory():
|
||
yield mock_session
|
||
|
||
# alpha=1.0: pure cosine → similar entry first
|
||
mem = EpisodicMemory(
|
||
session_factory=factory,
|
||
episodic_model=MockEpisodicModel,
|
||
embedder=embedder,
|
||
alpha=1.0,
|
||
pgvector_enabled=True,
|
||
)
|
||
|
||
results = await mem.search("test query")
|
||
assert len(results) == 2
|
||
# With alpha=1.0, cosine dominates, so similar entry should be first
|
||
assert results[0].value["input_summary"] == "similar but low quality"
|