214 lines
8.1 KiB
Python
214 lines
8.1 KiB
Python
"""Integration tests for LLM Cache integration into LLMGateway (U2/U17).
|
||
|
||
U17 更新:gateway 改用 ``LitellmCacheManager``(LiteLLM 内置缓存)。
|
||
旧的 ``InMemoryLLMCache`` 手动缓存逻辑已移除,缓存读写由 LiteLLM 内部处理。
|
||
测试用 ``CacheAwareMockProvider`` 模拟 LiteLLM 的缓存行为(检查 cache_key,
|
||
命中时返回 ``cache_hit=True`` 的响应)。
|
||
"""
|
||
|
||
import pytest
|
||
|
||
from agentkit.llm.config import CacheConfig, LLMConfig
|
||
from agentkit.llm.gateway import LLMGateway
|
||
from agentkit.llm.protocol import LLMProvider, LLMRequest, LLMResponse, TokenUsage
|
||
|
||
|
||
class MockProvider(LLMProvider):
|
||
"""Mock LLM provider that tracks call count (no cache awareness)."""
|
||
|
||
def __init__(self, response_content: str = "Mock response"):
|
||
self.call_count = 0
|
||
self._response_content = response_content
|
||
|
||
async def chat(self, request: LLMRequest) -> LLMResponse:
|
||
self.call_count += 1
|
||
return LLMResponse(
|
||
content=self._response_content,
|
||
model=request.model,
|
||
usage=TokenUsage(prompt_tokens=10, completion_tokens=20),
|
||
)
|
||
|
||
|
||
class CacheAwareMockProvider(LLMProvider):
|
||
"""Mock provider that simulates LiteLLM's cache behavior.
|
||
|
||
Reads ``request._cache`` for cache_key, maintains an internal cache dict.
|
||
On cache hit: returns cached response with ``cache_hit=True`` (call_count NOT incremented).
|
||
On cache miss: generates fresh response, caches it, returns with ``cache_hit=False``.
|
||
"""
|
||
|
||
def __init__(self, response_content: str = "Mock response"):
|
||
self.call_count = 0 # 仅统计真实调用(缓存未命中)
|
||
self._response_content = response_content
|
||
self._cache: dict[str, LLMResponse] = {}
|
||
|
||
async def chat(self, request: LLMRequest) -> LLMResponse:
|
||
cache_params = getattr(request, "_cache", None) or {}
|
||
cache_key = cache_params.get("cache_key")
|
||
no_cache = cache_params.get("no-cache", False)
|
||
|
||
# 缓存命中 — 返回缓存响应(不增加 call_count)
|
||
if cache_key and cache_key in self._cache and not no_cache:
|
||
cached = self._cache[cache_key]
|
||
return LLMResponse(
|
||
content=cached.content,
|
||
model=cached.model,
|
||
usage=cached.usage,
|
||
cache_hit=True,
|
||
)
|
||
|
||
# 缓存未命中 — 真实调用
|
||
self.call_count += 1
|
||
response = LLMResponse(
|
||
content=self._response_content,
|
||
model=request.model,
|
||
usage=TokenUsage(prompt_tokens=10, completion_tokens=20),
|
||
cache_hit=False,
|
||
)
|
||
if cache_key and not no_cache:
|
||
self._cache[cache_key] = response
|
||
return response
|
||
|
||
|
||
def _make_messages(user_content: str = "Hello") -> list[dict[str, str]]:
|
||
return [
|
||
{"role": "system", "content": "You are a helpful assistant."},
|
||
{"role": "user", "content": user_content},
|
||
]
|
||
|
||
|
||
class TestCacheDisabled:
|
||
@pytest.mark.asyncio
|
||
async def test_no_cache_by_default(self):
|
||
"""Cache is disabled by default — requests always hit provider."""
|
||
gateway = LLMGateway()
|
||
provider = MockProvider()
|
||
gateway.register_provider("test", provider)
|
||
|
||
msgs = _make_messages()
|
||
await gateway.chat(msgs, "test/model")
|
||
await gateway.chat(msgs, "test/model")
|
||
|
||
assert provider.call_count == 2
|
||
|
||
|
||
class TestCacheEnabled:
|
||
@pytest.mark.asyncio
|
||
async def test_first_request_is_miss(self):
|
||
"""First request is a cache miss — provider is called."""
|
||
config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory"))
|
||
gateway = LLMGateway(config=config)
|
||
provider = CacheAwareMockProvider()
|
||
gateway.register_provider("test", provider)
|
||
|
||
msgs = _make_messages()
|
||
response = await gateway.chat(msgs, "test/model", temperature=0.0)
|
||
|
||
assert provider.call_count == 1
|
||
assert response.content == "Mock response"
|
||
assert response.cache_hit is False
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_second_request_is_hit(self):
|
||
"""Second identical request is a cache hit — provider NOT called again."""
|
||
config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory"))
|
||
gateway = LLMGateway(config=config)
|
||
provider = CacheAwareMockProvider()
|
||
gateway.register_provider("test", provider)
|
||
|
||
msgs = _make_messages()
|
||
await gateway.chat(msgs, "test/model", temperature=0.0)
|
||
response = await gateway.chat(msgs, "test/model", temperature=0.0)
|
||
|
||
assert provider.call_count == 1 # Not called again (cache hit)
|
||
assert response.content == "Mock response"
|
||
assert response.cache_hit is True
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_cache_hit_usage_has_zero_cost(self):
|
||
"""Cache hit records usage with cost=0."""
|
||
config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory"))
|
||
gateway = LLMGateway(config=config)
|
||
provider = CacheAwareMockProvider()
|
||
gateway.register_provider("test", provider)
|
||
|
||
msgs = _make_messages()
|
||
await gateway.chat(msgs, "test/model", agent_name="agent1", temperature=0.0)
|
||
await gateway.chat(msgs, "test/model", agent_name="agent1", temperature=0.0)
|
||
|
||
usage = gateway.get_usage(agent_name="agent1")
|
||
# First request has cost, second (cache hit) has cost=0
|
||
assert usage.total_cost == 0.0 # No cost config, so both are 0
|
||
assert len(usage.records) == 2
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_different_messages_are_miss(self):
|
||
"""Different messages produce cache misses."""
|
||
config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory"))
|
||
gateway = LLMGateway(config=config)
|
||
provider = CacheAwareMockProvider()
|
||
gateway.register_provider("test", provider)
|
||
|
||
await gateway.chat(_make_messages("Hello"), "test/model", temperature=0.0)
|
||
await gateway.chat(_make_messages("World"), "test/model", temperature=0.0)
|
||
|
||
assert provider.call_count == 2
|
||
|
||
|
||
class TestCacheConfig:
|
||
def test_config_from_dict(self):
|
||
"""CacheConfig can be loaded from dict."""
|
||
config = LLMConfig.from_dict(
|
||
{
|
||
"cache": {
|
||
"enabled": True,
|
||
"backend": "memory",
|
||
"exact_ttl": 7200,
|
||
}
|
||
}
|
||
)
|
||
assert config.cache is not None
|
||
assert config.cache.enabled is True
|
||
assert config.cache.backend == "memory"
|
||
assert config.cache.exact_ttl == 7200
|
||
|
||
def test_config_from_dict_no_cache(self):
|
||
"""No cache section in config → cache is None."""
|
||
config = LLMConfig.from_dict({})
|
||
assert config.cache is None
|
||
|
||
def test_config_from_dict_embedding(self):
|
||
"""Embedding config is loaded correctly."""
|
||
config = LLMConfig.from_dict(
|
||
{
|
||
"cache": {
|
||
"enabled": True,
|
||
"embedding": {
|
||
"provider": "xinference",
|
||
"model": "bge-m3",
|
||
"base_url": "http://localhost:9997/v1",
|
||
},
|
||
}
|
||
}
|
||
)
|
||
assert config.cache.embedding_provider == "xinference"
|
||
assert config.cache.embedding_model == "bge-m3"
|
||
assert config.cache.embedding_base_url == "http://localhost:9997/v1"
|
||
|
||
def test_gateway_creates_cache_when_enabled(self):
|
||
"""Gateway creates cache_manager instance when cache.enabled=True."""
|
||
config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory"))
|
||
gateway = LLMGateway(config=config)
|
||
assert gateway._cache_manager is not None
|
||
|
||
def test_gateway_no_cache_when_disabled(self):
|
||
"""Gateway has no cache_manager when cache is disabled."""
|
||
config = LLMConfig(cache=CacheConfig(enabled=False))
|
||
gateway = LLMGateway(config=config)
|
||
assert gateway._cache_manager is None
|
||
|
||
def test_gateway_no_cache_when_no_config(self):
|
||
"""Gateway has no cache_manager when cache config is absent."""
|
||
gateway = LLMGateway()
|
||
assert gateway._cache_manager is None
|