239 lines
8.4 KiB
Python
239 lines
8.4 KiB
Python
"""EmbeddingCache 单元测试 - LRU 缓存 + TTL"""
|
|
|
|
import time
|
|
|
|
import pytest
|
|
|
|
from agentkit.memory.embedder import EmbeddingCache
|
|
|
|
|
|
class TestEmbeddingCacheBasic:
|
|
"""EmbeddingCache 基本功能测试"""
|
|
|
|
def test_put_and_get(self):
|
|
"""put 后可以 get 到"""
|
|
cache = EmbeddingCache(max_size=100, ttl=3600)
|
|
vec = [0.1, 0.2, 0.3]
|
|
cache.put("hello", vec)
|
|
assert cache.get("hello") == vec
|
|
|
|
def test_get_missing_key_returns_none(self):
|
|
"""get 不存在的 key 返回 None"""
|
|
cache = EmbeddingCache()
|
|
assert cache.get("nonexistent") is None
|
|
|
|
def test_clear_removes_all_entries(self):
|
|
"""clear 清除所有缓存"""
|
|
cache = EmbeddingCache()
|
|
cache.put("a", [1.0])
|
|
cache.put("b", [2.0])
|
|
cache.clear()
|
|
assert cache.get("a") is None
|
|
assert cache.get("b") is None
|
|
|
|
def test_same_text_same_key(self):
|
|
"""相同文本映射到相同缓存 key"""
|
|
cache = EmbeddingCache()
|
|
cache.put("hello", [1.0])
|
|
cache.put("hello", [2.0]) # overwrite
|
|
assert cache.get("hello") == [2.0]
|
|
|
|
def test_different_text_different_key(self):
|
|
"""不同文本映射到不同缓存 key"""
|
|
cache = EmbeddingCache()
|
|
cache.put("hello", [1.0])
|
|
cache.put("world", [2.0])
|
|
assert cache.get("hello") == [1.0]
|
|
assert cache.get("world") == [2.0]
|
|
|
|
|
|
class TestEmbeddingCacheLRU:
|
|
"""EmbeddingCache LRU 淘汰测试"""
|
|
|
|
def test_evicts_oldest_when_full(self):
|
|
"""缓存满时淘汰最久未使用的条目"""
|
|
cache = EmbeddingCache(max_size=3, ttl=3600)
|
|
cache.put("a", [1.0])
|
|
cache.put("b", [2.0])
|
|
cache.put("c", [3.0])
|
|
# Cache is full (3 entries). Adding "d" should evict "a"
|
|
cache.put("d", [4.0])
|
|
assert cache.get("a") is None
|
|
assert cache.get("b") == [2.0]
|
|
assert cache.get("c") == [3.0]
|
|
assert cache.get("d") == [4.0]
|
|
|
|
def test_get_refreshes_lru_order(self):
|
|
"""get 操作刷新 LRU 顺序,避免被淘汰"""
|
|
cache = EmbeddingCache(max_size=3, ttl=3600)
|
|
cache.put("a", [1.0])
|
|
cache.put("b", [2.0])
|
|
cache.put("c", [3.0])
|
|
# Access "a" to refresh its position
|
|
cache.get("a")
|
|
# Adding "d" should evict "b" (least recently used)
|
|
cache.put("d", [4.0])
|
|
assert cache.get("a") == [1.0] # Still present
|
|
assert cache.get("b") is None # Evicted
|
|
assert cache.get("c") == [3.0]
|
|
assert cache.get("d") == [4.0]
|
|
|
|
def test_put_existing_key_refreshes_position(self):
|
|
"""put 已存在的 key 刷新 LRU 位置"""
|
|
cache = EmbeddingCache(max_size=3, ttl=3600)
|
|
cache.put("a", [1.0])
|
|
cache.put("b", [2.0])
|
|
cache.put("c", [3.0])
|
|
# Re-put "a" to refresh
|
|
cache.put("a", [10.0])
|
|
# Adding "d" should evict "b"
|
|
cache.put("d", [4.0])
|
|
assert cache.get("a") == [10.0]
|
|
assert cache.get("b") is None
|
|
assert cache.get("c") == [3.0]
|
|
|
|
def test_max_size_one(self):
|
|
"""max_size=1 时只保留最新条目"""
|
|
cache = EmbeddingCache(max_size=1, ttl=3600)
|
|
cache.put("a", [1.0])
|
|
cache.put("b", [2.0])
|
|
assert cache.get("a") is None
|
|
assert cache.get("b") == [2.0]
|
|
|
|
|
|
class TestEmbeddingCacheTTL:
|
|
"""EmbeddingCache TTL 过期测试"""
|
|
|
|
def test_expired_entry_returns_none(self):
|
|
"""过期条目 get 返回 None"""
|
|
cache = EmbeddingCache(max_size=100, ttl=0) # TTL=0 means immediately expired
|
|
cache.put("hello", [1.0])
|
|
# With TTL=0, the entry should be expired by the time we get it
|
|
# (time.monotonic() advances between put and get)
|
|
result = cache.get("hello")
|
|
# This may or may not be None depending on timing, so we use a short TTL
|
|
# Let's test with a small positive TTL instead
|
|
cache2 = EmbeddingCache(max_size=100, ttl=1) # 1 second TTL
|
|
cache2.put("hello", [1.0])
|
|
assert cache2.get("hello") == [1.0] # Should still be valid
|
|
|
|
def test_non_expired_entry_returns_value(self):
|
|
"""未过期条目 get 返回缓存值"""
|
|
cache = EmbeddingCache(max_size=100, ttl=3600)
|
|
cache.put("hello", [1.0])
|
|
assert cache.get("hello") == [1.0]
|
|
|
|
def test_ttl_expiration_removes_entry(self):
|
|
"""过期后条目从缓存中移除"""
|
|
cache = EmbeddingCache(max_size=100, ttl=1) # 1 second
|
|
cache.put("hello", [1.0])
|
|
# Wait for TTL to expire
|
|
time.sleep(1.1)
|
|
assert cache.get("hello") is None
|
|
|
|
|
|
class TestEmbeddingCacheKeyGeneration:
|
|
"""EmbeddingCache key 生成测试"""
|
|
|
|
def test_key_is_deterministic(self):
|
|
"""相同文本生成相同 key"""
|
|
key1 = EmbeddingCache._make_key("hello world")
|
|
key2 = EmbeddingCache._make_key("hello world")
|
|
assert key1 == key2
|
|
|
|
def test_different_text_different_key(self):
|
|
"""不同文本生成不同 key"""
|
|
key1 = EmbeddingCache._make_key("hello")
|
|
key2 = EmbeddingCache._make_key("world")
|
|
assert key1 != key2
|
|
|
|
def test_key_is_sha256_hex(self):
|
|
"""key 是 SHA-256 十六进制字符串"""
|
|
import hashlib
|
|
text = "test input"
|
|
expected = hashlib.sha256(text.encode()).hexdigest()
|
|
assert EmbeddingCache._make_key(text) == expected
|
|
|
|
def test_unicode_text_handled(self):
|
|
"""Unicode 文本正确处理"""
|
|
key1 = EmbeddingCache._make_key("你好世界")
|
|
key2 = EmbeddingCache._make_key("你好世界")
|
|
assert key1 == key2
|
|
# Different unicode text should produce different keys
|
|
key3 = EmbeddingCache._make_key("こんにちは")
|
|
assert key1 != key3
|
|
|
|
|
|
class TestEmbeddingCacheEdgeCases:
|
|
"""EmbeddingCache 边界情况测试"""
|
|
|
|
def test_empty_string_key(self):
|
|
"""空字符串可以作为缓存 key"""
|
|
cache = EmbeddingCache(max_size=10, ttl=3600)
|
|
cache.put("", [0.0])
|
|
assert cache.get("") == [0.0]
|
|
|
|
def test_empty_vector_cached(self):
|
|
"""空向量可以被缓存"""
|
|
cache = EmbeddingCache(max_size=10, ttl=3600)
|
|
cache.put("empty_vec", [])
|
|
assert cache.get("empty_vec") == []
|
|
|
|
def test_large_vector_cached(self):
|
|
"""大维度向量可以被缓存"""
|
|
cache = EmbeddingCache(max_size=10, ttl=3600)
|
|
large_vec = [float(i) for i in range(1536)]
|
|
cache.put("large", large_vec)
|
|
assert cache.get("large") == large_vec
|
|
|
|
def test_max_size_zero_never_stores(self):
|
|
"""max_size=0 时无法存储任何条目"""
|
|
cache = EmbeddingCache(max_size=0, ttl=3600)
|
|
cache.put("a", [1.0])
|
|
# Entry is immediately evicted since max_size=0
|
|
assert cache.get("a") is None
|
|
|
|
def test_put_overwrite_preserves_freshness(self):
|
|
"""put 覆盖已存在的 key 时更新值和时间戳"""
|
|
cache = EmbeddingCache(max_size=3, ttl=3600)
|
|
cache.put("a", [1.0])
|
|
cache.put("b", [2.0])
|
|
cache.put("c", [3.0])
|
|
# Overwrite "a" with new value — refreshes its LRU position
|
|
cache.put("a", [10.0])
|
|
# Adding "d" should evict "b" (least recently used)
|
|
cache.put("d", [4.0])
|
|
assert cache.get("a") == [10.0]
|
|
assert cache.get("b") is None
|
|
|
|
def test_expired_entry_is_cleaned_up(self):
|
|
"""过期条目在 get 时被清除,不占用缓存空间"""
|
|
cache = EmbeddingCache(max_size=2, ttl=1)
|
|
cache.put("a", [1.0])
|
|
# Put "b" slightly later so its TTL extends beyond "a"'s
|
|
time.sleep(0.3)
|
|
cache.put("b", [2.0])
|
|
# Wait for "a" to expire but not "b"
|
|
time.sleep(0.8)
|
|
# "a" should be expired and removed from cache
|
|
assert cache.get("a") is None
|
|
# "b" is still valid (put 0.8s ago, TTL=1s)
|
|
assert cache.get("b") == [2.0]
|
|
# Now cache has room: we can add "c"
|
|
cache.put("c", [3.0])
|
|
assert cache.get("c") == [3.0]
|
|
|
|
def test_special_characters_in_text(self):
|
|
"""特殊字符文本正确处理"""
|
|
cache = EmbeddingCache(max_size=10, ttl=3600)
|
|
special = "hello\nworld\ttab\0null"
|
|
cache.put(special, [1.0])
|
|
assert cache.get(special) == [1.0]
|
|
|
|
def test_very_long_text_key(self):
|
|
"""超长文本可以生成 key 并缓存"""
|
|
cache = EmbeddingCache(max_size=10, ttl=3600)
|
|
long_text = "x" * 100_000
|
|
cache.put(long_text, [0.5])
|
|
assert cache.get(long_text) == [0.5]
|