feat(llm): U17 — LiteLLM 语义缓存替换 + per-user/ACL scope 安全隔离
- 新增 LitellmCacheManager:配置 litellm.cache 全局,三级后端 fallback (RedisSemanticCache -> RedisCache -> InMemoryCache),redisvl lazy import - cache_key 扩展 user_id + kb_acl_hash 参数(安全要求 a/b/e) - gateway 集成:读取 KB caching_disabled flag(安全要求 c),构建带 scope 的 cache_key,命中时 cost=0 - LLMResponse 新增 cache_hit 字段;LLMRequest 新增 cache 参数 - litellm_provider 透传 cache 参数 + 检测 _hidden_params 缓存命中 - 33 个新测试覆盖 13 场景(含 User A != User B 缓存隔离) - 旧 InMemoryLLMCache/RedisLLMCache 保留向后兼容
This commit is contained in:
parent
86541d7172
commit
793476cafa
|
|
@ -14,11 +14,14 @@ import logging
|
||||||
import time
|
import time
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Protocol, runtime_checkable
|
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
|
||||||
|
|
||||||
from agentkit.llm.protocol import LLMResponse, TokenUsage, ToolCall
|
from agentkit.llm.protocol import LLMResponse, TokenUsage, ToolCall
|
||||||
from agentkit.utils.vector_math import compute_cosine_similarity
|
from agentkit.utils.vector_math import compute_cosine_similarity
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from agentkit.llm.config import CacheConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -61,8 +64,7 @@ def _serialize_response(response: LLMResponse) -> dict:
|
||||||
"completion_tokens": response.usage.completion_tokens,
|
"completion_tokens": response.usage.completion_tokens,
|
||||||
},
|
},
|
||||||
"tool_calls": [
|
"tool_calls": [
|
||||||
{"id": tc.id, "name": tc.name, "arguments": tc.arguments}
|
{"id": tc.id, "name": tc.name, "arguments": tc.arguments} for tc in response.tool_calls
|
||||||
for tc in response.tool_calls
|
|
||||||
],
|
],
|
||||||
"latency_ms": response.latency_ms,
|
"latency_ms": response.latency_ms,
|
||||||
}
|
}
|
||||||
|
|
@ -236,7 +238,9 @@ class InMemoryLLMCache:
|
||||||
self._cache.move_to_end(key)
|
self._cache.move_to_end(key)
|
||||||
existing = self._cache[key]
|
existing = self._cache[key]
|
||||||
# Preserve existing embedding if new one is None
|
# Preserve existing embedding if new one is None
|
||||||
effective_embedding = query_embedding if query_embedding is not None else existing.query_embedding
|
effective_embedding = (
|
||||||
|
query_embedding if query_embedding is not None else existing.query_embedding
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
effective_embedding = query_embedding or []
|
effective_embedding = query_embedding or []
|
||||||
|
|
||||||
|
|
@ -276,9 +280,7 @@ class InMemoryLLMCache:
|
||||||
return count
|
return count
|
||||||
|
|
||||||
# Simple prefix matching for pattern
|
# Simple prefix matching for pattern
|
||||||
keys_to_remove = [
|
keys_to_remove = [k for k in self._cache if k.startswith(pattern.replace("*", ""))]
|
||||||
k for k in self._cache if k.startswith(pattern.replace("*", ""))
|
|
||||||
]
|
|
||||||
for key in keys_to_remove:
|
for key in keys_to_remove:
|
||||||
del self._cache[key]
|
del self._cache[key]
|
||||||
self._embeddings.pop(key, None)
|
self._embeddings.pop(key, None)
|
||||||
|
|
@ -338,9 +340,7 @@ class RedisLLMCache:
|
||||||
if self._redis is None:
|
if self._redis is None:
|
||||||
import redis.asyncio as aioredis
|
import redis.asyncio as aioredis
|
||||||
|
|
||||||
self._redis = aioredis.from_url(
|
self._redis = aioredis.from_url(self._redis_url, decode_responses=True)
|
||||||
self._redis_url, decode_responses=True
|
|
||||||
)
|
|
||||||
return self._redis
|
return self._redis
|
||||||
|
|
||||||
async def aclose(self) -> None:
|
async def aclose(self) -> None:
|
||||||
|
|
@ -393,9 +393,7 @@ class RedisLLMCache:
|
||||||
if data is not None:
|
if data is not None:
|
||||||
entry = _deserialize_entry(json.loads(data))
|
entry = _deserialize_entry(json.loads(data))
|
||||||
self._hits += 1
|
self._hits += 1
|
||||||
return CacheResult(
|
return CacheResult(hit=True, response=entry.response, match_type="exact")
|
||||||
hit=True, response=entry.response, match_type="exact"
|
|
||||||
)
|
|
||||||
self._misses += 1
|
self._misses += 1
|
||||||
return CacheResult(hit=False)
|
return CacheResult(hit=False)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -424,6 +422,7 @@ class RedisLLMCache:
|
||||||
if len(cache_keys_list) > max_scan:
|
if len(cache_keys_list) > max_scan:
|
||||||
# Take a random sample to avoid always scanning the same subset
|
# Take a random sample to avoid always scanning the same subset
|
||||||
import random
|
import random
|
||||||
|
|
||||||
cache_keys_list = random.sample(cache_keys_list, max_scan)
|
cache_keys_list = random.sample(cache_keys_list, max_scan)
|
||||||
|
|
||||||
# Batch fetch embeddings
|
# Batch fetch embeddings
|
||||||
|
|
@ -460,9 +459,7 @@ class RedisLLMCache:
|
||||||
if data is not None:
|
if data is not None:
|
||||||
entry = _deserialize_entry(json.loads(data))
|
entry = _deserialize_entry(json.loads(data))
|
||||||
self._hits += 1
|
self._hits += 1
|
||||||
return CacheResult(
|
return CacheResult(hit=True, response=entry.response, match_type="semantic")
|
||||||
hit=True, response=entry.response, match_type="semantic"
|
|
||||||
)
|
|
||||||
# Data key expired but embedding still exists — mark for cleanup
|
# Data key expired but embedding still exists — mark for cleanup
|
||||||
try:
|
try:
|
||||||
await redis.srem(self.INDEX_KEY, best_key)
|
await redis.srem(self.INDEX_KEY, best_key)
|
||||||
|
|
@ -615,9 +612,7 @@ def create_llm_cache(
|
||||||
similarity_threshold=similarity_threshold,
|
similarity_threshold=similarity_threshold,
|
||||||
)
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.warning(
|
logger.warning("redis package not available, falling back to in-memory cache")
|
||||||
"redis package not available, falling back to in-memory cache"
|
|
||||||
)
|
|
||||||
return InMemoryLLMCache(
|
return InMemoryLLMCache(
|
||||||
max_entries=max_entries,
|
max_entries=max_entries,
|
||||||
exact_ttl=exact_ttl,
|
exact_ttl=exact_ttl,
|
||||||
|
|
@ -630,3 +625,200 @@ def create_llm_cache(
|
||||||
semantic_ttl=semantic_ttl,
|
semantic_ttl=semantic_ttl,
|
||||||
similarity_threshold=similarity_threshold,
|
similarity_threshold=similarity_threshold,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# U17 — LiteLLM 缓存管理器
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LitellmCacheConfig:
|
||||||
|
"""U17 — LiteLLM 缓存配置(从 CacheConfig 转换)。
|
||||||
|
|
||||||
|
与旧 ``CacheConfig`` 的区别:
|
||||||
|
- ``similarity_threshold`` 固定默认 0.87(plan 规定)
|
||||||
|
- ``per_user_namespace`` 强制开启(安全要求 a)
|
||||||
|
- ``backend`` 新增 ``redis_semantic`` 选项(需要 redisvl)
|
||||||
|
"""
|
||||||
|
|
||||||
|
enabled: bool = False
|
||||||
|
backend: str = "auto" # "auto" | "redis_semantic" | "redis" | "memory"
|
||||||
|
redis_url: str = "redis://localhost:6379"
|
||||||
|
similarity_threshold: float = 0.87 # U17 默认 0.87(plan 规定)
|
||||||
|
ttl: int = 86400
|
||||||
|
embedding_model: str = "text-embedding-ada-002"
|
||||||
|
per_user_namespace: bool = True # 安全要求 (a)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_cache_config(cls, c: "CacheConfig") -> "LitellmCacheConfig":
|
||||||
|
"""从现有 CacheConfig 转换。
|
||||||
|
|
||||||
|
- ``similarity_threshold`` 固定 0.87(U17 plan 规定,忽略旧 0.92)
|
||||||
|
- ``per_user_namespace`` 强制 True(安全要求 a)
|
||||||
|
- ``embedding_model`` 回退到 "text-embedding-ada-002"(LiteLLM 默认)
|
||||||
|
"""
|
||||||
|
return cls(
|
||||||
|
enabled=c.enabled,
|
||||||
|
backend=c.backend if c.backend in ("auto", "redis", "memory") else "auto",
|
||||||
|
redis_url=c.redis_url,
|
||||||
|
similarity_threshold=0.87, # U17 固定默认,忽略旧 0.92
|
||||||
|
ttl=c.semantic_ttl,
|
||||||
|
embedding_model=c.embedding_model or "text-embedding-ada-002",
|
||||||
|
per_user_namespace=True, # 强制开启
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LitellmCacheManager:
|
||||||
|
"""U17 — LiteLLM 全局缓存管理器。
|
||||||
|
|
||||||
|
职责:
|
||||||
|
1. 创建并设置 ``litellm.cache`` 全局实例
|
||||||
|
2. 构建带 user/ACL scope 的 cache key(安全要求 a, b)
|
||||||
|
3. 提供 per-call cache 参数(cache_key 或 no-cache)
|
||||||
|
4. 检测 LiteLLM 响应的缓存命中标志(用于 usage tracking)
|
||||||
|
5. 统计缓存命中率
|
||||||
|
|
||||||
|
后端选择优先级(backend="auto" 时):
|
||||||
|
RedisSemanticCache(需 redisvl)→ RedisCache(精确)→ InMemoryCache
|
||||||
|
|
||||||
|
安全约束:
|
||||||
|
- (a) cache key 包含 user_id(per-user namespace)
|
||||||
|
- (b) cache key 包含 kb_acl_hash(ACL-scope 隔离)
|
||||||
|
- (c) KB 设置 caching_disabled=True 时禁用缓存
|
||||||
|
- (e) User A 的查询不会命中 User B 的缓存
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: LitellmCacheConfig):
|
||||||
|
self._config = config
|
||||||
|
self._cache_instance: Any = None # litellm.caching.Cache 实例
|
||||||
|
self._hits = 0
|
||||||
|
self._misses = 0
|
||||||
|
|
||||||
|
def enable(self) -> None:
|
||||||
|
"""创建 LiteLLM Cache 实例并赋值给 ``litellm.cache``。"""
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
self._cache_instance = self._create_cache_instance()
|
||||||
|
litellm.cache = self._cache_instance
|
||||||
|
|
||||||
|
def disable(self) -> None:
|
||||||
|
"""禁用缓存 — 设置 ``litellm.cache = None``。"""
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
litellm.cache = None
|
||||||
|
self._cache_instance = None
|
||||||
|
|
||||||
|
def _create_cache_instance(self) -> Any:
|
||||||
|
"""根据 backend 配置创建 LiteLLM Cache 实例。
|
||||||
|
|
||||||
|
auto 模式按优先级尝试:RedisSemanticCache → RedisCache → InMemoryCache。
|
||||||
|
redisvl 缺失时自动回退(安全要求 d — 不添加为必需依赖)。
|
||||||
|
"""
|
||||||
|
backend = self._config.backend
|
||||||
|
if backend in ("auto", "redis_semantic"):
|
||||||
|
# 尝试 RedisSemanticCache(需要 redisvl — lazy import,缺失时 fallback)
|
||||||
|
try:
|
||||||
|
from litellm.caching import RedisSemanticCache
|
||||||
|
|
||||||
|
return RedisSemanticCache(
|
||||||
|
redis_url=self._config.redis_url,
|
||||||
|
similarity_threshold=self._config.similarity_threshold,
|
||||||
|
embedding_model=self._config.embedding_model,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
logger.warning(
|
||||||
|
"RedisSemanticCache 需要 redisvl 包(未安装),"
|
||||||
|
"回退到 RedisCache(精确匹配,无语义搜索)。"
|
||||||
|
"安装 redisvl 以启用语义缓存:pip install redisvl"
|
||||||
|
)
|
||||||
|
if backend == "redis_semantic":
|
||||||
|
raise # 显式要求语义缓存但 redisvl 缺失 — 报错
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"RedisSemanticCache 初始化失败: {e},回退到 RedisCache")
|
||||||
|
|
||||||
|
if backend in ("auto", "redis", "redis_semantic"):
|
||||||
|
try:
|
||||||
|
from litellm.caching import RedisCache
|
||||||
|
|
||||||
|
return RedisCache(redis_url=self._config.redis_url)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"RedisCache 初始化失败: {e},回退到 InMemoryCache")
|
||||||
|
|
||||||
|
from litellm.caching import InMemoryCache
|
||||||
|
|
||||||
|
return InMemoryCache()
|
||||||
|
|
||||||
|
def build_cache_key(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
temperature: float,
|
||||||
|
tools: list[dict] | None = None,
|
||||||
|
tool_choice: str = "auto",
|
||||||
|
max_tokens: int = 2000,
|
||||||
|
user_id: str | None = None,
|
||||||
|
kb_acl_hash: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""构建带 user/ACL scope 的 cache key(安全要求 a, b, e)。
|
||||||
|
|
||||||
|
委托给 ``cache_key.generate_cache_key``,额外注入 user_id + kb_acl_hash
|
||||||
|
作为命名空间隔离,确保 User A 的查询不会命中 User B 的缓存。
|
||||||
|
"""
|
||||||
|
from agentkit.llm.cache_key import generate_cache_key
|
||||||
|
|
||||||
|
return generate_cache_key(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
temperature=temperature,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
user_id=user_id,
|
||||||
|
kb_acl_hash=kb_acl_hash,
|
||||||
|
)
|
||||||
|
|
||||||
|
def should_cache(
|
||||||
|
self,
|
||||||
|
kb_caching_disabled: bool = False,
|
||||||
|
user_id: str | None = None,
|
||||||
|
) -> bool:
|
||||||
|
"""判断当前请求是否应该缓存(安全要求 c)。
|
||||||
|
|
||||||
|
- KB 设置 caching_disabled=True → 不缓存
|
||||||
|
- 其余情况缓存(user_id 为 None 时仍可缓存,但 key 不含 user scope)
|
||||||
|
"""
|
||||||
|
_ = user_id # 预留:未来支持 per-user 缓存禁用
|
||||||
|
if kb_caching_disabled:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def cache_params_for_hit(cache_key: str) -> dict[str, str]:
|
||||||
|
"""返回 litellm acompletion 的 cache 参数(用于期望命中的调用)。"""
|
||||||
|
return {"cache_key": cache_key}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def cache_params_for_no_cache() -> dict[str, bool]:
|
||||||
|
"""返回 litellm acompletion 的 cache 参数(禁用缓存)。"""
|
||||||
|
return {"no-cache": True}
|
||||||
|
|
||||||
|
def detect_cache_hit(self, response: Any) -> bool:
|
||||||
|
"""检测 LiteLLM 响应是否为缓存命中。
|
||||||
|
|
||||||
|
LiteLLM 在缓存命中时设置 ``response._hidden_params["cache_key"]``。
|
||||||
|
"""
|
||||||
|
hidden = getattr(response, "_hidden_params", None)
|
||||||
|
if isinstance(hidden, dict):
|
||||||
|
if "cache_key" in hidden or hidden.get("cache_hit"):
|
||||||
|
self._hits += 1
|
||||||
|
return True
|
||||||
|
self._misses += 1
|
||||||
|
return False
|
||||||
|
|
||||||
|
def stats(self) -> dict[str, int]:
|
||||||
|
"""返回缓存统计。"""
|
||||||
|
return {
|
||||||
|
"total_hits": self._hits,
|
||||||
|
"total_misses": self._misses,
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,8 @@ def generate_cache_key(
|
||||||
tools: list[dict[str, Any]] | None = None,
|
tools: list[dict[str, Any]] | None = None,
|
||||||
tool_choice: str = "auto",
|
tool_choice: str = "auto",
|
||||||
max_tokens: int = 2000,
|
max_tokens: int = 2000,
|
||||||
|
user_id: str | None = None,
|
||||||
|
kb_acl_hash: str | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generate a deterministic SHA-256 cache key from LLM request parameters.
|
"""Generate a deterministic SHA-256 cache key from LLM request parameters.
|
||||||
|
|
||||||
|
|
@ -19,6 +21,10 @@ def generate_cache_key(
|
||||||
model, system_prompt (extracted from messages), messages content,
|
model, system_prompt (extracted from messages), messages content,
|
||||||
temperature, tools, tool_choice, and max_tokens.
|
temperature, tools, tool_choice, and max_tokens.
|
||||||
|
|
||||||
|
U17 安全扩展:``user_id`` 和 ``kb_acl_hash`` 非 None 时加入 hash 组件,
|
||||||
|
实现 per-user namespace 和 ACL-scope 隔离(安全要求 a, b)。为 None 时
|
||||||
|
行为与旧版完全一致(向后兼容)。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model: Model identifier (e.g. "openai/gpt-4o").
|
model: Model identifier (e.g. "openai/gpt-4o").
|
||||||
messages: Chat messages list (may include system prompt as first message).
|
messages: Chat messages list (may include system prompt as first message).
|
||||||
|
|
@ -26,6 +32,8 @@ def generate_cache_key(
|
||||||
tools: Optional list of tool definitions.
|
tools: Optional list of tool definitions.
|
||||||
tool_choice: Tool selection mode ("auto", "none", etc.).
|
tool_choice: Tool selection mode ("auto", "none", etc.).
|
||||||
max_tokens: Maximum response tokens.
|
max_tokens: Maximum response tokens.
|
||||||
|
user_id: U17 — 用户 ID,用于 per-user cache namespace 隔离。
|
||||||
|
kb_acl_hash: U17 — KB ACL-scope hash,用于 ACL 隔离缓存键。
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
64-character hex SHA-256 hash string.
|
64-character hex SHA-256 hash string.
|
||||||
|
|
@ -40,6 +48,11 @@ def generate_cache_key(
|
||||||
_hash_str(tool_choice),
|
_hash_str(tool_choice),
|
||||||
_hash_str(str(max_tokens)),
|
_hash_str(str(max_tokens)),
|
||||||
]
|
]
|
||||||
|
# U17 — per-user namespace + ACL scope hash(安全要求 a, b, e)
|
||||||
|
if user_id is not None:
|
||||||
|
components.append(_hash_str(f"user:{user_id}"))
|
||||||
|
if kb_acl_hash is not None:
|
||||||
|
components.append(_hash_str(f"acl:{kb_acl_hash}"))
|
||||||
combined = "".join(components)
|
combined = "".join(components)
|
||||||
return hashlib.sha256(combined.encode()).hexdigest()
|
return hashlib.sha256(combined.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
@ -61,6 +74,4 @@ def _hash_json(obj: Any) -> str:
|
||||||
"""SHA-256 hash of a JSON-serializable object."""
|
"""SHA-256 hash of a JSON-serializable object."""
|
||||||
if obj is None:
|
if obj is None:
|
||||||
return hashlib.sha256(b"null").hexdigest()
|
return hashlib.sha256(b"null").hexdigest()
|
||||||
return hashlib.sha256(
|
return hashlib.sha256(json.dumps(obj, sort_keys=True, ensure_ascii=False).encode()).hexdigest()
|
||||||
json.dumps(obj, sort_keys=True, ensure_ascii=False).encode()
|
|
||||||
).hexdigest()
|
|
||||||
|
|
|
||||||
|
|
@ -51,47 +51,19 @@ class LLMGateway:
|
||||||
self._usage_tracker = UsageTracker(store=usage_store) if usage_store else UsageTracker()
|
self._usage_tracker = UsageTracker(store=usage_store) if usage_store else UsageTracker()
|
||||||
self._config = config or LLMConfig()
|
self._config = config or LLMConfig()
|
||||||
|
|
||||||
# Cache (opt-in, disabled by default)
|
# Cache (U17 — LiteLLM 缓存管理器,opt-in,默认禁用)
|
||||||
self._cache: Any = None # LLMCache | None
|
self._cache_manager: Any = None # LitellmCacheManager | None
|
||||||
self._embedder: Any = None # Embedder | None
|
|
||||||
if self._config.cache and self._config.cache.enabled:
|
if self._config.cache and self._config.cache.enabled:
|
||||||
from agentkit.llm.cache import create_llm_cache
|
from agentkit.llm.cache import LitellmCacheConfig, LitellmCacheManager
|
||||||
|
|
||||||
self._cache = create_llm_cache(
|
litellm_config = LitellmCacheConfig.from_cache_config(self._config.cache)
|
||||||
backend=self._config.cache.backend,
|
self._cache_manager = LitellmCacheManager(litellm_config)
|
||||||
redis_url=self._config.cache.redis_url,
|
self._cache_manager.enable()
|
||||||
max_entries=self._config.cache.max_entries,
|
|
||||||
exact_ttl=self._config.cache.exact_ttl,
|
|
||||||
semantic_ttl=self._config.cache.semantic_ttl,
|
|
||||||
similarity_threshold=self._config.cache.similarity_threshold,
|
|
||||||
)
|
|
||||||
self._embedder = self._create_embedder(self._config.cache)
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"LLM cache enabled (backend={self._config.cache.backend}, "
|
f"LLM cache enabled (LiteLLM, backend={self._config.cache.backend}, "
|
||||||
f"embedder={self._config.cache.embedding_provider}/{self._config.cache.embedding_model})"
|
f"similarity_threshold={litellm_config.similarity_threshold})"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _create_embedder(self, cache_config) -> Any:
|
|
||||||
"""Create embedder for semantic cache based on config."""
|
|
||||||
try:
|
|
||||||
from agentkit.memory.embedder import OpenAIEmbedder
|
|
||||||
|
|
||||||
if cache_config.embedding_provider in ("xinference", "local"):
|
|
||||||
return OpenAIEmbedder(
|
|
||||||
api_key=cache_config.embedding_api_key or "not-needed",
|
|
||||||
model=cache_config.embedding_model,
|
|
||||||
base_url=cache_config.embedding_base_url or "http://localhost:9997/v1",
|
|
||||||
)
|
|
||||||
# Default: OpenAI
|
|
||||||
return OpenAIEmbedder(
|
|
||||||
api_key=cache_config.embedding_api_key,
|
|
||||||
model=cache_config.embedding_model,
|
|
||||||
base_url=cache_config.embedding_base_url,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to create embedder for semantic cache: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def register_provider(self, name: str, provider: LLMProvider) -> None:
|
def register_provider(self, name: str, provider: LLMProvider) -> None:
|
||||||
"""注册 Provider"""
|
"""注册 Provider"""
|
||||||
self._providers[name] = provider
|
self._providers[name] = provider
|
||||||
|
|
@ -114,6 +86,8 @@ class LLMGateway:
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
department_ids: list[str] | None = None,
|
department_ids: list[str] | None = None,
|
||||||
db_path: Path | str | None = None,
|
db_path: Path | str | None = None,
|
||||||
|
kb_id: str | None = None,
|
||||||
|
kb_acl_hash: str | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""发送 chat 请求,自动解析别名和 Fallback"""
|
"""发送 chat 请求,自动解析别名和 Fallback"""
|
||||||
|
|
@ -151,67 +125,40 @@ class LLMGateway:
|
||||||
|
|
||||||
start = time.monotonic()
|
start = time.monotonic()
|
||||||
|
|
||||||
# ── Cache check ──
|
# ── Cache check (U17 — LiteLLM cache via cache_key in request) ──
|
||||||
cache_key = None
|
# LiteLLM 在 litellm.acompletion 内部处理缓存读写,gateway 只需:
|
||||||
query_embedding = None
|
# 1. 构建 per-user + ACL-scoped cache_key(安全要求 a, b)
|
||||||
if self._cache is not None:
|
# 2. 将 cache 参数注入 kwargs 透传到 provider
|
||||||
from agentkit.llm.cache_key import generate_cache_key
|
# 3. 检测响应的 cache_hit 标志,用于 usage tracking(cost=0)
|
||||||
|
if self._cache_manager is not None:
|
||||||
|
from agentkit.llm.cache import LitellmCacheManager
|
||||||
|
|
||||||
cache_key = generate_cache_key(
|
# 解析 KB caching_disabled(安全要求 c)
|
||||||
model=resolved_model,
|
kb_caching_disabled = False
|
||||||
messages=messages,
|
if kb_id is not None:
|
||||||
temperature=kwargs.get("temperature", 0.7),
|
|
||||||
tools=tools,
|
|
||||||
tool_choice=tool_choice,
|
|
||||||
max_tokens=kwargs.get("max_tokens", 2000),
|
|
||||||
)
|
|
||||||
result = await self._cache.get(cache_key)
|
|
||||||
if result.hit:
|
|
||||||
latency_ms = (time.monotonic() - start) * 1000
|
|
||||||
await self._record_usage(
|
|
||||||
agent_name=agent_name,
|
|
||||||
model=result.response.model,
|
|
||||||
usage=result.response.usage,
|
|
||||||
cost=0.0,
|
|
||||||
latency_ms=latency_ms,
|
|
||||||
user_id=user_id,
|
|
||||||
department_ids=department_ids,
|
|
||||||
)
|
|
||||||
if _span is not None:
|
|
||||||
_span.set_attribute("gen_ai.cache.hit", True)
|
|
||||||
_span.set_attribute("gen_ai.cache.match_type", result.match_type)
|
|
||||||
return result.response
|
|
||||||
|
|
||||||
# Semantic match (only for temperature == 0)
|
|
||||||
temperature = kwargs.get("temperature", 0.7)
|
|
||||||
if temperature == 0 and self._embedder is not None:
|
|
||||||
try:
|
try:
|
||||||
# Embed last N messages for context-aware semantic matching
|
from agentkit.rag_platform.settings import get_settings_store
|
||||||
# (not just last user message — avoids cross-context false hits)
|
|
||||||
recent_messages = messages[-3:] if len(messages) > 3 else messages
|
settings = await get_settings_store().get_settings(kb_id)
|
||||||
embed_text = " | ".join(
|
if settings is not None:
|
||||||
m.get("content", "") for m in recent_messages if m.get("content")
|
kb_caching_disabled = settings.caching_disabled
|
||||||
)
|
|
||||||
if embed_text:
|
|
||||||
query_embedding = await self._embedder.embed(embed_text)
|
|
||||||
result = await self._cache.semantic_search(query_embedding)
|
|
||||||
if result.hit:
|
|
||||||
latency_ms = (time.monotonic() - start) * 1000
|
|
||||||
await self._record_usage(
|
|
||||||
agent_name=agent_name,
|
|
||||||
model=result.response.model,
|
|
||||||
usage=result.response.usage,
|
|
||||||
cost=0.0,
|
|
||||||
latency_ms=latency_ms,
|
|
||||||
user_id=user_id,
|
|
||||||
department_ids=department_ids,
|
|
||||||
)
|
|
||||||
if _span is not None:
|
|
||||||
_span.set_attribute("gen_ai.cache.hit", True)
|
|
||||||
_span.set_attribute("gen_ai.cache.match_type", "semantic")
|
|
||||||
return result.response
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Semantic cache search failed: {e}")
|
logger.warning(f"Failed to read KB cache settings for kb_id={kb_id}: {e}")
|
||||||
|
|
||||||
|
if self._cache_manager.should_cache(kb_caching_disabled, user_id):
|
||||||
|
cache_key = self._cache_manager.build_cache_key(
|
||||||
|
model=resolved_model,
|
||||||
|
messages=messages,
|
||||||
|
temperature=kwargs.get("temperature", 0.7),
|
||||||
|
tools=tools,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
max_tokens=kwargs.get("max_tokens", 2000),
|
||||||
|
user_id=user_id,
|
||||||
|
kb_acl_hash=kb_acl_hash,
|
||||||
|
)
|
||||||
|
kwargs["cache"] = LitellmCacheManager.cache_params_for_hit(cache_key)
|
||||||
|
else:
|
||||||
|
kwargs["cache"] = LitellmCacheManager.cache_params_for_no_cache()
|
||||||
|
|
||||||
# ── Normal provider call ──
|
# ── Normal provider call ──
|
||||||
models_to_try = self._get_models_to_try(resolved_model)
|
models_to_try = self._get_models_to_try(resolved_model)
|
||||||
|
|
@ -275,15 +222,15 @@ class LLMGateway:
|
||||||
|
|
||||||
latency_ms = (time.monotonic() - start) * 1000
|
latency_ms = (time.monotonic() - start) * 1000
|
||||||
|
|
||||||
# ── Cache write ──
|
# U17 — 检测 LiteLLM 缓存命中(用于 usage tracking cost=0)
|
||||||
if self._cache is not None and cache_key is not None:
|
is_cache_hit = getattr(response, "cache_hit", False)
|
||||||
try:
|
if is_cache_hit and self._cache_manager is not None:
|
||||||
await self._cache.put(cache_key, response, query_embedding)
|
self._cache_manager._hits += 1
|
||||||
except Exception as e:
|
elif self._cache_manager is not None:
|
||||||
logger.warning(f"Cache write failed: {e}")
|
self._cache_manager._misses += 1
|
||||||
|
|
||||||
# 计算成本
|
# 计算成本(缓存命中时 cost=0)
|
||||||
cost = self._calculate_cost(response.model, response.usage)
|
cost = 0.0 if is_cache_hit else self._calculate_cost(response.model, response.usage)
|
||||||
|
|
||||||
# 记录使用量
|
# 记录使用量
|
||||||
await self._record_usage(
|
await self._record_usage(
|
||||||
|
|
@ -302,8 +249,8 @@ class LLMGateway:
|
||||||
_span.set_attribute("gen_ai.usage.output_tokens", response.usage.completion_tokens)
|
_span.set_attribute("gen_ai.usage.output_tokens", response.usage.completion_tokens)
|
||||||
_span.set_attribute("gen_ai.response.model", response.model)
|
_span.set_attribute("gen_ai.response.model", response.model)
|
||||||
_span.set_attribute("gen_ai.duration.ms", int(latency_ms))
|
_span.set_attribute("gen_ai.duration.ms", int(latency_ms))
|
||||||
if self._cache is not None:
|
if self._cache_manager is not None:
|
||||||
_span.set_attribute("gen_ai.cache.hit", False)
|
_span.set_attribute("gen_ai.cache.hit", is_cache_hit)
|
||||||
llm_token_histogram().record(
|
llm_token_histogram().record(
|
||||||
response.usage.total_tokens,
|
response.usage.total_tokens,
|
||||||
{"gen_ai.request.model": resolved_model},
|
{"gen_ai.request.model": resolved_model},
|
||||||
|
|
@ -607,18 +554,10 @@ class LLMGateway:
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. Token + cost limits (daily AND monthly)
|
# 2. Token + cost limits (daily AND monthly)
|
||||||
await self._check_quota_period(
|
await self._check_quota_period(quota_service, db, dept_id, "daily", "token_limit")
|
||||||
quota_service, db, dept_id, "daily", "token_limit"
|
await self._check_quota_period(quota_service, db, dept_id, "daily", "cost_limit")
|
||||||
)
|
await self._check_quota_period(quota_service, db, dept_id, "monthly", "token_limit")
|
||||||
await self._check_quota_period(
|
await self._check_quota_period(quota_service, db, dept_id, "monthly", "cost_limit")
|
||||||
quota_service, db, dept_id, "daily", "cost_limit"
|
|
||||||
)
|
|
||||||
await self._check_quota_period(
|
|
||||||
quota_service, db, dept_id, "monthly", "token_limit"
|
|
||||||
)
|
|
||||||
await self._check_quota_period(
|
|
||||||
quota_service, db, dept_id, "monthly", "cost_limit"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _check_quota_period(
|
async def _check_quota_period(
|
||||||
self,
|
self,
|
||||||
|
|
@ -639,9 +578,7 @@ class LLMGateway:
|
||||||
else:
|
else:
|
||||||
current = await self._get_current_cost_for_quota(dept_id, period)
|
current = await self._get_current_cost_for_quota(dept_id, period)
|
||||||
|
|
||||||
allowed, _reason = await quota_service.check_quota(
|
allowed, _reason = await quota_service.check_quota(db, dept_id, quota_type, period, current)
|
||||||
db, dept_id, quota_type, period, current
|
|
||||||
)
|
|
||||||
if not allowed:
|
if not allowed:
|
||||||
quota = await quota_service.get_quota(db, dept_id, quota_type, period)
|
quota = await quota_service.get_quota(db, dept_id, quota_type, period)
|
||||||
limit = quota["limit_value"] if quota else None
|
limit = quota["limit_value"] if quota else None
|
||||||
|
|
|
||||||
|
|
@ -47,6 +47,7 @@ class LLMRequest:
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
max_tokens: int = 2000,
|
max_tokens: int = 2000,
|
||||||
timeout: float | None = None,
|
timeout: float | None = None,
|
||||||
|
cache: dict[str, Any] | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
self.messages = messages
|
self.messages = messages
|
||||||
|
|
@ -57,6 +58,8 @@ class LLMRequest:
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self._extra = kwargs
|
self._extra = kwargs
|
||||||
|
# U17 — LiteLLM cache 参数(cache_key 或 no-cache),透传到 litellm.acompletion
|
||||||
|
self._cache: dict[str, Any] | None = cache
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -81,6 +84,7 @@ class LLMResponse:
|
||||||
usage: TokenUsage
|
usage: TokenUsage
|
||||||
tool_calls: list[ToolCall] = field(default_factory=list)
|
tool_calls: list[ToolCall] = field(default_factory=list)
|
||||||
latency_ms: float = 0.0
|
latency_ms: float = 0.0
|
||||||
|
cache_hit: bool = False # U17 — 缓存命中标记,用于 usage tracking(cost=0)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def has_tool_calls(self) -> bool:
|
def has_tool_calls(self) -> bool:
|
||||||
|
|
|
||||||
|
|
@ -183,6 +183,10 @@ class LitellmProvider(LLMProvider):
|
||||||
kwargs["tool_choice"] = request.tool_choice
|
kwargs["tool_choice"] = request.tool_choice
|
||||||
if request.timeout is not None:
|
if request.timeout is not None:
|
||||||
kwargs["timeout"] = request.timeout
|
kwargs["timeout"] = request.timeout
|
||||||
|
# U17 — 透传 LiteLLM cache 参数(cache_key 或 no-cache)到 litellm.acompletion
|
||||||
|
cache_params = getattr(request, "_cache", None)
|
||||||
|
if cache_params is not None:
|
||||||
|
kwargs["cache"] = cache_params
|
||||||
# 合并构造时传入的默认 kwargs(如 max_connections 等provider特定参数)
|
# 合并构造时传入的默认 kwargs(如 max_connections 等provider特定参数)
|
||||||
kwargs.update(self._default_kwargs)
|
kwargs.update(self._default_kwargs)
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
@ -214,12 +218,19 @@ class LitellmProvider(LLMProvider):
|
||||||
|
|
||||||
model_name = getattr(response, "model", None) or request_model
|
model_name = getattr(response, "model", None) or request_model
|
||||||
|
|
||||||
|
# U17 — 检测 LiteLLM 缓存命中(_hidden_params 含 cache_key 或 cache_hit)
|
||||||
|
cache_hit = False
|
||||||
|
hidden = getattr(response, "_hidden_params", None)
|
||||||
|
if isinstance(hidden, dict) and ("cache_key" in hidden or hidden.get("cache_hit")):
|
||||||
|
cache_hit = True
|
||||||
|
|
||||||
return LLMResponse(
|
return LLMResponse(
|
||||||
content=content,
|
content=content,
|
||||||
model=model_name,
|
model=model_name,
|
||||||
usage=usage,
|
usage=usage,
|
||||||
tool_calls=tool_calls,
|
tool_calls=tool_calls,
|
||||||
latency_ms=latency_ms,
|
latency_ms=latency_ms,
|
||||||
|
cache_hit=cache_hit,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_stream_chunk(
|
def _parse_stream_chunk(
|
||||||
|
|
|
||||||
|
|
@ -287,13 +287,13 @@ class TestCacheIntegration:
|
||||||
cache=CacheConfig(enabled=True, backend="memory"),
|
cache=CacheConfig(enabled=True, backend="memory"),
|
||||||
)
|
)
|
||||||
gateway = LLMGateway(config=config)
|
gateway = LLMGateway(config=config)
|
||||||
assert gateway._cache is not None
|
assert gateway._cache_manager is not None
|
||||||
|
|
||||||
def test_gateway_without_cache_config(self):
|
def test_gateway_without_cache_config(self):
|
||||||
"""LLMGateway works without cache (default)."""
|
"""LLMGateway works without cache (default)."""
|
||||||
config = LLMConfig(providers={})
|
config = LLMConfig(providers={})
|
||||||
gateway = LLMGateway(config=config)
|
gateway = LLMGateway(config=config)
|
||||||
assert gateway._cache is None
|
assert gateway._cache_manager is None
|
||||||
|
|
||||||
def test_gateway_cache_disabled(self):
|
def test_gateway_cache_disabled(self):
|
||||||
"""LLMGateway does not initialize cache when disabled."""
|
"""LLMGateway does not initialize cache when disabled."""
|
||||||
|
|
@ -302,7 +302,7 @@ class TestCacheIntegration:
|
||||||
cache=CacheConfig(enabled=False),
|
cache=CacheConfig(enabled=False),
|
||||||
)
|
)
|
||||||
gateway = LLMGateway(config=config)
|
gateway = LLMGateway(config=config)
|
||||||
assert gateway._cache is None
|
assert gateway._cache_manager is None
|
||||||
|
|
||||||
|
|
||||||
# ── Graceful degradation tests ─────────────────────────────
|
# ── Graceful degradation tests ─────────────────────────────
|
||||||
|
|
@ -390,7 +390,7 @@ class TestFullFlowInMemory:
|
||||||
redis_url=config.usage_store.get("redis_url", "redis://localhost:6379"),
|
redis_url=config.usage_store.get("redis_url", "redis://localhost:6379"),
|
||||||
)
|
)
|
||||||
gateway = LLMGateway(config=config.llm_config, usage_store=usage_store)
|
gateway = LLMGateway(config=config.llm_config, usage_store=usage_store)
|
||||||
assert gateway._cache is not None
|
assert gateway._cache_manager is not None
|
||||||
|
|
||||||
cascade_store = create_cascade_state_store(
|
cascade_store = create_cascade_state_store(
|
||||||
backend=config.cascade_store.get("backend", "memory"),
|
backend=config.cascade_store.get("backend", "memory"),
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,486 @@
|
||||||
|
"""U17 — LiteLLM 语义缓存管理器单元测试。
|
||||||
|
|
||||||
|
覆盖场景(plan 要求的 4 个 + 安全要求 e + 向后兼容 + fallback):
|
||||||
|
1. 语义相似查询命中缓存(mock litellm.acompletion 模拟缓存行为)
|
||||||
|
2. 不同 system prompt 不命中缓存
|
||||||
|
3. 缓存命中率可统计
|
||||||
|
4. 阈值调优生效(similarity_threshold 传入 RedisSemanticCache)
|
||||||
|
5. User A 不返回 User B 缓存(安全要求 e)
|
||||||
|
6. kb_acl_hash 隔离 — 不同 ACL hash 产生不同 key
|
||||||
|
7. kb_caching_disabled 禁用缓存(安全要求 c)
|
||||||
|
8. cache_params_for_hit / no_cache — 返回正确 dict
|
||||||
|
9. detect_cache_hit — _hidden_params 含 cache_key 时返回 True
|
||||||
|
10. LitellmCacheConfig.from_cache_config — 转换正确,similarity_threshold=0.87
|
||||||
|
11. LitellmCacheManager.enable/disable — litellm.cache 正确设置/清除
|
||||||
|
12. generate_cache_key 向后兼容 — user_id=None, kb_acl_hash=None 时与旧版相同
|
||||||
|
13. RedisSemanticCache fallback — redisvl 缺失时 auto backend 回退
|
||||||
|
|
||||||
|
测试用 ``unittest.mock.patch`` mock ``litellm.cache`` 全局变量,避免测试间污染。
|
||||||
|
每个测试后清理 ``litellm.cache = None``。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from agentkit.llm.cache import LitellmCacheConfig, LitellmCacheManager
|
||||||
|
from agentkit.llm.cache_key import generate_cache_key
|
||||||
|
from agentkit.llm.config import CacheConfig
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 测试辅助
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_messages(user_content: str = "Hello") -> list[dict[str, str]]:
|
||||||
|
return [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": user_content},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _make_litellm_response(
|
||||||
|
content: str = "Hello!",
|
||||||
|
model: str = "gpt-4o-mini",
|
||||||
|
prompt_tokens: int = 10,
|
||||||
|
completion_tokens: int = 5,
|
||||||
|
cache_key: str | None = None,
|
||||||
|
) -> SimpleNamespace:
|
||||||
|
"""构造 LiteLLM 响应(OpenAI ChatCompletion 格式),可选 cache_key 标记。"""
|
||||||
|
hidden_params: dict[str, Any] = {}
|
||||||
|
if cache_key is not None:
|
||||||
|
hidden_params["cache_key"] = cache_key
|
||||||
|
message = SimpleNamespace(content=content, tool_calls=None)
|
||||||
|
return SimpleNamespace(
|
||||||
|
choices=[SimpleNamespace(message=message)],
|
||||||
|
usage=SimpleNamespace(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens),
|
||||||
|
model=model,
|
||||||
|
_hidden_params=hidden_params,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _cleanup_litellm_cache():
|
||||||
|
"""每个测试后清理 litellm.cache 全局变量,避免测试间污染。"""
|
||||||
|
yield
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
litellm.cache = None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 1. 语义相似查询命中缓存
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCacheHit:
|
||||||
|
"""场景 1 — 相同请求第二次命中缓存。"""
|
||||||
|
|
||||||
|
async def test_second_request_is_cache_hit(self):
|
||||||
|
"""相同 cache_key 的第二次请求返回缓存响应(cache_hit=True)。"""
|
||||||
|
config = LitellmCacheConfig(enabled=True, backend="memory")
|
||||||
|
manager = LitellmCacheManager(config)
|
||||||
|
|
||||||
|
msgs = _make_messages()
|
||||||
|
cache_key = manager.build_cache_key("gpt-4o", msgs, 0.0)
|
||||||
|
cache_params = LitellmCacheManager.cache_params_for_hit(cache_key)
|
||||||
|
|
||||||
|
# 模拟 LiteLLM 缓存行为:第一次 miss,第二次 hit
|
||||||
|
call_count = 0
|
||||||
|
cache_store: dict[str, SimpleNamespace] = {}
|
||||||
|
|
||||||
|
async def fake_acompletion(**kwargs):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
ck = kwargs.get("cache", {}).get("cache_key")
|
||||||
|
if ck and ck in cache_store:
|
||||||
|
# 缓存命中 — 返回带 cache_key 标记的响应
|
||||||
|
return _make_litellm_response(content="Cached!", cache_key=ck)
|
||||||
|
# 缓存未命中 — 调用"真实" API 并存储
|
||||||
|
resp = _make_litellm_response(content="Fresh response")
|
||||||
|
if ck:
|
||||||
|
cache_store[ck] = resp
|
||||||
|
return resp
|
||||||
|
|
||||||
|
with patch("litellm.acompletion", side_effect=fake_acompletion):
|
||||||
|
from agentkit.llm.providers.litellm_provider import LitellmProvider
|
||||||
|
|
||||||
|
provider = LitellmProvider(model_prefix="openai/", api_key="test")
|
||||||
|
from agentkit.llm.protocol import LLMRequest
|
||||||
|
|
||||||
|
# 第一次请求 — miss
|
||||||
|
req1 = LLMRequest(messages=msgs, model="gpt-4o", temperature=0.0, cache=cache_params)
|
||||||
|
resp1 = await provider.chat(req1)
|
||||||
|
assert resp1.cache_hit is False
|
||||||
|
assert resp1.content == "Fresh response"
|
||||||
|
|
||||||
|
# 第二次相同请求 — hit
|
||||||
|
req2 = LLMRequest(messages=msgs, model="gpt-4o", temperature=0.0, cache=cache_params)
|
||||||
|
resp2 = await provider.chat(req2)
|
||||||
|
assert resp2.cache_hit is True
|
||||||
|
assert resp2.content == "Cached!"
|
||||||
|
|
||||||
|
assert call_count == 2 # litellm.acompletion 被调用两次(LiteLLM 内部处理缓存)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 2. 不同 system prompt 不命中缓存
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSystemPromptIsolation:
|
||||||
|
"""场景 2 — 不同 system prompt 产生不同 cache_key,不误命中。"""
|
||||||
|
|
||||||
|
def test_different_system_prompt_different_key(self):
|
||||||
|
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
||||||
|
msgs1 = [
|
||||||
|
{"role": "system", "content": "Be concise"},
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
]
|
||||||
|
msgs2 = [
|
||||||
|
{"role": "system", "content": "Be verbose"},
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
]
|
||||||
|
key1 = manager.build_cache_key("gpt-4o", msgs1, 0.0)
|
||||||
|
key2 = manager.build_cache_key("gpt-4o", msgs2, 0.0)
|
||||||
|
assert key1 != key2
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 3. 缓存命中率可统计
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCacheStats:
|
||||||
|
"""场景 3 — LitellmCacheManager.stats() 返回正确 hits/misses。"""
|
||||||
|
|
||||||
|
def test_stats_after_hits_and_misses(self):
|
||||||
|
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
||||||
|
|
||||||
|
# 2 hits
|
||||||
|
manager.detect_cache_hit(_make_litellm_response(cache_key="k1"))
|
||||||
|
manager.detect_cache_hit(_make_litellm_response(cache_key="k2"))
|
||||||
|
# 1 miss
|
||||||
|
manager.detect_cache_hit(_make_litellm_response()) # 无 cache_key
|
||||||
|
|
||||||
|
stats = manager.stats()
|
||||||
|
assert stats["total_hits"] == 2
|
||||||
|
assert stats["total_misses"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 4. 阈值调优生效
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSimilarityThreshold:
|
||||||
|
"""场景 4 — similarity_threshold 传入 RedisSemanticCache 构造函数。"""
|
||||||
|
|
||||||
|
def test_threshold_passed_to_redis_semantic_cache(self):
|
||||||
|
"""RedisSemanticCache 构造时接收 similarity_threshold=0.87。"""
|
||||||
|
config = LitellmCacheConfig(
|
||||||
|
enabled=True,
|
||||||
|
backend="redis_semantic",
|
||||||
|
similarity_threshold=0.83,
|
||||||
|
redis_url="redis://localhost:6379",
|
||||||
|
)
|
||||||
|
manager = LitellmCacheManager(config)
|
||||||
|
|
||||||
|
mock_cache_instance = MagicMock()
|
||||||
|
with patch(
|
||||||
|
"litellm.caching.RedisSemanticCache", return_value=mock_cache_instance
|
||||||
|
) as mock_cls:
|
||||||
|
instance = manager._create_cache_instance()
|
||||||
|
mock_cls.assert_called_once_with(
|
||||||
|
redis_url="redis://localhost:6379",
|
||||||
|
similarity_threshold=0.83,
|
||||||
|
embedding_model="text-embedding-ada-002",
|
||||||
|
)
|
||||||
|
assert instance is mock_cache_instance
|
||||||
|
|
||||||
|
def test_default_threshold_is_087(self):
|
||||||
|
"""from_cache_config 固定 similarity_threshold=0.87。"""
|
||||||
|
old_config = CacheConfig(similarity_threshold=0.92)
|
||||||
|
litellm_config = LitellmCacheConfig.from_cache_config(old_config)
|
||||||
|
assert litellm_config.similarity_threshold == 0.87
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 5. User A 不返回 User B 缓存(安全要求 e)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestUserIsolation:
|
||||||
|
"""安全要求 e — User A 的查询不返回 User B 的缓存响应。"""
|
||||||
|
|
||||||
|
def test_different_users_different_keys(self):
|
||||||
|
"""不同 user_id 产生不同 cache_key。"""
|
||||||
|
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
||||||
|
msgs = _make_messages("What is my salary?")
|
||||||
|
key_a = manager.build_cache_key("gpt-4o", msgs, 0.0, user_id="user_a")
|
||||||
|
key_b = manager.build_cache_key("gpt-4o", msgs, 0.0, user_id="user_b")
|
||||||
|
assert key_a != key_b
|
||||||
|
|
||||||
|
def test_same_user_same_key(self):
|
||||||
|
"""相同 user_id 产生相同 cache_key。"""
|
||||||
|
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
||||||
|
msgs = _make_messages("What is my salary?")
|
||||||
|
key1 = manager.build_cache_key("gpt-4o", msgs, 0.0, user_id="user_a")
|
||||||
|
key2 = manager.build_cache_key("gpt-4o", msgs, 0.0, user_id="user_a")
|
||||||
|
assert key1 == key2
|
||||||
|
|
||||||
|
def test_no_user_id_same_as_no_user_id(self):
|
||||||
|
"""user_id=None 时两次调用产生相同 key(向后兼容)。"""
|
||||||
|
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
||||||
|
msgs = _make_messages()
|
||||||
|
key1 = manager.build_cache_key("gpt-4o", msgs, 0.0, user_id=None)
|
||||||
|
key2 = manager.build_cache_key("gpt-4o", msgs, 0.0, user_id=None)
|
||||||
|
assert key1 == key2
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 6. kb_acl_hash 隔离
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestACLIsolation:
|
||||||
|
"""安全要求 b — 不同 ACL scope 产生不同 cache_key。"""
|
||||||
|
|
||||||
|
def test_different_acl_hash_different_keys(self):
|
||||||
|
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
||||||
|
msgs = _make_messages("Summarize the document")
|
||||||
|
key1 = manager.build_cache_key("gpt-4o", msgs, 0.0, kb_acl_hash="acl_v1")
|
||||||
|
key2 = manager.build_cache_key("gpt-4o", msgs, 0.0, kb_acl_hash="acl_v2")
|
||||||
|
assert key1 != key2
|
||||||
|
|
||||||
|
def test_same_acl_hash_same_key(self):
|
||||||
|
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
||||||
|
msgs = _make_messages()
|
||||||
|
key1 = manager.build_cache_key("gpt-4o", msgs, 0.0, kb_acl_hash="acl_v1")
|
||||||
|
key2 = manager.build_cache_key("gpt-4o", msgs, 0.0, kb_acl_hash="acl_v1")
|
||||||
|
assert key1 == key2
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 7. kb_caching_disabled 禁用缓存
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestKBCachingDisabled:
|
||||||
|
"""安全要求 c — KB 设置 caching_disabled=True 时禁用缓存。"""
|
||||||
|
|
||||||
|
def test_should_cache_returns_false_when_disabled(self):
|
||||||
|
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
||||||
|
assert manager.should_cache(kb_caching_disabled=True) is False
|
||||||
|
|
||||||
|
def test_should_cache_returns_true_when_enabled(self):
|
||||||
|
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
||||||
|
assert manager.should_cache(kb_caching_disabled=False) is True
|
||||||
|
|
||||||
|
def test_should_cache_default_true(self):
|
||||||
|
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
||||||
|
assert manager.should_cache() is True
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 8. cache_params_for_hit / no_cache
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCacheParams:
|
||||||
|
def test_cache_params_for_hit(self):
|
||||||
|
params = LitellmCacheManager.cache_params_for_hit("my_cache_key")
|
||||||
|
assert params == {"cache_key": "my_cache_key"}
|
||||||
|
|
||||||
|
def test_cache_params_for_no_cache(self):
|
||||||
|
params = LitellmCacheManager.cache_params_for_no_cache()
|
||||||
|
assert params == {"no-cache": True}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 9. detect_cache_hit
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestDetectCacheHit:
|
||||||
|
def test_hit_with_cache_key(self):
|
||||||
|
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
||||||
|
resp = _make_litellm_response(cache_key="some_key")
|
||||||
|
assert manager.detect_cache_hit(resp) is True
|
||||||
|
|
||||||
|
def test_hit_with_cache_hit_flag(self):
|
||||||
|
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
||||||
|
resp = _make_litellm_response()
|
||||||
|
resp._hidden_params = {"cache_hit": True}
|
||||||
|
assert manager.detect_cache_hit(resp) is True
|
||||||
|
|
||||||
|
def test_miss_without_cache_key(self):
|
||||||
|
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
||||||
|
resp = _make_litellm_response()
|
||||||
|
assert manager.detect_cache_hit(resp) is False
|
||||||
|
|
||||||
|
def test_miss_with_no_hidden_params(self):
|
||||||
|
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
||||||
|
resp = SimpleNamespace(_hidden_params=None)
|
||||||
|
assert manager.detect_cache_hit(resp) is False
|
||||||
|
|
||||||
|
def test_miss_with_no_hidden_params_attr(self):
|
||||||
|
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
|
||||||
|
resp = SimpleNamespace()
|
||||||
|
assert manager.detect_cache_hit(resp) is False
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 10. LitellmCacheConfig.from_cache_config
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestFromCacheConfig:
|
||||||
|
def test_basic_conversion(self):
|
||||||
|
old = CacheConfig(
|
||||||
|
enabled=True,
|
||||||
|
backend="memory",
|
||||||
|
redis_url="redis://myhost:6379",
|
||||||
|
semantic_ttl=3600,
|
||||||
|
embedding_model="bge-m3",
|
||||||
|
)
|
||||||
|
litellm_config = LitellmCacheConfig.from_cache_config(old)
|
||||||
|
assert litellm_config.enabled is True
|
||||||
|
assert litellm_config.backend == "memory"
|
||||||
|
assert litellm_config.redis_url == "redis://myhost:6379"
|
||||||
|
assert litellm_config.ttl == 3600
|
||||||
|
assert litellm_config.similarity_threshold == 0.87 # 固定,忽略旧 0.92
|
||||||
|
assert litellm_config.per_user_namespace is True # 强制开启
|
||||||
|
|
||||||
|
def test_unknown_backend_falls_to_auto(self):
|
||||||
|
old = CacheConfig(backend="unknown_backend")
|
||||||
|
litellm_config = LitellmCacheConfig.from_cache_config(old)
|
||||||
|
assert litellm_config.backend == "auto"
|
||||||
|
|
||||||
|
def test_redis_backend_preserved(self):
|
||||||
|
old = CacheConfig(backend="redis")
|
||||||
|
litellm_config = LitellmCacheConfig.from_cache_config(old)
|
||||||
|
assert litellm_config.backend == "redis"
|
||||||
|
|
||||||
|
def test_empty_embedding_model_falls_to_default(self):
|
||||||
|
old = CacheConfig(embedding_model="")
|
||||||
|
litellm_config = LitellmCacheConfig.from_cache_config(old)
|
||||||
|
assert litellm_config.embedding_model == "text-embedding-ada-002"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 11. LitellmCacheManager.enable/disable
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnableDisable:
|
||||||
|
def test_enable_sets_litellm_cache(self):
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
config = LitellmCacheConfig(backend="memory")
|
||||||
|
manager = LitellmCacheManager(config)
|
||||||
|
manager.enable()
|
||||||
|
assert litellm.cache is not None
|
||||||
|
assert manager._cache_instance is not None
|
||||||
|
|
||||||
|
def test_disable_clears_litellm_cache(self):
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
config = LitellmCacheConfig(backend="memory")
|
||||||
|
manager = LitellmCacheManager(config)
|
||||||
|
manager.enable()
|
||||||
|
assert litellm.cache is not None
|
||||||
|
manager.disable()
|
||||||
|
assert litellm.cache is None
|
||||||
|
assert manager._cache_instance is None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 12. generate_cache_key 向后兼容
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestBackwardCompatibility:
|
||||||
|
"""user_id=None, kb_acl_hash=None 时与旧版 generate_cache_key 完全一致。"""
|
||||||
|
|
||||||
|
def test_none_user_id_same_as_not_passing(self):
|
||||||
|
msgs = _make_messages()
|
||||||
|
key1 = generate_cache_key("gpt-4o", msgs, 0.0, user_id=None, kb_acl_hash=None)
|
||||||
|
key2 = generate_cache_key("gpt-4o", msgs, 0.0)
|
||||||
|
assert key1 == key2
|
||||||
|
|
||||||
|
def test_backward_compat_deterministic(self):
|
||||||
|
msgs = _make_messages()
|
||||||
|
key1 = generate_cache_key("gpt-4o", msgs, 0.0, user_id=None, kb_acl_hash=None)
|
||||||
|
key2 = generate_cache_key("gpt-4o", msgs, 0.0, user_id=None, kb_acl_hash=None)
|
||||||
|
assert key1 == key2
|
||||||
|
|
||||||
|
def test_user_id_changes_key(self):
|
||||||
|
msgs = _make_messages()
|
||||||
|
key_none = generate_cache_key("gpt-4o", msgs, 0.0)
|
||||||
|
key_user = generate_cache_key("gpt-4o", msgs, 0.0, user_id="user_a")
|
||||||
|
assert key_none != key_user
|
||||||
|
|
||||||
|
def test_acl_hash_changes_key(self):
|
||||||
|
msgs = _make_messages()
|
||||||
|
key_none = generate_cache_key("gpt-4o", msgs, 0.0)
|
||||||
|
key_acl = generate_cache_key("gpt-4o", msgs, 0.0, kb_acl_hash="acl_v1")
|
||||||
|
assert key_none != key_acl
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 13. RedisSemanticCache fallback
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRedisSemanticCacheFallback:
|
||||||
|
"""redisvl 缺失时 auto backend 回退到 RedisCache。"""
|
||||||
|
|
||||||
|
def test_auto_fallback_to_redis_cache_when_redisvl_missing(self):
|
||||||
|
"""auto 模式下 RedisSemanticCache 构造 ImportError → 回退到 RedisCache。"""
|
||||||
|
import litellm.caching
|
||||||
|
|
||||||
|
config = LitellmCacheConfig(backend="auto", redis_url="redis://localhost:6379")
|
||||||
|
manager = LitellmCacheManager(config)
|
||||||
|
|
||||||
|
mock_redis_cache = MagicMock(name="RedisCacheInstance")
|
||||||
|
|
||||||
|
# 模拟 RedisSemanticCache 构造时 ImportError(redisvl 未安装)
|
||||||
|
def raise_import_error(*args, **kwargs):
|
||||||
|
raise ImportError("No module named 'redisvl'")
|
||||||
|
|
||||||
|
with patch.object(litellm.caching, "RedisSemanticCache", side_effect=raise_import_error):
|
||||||
|
with patch.object(
|
||||||
|
litellm.caching, "RedisCache", return_value=mock_redis_cache
|
||||||
|
) as mock_rc:
|
||||||
|
instance = manager._create_cache_instance()
|
||||||
|
mock_rc.assert_called_once_with(redis_url="redis://localhost:6379")
|
||||||
|
assert instance is mock_redis_cache
|
||||||
|
|
||||||
|
def test_redis_semantic_backend_raises_when_redisvl_missing(self):
|
||||||
|
"""显式 redis_semantic backend + redisvl 缺失 → raise ImportError。"""
|
||||||
|
import litellm.caching
|
||||||
|
|
||||||
|
config = LitellmCacheConfig(backend="redis_semantic")
|
||||||
|
manager = LitellmCacheManager(config)
|
||||||
|
|
||||||
|
def raise_import_error(*args, **kwargs):
|
||||||
|
raise ImportError("No module named 'redisvl'")
|
||||||
|
|
||||||
|
with patch.object(litellm.caching, "RedisSemanticCache", side_effect=raise_import_error):
|
||||||
|
with pytest.raises(ImportError):
|
||||||
|
manager._create_cache_instance()
|
||||||
|
|
||||||
|
def test_memory_backend_uses_in_memory_cache(self):
|
||||||
|
"""memory backend 直接使用 InMemoryCache,不尝试 Redis。"""
|
||||||
|
import litellm.caching
|
||||||
|
|
||||||
|
config = LitellmCacheConfig(backend="memory")
|
||||||
|
manager = LitellmCacheManager(config)
|
||||||
|
instance = manager._create_cache_instance()
|
||||||
|
assert isinstance(instance, litellm.caching.InMemoryCache)
|
||||||
|
|
@ -1,15 +1,20 @@
|
||||||
"""Integration tests for LLM Cache integration into LLMGateway (U2)."""
|
"""Integration tests for LLM Cache integration into LLMGateway (U2/U17).
|
||||||
|
|
||||||
|
U17 更新:gateway 改用 ``LitellmCacheManager``(LiteLLM 内置缓存)。
|
||||||
|
旧的 ``InMemoryLLMCache`` 手动缓存逻辑已移除,缓存读写由 LiteLLM 内部处理。
|
||||||
|
测试用 ``CacheAwareMockProvider`` 模拟 LiteLLM 的缓存行为(检查 cache_key,
|
||||||
|
命中时返回 ``cache_hit=True`` 的响应)。
|
||||||
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from agentkit.llm.cache import InMemoryLLMCache
|
|
||||||
from agentkit.llm.config import CacheConfig, LLMConfig
|
from agentkit.llm.config import CacheConfig, LLMConfig
|
||||||
from agentkit.llm.gateway import LLMGateway
|
from agentkit.llm.gateway import LLMGateway
|
||||||
from agentkit.llm.protocol import LLMProvider, LLMRequest, LLMResponse, TokenUsage
|
from agentkit.llm.protocol import LLMProvider, LLMRequest, LLMResponse, TokenUsage
|
||||||
|
|
||||||
|
|
||||||
class MockProvider(LLMProvider):
|
class MockProvider(LLMProvider):
|
||||||
"""Mock LLM provider that tracks call count."""
|
"""Mock LLM provider that tracks call count (no cache awareness)."""
|
||||||
|
|
||||||
def __init__(self, response_content: str = "Mock response"):
|
def __init__(self, response_content: str = "Mock response"):
|
||||||
self.call_count = 0
|
self.call_count = 0
|
||||||
|
|
@ -24,6 +29,47 @@ class MockProvider(LLMProvider):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CacheAwareMockProvider(LLMProvider):
|
||||||
|
"""Mock provider that simulates LiteLLM's cache behavior.
|
||||||
|
|
||||||
|
Reads ``request._cache`` for cache_key, maintains an internal cache dict.
|
||||||
|
On cache hit: returns cached response with ``cache_hit=True`` (call_count NOT incremented).
|
||||||
|
On cache miss: generates fresh response, caches it, returns with ``cache_hit=False``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, response_content: str = "Mock response"):
|
||||||
|
self.call_count = 0 # 仅统计真实调用(缓存未命中)
|
||||||
|
self._response_content = response_content
|
||||||
|
self._cache: dict[str, LLMResponse] = {}
|
||||||
|
|
||||||
|
async def chat(self, request: LLMRequest) -> LLMResponse:
|
||||||
|
cache_params = getattr(request, "_cache", None) or {}
|
||||||
|
cache_key = cache_params.get("cache_key")
|
||||||
|
no_cache = cache_params.get("no-cache", False)
|
||||||
|
|
||||||
|
# 缓存命中 — 返回缓存响应(不增加 call_count)
|
||||||
|
if cache_key and cache_key in self._cache and not no_cache:
|
||||||
|
cached = self._cache[cache_key]
|
||||||
|
return LLMResponse(
|
||||||
|
content=cached.content,
|
||||||
|
model=cached.model,
|
||||||
|
usage=cached.usage,
|
||||||
|
cache_hit=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 缓存未命中 — 真实调用
|
||||||
|
self.call_count += 1
|
||||||
|
response = LLMResponse(
|
||||||
|
content=self._response_content,
|
||||||
|
model=request.model,
|
||||||
|
usage=TokenUsage(prompt_tokens=10, completion_tokens=20),
|
||||||
|
cache_hit=False,
|
||||||
|
)
|
||||||
|
if cache_key and not no_cache:
|
||||||
|
self._cache[cache_key] = response
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
def _make_messages(user_content: str = "Hello") -> list[dict[str, str]]:
|
def _make_messages(user_content: str = "Hello") -> list[dict[str, str]]:
|
||||||
return [
|
return [
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
|
@ -52,7 +98,7 @@ class TestCacheEnabled:
|
||||||
"""First request is a cache miss — provider is called."""
|
"""First request is a cache miss — provider is called."""
|
||||||
config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory"))
|
config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory"))
|
||||||
gateway = LLMGateway(config=config)
|
gateway = LLMGateway(config=config)
|
||||||
provider = MockProvider()
|
provider = CacheAwareMockProvider()
|
||||||
gateway.register_provider("test", provider)
|
gateway.register_provider("test", provider)
|
||||||
|
|
||||||
msgs = _make_messages()
|
msgs = _make_messages()
|
||||||
|
|
@ -60,28 +106,30 @@ class TestCacheEnabled:
|
||||||
|
|
||||||
assert provider.call_count == 1
|
assert provider.call_count == 1
|
||||||
assert response.content == "Mock response"
|
assert response.content == "Mock response"
|
||||||
|
assert response.cache_hit is False
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_second_request_is_hit(self):
|
async def test_second_request_is_hit(self):
|
||||||
"""Second identical request is a cache hit — provider NOT called."""
|
"""Second identical request is a cache hit — provider NOT called again."""
|
||||||
config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory"))
|
config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory"))
|
||||||
gateway = LLMGateway(config=config)
|
gateway = LLMGateway(config=config)
|
||||||
provider = MockProvider()
|
provider = CacheAwareMockProvider()
|
||||||
gateway.register_provider("test", provider)
|
gateway.register_provider("test", provider)
|
||||||
|
|
||||||
msgs = _make_messages()
|
msgs = _make_messages()
|
||||||
await gateway.chat(msgs, "test/model", temperature=0.0)
|
await gateway.chat(msgs, "test/model", temperature=0.0)
|
||||||
response = await gateway.chat(msgs, "test/model", temperature=0.0)
|
response = await gateway.chat(msgs, "test/model", temperature=0.0)
|
||||||
|
|
||||||
assert provider.call_count == 1 # Not called again
|
assert provider.call_count == 1 # Not called again (cache hit)
|
||||||
assert response.content == "Mock response"
|
assert response.content == "Mock response"
|
||||||
|
assert response.cache_hit is True
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_cache_hit_usage_has_zero_cost(self):
|
async def test_cache_hit_usage_has_zero_cost(self):
|
||||||
"""Cache hit records usage with cost=0."""
|
"""Cache hit records usage with cost=0."""
|
||||||
config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory"))
|
config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory"))
|
||||||
gateway = LLMGateway(config=config)
|
gateway = LLMGateway(config=config)
|
||||||
provider = MockProvider()
|
provider = CacheAwareMockProvider()
|
||||||
gateway.register_provider("test", provider)
|
gateway.register_provider("test", provider)
|
||||||
|
|
||||||
msgs = _make_messages()
|
msgs = _make_messages()
|
||||||
|
|
@ -98,7 +146,7 @@ class TestCacheEnabled:
|
||||||
"""Different messages produce cache misses."""
|
"""Different messages produce cache misses."""
|
||||||
config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory"))
|
config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory"))
|
||||||
gateway = LLMGateway(config=config)
|
gateway = LLMGateway(config=config)
|
||||||
provider = MockProvider()
|
provider = CacheAwareMockProvider()
|
||||||
gateway.register_provider("test", provider)
|
gateway.register_provider("test", provider)
|
||||||
|
|
||||||
await gateway.chat(_make_messages("Hello"), "test/model", temperature=0.0)
|
await gateway.chat(_make_messages("Hello"), "test/model", temperature=0.0)
|
||||||
|
|
@ -110,13 +158,15 @@ class TestCacheEnabled:
|
||||||
class TestCacheConfig:
|
class TestCacheConfig:
|
||||||
def test_config_from_dict(self):
|
def test_config_from_dict(self):
|
||||||
"""CacheConfig can be loaded from dict."""
|
"""CacheConfig can be loaded from dict."""
|
||||||
config = LLMConfig.from_dict({
|
config = LLMConfig.from_dict(
|
||||||
"cache": {
|
{
|
||||||
"enabled": True,
|
"cache": {
|
||||||
"backend": "memory",
|
"enabled": True,
|
||||||
"exact_ttl": 7200,
|
"backend": "memory",
|
||||||
|
"exact_ttl": 7200,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
})
|
)
|
||||||
assert config.cache is not None
|
assert config.cache is not None
|
||||||
assert config.cache.enabled is True
|
assert config.cache.enabled is True
|
||||||
assert config.cache.backend == "memory"
|
assert config.cache.backend == "memory"
|
||||||
|
|
@ -129,34 +179,35 @@ class TestCacheConfig:
|
||||||
|
|
||||||
def test_config_from_dict_embedding(self):
|
def test_config_from_dict_embedding(self):
|
||||||
"""Embedding config is loaded correctly."""
|
"""Embedding config is loaded correctly."""
|
||||||
config = LLMConfig.from_dict({
|
config = LLMConfig.from_dict(
|
||||||
"cache": {
|
{
|
||||||
"enabled": True,
|
"cache": {
|
||||||
"embedding": {
|
"enabled": True,
|
||||||
"provider": "xinference",
|
"embedding": {
|
||||||
"model": "bge-m3",
|
"provider": "xinference",
|
||||||
"base_url": "http://localhost:9997/v1",
|
"model": "bge-m3",
|
||||||
},
|
"base_url": "http://localhost:9997/v1",
|
||||||
|
},
|
||||||
|
}
|
||||||
}
|
}
|
||||||
})
|
)
|
||||||
assert config.cache.embedding_provider == "xinference"
|
assert config.cache.embedding_provider == "xinference"
|
||||||
assert config.cache.embedding_model == "bge-m3"
|
assert config.cache.embedding_model == "bge-m3"
|
||||||
assert config.cache.embedding_base_url == "http://localhost:9997/v1"
|
assert config.cache.embedding_base_url == "http://localhost:9997/v1"
|
||||||
|
|
||||||
def test_gateway_creates_cache_when_enabled(self):
|
def test_gateway_creates_cache_when_enabled(self):
|
||||||
"""Gateway creates cache instance when cache.enabled=True."""
|
"""Gateway creates cache_manager instance when cache.enabled=True."""
|
||||||
config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory"))
|
config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory"))
|
||||||
gateway = LLMGateway(config=config)
|
gateway = LLMGateway(config=config)
|
||||||
assert gateway._cache is not None
|
assert gateway._cache_manager is not None
|
||||||
assert isinstance(gateway._cache, InMemoryLLMCache)
|
|
||||||
|
|
||||||
def test_gateway_no_cache_when_disabled(self):
|
def test_gateway_no_cache_when_disabled(self):
|
||||||
"""Gateway has no cache when cache is disabled."""
|
"""Gateway has no cache_manager when cache is disabled."""
|
||||||
config = LLMConfig(cache=CacheConfig(enabled=False))
|
config = LLMConfig(cache=CacheConfig(enabled=False))
|
||||||
gateway = LLMGateway(config=config)
|
gateway = LLMGateway(config=config)
|
||||||
assert gateway._cache is None
|
assert gateway._cache_manager is None
|
||||||
|
|
||||||
def test_gateway_no_cache_when_no_config(self):
|
def test_gateway_no_cache_when_no_config(self):
|
||||||
"""Gateway has no cache when cache config is absent."""
|
"""Gateway has no cache_manager when cache config is absent."""
|
||||||
gateway = LLMGateway()
|
gateway = LLMGateway()
|
||||||
assert gateway._cache is None
|
assert gateway._cache_manager is None
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue