163 lines
6.0 KiB
Python
163 lines
6.0 KiB
Python
"""Integration tests for LLM Cache integration into LLMGateway (U2)."""
|
|
|
|
import pytest
|
|
|
|
from agentkit.llm.cache import InMemoryLLMCache
|
|
from agentkit.llm.config import CacheConfig, LLMConfig
|
|
from agentkit.llm.gateway import LLMGateway
|
|
from agentkit.llm.protocol import LLMProvider, LLMRequest, LLMResponse, TokenUsage
|
|
|
|
|
|
class MockProvider(LLMProvider):
|
|
"""Mock LLM provider that tracks call count."""
|
|
|
|
def __init__(self, response_content: str = "Mock response"):
|
|
self.call_count = 0
|
|
self._response_content = response_content
|
|
|
|
async def chat(self, request: LLMRequest) -> LLMResponse:
|
|
self.call_count += 1
|
|
return LLMResponse(
|
|
content=self._response_content,
|
|
model=request.model,
|
|
usage=TokenUsage(prompt_tokens=10, completion_tokens=20),
|
|
)
|
|
|
|
|
|
def _make_messages(user_content: str = "Hello") -> list[dict[str, str]]:
|
|
return [
|
|
{"role": "system", "content": "You are a helpful assistant."},
|
|
{"role": "user", "content": user_content},
|
|
]
|
|
|
|
|
|
class TestCacheDisabled:
|
|
@pytest.mark.asyncio
|
|
async def test_no_cache_by_default(self):
|
|
"""Cache is disabled by default — requests always hit provider."""
|
|
gateway = LLMGateway()
|
|
provider = MockProvider()
|
|
gateway.register_provider("test", provider)
|
|
|
|
msgs = _make_messages()
|
|
await gateway.chat(msgs, "test/model")
|
|
await gateway.chat(msgs, "test/model")
|
|
|
|
assert provider.call_count == 2
|
|
|
|
|
|
class TestCacheEnabled:
|
|
@pytest.mark.asyncio
|
|
async def test_first_request_is_miss(self):
|
|
"""First request is a cache miss — provider is called."""
|
|
config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory"))
|
|
gateway = LLMGateway(config=config)
|
|
provider = MockProvider()
|
|
gateway.register_provider("test", provider)
|
|
|
|
msgs = _make_messages()
|
|
response = await gateway.chat(msgs, "test/model", temperature=0.0)
|
|
|
|
assert provider.call_count == 1
|
|
assert response.content == "Mock response"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_second_request_is_hit(self):
|
|
"""Second identical request is a cache hit — provider NOT called."""
|
|
config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory"))
|
|
gateway = LLMGateway(config=config)
|
|
provider = MockProvider()
|
|
gateway.register_provider("test", provider)
|
|
|
|
msgs = _make_messages()
|
|
await gateway.chat(msgs, "test/model", temperature=0.0)
|
|
response = await gateway.chat(msgs, "test/model", temperature=0.0)
|
|
|
|
assert provider.call_count == 1 # Not called again
|
|
assert response.content == "Mock response"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cache_hit_usage_has_zero_cost(self):
|
|
"""Cache hit records usage with cost=0."""
|
|
config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory"))
|
|
gateway = LLMGateway(config=config)
|
|
provider = MockProvider()
|
|
gateway.register_provider("test", provider)
|
|
|
|
msgs = _make_messages()
|
|
await gateway.chat(msgs, "test/model", agent_name="agent1", temperature=0.0)
|
|
await gateway.chat(msgs, "test/model", agent_name="agent1", temperature=0.0)
|
|
|
|
usage = gateway.get_usage(agent_name="agent1")
|
|
# First request has cost, second (cache hit) has cost=0
|
|
assert usage.total_cost == 0.0 # No cost config, so both are 0
|
|
assert len(usage.records) == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_different_messages_are_miss(self):
|
|
"""Different messages produce cache misses."""
|
|
config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory"))
|
|
gateway = LLMGateway(config=config)
|
|
provider = MockProvider()
|
|
gateway.register_provider("test", provider)
|
|
|
|
await gateway.chat(_make_messages("Hello"), "test/model", temperature=0.0)
|
|
await gateway.chat(_make_messages("World"), "test/model", temperature=0.0)
|
|
|
|
assert provider.call_count == 2
|
|
|
|
|
|
class TestCacheConfig:
|
|
def test_config_from_dict(self):
|
|
"""CacheConfig can be loaded from dict."""
|
|
config = LLMConfig.from_dict({
|
|
"cache": {
|
|
"enabled": True,
|
|
"backend": "memory",
|
|
"exact_ttl": 7200,
|
|
}
|
|
})
|
|
assert config.cache is not None
|
|
assert config.cache.enabled is True
|
|
assert config.cache.backend == "memory"
|
|
assert config.cache.exact_ttl == 7200
|
|
|
|
def test_config_from_dict_no_cache(self):
|
|
"""No cache section in config → cache is None."""
|
|
config = LLMConfig.from_dict({})
|
|
assert config.cache is None
|
|
|
|
def test_config_from_dict_embedding(self):
|
|
"""Embedding config is loaded correctly."""
|
|
config = LLMConfig.from_dict({
|
|
"cache": {
|
|
"enabled": True,
|
|
"embedding": {
|
|
"provider": "xinference",
|
|
"model": "bge-m3",
|
|
"base_url": "http://localhost:9997/v1",
|
|
},
|
|
}
|
|
})
|
|
assert config.cache.embedding_provider == "xinference"
|
|
assert config.cache.embedding_model == "bge-m3"
|
|
assert config.cache.embedding_base_url == "http://localhost:9997/v1"
|
|
|
|
def test_gateway_creates_cache_when_enabled(self):
|
|
"""Gateway creates cache instance when cache.enabled=True."""
|
|
config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory"))
|
|
gateway = LLMGateway(config=config)
|
|
assert gateway._cache is not None
|
|
assert isinstance(gateway._cache, InMemoryLLMCache)
|
|
|
|
def test_gateway_no_cache_when_disabled(self):
|
|
"""Gateway has no cache when cache is disabled."""
|
|
config = LLMConfig(cache=CacheConfig(enabled=False))
|
|
gateway = LLMGateway(config=config)
|
|
assert gateway._cache is None
|
|
|
|
def test_gateway_no_cache_when_no_config(self):
|
|
"""Gateway has no cache when cache config is absent."""
|
|
gateway = LLMGateway()
|
|
assert gateway._cache is None
|