fischer-agentkit/tests/unit/test_working_memory.py

189 lines
7.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""WorkingMemory 单元测试 - 基于 Redis 的短期任务记忆"""
import asyncio
import json
import pytest
from agentkit.memory.working import WorkingMemory
# ── Redis 可用性检测 ──────────────────────────────────────
def _redis_available():
"""检测 Redis 是否可用,不可用则跳过测试"""
import redis as sync_redis
try:
r = sync_redis.Redis(host="localhost", port=6381, db=0)
r.ping()
r.close()
return True
except Exception:
return False
skip_if_no_redis = pytest.mark.skipif(
not _redis_available(),
reason="Redis not available at localhost:6381",
)
# ── WorkingMemory 测试 ───────────────────────────────────
@skip_if_no_redis
@pytest.mark.redis
class TestWorkingMemory:
"""WorkingMemory 真实 Redis 连接测试"""
async def test_store_and_retrieve(self, redis_client, clean_redis):
"""store + retrieve 返回相同值"""
mem = WorkingMemory(redis=redis_client, key_prefix="test:working")
await mem.store("key1", {"name": "alice", "age": 30})
item = await mem.retrieve("key1")
assert item is not None
assert item.key == "key1"
assert item.value["name"] == "alice"
assert item.value["age"] == 30
async def test_ttl_expiration(self, redis_client, clean_redis):
"""TTL 过期后 retrieve 返回 None"""
mem = WorkingMemory(redis=redis_client, key_prefix="test:working", default_ttl=1)
await mem.store("short_lived", "will expire soon")
# 立即获取应该存在
item = await mem.retrieve("short_lived")
assert item is not None
# 等待 TTL 过期
await asyncio.sleep(1.5)
item = await mem.retrieve("short_lived")
assert item is None
async def test_get_context(self, redis_client, clean_redis):
"""get_context() 返回格式化的上下文字符串"""
mem = WorkingMemory(redis=redis_client, key_prefix="test:working")
await mem.store("task:1", "Generate AI report")
await mem.store("task:2", "Analyze data trends")
context = await mem.get_context("task")
# get_context 调用 searchsearch 按 key 前缀匹配
assert isinstance(context, str)
# 至少应包含其中一个值
assert "AI report" in context or "data trends" in context
async def test_key_prefix_isolation(self, redis_client, clean_redis):
"""不同 key_prefix 的 WorkingMemory 互相隔离"""
mem_a = WorkingMemory(redis=redis_client, key_prefix="test:agent_a")
mem_b = WorkingMemory(redis=redis_client, key_prefix="test:agent_b")
await mem_a.store("shared_key", "value_from_a")
await mem_b.store("shared_key", "value_from_b")
item_a = await mem_a.retrieve("shared_key")
item_b = await mem_b.retrieve("shared_key")
assert item_a is not None
assert item_b is not None
assert item_a.value == "value_from_a"
assert item_b.value == "value_from_b"
async def test_delete_then_retrieve(self, redis_client, clean_redis):
"""delete 后 retrieve 返回 None"""
mem = WorkingMemory(redis=redis_client, key_prefix="test:working")
await mem.store("to_delete", "temporary data")
result = await mem.delete("to_delete")
assert result is True
item = await mem.retrieve("to_delete")
assert item is None
async def test_delete_nonexistent_key(self, redis_client, clean_redis):
"""删除不存在的 key 返回 False"""
mem = WorkingMemory(redis=redis_client, key_prefix="test:working")
result = await mem.delete("nonexistent_key")
assert result is False
async def test_store_complex_nested_dict(self, redis_client, clean_redis):
"""存储复杂嵌套字典retrieve 正确还原"""
mem = WorkingMemory(redis=redis_client, key_prefix="test:working")
complex_data = {
"level1": {
"level2": {
"level3": [1, 2, 3],
"nested_str": "deep value",
},
"items": [{"id": i, "name": f"item_{i}"} for i in range(5)],
},
"count": 42,
}
await mem.store("complex", complex_data)
item = await mem.retrieve("complex")
assert item is not None
assert item.value["level1"]["level2"]["level3"] == [1, 2, 3]
assert item.value["level1"]["level2"]["nested_str"] == "deep value"
assert len(item.value["level1"]["items"]) == 5
assert item.value["count"] == 42
async def test_search_by_key_prefix(self, redis_client, clean_redis):
"""search 按 key 前缀模式匹配"""
mem = WorkingMemory(redis=redis_client, key_prefix="test:working")
await mem.store("user:profile", {"name": "alice"})
await mem.store("user:settings", {"theme": "dark"})
await mem.store("task:report", {"type": "monthly"})
# 搜索以 "user:" 开头的 key
results = await mem.search("user:")
assert len(results) >= 2
keys = [item.key for item in results]
assert "user:profile" in keys
assert "user:settings" in keys
assert "task:report" not in keys
async def test_search_top_k_limit(self, redis_client, clean_redis):
"""search 的 top_k 限制返回数量"""
mem = WorkingMemory(redis=redis_client, key_prefix="test:working")
for i in range(10):
await mem.store(f"item:{i:02d}", f"value_{i}")
results = await mem.search("item:", top_k=3)
assert len(results) <= 3
async def test_retrieve_nonexistent(self, redis_client, clean_redis):
"""retrieve 不存在的 key 返回 None"""
mem = WorkingMemory(redis=redis_client, key_prefix="test:working")
item = await mem.retrieve("does_not_exist")
assert item is None
async def test_store_with_metadata(self, redis_client, clean_redis):
"""store 携带 metadataretrieve 正确还原"""
mem = WorkingMemory(redis=redis_client, key_prefix="test:working")
await mem.store("meta_key", "some value", {"tag": "important", "priority": 1})
item = await mem.retrieve("meta_key")
assert item is not None
assert item.metadata["tag"] == "important"
assert item.metadata["priority"] == 1
async def test_clear(self, redis_client, clean_redis):
"""clear 清除指定前缀的所有 Working Memory"""
mem = WorkingMemory(redis=redis_client, key_prefix="test:working")
await mem.store("a:1", "value_a1")
await mem.store("a:2", "value_a2")
await mem.store("b:1", "value_b1")
count = await mem.clear(prefix="a:")
assert count >= 2
# a: 前缀的应该被清除
assert await mem.retrieve("a:1") is None
assert await mem.retrieve("a:2") is None
# b: 前缀的应该保留
item = await mem.retrieve("b:1")
assert item is not None