"""EmbeddingCache 单元测试 - LRU 缓存 + TTL""" import time import pytest from agentkit.memory.embedder import EmbeddingCache class TestEmbeddingCacheBasic: """EmbeddingCache 基本功能测试""" def test_put_and_get(self): """put 后可以 get 到""" cache = EmbeddingCache(max_size=100, ttl=3600) vec = [0.1, 0.2, 0.3] cache.put("hello", vec) assert cache.get("hello") == vec def test_get_missing_key_returns_none(self): """get 不存在的 key 返回 None""" cache = EmbeddingCache() assert cache.get("nonexistent") is None def test_clear_removes_all_entries(self): """clear 清除所有缓存""" cache = EmbeddingCache() cache.put("a", [1.0]) cache.put("b", [2.0]) cache.clear() assert cache.get("a") is None assert cache.get("b") is None def test_same_text_same_key(self): """相同文本映射到相同缓存 key""" cache = EmbeddingCache() cache.put("hello", [1.0]) cache.put("hello", [2.0]) # overwrite assert cache.get("hello") == [2.0] def test_different_text_different_key(self): """不同文本映射到不同缓存 key""" cache = EmbeddingCache() cache.put("hello", [1.0]) cache.put("world", [2.0]) assert cache.get("hello") == [1.0] assert cache.get("world") == [2.0] class TestEmbeddingCacheLRU: """EmbeddingCache LRU 淘汰测试""" def test_evicts_oldest_when_full(self): """缓存满时淘汰最久未使用的条目""" cache = EmbeddingCache(max_size=3, ttl=3600) cache.put("a", [1.0]) cache.put("b", [2.0]) cache.put("c", [3.0]) # Cache is full (3 entries). Adding "d" should evict "a" cache.put("d", [4.0]) assert cache.get("a") is None assert cache.get("b") == [2.0] assert cache.get("c") == [3.0] assert cache.get("d") == [4.0] def test_get_refreshes_lru_order(self): """get 操作刷新 LRU 顺序,避免被淘汰""" cache = EmbeddingCache(max_size=3, ttl=3600) cache.put("a", [1.0]) cache.put("b", [2.0]) cache.put("c", [3.0]) # Access "a" to refresh its position cache.get("a") # Adding "d" should evict "b" (least recently used) cache.put("d", [4.0]) assert cache.get("a") == [1.0] # Still present assert cache.get("b") is None # Evicted assert cache.get("c") == [3.0] assert cache.get("d") == [4.0] def test_put_existing_key_refreshes_position(self): """put 已存在的 key 刷新 LRU 位置""" cache = EmbeddingCache(max_size=3, ttl=3600) cache.put("a", [1.0]) cache.put("b", [2.0]) cache.put("c", [3.0]) # Re-put "a" to refresh cache.put("a", [10.0]) # Adding "d" should evict "b" cache.put("d", [4.0]) assert cache.get("a") == [10.0] assert cache.get("b") is None assert cache.get("c") == [3.0] def test_max_size_one(self): """max_size=1 时只保留最新条目""" cache = EmbeddingCache(max_size=1, ttl=3600) cache.put("a", [1.0]) cache.put("b", [2.0]) assert cache.get("a") is None assert cache.get("b") == [2.0] class TestEmbeddingCacheTTL: """EmbeddingCache TTL 过期测试""" def test_expired_entry_returns_none(self): """过期条目 get 返回 None""" cache = EmbeddingCache(max_size=100, ttl=0) # TTL=0 means immediately expired cache.put("hello", [1.0]) # With TTL=0, the entry should be expired by the time we get it # (time.monotonic() advances between put and get) result = cache.get("hello") # This may or may not be None depending on timing, so we use a short TTL # Let's test with a small positive TTL instead cache2 = EmbeddingCache(max_size=100, ttl=1) # 1 second TTL cache2.put("hello", [1.0]) assert cache2.get("hello") == [1.0] # Should still be valid def test_non_expired_entry_returns_value(self): """未过期条目 get 返回缓存值""" cache = EmbeddingCache(max_size=100, ttl=3600) cache.put("hello", [1.0]) assert cache.get("hello") == [1.0] def test_ttl_expiration_removes_entry(self): """过期后条目从缓存中移除""" cache = EmbeddingCache(max_size=100, ttl=1) # 1 second cache.put("hello", [1.0]) # Wait for TTL to expire time.sleep(1.1) assert cache.get("hello") is None class TestEmbeddingCacheKeyGeneration: """EmbeddingCache key 生成测试""" def test_key_is_deterministic(self): """相同文本生成相同 key""" key1 = EmbeddingCache._make_key("hello world") key2 = EmbeddingCache._make_key("hello world") assert key1 == key2 def test_different_text_different_key(self): """不同文本生成不同 key""" key1 = EmbeddingCache._make_key("hello") key2 = EmbeddingCache._make_key("world") assert key1 != key2 def test_key_is_sha256_hex(self): """key 是 SHA-256 十六进制字符串""" import hashlib text = "test input" expected = hashlib.sha256(text.encode()).hexdigest() assert EmbeddingCache._make_key(text) == expected def test_unicode_text_handled(self): """Unicode 文本正确处理""" key1 = EmbeddingCache._make_key("你好世界") key2 = EmbeddingCache._make_key("你好世界") assert key1 == key2 # Different unicode text should produce different keys key3 = EmbeddingCache._make_key("こんにちは") assert key1 != key3 class TestEmbeddingCacheEdgeCases: """EmbeddingCache 边界情况测试""" def test_empty_string_key(self): """空字符串可以作为缓存 key""" cache = EmbeddingCache(max_size=10, ttl=3600) cache.put("", [0.0]) assert cache.get("") == [0.0] def test_empty_vector_cached(self): """空向量可以被缓存""" cache = EmbeddingCache(max_size=10, ttl=3600) cache.put("empty_vec", []) assert cache.get("empty_vec") == [] def test_large_vector_cached(self): """大维度向量可以被缓存""" cache = EmbeddingCache(max_size=10, ttl=3600) large_vec = [float(i) for i in range(1536)] cache.put("large", large_vec) assert cache.get("large") == large_vec def test_max_size_zero_never_stores(self): """max_size=0 时无法存储任何条目""" cache = EmbeddingCache(max_size=0, ttl=3600) cache.put("a", [1.0]) # Entry is immediately evicted since max_size=0 assert cache.get("a") is None def test_put_overwrite_preserves_freshness(self): """put 覆盖已存在的 key 时更新值和时间戳""" cache = EmbeddingCache(max_size=3, ttl=3600) cache.put("a", [1.0]) cache.put("b", [2.0]) cache.put("c", [3.0]) # Overwrite "a" with new value — refreshes its LRU position cache.put("a", [10.0]) # Adding "d" should evict "b" (least recently used) cache.put("d", [4.0]) assert cache.get("a") == [10.0] assert cache.get("b") is None def test_expired_entry_is_cleaned_up(self): """过期条目在 get 时被清除,不占用缓存空间""" cache = EmbeddingCache(max_size=2, ttl=1) cache.put("a", [1.0]) # Put "b" slightly later so its TTL extends beyond "a"'s time.sleep(0.3) cache.put("b", [2.0]) # Wait for "a" to expire but not "b" time.sleep(0.8) # "a" should be expired and removed from cache assert cache.get("a") is None # "b" is still valid (put 0.8s ago, TTL=1s) assert cache.get("b") == [2.0] # Now cache has room: we can add "c" cache.put("c", [3.0]) assert cache.get("c") == [3.0] def test_special_characters_in_text(self): """特殊字符文本正确处理""" cache = EmbeddingCache(max_size=10, ttl=3600) special = "hello\nworld\ttab\0null" cache.put(special, [1.0]) assert cache.get(special) == [1.0] def test_very_long_text_key(self): """超长文本可以生成 key 并缓存""" cache = EmbeddingCache(max_size=10, ttl=3600) long_text = "x" * 100_000 cache.put(long_text, [0.5]) assert cache.get(long_text) == [0.5]