fischer-agentkit/src/agentkit/llm/cache.py

825 lines
29 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""LLM Response Cache — Exact-match + Semantic-match dual cache for LLM responses.
Architecture:
- LLMCache Protocol: async interface for cache backends
- InMemoryLLMCache: OrderedDict LRU + embedding index
- RedisLLMCache: Redis keys + SET index + lazy init
- create_llm_cache(): Factory with auto-detection
Design doc: docs/plans/2026-06-14-002-u1-llm-cache-architecture.md
"""
import json
import logging
import time
from collections import OrderedDict
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
from agentkit.llm.protocol import LLMResponse, TokenUsage, ToolCall
from agentkit.utils.vector_math import compute_cosine_similarity
if TYPE_CHECKING:
from agentkit.llm.config import CacheConfig
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Data Classes
# ---------------------------------------------------------------------------
@dataclass
class CacheEntry:
"""A cached LLM response with metadata."""
response: LLMResponse
query_embedding: list[float] = field(default_factory=list)
created_at: float = 0.0
hit_count: int = 0
@dataclass
class CacheResult:
"""Result of a cache lookup."""
hit: bool = False
response: LLMResponse | None = None
match_type: str = "" # "exact" | "semantic" | "" (miss)
# ---------------------------------------------------------------------------
# Serialization helpers (for Redis backend)
# ---------------------------------------------------------------------------
def _serialize_response(response: LLMResponse) -> dict:
"""Serialize LLMResponse to a JSON-compatible dict."""
return {
"content": response.content,
"model": response.model,
"usage": {
"prompt_tokens": response.usage.prompt_tokens,
"completion_tokens": response.usage.completion_tokens,
},
"tool_calls": [
{"id": tc.id, "name": tc.name, "arguments": tc.arguments} for tc in response.tool_calls
],
"latency_ms": response.latency_ms,
}
def _deserialize_response(data: dict) -> LLMResponse:
"""Deserialize a dict back to LLMResponse."""
usage_data = data.get("usage", {})
return LLMResponse(
content=data["content"],
model=data["model"],
usage=TokenUsage(
prompt_tokens=usage_data.get("prompt_tokens", 0),
completion_tokens=usage_data.get("completion_tokens", 0),
),
tool_calls=[
ToolCall(id=tc["id"], name=tc["name"], arguments=tc["arguments"])
for tc in data.get("tool_calls", [])
],
latency_ms=data.get("latency_ms", 0.0),
)
def _serialize_entry(entry: CacheEntry) -> dict:
"""Serialize CacheEntry to a JSON-compatible dict."""
return {
"response": _serialize_response(entry.response),
"query_embedding": entry.query_embedding,
"created_at": entry.created_at,
"hit_count": entry.hit_count,
}
def _deserialize_entry(data: dict) -> CacheEntry:
"""Deserialize a dict back to CacheEntry."""
return CacheEntry(
response=_deserialize_response(data["response"]),
query_embedding=data.get("query_embedding", []),
created_at=data.get("created_at", 0.0),
hit_count=data.get("hit_count", 0),
)
# ---------------------------------------------------------------------------
# LLMCache Protocol
# ---------------------------------------------------------------------------
@runtime_checkable
class LLMCache(Protocol):
"""LLM response cache interface."""
async def get(self, key: str) -> CacheResult:
"""Exact-match lookup by cache key."""
...
async def semantic_search(
self, query_embedding: list[float], threshold: float = 0.92
) -> CacheResult:
"""Semantic similarity search across cached entries."""
...
async def put(
self,
key: str,
response: LLMResponse,
query_embedding: list[float] | None = None,
) -> None:
"""Store a response in the cache with optional embedding."""
...
async def invalidate(self, pattern: str | None = None) -> int:
"""Invalidate cache entries. Returns count of invalidated entries."""
...
async def stats(self) -> dict[str, int]:
"""Return cache statistics."""
...
# ---------------------------------------------------------------------------
# InMemoryLLMCache
# ---------------------------------------------------------------------------
class InMemoryLLMCache:
"""In-memory LLM cache with LRU eviction and semantic search.
Uses OrderedDict for O(1) LRU access/eviction (follows EmbeddingCache pattern).
Maintains a parallel embedding index for semantic similarity search.
"""
def __init__(
self,
max_entries: int = 10000,
exact_ttl: int = 3600,
semantic_ttl: int = 86400,
similarity_threshold: float = 0.92,
):
self._max_entries = max_entries
self._exact_ttl = exact_ttl
self._semantic_ttl = semantic_ttl
self._similarity_threshold = similarity_threshold
self._cache: OrderedDict[str, CacheEntry] = OrderedDict()
self._embeddings: dict[str, list[float]] = {}
self._hits = 0
self._misses = 0
async def get(self, key: str) -> CacheResult:
now = time.monotonic()
entry = self._cache.get(key)
if entry is not None:
if now - entry.created_at <= self._exact_ttl:
# Hit: update LRU position and stats
self._cache.move_to_end(key)
entry.hit_count += 1
self._hits += 1
return CacheResult(hit=True, response=entry.response, match_type="exact")
# Expired: remove
del self._cache[key]
self._embeddings.pop(key, None)
self._misses += 1
return CacheResult(hit=False)
async def semantic_search(
self, query_embedding: list[float], threshold: float | None = None
) -> CacheResult:
if not self._embeddings:
return CacheResult(hit=False)
effective_threshold = threshold or self._similarity_threshold
now = time.monotonic()
best_key: str | None = None
best_sim: float = 0.0
for key, emb in self._embeddings.items():
entry = self._cache.get(key)
if entry is None:
continue
# Check semantic TTL
if now - entry.created_at > self._semantic_ttl:
continue
sim = compute_cosine_similarity(query_embedding, emb)
if sim > best_sim:
best_sim = sim
best_key = key
if best_key is not None and best_sim >= effective_threshold:
entry = self._cache[best_key]
entry.hit_count += 1
self._cache.move_to_end(best_key)
self._hits += 1
return CacheResult(hit=True, response=entry.response, match_type="semantic")
self._misses += 1
return CacheResult(hit=False)
async def put(
self,
key: str,
response: LLMResponse,
query_embedding: list[float] | None = None,
) -> None:
now = time.monotonic()
if key in self._cache:
self._cache.move_to_end(key)
existing = self._cache[key]
# Preserve existing embedding if new one is None
effective_embedding = (
query_embedding if query_embedding is not None else existing.query_embedding
)
else:
effective_embedding = query_embedding or []
self._cache[key] = CacheEntry(
response=response,
query_embedding=effective_embedding,
created_at=now,
hit_count=0,
)
if effective_embedding:
self._embeddings[key] = effective_embedding
# Evict LRU entries if over capacity
while len(self._cache) > self._max_entries:
evicted_key, _ = self._cache.popitem(last=False)
self._embeddings.pop(evicted_key, None)
# Lazy cleanup: remove a few expired entries on each put to prevent memory leak
# Check oldest entries first (they are most likely to be expired)
if len(self._cache) > 0:
expired_keys = []
# Iterate from oldest (front of OrderedDict) to find expired entries
for k in list(self._cache.keys())[:20]:
entry = self._cache.get(k)
if entry is not None and now - entry.created_at > self._semantic_ttl:
expired_keys.append(k)
for k in expired_keys:
self._cache.pop(k, None)
self._embeddings.pop(k, None)
async def invalidate(self, pattern: str | None = None) -> int:
if pattern is None:
count = len(self._cache)
self._cache.clear()
self._embeddings.clear()
return count
# Simple prefix matching for pattern
keys_to_remove = [k for k in self._cache if k.startswith(pattern.replace("*", ""))]
for key in keys_to_remove:
del self._cache[key]
self._embeddings.pop(key, None)
return len(keys_to_remove)
async def stats(self) -> dict[str, int]:
return {
"total_entries": len(self._cache),
"total_hits": self._hits,
"total_misses": self._misses,
}
# ---------------------------------------------------------------------------
# RedisLLMCache
# ---------------------------------------------------------------------------
class RedisLLMCache:
"""Redis-backed LLM cache with SET index for semantic search.
Key schema:
agentkit:llm_cache:{sha256_hex} → JSON(CacheEntry) with TTL
agentkit:llm_cache_emb:{sha256_hex} → JSON(list[float]) with TTL
agentkit:llm_cache_index → SET of active cache keys
"""
KEY_PREFIX = "agentkit:llm_cache:"
EMB_PREFIX = "agentkit:llm_cache_emb:"
INDEX_KEY = "agentkit:llm_cache_index"
def __init__(
self,
redis_url: str = "redis://localhost:6379",
max_entries: int = 10000,
exact_ttl: int = 3600,
semantic_ttl: int = 86400,
similarity_threshold: float = 0.92,
max_entries_to_scan: int = 500,
fallback: InMemoryLLMCache | None = None,
):
self._redis_url = redis_url
self._max_entries = max_entries
self._exact_ttl = exact_ttl
self._semantic_ttl = semantic_ttl
self._similarity_threshold = similarity_threshold
self._max_entries_to_scan = max_entries_to_scan
self._redis: Any = None
self._fallback: InMemoryLLMCache | None = fallback # For auto-degradation
self._degraded = False # True if Redis is unreachable
self._hits = 0
self._misses = 0
async def _get_redis(self):
"""Lazy Redis initialization (follows RedisSessionStore pattern)."""
if self._redis is None:
import redis.asyncio as aioredis
self._redis = aioredis.from_url(self._redis_url, decode_responses=True)
return self._redis
async def aclose(self) -> None:
"""Close the Redis connection pool."""
if self._redis is not None:
await self._redis.aclose()
self._redis = None
def _degrade_to_fallback(self) -> None:
"""Mark Redis as unreachable and switch to in-memory fallback."""
if not self._degraded:
self._degraded = True
self._degrade_count = 0
if self._fallback is None:
self._fallback = InMemoryLLMCache(
max_entries=self._max_entries,
exact_ttl=self._exact_ttl,
semantic_ttl=self._semantic_ttl,
similarity_threshold=self._similarity_threshold,
)
logger.warning("Redis cache unreachable, degraded to in-memory fallback")
def _try_recover(self) -> None:
"""Attempt to recover from degraded state after enough operations.
Resets the degraded flag optimistically. The next actual Redis
operation will verify connectivity — if it fails, degradation
is re-triggered immediately.
"""
if not self._degraded:
return
self._degrade_count = getattr(self, "_degrade_count", 0) + 1
# Try recovery every 100 operations
if self._degrade_count >= 100:
self._degrade_count = 0
self._degraded = False
logger.info("Redis cache: attempting recovery from degraded state")
async def get(self, key: str) -> CacheResult:
# If degraded to fallback, use InMemory cache
if self._degraded and self._fallback is not None:
self._try_recover()
if self._degraded:
return await self._fallback.get(key)
# Recovery attempted — fall through to try Redis
try:
redis = await self._get_redis()
data = await redis.get(f"{self.KEY_PREFIX}{key}")
if data is not None:
entry = _deserialize_entry(json.loads(data))
self._hits += 1
return CacheResult(hit=True, response=entry.response, match_type="exact")
self._misses += 1
return CacheResult(hit=False)
except Exception as e:
logger.warning(f"Redis cache get failed, returning miss: {e}")
self._degrade_to_fallback()
if self._fallback is not None:
return await self._fallback.get(key)
return CacheResult(hit=False)
async def semantic_search(
self, query_embedding: list[float], threshold: float | None = None
) -> CacheResult:
try:
redis = await self._get_redis()
effective_threshold = threshold or self._similarity_threshold
# Get all cache keys from index
cache_keys = await redis.smembers(self.INDEX_KEY)
if not cache_keys:
return CacheResult(hit=False)
# Limit scan to avoid O(n) memory/network transfer for large caches
# Sample up to max_entries_to_scan most recent keys
cache_keys_list = list(cache_keys)
max_scan = min(len(cache_keys_list), self._max_entries_to_scan)
if len(cache_keys_list) > max_scan:
# Take a random sample to avoid always scanning the same subset
import random
cache_keys_list = random.sample(cache_keys_list, max_scan)
# Batch fetch embeddings
emb_keys = [f"{self.EMB_PREFIX}{k}" for k in cache_keys_list]
emb_values = await redis.mget(emb_keys)
best_key: str | None = None
best_sim: float = 0.0
stale_keys: list[str] = [] # Keys whose data has expired
for cache_key, emb_json in zip(cache_keys_list, emb_values):
if emb_json is None:
# Embedding expired but index entry remains — mark for cleanup
stale_keys.append(cache_key)
continue
emb = json.loads(emb_json)
sim = compute_cosine_similarity(query_embedding, emb)
if sim > best_sim:
best_sim = sim
best_key = cache_key
# Lazy cleanup: remove stale index entries
if stale_keys:
try:
pipe = redis.pipeline()
for k in stale_keys:
pipe.srem(self.INDEX_KEY, k)
await pipe.execute()
except Exception:
pass # Best-effort cleanup
if best_key is not None and best_sim >= effective_threshold:
data = await redis.get(f"{self.KEY_PREFIX}{best_key}")
if data is not None:
entry = _deserialize_entry(json.loads(data))
self._hits += 1
return CacheResult(hit=True, response=entry.response, match_type="semantic")
# Data key expired but embedding still exists — mark for cleanup
try:
await redis.srem(self.INDEX_KEY, best_key)
except Exception:
pass
self._misses += 1
return CacheResult(hit=False)
except Exception as e:
logger.warning(f"Redis semantic search failed, returning miss: {e}")
self._degrade_to_fallback()
if self._fallback is not None:
return await self._fallback.semantic_search(query_embedding, threshold)
self._misses += 1
return CacheResult(hit=False)
async def put(
self,
key: str,
response: LLMResponse,
query_embedding: list[float] | None = None,
) -> None:
# If degraded to fallback, use InMemory cache
if self._degraded and self._fallback is not None:
self._try_recover()
if self._degraded:
await self._fallback.put(key, response, query_embedding)
return
# Recovery attempted — fall through to try Redis
try:
redis = await self._get_redis()
entry = CacheEntry(
response=response,
query_embedding=query_embedding or [],
created_at=time.time(), # Wall-clock for cross-process comparability in Redis
hit_count=0,
)
pipe = redis.pipeline()
# Data key TTL must cover both exact and semantic windows
# so semantic hits don't return None data
data_ttl = max(self._exact_ttl, self._semantic_ttl)
pipe.set(
f"{self.KEY_PREFIX}{key}",
json.dumps(_serialize_entry(entry)),
ex=data_ttl,
)
if query_embedding is not None:
pipe.set(
f"{self.EMB_PREFIX}{key}",
json.dumps(query_embedding),
ex=self._semantic_ttl,
)
pipe.sadd(self.INDEX_KEY, key)
await pipe.execute()
except Exception as e:
logger.warning(f"Redis cache put failed: {e}")
self._degrade_to_fallback()
if self._fallback is not None:
await self._fallback.put(key, response, query_embedding)
async def invalidate(self, pattern: str | None = None) -> int:
try:
redis = await self._get_redis()
if pattern is None:
cache_keys = await redis.smembers(self.INDEX_KEY)
if not cache_keys:
return 0
pipe = redis.pipeline()
for key in cache_keys:
pipe.delete(f"{self.KEY_PREFIX}{key}")
pipe.delete(f"{self.EMB_PREFIX}{key}")
pipe.delete(self.INDEX_KEY)
await pipe.execute()
return len(cache_keys)
# Pattern-based invalidation (prefix match)
prefix = pattern.replace("*", "")
cache_keys = await redis.smembers(self.INDEX_KEY)
keys_to_remove = [k for k in cache_keys if k.startswith(prefix)]
if not keys_to_remove:
return 0
pipe = redis.pipeline()
for key in keys_to_remove:
pipe.delete(f"{self.KEY_PREFIX}{key}")
pipe.delete(f"{self.EMB_PREFIX}{key}")
pipe.srem(self.INDEX_KEY, key)
await pipe.execute()
return len(keys_to_remove)
except Exception as e:
logger.warning(f"Redis cache invalidate failed: {e}")
return 0
async def stats(self) -> dict[str, int]:
try:
redis = await self._get_redis()
total_entries = await redis.scard(self.INDEX_KEY)
return {
"total_entries": total_entries,
"total_hits": self._hits,
"total_misses": self._misses,
}
except Exception:
return {
"total_entries": 0,
"total_hits": self._hits,
"total_misses": self._misses,
}
# ---------------------------------------------------------------------------
# Factory
# ---------------------------------------------------------------------------
def create_llm_cache(
backend: str = "auto",
redis_url: str = "redis://localhost:6379",
max_entries: int = 10000,
exact_ttl: int = 3600,
semantic_ttl: int = 86400,
similarity_threshold: float = 0.92,
) -> LLMCache:
"""Create an LLM cache backend.
Args:
backend: "auto" (try Redis, fallback to memory), "redis", "memory".
redis_url: Redis connection URL (only used for "redis"/"auto" backend).
max_entries: Maximum number of cache entries.
exact_ttl: TTL in seconds for exact-match cache entries.
semantic_ttl: TTL in seconds for semantic-match embeddings.
similarity_threshold: Cosine similarity threshold for semantic match.
Returns:
An LLMCache instance.
"""
if backend in ("auto", "redis"):
try:
import redis.asyncio as aioredis # noqa: F401
return RedisLLMCache(
redis_url=redis_url,
max_entries=max_entries,
exact_ttl=exact_ttl,
semantic_ttl=semantic_ttl,
similarity_threshold=similarity_threshold,
)
except ImportError:
logger.warning("redis package not available, falling back to in-memory cache")
return InMemoryLLMCache(
max_entries=max_entries,
exact_ttl=exact_ttl,
semantic_ttl=semantic_ttl,
similarity_threshold=similarity_threshold,
)
return InMemoryLLMCache(
max_entries=max_entries,
exact_ttl=exact_ttl,
semantic_ttl=semantic_ttl,
similarity_threshold=similarity_threshold,
)
# ---------------------------------------------------------------------------
# U17 — LiteLLM 缓存管理器
# ---------------------------------------------------------------------------
@dataclass
class LitellmCacheConfig:
"""U17 — LiteLLM 缓存配置(从 CacheConfig 转换)。
与旧 ``CacheConfig`` 的区别:
- ``similarity_threshold`` 固定默认 0.87plan 规定)
- ``per_user_namespace`` 强制开启(安全要求 a
- ``backend`` 新增 ``redis_semantic`` 选项(需要 redisvl
"""
enabled: bool = False
backend: str = "auto" # "auto" | "redis_semantic" | "redis" | "memory"
redis_url: str = "redis://localhost:6379"
similarity_threshold: float = 0.87 # U17 默认 0.87plan 规定)
ttl: int = 86400
embedding_model: str = "text-embedding-ada-002"
per_user_namespace: bool = True # 安全要求 (a)
@classmethod
def from_cache_config(cls, c: "CacheConfig") -> "LitellmCacheConfig":
"""从现有 CacheConfig 转换。
- ``similarity_threshold`` 固定 0.87U17 plan 规定,忽略旧 0.92
- ``per_user_namespace`` 强制 True安全要求 a
- ``embedding_model`` 回退到 "text-embedding-ada-002"LiteLLM 默认)
"""
return cls(
enabled=c.enabled,
backend=c.backend if c.backend in ("auto", "redis", "memory") else "auto",
redis_url=c.redis_url,
similarity_threshold=0.87, # U17 固定默认,忽略旧 0.92
ttl=c.semantic_ttl,
embedding_model=c.embedding_model or "text-embedding-ada-002",
per_user_namespace=True, # 强制开启
)
class LitellmCacheManager:
"""U17 — LiteLLM 全局缓存管理器。
职责:
1. 创建并设置 ``litellm.cache`` 全局实例
2. 构建带 user/ACL scope 的 cache key安全要求 a, b
3. 提供 per-call cache 参数cache_key 或 no-cache
4. 检测 LiteLLM 响应的缓存命中标志(用于 usage tracking
5. 统计缓存命中率
后端选择优先级backend="auto" 时):
RedisSemanticCache需 redisvl→ RedisCache精确→ InMemoryCache
安全约束:
- (a) cache key 包含 user_idper-user namespace
- (b) cache key 包含 kb_acl_hashACL-scope 隔离)
- (c) KB 设置 caching_disabled=True 时禁用缓存
- (e) User A 的查询不会命中 User B 的缓存
"""
def __init__(self, config: LitellmCacheConfig):
self._config = config
self._cache_instance: Any = None # litellm.caching.Cache 实例
self._hits = 0
self._misses = 0
def enable(self) -> None:
"""创建 LiteLLM Cache 实例并赋值给 ``litellm.cache``。"""
import litellm
self._cache_instance = self._create_cache_instance()
litellm.cache = self._cache_instance
def disable(self) -> None:
"""禁用缓存 — 设置 ``litellm.cache = None``。"""
import litellm
litellm.cache = None
self._cache_instance = None
def _create_cache_instance(self) -> Any:
"""根据 backend 配置创建 LiteLLM Cache 实例。
auto 模式按优先级尝试RedisSemanticCache → RedisCache → InMemoryCache。
redisvl 缺失时自动回退(安全要求 d — 不添加为必需依赖)。
"""
backend = self._config.backend
if backend in ("auto", "redis_semantic"):
# 尝试 RedisSemanticCache需要 redisvl — lazy import缺失时 fallback
try:
from litellm.caching import RedisSemanticCache
return RedisSemanticCache(
redis_url=self._config.redis_url,
similarity_threshold=self._config.similarity_threshold,
embedding_model=self._config.embedding_model,
)
except ImportError:
logger.warning(
"RedisSemanticCache 需要 redisvl 包(未安装),"
"回退到 RedisCache精确匹配无语义搜索"
"安装 redisvl 以启用语义缓存pip install redisvl"
)
if backend == "redis_semantic":
raise # 显式要求语义缓存但 redisvl 缺失 — 报错
except Exception as e:
logger.warning(f"RedisSemanticCache 初始化失败: {e},回退到 RedisCache")
if backend in ("auto", "redis", "redis_semantic"):
try:
from litellm.caching import RedisCache
return RedisCache(redis_url=self._config.redis_url)
except Exception as e:
logger.warning(f"RedisCache 初始化失败: {e},回退到 InMemoryCache")
from litellm.caching import InMemoryCache
return InMemoryCache()
def build_cache_key(
self,
model: str,
messages: list[dict[str, str]],
temperature: float,
tools: list[dict] | None = None,
tool_choice: str = "auto",
max_tokens: int = 2000,
user_id: str | None = None,
kb_acl_hash: str | None = None,
) -> str:
"""构建带 user/ACL scope 的 cache key安全要求 a, b, e
委托给 ``cache_key.generate_cache_key``,额外注入 user_id + kb_acl_hash
作为命名空间隔离,确保 User A 的查询不会命中 User B 的缓存。
"""
from agentkit.llm.cache_key import generate_cache_key
return generate_cache_key(
model=model,
messages=messages,
temperature=temperature,
tools=tools,
tool_choice=tool_choice,
max_tokens=max_tokens,
user_id=user_id,
kb_acl_hash=kb_acl_hash,
)
def should_cache(
self,
kb_caching_disabled: bool = False,
user_id: str | None = None,
) -> bool:
"""判断当前请求是否应该缓存(安全要求 c
- KB 设置 caching_disabled=True → 不缓存
- 其余情况缓存user_id 为 None 时仍可缓存,但 key 不含 user scope
"""
_ = user_id # 预留:未来支持 per-user 缓存禁用
if kb_caching_disabled:
return False
return True
@staticmethod
def cache_params_for_hit(cache_key: str) -> dict[str, str]:
"""返回 litellm acompletion 的 cache 参数(用于期望命中的调用)。"""
return {"cache_key": cache_key}
@staticmethod
def cache_params_for_no_cache() -> dict[str, bool]:
"""返回 litellm acompletion 的 cache 参数(禁用缓存)。"""
return {"no-cache": True}
def detect_cache_hit(self, response: Any) -> bool:
"""检测 LiteLLM 响应是否为缓存命中。
LiteLLM 在缓存命中时设置 ``response._hidden_params["cache_key"]``。
"""
hidden = getattr(response, "_hidden_params", None)
if isinstance(hidden, dict):
if "cache_key" in hidden or hidden.get("cache_hit"):
self._hits += 1
return True
self._misses += 1
return False
def stats(self) -> dict[str, int]:
"""返回缓存统计。"""
return {
"total_hits": self._hits,
"total_misses": self._misses,
}