"""WorkingMemory 单元测试 - 基于 Redis 的短期任务记忆""" import asyncio import json import pytest from agentkit.memory.working import WorkingMemory # ── Redis 可用性检测 ────────────────────────────────────── def _redis_available(): """检测 Redis 是否可用,不可用则跳过测试""" import redis as sync_redis try: r = sync_redis.Redis(host="localhost", port=6381, db=0) r.ping() r.close() return True except Exception: return False skip_if_no_redis = pytest.mark.skipif( not _redis_available(), reason="Redis not available at localhost:6381", ) # ── WorkingMemory 测试 ─────────────────────────────────── @skip_if_no_redis @pytest.mark.redis class TestWorkingMemory: """WorkingMemory 真实 Redis 连接测试""" async def test_store_and_retrieve(self, redis_client, clean_redis): """store + retrieve 返回相同值""" mem = WorkingMemory(redis=redis_client, key_prefix="test:working") await mem.store("key1", {"name": "alice", "age": 30}) item = await mem.retrieve("key1") assert item is not None assert item.key == "key1" assert item.value["name"] == "alice" assert item.value["age"] == 30 async def test_ttl_expiration(self, redis_client, clean_redis): """TTL 过期后 retrieve 返回 None""" mem = WorkingMemory(redis=redis_client, key_prefix="test:working", default_ttl=1) await mem.store("short_lived", "will expire soon") # 立即获取应该存在 item = await mem.retrieve("short_lived") assert item is not None # 等待 TTL 过期 await asyncio.sleep(1.5) item = await mem.retrieve("short_lived") assert item is None async def test_get_context(self, redis_client, clean_redis): """get_context() 返回格式化的上下文字符串""" mem = WorkingMemory(redis=redis_client, key_prefix="test:working") await mem.store("task:1", "Generate AI report") await mem.store("task:2", "Analyze data trends") context = await mem.get_context("task") # get_context 调用 search,search 按 key 前缀匹配 assert isinstance(context, str) # 至少应包含其中一个值 assert "AI report" in context or "data trends" in context async def test_key_prefix_isolation(self, redis_client, clean_redis): """不同 key_prefix 的 WorkingMemory 互相隔离""" mem_a = WorkingMemory(redis=redis_client, key_prefix="test:agent_a") mem_b = WorkingMemory(redis=redis_client, key_prefix="test:agent_b") await mem_a.store("shared_key", "value_from_a") await mem_b.store("shared_key", "value_from_b") item_a = await mem_a.retrieve("shared_key") item_b = await mem_b.retrieve("shared_key") assert item_a is not None assert item_b is not None assert item_a.value == "value_from_a" assert item_b.value == "value_from_b" async def test_delete_then_retrieve(self, redis_client, clean_redis): """delete 后 retrieve 返回 None""" mem = WorkingMemory(redis=redis_client, key_prefix="test:working") await mem.store("to_delete", "temporary data") result = await mem.delete("to_delete") assert result is True item = await mem.retrieve("to_delete") assert item is None async def test_delete_nonexistent_key(self, redis_client, clean_redis): """删除不存在的 key 返回 False""" mem = WorkingMemory(redis=redis_client, key_prefix="test:working") result = await mem.delete("nonexistent_key") assert result is False async def test_store_complex_nested_dict(self, redis_client, clean_redis): """存储复杂嵌套字典,retrieve 正确还原""" mem = WorkingMemory(redis=redis_client, key_prefix="test:working") complex_data = { "level1": { "level2": { "level3": [1, 2, 3], "nested_str": "deep value", }, "items": [{"id": i, "name": f"item_{i}"} for i in range(5)], }, "count": 42, } await mem.store("complex", complex_data) item = await mem.retrieve("complex") assert item is not None assert item.value["level1"]["level2"]["level3"] == [1, 2, 3] assert item.value["level1"]["level2"]["nested_str"] == "deep value" assert len(item.value["level1"]["items"]) == 5 assert item.value["count"] == 42 async def test_search_by_key_prefix(self, redis_client, clean_redis): """search 按 key 前缀模式匹配""" mem = WorkingMemory(redis=redis_client, key_prefix="test:working") await mem.store("user:profile", {"name": "alice"}) await mem.store("user:settings", {"theme": "dark"}) await mem.store("task:report", {"type": "monthly"}) # 搜索以 "user:" 开头的 key results = await mem.search("user:") assert len(results) >= 2 keys = [item.key for item in results] assert "user:profile" in keys assert "user:settings" in keys assert "task:report" not in keys async def test_search_top_k_limit(self, redis_client, clean_redis): """search 的 top_k 限制返回数量""" mem = WorkingMemory(redis=redis_client, key_prefix="test:working") for i in range(10): await mem.store(f"item:{i:02d}", f"value_{i}") results = await mem.search("item:", top_k=3) assert len(results) <= 3 async def test_retrieve_nonexistent(self, redis_client, clean_redis): """retrieve 不存在的 key 返回 None""" mem = WorkingMemory(redis=redis_client, key_prefix="test:working") item = await mem.retrieve("does_not_exist") assert item is None async def test_store_with_metadata(self, redis_client, clean_redis): """store 携带 metadata,retrieve 正确还原""" mem = WorkingMemory(redis=redis_client, key_prefix="test:working") await mem.store("meta_key", "some value", {"tag": "important", "priority": 1}) item = await mem.retrieve("meta_key") assert item is not None assert item.metadata["tag"] == "important" assert item.metadata["priority"] == 1 async def test_clear(self, redis_client, clean_redis): """clear 清除指定前缀的所有 Working Memory""" mem = WorkingMemory(redis=redis_client, key_prefix="test:working") await mem.store("a:1", "value_a1") await mem.store("a:2", "value_a2") await mem.store("b:1", "value_b1") count = await mem.clear(prefix="a:") assert count >= 2 # a: 前缀的应该被清除 assert await mem.retrieve("a:1") is None assert await mem.retrieve("a:2") is None # b: 前缀的应该保留 item = await mem.retrieve("b:1") assert item is not None