fischer-agentkit/tests/unit/test_gateway_cache.py

163 lines
6.0 KiB
Python

"""Integration tests for LLM Cache integration into LLMGateway (U2)."""
import pytest
from agentkit.llm.cache import InMemoryLLMCache
from agentkit.llm.config import CacheConfig, LLMConfig
from agentkit.llm.gateway import LLMGateway
from agentkit.llm.protocol import LLMProvider, LLMRequest, LLMResponse, TokenUsage
class MockProvider(LLMProvider):
"""Mock LLM provider that tracks call count."""
def __init__(self, response_content: str = "Mock response"):
self.call_count = 0
self._response_content = response_content
async def chat(self, request: LLMRequest) -> LLMResponse:
self.call_count += 1
return LLMResponse(
content=self._response_content,
model=request.model,
usage=TokenUsage(prompt_tokens=10, completion_tokens=20),
)
def _make_messages(user_content: str = "Hello") -> list[dict[str, str]]:
return [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": user_content},
]
class TestCacheDisabled:
@pytest.mark.asyncio
async def test_no_cache_by_default(self):
"""Cache is disabled by default — requests always hit provider."""
gateway = LLMGateway()
provider = MockProvider()
gateway.register_provider("test", provider)
msgs = _make_messages()
await gateway.chat(msgs, "test/model")
await gateway.chat(msgs, "test/model")
assert provider.call_count == 2
class TestCacheEnabled:
@pytest.mark.asyncio
async def test_first_request_is_miss(self):
"""First request is a cache miss — provider is called."""
config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory"))
gateway = LLMGateway(config=config)
provider = MockProvider()
gateway.register_provider("test", provider)
msgs = _make_messages()
response = await gateway.chat(msgs, "test/model", temperature=0.0)
assert provider.call_count == 1
assert response.content == "Mock response"
@pytest.mark.asyncio
async def test_second_request_is_hit(self):
"""Second identical request is a cache hit — provider NOT called."""
config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory"))
gateway = LLMGateway(config=config)
provider = MockProvider()
gateway.register_provider("test", provider)
msgs = _make_messages()
await gateway.chat(msgs, "test/model", temperature=0.0)
response = await gateway.chat(msgs, "test/model", temperature=0.0)
assert provider.call_count == 1 # Not called again
assert response.content == "Mock response"
@pytest.mark.asyncio
async def test_cache_hit_usage_has_zero_cost(self):
"""Cache hit records usage with cost=0."""
config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory"))
gateway = LLMGateway(config=config)
provider = MockProvider()
gateway.register_provider("test", provider)
msgs = _make_messages()
await gateway.chat(msgs, "test/model", agent_name="agent1", temperature=0.0)
await gateway.chat(msgs, "test/model", agent_name="agent1", temperature=0.0)
usage = gateway.get_usage(agent_name="agent1")
# First request has cost, second (cache hit) has cost=0
assert usage.total_cost == 0.0 # No cost config, so both are 0
assert len(usage.records) == 2
@pytest.mark.asyncio
async def test_different_messages_are_miss(self):
"""Different messages produce cache misses."""
config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory"))
gateway = LLMGateway(config=config)
provider = MockProvider()
gateway.register_provider("test", provider)
await gateway.chat(_make_messages("Hello"), "test/model", temperature=0.0)
await gateway.chat(_make_messages("World"), "test/model", temperature=0.0)
assert provider.call_count == 2
class TestCacheConfig:
def test_config_from_dict(self):
"""CacheConfig can be loaded from dict."""
config = LLMConfig.from_dict({
"cache": {
"enabled": True,
"backend": "memory",
"exact_ttl": 7200,
}
})
assert config.cache is not None
assert config.cache.enabled is True
assert config.cache.backend == "memory"
assert config.cache.exact_ttl == 7200
def test_config_from_dict_no_cache(self):
"""No cache section in config → cache is None."""
config = LLMConfig.from_dict({})
assert config.cache is None
def test_config_from_dict_embedding(self):
"""Embedding config is loaded correctly."""
config = LLMConfig.from_dict({
"cache": {
"enabled": True,
"embedding": {
"provider": "xinference",
"model": "bge-m3",
"base_url": "http://localhost:9997/v1",
},
}
})
assert config.cache.embedding_provider == "xinference"
assert config.cache.embedding_model == "bge-m3"
assert config.cache.embedding_base_url == "http://localhost:9997/v1"
def test_gateway_creates_cache_when_enabled(self):
"""Gateway creates cache instance when cache.enabled=True."""
config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory"))
gateway = LLMGateway(config=config)
assert gateway._cache is not None
assert isinstance(gateway._cache, InMemoryLLMCache)
def test_gateway_no_cache_when_disabled(self):
"""Gateway has no cache when cache is disabled."""
config = LLMConfig(cache=CacheConfig(enabled=False))
gateway = LLMGateway(config=config)
assert gateway._cache is None
def test_gateway_no_cache_when_no_config(self):
"""Gateway has no cache when cache config is absent."""
gateway = LLMGateway()
assert gateway._cache is None