feat(llm): U17 — LiteLLM 语义缓存替换 + per-user/ACL scope 安全隔离

- 新增 LitellmCacheManager:配置 litellm.cache 全局,三级后端 fallback
  (RedisSemanticCache -> RedisCache -> InMemoryCache),redisvl lazy import
- cache_key 扩展 user_id + kb_acl_hash 参数(安全要求 a/b/e)
- gateway 集成:读取 KB caching_disabled flag(安全要求 c),构建带 scope
  的 cache_key,命中时 cost=0
- LLMResponse 新增 cache_hit 字段;LLMRequest 新增 cache 参数
- litellm_provider 透传 cache 参数 + 检测 _hidden_params 缓存命中
- 33 个新测试覆盖 13 场景(含 User A != User B 缓存隔离)
- 旧 InMemoryLLMCache/RedisLLMCache 保留向后兼容
This commit is contained in:
chiguyong 2026-06-25 22:49:59 +08:00
parent 86541d7172
commit 793476cafa
8 changed files with 868 additions and 176 deletions

View File

@ -14,11 +14,14 @@ import logging
import time import time
from collections import OrderedDict from collections import OrderedDict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Protocol, runtime_checkable from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
from agentkit.llm.protocol import LLMResponse, TokenUsage, ToolCall from agentkit.llm.protocol import LLMResponse, TokenUsage, ToolCall
from agentkit.utils.vector_math import compute_cosine_similarity from agentkit.utils.vector_math import compute_cosine_similarity
if TYPE_CHECKING:
from agentkit.llm.config import CacheConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -61,8 +64,7 @@ def _serialize_response(response: LLMResponse) -> dict:
"completion_tokens": response.usage.completion_tokens, "completion_tokens": response.usage.completion_tokens,
}, },
"tool_calls": [ "tool_calls": [
{"id": tc.id, "name": tc.name, "arguments": tc.arguments} {"id": tc.id, "name": tc.name, "arguments": tc.arguments} for tc in response.tool_calls
for tc in response.tool_calls
], ],
"latency_ms": response.latency_ms, "latency_ms": response.latency_ms,
} }
@ -236,7 +238,9 @@ class InMemoryLLMCache:
self._cache.move_to_end(key) self._cache.move_to_end(key)
existing = self._cache[key] existing = self._cache[key]
# Preserve existing embedding if new one is None # Preserve existing embedding if new one is None
effective_embedding = query_embedding if query_embedding is not None else existing.query_embedding effective_embedding = (
query_embedding if query_embedding is not None else existing.query_embedding
)
else: else:
effective_embedding = query_embedding or [] effective_embedding = query_embedding or []
@ -276,9 +280,7 @@ class InMemoryLLMCache:
return count return count
# Simple prefix matching for pattern # Simple prefix matching for pattern
keys_to_remove = [ keys_to_remove = [k for k in self._cache if k.startswith(pattern.replace("*", ""))]
k for k in self._cache if k.startswith(pattern.replace("*", ""))
]
for key in keys_to_remove: for key in keys_to_remove:
del self._cache[key] del self._cache[key]
self._embeddings.pop(key, None) self._embeddings.pop(key, None)
@ -338,9 +340,7 @@ class RedisLLMCache:
if self._redis is None: if self._redis is None:
import redis.asyncio as aioredis import redis.asyncio as aioredis
self._redis = aioredis.from_url( self._redis = aioredis.from_url(self._redis_url, decode_responses=True)
self._redis_url, decode_responses=True
)
return self._redis return self._redis
async def aclose(self) -> None: async def aclose(self) -> None:
@ -393,9 +393,7 @@ class RedisLLMCache:
if data is not None: if data is not None:
entry = _deserialize_entry(json.loads(data)) entry = _deserialize_entry(json.loads(data))
self._hits += 1 self._hits += 1
return CacheResult( return CacheResult(hit=True, response=entry.response, match_type="exact")
hit=True, response=entry.response, match_type="exact"
)
self._misses += 1 self._misses += 1
return CacheResult(hit=False) return CacheResult(hit=False)
except Exception as e: except Exception as e:
@ -424,6 +422,7 @@ class RedisLLMCache:
if len(cache_keys_list) > max_scan: if len(cache_keys_list) > max_scan:
# Take a random sample to avoid always scanning the same subset # Take a random sample to avoid always scanning the same subset
import random import random
cache_keys_list = random.sample(cache_keys_list, max_scan) cache_keys_list = random.sample(cache_keys_list, max_scan)
# Batch fetch embeddings # Batch fetch embeddings
@ -460,9 +459,7 @@ class RedisLLMCache:
if data is not None: if data is not None:
entry = _deserialize_entry(json.loads(data)) entry = _deserialize_entry(json.loads(data))
self._hits += 1 self._hits += 1
return CacheResult( return CacheResult(hit=True, response=entry.response, match_type="semantic")
hit=True, response=entry.response, match_type="semantic"
)
# Data key expired but embedding still exists — mark for cleanup # Data key expired but embedding still exists — mark for cleanup
try: try:
await redis.srem(self.INDEX_KEY, best_key) await redis.srem(self.INDEX_KEY, best_key)
@ -615,9 +612,7 @@ def create_llm_cache(
similarity_threshold=similarity_threshold, similarity_threshold=similarity_threshold,
) )
except ImportError: except ImportError:
logger.warning( logger.warning("redis package not available, falling back to in-memory cache")
"redis package not available, falling back to in-memory cache"
)
return InMemoryLLMCache( return InMemoryLLMCache(
max_entries=max_entries, max_entries=max_entries,
exact_ttl=exact_ttl, exact_ttl=exact_ttl,
@ -630,3 +625,200 @@ def create_llm_cache(
semantic_ttl=semantic_ttl, semantic_ttl=semantic_ttl,
similarity_threshold=similarity_threshold, similarity_threshold=similarity_threshold,
) )
# ---------------------------------------------------------------------------
# U17 — LiteLLM 缓存管理器
# ---------------------------------------------------------------------------
@dataclass
class LitellmCacheConfig:
"""U17 — LiteLLM 缓存配置(从 CacheConfig 转换)。
与旧 ``CacheConfig`` 的区别
- ``similarity_threshold`` 固定默认 0.87plan 规定
- ``per_user_namespace`` 强制开启安全要求 a
- ``backend`` 新增 ``redis_semantic`` 选项需要 redisvl
"""
enabled: bool = False
backend: str = "auto" # "auto" | "redis_semantic" | "redis" | "memory"
redis_url: str = "redis://localhost:6379"
similarity_threshold: float = 0.87 # U17 默认 0.87plan 规定)
ttl: int = 86400
embedding_model: str = "text-embedding-ada-002"
per_user_namespace: bool = True # 安全要求 (a)
@classmethod
def from_cache_config(cls, c: "CacheConfig") -> "LitellmCacheConfig":
"""从现有 CacheConfig 转换。
- ``similarity_threshold`` 固定 0.87U17 plan 规定忽略旧 0.92
- ``per_user_namespace`` 强制 True安全要求 a
- ``embedding_model`` 回退到 "text-embedding-ada-002"LiteLLM 默认
"""
return cls(
enabled=c.enabled,
backend=c.backend if c.backend in ("auto", "redis", "memory") else "auto",
redis_url=c.redis_url,
similarity_threshold=0.87, # U17 固定默认,忽略旧 0.92
ttl=c.semantic_ttl,
embedding_model=c.embedding_model or "text-embedding-ada-002",
per_user_namespace=True, # 强制开启
)
class LitellmCacheManager:
"""U17 — LiteLLM 全局缓存管理器。
职责
1. 创建并设置 ``litellm.cache`` 全局实例
2. 构建带 user/ACL scope cache key安全要求 a, b
3. 提供 per-call cache 参数cache_key no-cache
4. 检测 LiteLLM 响应的缓存命中标志用于 usage tracking
5. 统计缓存命中率
后端选择优先级backend="auto"
RedisSemanticCache redisvl RedisCache精确 InMemoryCache
安全约束
- (a) cache key 包含 user_idper-user namespace
- (b) cache key 包含 kb_acl_hashACL-scope 隔离
- (c) KB 设置 caching_disabled=True 时禁用缓存
- (e) User A 的查询不会命中 User B 的缓存
"""
def __init__(self, config: LitellmCacheConfig):
self._config = config
self._cache_instance: Any = None # litellm.caching.Cache 实例
self._hits = 0
self._misses = 0
def enable(self) -> None:
"""创建 LiteLLM Cache 实例并赋值给 ``litellm.cache``。"""
import litellm
self._cache_instance = self._create_cache_instance()
litellm.cache = self._cache_instance
def disable(self) -> None:
"""禁用缓存 — 设置 ``litellm.cache = None``。"""
import litellm
litellm.cache = None
self._cache_instance = None
def _create_cache_instance(self) -> Any:
"""根据 backend 配置创建 LiteLLM Cache 实例。
auto 模式按优先级尝试RedisSemanticCache RedisCache InMemoryCache
redisvl 缺失时自动回退安全要求 d 不添加为必需依赖
"""
backend = self._config.backend
if backend in ("auto", "redis_semantic"):
# 尝试 RedisSemanticCache需要 redisvl — lazy import缺失时 fallback
try:
from litellm.caching import RedisSemanticCache
return RedisSemanticCache(
redis_url=self._config.redis_url,
similarity_threshold=self._config.similarity_threshold,
embedding_model=self._config.embedding_model,
)
except ImportError:
logger.warning(
"RedisSemanticCache 需要 redisvl 包(未安装),"
"回退到 RedisCache精确匹配无语义搜索"
"安装 redisvl 以启用语义缓存pip install redisvl"
)
if backend == "redis_semantic":
raise # 显式要求语义缓存但 redisvl 缺失 — 报错
except Exception as e:
logger.warning(f"RedisSemanticCache 初始化失败: {e},回退到 RedisCache")
if backend in ("auto", "redis", "redis_semantic"):
try:
from litellm.caching import RedisCache
return RedisCache(redis_url=self._config.redis_url)
except Exception as e:
logger.warning(f"RedisCache 初始化失败: {e},回退到 InMemoryCache")
from litellm.caching import InMemoryCache
return InMemoryCache()
def build_cache_key(
self,
model: str,
messages: list[dict[str, str]],
temperature: float,
tools: list[dict] | None = None,
tool_choice: str = "auto",
max_tokens: int = 2000,
user_id: str | None = None,
kb_acl_hash: str | None = None,
) -> str:
"""构建带 user/ACL scope 的 cache key安全要求 a, b, e
委托给 ``cache_key.generate_cache_key``额外注入 user_id + kb_acl_hash
作为命名空间隔离确保 User A 的查询不会命中 User B 的缓存
"""
from agentkit.llm.cache_key import generate_cache_key
return generate_cache_key(
model=model,
messages=messages,
temperature=temperature,
tools=tools,
tool_choice=tool_choice,
max_tokens=max_tokens,
user_id=user_id,
kb_acl_hash=kb_acl_hash,
)
def should_cache(
self,
kb_caching_disabled: bool = False,
user_id: str | None = None,
) -> bool:
"""判断当前请求是否应该缓存(安全要求 c
- KB 设置 caching_disabled=True 不缓存
- 其余情况缓存user_id None 时仍可缓存 key 不含 user scope
"""
_ = user_id # 预留:未来支持 per-user 缓存禁用
if kb_caching_disabled:
return False
return True
@staticmethod
def cache_params_for_hit(cache_key: str) -> dict[str, str]:
"""返回 litellm acompletion 的 cache 参数(用于期望命中的调用)。"""
return {"cache_key": cache_key}
@staticmethod
def cache_params_for_no_cache() -> dict[str, bool]:
"""返回 litellm acompletion 的 cache 参数(禁用缓存)。"""
return {"no-cache": True}
def detect_cache_hit(self, response: Any) -> bool:
"""检测 LiteLLM 响应是否为缓存命中。
LiteLLM 在缓存命中时设置 ``response._hidden_params["cache_key"]``
"""
hidden = getattr(response, "_hidden_params", None)
if isinstance(hidden, dict):
if "cache_key" in hidden or hidden.get("cache_hit"):
self._hits += 1
return True
self._misses += 1
return False
def stats(self) -> dict[str, int]:
"""返回缓存统计。"""
return {
"total_hits": self._hits,
"total_misses": self._misses,
}

View File

@ -12,6 +12,8 @@ def generate_cache_key(
tools: list[dict[str, Any]] | None = None, tools: list[dict[str, Any]] | None = None,
tool_choice: str = "auto", tool_choice: str = "auto",
max_tokens: int = 2000, max_tokens: int = 2000,
user_id: str | None = None,
kb_acl_hash: str | None = None,
) -> str: ) -> str:
"""Generate a deterministic SHA-256 cache key from LLM request parameters. """Generate a deterministic SHA-256 cache key from LLM request parameters.
@ -19,6 +21,10 @@ def generate_cache_key(
model, system_prompt (extracted from messages), messages content, model, system_prompt (extracted from messages), messages content,
temperature, tools, tool_choice, and max_tokens. temperature, tools, tool_choice, and max_tokens.
U17 安全扩展``user_id`` ``kb_acl_hash`` None 时加入 hash 组件
实现 per-user namespace ACL-scope 隔离安全要求 a, b None
行为与旧版完全一致向后兼容
Args: Args:
model: Model identifier (e.g. "openai/gpt-4o"). model: Model identifier (e.g. "openai/gpt-4o").
messages: Chat messages list (may include system prompt as first message). messages: Chat messages list (may include system prompt as first message).
@ -26,6 +32,8 @@ def generate_cache_key(
tools: Optional list of tool definitions. tools: Optional list of tool definitions.
tool_choice: Tool selection mode ("auto", "none", etc.). tool_choice: Tool selection mode ("auto", "none", etc.).
max_tokens: Maximum response tokens. max_tokens: Maximum response tokens.
user_id: U17 用户 ID用于 per-user cache namespace 隔离
kb_acl_hash: U17 KB ACL-scope hash用于 ACL 隔离缓存键
Returns: Returns:
64-character hex SHA-256 hash string. 64-character hex SHA-256 hash string.
@ -40,6 +48,11 @@ def generate_cache_key(
_hash_str(tool_choice), _hash_str(tool_choice),
_hash_str(str(max_tokens)), _hash_str(str(max_tokens)),
] ]
# U17 — per-user namespace + ACL scope hash安全要求 a, b, e
if user_id is not None:
components.append(_hash_str(f"user:{user_id}"))
if kb_acl_hash is not None:
components.append(_hash_str(f"acl:{kb_acl_hash}"))
combined = "".join(components) combined = "".join(components)
return hashlib.sha256(combined.encode()).hexdigest() return hashlib.sha256(combined.encode()).hexdigest()
@ -61,6 +74,4 @@ def _hash_json(obj: Any) -> str:
"""SHA-256 hash of a JSON-serializable object.""" """SHA-256 hash of a JSON-serializable object."""
if obj is None: if obj is None:
return hashlib.sha256(b"null").hexdigest() return hashlib.sha256(b"null").hexdigest()
return hashlib.sha256( return hashlib.sha256(json.dumps(obj, sort_keys=True, ensure_ascii=False).encode()).hexdigest()
json.dumps(obj, sort_keys=True, ensure_ascii=False).encode()
).hexdigest()

View File

@ -51,47 +51,19 @@ class LLMGateway:
self._usage_tracker = UsageTracker(store=usage_store) if usage_store else UsageTracker() self._usage_tracker = UsageTracker(store=usage_store) if usage_store else UsageTracker()
self._config = config or LLMConfig() self._config = config or LLMConfig()
# Cache (opt-in, disabled by default) # Cache (U17 — LiteLLM 缓存管理器opt-in默认禁用)
self._cache: Any = None # LLMCache | None self._cache_manager: Any = None # LitellmCacheManager | None
self._embedder: Any = None # Embedder | None
if self._config.cache and self._config.cache.enabled: if self._config.cache and self._config.cache.enabled:
from agentkit.llm.cache import create_llm_cache from agentkit.llm.cache import LitellmCacheConfig, LitellmCacheManager
self._cache = create_llm_cache( litellm_config = LitellmCacheConfig.from_cache_config(self._config.cache)
backend=self._config.cache.backend, self._cache_manager = LitellmCacheManager(litellm_config)
redis_url=self._config.cache.redis_url, self._cache_manager.enable()
max_entries=self._config.cache.max_entries,
exact_ttl=self._config.cache.exact_ttl,
semantic_ttl=self._config.cache.semantic_ttl,
similarity_threshold=self._config.cache.similarity_threshold,
)
self._embedder = self._create_embedder(self._config.cache)
logger.info( logger.info(
f"LLM cache enabled (backend={self._config.cache.backend}, " f"LLM cache enabled (LiteLLM, backend={self._config.cache.backend}, "
f"embedder={self._config.cache.embedding_provider}/{self._config.cache.embedding_model})" f"similarity_threshold={litellm_config.similarity_threshold})"
) )
def _create_embedder(self, cache_config) -> Any:
"""Create embedder for semantic cache based on config."""
try:
from agentkit.memory.embedder import OpenAIEmbedder
if cache_config.embedding_provider in ("xinference", "local"):
return OpenAIEmbedder(
api_key=cache_config.embedding_api_key or "not-needed",
model=cache_config.embedding_model,
base_url=cache_config.embedding_base_url or "http://localhost:9997/v1",
)
# Default: OpenAI
return OpenAIEmbedder(
api_key=cache_config.embedding_api_key,
model=cache_config.embedding_model,
base_url=cache_config.embedding_base_url,
)
except Exception as e:
logger.warning(f"Failed to create embedder for semantic cache: {e}")
return None
def register_provider(self, name: str, provider: LLMProvider) -> None: def register_provider(self, name: str, provider: LLMProvider) -> None:
"""注册 Provider""" """注册 Provider"""
self._providers[name] = provider self._providers[name] = provider
@ -114,6 +86,8 @@ class LLMGateway:
user_id: str | None = None, user_id: str | None = None,
department_ids: list[str] | None = None, department_ids: list[str] | None = None,
db_path: Path | str | None = None, db_path: Path | str | None = None,
kb_id: str | None = None,
kb_acl_hash: str | None = None,
**kwargs, **kwargs,
) -> LLMResponse: ) -> LLMResponse:
"""发送 chat 请求,自动解析别名和 Fallback""" """发送 chat 请求,自动解析别名和 Fallback"""
@ -151,67 +125,40 @@ class LLMGateway:
start = time.monotonic() start = time.monotonic()
# ── Cache check ── # ── Cache check (U17 — LiteLLM cache via cache_key in request) ──
cache_key = None # LiteLLM 在 litellm.acompletion 内部处理缓存读写gateway 只需:
query_embedding = None # 1. 构建 per-user + ACL-scoped cache_key安全要求 a, b
if self._cache is not None: # 2. 将 cache 参数注入 kwargs 透传到 provider
from agentkit.llm.cache_key import generate_cache_key # 3. 检测响应的 cache_hit 标志,用于 usage trackingcost=0
if self._cache_manager is not None:
from agentkit.llm.cache import LitellmCacheManager
cache_key = generate_cache_key( # 解析 KB caching_disabled安全要求 c
model=resolved_model, kb_caching_disabled = False
messages=messages, if kb_id is not None:
temperature=kwargs.get("temperature", 0.7),
tools=tools,
tool_choice=tool_choice,
max_tokens=kwargs.get("max_tokens", 2000),
)
result = await self._cache.get(cache_key)
if result.hit:
latency_ms = (time.monotonic() - start) * 1000
await self._record_usage(
agent_name=agent_name,
model=result.response.model,
usage=result.response.usage,
cost=0.0,
latency_ms=latency_ms,
user_id=user_id,
department_ids=department_ids,
)
if _span is not None:
_span.set_attribute("gen_ai.cache.hit", True)
_span.set_attribute("gen_ai.cache.match_type", result.match_type)
return result.response
# Semantic match (only for temperature == 0)
temperature = kwargs.get("temperature", 0.7)
if temperature == 0 and self._embedder is not None:
try: try:
# Embed last N messages for context-aware semantic matching from agentkit.rag_platform.settings import get_settings_store
# (not just last user message — avoids cross-context false hits)
recent_messages = messages[-3:] if len(messages) > 3 else messages settings = await get_settings_store().get_settings(kb_id)
embed_text = " | ".join( if settings is not None:
m.get("content", "") for m in recent_messages if m.get("content") kb_caching_disabled = settings.caching_disabled
)
if embed_text:
query_embedding = await self._embedder.embed(embed_text)
result = await self._cache.semantic_search(query_embedding)
if result.hit:
latency_ms = (time.monotonic() - start) * 1000
await self._record_usage(
agent_name=agent_name,
model=result.response.model,
usage=result.response.usage,
cost=0.0,
latency_ms=latency_ms,
user_id=user_id,
department_ids=department_ids,
)
if _span is not None:
_span.set_attribute("gen_ai.cache.hit", True)
_span.set_attribute("gen_ai.cache.match_type", "semantic")
return result.response
except Exception as e: except Exception as e:
logger.warning(f"Semantic cache search failed: {e}") logger.warning(f"Failed to read KB cache settings for kb_id={kb_id}: {e}")
if self._cache_manager.should_cache(kb_caching_disabled, user_id):
cache_key = self._cache_manager.build_cache_key(
model=resolved_model,
messages=messages,
temperature=kwargs.get("temperature", 0.7),
tools=tools,
tool_choice=tool_choice,
max_tokens=kwargs.get("max_tokens", 2000),
user_id=user_id,
kb_acl_hash=kb_acl_hash,
)
kwargs["cache"] = LitellmCacheManager.cache_params_for_hit(cache_key)
else:
kwargs["cache"] = LitellmCacheManager.cache_params_for_no_cache()
# ── Normal provider call ── # ── Normal provider call ──
models_to_try = self._get_models_to_try(resolved_model) models_to_try = self._get_models_to_try(resolved_model)
@ -275,15 +222,15 @@ class LLMGateway:
latency_ms = (time.monotonic() - start) * 1000 latency_ms = (time.monotonic() - start) * 1000
# ── Cache write ── # U17 — 检测 LiteLLM 缓存命中(用于 usage tracking cost=0
if self._cache is not None and cache_key is not None: is_cache_hit = getattr(response, "cache_hit", False)
try: if is_cache_hit and self._cache_manager is not None:
await self._cache.put(cache_key, response, query_embedding) self._cache_manager._hits += 1
except Exception as e: elif self._cache_manager is not None:
logger.warning(f"Cache write failed: {e}") self._cache_manager._misses += 1
# 计算成本 # 计算成本(缓存命中时 cost=0
cost = self._calculate_cost(response.model, response.usage) cost = 0.0 if is_cache_hit else self._calculate_cost(response.model, response.usage)
# 记录使用量 # 记录使用量
await self._record_usage( await self._record_usage(
@ -302,8 +249,8 @@ class LLMGateway:
_span.set_attribute("gen_ai.usage.output_tokens", response.usage.completion_tokens) _span.set_attribute("gen_ai.usage.output_tokens", response.usage.completion_tokens)
_span.set_attribute("gen_ai.response.model", response.model) _span.set_attribute("gen_ai.response.model", response.model)
_span.set_attribute("gen_ai.duration.ms", int(latency_ms)) _span.set_attribute("gen_ai.duration.ms", int(latency_ms))
if self._cache is not None: if self._cache_manager is not None:
_span.set_attribute("gen_ai.cache.hit", False) _span.set_attribute("gen_ai.cache.hit", is_cache_hit)
llm_token_histogram().record( llm_token_histogram().record(
response.usage.total_tokens, response.usage.total_tokens,
{"gen_ai.request.model": resolved_model}, {"gen_ai.request.model": resolved_model},
@ -607,18 +554,10 @@ class LLMGateway:
) )
# 2. Token + cost limits (daily AND monthly) # 2. Token + cost limits (daily AND monthly)
await self._check_quota_period( await self._check_quota_period(quota_service, db, dept_id, "daily", "token_limit")
quota_service, db, dept_id, "daily", "token_limit" await self._check_quota_period(quota_service, db, dept_id, "daily", "cost_limit")
) await self._check_quota_period(quota_service, db, dept_id, "monthly", "token_limit")
await self._check_quota_period( await self._check_quota_period(quota_service, db, dept_id, "monthly", "cost_limit")
quota_service, db, dept_id, "daily", "cost_limit"
)
await self._check_quota_period(
quota_service, db, dept_id, "monthly", "token_limit"
)
await self._check_quota_period(
quota_service, db, dept_id, "monthly", "cost_limit"
)
async def _check_quota_period( async def _check_quota_period(
self, self,
@ -639,9 +578,7 @@ class LLMGateway:
else: else:
current = await self._get_current_cost_for_quota(dept_id, period) current = await self._get_current_cost_for_quota(dept_id, period)
allowed, _reason = await quota_service.check_quota( allowed, _reason = await quota_service.check_quota(db, dept_id, quota_type, period, current)
db, dept_id, quota_type, period, current
)
if not allowed: if not allowed:
quota = await quota_service.get_quota(db, dept_id, quota_type, period) quota = await quota_service.get_quota(db, dept_id, quota_type, period)
limit = quota["limit_value"] if quota else None limit = quota["limit_value"] if quota else None

View File

@ -47,6 +47,7 @@ class LLMRequest:
temperature: float = 0.7, temperature: float = 0.7,
max_tokens: int = 2000, max_tokens: int = 2000,
timeout: float | None = None, timeout: float | None = None,
cache: dict[str, Any] | None = None,
**kwargs: Any, **kwargs: Any,
): ):
self.messages = messages self.messages = messages
@ -57,6 +58,8 @@ class LLMRequest:
self.max_tokens = max_tokens self.max_tokens = max_tokens
self.timeout = timeout self.timeout = timeout
self._extra = kwargs self._extra = kwargs
# U17 — LiteLLM cache 参数cache_key 或 no-cache透传到 litellm.acompletion
self._cache: dict[str, Any] | None = cache
@dataclass @dataclass
@ -81,6 +84,7 @@ class LLMResponse:
usage: TokenUsage usage: TokenUsage
tool_calls: list[ToolCall] = field(default_factory=list) tool_calls: list[ToolCall] = field(default_factory=list)
latency_ms: float = 0.0 latency_ms: float = 0.0
cache_hit: bool = False # U17 — 缓存命中标记,用于 usage trackingcost=0
@property @property
def has_tool_calls(self) -> bool: def has_tool_calls(self) -> bool:

View File

@ -183,6 +183,10 @@ class LitellmProvider(LLMProvider):
kwargs["tool_choice"] = request.tool_choice kwargs["tool_choice"] = request.tool_choice
if request.timeout is not None: if request.timeout is not None:
kwargs["timeout"] = request.timeout kwargs["timeout"] = request.timeout
# U17 — 透传 LiteLLM cache 参数cache_key 或 no-cache到 litellm.acompletion
cache_params = getattr(request, "_cache", None)
if cache_params is not None:
kwargs["cache"] = cache_params
# 合并构造时传入的默认 kwargs如 max_connections 等provider特定参数 # 合并构造时传入的默认 kwargs如 max_connections 等provider特定参数
kwargs.update(self._default_kwargs) kwargs.update(self._default_kwargs)
return kwargs return kwargs
@ -214,12 +218,19 @@ class LitellmProvider(LLMProvider):
model_name = getattr(response, "model", None) or request_model model_name = getattr(response, "model", None) or request_model
# U17 — 检测 LiteLLM 缓存命中_hidden_params 含 cache_key 或 cache_hit
cache_hit = False
hidden = getattr(response, "_hidden_params", None)
if isinstance(hidden, dict) and ("cache_key" in hidden or hidden.get("cache_hit")):
cache_hit = True
return LLMResponse( return LLMResponse(
content=content, content=content,
model=model_name, model=model_name,
usage=usage, usage=usage,
tool_calls=tool_calls, tool_calls=tool_calls,
latency_ms=latency_ms, latency_ms=latency_ms,
cache_hit=cache_hit,
) )
def _parse_stream_chunk( def _parse_stream_chunk(

View File

@ -287,13 +287,13 @@ class TestCacheIntegration:
cache=CacheConfig(enabled=True, backend="memory"), cache=CacheConfig(enabled=True, backend="memory"),
) )
gateway = LLMGateway(config=config) gateway = LLMGateway(config=config)
assert gateway._cache is not None assert gateway._cache_manager is not None
def test_gateway_without_cache_config(self): def test_gateway_without_cache_config(self):
"""LLMGateway works without cache (default).""" """LLMGateway works without cache (default)."""
config = LLMConfig(providers={}) config = LLMConfig(providers={})
gateway = LLMGateway(config=config) gateway = LLMGateway(config=config)
assert gateway._cache is None assert gateway._cache_manager is None
def test_gateway_cache_disabled(self): def test_gateway_cache_disabled(self):
"""LLMGateway does not initialize cache when disabled.""" """LLMGateway does not initialize cache when disabled."""
@ -302,7 +302,7 @@ class TestCacheIntegration:
cache=CacheConfig(enabled=False), cache=CacheConfig(enabled=False),
) )
gateway = LLMGateway(config=config) gateway = LLMGateway(config=config)
assert gateway._cache is None assert gateway._cache_manager is None
# ── Graceful degradation tests ───────────────────────────── # ── Graceful degradation tests ─────────────────────────────
@ -390,7 +390,7 @@ class TestFullFlowInMemory:
redis_url=config.usage_store.get("redis_url", "redis://localhost:6379"), redis_url=config.usage_store.get("redis_url", "redis://localhost:6379"),
) )
gateway = LLMGateway(config=config.llm_config, usage_store=usage_store) gateway = LLMGateway(config=config.llm_config, usage_store=usage_store)
assert gateway._cache is not None assert gateway._cache_manager is not None
cascade_store = create_cascade_state_store( cascade_store = create_cascade_state_store(
backend=config.cascade_store.get("backend", "memory"), backend=config.cascade_store.get("backend", "memory"),

View File

@ -0,0 +1,486 @@
"""U17 — LiteLLM 语义缓存管理器单元测试。
覆盖场景plan 要求的 4 + 安全要求 e + 向后兼容 + fallback
1. 语义相似查询命中缓存mock litellm.acompletion 模拟缓存行为
2. 不同 system prompt 不命中缓存
3. 缓存命中率可统计
4. 阈值调优生效similarity_threshold 传入 RedisSemanticCache
5. User A 不返回 User B 缓存安全要求 e
6. kb_acl_hash 隔离 不同 ACL hash 产生不同 key
7. kb_caching_disabled 禁用缓存安全要求 c
8. cache_params_for_hit / no_cache 返回正确 dict
9. detect_cache_hit _hidden_params cache_key 时返回 True
10. LitellmCacheConfig.from_cache_config 转换正确similarity_threshold=0.87
11. LitellmCacheManager.enable/disable litellm.cache 正确设置/清除
12. generate_cache_key 向后兼容 user_id=None, kb_acl_hash=None 时与旧版相同
13. RedisSemanticCache fallback redisvl 缺失时 auto backend 回退
测试用 ``unittest.mock.patch`` mock ``litellm.cache`` 全局变量避免测试间污染
每个测试后清理 ``litellm.cache = None``
"""
from __future__ import annotations
from types import SimpleNamespace
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
from agentkit.llm.cache import LitellmCacheConfig, LitellmCacheManager
from agentkit.llm.cache_key import generate_cache_key
from agentkit.llm.config import CacheConfig
# ---------------------------------------------------------------------------
# 测试辅助
# ---------------------------------------------------------------------------
def _make_messages(user_content: str = "Hello") -> list[dict[str, str]]:
return [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": user_content},
]
def _make_litellm_response(
content: str = "Hello!",
model: str = "gpt-4o-mini",
prompt_tokens: int = 10,
completion_tokens: int = 5,
cache_key: str | None = None,
) -> SimpleNamespace:
"""构造 LiteLLM 响应OpenAI ChatCompletion 格式),可选 cache_key 标记。"""
hidden_params: dict[str, Any] = {}
if cache_key is not None:
hidden_params["cache_key"] = cache_key
message = SimpleNamespace(content=content, tool_calls=None)
return SimpleNamespace(
choices=[SimpleNamespace(message=message)],
usage=SimpleNamespace(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens),
model=model,
_hidden_params=hidden_params,
)
@pytest.fixture(autouse=True)
def _cleanup_litellm_cache():
"""每个测试后清理 litellm.cache 全局变量,避免测试间污染。"""
yield
import litellm
litellm.cache = None
# ---------------------------------------------------------------------------
# 1. 语义相似查询命中缓存
# ---------------------------------------------------------------------------
class TestCacheHit:
"""场景 1 — 相同请求第二次命中缓存。"""
async def test_second_request_is_cache_hit(self):
"""相同 cache_key 的第二次请求返回缓存响应cache_hit=True"""
config = LitellmCacheConfig(enabled=True, backend="memory")
manager = LitellmCacheManager(config)
msgs = _make_messages()
cache_key = manager.build_cache_key("gpt-4o", msgs, 0.0)
cache_params = LitellmCacheManager.cache_params_for_hit(cache_key)
# 模拟 LiteLLM 缓存行为:第一次 miss第二次 hit
call_count = 0
cache_store: dict[str, SimpleNamespace] = {}
async def fake_acompletion(**kwargs):
nonlocal call_count
call_count += 1
ck = kwargs.get("cache", {}).get("cache_key")
if ck and ck in cache_store:
# 缓存命中 — 返回带 cache_key 标记的响应
return _make_litellm_response(content="Cached!", cache_key=ck)
# 缓存未命中 — 调用"真实" API 并存储
resp = _make_litellm_response(content="Fresh response")
if ck:
cache_store[ck] = resp
return resp
with patch("litellm.acompletion", side_effect=fake_acompletion):
from agentkit.llm.providers.litellm_provider import LitellmProvider
provider = LitellmProvider(model_prefix="openai/", api_key="test")
from agentkit.llm.protocol import LLMRequest
# 第一次请求 — miss
req1 = LLMRequest(messages=msgs, model="gpt-4o", temperature=0.0, cache=cache_params)
resp1 = await provider.chat(req1)
assert resp1.cache_hit is False
assert resp1.content == "Fresh response"
# 第二次相同请求 — hit
req2 = LLMRequest(messages=msgs, model="gpt-4o", temperature=0.0, cache=cache_params)
resp2 = await provider.chat(req2)
assert resp2.cache_hit is True
assert resp2.content == "Cached!"
assert call_count == 2 # litellm.acompletion 被调用两次LiteLLM 内部处理缓存)
# ---------------------------------------------------------------------------
# 2. 不同 system prompt 不命中缓存
# ---------------------------------------------------------------------------
class TestSystemPromptIsolation:
"""场景 2 — 不同 system prompt 产生不同 cache_key不误命中。"""
def test_different_system_prompt_different_key(self):
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
msgs1 = [
{"role": "system", "content": "Be concise"},
{"role": "user", "content": "Hello"},
]
msgs2 = [
{"role": "system", "content": "Be verbose"},
{"role": "user", "content": "Hello"},
]
key1 = manager.build_cache_key("gpt-4o", msgs1, 0.0)
key2 = manager.build_cache_key("gpt-4o", msgs2, 0.0)
assert key1 != key2
# ---------------------------------------------------------------------------
# 3. 缓存命中率可统计
# ---------------------------------------------------------------------------
class TestCacheStats:
"""场景 3 — LitellmCacheManager.stats() 返回正确 hits/misses。"""
def test_stats_after_hits_and_misses(self):
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
# 2 hits
manager.detect_cache_hit(_make_litellm_response(cache_key="k1"))
manager.detect_cache_hit(_make_litellm_response(cache_key="k2"))
# 1 miss
manager.detect_cache_hit(_make_litellm_response()) # 无 cache_key
stats = manager.stats()
assert stats["total_hits"] == 2
assert stats["total_misses"] == 1
# ---------------------------------------------------------------------------
# 4. 阈值调优生效
# ---------------------------------------------------------------------------
class TestSimilarityThreshold:
"""场景 4 — similarity_threshold 传入 RedisSemanticCache 构造函数。"""
def test_threshold_passed_to_redis_semantic_cache(self):
"""RedisSemanticCache 构造时接收 similarity_threshold=0.87。"""
config = LitellmCacheConfig(
enabled=True,
backend="redis_semantic",
similarity_threshold=0.83,
redis_url="redis://localhost:6379",
)
manager = LitellmCacheManager(config)
mock_cache_instance = MagicMock()
with patch(
"litellm.caching.RedisSemanticCache", return_value=mock_cache_instance
) as mock_cls:
instance = manager._create_cache_instance()
mock_cls.assert_called_once_with(
redis_url="redis://localhost:6379",
similarity_threshold=0.83,
embedding_model="text-embedding-ada-002",
)
assert instance is mock_cache_instance
def test_default_threshold_is_087(self):
"""from_cache_config 固定 similarity_threshold=0.87。"""
old_config = CacheConfig(similarity_threshold=0.92)
litellm_config = LitellmCacheConfig.from_cache_config(old_config)
assert litellm_config.similarity_threshold == 0.87
# ---------------------------------------------------------------------------
# 5. User A 不返回 User B 缓存(安全要求 e
# ---------------------------------------------------------------------------
class TestUserIsolation:
"""安全要求 e — User A 的查询不返回 User B 的缓存响应。"""
def test_different_users_different_keys(self):
"""不同 user_id 产生不同 cache_key。"""
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
msgs = _make_messages("What is my salary?")
key_a = manager.build_cache_key("gpt-4o", msgs, 0.0, user_id="user_a")
key_b = manager.build_cache_key("gpt-4o", msgs, 0.0, user_id="user_b")
assert key_a != key_b
def test_same_user_same_key(self):
"""相同 user_id 产生相同 cache_key。"""
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
msgs = _make_messages("What is my salary?")
key1 = manager.build_cache_key("gpt-4o", msgs, 0.0, user_id="user_a")
key2 = manager.build_cache_key("gpt-4o", msgs, 0.0, user_id="user_a")
assert key1 == key2
def test_no_user_id_same_as_no_user_id(self):
"""user_id=None 时两次调用产生相同 key向后兼容"""
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
msgs = _make_messages()
key1 = manager.build_cache_key("gpt-4o", msgs, 0.0, user_id=None)
key2 = manager.build_cache_key("gpt-4o", msgs, 0.0, user_id=None)
assert key1 == key2
# ---------------------------------------------------------------------------
# 6. kb_acl_hash 隔离
# ---------------------------------------------------------------------------
class TestACLIsolation:
"""安全要求 b — 不同 ACL scope 产生不同 cache_key。"""
def test_different_acl_hash_different_keys(self):
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
msgs = _make_messages("Summarize the document")
key1 = manager.build_cache_key("gpt-4o", msgs, 0.0, kb_acl_hash="acl_v1")
key2 = manager.build_cache_key("gpt-4o", msgs, 0.0, kb_acl_hash="acl_v2")
assert key1 != key2
def test_same_acl_hash_same_key(self):
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
msgs = _make_messages()
key1 = manager.build_cache_key("gpt-4o", msgs, 0.0, kb_acl_hash="acl_v1")
key2 = manager.build_cache_key("gpt-4o", msgs, 0.0, kb_acl_hash="acl_v1")
assert key1 == key2
# ---------------------------------------------------------------------------
# 7. kb_caching_disabled 禁用缓存
# ---------------------------------------------------------------------------
class TestKBCachingDisabled:
"""安全要求 c — KB 设置 caching_disabled=True 时禁用缓存。"""
def test_should_cache_returns_false_when_disabled(self):
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
assert manager.should_cache(kb_caching_disabled=True) is False
def test_should_cache_returns_true_when_enabled(self):
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
assert manager.should_cache(kb_caching_disabled=False) is True
def test_should_cache_default_true(self):
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
assert manager.should_cache() is True
# ---------------------------------------------------------------------------
# 8. cache_params_for_hit / no_cache
# ---------------------------------------------------------------------------
class TestCacheParams:
def test_cache_params_for_hit(self):
params = LitellmCacheManager.cache_params_for_hit("my_cache_key")
assert params == {"cache_key": "my_cache_key"}
def test_cache_params_for_no_cache(self):
params = LitellmCacheManager.cache_params_for_no_cache()
assert params == {"no-cache": True}
# ---------------------------------------------------------------------------
# 9. detect_cache_hit
# ---------------------------------------------------------------------------
class TestDetectCacheHit:
def test_hit_with_cache_key(self):
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
resp = _make_litellm_response(cache_key="some_key")
assert manager.detect_cache_hit(resp) is True
def test_hit_with_cache_hit_flag(self):
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
resp = _make_litellm_response()
resp._hidden_params = {"cache_hit": True}
assert manager.detect_cache_hit(resp) is True
def test_miss_without_cache_key(self):
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
resp = _make_litellm_response()
assert manager.detect_cache_hit(resp) is False
def test_miss_with_no_hidden_params(self):
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
resp = SimpleNamespace(_hidden_params=None)
assert manager.detect_cache_hit(resp) is False
def test_miss_with_no_hidden_params_attr(self):
manager = LitellmCacheManager(LitellmCacheConfig(backend="memory"))
resp = SimpleNamespace()
assert manager.detect_cache_hit(resp) is False
# ---------------------------------------------------------------------------
# 10. LitellmCacheConfig.from_cache_config
# ---------------------------------------------------------------------------
class TestFromCacheConfig:
def test_basic_conversion(self):
old = CacheConfig(
enabled=True,
backend="memory",
redis_url="redis://myhost:6379",
semantic_ttl=3600,
embedding_model="bge-m3",
)
litellm_config = LitellmCacheConfig.from_cache_config(old)
assert litellm_config.enabled is True
assert litellm_config.backend == "memory"
assert litellm_config.redis_url == "redis://myhost:6379"
assert litellm_config.ttl == 3600
assert litellm_config.similarity_threshold == 0.87 # 固定,忽略旧 0.92
assert litellm_config.per_user_namespace is True # 强制开启
def test_unknown_backend_falls_to_auto(self):
old = CacheConfig(backend="unknown_backend")
litellm_config = LitellmCacheConfig.from_cache_config(old)
assert litellm_config.backend == "auto"
def test_redis_backend_preserved(self):
old = CacheConfig(backend="redis")
litellm_config = LitellmCacheConfig.from_cache_config(old)
assert litellm_config.backend == "redis"
def test_empty_embedding_model_falls_to_default(self):
old = CacheConfig(embedding_model="")
litellm_config = LitellmCacheConfig.from_cache_config(old)
assert litellm_config.embedding_model == "text-embedding-ada-002"
# ---------------------------------------------------------------------------
# 11. LitellmCacheManager.enable/disable
# ---------------------------------------------------------------------------
class TestEnableDisable:
def test_enable_sets_litellm_cache(self):
import litellm
config = LitellmCacheConfig(backend="memory")
manager = LitellmCacheManager(config)
manager.enable()
assert litellm.cache is not None
assert manager._cache_instance is not None
def test_disable_clears_litellm_cache(self):
import litellm
config = LitellmCacheConfig(backend="memory")
manager = LitellmCacheManager(config)
manager.enable()
assert litellm.cache is not None
manager.disable()
assert litellm.cache is None
assert manager._cache_instance is None
# ---------------------------------------------------------------------------
# 12. generate_cache_key 向后兼容
# ---------------------------------------------------------------------------
class TestBackwardCompatibility:
"""user_id=None, kb_acl_hash=None 时与旧版 generate_cache_key 完全一致。"""
def test_none_user_id_same_as_not_passing(self):
msgs = _make_messages()
key1 = generate_cache_key("gpt-4o", msgs, 0.0, user_id=None, kb_acl_hash=None)
key2 = generate_cache_key("gpt-4o", msgs, 0.0)
assert key1 == key2
def test_backward_compat_deterministic(self):
msgs = _make_messages()
key1 = generate_cache_key("gpt-4o", msgs, 0.0, user_id=None, kb_acl_hash=None)
key2 = generate_cache_key("gpt-4o", msgs, 0.0, user_id=None, kb_acl_hash=None)
assert key1 == key2
def test_user_id_changes_key(self):
msgs = _make_messages()
key_none = generate_cache_key("gpt-4o", msgs, 0.0)
key_user = generate_cache_key("gpt-4o", msgs, 0.0, user_id="user_a")
assert key_none != key_user
def test_acl_hash_changes_key(self):
msgs = _make_messages()
key_none = generate_cache_key("gpt-4o", msgs, 0.0)
key_acl = generate_cache_key("gpt-4o", msgs, 0.0, kb_acl_hash="acl_v1")
assert key_none != key_acl
# ---------------------------------------------------------------------------
# 13. RedisSemanticCache fallback
# ---------------------------------------------------------------------------
class TestRedisSemanticCacheFallback:
"""redisvl 缺失时 auto backend 回退到 RedisCache。"""
def test_auto_fallback_to_redis_cache_when_redisvl_missing(self):
"""auto 模式下 RedisSemanticCache 构造 ImportError → 回退到 RedisCache。"""
import litellm.caching
config = LitellmCacheConfig(backend="auto", redis_url="redis://localhost:6379")
manager = LitellmCacheManager(config)
mock_redis_cache = MagicMock(name="RedisCacheInstance")
# 模拟 RedisSemanticCache 构造时 ImportErrorredisvl 未安装)
def raise_import_error(*args, **kwargs):
raise ImportError("No module named 'redisvl'")
with patch.object(litellm.caching, "RedisSemanticCache", side_effect=raise_import_error):
with patch.object(
litellm.caching, "RedisCache", return_value=mock_redis_cache
) as mock_rc:
instance = manager._create_cache_instance()
mock_rc.assert_called_once_with(redis_url="redis://localhost:6379")
assert instance is mock_redis_cache
def test_redis_semantic_backend_raises_when_redisvl_missing(self):
"""显式 redis_semantic backend + redisvl 缺失 → raise ImportError。"""
import litellm.caching
config = LitellmCacheConfig(backend="redis_semantic")
manager = LitellmCacheManager(config)
def raise_import_error(*args, **kwargs):
raise ImportError("No module named 'redisvl'")
with patch.object(litellm.caching, "RedisSemanticCache", side_effect=raise_import_error):
with pytest.raises(ImportError):
manager._create_cache_instance()
def test_memory_backend_uses_in_memory_cache(self):
"""memory backend 直接使用 InMemoryCache不尝试 Redis。"""
import litellm.caching
config = LitellmCacheConfig(backend="memory")
manager = LitellmCacheManager(config)
instance = manager._create_cache_instance()
assert isinstance(instance, litellm.caching.InMemoryCache)

View File

@ -1,15 +1,20 @@
"""Integration tests for LLM Cache integration into LLMGateway (U2).""" """Integration tests for LLM Cache integration into LLMGateway (U2/U17).
U17 更新gateway 改用 ``LitellmCacheManager``LiteLLM 内置缓存
旧的 ``InMemoryLLMCache`` 手动缓存逻辑已移除缓存读写由 LiteLLM 内部处理
测试用 ``CacheAwareMockProvider`` 模拟 LiteLLM 的缓存行为检查 cache_key
命中时返回 ``cache_hit=True`` 的响应
"""
import pytest import pytest
from agentkit.llm.cache import InMemoryLLMCache
from agentkit.llm.config import CacheConfig, LLMConfig from agentkit.llm.config import CacheConfig, LLMConfig
from agentkit.llm.gateway import LLMGateway from agentkit.llm.gateway import LLMGateway
from agentkit.llm.protocol import LLMProvider, LLMRequest, LLMResponse, TokenUsage from agentkit.llm.protocol import LLMProvider, LLMRequest, LLMResponse, TokenUsage
class MockProvider(LLMProvider): class MockProvider(LLMProvider):
"""Mock LLM provider that tracks call count.""" """Mock LLM provider that tracks call count (no cache awareness)."""
def __init__(self, response_content: str = "Mock response"): def __init__(self, response_content: str = "Mock response"):
self.call_count = 0 self.call_count = 0
@ -24,6 +29,47 @@ class MockProvider(LLMProvider):
) )
class CacheAwareMockProvider(LLMProvider):
"""Mock provider that simulates LiteLLM's cache behavior.
Reads ``request._cache`` for cache_key, maintains an internal cache dict.
On cache hit: returns cached response with ``cache_hit=True`` (call_count NOT incremented).
On cache miss: generates fresh response, caches it, returns with ``cache_hit=False``.
"""
def __init__(self, response_content: str = "Mock response"):
self.call_count = 0 # 仅统计真实调用(缓存未命中)
self._response_content = response_content
self._cache: dict[str, LLMResponse] = {}
async def chat(self, request: LLMRequest) -> LLMResponse:
cache_params = getattr(request, "_cache", None) or {}
cache_key = cache_params.get("cache_key")
no_cache = cache_params.get("no-cache", False)
# 缓存命中 — 返回缓存响应(不增加 call_count
if cache_key and cache_key in self._cache and not no_cache:
cached = self._cache[cache_key]
return LLMResponse(
content=cached.content,
model=cached.model,
usage=cached.usage,
cache_hit=True,
)
# 缓存未命中 — 真实调用
self.call_count += 1
response = LLMResponse(
content=self._response_content,
model=request.model,
usage=TokenUsage(prompt_tokens=10, completion_tokens=20),
cache_hit=False,
)
if cache_key and not no_cache:
self._cache[cache_key] = response
return response
def _make_messages(user_content: str = "Hello") -> list[dict[str, str]]: def _make_messages(user_content: str = "Hello") -> list[dict[str, str]]:
return [ return [
{"role": "system", "content": "You are a helpful assistant."}, {"role": "system", "content": "You are a helpful assistant."},
@ -52,7 +98,7 @@ class TestCacheEnabled:
"""First request is a cache miss — provider is called.""" """First request is a cache miss — provider is called."""
config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory")) config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory"))
gateway = LLMGateway(config=config) gateway = LLMGateway(config=config)
provider = MockProvider() provider = CacheAwareMockProvider()
gateway.register_provider("test", provider) gateway.register_provider("test", provider)
msgs = _make_messages() msgs = _make_messages()
@ -60,28 +106,30 @@ class TestCacheEnabled:
assert provider.call_count == 1 assert provider.call_count == 1
assert response.content == "Mock response" assert response.content == "Mock response"
assert response.cache_hit is False
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_second_request_is_hit(self): async def test_second_request_is_hit(self):
"""Second identical request is a cache hit — provider NOT called.""" """Second identical request is a cache hit — provider NOT called again."""
config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory")) config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory"))
gateway = LLMGateway(config=config) gateway = LLMGateway(config=config)
provider = MockProvider() provider = CacheAwareMockProvider()
gateway.register_provider("test", provider) gateway.register_provider("test", provider)
msgs = _make_messages() msgs = _make_messages()
await gateway.chat(msgs, "test/model", temperature=0.0) await gateway.chat(msgs, "test/model", temperature=0.0)
response = await gateway.chat(msgs, "test/model", temperature=0.0) response = await gateway.chat(msgs, "test/model", temperature=0.0)
assert provider.call_count == 1 # Not called again assert provider.call_count == 1 # Not called again (cache hit)
assert response.content == "Mock response" assert response.content == "Mock response"
assert response.cache_hit is True
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cache_hit_usage_has_zero_cost(self): async def test_cache_hit_usage_has_zero_cost(self):
"""Cache hit records usage with cost=0.""" """Cache hit records usage with cost=0."""
config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory")) config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory"))
gateway = LLMGateway(config=config) gateway = LLMGateway(config=config)
provider = MockProvider() provider = CacheAwareMockProvider()
gateway.register_provider("test", provider) gateway.register_provider("test", provider)
msgs = _make_messages() msgs = _make_messages()
@ -98,7 +146,7 @@ class TestCacheEnabled:
"""Different messages produce cache misses.""" """Different messages produce cache misses."""
config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory")) config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory"))
gateway = LLMGateway(config=config) gateway = LLMGateway(config=config)
provider = MockProvider() provider = CacheAwareMockProvider()
gateway.register_provider("test", provider) gateway.register_provider("test", provider)
await gateway.chat(_make_messages("Hello"), "test/model", temperature=0.0) await gateway.chat(_make_messages("Hello"), "test/model", temperature=0.0)
@ -110,13 +158,15 @@ class TestCacheEnabled:
class TestCacheConfig: class TestCacheConfig:
def test_config_from_dict(self): def test_config_from_dict(self):
"""CacheConfig can be loaded from dict.""" """CacheConfig can be loaded from dict."""
config = LLMConfig.from_dict({ config = LLMConfig.from_dict(
"cache": { {
"enabled": True, "cache": {
"backend": "memory", "enabled": True,
"exact_ttl": 7200, "backend": "memory",
"exact_ttl": 7200,
}
} }
}) )
assert config.cache is not None assert config.cache is not None
assert config.cache.enabled is True assert config.cache.enabled is True
assert config.cache.backend == "memory" assert config.cache.backend == "memory"
@ -129,34 +179,35 @@ class TestCacheConfig:
def test_config_from_dict_embedding(self): def test_config_from_dict_embedding(self):
"""Embedding config is loaded correctly.""" """Embedding config is loaded correctly."""
config = LLMConfig.from_dict({ config = LLMConfig.from_dict(
"cache": { {
"enabled": True, "cache": {
"embedding": { "enabled": True,
"provider": "xinference", "embedding": {
"model": "bge-m3", "provider": "xinference",
"base_url": "http://localhost:9997/v1", "model": "bge-m3",
}, "base_url": "http://localhost:9997/v1",
},
}
} }
}) )
assert config.cache.embedding_provider == "xinference" assert config.cache.embedding_provider == "xinference"
assert config.cache.embedding_model == "bge-m3" assert config.cache.embedding_model == "bge-m3"
assert config.cache.embedding_base_url == "http://localhost:9997/v1" assert config.cache.embedding_base_url == "http://localhost:9997/v1"
def test_gateway_creates_cache_when_enabled(self): def test_gateway_creates_cache_when_enabled(self):
"""Gateway creates cache instance when cache.enabled=True.""" """Gateway creates cache_manager instance when cache.enabled=True."""
config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory")) config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory"))
gateway = LLMGateway(config=config) gateway = LLMGateway(config=config)
assert gateway._cache is not None assert gateway._cache_manager is not None
assert isinstance(gateway._cache, InMemoryLLMCache)
def test_gateway_no_cache_when_disabled(self): def test_gateway_no_cache_when_disabled(self):
"""Gateway has no cache when cache is disabled.""" """Gateway has no cache_manager when cache is disabled."""
config = LLMConfig(cache=CacheConfig(enabled=False)) config = LLMConfig(cache=CacheConfig(enabled=False))
gateway = LLMGateway(config=config) gateway = LLMGateway(config=config)
assert gateway._cache is None assert gateway._cache_manager is None
def test_gateway_no_cache_when_no_config(self): def test_gateway_no_cache_when_no_config(self):
"""Gateway has no cache when cache config is absent.""" """Gateway has no cache_manager when cache config is absent."""
gateway = LLMGateway() gateway = LLMGateway()
assert gateway._cache is None assert gateway._cache_manager is None