diff --git a/src/agentkit/llm/cache.py b/src/agentkit/llm/cache.py index cc7dda6..108d209 100644 --- a/src/agentkit/llm/cache.py +++ b/src/agentkit/llm/cache.py @@ -14,11 +14,14 @@ import logging import time from collections import OrderedDict from dataclasses import dataclass, field -from typing import Any, Protocol, runtime_checkable +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable from agentkit.llm.protocol import LLMResponse, TokenUsage, ToolCall from agentkit.utils.vector_math import compute_cosine_similarity +if TYPE_CHECKING: + from agentkit.llm.config import CacheConfig + logger = logging.getLogger(__name__) @@ -61,8 +64,7 @@ def _serialize_response(response: LLMResponse) -> dict: "completion_tokens": response.usage.completion_tokens, }, "tool_calls": [ - {"id": tc.id, "name": tc.name, "arguments": tc.arguments} - for tc in response.tool_calls + {"id": tc.id, "name": tc.name, "arguments": tc.arguments} for tc in response.tool_calls ], "latency_ms": response.latency_ms, } @@ -236,7 +238,9 @@ class InMemoryLLMCache: self._cache.move_to_end(key) existing = self._cache[key] # Preserve existing embedding if new one is None - effective_embedding = query_embedding if query_embedding is not None else existing.query_embedding + effective_embedding = ( + query_embedding if query_embedding is not None else existing.query_embedding + ) else: effective_embedding = query_embedding or [] @@ -276,9 +280,7 @@ class InMemoryLLMCache: return count # Simple prefix matching for pattern - keys_to_remove = [ - k for k in self._cache if k.startswith(pattern.replace("*", "")) - ] + keys_to_remove = [k for k in self._cache if k.startswith(pattern.replace("*", ""))] for key in keys_to_remove: del self._cache[key] self._embeddings.pop(key, None) @@ -338,9 +340,7 @@ class RedisLLMCache: if self._redis is None: import redis.asyncio as aioredis - self._redis = aioredis.from_url( - self._redis_url, decode_responses=True - ) + self._redis = aioredis.from_url(self._redis_url, decode_responses=True) return self._redis async def aclose(self) -> None: @@ -393,9 +393,7 @@ class RedisLLMCache: if data is not None: entry = _deserialize_entry(json.loads(data)) self._hits += 1 - return CacheResult( - hit=True, response=entry.response, match_type="exact" - ) + return CacheResult(hit=True, response=entry.response, match_type="exact") self._misses += 1 return CacheResult(hit=False) except Exception as e: @@ -424,6 +422,7 @@ class RedisLLMCache: if len(cache_keys_list) > max_scan: # Take a random sample to avoid always scanning the same subset import random + cache_keys_list = random.sample(cache_keys_list, max_scan) # Batch fetch embeddings @@ -460,9 +459,7 @@ class RedisLLMCache: if data is not None: entry = _deserialize_entry(json.loads(data)) self._hits += 1 - return CacheResult( - hit=True, response=entry.response, match_type="semantic" - ) + return CacheResult(hit=True, response=entry.response, match_type="semantic") # Data key expired but embedding still exists — mark for cleanup try: await redis.srem(self.INDEX_KEY, best_key) @@ -615,9 +612,7 @@ def create_llm_cache( similarity_threshold=similarity_threshold, ) except ImportError: - logger.warning( - "redis package not available, falling back to in-memory cache" - ) + logger.warning("redis package not available, falling back to in-memory cache") return InMemoryLLMCache( max_entries=max_entries, exact_ttl=exact_ttl, @@ -630,3 +625,200 @@ def create_llm_cache( semantic_ttl=semantic_ttl, similarity_threshold=similarity_threshold, ) + + +# --------------------------------------------------------------------------- +# U17 — LiteLLM 缓存管理器 +# --------------------------------------------------------------------------- + + +@dataclass +class LitellmCacheConfig: + """U17 — LiteLLM 缓存配置(从 CacheConfig 转换)。 + + 与旧 ``CacheConfig`` 的区别: + - ``similarity_threshold`` 固定默认 0.87(plan 规定) + - ``per_user_namespace`` 强制开启(安全要求 a) + - ``backend`` 新增 ``redis_semantic`` 选项(需要 redisvl) + """ + + enabled: bool = False + backend: str = "auto" # "auto" | "redis_semantic" | "redis" | "memory" + redis_url: str = "redis://localhost:6379" + similarity_threshold: float = 0.87 # U17 默认 0.87(plan 规定) + ttl: int = 86400 + embedding_model: str = "text-embedding-ada-002" + per_user_namespace: bool = True # 安全要求 (a) + + @classmethod + def from_cache_config(cls, c: "CacheConfig") -> "LitellmCacheConfig": + """从现有 CacheConfig 转换。 + + - ``similarity_threshold`` 固定 0.87(U17 plan 规定,忽略旧 0.92) + - ``per_user_namespace`` 强制 True(安全要求 a) + - ``embedding_model`` 回退到 "text-embedding-ada-002"(LiteLLM 默认) + """ + return cls( + enabled=c.enabled, + backend=c.backend if c.backend in ("auto", "redis", "memory") else "auto", + redis_url=c.redis_url, + similarity_threshold=0.87, # U17 固定默认,忽略旧 0.92 + ttl=c.semantic_ttl, + embedding_model=c.embedding_model or "text-embedding-ada-002", + per_user_namespace=True, # 强制开启 + ) + + +class LitellmCacheManager: + """U17 — LiteLLM 全局缓存管理器。 + + 职责: + 1. 创建并设置 ``litellm.cache`` 全局实例 + 2. 构建带 user/ACL scope 的 cache key(安全要求 a, b) + 3. 提供 per-call cache 参数(cache_key 或 no-cache) + 4. 检测 LiteLLM 响应的缓存命中标志(用于 usage tracking) + 5. 统计缓存命中率 + + 后端选择优先级(backend="auto" 时): + RedisSemanticCache(需 redisvl)→ RedisCache(精确)→ InMemoryCache + + 安全约束: + - (a) cache key 包含 user_id(per-user namespace) + - (b) cache key 包含 kb_acl_hash(ACL-scope 隔离) + - (c) KB 设置 caching_disabled=True 时禁用缓存 + - (e) User A 的查询不会命中 User B 的缓存 + """ + + def __init__(self, config: LitellmCacheConfig): + self._config = config + self._cache_instance: Any = None # litellm.caching.Cache 实例 + self._hits = 0 + self._misses = 0 + + def enable(self) -> None: + """创建 LiteLLM Cache 实例并赋值给 ``litellm.cache``。""" + import litellm + + self._cache_instance = self._create_cache_instance() + litellm.cache = self._cache_instance + + def disable(self) -> None: + """禁用缓存 — 设置 ``litellm.cache = None``。""" + import litellm + + litellm.cache = None + self._cache_instance = None + + def _create_cache_instance(self) -> Any: + """根据 backend 配置创建 LiteLLM Cache 实例。 + + auto 模式按优先级尝试:RedisSemanticCache → RedisCache → InMemoryCache。 + redisvl 缺失时自动回退(安全要求 d — 不添加为必需依赖)。 + """ + backend = self._config.backend + if backend in ("auto", "redis_semantic"): + # 尝试 RedisSemanticCache(需要 redisvl — lazy import,缺失时 fallback) + try: + from litellm.caching import RedisSemanticCache + + return RedisSemanticCache( + redis_url=self._config.redis_url, + similarity_threshold=self._config.similarity_threshold, + embedding_model=self._config.embedding_model, + ) + except ImportError: + logger.warning( + "RedisSemanticCache 需要 redisvl 包(未安装)," + "回退到 RedisCache(精确匹配,无语义搜索)。" + "安装 redisvl 以启用语义缓存:pip install redisvl" + ) + if backend == "redis_semantic": + raise # 显式要求语义缓存但 redisvl 缺失 — 报错 + except Exception as e: + logger.warning(f"RedisSemanticCache 初始化失败: {e},回退到 RedisCache") + + if backend in ("auto", "redis", "redis_semantic"): + try: + from litellm.caching import RedisCache + + return RedisCache(redis_url=self._config.redis_url) + except Exception as e: + logger.warning(f"RedisCache 初始化失败: {e},回退到 InMemoryCache") + + from litellm.caching import InMemoryCache + + return InMemoryCache() + + def build_cache_key( + self, + model: str, + messages: list[dict[str, str]], + temperature: float, + tools: list[dict] | None = None, + tool_choice: str = "auto", + max_tokens: int = 2000, + user_id: str | None = None, + kb_acl_hash: str | None = None, + ) -> str: + """构建带 user/ACL scope 的 cache key(安全要求 a, b, e)。 + + 委托给 ``cache_key.generate_cache_key``,额外注入 user_id + kb_acl_hash + 作为命名空间隔离,确保 User A 的查询不会命中 User B 的缓存。 + """ + from agentkit.llm.cache_key import generate_cache_key + + return generate_cache_key( + model=model, + messages=messages, + temperature=temperature, + tools=tools, + tool_choice=tool_choice, + max_tokens=max_tokens, + user_id=user_id, + kb_acl_hash=kb_acl_hash, + ) + + def should_cache( + self, + kb_caching_disabled: bool = False, + user_id: str | None = None, + ) -> bool: + """判断当前请求是否应该缓存(安全要求 c)。 + + - KB 设置 caching_disabled=True → 不缓存 + - 其余情况缓存(user_id 为 None 时仍可缓存,但 key 不含 user scope) + """ + _ = user_id # 预留:未来支持 per-user 缓存禁用 + if kb_caching_disabled: + return False + return True + + @staticmethod + def cache_params_for_hit(cache_key: str) -> dict[str, str]: + """返回 litellm acompletion 的 cache 参数(用于期望命中的调用)。""" + return {"cache_key": cache_key} + + @staticmethod + def cache_params_for_no_cache() -> dict[str, bool]: + """返回 litellm acompletion 的 cache 参数(禁用缓存)。""" + return {"no-cache": True} + + def detect_cache_hit(self, response: Any) -> bool: + """检测 LiteLLM 响应是否为缓存命中。 + + LiteLLM 在缓存命中时设置 ``response._hidden_params["cache_key"]``。 + """ + hidden = getattr(response, "_hidden_params", None) + if isinstance(hidden, dict): + if "cache_key" in hidden or hidden.get("cache_hit"): + self._hits += 1 + return True + self._misses += 1 + return False + + def stats(self) -> dict[str, int]: + """返回缓存统计。""" + return { + "total_hits": self._hits, + "total_misses": self._misses, + } diff --git a/src/agentkit/llm/cache_key.py b/src/agentkit/llm/cache_key.py index 63eb2e4..362b993 100644 --- a/src/agentkit/llm/cache_key.py +++ b/src/agentkit/llm/cache_key.py @@ -12,6 +12,8 @@ def generate_cache_key( tools: list[dict[str, Any]] | None = None, tool_choice: str = "auto", max_tokens: int = 2000, + user_id: str | None = None, + kb_acl_hash: str | None = None, ) -> str: """Generate a deterministic SHA-256 cache key from LLM request parameters. @@ -19,6 +21,10 @@ def generate_cache_key( model, system_prompt (extracted from messages), messages content, temperature, tools, tool_choice, and max_tokens. + U17 安全扩展:``user_id`` 和 ``kb_acl_hash`` 非 None 时加入 hash 组件, + 实现 per-user namespace 和 ACL-scope 隔离(安全要求 a, b)。为 None 时 + 行为与旧版完全一致(向后兼容)。 + Args: model: Model identifier (e.g. "openai/gpt-4o"). messages: Chat messages list (may include system prompt as first message). @@ -26,6 +32,8 @@ def generate_cache_key( tools: Optional list of tool definitions. tool_choice: Tool selection mode ("auto", "none", etc.). max_tokens: Maximum response tokens. + user_id: U17 — 用户 ID,用于 per-user cache namespace 隔离。 + kb_acl_hash: U17 — KB ACL-scope hash,用于 ACL 隔离缓存键。 Returns: 64-character hex SHA-256 hash string. @@ -40,6 +48,11 @@ def generate_cache_key( _hash_str(tool_choice), _hash_str(str(max_tokens)), ] + # U17 — per-user namespace + ACL scope hash(安全要求 a, b, e) + if user_id is not None: + components.append(_hash_str(f"user:{user_id}")) + if kb_acl_hash is not None: + components.append(_hash_str(f"acl:{kb_acl_hash}")) combined = "".join(components) return hashlib.sha256(combined.encode()).hexdigest() @@ -61,6 +74,4 @@ def _hash_json(obj: Any) -> str: """SHA-256 hash of a JSON-serializable object.""" if obj is None: return hashlib.sha256(b"null").hexdigest() - return hashlib.sha256( - json.dumps(obj, sort_keys=True, ensure_ascii=False).encode() - ).hexdigest() + return hashlib.sha256(json.dumps(obj, sort_keys=True, ensure_ascii=False).encode()).hexdigest() diff --git a/src/agentkit/llm/gateway.py b/src/agentkit/llm/gateway.py index 48797ba..48ea170 100644 --- a/src/agentkit/llm/gateway.py +++ b/src/agentkit/llm/gateway.py @@ -51,47 +51,19 @@ class LLMGateway: self._usage_tracker = UsageTracker(store=usage_store) if usage_store else UsageTracker() self._config = config or LLMConfig() - # Cache (opt-in, disabled by default) - self._cache: Any = None # LLMCache | None - self._embedder: Any = None # Embedder | None + # Cache (U17 — LiteLLM 缓存管理器,opt-in,默认禁用) + self._cache_manager: Any = None # LitellmCacheManager | None if self._config.cache and self._config.cache.enabled: - from agentkit.llm.cache import create_llm_cache + from agentkit.llm.cache import LitellmCacheConfig, LitellmCacheManager - self._cache = create_llm_cache( - backend=self._config.cache.backend, - redis_url=self._config.cache.redis_url, - max_entries=self._config.cache.max_entries, - exact_ttl=self._config.cache.exact_ttl, - semantic_ttl=self._config.cache.semantic_ttl, - similarity_threshold=self._config.cache.similarity_threshold, - ) - self._embedder = self._create_embedder(self._config.cache) + litellm_config = LitellmCacheConfig.from_cache_config(self._config.cache) + self._cache_manager = LitellmCacheManager(litellm_config) + self._cache_manager.enable() logger.info( - f"LLM cache enabled (backend={self._config.cache.backend}, " - f"embedder={self._config.cache.embedding_provider}/{self._config.cache.embedding_model})" + f"LLM cache enabled (LiteLLM, backend={self._config.cache.backend}, " + f"similarity_threshold={litellm_config.similarity_threshold})" ) - def _create_embedder(self, cache_config) -> Any: - """Create embedder for semantic cache based on config.""" - try: - from agentkit.memory.embedder import OpenAIEmbedder - - if cache_config.embedding_provider in ("xinference", "local"): - return OpenAIEmbedder( - api_key=cache_config.embedding_api_key or "not-needed", - model=cache_config.embedding_model, - base_url=cache_config.embedding_base_url or "http://localhost:9997/v1", - ) - # Default: OpenAI - return OpenAIEmbedder( - api_key=cache_config.embedding_api_key, - model=cache_config.embedding_model, - base_url=cache_config.embedding_base_url, - ) - except Exception as e: - logger.warning(f"Failed to create embedder for semantic cache: {e}") - return None - def register_provider(self, name: str, provider: LLMProvider) -> None: """注册 Provider""" self._providers[name] = provider @@ -114,6 +86,8 @@ class LLMGateway: user_id: str | None = None, department_ids: list[str] | None = None, db_path: Path | str | None = None, + kb_id: str | None = None, + kb_acl_hash: str | None = None, **kwargs, ) -> LLMResponse: """发送 chat 请求,自动解析别名和 Fallback""" @@ -151,67 +125,40 @@ class LLMGateway: start = time.monotonic() - # ── Cache check ── - cache_key = None - query_embedding = None - if self._cache is not None: - from agentkit.llm.cache_key import generate_cache_key + # ── Cache check (U17 — LiteLLM cache via cache_key in request) ── + # LiteLLM 在 litellm.acompletion 内部处理缓存读写,gateway 只需: + # 1. 构建 per-user + ACL-scoped cache_key(安全要求 a, b) + # 2. 将 cache 参数注入 kwargs 透传到 provider + # 3. 检测响应的 cache_hit 标志,用于 usage tracking(cost=0) + if self._cache_manager is not None: + from agentkit.llm.cache import LitellmCacheManager - cache_key = generate_cache_key( - model=resolved_model, - messages=messages, - temperature=kwargs.get("temperature", 0.7), - tools=tools, - tool_choice=tool_choice, - max_tokens=kwargs.get("max_tokens", 2000), - ) - result = await self._cache.get(cache_key) - if result.hit: - latency_ms = (time.monotonic() - start) * 1000 - await self._record_usage( - agent_name=agent_name, - model=result.response.model, - usage=result.response.usage, - cost=0.0, - latency_ms=latency_ms, - user_id=user_id, - department_ids=department_ids, - ) - if _span is not None: - _span.set_attribute("gen_ai.cache.hit", True) - _span.set_attribute("gen_ai.cache.match_type", result.match_type) - return result.response - - # Semantic match (only for temperature == 0) - temperature = kwargs.get("temperature", 0.7) - if temperature == 0 and self._embedder is not None: + # 解析 KB caching_disabled(安全要求 c) + kb_caching_disabled = False + if kb_id is not None: try: - # Embed last N messages for context-aware semantic matching - # (not just last user message — avoids cross-context false hits) - recent_messages = messages[-3:] if len(messages) > 3 else messages - embed_text = " | ".join( - m.get("content", "") for m in recent_messages if m.get("content") - ) - if embed_text: - query_embedding = await self._embedder.embed(embed_text) - result = await self._cache.semantic_search(query_embedding) - if result.hit: - latency_ms = (time.monotonic() - start) * 1000 - await self._record_usage( - agent_name=agent_name, - model=result.response.model, - usage=result.response.usage, - cost=0.0, - latency_ms=latency_ms, - user_id=user_id, - department_ids=department_ids, - ) - if _span is not None: - _span.set_attribute("gen_ai.cache.hit", True) - _span.set_attribute("gen_ai.cache.match_type", "semantic") - return result.response + from agentkit.rag_platform.settings import get_settings_store + + settings = await get_settings_store().get_settings(kb_id) + if settings is not None: + kb_caching_disabled = settings.caching_disabled except Exception as e: - logger.warning(f"Semantic cache search failed: {e}") + logger.warning(f"Failed to read KB cache settings for kb_id={kb_id}: {e}") + + if self._cache_manager.should_cache(kb_caching_disabled, user_id): + cache_key = self._cache_manager.build_cache_key( + model=resolved_model, + messages=messages, + temperature=kwargs.get("temperature", 0.7), + tools=tools, + tool_choice=tool_choice, + max_tokens=kwargs.get("max_tokens", 2000), + user_id=user_id, + kb_acl_hash=kb_acl_hash, + ) + kwargs["cache"] = LitellmCacheManager.cache_params_for_hit(cache_key) + else: + kwargs["cache"] = LitellmCacheManager.cache_params_for_no_cache() # ── Normal provider call ── models_to_try = self._get_models_to_try(resolved_model) @@ -275,15 +222,15 @@ class LLMGateway: latency_ms = (time.monotonic() - start) * 1000 - # ── Cache write ── - if self._cache is not None and cache_key is not None: - try: - await self._cache.put(cache_key, response, query_embedding) - except Exception as e: - logger.warning(f"Cache write failed: {e}") + # U17 — 检测 LiteLLM 缓存命中(用于 usage tracking cost=0) + is_cache_hit = getattr(response, "cache_hit", False) + if is_cache_hit and self._cache_manager is not None: + self._cache_manager._hits += 1 + elif self._cache_manager is not None: + self._cache_manager._misses += 1 - # 计算成本 - cost = self._calculate_cost(response.model, response.usage) + # 计算成本(缓存命中时 cost=0) + cost = 0.0 if is_cache_hit else self._calculate_cost(response.model, response.usage) # 记录使用量 await self._record_usage( @@ -302,8 +249,8 @@ class LLMGateway: _span.set_attribute("gen_ai.usage.output_tokens", response.usage.completion_tokens) _span.set_attribute("gen_ai.response.model", response.model) _span.set_attribute("gen_ai.duration.ms", int(latency_ms)) - if self._cache is not None: - _span.set_attribute("gen_ai.cache.hit", False) + if self._cache_manager is not None: + _span.set_attribute("gen_ai.cache.hit", is_cache_hit) llm_token_histogram().record( response.usage.total_tokens, {"gen_ai.request.model": resolved_model}, @@ -607,18 +554,10 @@ class LLMGateway: ) # 2. Token + cost limits (daily AND monthly) - await self._check_quota_period( - quota_service, db, dept_id, "daily", "token_limit" - ) - await self._check_quota_period( - quota_service, db, dept_id, "daily", "cost_limit" - ) - await self._check_quota_period( - quota_service, db, dept_id, "monthly", "token_limit" - ) - await self._check_quota_period( - quota_service, db, dept_id, "monthly", "cost_limit" - ) + await self._check_quota_period(quota_service, db, dept_id, "daily", "token_limit") + await self._check_quota_period(quota_service, db, dept_id, "daily", "cost_limit") + await self._check_quota_period(quota_service, db, dept_id, "monthly", "token_limit") + await self._check_quota_period(quota_service, db, dept_id, "monthly", "cost_limit") async def _check_quota_period( self, @@ -639,9 +578,7 @@ class LLMGateway: else: current = await self._get_current_cost_for_quota(dept_id, period) - allowed, _reason = await quota_service.check_quota( - db, dept_id, quota_type, period, current - ) + allowed, _reason = await quota_service.check_quota(db, dept_id, quota_type, period, current) if not allowed: quota = await quota_service.get_quota(db, dept_id, quota_type, period) limit = quota["limit_value"] if quota else None diff --git a/src/agentkit/llm/protocol.py b/src/agentkit/llm/protocol.py index b367573..c6f5d54 100644 --- a/src/agentkit/llm/protocol.py +++ b/src/agentkit/llm/protocol.py @@ -47,6 +47,7 @@ class LLMRequest: temperature: float = 0.7, max_tokens: int = 2000, timeout: float | None = None, + cache: dict[str, Any] | None = None, **kwargs: Any, ): self.messages = messages @@ -57,6 +58,8 @@ class LLMRequest: self.max_tokens = max_tokens self.timeout = timeout self._extra = kwargs + # U17 — LiteLLM cache 参数(cache_key 或 no-cache),透传到 litellm.acompletion + self._cache: dict[str, Any] | None = cache @dataclass @@ -81,6 +84,7 @@ class LLMResponse: usage: TokenUsage tool_calls: list[ToolCall] = field(default_factory=list) latency_ms: float = 0.0 + cache_hit: bool = False # U17 — 缓存命中标记,用于 usage tracking(cost=0) @property def has_tool_calls(self) -> bool: diff --git a/src/agentkit/llm/providers/litellm_provider.py b/src/agentkit/llm/providers/litellm_provider.py index 15336c2..56adce4 100644 --- a/src/agentkit/llm/providers/litellm_provider.py +++ b/src/agentkit/llm/providers/litellm_provider.py @@ -183,6 +183,10 @@ class LitellmProvider(LLMProvider): kwargs["tool_choice"] = request.tool_choice if request.timeout is not None: kwargs["timeout"] = request.timeout + # U17 — 透传 LiteLLM cache 参数(cache_key 或 no-cache)到 litellm.acompletion + cache_params = getattr(request, "_cache", None) + if cache_params is not None: + kwargs["cache"] = cache_params # 合并构造时传入的默认 kwargs(如 max_connections 等provider特定参数) kwargs.update(self._default_kwargs) return kwargs @@ -214,12 +218,19 @@ class LitellmProvider(LLMProvider): model_name = getattr(response, "model", None) or request_model + # U17 — 检测 LiteLLM 缓存命中(_hidden_params 含 cache_key 或 cache_hit) + cache_hit = False + hidden = getattr(response, "_hidden_params", None) + if isinstance(hidden, dict) and ("cache_key" in hidden or hidden.get("cache_hit")): + cache_hit = True + return LLMResponse( content=content, model=model_name, usage=usage, tool_calls=tool_calls, latency_ms=latency_ms, + cache_hit=cache_hit, ) def _parse_stream_chunk( diff --git a/tests/integration/test_p0_hardening.py b/tests/integration/test_p0_hardening.py index 8952db7..f28aec4 100644 --- a/tests/integration/test_p0_hardening.py +++ b/tests/integration/test_p0_hardening.py @@ -287,13 +287,13 @@ class TestCacheIntegration: cache=CacheConfig(enabled=True, backend="memory"), ) gateway = LLMGateway(config=config) - assert gateway._cache is not None + assert gateway._cache_manager is not None def test_gateway_without_cache_config(self): """LLMGateway works without cache (default).""" config = LLMConfig(providers={}) gateway = LLMGateway(config=config) - assert gateway._cache is None + assert gateway._cache_manager is None def test_gateway_cache_disabled(self): """LLMGateway does not initialize cache when disabled.""" @@ -302,7 +302,7 @@ class TestCacheIntegration: cache=CacheConfig(enabled=False), ) gateway = LLMGateway(config=config) - assert gateway._cache is None + assert gateway._cache_manager is None # ── Graceful degradation tests ───────────────────────────── @@ -390,7 +390,7 @@ class TestFullFlowInMemory: redis_url=config.usage_store.get("redis_url", "redis://localhost:6379"), ) gateway = LLMGateway(config=config.llm_config, usage_store=usage_store) - assert gateway._cache is not None + assert gateway._cache_manager is not None cascade_store = create_cascade_state_store( backend=config.cascade_store.get("backend", "memory"), diff --git a/tests/unit/llm/test_cache.py b/tests/unit/llm/test_cache.py new file mode 100644 index 0000000..bc3f142 --- /dev/null +++ b/tests/unit/llm/test_cache.py @@ -0,0 +1,486 @@ +"""U17 — LiteLLM 语义缓存管理器单元测试。 + +覆盖场景(plan 要求的 4 个 + 安全要求 e + 向后兼容 + fallback): +1. 语义相似查询命中缓存(mock litellm.acompletion 模拟缓存行为) +2. 不同 system prompt 不命中缓存 +3. 缓存命中率可统计 +4. 阈值调优生效(similarity_threshold 传入 RedisSemanticCache) +5. User A 不返回 User B 缓存(安全要求 e) +6. kb_acl_hash 隔离 — 不同 ACL hash 产生不同 key +7. kb_caching_disabled 禁用缓存(安全要求 c) +8. cache_params_for_hit / no_cache — 返回正确 dict +9. detect_cache_hit — _hidden_params 含 cache_key 时返回 True +10. LitellmCacheConfig.from_cache_config — 转换正确,similarity_threshold=0.87 +11. LitellmCacheManager.enable/disable — litellm.cache 正确设置/清除 +12. generate_cache_key 向后兼容 — user_id=None, kb_acl_hash=None 时与旧版相同 +13. RedisSemanticCache fallback — redisvl 缺失时 auto backend 回退 + +测试用 ``unittest.mock.patch`` mock ``litellm.cache`` 全局变量,避免测试间污染。 +每个测试后清理 ``litellm.cache = None``。 +""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +from agentkit.llm.cache import LitellmCacheConfig, LitellmCacheManager +from agentkit.llm.cache_key import generate_cache_key +from agentkit.llm.config import CacheConfig + + +# --------------------------------------------------------------------------- +# 测试辅助 +# --------------------------------------------------------------------------- + + +def _make_messages(user_content: str = "Hello") -> list[dict[str, str]]: + return [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": user_content}, + ] + + +def _make_litellm_response( + content: str = "Hello!", + model: str = "gpt-4o-mini", + prompt_tokens: int = 10, + completion_tokens: int = 5, + cache_key: str | None = None, +) -> SimpleNamespace: + """构造 LiteLLM 响应(OpenAI ChatCompletion 格式),可选 cache_key 标记。""" + hidden_params: dict[str, Any] = {} + if cache_key is not None: + hidden_params["cache_key"] = cache_key + message = SimpleNamespace(content=content, tool_calls=None) + return SimpleNamespace( + choices=[SimpleNamespace(message=message)], + usage=SimpleNamespace(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens), + model=model, + _hidden_params=hidden_params, + ) + + +@pytest.fixture(autouse=True) +def _cleanup_litellm_cache(): + """每个测试后清理 litellm.cache 全局变量,避免测试间污染。""" + yield + import litellm + + litellm.cache = None + + +# --------------------------------------------------------------------------- +# 1. 语义相似查询命中缓存 +# --------------------------------------------------------------------------- + + +class TestCacheHit: + """场景 1 — 相同请求第二次命中缓存。""" + + async def test_second_request_is_cache_hit(self): + """相同 cache_key 的第二次请求返回缓存响应(cache_hit=True)。""" + config = LitellmCacheConfig(enabled=True, backend="memory") + manager = LitellmCacheManager(config) + + msgs = _make_messages() + cache_key = manager.build_cache_key("gpt-4o", msgs, 0.0) + cache_params = LitellmCacheManager.cache_params_for_hit(cache_key) + + # 模拟 LiteLLM 缓存行为:第一次 miss,第二次 hit + call_count = 0 + cache_store: dict[str, SimpleNamespace] = {} + + async def fake_acompletion(**kwargs): + nonlocal call_count + call_count += 1 + ck = kwargs.get("cache", {}).get("cache_key") + if ck and ck in cache_store: + # 缓存命中 — 返回带 cache_key 标记的响应 + return _make_litellm_response(content="Cached!", cache_key=ck) + # 缓存未命中 — 调用"真实" API 并存储 + resp = _make_litellm_response(content="Fresh response") + if ck: + cache_store[ck] = resp + return resp + + with patch("litellm.acompletion", side_effect=fake_acompletion): + from agentkit.llm.providers.litellm_provider import LitellmProvider + + provider = LitellmProvider(model_prefix="openai/", api_key="test") + from agentkit.llm.protocol import LLMRequest + + # 第一次请求 — miss + req1 = LLMRequest(messages=msgs, model="gpt-4o", temperature=0.0, cache=cache_params) + resp1 = await provider.chat(req1) + assert resp1.cache_hit is False + assert resp1.content == "Fresh response" + + # 第二次相同请求 — hit + req2 = LLMRequest(messages=msgs, model="gpt-4o", temperature=0.0, cache=cache_params) + resp2 = await provider.chat(req2) + assert resp2.cache_hit is True + assert resp2.content == "Cached!" + + assert call_count == 2 # litellm.acompletion 被调用两次(LiteLLM 内部处理缓存) + + +# --------------------------------------------------------------------------- +# 2. 不同 system prompt 不命中缓存 +# --------------------------------------------------------------------------- + + +class TestSystemPromptIsolation: + """场景 2 — 不同 system prompt 产生不同 cache_key,不误命中。""" + + def test_different_system_prompt_different_key(self): + manager = LitellmCacheManager(LitellmCacheConfig(backend="memory")) + msgs1 = [ + {"role": "system", "content": "Be concise"}, + {"role": "user", "content": "Hello"}, + ] + msgs2 = [ + {"role": "system", "content": "Be verbose"}, + {"role": "user", "content": "Hello"}, + ] + key1 = manager.build_cache_key("gpt-4o", msgs1, 0.0) + key2 = manager.build_cache_key("gpt-4o", msgs2, 0.0) + assert key1 != key2 + + +# --------------------------------------------------------------------------- +# 3. 缓存命中率可统计 +# --------------------------------------------------------------------------- + + +class TestCacheStats: + """场景 3 — LitellmCacheManager.stats() 返回正确 hits/misses。""" + + def test_stats_after_hits_and_misses(self): + manager = LitellmCacheManager(LitellmCacheConfig(backend="memory")) + + # 2 hits + manager.detect_cache_hit(_make_litellm_response(cache_key="k1")) + manager.detect_cache_hit(_make_litellm_response(cache_key="k2")) + # 1 miss + manager.detect_cache_hit(_make_litellm_response()) # 无 cache_key + + stats = manager.stats() + assert stats["total_hits"] == 2 + assert stats["total_misses"] == 1 + + +# --------------------------------------------------------------------------- +# 4. 阈值调优生效 +# --------------------------------------------------------------------------- + + +class TestSimilarityThreshold: + """场景 4 — similarity_threshold 传入 RedisSemanticCache 构造函数。""" + + def test_threshold_passed_to_redis_semantic_cache(self): + """RedisSemanticCache 构造时接收 similarity_threshold=0.87。""" + config = LitellmCacheConfig( + enabled=True, + backend="redis_semantic", + similarity_threshold=0.83, + redis_url="redis://localhost:6379", + ) + manager = LitellmCacheManager(config) + + mock_cache_instance = MagicMock() + with patch( + "litellm.caching.RedisSemanticCache", return_value=mock_cache_instance + ) as mock_cls: + instance = manager._create_cache_instance() + mock_cls.assert_called_once_with( + redis_url="redis://localhost:6379", + similarity_threshold=0.83, + embedding_model="text-embedding-ada-002", + ) + assert instance is mock_cache_instance + + def test_default_threshold_is_087(self): + """from_cache_config 固定 similarity_threshold=0.87。""" + old_config = CacheConfig(similarity_threshold=0.92) + litellm_config = LitellmCacheConfig.from_cache_config(old_config) + assert litellm_config.similarity_threshold == 0.87 + + +# --------------------------------------------------------------------------- +# 5. User A 不返回 User B 缓存(安全要求 e) +# --------------------------------------------------------------------------- + + +class TestUserIsolation: + """安全要求 e — User A 的查询不返回 User B 的缓存响应。""" + + def test_different_users_different_keys(self): + """不同 user_id 产生不同 cache_key。""" + manager = LitellmCacheManager(LitellmCacheConfig(backend="memory")) + msgs = _make_messages("What is my salary?") + key_a = manager.build_cache_key("gpt-4o", msgs, 0.0, user_id="user_a") + key_b = manager.build_cache_key("gpt-4o", msgs, 0.0, user_id="user_b") + assert key_a != key_b + + def test_same_user_same_key(self): + """相同 user_id 产生相同 cache_key。""" + manager = LitellmCacheManager(LitellmCacheConfig(backend="memory")) + msgs = _make_messages("What is my salary?") + key1 = manager.build_cache_key("gpt-4o", msgs, 0.0, user_id="user_a") + key2 = manager.build_cache_key("gpt-4o", msgs, 0.0, user_id="user_a") + assert key1 == key2 + + def test_no_user_id_same_as_no_user_id(self): + """user_id=None 时两次调用产生相同 key(向后兼容)。""" + manager = LitellmCacheManager(LitellmCacheConfig(backend="memory")) + msgs = _make_messages() + key1 = manager.build_cache_key("gpt-4o", msgs, 0.0, user_id=None) + key2 = manager.build_cache_key("gpt-4o", msgs, 0.0, user_id=None) + assert key1 == key2 + + +# --------------------------------------------------------------------------- +# 6. kb_acl_hash 隔离 +# --------------------------------------------------------------------------- + + +class TestACLIsolation: + """安全要求 b — 不同 ACL scope 产生不同 cache_key。""" + + def test_different_acl_hash_different_keys(self): + manager = LitellmCacheManager(LitellmCacheConfig(backend="memory")) + msgs = _make_messages("Summarize the document") + key1 = manager.build_cache_key("gpt-4o", msgs, 0.0, kb_acl_hash="acl_v1") + key2 = manager.build_cache_key("gpt-4o", msgs, 0.0, kb_acl_hash="acl_v2") + assert key1 != key2 + + def test_same_acl_hash_same_key(self): + manager = LitellmCacheManager(LitellmCacheConfig(backend="memory")) + msgs = _make_messages() + key1 = manager.build_cache_key("gpt-4o", msgs, 0.0, kb_acl_hash="acl_v1") + key2 = manager.build_cache_key("gpt-4o", msgs, 0.0, kb_acl_hash="acl_v1") + assert key1 == key2 + + +# --------------------------------------------------------------------------- +# 7. kb_caching_disabled 禁用缓存 +# --------------------------------------------------------------------------- + + +class TestKBCachingDisabled: + """安全要求 c — KB 设置 caching_disabled=True 时禁用缓存。""" + + def test_should_cache_returns_false_when_disabled(self): + manager = LitellmCacheManager(LitellmCacheConfig(backend="memory")) + assert manager.should_cache(kb_caching_disabled=True) is False + + def test_should_cache_returns_true_when_enabled(self): + manager = LitellmCacheManager(LitellmCacheConfig(backend="memory")) + assert manager.should_cache(kb_caching_disabled=False) is True + + def test_should_cache_default_true(self): + manager = LitellmCacheManager(LitellmCacheConfig(backend="memory")) + assert manager.should_cache() is True + + +# --------------------------------------------------------------------------- +# 8. cache_params_for_hit / no_cache +# --------------------------------------------------------------------------- + + +class TestCacheParams: + def test_cache_params_for_hit(self): + params = LitellmCacheManager.cache_params_for_hit("my_cache_key") + assert params == {"cache_key": "my_cache_key"} + + def test_cache_params_for_no_cache(self): + params = LitellmCacheManager.cache_params_for_no_cache() + assert params == {"no-cache": True} + + +# --------------------------------------------------------------------------- +# 9. detect_cache_hit +# --------------------------------------------------------------------------- + + +class TestDetectCacheHit: + def test_hit_with_cache_key(self): + manager = LitellmCacheManager(LitellmCacheConfig(backend="memory")) + resp = _make_litellm_response(cache_key="some_key") + assert manager.detect_cache_hit(resp) is True + + def test_hit_with_cache_hit_flag(self): + manager = LitellmCacheManager(LitellmCacheConfig(backend="memory")) + resp = _make_litellm_response() + resp._hidden_params = {"cache_hit": True} + assert manager.detect_cache_hit(resp) is True + + def test_miss_without_cache_key(self): + manager = LitellmCacheManager(LitellmCacheConfig(backend="memory")) + resp = _make_litellm_response() + assert manager.detect_cache_hit(resp) is False + + def test_miss_with_no_hidden_params(self): + manager = LitellmCacheManager(LitellmCacheConfig(backend="memory")) + resp = SimpleNamespace(_hidden_params=None) + assert manager.detect_cache_hit(resp) is False + + def test_miss_with_no_hidden_params_attr(self): + manager = LitellmCacheManager(LitellmCacheConfig(backend="memory")) + resp = SimpleNamespace() + assert manager.detect_cache_hit(resp) is False + + +# --------------------------------------------------------------------------- +# 10. LitellmCacheConfig.from_cache_config +# --------------------------------------------------------------------------- + + +class TestFromCacheConfig: + def test_basic_conversion(self): + old = CacheConfig( + enabled=True, + backend="memory", + redis_url="redis://myhost:6379", + semantic_ttl=3600, + embedding_model="bge-m3", + ) + litellm_config = LitellmCacheConfig.from_cache_config(old) + assert litellm_config.enabled is True + assert litellm_config.backend == "memory" + assert litellm_config.redis_url == "redis://myhost:6379" + assert litellm_config.ttl == 3600 + assert litellm_config.similarity_threshold == 0.87 # 固定,忽略旧 0.92 + assert litellm_config.per_user_namespace is True # 强制开启 + + def test_unknown_backend_falls_to_auto(self): + old = CacheConfig(backend="unknown_backend") + litellm_config = LitellmCacheConfig.from_cache_config(old) + assert litellm_config.backend == "auto" + + def test_redis_backend_preserved(self): + old = CacheConfig(backend="redis") + litellm_config = LitellmCacheConfig.from_cache_config(old) + assert litellm_config.backend == "redis" + + def test_empty_embedding_model_falls_to_default(self): + old = CacheConfig(embedding_model="") + litellm_config = LitellmCacheConfig.from_cache_config(old) + assert litellm_config.embedding_model == "text-embedding-ada-002" + + +# --------------------------------------------------------------------------- +# 11. LitellmCacheManager.enable/disable +# --------------------------------------------------------------------------- + + +class TestEnableDisable: + def test_enable_sets_litellm_cache(self): + import litellm + + config = LitellmCacheConfig(backend="memory") + manager = LitellmCacheManager(config) + manager.enable() + assert litellm.cache is not None + assert manager._cache_instance is not None + + def test_disable_clears_litellm_cache(self): + import litellm + + config = LitellmCacheConfig(backend="memory") + manager = LitellmCacheManager(config) + manager.enable() + assert litellm.cache is not None + manager.disable() + assert litellm.cache is None + assert manager._cache_instance is None + + +# --------------------------------------------------------------------------- +# 12. generate_cache_key 向后兼容 +# --------------------------------------------------------------------------- + + +class TestBackwardCompatibility: + """user_id=None, kb_acl_hash=None 时与旧版 generate_cache_key 完全一致。""" + + def test_none_user_id_same_as_not_passing(self): + msgs = _make_messages() + key1 = generate_cache_key("gpt-4o", msgs, 0.0, user_id=None, kb_acl_hash=None) + key2 = generate_cache_key("gpt-4o", msgs, 0.0) + assert key1 == key2 + + def test_backward_compat_deterministic(self): + msgs = _make_messages() + key1 = generate_cache_key("gpt-4o", msgs, 0.0, user_id=None, kb_acl_hash=None) + key2 = generate_cache_key("gpt-4o", msgs, 0.0, user_id=None, kb_acl_hash=None) + assert key1 == key2 + + def test_user_id_changes_key(self): + msgs = _make_messages() + key_none = generate_cache_key("gpt-4o", msgs, 0.0) + key_user = generate_cache_key("gpt-4o", msgs, 0.0, user_id="user_a") + assert key_none != key_user + + def test_acl_hash_changes_key(self): + msgs = _make_messages() + key_none = generate_cache_key("gpt-4o", msgs, 0.0) + key_acl = generate_cache_key("gpt-4o", msgs, 0.0, kb_acl_hash="acl_v1") + assert key_none != key_acl + + +# --------------------------------------------------------------------------- +# 13. RedisSemanticCache fallback +# --------------------------------------------------------------------------- + + +class TestRedisSemanticCacheFallback: + """redisvl 缺失时 auto backend 回退到 RedisCache。""" + + def test_auto_fallback_to_redis_cache_when_redisvl_missing(self): + """auto 模式下 RedisSemanticCache 构造 ImportError → 回退到 RedisCache。""" + import litellm.caching + + config = LitellmCacheConfig(backend="auto", redis_url="redis://localhost:6379") + manager = LitellmCacheManager(config) + + mock_redis_cache = MagicMock(name="RedisCacheInstance") + + # 模拟 RedisSemanticCache 构造时 ImportError(redisvl 未安装) + def raise_import_error(*args, **kwargs): + raise ImportError("No module named 'redisvl'") + + with patch.object(litellm.caching, "RedisSemanticCache", side_effect=raise_import_error): + with patch.object( + litellm.caching, "RedisCache", return_value=mock_redis_cache + ) as mock_rc: + instance = manager._create_cache_instance() + mock_rc.assert_called_once_with(redis_url="redis://localhost:6379") + assert instance is mock_redis_cache + + def test_redis_semantic_backend_raises_when_redisvl_missing(self): + """显式 redis_semantic backend + redisvl 缺失 → raise ImportError。""" + import litellm.caching + + config = LitellmCacheConfig(backend="redis_semantic") + manager = LitellmCacheManager(config) + + def raise_import_error(*args, **kwargs): + raise ImportError("No module named 'redisvl'") + + with patch.object(litellm.caching, "RedisSemanticCache", side_effect=raise_import_error): + with pytest.raises(ImportError): + manager._create_cache_instance() + + def test_memory_backend_uses_in_memory_cache(self): + """memory backend 直接使用 InMemoryCache,不尝试 Redis。""" + import litellm.caching + + config = LitellmCacheConfig(backend="memory") + manager = LitellmCacheManager(config) + instance = manager._create_cache_instance() + assert isinstance(instance, litellm.caching.InMemoryCache) diff --git a/tests/unit/test_gateway_cache.py b/tests/unit/test_gateway_cache.py index 9e56b67..f054aea 100644 --- a/tests/unit/test_gateway_cache.py +++ b/tests/unit/test_gateway_cache.py @@ -1,15 +1,20 @@ -"""Integration tests for LLM Cache integration into LLMGateway (U2).""" +"""Integration tests for LLM Cache integration into LLMGateway (U2/U17). + +U17 更新:gateway 改用 ``LitellmCacheManager``(LiteLLM 内置缓存)。 +旧的 ``InMemoryLLMCache`` 手动缓存逻辑已移除,缓存读写由 LiteLLM 内部处理。 +测试用 ``CacheAwareMockProvider`` 模拟 LiteLLM 的缓存行为(检查 cache_key, +命中时返回 ``cache_hit=True`` 的响应)。 +""" import pytest -from agentkit.llm.cache import InMemoryLLMCache from agentkit.llm.config import CacheConfig, LLMConfig from agentkit.llm.gateway import LLMGateway from agentkit.llm.protocol import LLMProvider, LLMRequest, LLMResponse, TokenUsage class MockProvider(LLMProvider): - """Mock LLM provider that tracks call count.""" + """Mock LLM provider that tracks call count (no cache awareness).""" def __init__(self, response_content: str = "Mock response"): self.call_count = 0 @@ -24,6 +29,47 @@ class MockProvider(LLMProvider): ) +class CacheAwareMockProvider(LLMProvider): + """Mock provider that simulates LiteLLM's cache behavior. + + Reads ``request._cache`` for cache_key, maintains an internal cache dict. + On cache hit: returns cached response with ``cache_hit=True`` (call_count NOT incremented). + On cache miss: generates fresh response, caches it, returns with ``cache_hit=False``. + """ + + def __init__(self, response_content: str = "Mock response"): + self.call_count = 0 # 仅统计真实调用(缓存未命中) + self._response_content = response_content + self._cache: dict[str, LLMResponse] = {} + + async def chat(self, request: LLMRequest) -> LLMResponse: + cache_params = getattr(request, "_cache", None) or {} + cache_key = cache_params.get("cache_key") + no_cache = cache_params.get("no-cache", False) + + # 缓存命中 — 返回缓存响应(不增加 call_count) + if cache_key and cache_key in self._cache and not no_cache: + cached = self._cache[cache_key] + return LLMResponse( + content=cached.content, + model=cached.model, + usage=cached.usage, + cache_hit=True, + ) + + # 缓存未命中 — 真实调用 + self.call_count += 1 + response = LLMResponse( + content=self._response_content, + model=request.model, + usage=TokenUsage(prompt_tokens=10, completion_tokens=20), + cache_hit=False, + ) + if cache_key and not no_cache: + self._cache[cache_key] = response + return response + + def _make_messages(user_content: str = "Hello") -> list[dict[str, str]]: return [ {"role": "system", "content": "You are a helpful assistant."}, @@ -52,7 +98,7 @@ class TestCacheEnabled: """First request is a cache miss — provider is called.""" config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory")) gateway = LLMGateway(config=config) - provider = MockProvider() + provider = CacheAwareMockProvider() gateway.register_provider("test", provider) msgs = _make_messages() @@ -60,28 +106,30 @@ class TestCacheEnabled: assert provider.call_count == 1 assert response.content == "Mock response" + assert response.cache_hit is False @pytest.mark.asyncio async def test_second_request_is_hit(self): - """Second identical request is a cache hit — provider NOT called.""" + """Second identical request is a cache hit — provider NOT called again.""" config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory")) gateway = LLMGateway(config=config) - provider = MockProvider() + provider = CacheAwareMockProvider() gateway.register_provider("test", provider) msgs = _make_messages() await gateway.chat(msgs, "test/model", temperature=0.0) response = await gateway.chat(msgs, "test/model", temperature=0.0) - assert provider.call_count == 1 # Not called again + assert provider.call_count == 1 # Not called again (cache hit) assert response.content == "Mock response" + assert response.cache_hit is True @pytest.mark.asyncio async def test_cache_hit_usage_has_zero_cost(self): """Cache hit records usage with cost=0.""" config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory")) gateway = LLMGateway(config=config) - provider = MockProvider() + provider = CacheAwareMockProvider() gateway.register_provider("test", provider) msgs = _make_messages() @@ -98,7 +146,7 @@ class TestCacheEnabled: """Different messages produce cache misses.""" config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory")) gateway = LLMGateway(config=config) - provider = MockProvider() + provider = CacheAwareMockProvider() gateway.register_provider("test", provider) await gateway.chat(_make_messages("Hello"), "test/model", temperature=0.0) @@ -110,13 +158,15 @@ class TestCacheEnabled: class TestCacheConfig: def test_config_from_dict(self): """CacheConfig can be loaded from dict.""" - config = LLMConfig.from_dict({ - "cache": { - "enabled": True, - "backend": "memory", - "exact_ttl": 7200, + config = LLMConfig.from_dict( + { + "cache": { + "enabled": True, + "backend": "memory", + "exact_ttl": 7200, + } } - }) + ) assert config.cache is not None assert config.cache.enabled is True assert config.cache.backend == "memory" @@ -129,34 +179,35 @@ class TestCacheConfig: def test_config_from_dict_embedding(self): """Embedding config is loaded correctly.""" - config = LLMConfig.from_dict({ - "cache": { - "enabled": True, - "embedding": { - "provider": "xinference", - "model": "bge-m3", - "base_url": "http://localhost:9997/v1", - }, + config = LLMConfig.from_dict( + { + "cache": { + "enabled": True, + "embedding": { + "provider": "xinference", + "model": "bge-m3", + "base_url": "http://localhost:9997/v1", + }, + } } - }) + ) assert config.cache.embedding_provider == "xinference" assert config.cache.embedding_model == "bge-m3" assert config.cache.embedding_base_url == "http://localhost:9997/v1" def test_gateway_creates_cache_when_enabled(self): - """Gateway creates cache instance when cache.enabled=True.""" + """Gateway creates cache_manager instance when cache.enabled=True.""" config = LLMConfig(cache=CacheConfig(enabled=True, backend="memory")) gateway = LLMGateway(config=config) - assert gateway._cache is not None - assert isinstance(gateway._cache, InMemoryLLMCache) + assert gateway._cache_manager is not None def test_gateway_no_cache_when_disabled(self): - """Gateway has no cache when cache is disabled.""" + """Gateway has no cache_manager when cache is disabled.""" config = LLMConfig(cache=CacheConfig(enabled=False)) gateway = LLMGateway(config=config) - assert gateway._cache is None + assert gateway._cache_manager is None def test_gateway_no_cache_when_no_config(self): - """Gateway has no cache when cache config is absent.""" + """Gateway has no cache_manager when cache config is absent.""" gateway = LLMGateway() - assert gateway._cache is None + assert gateway._cache_manager is None