189 lines
7.1 KiB
Python
189 lines
7.1 KiB
Python
"""WorkingMemory 单元测试 - 基于 Redis 的短期任务记忆"""
|
||
|
||
import asyncio
|
||
import json
|
||
|
||
import pytest
|
||
|
||
from agentkit.memory.working import WorkingMemory
|
||
|
||
|
||
# ── Redis 可用性检测 ──────────────────────────────────────
|
||
|
||
|
||
def _redis_available():
|
||
"""检测 Redis 是否可用,不可用则跳过测试"""
|
||
import redis as sync_redis
|
||
|
||
try:
|
||
r = sync_redis.Redis(host="localhost", port=6381, db=0)
|
||
r.ping()
|
||
r.close()
|
||
return True
|
||
except Exception:
|
||
return False
|
||
|
||
|
||
skip_if_no_redis = pytest.mark.skipif(
|
||
not _redis_available(),
|
||
reason="Redis not available at localhost:6381",
|
||
)
|
||
|
||
|
||
# ── WorkingMemory 测试 ───────────────────────────────────
|
||
|
||
|
||
@skip_if_no_redis
|
||
@pytest.mark.redis
|
||
class TestWorkingMemory:
|
||
"""WorkingMemory 真实 Redis 连接测试"""
|
||
|
||
async def test_store_and_retrieve(self, redis_client, clean_redis):
|
||
"""store + retrieve 返回相同值"""
|
||
mem = WorkingMemory(redis=redis_client, key_prefix="test:working")
|
||
await mem.store("key1", {"name": "alice", "age": 30})
|
||
|
||
item = await mem.retrieve("key1")
|
||
assert item is not None
|
||
assert item.key == "key1"
|
||
assert item.value["name"] == "alice"
|
||
assert item.value["age"] == 30
|
||
|
||
async def test_ttl_expiration(self, redis_client, clean_redis):
|
||
"""TTL 过期后 retrieve 返回 None"""
|
||
mem = WorkingMemory(redis=redis_client, key_prefix="test:working", default_ttl=1)
|
||
await mem.store("short_lived", "will expire soon")
|
||
|
||
# 立即获取应该存在
|
||
item = await mem.retrieve("short_lived")
|
||
assert item is not None
|
||
|
||
# 等待 TTL 过期
|
||
await asyncio.sleep(1.5)
|
||
item = await mem.retrieve("short_lived")
|
||
assert item is None
|
||
|
||
async def test_get_context(self, redis_client, clean_redis):
|
||
"""get_context() 返回格式化的上下文字符串"""
|
||
mem = WorkingMemory(redis=redis_client, key_prefix="test:working")
|
||
await mem.store("task:1", "Generate AI report")
|
||
await mem.store("task:2", "Analyze data trends")
|
||
|
||
context = await mem.get_context("task")
|
||
# get_context 调用 search,search 按 key 前缀匹配
|
||
assert isinstance(context, str)
|
||
# 至少应包含其中一个值
|
||
assert "AI report" in context or "data trends" in context
|
||
|
||
async def test_key_prefix_isolation(self, redis_client, clean_redis):
|
||
"""不同 key_prefix 的 WorkingMemory 互相隔离"""
|
||
mem_a = WorkingMemory(redis=redis_client, key_prefix="test:agent_a")
|
||
mem_b = WorkingMemory(redis=redis_client, key_prefix="test:agent_b")
|
||
|
||
await mem_a.store("shared_key", "value_from_a")
|
||
await mem_b.store("shared_key", "value_from_b")
|
||
|
||
item_a = await mem_a.retrieve("shared_key")
|
||
item_b = await mem_b.retrieve("shared_key")
|
||
|
||
assert item_a is not None
|
||
assert item_b is not None
|
||
assert item_a.value == "value_from_a"
|
||
assert item_b.value == "value_from_b"
|
||
|
||
async def test_delete_then_retrieve(self, redis_client, clean_redis):
|
||
"""delete 后 retrieve 返回 None"""
|
||
mem = WorkingMemory(redis=redis_client, key_prefix="test:working")
|
||
await mem.store("to_delete", "temporary data")
|
||
|
||
result = await mem.delete("to_delete")
|
||
assert result is True
|
||
|
||
item = await mem.retrieve("to_delete")
|
||
assert item is None
|
||
|
||
async def test_delete_nonexistent_key(self, redis_client, clean_redis):
|
||
"""删除不存在的 key 返回 False"""
|
||
mem = WorkingMemory(redis=redis_client, key_prefix="test:working")
|
||
result = await mem.delete("nonexistent_key")
|
||
assert result is False
|
||
|
||
async def test_store_complex_nested_dict(self, redis_client, clean_redis):
|
||
"""存储复杂嵌套字典,retrieve 正确还原"""
|
||
mem = WorkingMemory(redis=redis_client, key_prefix="test:working")
|
||
complex_data = {
|
||
"level1": {
|
||
"level2": {
|
||
"level3": [1, 2, 3],
|
||
"nested_str": "deep value",
|
||
},
|
||
"items": [{"id": i, "name": f"item_{i}"} for i in range(5)],
|
||
},
|
||
"count": 42,
|
||
}
|
||
await mem.store("complex", complex_data)
|
||
|
||
item = await mem.retrieve("complex")
|
||
assert item is not None
|
||
assert item.value["level1"]["level2"]["level3"] == [1, 2, 3]
|
||
assert item.value["level1"]["level2"]["nested_str"] == "deep value"
|
||
assert len(item.value["level1"]["items"]) == 5
|
||
assert item.value["count"] == 42
|
||
|
||
async def test_search_by_key_prefix(self, redis_client, clean_redis):
|
||
"""search 按 key 前缀模式匹配"""
|
||
mem = WorkingMemory(redis=redis_client, key_prefix="test:working")
|
||
await mem.store("user:profile", {"name": "alice"})
|
||
await mem.store("user:settings", {"theme": "dark"})
|
||
await mem.store("task:report", {"type": "monthly"})
|
||
|
||
# 搜索以 "user:" 开头的 key
|
||
results = await mem.search("user:")
|
||
assert len(results) >= 2
|
||
keys = [item.key for item in results]
|
||
assert "user:profile" in keys
|
||
assert "user:settings" in keys
|
||
assert "task:report" not in keys
|
||
|
||
async def test_search_top_k_limit(self, redis_client, clean_redis):
|
||
"""search 的 top_k 限制返回数量"""
|
||
mem = WorkingMemory(redis=redis_client, key_prefix="test:working")
|
||
for i in range(10):
|
||
await mem.store(f"item:{i:02d}", f"value_{i}")
|
||
|
||
results = await mem.search("item:", top_k=3)
|
||
assert len(results) <= 3
|
||
|
||
async def test_retrieve_nonexistent(self, redis_client, clean_redis):
|
||
"""retrieve 不存在的 key 返回 None"""
|
||
mem = WorkingMemory(redis=redis_client, key_prefix="test:working")
|
||
item = await mem.retrieve("does_not_exist")
|
||
assert item is None
|
||
|
||
async def test_store_with_metadata(self, redis_client, clean_redis):
|
||
"""store 携带 metadata,retrieve 正确还原"""
|
||
mem = WorkingMemory(redis=redis_client, key_prefix="test:working")
|
||
await mem.store("meta_key", "some value", {"tag": "important", "priority": 1})
|
||
|
||
item = await mem.retrieve("meta_key")
|
||
assert item is not None
|
||
assert item.metadata["tag"] == "important"
|
||
assert item.metadata["priority"] == 1
|
||
|
||
async def test_clear(self, redis_client, clean_redis):
|
||
"""clear 清除指定前缀的所有 Working Memory"""
|
||
mem = WorkingMemory(redis=redis_client, key_prefix="test:working")
|
||
await mem.store("a:1", "value_a1")
|
||
await mem.store("a:2", "value_a2")
|
||
await mem.store("b:1", "value_b1")
|
||
|
||
count = await mem.clear(prefix="a:")
|
||
assert count >= 2
|
||
|
||
# a: 前缀的应该被清除
|
||
assert await mem.retrieve("a:1") is None
|
||
assert await mem.retrieve("a:2") is None
|
||
# b: 前缀的应该保留
|
||
item = await mem.retrieve("b:1")
|
||
assert item is not None
|