feat(llm): U17 — LiteLLM 语义缓存替换 + per-user/ACL scope 安全隔离
- 新增 LitellmCacheManager:配置 litellm.cache 全局,三级后端 fallback (RedisSemanticCache -> RedisCache -> InMemoryCache),redisvl lazy import - cache_key 扩展 user_id + kb_acl_hash 参数(安全要求 a/b/e) - gateway 集成:读取 KB caching_disabled flag(安全要求 c),构建带 scope 的 cache_key,命中时 cost=0 - LLMResponse 新增 cache_hit 字段;LLMRequest 新增 cache 参数 - litellm_provider 透传 cache 参数 + 检测 _hidden_params 缓存命中 - 33 个新测试覆盖 13 场景(含 User A != User B 缓存隔离) - 旧 InMemoryLLMCache/RedisLLMCache 保留向后兼容
This commit is contained in:
parent
86541d7172
commit
793476cafa
|
|
@ -14,11 +14,14 @@ import logging
|
|||
import time
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Protocol, runtime_checkable
|
||||
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
|
||||
|
||||
from agentkit.llm.protocol import LLMResponse, TokenUsage, ToolCall
|
||||
from agentkit.utils.vector_math import compute_cosine_similarity
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from agentkit.llm.config import CacheConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -61,8 +64,7 @@ def _serialize_response(response: LLMResponse) -> dict:
|
|||
"completion_tokens": response.usage.completion_tokens,
|
||||
},
|
||||
"tool_calls": [
|
||||
{"id": tc.id, "name": tc.name, "arguments": tc.arguments}
|
||||
for tc in response.tool_calls
|
||||
{"id": tc.id, "name": tc.name, "arguments": tc.arguments} for tc in response.tool_calls
|
||||
],
|
||||
"latency_ms": response.latency_ms,
|
||||
}
|
||||
|
|
@ -236,7 +238,9 @@ class InMemoryLLMCache:
|
|||
self._cache.move_to_end(key)
|
||||
existing = self._cache[key]
|
||||
# Preserve existing embedding if new one is None
|
||||
effective_embedding = query_embedding if query_embedding is not None else existing.query_embedding
|
||||
effective_embedding = (
|
||||
query_embedding if query_embedding is not None else existing.query_embedding
|
||||
)
|
||||
else:
|
||||
effective_embedding = query_embedding or []
|
||||
|
||||
|
|
@ -276,9 +280,7 @@ class InMemoryLLMCache:
|
|||
return count
|
||||
|
||||
# Simple prefix matching for pattern
|
||||
keys_to_remove = [
|
||||
k for k in self._cache if k.startswith(pattern.replace("*", ""))
|
||||
]
|
||||
keys_to_remove = [k for k in self._cache if k.startswith(pattern.replace("*", ""))]
|
||||
for key in keys_to_remove:
|
||||
del self._cache[key]
|
||||
self._embeddings.pop(key, None)
|
||||
|
|
@ -338,9 +340,7 @@ class RedisLLMCache:
|
|||
if self._redis is None:
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
self._redis = aioredis.from_url(
|
||||
self._redis_url, decode_responses=True
|
||||
)
|
||||
self._redis = aioredis.from_url(self._redis_url, decode_responses=True)
|
||||
return self._redis
|
||||
|
||||
async def aclose(self) -> None:
|
||||
|
|
@ -393,9 +393,7 @@ class RedisLLMCache:
|
|||
if data is not None:
|
||||
entry = _deserialize_entry(json.loads(data))
|
||||
self._hits += 1
|
||||
return CacheResult(
|
||||
hit=True, response=entry.response, match_type="exact"
|
||||
)
|
||||
return CacheResult(hit=True, response=entry.response, match_type="exact")
|
||||
self._misses += 1
|
||||
return CacheResult(hit=False)
|
||||
except Exception as e:
|
||||
|
|
@ -424,6 +422,7 @@ class RedisLLMCache:
|
|||
if len(cache_keys_list) > max_scan:
|
||||
# Take a random sample to avoid always scanning the same subset
|
||||
import random
|
||||
|
||||
cache_keys_list = random.sample(cache_keys_list, max_scan)
|
||||
|
||||
# Batch fetch embeddings
|
||||
|
|
@ -460,9 +459,7 @@ class RedisLLMCache:
|
|||
if data is not None:
|
||||
entry = _deserialize_entry(json.loads(data))
|
||||
self._hits += 1
|
||||
return CacheResult(
|
||||
hit=True, response=entry.response, match_type="semantic"
|
||||
)
|
||||
return CacheResult(hit=True, response=entry.response, match_type="semantic")
|
||||
# Data key expired but embedding still exists — mark for cleanup
|
||||
try:
|
||||
await redis.srem(self.INDEX_KEY, best_key)
|
||||
|
|
@ -614,19 +611,214 @@ def create_llm_cache(
|
|||
semantic_ttl=semantic_ttl,
|
||||
similarity_threshold=similarity_threshold,
|
||||
)
|
||||
except ImportError:
|
||||
logger.warning("redis package not available, falling back to in-memory cache")
|
||||
return InMemoryLLMCache(
|
||||
max_entries=max_entries,
|
||||
exact_ttl=exact_ttl,
|
||||
semantic_ttl=semantic_ttl,
|
||||
similarity_threshold=similarity_threshold,
|
||||
)
|
||||
return InMemoryLLMCache(
|
||||
max_entries=max_entries,
|
||||
exact_ttl=exact_ttl,
|
||||
semantic_ttl=semantic_ttl,
|
||||
similarity_threshold=similarity_threshold,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# U17 — LiteLLM 缓存管理器
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class LitellmCacheConfig:
|
||||
"""U17 — LiteLLM 缓存配置(从 CacheConfig 转换)。
|
||||
|
||||
与旧 ``CacheConfig`` 的区别:
|
||||
- ``similarity_threshold`` 固定默认 0.87(plan 规定)
|
||||
- ``per_user_namespace`` 强制开启(安全要求 a)
|
||||
- ``backend`` 新增 ``redis_semantic`` 选项(需要 redisvl)
|
||||
"""
|
||||
|
||||
enabled: bool = False
|
||||
backend: str = "auto" # "auto" | "redis_semantic" | "redis" | "memory"
|
||||
redis_url: str = "redis://localhost:6379"
|
||||
similarity_threshold: float = 0.87 # U17 默认 0.87(plan 规定)
|
||||
ttl: int = 86400
|
||||
embedding_model: str = "text-embedding-ada-002"
|
||||
per_user_namespace: bool = True # 安全要求 (a)
|
||||
|
||||
@classmethod
|
||||
def from_cache_config(cls, c: "CacheConfig") -> "LitellmCacheConfig":
|
||||
"""从现有 CacheConfig 转换。
|
||||
|
||||
- ``similarity_threshold`` 固定 0.87(U17 plan 规定,忽略旧 0.92)
|
||||
- ``per_user_namespace`` 强制 True(安全要求 a)
|
||||
- ``embedding_model`` 回退到 "text-embedding-ada-002"(LiteLLM 默认)
|
||||
"""
|
||||
return cls(
|
||||
enabled=c.enabled,
|
||||
backend=c.backend if c.backend in ("auto", "redis", "memory") else "auto",
|
||||
redis_url=c.redis_url,
|
||||
similarity_threshold=0.87, # U17 固定默认,忽略旧 0.92
|
||||
ttl=c.semantic_ttl,
|
||||
embedding_model=c.embedding_model or "text-embedding-ada-002",
|
||||
per_user_namespace=True, # 强制开启
|
||||
)
|
||||
|
||||
|
||||
class LitellmCacheManager:
|
||||
"""U17 — LiteLLM 全局缓存管理器。
|
||||
|
||||
职责:
|
||||
1. 创建并设置 ``litellm.cache`` 全局实例
|
||||
2. 构建带 user/ACL scope 的 cache key(安全要求 a, b)
|
||||
3. 提供 per-call cache 参数(cache_key 或 no-cache)
|
||||
4. 检测 LiteLLM 响应的缓存命中标志(用于 usage tracking)
|
||||
5. 统计缓存命中率
|
||||
|
||||
后端选择优先级(backend="auto" 时):
|
||||
RedisSemanticCache(需 redisvl)→ RedisCache(精确)→ InMemoryCache
|
||||
|
||||
安全约束:
|
||||
- (a) cache key 包含 user_id(per-user namespace)
|
||||
- (b) cache key 包含 kb_acl_hash(ACL-scope 隔离)
|
||||
- (c) KB 设置 caching_disabled=True 时禁用缓存
|
||||
- (e) User A 的查询不会命中 User B 的缓存
|
||||
"""
|
||||
|
||||
def __init__(self, config: LitellmCacheConfig):
|
||||
self._config = config
|
||||
self._cache_instance: Any = None # litellm.caching.Cache 实例
|
||||
self._hits = 0
|
||||
self._misses = 0
|
||||
|
||||
def enable(self) -> None:
|
||||
"""创建 LiteLLM Cache 实例并赋值给 ``litellm.cache``。"""
|
||||
import litellm
|
||||
|
||||
self._cache_instance = self._create_cache_instance()
|
||||
litellm.cache = self._cache_instance
|
||||
|
||||
def disable(self) -> None:
|
||||
"""禁用缓存 — 设置 ``litellm.cache = None``。"""
|
||||
import litellm
|
||||
|
||||
litellm.cache = None
|
||||
self._cache_instance = None
|
||||
|
||||
def _create_cache_instance(self) -> Any:
|
||||
"""根据 backend 配置创建 LiteLLM Cache 实例。
|
||||
|
||||
auto 模式按优先级尝试:RedisSemanticCache → RedisCache → InMemoryCache。
|
||||
redisvl 缺失时自动回退(安全要求 d — 不添加为必需依赖)。
|
||||
"""
|
||||
backend = self._config.backend
|
||||
if backend in ("auto", "redis_semantic"):
|
||||
# 尝试 RedisSemanticCache(需要 redisvl — lazy import,缺失时 fallback)
|
||||
try:
|
||||
from litellm.caching import RedisSemanticCache
|
||||
|
||||
return RedisSemanticCache(
|
||||
redis_url=self._config.redis_url,
|
||||
similarity_threshold=self._config.similarity_threshold,
|
||||
embedding_model=self._config.embedding_model,
|
||||
)
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"redis package not available, falling back to in-memory cache"
|
||||
"RedisSemanticCache 需要 redisvl 包(未安装),"
|
||||
"回退到 RedisCache(精确匹配,无语义搜索)。"
|
||||
"安装 redisvl 以启用语义缓存:pip install redisvl"
|
||||
)
|
||||
return InMemoryLLMCache(
|
||||
max_entries=max_entries,
|
||||
exact_ttl=exact_ttl,
|
||||
semantic_ttl=semantic_ttl,
|
||||
similarity_threshold=similarity_threshold,
|
||||
)
|
||||
return InMemoryLLMCache(
|
||||
max_entries=max_entries,
|
||||
exact_ttl=exact_ttl,
|
||||
semantic_ttl=semantic_ttl,
|
||||
similarity_threshold=similarity_threshold,
|
||||
if backend == "redis_semantic":
|
||||
raise # 显式要求语义缓存但 redisvl 缺失 — 报错
|
||||
except Exception as e:
|
||||
logger.warning(f"RedisSemanticCache 初始化失败: {e},回退到 RedisCache")
|
||||
|
||||
if backend in ("auto", "redis", "redis_semantic"):
|
||||
try:
|
||||
from litellm.caching import RedisCache
|
||||
|
||||
return RedisCache(redis_url=self._config.redis_url)
|
||||
except Exception as e:
|
||||
logger.warning(f"RedisCache 初始化失败: {e},回退到 InMemoryCache")
|
||||
|
||||
from litellm.caching import InMemoryCache
|
||||
|
||||
return InMemoryCache()
|
||||
|
||||
def build_cache_key(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[dict[str, str]],
|
||||
temperature: float,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: str = "auto",
|
||||
max_tokens: int = 2000,
|
||||
user_id: str | None = None,
|
||||
kb_acl_hash: str | None = None,
|
||||
) -> str:
|
||||
"""构建带 user/ACL scope 的 cache key(安全要求 a, b, e)。
|
||||
|
||||
委托给 ``cache_key.generate_cache_key``,额外注入 user_id + kb_acl_hash
|
||||
作为命名空间隔离,确保 User A 的查询不会命中 User B 的缓存。
|
||||
"""
|
||||
from agentkit.llm.cache_key import generate_cache_key
|
||||
|
||||
return generate_cache_key(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
max_tokens=max_tokens,
|
||||
user_id=user_id,
|
||||
kb_acl_hash=kb_acl_hash,
|
||||
)
|
||||
|
||||
def should_cache(
|
||||
self,
|
||||
kb_caching_disabled: bool = False,
|
||||
user_id: str | None = None,
|
||||
) -> bool:
|
||||
"""判断当前请求是否应该缓存(安全要求 c)。
|
||||
|
||||
- KB 设置 caching_disabled=True → 不缓存
|
||||
- 其余情况缓存(user_id 为 None 时仍可缓存,但 key 不含 user scope)
|
||||
"""
|
||||
_ = user_id # 预留:未来支持 per-user 缓存禁用
|
||||
if kb_caching_disabled:
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def cache_params_for_hit(cache_key: str) -> dict[str, str]:
|
||||
"""返回 litellm acompletion 的 cache 参数(用于期望命中的调用)。"""
|
||||
return {"cache_key": cache_key}
|
||||
|
||||
@staticmethod
|
||||
def cache_params_for_no_cache() -> dict[str, bool]:
|
||||
"""返回 litellm acompletion 的 cache 参数(禁用缓存)。"""
|
||||
return {"no-cache": True}
|
||||
|
||||
def detect_cache_hit(self, response: Any) -> bool:
|
||||
"""检测 LiteLLM 响应是否为缓存命中。
|
||||
|
||||
LiteLLM 在缓存命中时设置 ``response._hidden_params["cache_key"]``。
|
||||
"""
|
||||
hidden = getattr(response, "_hidden_params", None)
|
||||
if isinstance(hidden, dict):
|
||||
if "cache_key" in hidden or hidden.get("cache_hit"):
|
||||
self._hits += 1
|
||||
return True
|
||||
self._misses += 1
|
||||
return False
|
||||
|
||||
def stats(self) -> dict[str, int]:
|
||||
"""返回缓存统计。"""
|
||||
return {
|
||||
"total_hits": self._hits,
|
||||
"total_misses": self._misses,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,6 +12,8 @@ def generate_cache_key(
|
|||
tools: list[dict[str, Any]] | None = None,
|
||||
tool_choice: str = "auto",
|
||||
max_tokens: int = 2000,
|
||||
user_id: str | None = None,
|
||||
kb_acl_hash: str | None = None,
|
||||
) -> str:
|
||||
"""Generate a deterministic SHA-256 cache key from LLM request parameters.
|
||||
|
||||
|
|
@ -19,6 +21,10 @@ def generate_cache_key(
|
|||
model, system_prompt (extracted from messages), messages content,
|
||||
temperature, tools, tool_choice, and max_tokens.
|
||||
|
||||
U17 安全扩展:``user_id`` 和 ``kb_acl_hash`` 非 None 时加入 hash 组件,
|
||||
实现 per-user namespace 和 ACL-scope 隔离(安全要求 a, b)。为 None 时
|
||||
行为与旧版完全一致(向后兼容)。
|
||||
|
||||
Args:
|
||||
model: Model identifier (e.g. "openai/gpt-4o").
|
||||
messages: Chat messages list (may include system prompt as first message).
|
||||
|
|
@ -26,6 +32,8 @@ def generate_cache_key(
|
|||
tools: Optional list of tool definitions.
|
||||
tool_choice: Tool selection mode ("auto", "none", etc.).
|
||||
max_tokens: Maximum response tokens.
|
||||
user_id: U17 — 用户 ID,用于 per-user cache namespace 隔离。
|
||||
kb_acl_hash: U17 — KB ACL-scope hash,用于 ACL 隔离缓存键。
|
||||
|
||||
Returns:
|
||||
64-character hex SHA-256 hash string.
|
||||
|
|
@ -40,6 +48,11 @@ def generate_cache_key(
|
|||
_hash_str(tool_choice),
|
||||
_hash_str(str(max_tokens)),
|
||||
]
|
||||
# U17 — per-user namespace + ACL scope hash(安全要求 a, b, e)
|
||||
if user_id is not None:
|
||||
components.append(_hash_str(f"user:{user_id}"))
|
||||
if kb_acl_hash is not None:
|
||||
components.append(_hash_str(f"acl:{kb_acl_hash}"))
|
||||
combined = "".join(components)
|
||||
return hashlib.sha256(combined.encode()).hexdigest()
|
||||
|
||||
|
|
@ -61,6 +74,4 @@ def _hash_json(obj: Any) -> str:
|
|||
"""SHA-256 hash of a JSON-serializable object."""
|
||||
if obj is None:
|
||||
return hashlib.sha256(b"null").hexdigest()
|
||||
return hashlib.sha256(
|
||||
json.dumps(obj, sort_keys=True, ensure_ascii=False).encode()
|
||||
).hexdigest()
|
||||
return hashlib.sha256(json.dumps(obj, sort_keys=True, ensure_ascii=False).encode()).hexdigest()
|
||||
|
|
|
|||
|
|
@ -51,47 +51,19 @@ class LLMGateway:
|
|||
self._usage_tracker = UsageTracker(store=usage_store) if usage_store else UsageTracker()
|
||||
self._config = config or LLMConfig()
|
||||
|
||||
# Cache (opt-in, disabled by default)
|
||||
self._cache: Any = None # LLMCache | None
|
||||
self._embedder: Any = None # Embedder | None
|
||||
# Cache (U17 — LiteLLM 缓存管理器,opt-in,默认禁用)
|
||||
self._cache_manager: Any = None # LitellmCacheManager | None
|
||||
if self._config.cache and self._config.cache.enabled:
|
||||
from agentkit.llm.cache import create_llm_cache
|
||||
from agentkit.llm.cache import LitellmCacheConfig, LitellmCacheManager
|
||||
|
||||
self._cache = create_llm_cache(
|
||||
backend=self._config.cache.backend,
|
||||
redis_url=self._config.cache.redis_url,
|
||||
max_entries=self._config.cache.max_entries,
|
||||
exact_ttl=self._config.cache.exact_ttl,
|
||||
semantic_ttl=self._config.cache.semantic_ttl,
|
||||
similarity_threshold=self._config.cache.similarity_threshold,
|
||||
)
|
||||
self._embedder = self._create_embedder(self._config.cache)
|
||||
litellm_config = LitellmCacheConfig.from_cache_config(self._config.cache)
|
||||
self._cache_manager = LitellmCacheManager(litellm_config)
|
||||
self._cache_manager.enable()
|
||||
logger.info(
|
||||
f"LLM cache enabled (backend={self._config.cache.backend}, "
|
||||
f"embedder={self._config.cache.embedding_provider}/{self._config.cache.embedding_model})"
|
||||
f"LLM cache enabled (LiteLLM, backend={self._config.cache.backend}, "
|
||||
f"similarity_threshold={litellm_config.similarity_threshold})"
|
||||
)
|
||||
|
||||
def _create_embedder(self, cache_config) -> Any:
|
||||
"""Create embedder for semantic cache based on config."""
|
||||
try:
|
||||
from agentkit.memory.embedder import OpenAIEmbedder
|
||||
|
||||
if cache_config.embedding_provider in ("xinference", "local"):
|
||||
return OpenAIEmbedder(
|
||||
api_key=cache_config.embedding_api_key or "not-needed",
|
||||
model=cache_config.embedding_model,
|
||||
base_url=cache_config.embedding_base_url or "http://localhost:9997/v1",
|
||||
)
|
||||
# Default: OpenAI
|
||||
return OpenAIEmbedder(
|
||||
api_key=cache_config.embedding_api_key,
|
||||
model=cache_config.embedding_model,
|
||||
base_url=cache_config.embedding_base_url,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create embedder for semantic cache: {e}")
|
||||
return None
|
||||
|
||||
def register_provider(self, name: str, provider: LLMProvider) -> None:
|
||||
"""注册 Provider"""
|
||||
self._providers[name] = provider
|
||||
|
|
@ -114,6 +86,8 @@ class LLMGateway:
|
|||
user_id: str | None = None,
|
||||
department_ids: list[str] | None = None,
|
||||
db_path: Path | str | None = None,
|
||||
kb_id: str | None = None,
|
||||
kb_acl_hash: str | None = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
"""发送 chat 请求,自动解析别名和 Fallback"""
|
||||
|
|
@ -151,67 +125,40 @@ class LLMGateway:
|
|||
|
||||
start = time.monotonic()
|
||||
|
||||
# ── Cache check ──
|
||||
cache_key = None
|
||||
query_embedding = None
|
||||
if self._cache is not None:
|
||||
from agentkit.llm.cache_key import generate_cache_key
|
||||
# ── Cache check (U17 — LiteLLM cache via cache_key in request) ──
|
||||
# LiteLLM 在 litellm.acompletion 内部处理缓存读写,gateway 只需:
|
||||
# 1. 构建 per-user + ACL-scoped cache_key(安全要求 a, b)
|
||||
# 2. 将 cache 参数注入 kwargs 透传到 provider
|
||||
# 3. 检测响应的 cache_hit 标志,用于 usage tracking(cost=0)
|
||||
if self._cache_manager is not None:
|
||||
from agentkit.llm.cache import LitellmCacheManager
|
||||
|
||||
cache_key = generate_cache_key(
|
||||
# 解析 KB caching_disabled(安全要求 c)
|
||||
kb_caching_disabled = False
|
||||
if kb_id is not None:
|
||||
try:
|
||||
from agentkit.rag_platform.settings import get_settings_store
|
||||
|
||||
settings = await get_settings_store().get_settings(kb_id)
|
||||
if settings is not None:
|
||||
kb_caching_disabled = settings.caching_disabled
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to read KB cache settings for kb_id={kb_id}: {e}")
|
||||
|
||||
if self._cache_manager.should_cache(kb_caching_disabled, user_id):
|
||||
cache_key = self._cache_manager.build_cache_key(
|
||||
model=resolved_model,
|
||||
messages=messages,
|
||||
temperature=kwargs.get("temperature", 0.7),
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
max_tokens=kwargs.get("max_tokens", 2000),
|
||||
)
|
||||
result = await self._cache.get(cache_key)
|
||||
if result.hit:
|
||||
latency_ms = (time.monotonic() - start) * 1000
|
||||
await self._record_usage(
|
||||
agent_name=agent_name,
|
||||
model=result.response.model,
|
||||
usage=result.response.usage,
|
||||
cost=0.0,
|
||||
latency_ms=latency_ms,
|
||||
user_id=user_id,
|
||||
department_ids=department_ids,
|
||||
kb_acl_hash=kb_acl_hash,
|
||||
)
|
||||
if _span is not None:
|
||||
_span.set_attribute("gen_ai.cache.hit", True)
|
||||
_span.set_attribute("gen_ai.cache.match_type", result.match_type)
|
||||
return result.response
|
||||
|
||||
# Semantic match (only for temperature == 0)
|
||||
temperature = kwargs.get("temperature", 0.7)
|
||||
if temperature == 0 and self._embedder is not None:
|
||||
try:
|
||||
# Embed last N messages for context-aware semantic matching
|
||||
# (not just last user message — avoids cross-context false hits)
|
||||
recent_messages = messages[-3:] if len(messages) > 3 else messages
|
||||
embed_text = " | ".join(
|
||||
m.get("content", "") for m in recent_messages if m.get("content")
|
||||
)
|
||||
if embed_text:
|
||||
query_embedding = await self._embedder.embed(embed_text)
|
||||
result = await self._cache.semantic_search(query_embedding)
|
||||
if result.hit:
|
||||
latency_ms = (time.monotonic() - start) * 1000
|
||||
await self._record_usage(
|
||||
agent_name=agent_name,
|
||||
model=result.response.model,
|
||||
usage=result.response.usage,
|
||||
cost=0.0,
|
||||
latency_ms=latency_ms,
|
||||
user_id=user_id,
|
||||
department_ids=department_ids,
|
||||
)
|
||||
if _span is not None:
|
||||
_span.set_attribute("gen_ai.cache.hit", True)
|
||||
_span.set_attribute("gen_ai.cache.match_type", "semantic")
|
||||
return result.response
|
||||
except Exception as e:
|
||||
logger.warning(f"Semantic cache search failed: {e}")
|
||||
kwargs["cache"] = LitellmCacheManager.cache_params_for_hit(cache_key)
|
||||
else:
|
||||
kwargs["cache"] = LitellmCacheManager.cache_params_for_no_cache()
|
||||
|
||||
# ── Normal provider call ──
|
||||
models_to_try = self._get_models_to_try(resolved_model)
|
||||
|
|
@ -275,15 +222,15 @@ class LLMGateway:
|
|||
|
||||
latency_ms = (time.monotonic() - start) * 1000
|
||||
|
||||
# ── Cache write ──
|
||||
if self._cache is not None and cache_key is not None:
|
||||
try:
|
||||
await self._cache.put(cache_key, response, query_embedding)
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache write failed: {e}")
|
||||
# U17 — 检测 LiteLLM 缓存命中(用于 usage tracking cost=0)
|
||||
is_cache_hit = getattr(response, "cache_hit", False)
|
||||
if is_cache_hit and self._cache_manager is not None:
|
||||
self._cache_manager._hits += 1
|
||||
elif self._cache_manager is not None:
|
||||
self._cache_manager._misses += 1
|
||||
|
||||
# 计算成本
|
||||
cost = self._calculate_cost(response.model, response.usage)
|
||||
# 计算成本(缓存命中时 cost=0)
|
||||
cost = 0.0 if is_cache_hit else self._calculate_cost(response.model, response.usage)
|
||||
|
||||
# 记录使用量
|
||||
await self._record_usage(
|
||||
|
|
@ -302,8 +249,8 @@ class LLMGateway:
|
|||
_span.set_attribute("gen_ai.usage.output_tokens", response.usage.completion_tokens)
|
||||
_span.set_attribute("gen_ai.response.model", response.model)
|
||||
_span.set_attribute("gen_ai.duration.ms", int(latency_ms))
|
||||
if self._cache is not None:
|
||||
_span.set_attribute("gen_ai.cache.hit", False)
|
||||
if self._cache_manager is not None:
|
||||
_span.set_attribute("gen_ai.cache.hit", is_cache_hit)
|
||||
llm_token_histogram().record(
|
||||
response.usage.total_tokens,
|
||||
{"gen_ai.request.model": resolved_model},
|
||||
|
|
@ -607,18 +554,10 @@ class LLMGateway:
|
|||
)
|
||||
|
||||
# 2. Token + cost limits (daily AND monthly)
|
||||
await self._check_quota_period(
|
||||
quota_service, db, dept_id, "daily", "token_limit"
|
||||
)
|
||||
await self._check_quota_period(
|
||||
quota_service, db, dept_id, "daily", "cost_limit"
|
||||
)
|
||||
await self._check_quota_period(
|
||||
quota_service, db, dept_id, "monthly", "token_limit"
|
||||
)
|
||||
await self._check_quota_period(
|
||||
quota_service, db, dept_id, "monthly", "cost_limit"
|
||||
)
|
||||
await self._check_quota_period(quota_service, db, dept_id, "daily", "token_limit")
|
||||
await self._check_quota_period(quota_service, db, dept_id, "daily", "cost_limit")
|
||||
await self._check_quota_period(quota_service, db, dept_id, "monthly", "token_limit")
|
||||
await self._check_quota_period(quota_service, db, dept_id, "monthly", "cost_limit")
|
||||
|
||||
async def _check_quota_period(
|
||||
self,
|
||||
|
|
@ -639,9 +578,7 @@ class LLMGateway:
|
|||
else:
|
||||
current = await self._get_current_cost_for_quota(dept_id, period)
|
||||
|
||||
allowed, _reason = await quota_service.check_quota(
|
||||
db, dept_id, quota_type, period, current
|
||||
)
|
||||
allowed, _reason = await quota_service.check_quota(db, dept_id, quota_type, period, current)
|
||||
if not allowed:
|
||||
quota = await quota_service.get_quota(db, dept_id, quota_type, period)
|
||||
limit = quota["limit_value"] if quota else None
|
||||
|
|
|
|||
|
|
@ -47,6 +47,7 @@ class LLMRequest:
|
|||
temperature: float = 0.7,
|
||||
max_tokens: int = 2000,
|
||||
timeout: float | None = None,
|
||||
cache: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
self.messages = messages
|
||||
|
|
@ -57,6 +58,8 @@ class LLMRequest:
|
|||
self.max_tokens = max_tokens
|
||||
self.timeout = timeout
|
||||
self._extra = kwargs
|
||||
# U17 — LiteLLM cache 参数(cache_key 或 no-cache),透传到 litellm.acompletion
|
||||
self._cache: dict[str, Any] | None = cache
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -81,6 +84,7 @@ class LLMResponse:
|
|||
usage: TokenUsage
|
||||
tool_calls: list[ToolCall] = field(default_factory=list)
|
||||
latency_ms: float = 0.0
|
||||
cache_hit: bool = False # U17 — 缓存命中标记,用于 usage tracking(cost=0)
|
||||
|
||||
@property
|
||||
def has_tool_calls(self) -> bool:
|
||||
|
|
|
|||
|
|
@ -183,6 +183,10 @@ class LitellmProvider(LLMProvider):
|
|||
kwargs["tool_choice"] = request.tool_choice
|
||||
if request.timeout is not None:
|
||||
kwargs["timeout"] = request.timeout
|
||||
# U17 — 透传 LiteLLM cache 参数(cache_key 或 no-cache)到 litellm.acompletion
|
||||
cache_params = getattr(request, "_cache", None)
|
||||
if cache_params is not None:
|
||||
kwargs["cache"] = cache_params
|
||||
# 合并构造时传入的默认 kwargs(如 max_connections 等provider特定参数)
|
||||
kwargs.update(self._default_kwargs)
|
||||
return kwargs
|
||||
|
|
@ -214,12 +218,19 @@ class LitellmProvider(LLMProvider):
|
|||
|
||||
model_name = getattr(response, "model", None) or request_model
|
||||
|
||||
# U17 — 检测 LiteLLM 缓存命中(_hidden_params 含 cache_key 或 cache_hit)
|
||||
cache_hit = False
|
||||
hidden = getattr(response, "_hidden_params", None)
|
||||
if isinstance(hidden, dict) and ("cache_key" in hidden or hidden.get("cache_hit")):
|
||||
cache_hit = True
|
||||
|
||||
return LLMResponse(
|
||||
content=content,
|
||||
model=model_name,
|
||||
usage=usage,
|
||||
tool_calls=tool_calls,
|
||||
latency_ms=latency_ms,
|
||||
cache_hit=cache_hit,
|
||||
)
|
||||
|
||||
def _parse_stream_chunk(
|
||||
|
|
|
|||
|
|
@ -287,13 +287,13 @@ class TestCacheIntegration:
|
|||
cache=CacheConfig(enabled=True, backend="memory"),
|
||||
)
|
||||
gateway = LLMGateway(config=config)
|
||||
assert gateway._cache is not None
|
||||
assert gateway._cache_manager is not None
|
||||
|
||||
def test_gateway_without_cache_config(self):
|
||||
"""LLMGateway works without cache (default)."""
|
||||
config = LLMConfig(providers={})
|
||||
gateway = LLMGateway(config=config)
|
||||
assert gateway._cache is None
|
||||
assert gateway._cache_manager is None
|
||||
|
||||
def test_gateway_cache_disabled(self):
|
||||
"""LLMGateway does not initialize cache when disabled."""
|
||||
|
|
@ -302,7 +302,7 @@ class TestCacheIntegration:
|
|||
cache=CacheConfig(enabled=False),
|
||||
)
|
||||
gateway = LLMGateway(config=config)
|
||||
assert gateway._cache is None
|
||||
assert gateway._cache_manager is None
|
||||
|
||||
|
||||
# ── Graceful degradation tests ─────────────────────────────
|
||||
|
|
@ -390,7 +390,7 @@ class TestFullFlowInMemory:
|
|||
redis_url=config.usage_store.get("redis_url", "redis://localhost:6379"),
|
||||
)
|
||||
gateway = LLMGateway(config=config.llm_config, usage_store=usage_store)
|
||||
assert gateway._cache is not None
|
||||
assert gateway._cache_manager is not None
|
||||
|
||||
cascade_store = create_cascade_state_store(
|
||||
backend=config.cascade_store.get("backend", "memory"),
|
||||
|
|
|
|||
|
|
@ -0,0 +1,486 @@
|
|||
"""U17 — LiteLLM 语义缓存管理器单元测试。
|
||||
|
||||
覆盖场景(plan 要求的 4 个 + 安全要求 e + 向后兼容 + fallback):
|
||||
1. 语义相似查询命中缓存(mock litellm.acompletion 模拟缓存行为)
|
||||
2. 不同 system prompt 不命中缓存
|
||||
3. 缓存命中率可统计
|
||||
4. 阈值调优生效(similarity_threshold 传入 RedisSemanticCache)
|
||||
5. User A 不返回 User B 缓存(安全要求 e)
|
||||
6. kb_acl_hash 隔离 — 不同 ACL hash 产生不同 key
|
||||
7. kb_caching_disabled 禁用缓存(安全要求 c)
|
||||
8. cache_params_for_hit / no_cache — 返回正确 dict
|
||||
9. detect_cache_hit — _hidden_params 含 cache_key 时返回 True
|
||||
10. LitellmCacheConfig.from_cache_config — 转换正确,similarity_threshold=0.87
|
||||
11. LitellmCacheManager.enable/disable — litellm.cache 正确设置/清除
|
||||
12. generate_cache_key 向后兼容 — user_id=None, kb_acl_hash=None 时与旧版相同
|
||||
13. RedisSemanticCache fallback — redisvl 缺失时 auto backend 回退
|
||||
|
||||
测试用 ``unittest.mock.patch`` mock ``litellm.cache`` 全局变量,避免测试间污染。
|
||||
每个测试后清理 ``litellm.cache = None``。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.llm.cache import LitellmCacheConfig, LitellmCacheManager
|
||||
from agentkit.llm.cache_key import generate_cache_key
|
||||
from agentkit.llm.config import CacheConfig
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 测试辅助
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_messages(user_content: str = "Hello") -> list[dict[str, str]]:
|
||||
return [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": user_content},
|
||||
]
|
||||
|
||||
|
||||
def _make_litellm_response(
|
||||
content: str = "Hello!",
|
||||
model: str = "gpt-4o-mini",
|
||||
prompt_tokens: int = 10,
|
||||
completion_tokens: int = 5,
|
||||
cache_key: str | None = None,
|
||||
) -> SimpleNamespace:
|
||||
"""构造 LiteLLM 响应(OpenAI ChatCompletion 格式),可选 cache_key 标记。"""
|
||||
hidden_params: dict[str, Any] = {}
|
||||
if cache_key is not None:
|
||||
hidden_params["cache_key"] = cache_key
|
||||
message = SimpleNamespace(content=content, tool_calls=None)
|
||||
return SimpleNamespace(
|
||||
choices=[SimpleNamespace(message=message)],
|
||||
usage=SimpleNamespace(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens),
|
||||
model=model,
|
||||
_hidden_params=hidden_params,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _cleanup_litellm_cache():
|
||||
"""每个测试后清理 litellm.cache 全局变量,避免测试间污染。"""
|
||||
yield
|
||||
import litellm
|
||||
|
||||
litellm.cache = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. 语义相似查询命中缓存
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCacheHit:
|
||||
"""场景 1 — 相同请求第二次命中缓存。"""
|
||||
|
||||
async def test_second_request_is_cache_hit(self):
|
||||
"""相同 cache_key 的第二次请求返回缓存响应(cache_hit=True)。"""
|
||||
config = LitellmCacheConfig(enabled=True, backend="memory")
|
||||
manager = LitellmCacheManager(config)
|
||||
|
||||
msgs = _make_messages()
|
||||
cache_key = manager.build_cache_key("gpt-4o", msgs, 0.0)
|
||||
cache_params = LitellmCacheManager.cache_params_for_hit(cache_key)
|
||||
|
||||
# 模拟 LiteLLM 缓存行为:第一次 miss,第二次 hit
|
||||
call_count = 0
|
||||
cache_store: dict[str, SimpleNamespace] = {}
|
||||
|
||||
async def fake_acompletion(**kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
ck = kwargs.get("cache", {}).get("cache_key")
|
||||
if ck and ck in cache_store:
|
||||
# 缓存命中 — 返回带 cache_key 标记的响应
|
||||
return _make_litellm_response(content="Cached!", cache_key=ck)
|
||||
# 缓存未命中 — 调用"真实" API 并存储
|
||||
resp = _make_litellm_response(content="Fresh response")
|
||||
if ck:
|
||||
cache_store[ck] = resp
|
||||
return resp
|
||||
|
||||
with patch("litellm.acompletion", side_effect=fake_acompletion):
|
||||
from agentkit.llm.providers.litellm_provider import LitellmProvider
|
||||
|
||||
provider = LitellmProvider(model_prefix="openai/", api_key="test")
|
||||
from agentkit.llm.protocol import LLMRequest
|
||||
|
||||
# 第一次请求 — miss
|
||||
req1 = LLMRequest(messages=msgs, model="gpt-4o", temperature=0.0, cache=cache_params)
|
||||
resp1 = await provider.chat(req1)
|
||||
assert resp1.cache_hit is False
|
||||
assert resp1.content == "Fresh response"
|
||||
|
||||
# 第二次相同请求 — hit
|
||||
req2 = LLMRequest(messages=msgs, model="gpt-4o", temperature=0.0, cache=cache_params)
|
||||
resp2 = await provider.chat(req2)
|
||||
assert resp2.cache_hit is True
|
||||
assert resp2.content == "Cached!"
|
||||
|
||||
assert call_count == 2 # litellm.acompletion 被调用两次(LiteLLM 内部处理缓存)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. 不同 system prompt 不命中缓存
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSystemPromptIsolation:
|
||||
"""场景 2 — 不同 system prompt 产生不同 cache_key,不误命中。"""
|
||||
|
||||
def test_different_system_prompt_different_key(self):
|
||||
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
||||
msgs1 = [
|
||||
{"role": "system", "content": "Be concise"},
|
||||
{"role": "user", "content": "Hello"},
|
||||
]
|
||||
msgs2 = [
|
||||
{"role": "system", "content": "Be verbose"},
|
||||
{"role": "user", "content": "Hello"},
|
||||
]
|
||||
key1 = manager.build_cache_key("gpt-4o", msgs1, 0.0)
|
||||
key2 = manager.build_cache_key("gpt-4o", msgs2, 0.0)
|
||||
assert key1 != key2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. 缓存命中率可统计
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCacheStats:
|
||||
"""场景 3 — LitellmCacheManager.stats() 返回正确 hits/misses。"""
|
||||
|
||||
def test_stats_after_hits_and_misses(self):
|
||||
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
||||
|
||||
# 2 hits
|
||||
manager.detect_cache_hit(_make_litellm_response(cache_key="k1"))
|
||||
manager.detect_cache_hit(_make_litellm_response(cache_key="k2"))
|
||||
# 1 miss
|
||||
manager.detect_cache_hit(_make_litellm_response()) # 无 cache_key
|
||||
|
||||
stats = manager.stats()
|
||||
assert stats["total_hits"] == 2
|
||||
assert stats["total_misses"] == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4. 阈值调优生效
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSimilarityThreshold:
|
||||
"""场景 4 — similarity_threshold 传入 RedisSemanticCache 构造函数。"""
|
||||
|
||||
def test_threshold_passed_to_redis_semantic_cache(self):
|
||||
"""RedisSemanticCache 构造时接收 similarity_threshold=0.87。"""
|
||||
config = LitellmCacheConfig(
|
||||
enabled=True,
|
||||
backend="redis_semantic",
|
||||
similarity_threshold=0.83,
|
||||
redis_url="redis://localhost:6379",
|
||||
)
|
||||
manager = LitellmCacheManager(config)
|
||||
|
||||
mock_cache_instance = MagicMock()
|
||||
with patch(
|
||||
"litellm.caching.RedisSemanticCache", return_value=mock_cache_instance
|
||||
) as mock_cls:
|
||||
instance = manager._create_cache_instance()
|
||||
mock_cls.assert_called_once_with(
|
||||
redis_url="redis://localhost:6379",
|
||||
similarity_threshold=0.83,
|
||||
embedding_model="text-embedding-ada-002",
|
||||
)
|
||||
assert instance is mock_cache_instance
|
||||
|
||||
def test_default_threshold_is_087(self):
|
||||
"""from_cache_config 固定 similarity_threshold=0.87。"""
|
||||
old_config = CacheConfig(similarity_threshold=0.92)
|
||||
litellm_config = LitellmCacheConfig.from_cache_config(old_config)
|
||||
assert litellm_config.similarity_threshold == 0.87
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 5. User A 不返回 User B 缓存(安全要求 e)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestUserIsolation:
|
||||
"""安全要求 e — User A 的查询不返回 User B 的缓存响应。"""
|
||||
|
||||
def test_different_users_different_keys(self):
|
||||
"""不同 user_id 产生不同 cache_key。"""
|
||||
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
||||
msgs = _make_messages("What is my salary?")
|
||||
key_a = manager.build_cache_key("gpt-4o", msgs, 0.0, user_id="user_a")
|
||||
key_b = manager.build_cache_key("gpt-4o", msgs, 0.0, user_id="user_b")
|
||||
assert key_a != key_b
|
||||
|
||||
def test_same_user_same_key(self):
|
||||
"""相同 user_id 产生相同 cache_key。"""
|
||||
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
||||
msgs = _make_messages("What is my salary?")
|
||||
key1 = manager.build_cache_key("gpt-4o", msgs, 0.0, user_id="user_a")
|
||||
key2 = manager.build_cache_key("gpt-4o", msgs, 0.0, user_id="user_a")
|
||||
assert key1 == key2
|
||||
|
||||
def test_no_user_id_same_as_no_user_id(self):
|
||||
"""user_id=None 时两次调用产生相同 key(向后兼容)。"""
|
||||
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
||||
msgs = _make_messages()
|
||||
key1 = manager.build_cache_key("gpt-4o", msgs, 0.0, user_id=None)
|
||||
key2 = manager.build_cache_key("gpt-4o", msgs, 0.0, user_id=None)
|
||||
assert key1 == key2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 6. kb_acl_hash 隔离
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestACLIsolation:
|
||||
"""安全要求 b — 不同 ACL scope 产生不同 cache_key。"""
|
||||
|
||||
def test_different_acl_hash_different_keys(self):
|
||||
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
||||
msgs = _make_messages("Summarize the document")
|
||||
key1 = manager.build_cache_key("gpt-4o", msgs, 0.0, kb_acl_hash="acl_v1")
|
||||
key2 = manager.build_cache_key("gpt-4o", msgs, 0.0, kb_acl_hash="acl_v2")
|
||||
assert key1 != key2
|
||||
|
||||
def test_same_acl_hash_same_key(self):
|
||||
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
||||
msgs = _make_messages()
|
||||
key1 = manager.build_cache_key("gpt-4o", msgs, 0.0, kb_acl_hash="acl_v1")
|
||||
key2 = manager.build_cache_key("gpt-4o", msgs, 0.0, kb_acl_hash="acl_v1")
|
||||
assert key1 == key2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 7. kb_caching_disabled 禁用缓存
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestKBCachingDisabled:
|
||||
"""安全要求 c — KB 设置 caching_disabled=True 时禁用缓存。"""
|
||||
|
||||
def test_should_cache_returns_false_when_disabled(self):
|
||||
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
||||
assert manager.should_cache(kb_caching_disabled=True) is False
|
||||
|
||||
def test_should_cache_returns_true_when_enabled(self):
|
||||
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
||||
assert manager.should_cache(kb_caching_disabled=False) is True
|
||||
|
||||
def test_should_cache_default_true(self):
|
||||
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
||||
assert manager.should_cache() is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 8. cache_params_for_hit / no_cache
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCacheParams:
|
||||
def test_cache_params_for_hit(self):
|
||||
params = LitellmCacheManager.cache_params_for_hit("my_cache_key")
|
||||
assert params == {"cache_key": "my_cache_key"}
|
||||
|
||||
def test_cache_params_for_no_cache(self):
|
||||
params = LitellmCacheManager.cache_params_for_no_cache()
|
||||
assert params == {"no-cache": True}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 9. detect_cache_hit
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDetectCacheHit:
|
||||
def test_hit_with_cache_key(self):
|
||||
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
||||
resp = _make_litellm_response(cache_key="some_key")
|
||||
assert manager.detect_cache_hit(resp) is True
|
||||
|
||||
def test_hit_with_cache_hit_flag(self):
|
||||
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
||||
resp = _make_litellm_response()
|
||||
resp._hidden_params = {"cache_hit": True}
|
||||
assert manager.detect_cache_hit(resp) is True
|
||||
|
||||
def test_miss_without_cache_key(self):
|
||||
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
||||
resp = _make_litellm_response()
|
||||
assert manager.detect_cache_hit(resp) is False
|
||||
|
||||
def test_miss_with_no_hidden_params(self):
|
||||
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
||||
resp = SimpleNamespace(_hidden_params=None)
|
||||
assert manager.detect_cache_hit(resp) is False
|
||||
|
||||
def test_miss_with_no_hidden_params_attr(self):
|
||||
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
||||
resp = SimpleNamespace()
|
||||
assert manager.detect_cache_hit(resp) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 10. LitellmCacheConfig.from_cache_config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFromCacheConfig:
|
||||
def test_basic_conversion(self):
|
||||
old = CacheConfig(
|
||||
enabled=True,
|
||||
backend="memory",
|
||||
redis_url="redis://myhost:6379",
|
||||
semantic_ttl=3600,
|
||||
embedding_model="bge-m3",
|
||||
)
|
||||
litellm_config = LitellmCacheConfig.from_cache_config(old)
|
||||
assert litellm_config.enabled is True
|
||||
assert litellm_config.backend == "memory"
|
||||
assert litellm_config.redis_url == "redis://myhost:6379"
|
||||
assert litellm_config.ttl == 3600
|
||||
assert litellm_config.similarity_threshold == 0.87 # 固定,忽略旧 0.92
|
||||
assert litellm_config.per_user_namespace is True # 强制开启
|
||||
|
||||
def test_unknown_backend_falls_to_auto(self):
|
||||
old = CacheConfig(backend="unknown_backend")
|
||||
litellm_config = LitellmCacheConfig.from_cache_config(old)
|
||||
assert litellm_config.backend == "auto"
|
||||
|
||||
def test_redis_backend_preserved(self):
|
||||
old = CacheConfig(backend="redis")
|
||||
litellm_config = LitellmCacheConfig.from_cache_config(old)
|
||||
assert litellm_config.backend == "redis"
|
||||
|
||||
def test_empty_embedding_model_falls_to_default(self):
|
||||
old = CacheConfig(embedding_model="")
|
||||
litellm_config = LitellmCacheConfig.from_cache_config(old)
|
||||
assert litellm_config.embedding_model == "text-embedding-ada-002"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 11. LitellmCacheManager.enable/disable
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEnableDisable:
|
||||
def test_enable_sets_litellm_cache(self):
|
||||
import litellm
|
||||
|
||||
config = LitellmCacheConfig(backend="memory")
|
||||
manager = LitellmCacheManager(config)
|
||||
manager.enable()
|
||||
assert litellm.cache is not None
|
||||
assert manager._cache_instance is not None
|
||||
|
||||
def test_disable_clears_litellm_cache(self):
|
||||
import litellm
|
||||
|
||||
config = LitellmCacheConfig(backend="memory")
|
||||
manager = LitellmCacheManager(config)
|
||||
manager.enable()
|
||||
assert litellm.cache is not None
|
||||
manager.disable()
|
||||
assert litellm.cache is None
|
||||
assert manager._cache_instance is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 12. generate_cache_key 向后兼容
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBackwardCompatibility:
|
||||
"""user_id=None, kb_acl_hash=None 时与旧版 generate_cache_key 完全一致。"""
|
||||
|
||||
def test_none_user_id_same_as_not_passing(self):
|
||||
msgs = _make_messages()
|
||||
key1 = generate_cache_key("gpt-4o", msgs, 0.0, user_id=None, kb_acl_hash=None)
|
||||
key2 = generate_cache_key("gpt-4o", msgs, 0.0)
|
||||
assert key1 == key2
|
||||
|
||||
def test_backward_compat_deterministic(self):
|
||||
msgs = _make_messages()
|
||||
key1 = generate_cache_key("gpt-4o", msgs, 0.0, user_id=None, kb_acl_hash=None)
|
||||
key2 = generate_cache_key("gpt-4o", msgs, 0.0, user_id=None, kb_acl_hash=None)
|
||||
assert key1 == key2
|
||||
|
||||
def test_user_id_changes_key(self):
|
||||
msgs = _make_messages()
|
||||
key_none = generate_cache_key("gpt-4o", msgs, 0.0)
|
||||
key_user = generate_cache_key("gpt-4o", msgs, 0.0, user_id="user_a")
|
||||
assert key_none != key_user
|
||||
|
||||
def test_acl_hash_changes_key(self):
|
||||
msgs = _make_messages()
|
||||
key_none = generate_cache_key("gpt-4o", msgs, 0.0)
|
||||
key_acl = generate_cache_key("gpt-4o", msgs, 0.0, kb_acl_hash="acl_v1")
|
||||
assert key_none != key_acl
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 13. RedisSemanticCache fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRedisSemanticCacheFallback:
|
||||
"""redisvl 缺失时 auto backend 回退到 RedisCache。"""
|
||||
|
||||
def test_auto_fallback_to_redis_cache_when_redisvl_missing(self):
|
||||
"""auto 模式下 RedisSemanticCache 构造 ImportError → 回退到 RedisCache。"""
|
||||
import litellm.caching
|
||||
|
||||
config = LitellmCacheConfig(backend="auto", redis_url="redis://localhost:6379")
|
||||
manager = LitellmCacheManager(config)
|
||||
|
||||
mock_redis_cache = MagicMock(name="RedisCacheInstance")
|
||||
|
||||
# 模拟 RedisSemanticCache 构造时 ImportError(redisvl 未安装)
|
||||
def raise_import_error(*args, **kwargs):
|
||||
raise ImportError("No module named 'redisvl'")
|
||||
|
||||
with patch.object(litellm.caching, "RedisSemanticCache", side_effect=raise_import_error):
|
||||
with patch.object(
|
||||
litellm.caching, "RedisCache", return_value=mock_redis_cache
|
||||
) as mock_rc:
|
||||
instance = manager._create_cache_instance()
|
||||
mock_rc.assert_called_once_with(redis_url="redis://localhost:6379")
|
||||
assert instance is mock_redis_cache
|
||||
|
||||
def test_redis_semantic_backend_raises_when_redisvl_missing(self):
|
||||
"""显式 redis_semantic backend + redisvl 缺失 → raise ImportError。"""
|
||||
import litellm.caching
|
||||
|
||||
config = LitellmCacheConfig(backend="redis_semantic")
|
||||
manager = LitellmCacheManager(config)
|
||||
|
||||
def raise_import_error(*args, **kwargs):
|
||||
raise ImportError("No module named 'redisvl'")
|
||||
|
||||
with patch.object(litellm.caching, "RedisSemanticCache", side_effect=raise_import_error):
|
||||
with pytest.raises(ImportError):
|
||||
manager._create_cache_instance()
|
||||
|
||||
def test_memory_backend_uses_in_memory_cache(self):
|
||||
"""memory backend 直接使用 InMemoryCache,不尝试 Redis。"""
|
||||
import litellm.caching
|
||||
|
||||
config = LitellmCacheConfig(backend="memory")
|
||||
manager = LitellmCacheManager(config)
|
||||
instance = manager._create_cache_instance()
|
||||
assert isinstance(instance, litellm.caching.InMemoryCache)
|
||||
|
|
@ -1,15 +1,20 @@
|
|||
"""Integration tests for LLM Cache integration into LLMGateway (U2)."""
|
||||
"""Integration tests for LLM Cache integration into LLMGateway (U2/U17).
|
||||
|
||||
U17 更新:gateway 改用 ``LitellmCacheManager``(LiteLLM 内置缓存)。
|
||||
旧的 ``InMemoryLLMCache`` 手动缓存逻辑已移除,缓存读写由 LiteLLM 内部处理。
|
||||
测试用 ``CacheAwareMockProvider`` 模拟 LiteLLM 的缓存行为(检查 cache_key,
|
||||
命中时返回 ``cache_hit=True`` 的响应)。
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.llm.cache import InMemoryLLMCache
|
||||
from agentkit.llm.config import CacheConfig, LLMConfig
|
||||
from agentkit.llm.gateway import LLMGateway
|
||||
from agentkit.llm.protocol import LLMProvider, LLMRequest, LLMResponse, TokenUsage
|
||||
|
||||
|
||||
class MockProvider(LLMProvider):
|
||||
"""Mock LLM provider that tracks call count."""
|
||||
"""Mock LLM provider that tracks call count (no cache awareness)."""
|
||||
|
||||
def __init__(self, response_content: str = "Mock response"):
|
||||
self.call_count = 0
|
||||
|
|
@ -24,6 +29,47 @@ class MockProvider(LLMProvider):
|
|||
)
|
||||
|
||||
|
||||
class CacheAwareMockProvider(LLMProvider):
|
||||
"""Mock provider that simulates LiteLLM's cache behavior.
|
||||
|
||||
Reads ``request._cache`` for cache_key, maintains an internal cache dict.
|
||||
On cache hit: returns cached response with ``cache_hit=True`` (call_count NOT incremented).
|
||||
On cache miss: generates fresh response, caches it, returns with ``cache_hit=False``.
|
||||
"""
|
||||
|
||||
def __init__(self, response_content: str = "Mock response"):
|
||||
self.call_count = 0 # 仅统计真实调用(缓存未命中)
|
||||
self._response_content = response_content
|
||||
self._cache: dict[str, LLMResponse] = {}
|
||||
|
||||
async def chat(self, request: LLMRequest) -> LLMResponse:
|
||||
cache_params = getattr(request, "_cache", None) or {}
|
||||
cache_key = cache_params.get("cache_key")
|
||||
no_cache = cache_params.get("no-cache", False)
|
||||
|
||||
# 缓存命中 — 返回缓存响应(不增加 call_count)
|
||||
if cache_key and cache_key in self._cache and not no_cache:
|
||||
cached = self._cache[cache_key]
|
||||
return LLMResponse(
|
||||
content=cached.content,
|
||||
model=cached.model,
|
||||
usage=cached.usage,
|
||||
cache_hit=True,
|
||||
)
|
||||
|
||||
# 缓存未命中 — 真实调用
|
||||
self.call_count += 1
|
||||
response = LLMResponse(
|
||||
content=self._response_content,
|
||||
model=request.model,
|
||||
usage=TokenUsage(prompt_tokens=10, completion_tokens=20),
|
||||
cache_hit=False,
|
||||
)
|
||||
if cache_key and not no_cache:
|
||||
self._cache[cache_key] = response
|
||||
return response
|
||||
|
||||
|
||||
def _make_messages(user_content: str = "Hello") -> list[dict[str, str]]:
|
||||
return [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
|
|
@ -52,7 +98,7 @@ class TestCacheEnabled:
|
|||
"""First request is a cache miss — provider is called."""
|
||||
config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory"))
|
||||
gateway = LLMGateway(config=config)
|
||||
provider = MockProvider()
|
||||
provider = CacheAwareMockProvider()
|
||||
gateway.register_provider("test", provider)
|
||||
|
||||
msgs = _make_messages()
|
||||
|
|
@ -60,28 +106,30 @@ class TestCacheEnabled:
|
|||
|
||||
assert provider.call_count == 1
|
||||
assert response.content == "Mock response"
|
||||
assert response.cache_hit is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_second_request_is_hit(self):
|
||||
"""Second identical request is a cache hit — provider NOT called."""
|
||||
"""Second identical request is a cache hit — provider NOT called again."""
|
||||
config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory"))
|
||||
gateway = LLMGateway(config=config)
|
||||
provider = MockProvider()
|
||||
provider = CacheAwareMockProvider()
|
||||
gateway.register_provider("test", provider)
|
||||
|
||||
msgs = _make_messages()
|
||||
await gateway.chat(msgs, "test/model", temperature=0.0)
|
||||
response = await gateway.chat(msgs, "test/model", temperature=0.0)
|
||||
|
||||
assert provider.call_count == 1 # Not called again
|
||||
assert provider.call_count == 1 # Not called again (cache hit)
|
||||
assert response.content == "Mock response"
|
||||
assert response.cache_hit is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cache_hit_usage_has_zero_cost(self):
|
||||
"""Cache hit records usage with cost=0."""
|
||||
config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory"))
|
||||
gateway = LLMGateway(config=config)
|
||||
provider = MockProvider()
|
||||
provider = CacheAwareMockProvider()
|
||||
gateway.register_provider("test", provider)
|
||||
|
||||
msgs = _make_messages()
|
||||
|
|
@ -98,7 +146,7 @@ class TestCacheEnabled:
|
|||
"""Different messages produce cache misses."""
|
||||
config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory"))
|
||||
gateway = LLMGateway(config=config)
|
||||
provider = MockProvider()
|
||||
provider = CacheAwareMockProvider()
|
||||
gateway.register_provider("test", provider)
|
||||
|
||||
await gateway.chat(_make_messages("Hello"), "test/model", temperature=0.0)
|
||||
|
|
@ -110,13 +158,15 @@ class TestCacheEnabled:
|
|||
class TestCacheConfig:
|
||||
def test_config_from_dict(self):
|
||||
"""CacheConfig can be loaded from dict."""
|
||||
config = LLMConfig.from_dict({
|
||||
config = LLMConfig.from_dict(
|
||||
{
|
||||
"cache": {
|
||||
"enabled": True,
|
||||
"backend": "memory",
|
||||
"exact_ttl": 7200,
|
||||
}
|
||||
})
|
||||
}
|
||||
)
|
||||
assert config.cache is not None
|
||||
assert config.cache.enabled is True
|
||||
assert config.cache.backend == "memory"
|
||||
|
|
@ -129,7 +179,8 @@ class TestCacheConfig:
|
|||
|
||||
def test_config_from_dict_embedding(self):
|
||||
"""Embedding config is loaded correctly."""
|
||||
config = LLMConfig.from_dict({
|
||||
config = LLMConfig.from_dict(
|
||||
{
|
||||
"cache": {
|
||||
"enabled": True,
|
||||
"embedding": {
|
||||
|
|
@ -138,25 +189,25 @@ class TestCacheConfig:
|
|||
"base_url": "http://localhost:9997/v1",
|
||||
},
|
||||
}
|
||||
})
|
||||
}
|
||||
)
|
||||
assert config.cache.embedding_provider == "xinference"
|
||||
assert config.cache.embedding_model == "bge-m3"
|
||||
assert config.cache.embedding_base_url == "http://localhost:9997/v1"
|
||||
|
||||
def test_gateway_creates_cache_when_enabled(self):
|
||||
"""Gateway creates cache instance when cache.enabled=True."""
|
||||
"""Gateway creates cache_manager instance when cache.enabled=True."""
|
||||
config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory"))
|
||||
gateway = LLMGateway(config=config)
|
||||
assert gateway._cache is not None
|
||||
assert isinstance(gateway._cache, InMemoryLLMCache)
|
||||
assert gateway._cache_manager is not None
|
||||
|
||||
def test_gateway_no_cache_when_disabled(self):
|
||||
"""Gateway has no cache when cache is disabled."""
|
||||
"""Gateway has no cache_manager when cache is disabled."""
|
||||
config = LLMConfig(cache=CacheConfig(enabled=False))
|
||||
gateway = LLMGateway(config=config)
|
||||
assert gateway._cache is None
|
||||
assert gateway._cache_manager is None
|
||||
|
||||
def test_gateway_no_cache_when_no_config(self):
|
||||
"""Gateway has no cache when cache config is absent."""
|
||||
"""Gateway has no cache_manager when cache config is absent."""
|
||||
gateway = LLMGateway()
|
||||
assert gateway._cache is None
|
||||
assert gateway._cache_manager is None
|
||||
|
|
|
|||
Loading…
Reference in New Issue