825 lines
29 KiB
Python
825 lines
29 KiB
Python
"""LLM Response Cache — Exact-match + Semantic-match dual cache for LLM responses.
|
||
|
||
Architecture:
|
||
- LLMCache Protocol: async interface for cache backends
|
||
- InMemoryLLMCache: OrderedDict LRU + embedding index
|
||
- RedisLLMCache: Redis keys + SET index + lazy init
|
||
- create_llm_cache(): Factory with auto-detection
|
||
|
||
Design doc: docs/plans/2026-06-14-002-u1-llm-cache-architecture.md
|
||
"""
|
||
|
||
import json
|
||
import logging
|
||
import time
|
||
from collections import OrderedDict
|
||
from dataclasses import dataclass, field
|
||
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
|
||
|
||
from agentkit.llm.protocol import LLMResponse, TokenUsage, ToolCall
|
||
from agentkit.utils.vector_math import compute_cosine_similarity
|
||
|
||
if TYPE_CHECKING:
|
||
from agentkit.llm.config import CacheConfig
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Data Classes
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
@dataclass
|
||
class CacheEntry:
|
||
"""A cached LLM response with metadata."""
|
||
|
||
response: LLMResponse
|
||
query_embedding: list[float] = field(default_factory=list)
|
||
created_at: float = 0.0
|
||
hit_count: int = 0
|
||
|
||
|
||
@dataclass
|
||
class CacheResult:
|
||
"""Result of a cache lookup."""
|
||
|
||
hit: bool = False
|
||
response: LLMResponse | None = None
|
||
match_type: str = "" # "exact" | "semantic" | "" (miss)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Serialization helpers (for Redis backend)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def _serialize_response(response: LLMResponse) -> dict:
|
||
"""Serialize LLMResponse to a JSON-compatible dict."""
|
||
return {
|
||
"content": response.content,
|
||
"model": response.model,
|
||
"usage": {
|
||
"prompt_tokens": response.usage.prompt_tokens,
|
||
"completion_tokens": response.usage.completion_tokens,
|
||
},
|
||
"tool_calls": [
|
||
{"id": tc.id, "name": tc.name, "arguments": tc.arguments} for tc in response.tool_calls
|
||
],
|
||
"latency_ms": response.latency_ms,
|
||
}
|
||
|
||
|
||
def _deserialize_response(data: dict) -> LLMResponse:
|
||
"""Deserialize a dict back to LLMResponse."""
|
||
usage_data = data.get("usage", {})
|
||
return LLMResponse(
|
||
content=data["content"],
|
||
model=data["model"],
|
||
usage=TokenUsage(
|
||
prompt_tokens=usage_data.get("prompt_tokens", 0),
|
||
completion_tokens=usage_data.get("completion_tokens", 0),
|
||
),
|
||
tool_calls=[
|
||
ToolCall(id=tc["id"], name=tc["name"], arguments=tc["arguments"])
|
||
for tc in data.get("tool_calls", [])
|
||
],
|
||
latency_ms=data.get("latency_ms", 0.0),
|
||
)
|
||
|
||
|
||
def _serialize_entry(entry: CacheEntry) -> dict:
|
||
"""Serialize CacheEntry to a JSON-compatible dict."""
|
||
return {
|
||
"response": _serialize_response(entry.response),
|
||
"query_embedding": entry.query_embedding,
|
||
"created_at": entry.created_at,
|
||
"hit_count": entry.hit_count,
|
||
}
|
||
|
||
|
||
def _deserialize_entry(data: dict) -> CacheEntry:
|
||
"""Deserialize a dict back to CacheEntry."""
|
||
return CacheEntry(
|
||
response=_deserialize_response(data["response"]),
|
||
query_embedding=data.get("query_embedding", []),
|
||
created_at=data.get("created_at", 0.0),
|
||
hit_count=data.get("hit_count", 0),
|
||
)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# LLMCache Protocol
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
@runtime_checkable
|
||
class LLMCache(Protocol):
|
||
"""LLM response cache interface."""
|
||
|
||
async def get(self, key: str) -> CacheResult:
|
||
"""Exact-match lookup by cache key."""
|
||
...
|
||
|
||
async def semantic_search(
|
||
self, query_embedding: list[float], threshold: float = 0.92
|
||
) -> CacheResult:
|
||
"""Semantic similarity search across cached entries."""
|
||
...
|
||
|
||
async def put(
|
||
self,
|
||
key: str,
|
||
response: LLMResponse,
|
||
query_embedding: list[float] | None = None,
|
||
) -> None:
|
||
"""Store a response in the cache with optional embedding."""
|
||
...
|
||
|
||
async def invalidate(self, pattern: str | None = None) -> int:
|
||
"""Invalidate cache entries. Returns count of invalidated entries."""
|
||
...
|
||
|
||
async def stats(self) -> dict[str, int]:
|
||
"""Return cache statistics."""
|
||
...
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# InMemoryLLMCache
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class InMemoryLLMCache:
|
||
"""In-memory LLM cache with LRU eviction and semantic search.
|
||
|
||
Uses OrderedDict for O(1) LRU access/eviction (follows EmbeddingCache pattern).
|
||
Maintains a parallel embedding index for semantic similarity search.
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
max_entries: int = 10000,
|
||
exact_ttl: int = 3600,
|
||
semantic_ttl: int = 86400,
|
||
similarity_threshold: float = 0.92,
|
||
):
|
||
self._max_entries = max_entries
|
||
self._exact_ttl = exact_ttl
|
||
self._semantic_ttl = semantic_ttl
|
||
self._similarity_threshold = similarity_threshold
|
||
|
||
self._cache: OrderedDict[str, CacheEntry] = OrderedDict()
|
||
self._embeddings: dict[str, list[float]] = {}
|
||
|
||
self._hits = 0
|
||
self._misses = 0
|
||
|
||
async def get(self, key: str) -> CacheResult:
|
||
now = time.monotonic()
|
||
entry = self._cache.get(key)
|
||
|
||
if entry is not None:
|
||
if now - entry.created_at <= self._exact_ttl:
|
||
# Hit: update LRU position and stats
|
||
self._cache.move_to_end(key)
|
||
entry.hit_count += 1
|
||
self._hits += 1
|
||
return CacheResult(hit=True, response=entry.response, match_type="exact")
|
||
# Expired: remove
|
||
del self._cache[key]
|
||
self._embeddings.pop(key, None)
|
||
|
||
self._misses += 1
|
||
return CacheResult(hit=False)
|
||
|
||
async def semantic_search(
|
||
self, query_embedding: list[float], threshold: float | None = None
|
||
) -> CacheResult:
|
||
if not self._embeddings:
|
||
return CacheResult(hit=False)
|
||
|
||
effective_threshold = threshold or self._similarity_threshold
|
||
now = time.monotonic()
|
||
best_key: str | None = None
|
||
best_sim: float = 0.0
|
||
|
||
for key, emb in self._embeddings.items():
|
||
entry = self._cache.get(key)
|
||
if entry is None:
|
||
continue
|
||
# Check semantic TTL
|
||
if now - entry.created_at > self._semantic_ttl:
|
||
continue
|
||
sim = compute_cosine_similarity(query_embedding, emb)
|
||
if sim > best_sim:
|
||
best_sim = sim
|
||
best_key = key
|
||
|
||
if best_key is not None and best_sim >= effective_threshold:
|
||
entry = self._cache[best_key]
|
||
entry.hit_count += 1
|
||
self._cache.move_to_end(best_key)
|
||
self._hits += 1
|
||
return CacheResult(hit=True, response=entry.response, match_type="semantic")
|
||
|
||
self._misses += 1
|
||
return CacheResult(hit=False)
|
||
|
||
async def put(
|
||
self,
|
||
key: str,
|
||
response: LLMResponse,
|
||
query_embedding: list[float] | None = None,
|
||
) -> None:
|
||
now = time.monotonic()
|
||
|
||
if key in self._cache:
|
||
self._cache.move_to_end(key)
|
||
existing = self._cache[key]
|
||
# Preserve existing embedding if new one is None
|
||
effective_embedding = (
|
||
query_embedding if query_embedding is not None else existing.query_embedding
|
||
)
|
||
else:
|
||
effective_embedding = query_embedding or []
|
||
|
||
self._cache[key] = CacheEntry(
|
||
response=response,
|
||
query_embedding=effective_embedding,
|
||
created_at=now,
|
||
hit_count=0,
|
||
)
|
||
|
||
if effective_embedding:
|
||
self._embeddings[key] = effective_embedding
|
||
|
||
# Evict LRU entries if over capacity
|
||
while len(self._cache) > self._max_entries:
|
||
evicted_key, _ = self._cache.popitem(last=False)
|
||
self._embeddings.pop(evicted_key, None)
|
||
|
||
# Lazy cleanup: remove a few expired entries on each put to prevent memory leak
|
||
# Check oldest entries first (they are most likely to be expired)
|
||
if len(self._cache) > 0:
|
||
expired_keys = []
|
||
# Iterate from oldest (front of OrderedDict) to find expired entries
|
||
for k in list(self._cache.keys())[:20]:
|
||
entry = self._cache.get(k)
|
||
if entry is not None and now - entry.created_at > self._semantic_ttl:
|
||
expired_keys.append(k)
|
||
for k in expired_keys:
|
||
self._cache.pop(k, None)
|
||
self._embeddings.pop(k, None)
|
||
|
||
async def invalidate(self, pattern: str | None = None) -> int:
|
||
if pattern is None:
|
||
count = len(self._cache)
|
||
self._cache.clear()
|
||
self._embeddings.clear()
|
||
return count
|
||
|
||
# Simple prefix matching for pattern
|
||
keys_to_remove = [k for k in self._cache if k.startswith(pattern.replace("*", ""))]
|
||
for key in keys_to_remove:
|
||
del self._cache[key]
|
||
self._embeddings.pop(key, None)
|
||
return len(keys_to_remove)
|
||
|
||
async def stats(self) -> dict[str, int]:
|
||
return {
|
||
"total_entries": len(self._cache),
|
||
"total_hits": self._hits,
|
||
"total_misses": self._misses,
|
||
}
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# RedisLLMCache
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class RedisLLMCache:
|
||
"""Redis-backed LLM cache with SET index for semantic search.
|
||
|
||
Key schema:
|
||
agentkit:llm_cache:{sha256_hex} → JSON(CacheEntry) with TTL
|
||
agentkit:llm_cache_emb:{sha256_hex} → JSON(list[float]) with TTL
|
||
agentkit:llm_cache_index → SET of active cache keys
|
||
"""
|
||
|
||
KEY_PREFIX = "agentkit:llm_cache:"
|
||
EMB_PREFIX = "agentkit:llm_cache_emb:"
|
||
INDEX_KEY = "agentkit:llm_cache_index"
|
||
|
||
def __init__(
|
||
self,
|
||
redis_url: str = "redis://localhost:6379",
|
||
max_entries: int = 10000,
|
||
exact_ttl: int = 3600,
|
||
semantic_ttl: int = 86400,
|
||
similarity_threshold: float = 0.92,
|
||
max_entries_to_scan: int = 500,
|
||
fallback: InMemoryLLMCache | None = None,
|
||
):
|
||
self._redis_url = redis_url
|
||
self._max_entries = max_entries
|
||
self._exact_ttl = exact_ttl
|
||
self._semantic_ttl = semantic_ttl
|
||
self._similarity_threshold = similarity_threshold
|
||
self._max_entries_to_scan = max_entries_to_scan
|
||
self._redis: Any = None
|
||
self._fallback: InMemoryLLMCache | None = fallback # For auto-degradation
|
||
self._degraded = False # True if Redis is unreachable
|
||
|
||
self._hits = 0
|
||
self._misses = 0
|
||
|
||
async def _get_redis(self):
|
||
"""Lazy Redis initialization (follows RedisSessionStore pattern)."""
|
||
if self._redis is None:
|
||
import redis.asyncio as aioredis
|
||
|
||
self._redis = aioredis.from_url(self._redis_url, decode_responses=True)
|
||
return self._redis
|
||
|
||
async def aclose(self) -> None:
|
||
"""Close the Redis connection pool."""
|
||
if self._redis is not None:
|
||
await self._redis.aclose()
|
||
self._redis = None
|
||
|
||
def _degrade_to_fallback(self) -> None:
|
||
"""Mark Redis as unreachable and switch to in-memory fallback."""
|
||
if not self._degraded:
|
||
self._degraded = True
|
||
self._degrade_count = 0
|
||
if self._fallback is None:
|
||
self._fallback = InMemoryLLMCache(
|
||
max_entries=self._max_entries,
|
||
exact_ttl=self._exact_ttl,
|
||
semantic_ttl=self._semantic_ttl,
|
||
similarity_threshold=self._similarity_threshold,
|
||
)
|
||
logger.warning("Redis cache unreachable, degraded to in-memory fallback")
|
||
|
||
def _try_recover(self) -> None:
|
||
"""Attempt to recover from degraded state after enough operations.
|
||
|
||
Resets the degraded flag optimistically. The next actual Redis
|
||
operation will verify connectivity — if it fails, degradation
|
||
is re-triggered immediately.
|
||
"""
|
||
if not self._degraded:
|
||
return
|
||
self._degrade_count = getattr(self, "_degrade_count", 0) + 1
|
||
# Try recovery every 100 operations
|
||
if self._degrade_count >= 100:
|
||
self._degrade_count = 0
|
||
self._degraded = False
|
||
logger.info("Redis cache: attempting recovery from degraded state")
|
||
|
||
async def get(self, key: str) -> CacheResult:
|
||
# If degraded to fallback, use InMemory cache
|
||
if self._degraded and self._fallback is not None:
|
||
self._try_recover()
|
||
if self._degraded:
|
||
return await self._fallback.get(key)
|
||
# Recovery attempted — fall through to try Redis
|
||
|
||
try:
|
||
redis = await self._get_redis()
|
||
data = await redis.get(f"{self.KEY_PREFIX}{key}")
|
||
if data is not None:
|
||
entry = _deserialize_entry(json.loads(data))
|
||
self._hits += 1
|
||
return CacheResult(hit=True, response=entry.response, match_type="exact")
|
||
self._misses += 1
|
||
return CacheResult(hit=False)
|
||
except Exception as e:
|
||
logger.warning(f"Redis cache get failed, returning miss: {e}")
|
||
self._degrade_to_fallback()
|
||
if self._fallback is not None:
|
||
return await self._fallback.get(key)
|
||
return CacheResult(hit=False)
|
||
|
||
async def semantic_search(
|
||
self, query_embedding: list[float], threshold: float | None = None
|
||
) -> CacheResult:
|
||
try:
|
||
redis = await self._get_redis()
|
||
effective_threshold = threshold or self._similarity_threshold
|
||
|
||
# Get all cache keys from index
|
||
cache_keys = await redis.smembers(self.INDEX_KEY)
|
||
if not cache_keys:
|
||
return CacheResult(hit=False)
|
||
|
||
# Limit scan to avoid O(n) memory/network transfer for large caches
|
||
# Sample up to max_entries_to_scan most recent keys
|
||
cache_keys_list = list(cache_keys)
|
||
max_scan = min(len(cache_keys_list), self._max_entries_to_scan)
|
||
if len(cache_keys_list) > max_scan:
|
||
# Take a random sample to avoid always scanning the same subset
|
||
import random
|
||
|
||
cache_keys_list = random.sample(cache_keys_list, max_scan)
|
||
|
||
# Batch fetch embeddings
|
||
emb_keys = [f"{self.EMB_PREFIX}{k}" for k in cache_keys_list]
|
||
emb_values = await redis.mget(emb_keys)
|
||
|
||
best_key: str | None = None
|
||
best_sim: float = 0.0
|
||
stale_keys: list[str] = [] # Keys whose data has expired
|
||
|
||
for cache_key, emb_json in zip(cache_keys_list, emb_values):
|
||
if emb_json is None:
|
||
# Embedding expired but index entry remains — mark for cleanup
|
||
stale_keys.append(cache_key)
|
||
continue
|
||
emb = json.loads(emb_json)
|
||
sim = compute_cosine_similarity(query_embedding, emb)
|
||
if sim > best_sim:
|
||
best_sim = sim
|
||
best_key = cache_key
|
||
|
||
# Lazy cleanup: remove stale index entries
|
||
if stale_keys:
|
||
try:
|
||
pipe = redis.pipeline()
|
||
for k in stale_keys:
|
||
pipe.srem(self.INDEX_KEY, k)
|
||
await pipe.execute()
|
||
except Exception:
|
||
pass # Best-effort cleanup
|
||
|
||
if best_key is not None and best_sim >= effective_threshold:
|
||
data = await redis.get(f"{self.KEY_PREFIX}{best_key}")
|
||
if data is not None:
|
||
entry = _deserialize_entry(json.loads(data))
|
||
self._hits += 1
|
||
return CacheResult(hit=True, response=entry.response, match_type="semantic")
|
||
# Data key expired but embedding still exists — mark for cleanup
|
||
try:
|
||
await redis.srem(self.INDEX_KEY, best_key)
|
||
except Exception:
|
||
pass
|
||
|
||
self._misses += 1
|
||
return CacheResult(hit=False)
|
||
except Exception as e:
|
||
logger.warning(f"Redis semantic search failed, returning miss: {e}")
|
||
self._degrade_to_fallback()
|
||
if self._fallback is not None:
|
||
return await self._fallback.semantic_search(query_embedding, threshold)
|
||
self._misses += 1
|
||
return CacheResult(hit=False)
|
||
|
||
async def put(
|
||
self,
|
||
key: str,
|
||
response: LLMResponse,
|
||
query_embedding: list[float] | None = None,
|
||
) -> None:
|
||
# If degraded to fallback, use InMemory cache
|
||
if self._degraded and self._fallback is not None:
|
||
self._try_recover()
|
||
if self._degraded:
|
||
await self._fallback.put(key, response, query_embedding)
|
||
return
|
||
# Recovery attempted — fall through to try Redis
|
||
|
||
try:
|
||
redis = await self._get_redis()
|
||
entry = CacheEntry(
|
||
response=response,
|
||
query_embedding=query_embedding or [],
|
||
created_at=time.time(), # Wall-clock for cross-process comparability in Redis
|
||
hit_count=0,
|
||
)
|
||
|
||
pipe = redis.pipeline()
|
||
# Data key TTL must cover both exact and semantic windows
|
||
# so semantic hits don't return None data
|
||
data_ttl = max(self._exact_ttl, self._semantic_ttl)
|
||
pipe.set(
|
||
f"{self.KEY_PREFIX}{key}",
|
||
json.dumps(_serialize_entry(entry)),
|
||
ex=data_ttl,
|
||
)
|
||
if query_embedding is not None:
|
||
pipe.set(
|
||
f"{self.EMB_PREFIX}{key}",
|
||
json.dumps(query_embedding),
|
||
ex=self._semantic_ttl,
|
||
)
|
||
pipe.sadd(self.INDEX_KEY, key)
|
||
await pipe.execute()
|
||
except Exception as e:
|
||
logger.warning(f"Redis cache put failed: {e}")
|
||
self._degrade_to_fallback()
|
||
if self._fallback is not None:
|
||
await self._fallback.put(key, response, query_embedding)
|
||
|
||
async def invalidate(self, pattern: str | None = None) -> int:
|
||
try:
|
||
redis = await self._get_redis()
|
||
|
||
if pattern is None:
|
||
cache_keys = await redis.smembers(self.INDEX_KEY)
|
||
if not cache_keys:
|
||
return 0
|
||
pipe = redis.pipeline()
|
||
for key in cache_keys:
|
||
pipe.delete(f"{self.KEY_PREFIX}{key}")
|
||
pipe.delete(f"{self.EMB_PREFIX}{key}")
|
||
pipe.delete(self.INDEX_KEY)
|
||
await pipe.execute()
|
||
return len(cache_keys)
|
||
|
||
# Pattern-based invalidation (prefix match)
|
||
prefix = pattern.replace("*", "")
|
||
cache_keys = await redis.smembers(self.INDEX_KEY)
|
||
keys_to_remove = [k for k in cache_keys if k.startswith(prefix)]
|
||
|
||
if not keys_to_remove:
|
||
return 0
|
||
|
||
pipe = redis.pipeline()
|
||
for key in keys_to_remove:
|
||
pipe.delete(f"{self.KEY_PREFIX}{key}")
|
||
pipe.delete(f"{self.EMB_PREFIX}{key}")
|
||
pipe.srem(self.INDEX_KEY, key)
|
||
await pipe.execute()
|
||
return len(keys_to_remove)
|
||
except Exception as e:
|
||
logger.warning(f"Redis cache invalidate failed: {e}")
|
||
return 0
|
||
|
||
async def stats(self) -> dict[str, int]:
|
||
try:
|
||
redis = await self._get_redis()
|
||
total_entries = await redis.scard(self.INDEX_KEY)
|
||
return {
|
||
"total_entries": total_entries,
|
||
"total_hits": self._hits,
|
||
"total_misses": self._misses,
|
||
}
|
||
except Exception:
|
||
return {
|
||
"total_entries": 0,
|
||
"total_hits": self._hits,
|
||
"total_misses": self._misses,
|
||
}
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Factory
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def create_llm_cache(
|
||
backend: str = "auto",
|
||
redis_url: str = "redis://localhost:6379",
|
||
max_entries: int = 10000,
|
||
exact_ttl: int = 3600,
|
||
semantic_ttl: int = 86400,
|
||
similarity_threshold: float = 0.92,
|
||
) -> LLMCache:
|
||
"""Create an LLM cache backend.
|
||
|
||
Args:
|
||
backend: "auto" (try Redis, fallback to memory), "redis", "memory".
|
||
redis_url: Redis connection URL (only used for "redis"/"auto" backend).
|
||
max_entries: Maximum number of cache entries.
|
||
exact_ttl: TTL in seconds for exact-match cache entries.
|
||
semantic_ttl: TTL in seconds for semantic-match embeddings.
|
||
similarity_threshold: Cosine similarity threshold for semantic match.
|
||
|
||
Returns:
|
||
An LLMCache instance.
|
||
"""
|
||
if backend in ("auto", "redis"):
|
||
try:
|
||
import redis.asyncio as aioredis # noqa: F401
|
||
|
||
return RedisLLMCache(
|
||
redis_url=redis_url,
|
||
max_entries=max_entries,
|
||
exact_ttl=exact_ttl,
|
||
semantic_ttl=semantic_ttl,
|
||
similarity_threshold=similarity_threshold,
|
||
)
|
||
except ImportError:
|
||
logger.warning("redis package not available, falling back to in-memory cache")
|
||
return InMemoryLLMCache(
|
||
max_entries=max_entries,
|
||
exact_ttl=exact_ttl,
|
||
semantic_ttl=semantic_ttl,
|
||
similarity_threshold=similarity_threshold,
|
||
)
|
||
return InMemoryLLMCache(
|
||
max_entries=max_entries,
|
||
exact_ttl=exact_ttl,
|
||
semantic_ttl=semantic_ttl,
|
||
similarity_threshold=similarity_threshold,
|
||
)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# U17 — LiteLLM 缓存管理器
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
@dataclass
|
||
class LitellmCacheConfig:
|
||
"""U17 — LiteLLM 缓存配置(从 CacheConfig 转换)。
|
||
|
||
与旧 ``CacheConfig`` 的区别:
|
||
- ``similarity_threshold`` 固定默认 0.87(plan 规定)
|
||
- ``per_user_namespace`` 强制开启(安全要求 a)
|
||
- ``backend`` 新增 ``redis_semantic`` 选项(需要 redisvl)
|
||
"""
|
||
|
||
enabled: bool = False
|
||
backend: str = "auto" # "auto" | "redis_semantic" | "redis" | "memory"
|
||
redis_url: str = "redis://localhost:6379"
|
||
similarity_threshold: float = 0.87 # U17 默认 0.87(plan 规定)
|
||
ttl: int = 86400
|
||
embedding_model: str = "text-embedding-ada-002"
|
||
per_user_namespace: bool = True # 安全要求 (a)
|
||
|
||
@classmethod
|
||
def from_cache_config(cls, c: "CacheConfig") -> "LitellmCacheConfig":
|
||
"""从现有 CacheConfig 转换。
|
||
|
||
- ``similarity_threshold`` 固定 0.87(U17 plan 规定,忽略旧 0.92)
|
||
- ``per_user_namespace`` 强制 True(安全要求 a)
|
||
- ``embedding_model`` 回退到 "text-embedding-ada-002"(LiteLLM 默认)
|
||
"""
|
||
return cls(
|
||
enabled=c.enabled,
|
||
backend=c.backend if c.backend in ("auto", "redis", "memory") else "auto",
|
||
redis_url=c.redis_url,
|
||
similarity_threshold=0.87, # U17 固定默认,忽略旧 0.92
|
||
ttl=c.semantic_ttl,
|
||
embedding_model=c.embedding_model or "text-embedding-ada-002",
|
||
per_user_namespace=True, # 强制开启
|
||
)
|
||
|
||
|
||
class LitellmCacheManager:
|
||
"""U17 — LiteLLM 全局缓存管理器。
|
||
|
||
职责:
|
||
1. 创建并设置 ``litellm.cache`` 全局实例
|
||
2. 构建带 user/ACL scope 的 cache key(安全要求 a, b)
|
||
3. 提供 per-call cache 参数(cache_key 或 no-cache)
|
||
4. 检测 LiteLLM 响应的缓存命中标志(用于 usage tracking)
|
||
5. 统计缓存命中率
|
||
|
||
后端选择优先级(backend="auto" 时):
|
||
RedisSemanticCache(需 redisvl)→ RedisCache(精确)→ InMemoryCache
|
||
|
||
安全约束:
|
||
- (a) cache key 包含 user_id(per-user namespace)
|
||
- (b) cache key 包含 kb_acl_hash(ACL-scope 隔离)
|
||
- (c) KB 设置 caching_disabled=True 时禁用缓存
|
||
- (e) User A 的查询不会命中 User B 的缓存
|
||
"""
|
||
|
||
def __init__(self, config: LitellmCacheConfig):
|
||
self._config = config
|
||
self._cache_instance: Any = None # litellm.caching.Cache 实例
|
||
self._hits = 0
|
||
self._misses = 0
|
||
|
||
def enable(self) -> None:
|
||
"""创建 LiteLLM Cache 实例并赋值给 ``litellm.cache``。"""
|
||
import litellm
|
||
|
||
self._cache_instance = self._create_cache_instance()
|
||
litellm.cache = self._cache_instance
|
||
|
||
def disable(self) -> None:
|
||
"""禁用缓存 — 设置 ``litellm.cache = None``。"""
|
||
import litellm
|
||
|
||
litellm.cache = None
|
||
self._cache_instance = None
|
||
|
||
def _create_cache_instance(self) -> Any:
|
||
"""根据 backend 配置创建 LiteLLM Cache 实例。
|
||
|
||
auto 模式按优先级尝试:RedisSemanticCache → RedisCache → InMemoryCache。
|
||
redisvl 缺失时自动回退(安全要求 d — 不添加为必需依赖)。
|
||
"""
|
||
backend = self._config.backend
|
||
if backend in ("auto", "redis_semantic"):
|
||
# 尝试 RedisSemanticCache(需要 redisvl — lazy import,缺失时 fallback)
|
||
try:
|
||
from litellm.caching import RedisSemanticCache
|
||
|
||
return RedisSemanticCache(
|
||
redis_url=self._config.redis_url,
|
||
similarity_threshold=self._config.similarity_threshold,
|
||
embedding_model=self._config.embedding_model,
|
||
)
|
||
except ImportError:
|
||
logger.warning(
|
||
"RedisSemanticCache 需要 redisvl 包(未安装),"
|
||
"回退到 RedisCache(精确匹配,无语义搜索)。"
|
||
"安装 redisvl 以启用语义缓存:pip install redisvl"
|
||
)
|
||
if backend == "redis_semantic":
|
||
raise # 显式要求语义缓存但 redisvl 缺失 — 报错
|
||
except Exception as e:
|
||
logger.warning(f"RedisSemanticCache 初始化失败: {e},回退到 RedisCache")
|
||
|
||
if backend in ("auto", "redis", "redis_semantic"):
|
||
try:
|
||
from litellm.caching import RedisCache
|
||
|
||
return RedisCache(redis_url=self._config.redis_url)
|
||
except Exception as e:
|
||
logger.warning(f"RedisCache 初始化失败: {e},回退到 InMemoryCache")
|
||
|
||
from litellm.caching import InMemoryCache
|
||
|
||
return InMemoryCache()
|
||
|
||
def build_cache_key(
|
||
self,
|
||
model: str,
|
||
messages: list[dict[str, str]],
|
||
temperature: float,
|
||
tools: list[dict] | None = None,
|
||
tool_choice: str = "auto",
|
||
max_tokens: int = 2000,
|
||
user_id: str | None = None,
|
||
kb_acl_hash: str | None = None,
|
||
) -> str:
|
||
"""构建带 user/ACL scope 的 cache key(安全要求 a, b, e)。
|
||
|
||
委托给 ``cache_key.generate_cache_key``,额外注入 user_id + kb_acl_hash
|
||
作为命名空间隔离,确保 User A 的查询不会命中 User B 的缓存。
|
||
"""
|
||
from agentkit.llm.cache_key import generate_cache_key
|
||
|
||
return generate_cache_key(
|
||
model=model,
|
||
messages=messages,
|
||
temperature=temperature,
|
||
tools=tools,
|
||
tool_choice=tool_choice,
|
||
max_tokens=max_tokens,
|
||
user_id=user_id,
|
||
kb_acl_hash=kb_acl_hash,
|
||
)
|
||
|
||
def should_cache(
|
||
self,
|
||
kb_caching_disabled: bool = False,
|
||
user_id: str | None = None,
|
||
) -> bool:
|
||
"""判断当前请求是否应该缓存(安全要求 c)。
|
||
|
||
- KB 设置 caching_disabled=True → 不缓存
|
||
- 其余情况缓存(user_id 为 None 时仍可缓存,但 key 不含 user scope)
|
||
"""
|
||
_ = user_id # 预留:未来支持 per-user 缓存禁用
|
||
if kb_caching_disabled:
|
||
return False
|
||
return True
|
||
|
||
@staticmethod
|
||
def cache_params_for_hit(cache_key: str) -> dict[str, str]:
|
||
"""返回 litellm acompletion 的 cache 参数(用于期望命中的调用)。"""
|
||
return {"cache_key": cache_key}
|
||
|
||
@staticmethod
|
||
def cache_params_for_no_cache() -> dict[str, bool]:
|
||
"""返回 litellm acompletion 的 cache 参数(禁用缓存)。"""
|
||
return {"no-cache": True}
|
||
|
||
def detect_cache_hit(self, response: Any) -> bool:
|
||
"""检测 LiteLLM 响应是否为缓存命中。
|
||
|
||
LiteLLM 在缓存命中时设置 ``response._hidden_params["cache_key"]``。
|
||
"""
|
||
hidden = getattr(response, "_hidden_params", None)
|
||
if isinstance(hidden, dict):
|
||
if "cache_key" in hidden or hidden.get("cache_hit"):
|
||
self._hits += 1
|
||
return True
|
||
self._misses += 1
|
||
return False
|
||
|
||
def stats(self) -> dict[str, int]:
|
||
"""返回缓存统计。"""
|
||
return {
|
||
"total_hits": self._hits,
|
||
"total_misses": self._misses,
|
||
}
|