183 lines
7.0 KiB
Python
183 lines
7.0 KiB
Python
"""U1 — Gateway KB cache fail-closed 行为测试。
|
||
|
||
验证安全要求 R1:KB settings 读取失败时必须 fail-closed(禁用缓存),
|
||
不得 fail-open(默认启用缓存)。
|
||
|
||
覆盖场景:
|
||
1. settings 正常读取 caching_disabled=False → 缓存启用
|
||
2. settings 正常读取 caching_disabled=True → 缓存禁用
|
||
3. get_settings_store() 抛异常 → fail-closed,缓存禁用
|
||
4. settings 返回 None(KB 不存在)→ fail-closed,缓存禁用
|
||
5. kb_id=None(非 RAG 请求)→ 不查 settings,缓存正常启用
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
from types import SimpleNamespace
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
from agentkit.llm.gateway import LLMGateway
|
||
from agentkit.llm.protocol import LLMRequest, LLMResponse, TokenUsage
|
||
|
||
|
||
def _make_response() -> LLMResponse:
|
||
"""构造最小 LLMResponse。"""
|
||
return LLMResponse(
|
||
content="ok",
|
||
model="test-model",
|
||
usage=TokenUsage(prompt_tokens=5, completion_tokens=3),
|
||
)
|
||
|
||
|
||
def _make_gateway_with_cache() -> LLMGateway:
|
||
"""构造带 mock 缓存管理器的 LLMGateway(避免 litellm 依赖)。
|
||
|
||
gateway.chat() 调用 LitellmCacheManager.cache_params_for_hit/no_cache(类静态方法),
|
||
仅 should_cache 和 build_cache_key 通过实例调用 — mock 这两个即可。
|
||
"""
|
||
gateway = LLMGateway() # 不启用真实 cache(litellm 可能未安装)
|
||
mock_manager = MagicMock()
|
||
# should_cache: kb_caching_disabled=True 或 user_id=None 时返回 False
|
||
mock_manager.should_cache = lambda kb_disabled, uid: not kb_disabled and uid is not None
|
||
mock_manager.build_cache_key = MagicMock(return_value="mock_cache_key")
|
||
mock_manager.record_cache_result = MagicMock()
|
||
gateway._cache_manager = mock_manager
|
||
return gateway
|
||
|
||
|
||
def _register_mock_provider(gateway: LLMGateway) -> MagicMock:
|
||
"""注册 mock provider,返回带 cache 参数的 capture。
|
||
|
||
模型格式 "test/test-model" → _resolve_model 按 "/" 分割为 provider="test" + model="test-model"。
|
||
"""
|
||
provider = MagicMock()
|
||
provider.chat = AsyncMock(return_value=_make_response())
|
||
gateway.register_provider("test", provider)
|
||
return provider
|
||
|
||
|
||
_MODEL = "test/test-model"
|
||
|
||
|
||
def _get_cache_arg(provider: MagicMock) -> dict:
|
||
"""从 provider.chat 调用中提取 cache 参数。"""
|
||
call_args = provider.chat.call_args
|
||
# provider.chat(req) — req 是第一个位置参数
|
||
req: LLMRequest = call_args.args[0]
|
||
return req._cache or {}
|
||
|
||
|
||
class TestGatewayCacheFailClosed:
|
||
"""U1 — KB settings 读取异常时 fail-closed。"""
|
||
|
||
async def test_settings_caching_false_enables_cache(self):
|
||
"""settings 正常读取 caching_disabled=False → cache_key 传入 provider(启用缓存)。"""
|
||
gateway = _make_gateway_with_cache()
|
||
provider = _register_mock_provider(gateway)
|
||
|
||
mock_settings = SimpleNamespace(caching_disabled=False)
|
||
with patch("agentkit.rag_platform.settings.get_settings_store") as mock_get_store:
|
||
mock_store = MagicMock()
|
||
mock_store.get_settings = AsyncMock(return_value=mock_settings)
|
||
mock_get_store.return_value = mock_store
|
||
|
||
await gateway.chat(
|
||
messages=[{"role": "user", "content": "hi"}],
|
||
model=_MODEL,
|
||
user_id="u1",
|
||
kb_id="kb1",
|
||
kb_acl_hash="acl1",
|
||
)
|
||
|
||
cache_arg = _get_cache_arg(provider)
|
||
assert "cache_key" in cache_arg, f"Expected cache_key (cache enabled), got {cache_arg}"
|
||
|
||
async def test_settings_caching_true_disables_cache(self):
|
||
"""settings 正常读取 caching_disabled=True → no-cache 传入 provider(禁用缓存)。"""
|
||
gateway = _make_gateway_with_cache()
|
||
provider = _register_mock_provider(gateway)
|
||
|
||
mock_settings = SimpleNamespace(caching_disabled=True)
|
||
with patch("agentkit.rag_platform.settings.get_settings_store") as mock_get_store:
|
||
mock_store = MagicMock()
|
||
mock_store.get_settings = AsyncMock(return_value=mock_settings)
|
||
mock_get_store.return_value = mock_store
|
||
|
||
await gateway.chat(
|
||
messages=[{"role": "user", "content": "hi"}],
|
||
model=_MODEL,
|
||
user_id="u1",
|
||
kb_id="kb1",
|
||
)
|
||
|
||
cache_arg = _get_cache_arg(provider)
|
||
assert cache_arg.get("no-cache") is True, f"Expected no-cache=True, got {cache_arg}"
|
||
|
||
async def test_settings_exception_fail_closed(self):
|
||
"""get_settings_store() 抛异常 → fail-closed(no-cache)。"""
|
||
gateway = _make_gateway_with_cache()
|
||
provider = _register_mock_provider(gateway)
|
||
|
||
with patch("agentkit.rag_platform.settings.get_settings_store") as mock_get_store:
|
||
mock_store = MagicMock()
|
||
mock_store.get_settings = AsyncMock(side_effect=RuntimeError("DB down"))
|
||
mock_get_store.return_value = mock_store
|
||
|
||
await gateway.chat(
|
||
messages=[{"role": "user", "content": "hi"}],
|
||
model=_MODEL,
|
||
user_id="u1",
|
||
kb_id="kb1",
|
||
)
|
||
|
||
cache_arg = _get_cache_arg(provider)
|
||
assert cache_arg.get("no-cache") is True, (
|
||
f"fail-closed: 读取异常应禁用缓存,但 got {cache_arg}"
|
||
)
|
||
|
||
async def test_settings_none_fail_closed(self):
|
||
"""settings 返回 None(KB 不存在)→ fail-closed(no-cache)。"""
|
||
gateway = _make_gateway_with_cache()
|
||
provider = _register_mock_provider(gateway)
|
||
|
||
with patch("agentkit.rag_platform.settings.get_settings_store") as mock_get_store:
|
||
mock_store = MagicMock()
|
||
mock_store.get_settings = AsyncMock(return_value=None)
|
||
mock_get_store.return_value = mock_store
|
||
|
||
await gateway.chat(
|
||
messages=[{"role": "user", "content": "hi"}],
|
||
model=_MODEL,
|
||
user_id="u1",
|
||
kb_id="kb_nonexistent",
|
||
)
|
||
|
||
cache_arg = _get_cache_arg(provider)
|
||
assert cache_arg.get("no-cache") is True, (
|
||
f"Expected no-cache for None settings, got {cache_arg}"
|
||
)
|
||
|
||
async def test_no_kb_id_skips_settings_lookup(self):
|
||
"""kb_id=None(非 RAG 请求)→ 不查 settings,缓存正常启用。"""
|
||
gateway = _make_gateway_with_cache()
|
||
provider = _register_mock_provider(gateway)
|
||
|
||
with patch("agentkit.rag_platform.settings.get_settings_store") as mock_get_store:
|
||
mock_store = MagicMock()
|
||
mock_store.get_settings = AsyncMock()
|
||
mock_get_store.return_value = mock_store
|
||
|
||
await gateway.chat(
|
||
messages=[{"role": "user", "content": "hi"}],
|
||
model=_MODEL,
|
||
user_id="u1",
|
||
# kb_id 不传
|
||
)
|
||
|
||
# 不应查询 settings
|
||
mock_store.get_settings.assert_not_called()
|
||
|
||
# 缓存应启用(cache_key 传入)
|
||
cache_arg = _get_cache_arg(provider)
|
||
assert "cache_key" in cache_arg, f"Expected cache_key (no kb_id), got {cache_arg}"
|