fischer-agentkit/tests/unit/test_gateway_cache.py

214 lines
8.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Integration tests for LLM Cache integration into LLMGateway (U2/U17).
U17 更新gateway 改用 ``LitellmCacheManager``LiteLLM 内置缓存)。
旧的 ``InMemoryLLMCache`` 手动缓存逻辑已移除,缓存读写由 LiteLLM 内部处理。
测试用 ``CacheAwareMockProvider`` 模拟 LiteLLM 的缓存行为(检查 cache_key
命中时返回 ``cache_hit=True`` 的响应)。
"""
import pytest
from agentkit.llm.config import CacheConfig, LLMConfig
from agentkit.llm.gateway import LLMGateway
from agentkit.llm.protocol import LLMProvider, LLMRequest, LLMResponse, TokenUsage
class MockProvider(LLMProvider):
"""Mock LLM provider that tracks call count (no cache awareness)."""
def __init__(self, response_content: str = "Mock response"):
self.call_count = 0
self._response_content = response_content
async def chat(self, request: LLMRequest) -> LLMResponse:
self.call_count += 1
return LLMResponse(
content=self._response_content,
model=request.model,
usage=TokenUsage(prompt_tokens=10, completion_tokens=20),
)
class CacheAwareMockProvider(LLMProvider):
"""Mock provider that simulates LiteLLM's cache behavior.
Reads ``request._cache`` for cache_key, maintains an internal cache dict.
On cache hit: returns cached response with ``cache_hit=True`` (call_count NOT incremented).
On cache miss: generates fresh response, caches it, returns with ``cache_hit=False``.
"""
def __init__(self, response_content: str = "Mock response"):
self.call_count = 0 # 仅统计真实调用(缓存未命中)
self._response_content = response_content
self._cache: dict[str, LLMResponse] = {}
async def chat(self, request: LLMRequest) -> LLMResponse:
cache_params = getattr(request, "_cache", None) or {}
cache_key = cache_params.get("cache_key")
no_cache = cache_params.get("no-cache", False)
# 缓存命中 — 返回缓存响应(不增加 call_count
if cache_key and cache_key in self._cache and not no_cache:
cached = self._cache[cache_key]
return LLMResponse(
content=cached.content,
model=cached.model,
usage=cached.usage,
cache_hit=True,
)
# 缓存未命中 — 真实调用
self.call_count += 1
response = LLMResponse(
content=self._response_content,
model=request.model,
usage=TokenUsage(prompt_tokens=10, completion_tokens=20),
cache_hit=False,
)
if cache_key and not no_cache:
self._cache[cache_key] = response
return response
def _make_messages(user_content: str = "Hello") -> list[dict[str, str]]:
return [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": user_content},
]
class TestCacheDisabled:
@pytest.mark.asyncio
async def test_no_cache_by_default(self):
"""Cache is disabled by default — requests always hit provider."""
gateway = LLMGateway()
provider = MockProvider()
gateway.register_provider("test", provider)
msgs = _make_messages()
await gateway.chat(msgs, "test/model")
await gateway.chat(msgs, "test/model")
assert provider.call_count == 2
class TestCacheEnabled:
@pytest.mark.asyncio
async def test_first_request_is_miss(self):
"""First request is a cache miss — provider is called."""
config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory"))
gateway = LLMGateway(config=config)
provider = CacheAwareMockProvider()
gateway.register_provider("test", provider)
msgs = _make_messages()
response = await gateway.chat(msgs, "test/model", temperature=0.0)
assert provider.call_count == 1
assert response.content == "Mock response"
assert response.cache_hit is False
@pytest.mark.asyncio
async def test_second_request_is_hit(self):
"""Second identical request is a cache hit — provider NOT called again."""
config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory"))
gateway = LLMGateway(config=config)
provider = CacheAwareMockProvider()
gateway.register_provider("test", provider)
msgs = _make_messages()
await gateway.chat(msgs, "test/model", temperature=0.0)
response = await gateway.chat(msgs, "test/model", temperature=0.0)
assert provider.call_count == 1 # Not called again (cache hit)
assert response.content == "Mock response"
assert response.cache_hit is True
@pytest.mark.asyncio
async def test_cache_hit_usage_has_zero_cost(self):
"""Cache hit records usage with cost=0."""
config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory"))
gateway = LLMGateway(config=config)
provider = CacheAwareMockProvider()
gateway.register_provider("test", provider)
msgs = _make_messages()
await gateway.chat(msgs, "test/model", agent_name="agent1", temperature=0.0)
await gateway.chat(msgs, "test/model", agent_name="agent1", temperature=0.0)
usage = gateway.get_usage(agent_name="agent1")
# First request has cost, second (cache hit) has cost=0
assert usage.total_cost == 0.0 # No cost config, so both are 0
assert len(usage.records) == 2
@pytest.mark.asyncio
async def test_different_messages_are_miss(self):
"""Different messages produce cache misses."""
config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory"))
gateway = LLMGateway(config=config)
provider = CacheAwareMockProvider()
gateway.register_provider("test", provider)
await gateway.chat(_make_messages("Hello"), "test/model", temperature=0.0)
await gateway.chat(_make_messages("World"), "test/model", temperature=0.0)
assert provider.call_count == 2
class TestCacheConfig:
def test_config_from_dict(self):
"""CacheConfig can be loaded from dict."""
config = LLMConfig.from_dict(
{
"cache": {
"enabled": True,
"backend": "memory",
"exact_ttl": 7200,
}
}
)
assert config.cache is not None
assert config.cache.enabled is True
assert config.cache.backend == "memory"
assert config.cache.exact_ttl == 7200
def test_config_from_dict_no_cache(self):
"""No cache section in config → cache is None."""
config = LLMConfig.from_dict({})
assert config.cache is None
def test_config_from_dict_embedding(self):
"""Embedding config is loaded correctly."""
config = LLMConfig.from_dict(
{
"cache": {
"enabled": True,
"embedding": {
"provider": "xinference",
"model": "bge-m3",
"base_url": "http://localhost:9997/v1",
},
}
}
)
assert config.cache.embedding_provider == "xinference"
assert config.cache.embedding_model == "bge-m3"
assert config.cache.embedding_base_url == "http://localhost:9997/v1"
def test_gateway_creates_cache_when_enabled(self):
"""Gateway creates cache_manager instance when cache.enabled=True."""
config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory"))
gateway = LLMGateway(config=config)
assert gateway._cache_manager is not None
def test_gateway_no_cache_when_disabled(self):
"""Gateway has no cache_manager when cache is disabled."""
config = LLMConfig(cache=CacheConfig(enabled=False))
gateway = LLMGateway(config=config)
assert gateway._cache_manager is None
def test_gateway_no_cache_when_no_config(self):
"""Gateway has no cache_manager when cache config is absent."""
gateway = LLMGateway()
assert gateway._cache_manager is None