605 lines
20 KiB
Python
605 lines
20 KiB
Python
"""Unit tests for LLM Cache Core (U1).
|
|
|
|
Tests cover:
|
|
- CacheKey generation (deterministic, component isolation)
|
|
- InMemoryLLMCache (exact match, semantic match, TTL, LRU, stats)
|
|
- RedisLLMCache (same tests with mocked Redis)
|
|
- Factory function (backend selection, fallback)
|
|
"""
|
|
|
|
import json
|
|
import time
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from agentkit.llm.cache import (
|
|
CacheEntry,
|
|
CacheResult,
|
|
InMemoryLLMCache,
|
|
RedisLLMCache,
|
|
create_llm_cache,
|
|
_serialize_response,
|
|
_deserialize_response,
|
|
_serialize_entry,
|
|
_deserialize_entry,
|
|
)
|
|
from agentkit.llm.cache_key import generate_cache_key
|
|
from agentkit.llm.protocol import LLMResponse, TokenUsage, ToolCall
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fixtures
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _make_response(
|
|
content: str = "Hello",
|
|
model: str = "gpt-4o",
|
|
prompt_tokens: int = 10,
|
|
completion_tokens: int = 20,
|
|
tool_calls: list[ToolCall] | None = None,
|
|
) -> LLMResponse:
|
|
return LLMResponse(
|
|
content=content,
|
|
model=model,
|
|
usage=TokenUsage(
|
|
prompt_tokens=prompt_tokens,
|
|
completion_tokens=completion_tokens,
|
|
),
|
|
tool_calls=tool_calls or [],
|
|
latency_ms=100.0,
|
|
)
|
|
|
|
|
|
def _make_messages(user_content: str = "Hello") -> list[dict[str, str]]:
|
|
return [
|
|
{"role": "system", "content": "You are a helpful assistant."},
|
|
{"role": "user", "content": user_content},
|
|
]
|
|
|
|
|
|
def _make_embedding(base_val: float = 1.0, dim: int = 128) -> list[float]:
|
|
"""Create a unit vector for similarity testing."""
|
|
vec = [base_val] * dim
|
|
magnitude = sum(x**2 for x in vec) ** 0.5
|
|
return [x / magnitude for x in vec] if magnitude > 0 else vec
|
|
|
|
|
|
def _make_similar_embedding(base: list[float], noise: float = 0.01) -> list[float]:
|
|
"""Create a vector similar to base with small noise."""
|
|
vec = [x + noise for x in base]
|
|
magnitude = sum(x**2 for x in vec) ** 0.5
|
|
return [x / magnitude for x in vec] if magnitude > 0 else vec
|
|
|
|
|
|
def _make_different_embedding(dim: int = 128) -> list[float]:
|
|
"""Create a vector with very low cosine similarity to _make_embedding()."""
|
|
# _make_embedding(1.0) is all-positive unit vector.
|
|
# Negate first half to create near-orthogonal vector.
|
|
half = dim // 2
|
|
vec = [-1.0] * half + [1.0] * (dim - half)
|
|
magnitude = sum(x**2 for x in vec) ** 0.5
|
|
return [x / magnitude for x in vec] if magnitude > 0 else vec
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# CacheKey Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestCacheKey:
|
|
def test_deterministic(self):
|
|
"""Same inputs produce same key."""
|
|
msgs = _make_messages()
|
|
key1 = generate_cache_key("gpt-4o", msgs, 0.0)
|
|
key2 = generate_cache_key("gpt-4o", msgs, 0.0)
|
|
assert key1 == key2
|
|
assert len(key1) == 64 # SHA-256 hex
|
|
|
|
def test_different_model(self):
|
|
"""Different model produces different key."""
|
|
msgs = _make_messages()
|
|
key1 = generate_cache_key("gpt-4o", msgs, 0.0)
|
|
key2 = generate_cache_key("gpt-3.5-turbo", msgs, 0.0)
|
|
assert key1 != key2
|
|
|
|
def test_different_temperature(self):
|
|
"""Different temperature produces different key."""
|
|
msgs = _make_messages()
|
|
key1 = generate_cache_key("gpt-4o", msgs, 0.0)
|
|
key2 = generate_cache_key("gpt-4o", msgs, 0.7)
|
|
assert key1 != key2
|
|
|
|
def test_different_messages(self):
|
|
"""Different messages produce different key."""
|
|
key1 = generate_cache_key("gpt-4o", _make_messages("Hello"), 0.0)
|
|
key2 = generate_cache_key("gpt-4o", _make_messages("World"), 0.0)
|
|
assert key1 != key2
|
|
|
|
def test_different_tools(self):
|
|
"""Different tools produce different key."""
|
|
msgs = _make_messages()
|
|
tools1 = [{"type": "function", "function": {"name": "f1"}}]
|
|
tools2 = [{"type": "function", "function": {"name": "f2"}}]
|
|
key1 = generate_cache_key("gpt-4o", msgs, 0.0, tools=tools1)
|
|
key2 = generate_cache_key("gpt-4o", msgs, 0.0, tools=tools2)
|
|
assert key1 != key2
|
|
|
|
def test_none_tools_same_as_no_tools(self):
|
|
"""None tools and no tools produce same key."""
|
|
msgs = _make_messages()
|
|
key1 = generate_cache_key("gpt-4o", msgs, 0.0, tools=None)
|
|
key2 = generate_cache_key("gpt-4o", msgs, 0.0)
|
|
assert key1 == key2
|
|
|
|
def test_system_prompt_extracted_from_messages(self):
|
|
"""System prompt is extracted from messages[0] with role=system."""
|
|
msgs = [
|
|
{"role": "system", "content": "Be concise"},
|
|
{"role": "user", "content": "Hello"},
|
|
]
|
|
key1 = generate_cache_key("gpt-4o", msgs, 0.0)
|
|
|
|
msgs2 = [
|
|
{"role": "system", "content": "Be verbose"},
|
|
{"role": "user", "content": "Hello"},
|
|
]
|
|
key2 = generate_cache_key("gpt-4o", msgs2, 0.0)
|
|
assert key1 != key2
|
|
|
|
def test_max_tokens_affects_key(self):
|
|
"""Different max_tokens produce different key."""
|
|
msgs = _make_messages()
|
|
key1 = generate_cache_key("gpt-4o", msgs, 0.0, max_tokens=2000)
|
|
key2 = generate_cache_key("gpt-4o", msgs, 0.0, max_tokens=4000)
|
|
assert key1 != key2
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# InMemoryLLMCache Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestInMemoryLLMCache:
|
|
@pytest.mark.asyncio
|
|
async def test_exact_match_hit(self):
|
|
cache = InMemoryLLMCache()
|
|
key = "test_key_1"
|
|
response = _make_response("Cached answer")
|
|
|
|
await cache.put(key, response)
|
|
result = await cache.get(key)
|
|
|
|
assert result.hit is True
|
|
assert result.match_type == "exact"
|
|
assert result.response.content == "Cached answer"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_exact_match_miss(self):
|
|
cache = InMemoryLLMCache()
|
|
await cache.put("key1", _make_response())
|
|
|
|
result = await cache.get("key2")
|
|
assert result.hit is False
|
|
assert result.response is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_semantic_match_hit(self):
|
|
cache = InMemoryLLMCache(similarity_threshold=0.9)
|
|
emb1 = _make_embedding(1.0, dim=64)
|
|
emb_similar = _make_similar_embedding(emb1, noise=0.001)
|
|
|
|
await cache.put("key1", _make_response("Cached"), query_embedding=emb1)
|
|
result = await cache.semantic_search(emb_similar, threshold=0.9)
|
|
|
|
assert result.hit is True
|
|
assert result.match_type == "semantic"
|
|
assert result.response.content == "Cached"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_semantic_match_miss(self):
|
|
cache = InMemoryLLMCache(similarity_threshold=0.9)
|
|
emb1 = _make_embedding(1.0, dim=64)
|
|
emb_different = _make_different_embedding(dim=64)
|
|
|
|
await cache.put("key1", _make_response("Cached"), query_embedding=emb1)
|
|
result = await cache.semantic_search(emb_different, threshold=0.9)
|
|
|
|
assert result.hit is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_semantic_match_empty_cache(self):
|
|
cache = InMemoryLLMCache()
|
|
result = await cache.semantic_search(_make_embedding(dim=64))
|
|
assert result.hit is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ttl_expiry_exact(self):
|
|
cache = InMemoryLLMCache(exact_ttl=1) # 1 second TTL
|
|
await cache.put("key1", _make_response())
|
|
|
|
# Wait for expiry
|
|
time.sleep(1.1)
|
|
result = await cache.get("key1")
|
|
assert result.hit is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ttl_expiry_semantic(self):
|
|
cache = InMemoryLLMCache(semantic_ttl=1)
|
|
emb = _make_embedding(dim=64)
|
|
await cache.put("key1", _make_response(), query_embedding=emb)
|
|
|
|
time.sleep(1.1)
|
|
result = await cache.semantic_search(emb)
|
|
assert result.hit is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_lru_eviction(self):
|
|
cache = InMemoryLLMCache(max_entries=3)
|
|
|
|
for i in range(4):
|
|
await cache.put(f"key{i}", _make_response(f"Response {i}"))
|
|
|
|
# key0 should be evicted (oldest)
|
|
result = await cache.get("key0")
|
|
assert result.hit is False
|
|
|
|
# key1-key3 should still be present
|
|
for i in range(1, 4):
|
|
result = await cache.get(f"key{i}")
|
|
assert result.hit is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_lru_access_refreshes(self):
|
|
cache = InMemoryLLMCache(max_entries=3)
|
|
|
|
await cache.put("key0", _make_response("R0"))
|
|
await cache.put("key1", _make_response("R1"))
|
|
await cache.put("key2", _make_response("R2"))
|
|
|
|
# Access key0 to move it to most-recently-used
|
|
await cache.get("key0")
|
|
|
|
# Adding key3 should evict key1 (now LRU)
|
|
await cache.put("key3", _make_response("R3"))
|
|
|
|
result = await cache.get("key1")
|
|
assert result.hit is False
|
|
|
|
result = await cache.get("key0")
|
|
assert result.hit is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_invalidate_all(self):
|
|
cache = InMemoryLLMCache()
|
|
await cache.put("key1", _make_response())
|
|
await cache.put("key2", _make_response())
|
|
|
|
count = await cache.invalidate()
|
|
assert count == 2
|
|
|
|
result = await cache.get("key1")
|
|
assert result.hit is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_invalidate_pattern(self):
|
|
cache = InMemoryLLMCache()
|
|
await cache.put("abc_1", _make_response())
|
|
await cache.put("abc_2", _make_response())
|
|
await cache.put("xyz_1", _make_response())
|
|
|
|
count = await cache.invalidate("abc_*")
|
|
assert count == 2
|
|
|
|
result = await cache.get("xyz_1")
|
|
assert result.hit is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stats(self):
|
|
cache = InMemoryLLMCache()
|
|
await cache.put("key1", _make_response())
|
|
await cache.get("key1") # hit
|
|
await cache.get("key2") # miss
|
|
|
|
stats = await cache.stats()
|
|
assert stats["total_entries"] == 1
|
|
assert stats["total_hits"] == 1
|
|
assert stats["total_misses"] == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tool_calls_cached(self):
|
|
cache = InMemoryLLMCache()
|
|
tool_calls = [
|
|
ToolCall(id="call_1", name="search", arguments={"query": "test"})
|
|
]
|
|
response = _make_response(tool_calls=tool_calls)
|
|
|
|
await cache.put("key1", response)
|
|
result = await cache.get("key1")
|
|
|
|
assert result.hit is True
|
|
assert len(result.response.tool_calls) == 1
|
|
assert result.response.tool_calls[0].name == "search"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_put_without_embedding(self):
|
|
cache = InMemoryLLMCache()
|
|
await cache.put("key1", _make_response(), query_embedding=None)
|
|
|
|
# Exact match should still work
|
|
result = await cache.get("key1")
|
|
assert result.hit is True
|
|
|
|
# Semantic search should return miss (no embeddings)
|
|
result = await cache.semantic_search(_make_embedding(dim=64))
|
|
assert result.hit is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_put_updates_existing_key(self):
|
|
cache = InMemoryLLMCache()
|
|
await cache.put("key1", _make_response("Old"))
|
|
await cache.put("key1", _make_response("New"))
|
|
|
|
result = await cache.get("key1")
|
|
assert result.hit is True
|
|
assert result.response.content == "New"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Serialization Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestSerialization:
|
|
def test_serialize_deserialize_response(self):
|
|
response = _make_response(
|
|
content="Test",
|
|
model="gpt-4o",
|
|
prompt_tokens=5,
|
|
completion_tokens=10,
|
|
tool_calls=[ToolCall(id="c1", name="tool1", arguments={"k": "v"})],
|
|
)
|
|
serialized = _serialize_response(response)
|
|
deserialized = _deserialize_response(serialized)
|
|
|
|
assert deserialized.content == "Test"
|
|
assert deserialized.model == "gpt-4o"
|
|
assert deserialized.usage.prompt_tokens == 5
|
|
assert deserialized.usage.completion_tokens == 10
|
|
assert len(deserialized.tool_calls) == 1
|
|
assert deserialized.tool_calls[0].name == "tool1"
|
|
|
|
def test_serialize_deserialize_entry(self):
|
|
entry = CacheEntry(
|
|
response=_make_response(),
|
|
query_embedding=[0.1, 0.2, 0.3],
|
|
created_at=12345.0,
|
|
hit_count=5,
|
|
)
|
|
serialized = _serialize_entry(entry)
|
|
deserialized = _deserialize_entry(serialized)
|
|
|
|
assert deserialized.response.content == "Hello"
|
|
assert deserialized.query_embedding == [0.1, 0.2, 0.3]
|
|
assert deserialized.created_at == 12345.0
|
|
assert deserialized.hit_count == 5
|
|
|
|
def test_serialize_response_no_tool_calls(self):
|
|
response = _make_response()
|
|
serialized = _serialize_response(response)
|
|
assert serialized["tool_calls"] == []
|
|
|
|
deserialized = _deserialize_response(serialized)
|
|
assert deserialized.tool_calls == []
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# RedisLLMCache Tests (with mocked Redis)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestRedisLLMCache:
|
|
def _make_mock_redis(self):
|
|
"""Create a mock Redis client that simulates basic operations."""
|
|
mock = AsyncMock()
|
|
mock._data = {}
|
|
mock._sets = {}
|
|
|
|
async def mock_get(key):
|
|
return mock._data.get(key)
|
|
|
|
async def mock_set(key, value, ex=None):
|
|
mock._data[key] = value
|
|
|
|
async def mock_mget(keys):
|
|
return [mock._data.get(k) for k in keys]
|
|
|
|
async def mock_sadd(key, *members):
|
|
if key not in mock._sets:
|
|
mock._sets[key] = set()
|
|
mock._sets[key].update(members)
|
|
|
|
async def mock_smembers(key):
|
|
return mock._sets.get(key, set())
|
|
|
|
async def mock_scard(key):
|
|
return len(mock._sets.get(key, set()))
|
|
|
|
async def mock_delete(*keys):
|
|
for k in keys:
|
|
mock._data.pop(k, None)
|
|
|
|
async def mock_srem(key, *members):
|
|
if key in mock._sets:
|
|
mock._sets[key] -= set(members)
|
|
|
|
mock.get = mock_get
|
|
mock.set = mock_set
|
|
mock.mget = mock_mget
|
|
mock.sadd = mock_sadd
|
|
mock.smembers = mock_smembers
|
|
mock.scard = mock_scard
|
|
mock.delete = mock_delete
|
|
mock.srem = mock_srem
|
|
|
|
# Pipeline mock — collects commands and executes them on execute()
|
|
class MockPipeline:
|
|
def __init__(self):
|
|
self._commands = []
|
|
|
|
def set(self, key, value, ex=None):
|
|
self._commands.append(("set", key, value, ex))
|
|
|
|
def sadd(self, key, *members):
|
|
self._commands.append(("sadd", key, members))
|
|
|
|
def delete(self, *keys):
|
|
self._commands.append(("delete", keys))
|
|
|
|
def srem(self, key, *members):
|
|
self._commands.append(("srem", key, members))
|
|
|
|
async def execute(self):
|
|
for cmd in self._commands:
|
|
if cmd[0] == "set":
|
|
_, key, value, ex = cmd
|
|
mock._data[key] = value
|
|
elif cmd[0] == "sadd":
|
|
_, key, members = cmd
|
|
if key not in mock._sets:
|
|
mock._sets[key] = set()
|
|
mock._sets[key].update(members)
|
|
elif cmd[0] == "delete":
|
|
_, keys = cmd
|
|
for k in keys:
|
|
mock._data.pop(k, None)
|
|
elif cmd[0] == "srem":
|
|
_, key, members = cmd
|
|
if key in mock._sets:
|
|
mock._sets[key] -= set(members)
|
|
|
|
mock.pipeline = MagicMock(return_value=MockPipeline())
|
|
|
|
return mock
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_exact_match_hit(self):
|
|
cache = RedisLLMCache()
|
|
mock_redis = self._make_mock_redis()
|
|
cache._redis = mock_redis
|
|
|
|
key = "test_key"
|
|
response = _make_response("Cached")
|
|
entry = CacheEntry(response=response, created_at=time.monotonic(), hit_count=0)
|
|
entry_json = json.dumps(_serialize_entry(entry))
|
|
|
|
# Simulate Redis already has the data
|
|
mock_redis._data[f"{cache.KEY_PREFIX}{key}"] = entry_json
|
|
|
|
result = await cache.get(key)
|
|
assert result.hit is True
|
|
assert result.match_type == "exact"
|
|
assert result.response.content == "Cached"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_exact_match_miss(self):
|
|
cache = RedisLLMCache()
|
|
mock_redis = self._make_mock_redis()
|
|
cache._redis = mock_redis
|
|
|
|
result = await cache.get("nonexistent_key")
|
|
assert result.hit is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_redis_failure_returns_miss(self):
|
|
cache = RedisLLMCache()
|
|
mock_redis = AsyncMock()
|
|
mock_redis.get = AsyncMock(side_effect=Exception("Connection refused"))
|
|
cache._redis = mock_redis
|
|
|
|
result = await cache.get("any_key")
|
|
assert result.hit is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_put_stores_data(self):
|
|
cache = RedisLLMCache()
|
|
mock_redis = self._make_mock_redis()
|
|
cache._redis = mock_redis
|
|
|
|
key = "test_key"
|
|
response = _make_response()
|
|
emb = [0.1, 0.2, 0.3]
|
|
|
|
await cache.put(key, response, query_embedding=emb)
|
|
|
|
# Verify data was stored
|
|
assert f"{cache.KEY_PREFIX}{key}" in mock_redis._data
|
|
assert f"{cache.EMB_PREFIX}{key}" in mock_redis._data
|
|
assert key in mock_redis._sets.get(cache.INDEX_KEY, set())
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_invalidate_all(self):
|
|
cache = RedisLLMCache()
|
|
mock_redis = self._make_mock_redis()
|
|
cache._redis = mock_redis
|
|
|
|
# Pre-populate
|
|
mock_redis._sets[cache.INDEX_KEY] = {"key1", "key2"}
|
|
mock_redis._data[f"{cache.KEY_PREFIX}key1"] = "data1"
|
|
mock_redis._data[f"{cache.KEY_PREFIX}key2"] = "data2"
|
|
|
|
count = await cache.invalidate()
|
|
assert count == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stats(self):
|
|
cache = RedisLLMCache()
|
|
mock_redis = self._make_mock_redis()
|
|
cache._redis = mock_redis
|
|
mock_redis._sets[cache.INDEX_KEY] = {"k1", "k2", "k3"}
|
|
|
|
stats = await cache.stats()
|
|
assert stats["total_entries"] == 3
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Factory Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestCreateLLMCache:
|
|
def test_memory_backend(self):
|
|
cache = create_llm_cache(backend="memory")
|
|
assert isinstance(cache, InMemoryLLMCache)
|
|
|
|
def test_auto_backend_fallback(self):
|
|
"""When redis package is not available, auto falls back to InMemory."""
|
|
with patch.dict("sys.modules", {"redis.asyncio": None}):
|
|
# Force ImportError by making redis.asyncio unimportable
|
|
cache = create_llm_cache(backend="auto")
|
|
assert isinstance(cache, InMemoryLLMCache)
|
|
|
|
def test_redis_backend_with_redis_available(self):
|
|
"""When redis.asyncio is available, auto/redis returns RedisLLMCache."""
|
|
cache = create_llm_cache(backend="redis")
|
|
assert isinstance(cache, RedisLLMCache)
|
|
|
|
def test_auto_backend_with_redis_available(self):
|
|
cache = create_llm_cache(backend="auto")
|
|
assert isinstance(cache, RedisLLMCache)
|
|
|
|
def test_custom_parameters(self):
|
|
cache = create_llm_cache(
|
|
backend="memory",
|
|
max_entries=500,
|
|
exact_ttl=7200,
|
|
semantic_ttl=172800,
|
|
similarity_threshold=0.95,
|
|
)
|
|
assert isinstance(cache, InMemoryLLMCache)
|
|
assert cache._max_entries == 500
|
|
assert cache._exact_ttl == 7200
|
|
assert cache._semantic_ttl == 172800
|
|
assert cache._similarity_threshold == 0.95
|