fischer-agentkit/tests/unit/test_embedding_cache.py

239 lines
8.4 KiB
Python

"""EmbeddingCache 单元测试 - LRU 缓存 + TTL"""
import time
import pytest
from agentkit.memory.embedder import EmbeddingCache
class TestEmbeddingCacheBasic:
"""EmbeddingCache 基本功能测试"""
def test_put_and_get(self):
"""put 后可以 get 到"""
cache = EmbeddingCache(max_size=100, ttl=3600)
vec = [0.1, 0.2, 0.3]
cache.put("hello", vec)
assert cache.get("hello") == vec
def test_get_missing_key_returns_none(self):
"""get 不存在的 key 返回 None"""
cache = EmbeddingCache()
assert cache.get("nonexistent") is None
def test_clear_removes_all_entries(self):
"""clear 清除所有缓存"""
cache = EmbeddingCache()
cache.put("a", [1.0])
cache.put("b", [2.0])
cache.clear()
assert cache.get("a") is None
assert cache.get("b") is None
def test_same_text_same_key(self):
"""相同文本映射到相同缓存 key"""
cache = EmbeddingCache()
cache.put("hello", [1.0])
cache.put("hello", [2.0]) # overwrite
assert cache.get("hello") == [2.0]
def test_different_text_different_key(self):
"""不同文本映射到不同缓存 key"""
cache = EmbeddingCache()
cache.put("hello", [1.0])
cache.put("world", [2.0])
assert cache.get("hello") == [1.0]
assert cache.get("world") == [2.0]
class TestEmbeddingCacheLRU:
"""EmbeddingCache LRU 淘汰测试"""
def test_evicts_oldest_when_full(self):
"""缓存满时淘汰最久未使用的条目"""
cache = EmbeddingCache(max_size=3, ttl=3600)
cache.put("a", [1.0])
cache.put("b", [2.0])
cache.put("c", [3.0])
# Cache is full (3 entries). Adding "d" should evict "a"
cache.put("d", [4.0])
assert cache.get("a") is None
assert cache.get("b") == [2.0]
assert cache.get("c") == [3.0]
assert cache.get("d") == [4.0]
def test_get_refreshes_lru_order(self):
"""get 操作刷新 LRU 顺序,避免被淘汰"""
cache = EmbeddingCache(max_size=3, ttl=3600)
cache.put("a", [1.0])
cache.put("b", [2.0])
cache.put("c", [3.0])
# Access "a" to refresh its position
cache.get("a")
# Adding "d" should evict "b" (least recently used)
cache.put("d", [4.0])
assert cache.get("a") == [1.0] # Still present
assert cache.get("b") is None # Evicted
assert cache.get("c") == [3.0]
assert cache.get("d") == [4.0]
def test_put_existing_key_refreshes_position(self):
"""put 已存在的 key 刷新 LRU 位置"""
cache = EmbeddingCache(max_size=3, ttl=3600)
cache.put("a", [1.0])
cache.put("b", [2.0])
cache.put("c", [3.0])
# Re-put "a" to refresh
cache.put("a", [10.0])
# Adding "d" should evict "b"
cache.put("d", [4.0])
assert cache.get("a") == [10.0]
assert cache.get("b") is None
assert cache.get("c") == [3.0]
def test_max_size_one(self):
"""max_size=1 时只保留最新条目"""
cache = EmbeddingCache(max_size=1, ttl=3600)
cache.put("a", [1.0])
cache.put("b", [2.0])
assert cache.get("a") is None
assert cache.get("b") == [2.0]
class TestEmbeddingCacheTTL:
"""EmbeddingCache TTL 过期测试"""
def test_expired_entry_returns_none(self):
"""过期条目 get 返回 None"""
cache = EmbeddingCache(max_size=100, ttl=0) # TTL=0 means immediately expired
cache.put("hello", [1.0])
# With TTL=0, the entry should be expired by the time we get it
# (time.monotonic() advances between put and get)
result = cache.get("hello")
# This may or may not be None depending on timing, so we use a short TTL
# Let's test with a small positive TTL instead
cache2 = EmbeddingCache(max_size=100, ttl=1) # 1 second TTL
cache2.put("hello", [1.0])
assert cache2.get("hello") == [1.0] # Should still be valid
def test_non_expired_entry_returns_value(self):
"""未过期条目 get 返回缓存值"""
cache = EmbeddingCache(max_size=100, ttl=3600)
cache.put("hello", [1.0])
assert cache.get("hello") == [1.0]
def test_ttl_expiration_removes_entry(self):
"""过期后条目从缓存中移除"""
cache = EmbeddingCache(max_size=100, ttl=1) # 1 second
cache.put("hello", [1.0])
# Wait for TTL to expire
time.sleep(1.1)
assert cache.get("hello") is None
class TestEmbeddingCacheKeyGeneration:
"""EmbeddingCache key 生成测试"""
def test_key_is_deterministic(self):
"""相同文本生成相同 key"""
key1 = EmbeddingCache._make_key("hello world")
key2 = EmbeddingCache._make_key("hello world")
assert key1 == key2
def test_different_text_different_key(self):
"""不同文本生成不同 key"""
key1 = EmbeddingCache._make_key("hello")
key2 = EmbeddingCache._make_key("world")
assert key1 != key2
def test_key_is_sha256_hex(self):
"""key 是 SHA-256 十六进制字符串"""
import hashlib
text = "test input"
expected = hashlib.sha256(text.encode()).hexdigest()
assert EmbeddingCache._make_key(text) == expected
def test_unicode_text_handled(self):
"""Unicode 文本正确处理"""
key1 = EmbeddingCache._make_key("你好世界")
key2 = EmbeddingCache._make_key("你好世界")
assert key1 == key2
# Different unicode text should produce different keys
key3 = EmbeddingCache._make_key("こんにちは")
assert key1 != key3
class TestEmbeddingCacheEdgeCases:
"""EmbeddingCache 边界情况测试"""
def test_empty_string_key(self):
"""空字符串可以作为缓存 key"""
cache = EmbeddingCache(max_size=10, ttl=3600)
cache.put("", [0.0])
assert cache.get("") == [0.0]
def test_empty_vector_cached(self):
"""空向量可以被缓存"""
cache = EmbeddingCache(max_size=10, ttl=3600)
cache.put("empty_vec", [])
assert cache.get("empty_vec") == []
def test_large_vector_cached(self):
"""大维度向量可以被缓存"""
cache = EmbeddingCache(max_size=10, ttl=3600)
large_vec = [float(i) for i in range(1536)]
cache.put("large", large_vec)
assert cache.get("large") == large_vec
def test_max_size_zero_never_stores(self):
"""max_size=0 时无法存储任何条目"""
cache = EmbeddingCache(max_size=0, ttl=3600)
cache.put("a", [1.0])
# Entry is immediately evicted since max_size=0
assert cache.get("a") is None
def test_put_overwrite_preserves_freshness(self):
"""put 覆盖已存在的 key 时更新值和时间戳"""
cache = EmbeddingCache(max_size=3, ttl=3600)
cache.put("a", [1.0])
cache.put("b", [2.0])
cache.put("c", [3.0])
# Overwrite "a" with new value — refreshes its LRU position
cache.put("a", [10.0])
# Adding "d" should evict "b" (least recently used)
cache.put("d", [4.0])
assert cache.get("a") == [10.0]
assert cache.get("b") is None
def test_expired_entry_is_cleaned_up(self):
"""过期条目在 get 时被清除,不占用缓存空间"""
cache = EmbeddingCache(max_size=2, ttl=1)
cache.put("a", [1.0])
# Put "b" slightly later so its TTL extends beyond "a"'s
time.sleep(0.3)
cache.put("b", [2.0])
# Wait for "a" to expire but not "b"
time.sleep(0.8)
# "a" should be expired and removed from cache
assert cache.get("a") is None
# "b" is still valid (put 0.8s ago, TTL=1s)
assert cache.get("b") == [2.0]
# Now cache has room: we can add "c"
cache.put("c", [3.0])
assert cache.get("c") == [3.0]
def test_special_characters_in_text(self):
"""特殊字符文本正确处理"""
cache = EmbeddingCache(max_size=10, ttl=3600)
special = "hello\nworld\ttab\0null"
cache.put(special, [1.0])
assert cache.get(special) == [1.0]
def test_very_long_text_key(self):
"""超长文本可以生成 key 并缓存"""
cache = EmbeddingCache(max_size=10, ttl=3600)
long_text = "x" * 100_000
cache.put(long_text, [0.5])
assert cache.get(long_text) == [0.5]