"""Unit tests for LLM Cache Core (U1). Tests cover: - CacheKey generation (deterministic, component isolation) - InMemoryLLMCache (exact match, semantic match, TTL, LRU, stats) - RedisLLMCache (same tests with mocked Redis) - Factory function (backend selection, fallback) """ import json import time from unittest.mock import AsyncMock, MagicMock, patch import pytest from agentkit.llm.cache import ( CacheEntry, CacheResult, InMemoryLLMCache, RedisLLMCache, create_llm_cache, _serialize_response, _deserialize_response, _serialize_entry, _deserialize_entry, ) from agentkit.llm.cache_key import generate_cache_key from agentkit.llm.protocol import LLMResponse, TokenUsage, ToolCall # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- def _make_response( content: str = "Hello", model: str = "gpt-4o", prompt_tokens: int = 10, completion_tokens: int = 20, tool_calls: list[ToolCall] | None = None, ) -> LLMResponse: return LLMResponse( content=content, model=model, usage=TokenUsage( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, ), tool_calls=tool_calls or [], latency_ms=100.0, ) def _make_messages(user_content: str = "Hello") -> list[dict[str, str]]: return [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": user_content}, ] def _make_embedding(base_val: float = 1.0, dim: int = 128) -> list[float]: """Create a unit vector for similarity testing.""" vec = [base_val] * dim magnitude = sum(x**2 for x in vec) ** 0.5 return [x / magnitude for x in vec] if magnitude > 0 else vec def _make_similar_embedding(base: list[float], noise: float = 0.01) -> list[float]: """Create a vector similar to base with small noise.""" vec = [x + noise for x in base] magnitude = sum(x**2 for x in vec) ** 0.5 return [x / magnitude for x in vec] if magnitude > 0 else vec def _make_different_embedding(dim: int = 128) -> list[float]: """Create a vector with very low cosine similarity to _make_embedding().""" # _make_embedding(1.0) is all-positive unit vector. # Negate first half to create near-orthogonal vector. half = dim // 2 vec = [-1.0] * half + [1.0] * (dim - half) magnitude = sum(x**2 for x in vec) ** 0.5 return [x / magnitude for x in vec] if magnitude > 0 else vec # --------------------------------------------------------------------------- # CacheKey Tests # --------------------------------------------------------------------------- class TestCacheKey: def test_deterministic(self): """Same inputs produce same key.""" msgs = _make_messages() key1 = generate_cache_key("gpt-4o", msgs, 0.0) key2 = generate_cache_key("gpt-4o", msgs, 0.0) assert key1 == key2 assert len(key1) == 64 # SHA-256 hex def test_different_model(self): """Different model produces different key.""" msgs = _make_messages() key1 = generate_cache_key("gpt-4o", msgs, 0.0) key2 = generate_cache_key("gpt-3.5-turbo", msgs, 0.0) assert key1 != key2 def test_different_temperature(self): """Different temperature produces different key.""" msgs = _make_messages() key1 = generate_cache_key("gpt-4o", msgs, 0.0) key2 = generate_cache_key("gpt-4o", msgs, 0.7) assert key1 != key2 def test_different_messages(self): """Different messages produce different key.""" key1 = generate_cache_key("gpt-4o", _make_messages("Hello"), 0.0) key2 = generate_cache_key("gpt-4o", _make_messages("World"), 0.0) assert key1 != key2 def test_different_tools(self): """Different tools produce different key.""" msgs = _make_messages() tools1 = [{"type": "function", "function": {"name": "f1"}}] tools2 = [{"type": "function", "function": {"name": "f2"}}] key1 = generate_cache_key("gpt-4o", msgs, 0.0, tools=tools1) key2 = generate_cache_key("gpt-4o", msgs, 0.0, tools=tools2) assert key1 != key2 def test_none_tools_same_as_no_tools(self): """None tools and no tools produce same key.""" msgs = _make_messages() key1 = generate_cache_key("gpt-4o", msgs, 0.0, tools=None) key2 = generate_cache_key("gpt-4o", msgs, 0.0) assert key1 == key2 def test_system_prompt_extracted_from_messages(self): """System prompt is extracted from messages[0] with role=system.""" msgs = [ {"role": "system", "content": "Be concise"}, {"role": "user", "content": "Hello"}, ] key1 = generate_cache_key("gpt-4o", msgs, 0.0) msgs2 = [ {"role": "system", "content": "Be verbose"}, {"role": "user", "content": "Hello"}, ] key2 = generate_cache_key("gpt-4o", msgs2, 0.0) assert key1 != key2 def test_max_tokens_affects_key(self): """Different max_tokens produce different key.""" msgs = _make_messages() key1 = generate_cache_key("gpt-4o", msgs, 0.0, max_tokens=2000) key2 = generate_cache_key("gpt-4o", msgs, 0.0, max_tokens=4000) assert key1 != key2 # --------------------------------------------------------------------------- # InMemoryLLMCache Tests # --------------------------------------------------------------------------- class TestInMemoryLLMCache: @pytest.mark.asyncio async def test_exact_match_hit(self): cache = InMemoryLLMCache() key = "test_key_1" response = _make_response("Cached answer") await cache.put(key, response) result = await cache.get(key) assert result.hit is True assert result.match_type == "exact" assert result.response.content == "Cached answer" @pytest.mark.asyncio async def test_exact_match_miss(self): cache = InMemoryLLMCache() await cache.put("key1", _make_response()) result = await cache.get("key2") assert result.hit is False assert result.response is None @pytest.mark.asyncio async def test_semantic_match_hit(self): cache = InMemoryLLMCache(similarity_threshold=0.9) emb1 = _make_embedding(1.0, dim=64) emb_similar = _make_similar_embedding(emb1, noise=0.001) await cache.put("key1", _make_response("Cached"), query_embedding=emb1) result = await cache.semantic_search(emb_similar, threshold=0.9) assert result.hit is True assert result.match_type == "semantic" assert result.response.content == "Cached" @pytest.mark.asyncio async def test_semantic_match_miss(self): cache = InMemoryLLMCache(similarity_threshold=0.9) emb1 = _make_embedding(1.0, dim=64) emb_different = _make_different_embedding(dim=64) await cache.put("key1", _make_response("Cached"), query_embedding=emb1) result = await cache.semantic_search(emb_different, threshold=0.9) assert result.hit is False @pytest.mark.asyncio async def test_semantic_match_empty_cache(self): cache = InMemoryLLMCache() result = await cache.semantic_search(_make_embedding(dim=64)) assert result.hit is False @pytest.mark.asyncio async def test_ttl_expiry_exact(self): cache = InMemoryLLMCache(exact_ttl=1) # 1 second TTL await cache.put("key1", _make_response()) # Wait for expiry time.sleep(1.1) result = await cache.get("key1") assert result.hit is False @pytest.mark.asyncio async def test_ttl_expiry_semantic(self): cache = InMemoryLLMCache(semantic_ttl=1) emb = _make_embedding(dim=64) await cache.put("key1", _make_response(), query_embedding=emb) time.sleep(1.1) result = await cache.semantic_search(emb) assert result.hit is False @pytest.mark.asyncio async def test_lru_eviction(self): cache = InMemoryLLMCache(max_entries=3) for i in range(4): await cache.put(f"key{i}", _make_response(f"Response {i}")) # key0 should be evicted (oldest) result = await cache.get("key0") assert result.hit is False # key1-key3 should still be present for i in range(1, 4): result = await cache.get(f"key{i}") assert result.hit is True @pytest.mark.asyncio async def test_lru_access_refreshes(self): cache = InMemoryLLMCache(max_entries=3) await cache.put("key0", _make_response("R0")) await cache.put("key1", _make_response("R1")) await cache.put("key2", _make_response("R2")) # Access key0 to move it to most-recently-used await cache.get("key0") # Adding key3 should evict key1 (now LRU) await cache.put("key3", _make_response("R3")) result = await cache.get("key1") assert result.hit is False result = await cache.get("key0") assert result.hit is True @pytest.mark.asyncio async def test_invalidate_all(self): cache = InMemoryLLMCache() await cache.put("key1", _make_response()) await cache.put("key2", _make_response()) count = await cache.invalidate() assert count == 2 result = await cache.get("key1") assert result.hit is False @pytest.mark.asyncio async def test_invalidate_pattern(self): cache = InMemoryLLMCache() await cache.put("abc_1", _make_response()) await cache.put("abc_2", _make_response()) await cache.put("xyz_1", _make_response()) count = await cache.invalidate("abc_*") assert count == 2 result = await cache.get("xyz_1") assert result.hit is True @pytest.mark.asyncio async def test_stats(self): cache = InMemoryLLMCache() await cache.put("key1", _make_response()) await cache.get("key1") # hit await cache.get("key2") # miss stats = await cache.stats() assert stats["total_entries"] == 1 assert stats["total_hits"] == 1 assert stats["total_misses"] == 1 @pytest.mark.asyncio async def test_tool_calls_cached(self): cache = InMemoryLLMCache() tool_calls = [ ToolCall(id="call_1", name="search", arguments={"query": "test"}) ] response = _make_response(tool_calls=tool_calls) await cache.put("key1", response) result = await cache.get("key1") assert result.hit is True assert len(result.response.tool_calls) == 1 assert result.response.tool_calls[0].name == "search" @pytest.mark.asyncio async def test_put_without_embedding(self): cache = InMemoryLLMCache() await cache.put("key1", _make_response(), query_embedding=None) # Exact match should still work result = await cache.get("key1") assert result.hit is True # Semantic search should return miss (no embeddings) result = await cache.semantic_search(_make_embedding(dim=64)) assert result.hit is False @pytest.mark.asyncio async def test_put_updates_existing_key(self): cache = InMemoryLLMCache() await cache.put("key1", _make_response("Old")) await cache.put("key1", _make_response("New")) result = await cache.get("key1") assert result.hit is True assert result.response.content == "New" # --------------------------------------------------------------------------- # Serialization Tests # --------------------------------------------------------------------------- class TestSerialization: def test_serialize_deserialize_response(self): response = _make_response( content="Test", model="gpt-4o", prompt_tokens=5, completion_tokens=10, tool_calls=[ToolCall(id="c1", name="tool1", arguments={"k": "v"})], ) serialized = _serialize_response(response) deserialized = _deserialize_response(serialized) assert deserialized.content == "Test" assert deserialized.model == "gpt-4o" assert deserialized.usage.prompt_tokens == 5 assert deserialized.usage.completion_tokens == 10 assert len(deserialized.tool_calls) == 1 assert deserialized.tool_calls[0].name == "tool1" def test_serialize_deserialize_entry(self): entry = CacheEntry( response=_make_response(), query_embedding=[0.1, 0.2, 0.3], created_at=12345.0, hit_count=5, ) serialized = _serialize_entry(entry) deserialized = _deserialize_entry(serialized) assert deserialized.response.content == "Hello" assert deserialized.query_embedding == [0.1, 0.2, 0.3] assert deserialized.created_at == 12345.0 assert deserialized.hit_count == 5 def test_serialize_response_no_tool_calls(self): response = _make_response() serialized = _serialize_response(response) assert serialized["tool_calls"] == [] deserialized = _deserialize_response(serialized) assert deserialized.tool_calls == [] # --------------------------------------------------------------------------- # RedisLLMCache Tests (with mocked Redis) # --------------------------------------------------------------------------- class TestRedisLLMCache: def _make_mock_redis(self): """Create a mock Redis client that simulates basic operations.""" mock = AsyncMock() mock._data = {} mock._sets = {} async def mock_get(key): return mock._data.get(key) async def mock_set(key, value, ex=None): mock._data[key] = value async def mock_mget(keys): return [mock._data.get(k) for k in keys] async def mock_sadd(key, *members): if key not in mock._sets: mock._sets[key] = set() mock._sets[key].update(members) async def mock_smembers(key): return mock._sets.get(key, set()) async def mock_scard(key): return len(mock._sets.get(key, set())) async def mock_delete(*keys): for k in keys: mock._data.pop(k, None) async def mock_srem(key, *members): if key in mock._sets: mock._sets[key] -= set(members) mock.get = mock_get mock.set = mock_set mock.mget = mock_mget mock.sadd = mock_sadd mock.smembers = mock_smembers mock.scard = mock_scard mock.delete = mock_delete mock.srem = mock_srem # Pipeline mock — collects commands and executes them on execute() class MockPipeline: def __init__(self): self._commands = [] def set(self, key, value, ex=None): self._commands.append(("set", key, value, ex)) def sadd(self, key, *members): self._commands.append(("sadd", key, members)) def delete(self, *keys): self._commands.append(("delete", keys)) def srem(self, key, *members): self._commands.append(("srem", key, members)) async def execute(self): for cmd in self._commands: if cmd[0] == "set": _, key, value, ex = cmd mock._data[key] = value elif cmd[0] == "sadd": _, key, members = cmd if key not in mock._sets: mock._sets[key] = set() mock._sets[key].update(members) elif cmd[0] == "delete": _, keys = cmd for k in keys: mock._data.pop(k, None) elif cmd[0] == "srem": _, key, members = cmd if key in mock._sets: mock._sets[key] -= set(members) mock.pipeline = MagicMock(return_value=MockPipeline()) return mock @pytest.mark.asyncio async def test_exact_match_hit(self): cache = RedisLLMCache() mock_redis = self._make_mock_redis() cache._redis = mock_redis key = "test_key" response = _make_response("Cached") entry = CacheEntry(response=response, created_at=time.monotonic(), hit_count=0) entry_json = json.dumps(_serialize_entry(entry)) # Simulate Redis already has the data mock_redis._data[f"{cache.KEY_PREFIX}{key}"] = entry_json result = await cache.get(key) assert result.hit is True assert result.match_type == "exact" assert result.response.content == "Cached" @pytest.mark.asyncio async def test_exact_match_miss(self): cache = RedisLLMCache() mock_redis = self._make_mock_redis() cache._redis = mock_redis result = await cache.get("nonexistent_key") assert result.hit is False @pytest.mark.asyncio async def test_redis_failure_returns_miss(self): cache = RedisLLMCache() mock_redis = AsyncMock() mock_redis.get = AsyncMock(side_effect=Exception("Connection refused")) cache._redis = mock_redis result = await cache.get("any_key") assert result.hit is False @pytest.mark.asyncio async def test_put_stores_data(self): cache = RedisLLMCache() mock_redis = self._make_mock_redis() cache._redis = mock_redis key = "test_key" response = _make_response() emb = [0.1, 0.2, 0.3] await cache.put(key, response, query_embedding=emb) # Verify data was stored assert f"{cache.KEY_PREFIX}{key}" in mock_redis._data assert f"{cache.EMB_PREFIX}{key}" in mock_redis._data assert key in mock_redis._sets.get(cache.INDEX_KEY, set()) @pytest.mark.asyncio async def test_invalidate_all(self): cache = RedisLLMCache() mock_redis = self._make_mock_redis() cache._redis = mock_redis # Pre-populate mock_redis._sets[cache.INDEX_KEY] = {"key1", "key2"} mock_redis._data[f"{cache.KEY_PREFIX}key1"] = "data1" mock_redis._data[f"{cache.KEY_PREFIX}key2"] = "data2" count = await cache.invalidate() assert count == 2 @pytest.mark.asyncio async def test_stats(self): cache = RedisLLMCache() mock_redis = self._make_mock_redis() cache._redis = mock_redis mock_redis._sets[cache.INDEX_KEY] = {"k1", "k2", "k3"} stats = await cache.stats() assert stats["total_entries"] == 3 # --------------------------------------------------------------------------- # Factory Tests # --------------------------------------------------------------------------- class TestCreateLLMCache: def test_memory_backend(self): cache = create_llm_cache(backend="memory") assert isinstance(cache, InMemoryLLMCache) def test_auto_backend_fallback(self): """When redis package is not available, auto falls back to InMemory.""" with patch.dict("sys.modules", {"redis.asyncio": None}): # Force ImportError by making redis.asyncio unimportable cache = create_llm_cache(backend="auto") assert isinstance(cache, InMemoryLLMCache) def test_redis_backend_with_redis_available(self): """When redis.asyncio is available, auto/redis returns RedisLLMCache.""" cache = create_llm_cache(backend="redis") assert isinstance(cache, RedisLLMCache) def test_auto_backend_with_redis_available(self): cache = create_llm_cache(backend="auto") assert isinstance(cache, RedisLLMCache) def test_custom_parameters(self): cache = create_llm_cache( backend="memory", max_entries=500, exact_ttl=7200, semantic_ttl=172800, similarity_threshold=0.95, ) assert isinstance(cache, InMemoryLLMCache) assert cache._max_entries == 500 assert cache._exact_ttl == 7200 assert cache._semantic_ttl == 172800 assert cache._similarity_threshold == 0.95