454 lines
18 KiB
Python
454 lines
18 KiB
Python
"""U17 — LiteLLM 语义缓存管理器单元测试。
|
||
|
||
覆盖场景(plan 要求的 4 个 + 安全要求 e + 向后兼容 + fallback):
|
||
1. 语义相似查询命中缓存(mock litellm.acompletion 模拟缓存行为)
|
||
2. 不同 system prompt 不命中缓存
|
||
3. 缓存命中率可统计
|
||
4. 阈值调优生效(similarity_threshold 传入 RedisSemanticCache)
|
||
5. User A 不返回 User B 缓存(安全要求 e)
|
||
6. kb_acl_hash 隔离 — 不同 ACL hash 产生不同 key
|
||
7. kb_caching_disabled 禁用缓存(安全要求 c)
|
||
8. cache_params_for_hit / no_cache — 返回正确 dict
|
||
9. record_cache_result — 记录命中/未命中到 stats 计数器
|
||
10. LitellmCacheConfig.from_cache_config — 转换正确,similarity_threshold=0.87
|
||
11. LitellmCacheManager.enable/disable — litellm.cache 正确设置/清除
|
||
12. generate_cache_key 向后兼容 — user_id=None, kb_acl_hash=None 时与旧版相同
|
||
13. RedisSemanticCache fallback — redisvl 缺失时 auto backend 回退
|
||
|
||
测试用 ``unittest.mock.patch`` mock ``litellm.cache`` 全局变量,避免测试间污染。
|
||
每个测试后清理 ``litellm.cache = None``。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
from types import SimpleNamespace
|
||
from typing import Any
|
||
from unittest.mock import MagicMock, patch
|
||
|
||
import pytest
|
||
|
||
from agentkit.llm.cache import LitellmCacheConfig, LitellmCacheManager
|
||
from agentkit.llm.cache_key import generate_cache_key
|
||
from agentkit.llm.config import CacheConfig
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 测试辅助
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def _make_messages(user_content: str = "Hello") -> list[dict[str, str]]:
|
||
return [
|
||
{"role": "system", "content": "You are a helpful assistant."},
|
||
{"role": "user", "content": user_content},
|
||
]
|
||
|
||
|
||
def _make_litellm_response(
|
||
content: str = "Hello!",
|
||
model: str = "gpt-4o-mini",
|
||
prompt_tokens: int = 10,
|
||
completion_tokens: int = 5,
|
||
cache_key: str | None = None,
|
||
) -> SimpleNamespace:
|
||
"""构造 LiteLLM 响应(OpenAI ChatCompletion 格式),可选 cache_key 标记。"""
|
||
hidden_params: dict[str, Any] = {}
|
||
if cache_key is not None:
|
||
hidden_params["cache_key"] = cache_key
|
||
message = SimpleNamespace(content=content, tool_calls=None)
|
||
return SimpleNamespace(
|
||
choices=[SimpleNamespace(message=message)],
|
||
usage=SimpleNamespace(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens),
|
||
model=model,
|
||
_hidden_params=hidden_params,
|
||
)
|
||
|
||
|
||
@pytest.fixture(autouse=True)
|
||
def _cleanup_litellm_cache():
|
||
"""每个测试后清理 litellm.cache 全局变量,避免测试间污染。"""
|
||
yield
|
||
import litellm
|
||
|
||
litellm.cache = None
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 1. 语义相似查询命中缓存
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestCacheHit:
|
||
"""场景 1 — 相同请求第二次命中缓存。"""
|
||
|
||
async def test_second_request_is_cache_hit(self):
|
||
"""相同 cache_key 的第二次请求返回缓存响应(cache_hit=True)。"""
|
||
config = LitellmCacheConfig(enabled=True, backend="memory")
|
||
manager = LitellmCacheManager(config)
|
||
|
||
msgs = _make_messages()
|
||
cache_key = manager.build_cache_key("gpt-4o", msgs, 0.0)
|
||
cache_params = LitellmCacheManager.cache_params_for_hit(cache_key)
|
||
|
||
# 模拟 LiteLLM 缓存行为:第一次 miss,第二次 hit
|
||
call_count = 0
|
||
cache_store: dict[str, SimpleNamespace] = {}
|
||
|
||
async def fake_acompletion(**kwargs):
|
||
nonlocal call_count
|
||
call_count += 1
|
||
ck = kwargs.get("cache", {}).get("cache_key")
|
||
if ck and ck in cache_store:
|
||
# 缓存命中 — 返回带 cache_key 标记的响应
|
||
return _make_litellm_response(content="Cached!", cache_key=ck)
|
||
# 缓存未命中 — 调用"真实" API 并存储
|
||
resp = _make_litellm_response(content="Fresh response")
|
||
if ck:
|
||
cache_store[ck] = resp
|
||
return resp
|
||
|
||
with patch("litellm.acompletion", side_effect=fake_acompletion):
|
||
from agentkit.llm.providers.litellm_provider import LitellmProvider
|
||
|
||
provider = LitellmProvider(model_prefix="openai/", api_key="test")
|
||
from agentkit.llm.protocol import LLMRequest
|
||
|
||
# 第一次请求 — miss
|
||
req1 = LLMRequest(messages=msgs, model="gpt-4o", temperature=0.0, cache=cache_params)
|
||
resp1 = await provider.chat(req1)
|
||
assert resp1.cache_hit is False
|
||
assert resp1.content == "Fresh response"
|
||
|
||
# 第二次相同请求 — hit
|
||
req2 = LLMRequest(messages=msgs, model="gpt-4o", temperature=0.0, cache=cache_params)
|
||
resp2 = await provider.chat(req2)
|
||
assert resp2.cache_hit is True
|
||
assert resp2.content == "Cached!"
|
||
|
||
assert call_count == 2 # litellm.acompletion 被调用两次(LiteLLM 内部处理缓存)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 2. 不同 system prompt 不命中缓存
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestSystemPromptIsolation:
|
||
"""场景 2 — 不同 system prompt 产生不同 cache_key,不误命中。"""
|
||
|
||
def test_different_system_prompt_different_key(self):
|
||
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
||
msgs1 = [
|
||
{"role": "system", "content": "Be concise"},
|
||
{"role": "user", "content": "Hello"},
|
||
]
|
||
msgs2 = [
|
||
{"role": "system", "content": "Be verbose"},
|
||
{"role": "user", "content": "Hello"},
|
||
]
|
||
key1 = manager.build_cache_key("gpt-4o", msgs1, 0.0)
|
||
key2 = manager.build_cache_key("gpt-4o", msgs2, 0.0)
|
||
assert key1 != key2
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 3. 缓存命中率可统计
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestCacheStats:
|
||
"""场景 3 — LitellmCacheManager.stats() 返回正确 hits/misses。"""
|
||
|
||
def test_stats_after_hits_and_misses(self):
|
||
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
||
|
||
# 2 hits
|
||
manager.record_cache_result(True)
|
||
manager.record_cache_result(True)
|
||
# 1 miss
|
||
manager.record_cache_result(False)
|
||
|
||
stats = manager.stats()
|
||
assert stats["total_hits"] == 2
|
||
assert stats["total_misses"] == 1
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 4. 阈值调优生效
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestSimilarityThreshold:
|
||
"""场景 4 — similarity_threshold 传入 RedisSemanticCache 构造函数。"""
|
||
|
||
def test_threshold_passed_to_redis_semantic_cache(self):
|
||
"""RedisSemanticCache 构造时接收 similarity_threshold=0.87。"""
|
||
config = LitellmCacheConfig(
|
||
enabled=True,
|
||
backend="redis_semantic",
|
||
similarity_threshold=0.83,
|
||
redis_url="redis://localhost:6379",
|
||
)
|
||
manager = LitellmCacheManager(config)
|
||
|
||
mock_cache_instance = MagicMock()
|
||
with patch(
|
||
"litellm.caching.RedisSemanticCache", return_value=mock_cache_instance
|
||
) as mock_cls:
|
||
instance = manager._create_cache_instance()
|
||
mock_cls.assert_called_once_with(
|
||
redis_url="redis://localhost:6379",
|
||
similarity_threshold=0.83,
|
||
embedding_model="text-embedding-ada-002",
|
||
)
|
||
assert instance is mock_cache_instance
|
||
|
||
def test_default_threshold_is_087(self):
|
||
"""from_cache_config 固定 similarity_threshold=0.87。"""
|
||
old_config = CacheConfig(similarity_threshold=0.92)
|
||
litellm_config = LitellmCacheConfig.from_cache_config(old_config)
|
||
assert litellm_config.similarity_threshold == 0.87
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 5. User A 不返回 User B 缓存(安全要求 e)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestUserIsolation:
|
||
"""安全要求 e — User A 的查询不返回 User B 的缓存响应。"""
|
||
|
||
def test_different_users_different_keys(self):
|
||
"""不同 user_id 产生不同 cache_key。"""
|
||
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
||
msgs = _make_messages("What is my salary?")
|
||
key_a = manager.build_cache_key("gpt-4o", msgs, 0.0, user_id="user_a")
|
||
key_b = manager.build_cache_key("gpt-4o", msgs, 0.0, user_id="user_b")
|
||
assert key_a != key_b
|
||
|
||
def test_same_user_same_key(self):
|
||
"""相同 user_id 产生相同 cache_key。"""
|
||
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
||
msgs = _make_messages("What is my salary?")
|
||
key1 = manager.build_cache_key("gpt-4o", msgs, 0.0, user_id="user_a")
|
||
key2 = manager.build_cache_key("gpt-4o", msgs, 0.0, user_id="user_a")
|
||
assert key1 == key2
|
||
|
||
def test_no_user_id_same_as_no_user_id(self):
|
||
"""user_id=None 时两次调用产生相同 key(向后兼容)。"""
|
||
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
||
msgs = _make_messages()
|
||
key1 = manager.build_cache_key("gpt-4o", msgs, 0.0, user_id=None)
|
||
key2 = manager.build_cache_key("gpt-4o", msgs, 0.0, user_id=None)
|
||
assert key1 == key2
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 6. kb_acl_hash 隔离
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestACLIsolation:
|
||
"""安全要求 b — 不同 ACL scope 产生不同 cache_key。"""
|
||
|
||
def test_different_acl_hash_different_keys(self):
|
||
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
||
msgs = _make_messages("Summarize the document")
|
||
key1 = manager.build_cache_key("gpt-4o", msgs, 0.0, kb_acl_hash="acl_v1")
|
||
key2 = manager.build_cache_key("gpt-4o", msgs, 0.0, kb_acl_hash="acl_v2")
|
||
assert key1 != key2
|
||
|
||
def test_same_acl_hash_same_key(self):
|
||
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
||
msgs = _make_messages()
|
||
key1 = manager.build_cache_key("gpt-4o", msgs, 0.0, kb_acl_hash="acl_v1")
|
||
key2 = manager.build_cache_key("gpt-4o", msgs, 0.0, kb_acl_hash="acl_v1")
|
||
assert key1 == key2
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 7. kb_caching_disabled 禁用缓存
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestKBCachingDisabled:
|
||
"""安全要求 c — KB 设置 caching_disabled=True 时禁用缓存。"""
|
||
|
||
def test_should_cache_returns_false_when_disabled(self):
|
||
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
||
assert manager.should_cache(kb_caching_disabled=True) is False
|
||
|
||
def test_should_cache_returns_true_when_enabled(self):
|
||
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
||
assert manager.should_cache(kb_caching_disabled=False) is True
|
||
|
||
def test_should_cache_default_true(self):
|
||
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
||
assert manager.should_cache() is True
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 8. cache_params_for_hit / no_cache
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestCacheParams:
|
||
def test_cache_params_for_hit(self):
|
||
params = LitellmCacheManager.cache_params_for_hit("my_cache_key")
|
||
assert params == {"cache_key": "my_cache_key"}
|
||
|
||
def test_cache_params_for_no_cache(self):
|
||
params = LitellmCacheManager.cache_params_for_no_cache()
|
||
assert params == {"no-cache": True}
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 10. LitellmCacheConfig.from_cache_config
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestFromCacheConfig:
|
||
def test_basic_conversion(self):
|
||
old = CacheConfig(
|
||
enabled=True,
|
||
backend="memory",
|
||
redis_url="redis://myhost:6379",
|
||
semantic_ttl=3600,
|
||
embedding_model="bge-m3",
|
||
)
|
||
litellm_config = LitellmCacheConfig.from_cache_config(old)
|
||
assert litellm_config.enabled is True
|
||
assert litellm_config.backend == "memory"
|
||
assert litellm_config.redis_url == "redis://myhost:6379"
|
||
assert litellm_config.ttl == 3600
|
||
assert litellm_config.similarity_threshold == 0.87 # 固定,忽略旧 0.92
|
||
assert litellm_config.per_user_namespace is True # 强制开启
|
||
|
||
def test_unknown_backend_falls_to_auto(self):
|
||
old = CacheConfig(backend="unknown_backend")
|
||
litellm_config = LitellmCacheConfig.from_cache_config(old)
|
||
assert litellm_config.backend == "auto"
|
||
|
||
def test_redis_backend_preserved(self):
|
||
old = CacheConfig(backend="redis")
|
||
litellm_config = LitellmCacheConfig.from_cache_config(old)
|
||
assert litellm_config.backend == "redis"
|
||
|
||
def test_empty_embedding_model_falls_to_default(self):
|
||
old = CacheConfig(embedding_model="")
|
||
litellm_config = LitellmCacheConfig.from_cache_config(old)
|
||
assert litellm_config.embedding_model == "text-embedding-ada-002"
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 11. LitellmCacheManager.enable/disable
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestEnableDisable:
|
||
def test_enable_sets_litellm_cache(self):
|
||
import litellm
|
||
|
||
config = LitellmCacheConfig(backend="memory")
|
||
manager = LitellmCacheManager(config)
|
||
manager.enable()
|
||
assert litellm.cache is not None
|
||
assert manager._cache_instance is not None
|
||
|
||
def test_disable_clears_litellm_cache(self):
|
||
import litellm
|
||
|
||
config = LitellmCacheConfig(backend="memory")
|
||
manager = LitellmCacheManager(config)
|
||
manager.enable()
|
||
assert litellm.cache is not None
|
||
manager.disable()
|
||
assert litellm.cache is None
|
||
assert manager._cache_instance is None
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 12. generate_cache_key 向后兼容
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestBackwardCompatibility:
|
||
"""user_id=None, kb_acl_hash=None 时与旧版 generate_cache_key 完全一致。"""
|
||
|
||
def test_none_user_id_same_as_not_passing(self):
|
||
msgs = _make_messages()
|
||
key1 = generate_cache_key("gpt-4o", msgs, 0.0, user_id=None, kb_acl_hash=None)
|
||
key2 = generate_cache_key("gpt-4o", msgs, 0.0)
|
||
assert key1 == key2
|
||
|
||
def test_backward_compat_deterministic(self):
|
||
msgs = _make_messages()
|
||
key1 = generate_cache_key("gpt-4o", msgs, 0.0, user_id=None, kb_acl_hash=None)
|
||
key2 = generate_cache_key("gpt-4o", msgs, 0.0, user_id=None, kb_acl_hash=None)
|
||
assert key1 == key2
|
||
|
||
def test_user_id_changes_key(self):
|
||
msgs = _make_messages()
|
||
key_none = generate_cache_key("gpt-4o", msgs, 0.0)
|
||
key_user = generate_cache_key("gpt-4o", msgs, 0.0, user_id="user_a")
|
||
assert key_none != key_user
|
||
|
||
def test_acl_hash_changes_key(self):
|
||
msgs = _make_messages()
|
||
key_none = generate_cache_key("gpt-4o", msgs, 0.0)
|
||
key_acl = generate_cache_key("gpt-4o", msgs, 0.0, kb_acl_hash="acl_v1")
|
||
assert key_none != key_acl
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 13. RedisSemanticCache fallback
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestRedisSemanticCacheFallback:
|
||
"""redisvl 缺失时 auto backend 回退到 RedisCache。"""
|
||
|
||
def test_auto_fallback_to_redis_cache_when_redisvl_missing(self):
|
||
"""auto 模式下 RedisSemanticCache 构造 ImportError → 回退到 RedisCache。"""
|
||
import litellm.caching
|
||
|
||
config = LitellmCacheConfig(backend="auto", redis_url="redis://localhost:6379")
|
||
manager = LitellmCacheManager(config)
|
||
|
||
mock_redis_cache = MagicMock(name="RedisCacheInstance")
|
||
|
||
# 模拟 RedisSemanticCache 构造时 ImportError(redisvl 未安装)
|
||
def raise_import_error(*args, **kwargs):
|
||
raise ImportError("No module named 'redisvl'")
|
||
|
||
with patch.object(litellm.caching, "RedisSemanticCache", side_effect=raise_import_error):
|
||
with patch.object(
|
||
litellm.caching, "RedisCache", return_value=mock_redis_cache
|
||
) as mock_rc:
|
||
instance = manager._create_cache_instance()
|
||
mock_rc.assert_called_once_with(redis_url="redis://localhost:6379")
|
||
assert instance is mock_redis_cache
|
||
|
||
def test_redis_semantic_backend_raises_when_redisvl_missing(self):
|
||
"""显式 redis_semantic backend + redisvl 缺失 → raise ImportError。"""
|
||
import litellm.caching
|
||
|
||
config = LitellmCacheConfig(backend="redis_semantic")
|
||
manager = LitellmCacheManager(config)
|
||
|
||
def raise_import_error(*args, **kwargs):
|
||
raise ImportError("No module named 'redisvl'")
|
||
|
||
with patch.object(litellm.caching, "RedisSemanticCache", side_effect=raise_import_error):
|
||
with pytest.raises(ImportError):
|
||
manager._create_cache_instance()
|
||
|
||
def test_memory_backend_uses_in_memory_cache(self):
|
||
"""memory backend 直接使用 InMemoryCache,不尝试 Redis。"""
|
||
import litellm.caching
|
||
|
||
config = LitellmCacheConfig(backend="memory")
|
||
manager = LitellmCacheManager(config)
|
||
instance = manager._create_cache_instance()
|
||
assert isinstance(instance, litellm.caching.InMemoryCache)
|