fischer-agentkit/tests/unit/test_llm_cache.py

605 lines
20 KiB
Python

"""Unit tests for LLM Cache Core (U1).
Tests cover:
- CacheKey generation (deterministic, component isolation)
- InMemoryLLMCache (exact match, semantic match, TTL, LRU, stats)
- RedisLLMCache (same tests with mocked Redis)
- Factory function (backend selection, fallback)
"""
import json
import time
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from agentkit.llm.cache import (
CacheEntry,
CacheResult,
InMemoryLLMCache,
RedisLLMCache,
create_llm_cache,
_serialize_response,
_deserialize_response,
_serialize_entry,
_deserialize_entry,
)
from agentkit.llm.cache_key import generate_cache_key
from agentkit.llm.protocol import LLMResponse, TokenUsage, ToolCall
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
def _make_response(
content: str = "Hello",
model: str = "gpt-4o",
prompt_tokens: int = 10,
completion_tokens: int = 20,
tool_calls: list[ToolCall] | None = None,
) -> LLMResponse:
return LLMResponse(
content=content,
model=model,
usage=TokenUsage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
),
tool_calls=tool_calls or [],
latency_ms=100.0,
)
def _make_messages(user_content: str = "Hello") -> list[dict[str, str]]:
return [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": user_content},
]
def _make_embedding(base_val: float = 1.0, dim: int = 128) -> list[float]:
"""Create a unit vector for similarity testing."""
vec = [base_val] * dim
magnitude = sum(x**2 for x in vec) ** 0.5
return [x / magnitude for x in vec] if magnitude > 0 else vec
def _make_similar_embedding(base: list[float], noise: float = 0.01) -> list[float]:
"""Create a vector similar to base with small noise."""
vec = [x + noise for x in base]
magnitude = sum(x**2 for x in vec) ** 0.5
return [x / magnitude for x in vec] if magnitude > 0 else vec
def _make_different_embedding(dim: int = 128) -> list[float]:
"""Create a vector with very low cosine similarity to _make_embedding()."""
# _make_embedding(1.0) is all-positive unit vector.
# Negate first half to create near-orthogonal vector.
half = dim // 2
vec = [-1.0] * half + [1.0] * (dim - half)
magnitude = sum(x**2 for x in vec) ** 0.5
return [x / magnitude for x in vec] if magnitude > 0 else vec
# ---------------------------------------------------------------------------
# CacheKey Tests
# ---------------------------------------------------------------------------
class TestCacheKey:
def test_deterministic(self):
"""Same inputs produce same key."""
msgs = _make_messages()
key1 = generate_cache_key("gpt-4o", msgs, 0.0)
key2 = generate_cache_key("gpt-4o", msgs, 0.0)
assert key1 == key2
assert len(key1) == 64 # SHA-256 hex
def test_different_model(self):
"""Different model produces different key."""
msgs = _make_messages()
key1 = generate_cache_key("gpt-4o", msgs, 0.0)
key2 = generate_cache_key("gpt-3.5-turbo", msgs, 0.0)
assert key1 != key2
def test_different_temperature(self):
"""Different temperature produces different key."""
msgs = _make_messages()
key1 = generate_cache_key("gpt-4o", msgs, 0.0)
key2 = generate_cache_key("gpt-4o", msgs, 0.7)
assert key1 != key2
def test_different_messages(self):
"""Different messages produce different key."""
key1 = generate_cache_key("gpt-4o", _make_messages("Hello"), 0.0)
key2 = generate_cache_key("gpt-4o", _make_messages("World"), 0.0)
assert key1 != key2
def test_different_tools(self):
"""Different tools produce different key."""
msgs = _make_messages()
tools1 = [{"type": "function", "function": {"name": "f1"}}]
tools2 = [{"type": "function", "function": {"name": "f2"}}]
key1 = generate_cache_key("gpt-4o", msgs, 0.0, tools=tools1)
key2 = generate_cache_key("gpt-4o", msgs, 0.0, tools=tools2)
assert key1 != key2
def test_none_tools_same_as_no_tools(self):
"""None tools and no tools produce same key."""
msgs = _make_messages()
key1 = generate_cache_key("gpt-4o", msgs, 0.0, tools=None)
key2 = generate_cache_key("gpt-4o", msgs, 0.0)
assert key1 == key2
def test_system_prompt_extracted_from_messages(self):
"""System prompt is extracted from messages[0] with role=system."""
msgs = [
{"role": "system", "content": "Be concise"},
{"role": "user", "content": "Hello"},
]
key1 = generate_cache_key("gpt-4o", msgs, 0.0)
msgs2 = [
{"role": "system", "content": "Be verbose"},
{"role": "user", "content": "Hello"},
]
key2 = generate_cache_key("gpt-4o", msgs2, 0.0)
assert key1 != key2
def test_max_tokens_affects_key(self):
"""Different max_tokens produce different key."""
msgs = _make_messages()
key1 = generate_cache_key("gpt-4o", msgs, 0.0, max_tokens=2000)
key2 = generate_cache_key("gpt-4o", msgs, 0.0, max_tokens=4000)
assert key1 != key2
# ---------------------------------------------------------------------------
# InMemoryLLMCache Tests
# ---------------------------------------------------------------------------
class TestInMemoryLLMCache:
@pytest.mark.asyncio
async def test_exact_match_hit(self):
cache = InMemoryLLMCache()
key = "test_key_1"
response = _make_response("Cached answer")
await cache.put(key, response)
result = await cache.get(key)
assert result.hit is True
assert result.match_type == "exact"
assert result.response.content == "Cached answer"
@pytest.mark.asyncio
async def test_exact_match_miss(self):
cache = InMemoryLLMCache()
await cache.put("key1", _make_response())
result = await cache.get("key2")
assert result.hit is False
assert result.response is None
@pytest.mark.asyncio
async def test_semantic_match_hit(self):
cache = InMemoryLLMCache(similarity_threshold=0.9)
emb1 = _make_embedding(1.0, dim=64)
emb_similar = _make_similar_embedding(emb1, noise=0.001)
await cache.put("key1", _make_response("Cached"), query_embedding=emb1)
result = await cache.semantic_search(emb_similar, threshold=0.9)
assert result.hit is True
assert result.match_type == "semantic"
assert result.response.content == "Cached"
@pytest.mark.asyncio
async def test_semantic_match_miss(self):
cache = InMemoryLLMCache(similarity_threshold=0.9)
emb1 = _make_embedding(1.0, dim=64)
emb_different = _make_different_embedding(dim=64)
await cache.put("key1", _make_response("Cached"), query_embedding=emb1)
result = await cache.semantic_search(emb_different, threshold=0.9)
assert result.hit is False
@pytest.mark.asyncio
async def test_semantic_match_empty_cache(self):
cache = InMemoryLLMCache()
result = await cache.semantic_search(_make_embedding(dim=64))
assert result.hit is False
@pytest.mark.asyncio
async def test_ttl_expiry_exact(self):
cache = InMemoryLLMCache(exact_ttl=1) # 1 second TTL
await cache.put("key1", _make_response())
# Wait for expiry
time.sleep(1.1)
result = await cache.get("key1")
assert result.hit is False
@pytest.mark.asyncio
async def test_ttl_expiry_semantic(self):
cache = InMemoryLLMCache(semantic_ttl=1)
emb = _make_embedding(dim=64)
await cache.put("key1", _make_response(), query_embedding=emb)
time.sleep(1.1)
result = await cache.semantic_search(emb)
assert result.hit is False
@pytest.mark.asyncio
async def test_lru_eviction(self):
cache = InMemoryLLMCache(max_entries=3)
for i in range(4):
await cache.put(f"key{i}", _make_response(f"Response {i}"))
# key0 should be evicted (oldest)
result = await cache.get("key0")
assert result.hit is False
# key1-key3 should still be present
for i in range(1, 4):
result = await cache.get(f"key{i}")
assert result.hit is True
@pytest.mark.asyncio
async def test_lru_access_refreshes(self):
cache = InMemoryLLMCache(max_entries=3)
await cache.put("key0", _make_response("R0"))
await cache.put("key1", _make_response("R1"))
await cache.put("key2", _make_response("R2"))
# Access key0 to move it to most-recently-used
await cache.get("key0")
# Adding key3 should evict key1 (now LRU)
await cache.put("key3", _make_response("R3"))
result = await cache.get("key1")
assert result.hit is False
result = await cache.get("key0")
assert result.hit is True
@pytest.mark.asyncio
async def test_invalidate_all(self):
cache = InMemoryLLMCache()
await cache.put("key1", _make_response())
await cache.put("key2", _make_response())
count = await cache.invalidate()
assert count == 2
result = await cache.get("key1")
assert result.hit is False
@pytest.mark.asyncio
async def test_invalidate_pattern(self):
cache = InMemoryLLMCache()
await cache.put("abc_1", _make_response())
await cache.put("abc_2", _make_response())
await cache.put("xyz_1", _make_response())
count = await cache.invalidate("abc_*")
assert count == 2
result = await cache.get("xyz_1")
assert result.hit is True
@pytest.mark.asyncio
async def test_stats(self):
cache = InMemoryLLMCache()
await cache.put("key1", _make_response())
await cache.get("key1") # hit
await cache.get("key2") # miss
stats = await cache.stats()
assert stats["total_entries"] == 1
assert stats["total_hits"] == 1
assert stats["total_misses"] == 1
@pytest.mark.asyncio
async def test_tool_calls_cached(self):
cache = InMemoryLLMCache()
tool_calls = [
ToolCall(id="call_1", name="search", arguments={"query": "test"})
]
response = _make_response(tool_calls=tool_calls)
await cache.put("key1", response)
result = await cache.get("key1")
assert result.hit is True
assert len(result.response.tool_calls) == 1
assert result.response.tool_calls[0].name == "search"
@pytest.mark.asyncio
async def test_put_without_embedding(self):
cache = InMemoryLLMCache()
await cache.put("key1", _make_response(), query_embedding=None)
# Exact match should still work
result = await cache.get("key1")
assert result.hit is True
# Semantic search should return miss (no embeddings)
result = await cache.semantic_search(_make_embedding(dim=64))
assert result.hit is False
@pytest.mark.asyncio
async def test_put_updates_existing_key(self):
cache = InMemoryLLMCache()
await cache.put("key1", _make_response("Old"))
await cache.put("key1", _make_response("New"))
result = await cache.get("key1")
assert result.hit is True
assert result.response.content == "New"
# ---------------------------------------------------------------------------
# Serialization Tests
# ---------------------------------------------------------------------------
class TestSerialization:
def test_serialize_deserialize_response(self):
response = _make_response(
content="Test",
model="gpt-4o",
prompt_tokens=5,
completion_tokens=10,
tool_calls=[ToolCall(id="c1", name="tool1", arguments={"k": "v"})],
)
serialized = _serialize_response(response)
deserialized = _deserialize_response(serialized)
assert deserialized.content == "Test"
assert deserialized.model == "gpt-4o"
assert deserialized.usage.prompt_tokens == 5
assert deserialized.usage.completion_tokens == 10
assert len(deserialized.tool_calls) == 1
assert deserialized.tool_calls[0].name == "tool1"
def test_serialize_deserialize_entry(self):
entry = CacheEntry(
response=_make_response(),
query_embedding=[0.1, 0.2, 0.3],
created_at=12345.0,
hit_count=5,
)
serialized = _serialize_entry(entry)
deserialized = _deserialize_entry(serialized)
assert deserialized.response.content == "Hello"
assert deserialized.query_embedding == [0.1, 0.2, 0.3]
assert deserialized.created_at == 12345.0
assert deserialized.hit_count == 5
def test_serialize_response_no_tool_calls(self):
response = _make_response()
serialized = _serialize_response(response)
assert serialized["tool_calls"] == []
deserialized = _deserialize_response(serialized)
assert deserialized.tool_calls == []
# ---------------------------------------------------------------------------
# RedisLLMCache Tests (with mocked Redis)
# ---------------------------------------------------------------------------
class TestRedisLLMCache:
def _make_mock_redis(self):
"""Create a mock Redis client that simulates basic operations."""
mock = AsyncMock()
mock._data = {}
mock._sets = {}
async def mock_get(key):
return mock._data.get(key)
async def mock_set(key, value, ex=None):
mock._data[key] = value
async def mock_mget(keys):
return [mock._data.get(k) for k in keys]
async def mock_sadd(key, *members):
if key not in mock._sets:
mock._sets[key] = set()
mock._sets[key].update(members)
async def mock_smembers(key):
return mock._sets.get(key, set())
async def mock_scard(key):
return len(mock._sets.get(key, set()))
async def mock_delete(*keys):
for k in keys:
mock._data.pop(k, None)
async def mock_srem(key, *members):
if key in mock._sets:
mock._sets[key] -= set(members)
mock.get = mock_get
mock.set = mock_set
mock.mget = mock_mget
mock.sadd = mock_sadd
mock.smembers = mock_smembers
mock.scard = mock_scard
mock.delete = mock_delete
mock.srem = mock_srem
# Pipeline mock — collects commands and executes them on execute()
class MockPipeline:
def __init__(self):
self._commands = []
def set(self, key, value, ex=None):
self._commands.append(("set", key, value, ex))
def sadd(self, key, *members):
self._commands.append(("sadd", key, members))
def delete(self, *keys):
self._commands.append(("delete", keys))
def srem(self, key, *members):
self._commands.append(("srem", key, members))
async def execute(self):
for cmd in self._commands:
if cmd[0] == "set":
_, key, value, ex = cmd
mock._data[key] = value
elif cmd[0] == "sadd":
_, key, members = cmd
if key not in mock._sets:
mock._sets[key] = set()
mock._sets[key].update(members)
elif cmd[0] == "delete":
_, keys = cmd
for k in keys:
mock._data.pop(k, None)
elif cmd[0] == "srem":
_, key, members = cmd
if key in mock._sets:
mock._sets[key] -= set(members)
mock.pipeline = MagicMock(return_value=MockPipeline())
return mock
@pytest.mark.asyncio
async def test_exact_match_hit(self):
cache = RedisLLMCache()
mock_redis = self._make_mock_redis()
cache._redis = mock_redis
key = "test_key"
response = _make_response("Cached")
entry = CacheEntry(response=response, created_at=time.monotonic(), hit_count=0)
entry_json = json.dumps(_serialize_entry(entry))
# Simulate Redis already has the data
mock_redis._data[f"{cache.KEY_PREFIX}{key}"] = entry_json
result = await cache.get(key)
assert result.hit is True
assert result.match_type == "exact"
assert result.response.content == "Cached"
@pytest.mark.asyncio
async def test_exact_match_miss(self):
cache = RedisLLMCache()
mock_redis = self._make_mock_redis()
cache._redis = mock_redis
result = await cache.get("nonexistent_key")
assert result.hit is False
@pytest.mark.asyncio
async def test_redis_failure_returns_miss(self):
cache = RedisLLMCache()
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(side_effect=Exception("Connection refused"))
cache._redis = mock_redis
result = await cache.get("any_key")
assert result.hit is False
@pytest.mark.asyncio
async def test_put_stores_data(self):
cache = RedisLLMCache()
mock_redis = self._make_mock_redis()
cache._redis = mock_redis
key = "test_key"
response = _make_response()
emb = [0.1, 0.2, 0.3]
await cache.put(key, response, query_embedding=emb)
# Verify data was stored
assert f"{cache.KEY_PREFIX}{key}" in mock_redis._data
assert f"{cache.EMB_PREFIX}{key}" in mock_redis._data
assert key in mock_redis._sets.get(cache.INDEX_KEY, set())
@pytest.mark.asyncio
async def test_invalidate_all(self):
cache = RedisLLMCache()
mock_redis = self._make_mock_redis()
cache._redis = mock_redis
# Pre-populate
mock_redis._sets[cache.INDEX_KEY] = {"key1", "key2"}
mock_redis._data[f"{cache.KEY_PREFIX}key1"] = "data1"
mock_redis._data[f"{cache.KEY_PREFIX}key2"] = "data2"
count = await cache.invalidate()
assert count == 2
@pytest.mark.asyncio
async def test_stats(self):
cache = RedisLLMCache()
mock_redis = self._make_mock_redis()
cache._redis = mock_redis
mock_redis._sets[cache.INDEX_KEY] = {"k1", "k2", "k3"}
stats = await cache.stats()
assert stats["total_entries"] == 3
# ---------------------------------------------------------------------------
# Factory Tests
# ---------------------------------------------------------------------------
class TestCreateLLMCache:
def test_memory_backend(self):
cache = create_llm_cache(backend="memory")
assert isinstance(cache, InMemoryLLMCache)
def test_auto_backend_fallback(self):
"""When redis package is not available, auto falls back to InMemory."""
with patch.dict("sys.modules", {"redis.asyncio": None}):
# Force ImportError by making redis.asyncio unimportable
cache = create_llm_cache(backend="auto")
assert isinstance(cache, InMemoryLLMCache)
def test_redis_backend_with_redis_available(self):
"""When redis.asyncio is available, auto/redis returns RedisLLMCache."""
cache = create_llm_cache(backend="redis")
assert isinstance(cache, RedisLLMCache)
def test_auto_backend_with_redis_available(self):
cache = create_llm_cache(backend="auto")
assert isinstance(cache, RedisLLMCache)
def test_custom_parameters(self):
cache = create_llm_cache(
backend="memory",
max_entries=500,
exact_ttl=7200,
semantic_ttl=172800,
similarity_threshold=0.95,
)
assert isinstance(cache, InMemoryLLMCache)
assert cache._max_entries == 500
assert cache._exact_ttl == 7200
assert cache._semantic_ttl == 172800
assert cache._similarity_threshold == 0.95