fix(security): P0 安全加固 + 多实例部署一致性 (U1-U4 + U5c)
Deploy to Production / deploy (push) Has been cancelled Details

U1: LLM gateway KB 缓存 fail-closed — 异常时默认禁用缓存防止 KB 数据泄漏
U2: MCP 危险工具黑名单过滤 — 6+1 端点覆盖,防止绕过 chat confirmation
U3: SecretsStore Redis 迁移 — 多 worker 共享凭证,内存降级保留开发模式
U4: channels webhook Redis 状态 — ZSET 滑动窗口限流 + nonce dedup + backpressure
U5c: ce-code-review 修复批次:
  - P0: 统一 MCP 黑名单与 publisher.py 一致 (terminal_execute -> terminal, +file_read)
  - P1: ZSET 限流 member 加 uuid 后缀避免同时间戳碰撞
  - P1: SecretsStore redis 参数 Any -> aioredis.Redis | None (AGENTS.md 合规)
  - P1: Redis client 添加 socket_timeout 防止单点故障请求挂死

测试: 171 scoped tests pass, ruff clean
This commit is contained in:
chiguyong 2026-06-26 04:05:33 +08:00
parent c62d435c43
commit 31c65e01b8
10 changed files with 1539 additions and 80 deletions

View File

@ -0,0 +1,255 @@
---
title: "fix: P0 安全与多实例一致性加固"
date: 2026-06-26
type: fix
status: planned
origin: 代码走查报告portal-platform-evolution 合并后审查)
---
# fix: P0 安全与多实例一致性加固
## Summary
修复代码走查发现的 4 项高优先级问题,分两组:(1) 安全 fail-open 默认值 — Gateway KB cache fail-closed + MCP dangerous-tool 黑名单过滤;(2) 多实例部署一致性 — SecretsStore 与 Channels 全局状态迁移到 Redis消除多 worker 下的安全与状态共享缺口。
## Problem Frame
`portal-platform-evolution` 合并引入了多端渠道U10-U12、MCP 协议U13/U16、LiteLLM 缓存U17等特性。代码走查发现两类系统性风险
1. **安全 fail-open**Gateway 在 KB settings 读取异常时默认启用缓存(可能泄漏禁用缓存的 KB 数据MCP 端点暴露所有工具(含 ShellTool 等危险工具),绕过 chat 流程的 confirmation 机制。
2. **多实例不一致**SecretsStore 和 Channels 路由使用模块级内存字典存储凭证、nonce、限流、backpressure 状态,多 worker 部署下失效,与 README 宣称的 K8s 多实例部署目标冲突。
## Requirements
- **R1**Gateway KB cache 在 settings 读取失败时必须 fail-closed禁用缓存不得 fail-open
- **R2**MCP `/tools/call``/tools/list`(含 JSON-RPC 端点)必须过滤危险工具,禁止绕过 chat confirmation 流程
- **R3**SecretsStore 加密凭证必须存储在 Redis多 worker 共享,保留 AES-256-GCM 加密层
- **R4**Channels webhook nonce dedup、rate limit、backpressure、渠道配置必须存储在 Redis多 worker 一致
- **R5**所有改动不破坏现有单进程行为Redis 不可用时降级到内存,需显式 log warning
- **R6**:每个修复单元包含可运行的验证测试
## Key Technical Decisions
### KTD1: H4 危险工具过滤采用黑名单
**决策**:使用 `_MCP_BLOCKED_TOOLS: frozenset[str]` 黑名单,而非在 Tool 基类增加 `requires_confirmation` 元数据字段。
**理由**黑名单改动最小ponytail 原则),仅影响 mcp/server.py 一个文件。元数据方案需修改 Tool ABC + 所有工具子类 + ToolRegistry改动面过大。黑名单天花板已在代码注释标注`ponytail: 黑名单需手动维护,新增危险工具需同步更新`)。
### KTD2: H1/H2 Redis 连接复用 app.state.working_redis_client
**决策**:复用 `app.py` lifespan 已创建的 `app.state.working_redis_client``aioredis.from_url(redis_url, decode_responses=True)`),通过 FastAPI `Depends``request.app.state` 注入到 channels 路由和 SecretsStore。
**理由**:避免新建 Redis 连接池(资源浪费),与 WorkingMemory、TaskManager 等现有子系统共享同一连接。`decode_responses=True` 与 RedisBus 模式一致。
### KTD3: 降级策略 — Redis 不可用时 fail-closed
**决策**Redis 连接失败时webhook 端点返回 503而非降级到内存secrets 读取返回 None而非空字典
**理由**:安全优先。降级到内存会导致多 worker 间状态不一致nonce dedup 失效可能引发重放攻击。fail-closed 符合 AGENTS.md 安全约定。
## Implementation Units
### U1. Gateway KB cache fail-closed
**Goal**: KB settings 读取异常时禁用缓存,而非默认启用。
**Requirements**: R1
**Dependencies**: 无
**Files**:
- `src/agentkit/llm/gateway.py`(修改第 137-146 行)
- `tests/unit/llm/test_gateway_cache_failclosed.py`(新建)
**Approach**: 将 `kb_caching_disabled` 默认值从 `False` 改为 `True`fail-closed仅在 settings 成功读取且 `caching_disabled=False` 时才启用缓存。读取异常或 settings 为 None 时保持 `True`
**Patterns to follow**: 现有 `should_cache` 已有 `per_user_namespace + user_id=None` 防护逻辑,本单元补齐 KB settings 读取层的防护。
**Test scenarios**:
- **Happy path**: settings 正常读取 `caching_disabled=False` → 缓存启用,`should_cache` 返回 True
- **Happy path**: settings 正常读取 `caching_disabled=True` → 缓存禁用,`should_cache` 返回 False
- **Error path**: `get_settings_store().get_settings()` 抛异常 → `kb_caching_disabled=True``should_cache` 返回 Falsefail-closed
- **Edge case**: settings 返回 NoneKB 不存在)→ `kb_caching_disabled=True`
- **Edge case**: `kb_id=None`(非 RAG 请求)→ 不查 settings缓存正常启用
**Verification**: `pytest tests/unit/llm/test_gateway_cache_failclosed.py -v` 全绿mock settings store 抛异常时断言 `should_cache` 返回 False。
---
### U2. MCP dangerous-tool 黑名单过滤
**Goal**: 阻止危险工具通过 MCP 端点执行,绕过 chat confirmation 流程。
**Requirements**: R2
**Dependencies**: 无
**Files**:
- `src/agentkit/mcp/server.py`(修改 `_all_tools`、`_find_tool`、REST `call_tool`、JSON-RPC `tools/call`
- `tests/unit/mcp/test_server_dangerous_tools.py`(新建)
**Approach**: 在 `server.py` 顶部定义 `_MCP_BLOCKED_TOOLS` 黑名单(`shell`、`file_write`、`file_delete`、`terminal_execute`)。`_all_tools()` 和 `_find_tool()` 过滤黑名单工具。`list_tools` 和 JSON-RPC `tools/list` 不返回黑名单工具。`call_tool` 对黑名单工具返回 404。
**Technical design**directional guidance:
```python
# ponytail: 黑名单需手动维护,新增危险工具需同步更新。
# 天花板:若工具名动态生成或别名注册,黑名单可能漏判。
# 升级路径Tool 基类增加 requires_confirmation 元数据字段。
_MCP_BLOCKED_TOOLS: frozenset[str] = frozenset({
"shell", "file_write", "file_delete", "terminal_execute",
})
```
**Patterns to follow**: `shell.py` 已有 `_DANGEROUS_BINARIES` frozenset 模式。
**Test scenarios**:
- **Happy path**: `GET /tools/list` 不返回黑名单工具
- **Happy path**: `POST /tools/call` with `{"name": "safe_tool"}` 正常执行
- **Error path**: `POST /tools/call` with `{"name": "shell"}` → 404
- **Error path**: `POST /tools/call` with `{"name": "file_delete"}` → 404
- **Integration**: JSON-RPC `tools/call` with `{"method": "tools/call", "params": {"name": "shell"}}``isError: True`
- **Edge case**: JSON-RPC `tools/list` 不包含黑名单工具
**Verification**: `pytest tests/unit/mcp/test_server_dangerous_tools.py -v` 全绿;黑名单工具在 list 和 call 端点均不可见。
---
### U3. SecretsStore Redis 迁移
**Goal**: 加密凭证存储迁移到 Redis多 worker 共享,保留 AES-256-GCM 加密层。
**Requirements**: R3, R5
**Dependencies**: 无(可与 U4 并行,共享 Redis 注入基础设施)
**Files**:
- `src/agentkit/channels/secrets.py`(修改 `SecretsStore.__init__`、`set_secret`、`get_secret`、`delete_secret`、`list_keys`
- `src/agentkit/server/routes/channels.py`(修改 `_get_secrets_store` 注入 Redis
- `src/agentkit/server/app.py`lifespan 中将 `working_redis_client` 传入 SecretsStore
- `tests/unit/channels/test_secrets_redis.py`(新建)
**Approach**: `SecretsStore.__init__` 新增 `redis` 参数(`aioredis.Redis | None`)。`_store` 字典替换为 Redis 调用:`set_secret` → `redis.set(f"secrets:{key}", entry.model_dump_json())``get_secret` → `redis.get` + `SecretEntry.model_validate_json``delete_secret` → `redis.delete``list_keys` → `redis.scan_iter`。加密层(`encrypt`/`decrypt`/`_derive_key`不变。Redis 为 None 时降级到内存仅开发模式log warning
**Patterns to follow**: `RedisBus._get_redis()` 懒初始化模式;`SharedWorkspace` 的 `redis_client` 注入模式。
**Test scenarios**:
- **Happy path**: `set_secret("k", "v")``get_secret("k")` 返回 `"v"`Redis mock
- **Happy path**: `delete_secret("k")``get_secret("k")` 返回 None
- **Happy path**: `list_keys(prefix="feishu:")` 返回匹配前缀的 key 列表
- **Edge case**: Redis 为 None降级模式→ 内存字典行为log warning
- **Error path**: Redis 连接失败 → `get_secret` 抛异常或返回 Nonefail-closed
- **Integration**: 加密-存储-读取-解密往返:明文 → encrypt → Redis SET → Redis GET → decrypt → 明文一致
**Verification**: `pytest tests/unit/channels/test_secrets_redis.py -v` 全绿;多 worker 场景下 worker A `set_secret` 后 worker B `get_secret` 能读到(需 Redis faker 或集成测试)。
---
### U4. Channels 全局状态 Redis 迁移
**Goal**: nonce dedup、rate limit、backpressure、渠道配置迁移到 Redis多 worker 一致。
**Requirements**: R4, R5, R6
**Dependencies**: U3共享 Redis 注入基础设施)
**Files**:
- `src/agentkit/server/routes/channels.py`(修改 `_channels`、`_rate_limits`、`_seen_nonces`、`_pending_webhook_tasks` 相关逻辑)
- `tests/unit/channels/test_webhook_redis_state.py`(新建)
**Approach**:
1. **nonce dedup** → Redis `SET NX EX``await redis.set(f"nonce:{nonce}", "1", ex=300, nx=True)` 返回 None 表示重复。
2. **rate limit** → Redis ZSET 滑动窗口:`ZREMRANGEBYSCORE` 清理过期 + `ZADD` + `ZCARD` 计数。
3. **backpressure** → Redis `INCR`/`DECR` 共享计数器:`INCR webhook:concurrent`,超限 `DECR` 并返回 429执行完毕 `DECR`。`EXPIRE 30s` 防 crash 计数不归零。
4. **渠道配置** → Redis Hash`HSET channel:{id}` 存配置 JSON`HGETALL` 读取。`_adapter_cache` 保留进程内缓存(只读快照,配置变更时 invalidate
5. **Redis 注入**webhook handler 通过 `request.app.state.working_redis_client` 获取连接。
**Technical design**directional guidance:
```python
# nonce dedup — 原子操作TTL 自动过期
is_new = await redis.set(f"nonce:{nonce}", "1", ex=int(_NONCE_TTL), nx=True)
if not is_new:
return Response(status_code=200) # 重复事件,飞书要求 3s 内响应
# rate limit — ZSET 滑动窗口
key = f"ratelimit:{client_ip}"
now = time.time()
pipe = redis.pipeline()
pipe.zremrangebyscore(key, 0, now - _RATE_LIMIT_WINDOW)
pipe.zadd(key, {str(now): now})
pipe.zcard(key)
pipe.expire(key, int(_RATE_LIMIT_WINDOW))
_, _, count, _ = await pipe.execute()
if count > _RATE_LIMIT_MAX:
return Response(status_code=429)
# backpressure — 共享计数器
current = await redis.incr("webhook:concurrent")
if current > _WEBHOOK_MAX_CONCURRENT * 2:
await redis.decr("webhook:concurrent")
return Response(status_code=429)
try:
# ... process webhook ...
finally:
await redis.decr("webhook:concurrent")
```
**Patterns to follow**: Redis pipeline 模式(`usage_store.py` 已有);`_adapter_cache` 保留进程内缓存(`_invalidate_adapter_cache` 逻辑不变)。
**Test scenarios**:
- **Happy path**: 首次 nonce → `SET NX` 返回 True请求正常处理
- **Happy path**: 重复 nonce → `SET NX` 返回 None返回 200不重复处理
- **Happy path**: rate limit 窗口内 100 请求通过,第 101 请求返回 429
- **Happy path**: backpressure 并发 < 2x 上限时请求通过
- **Error path**: backpressure 并发 >= 2x 上限时返回 429
- **Edge case**: nonce TTL 过期后相同 nonce 可再次使用mock `time.sleep` 或 Redis TTL
- **Edge case**: rate limit 窗口滚动后计数重置
- **Integration**: 渠道配置写入 Redis 后,新 worker 的 `_adapter_cache` miss 时从 Redis 读取并缓存
**Verification**: `pytest tests/unit/channels/test_webhook_redis_state.py -v` 全绿;`ruff check src/ && ruff format src/` 通过;现有 `tests/unit/channels/test_wecom.py` 等不回归。
---
## Scope Boundaries
### In Scope
- 4 项高优先级问题修复H3/H4/H1/H2
- Redis 连接复用现有 `app.state.working_redis_client`
- 单元测试覆盖每个修复单元
### Out of Scope
- 中优先级问题M1-M8feishu/slack decode 异常、dingtalk 时间戳窗口、mcp/client 性能、app.py shutdown 超时等
- 低优先级问题L1-L3X-Forwarded-For 信任、生产 guard 触发条件、Any 类型
- PostgreSQL 持久化迁移Redis 已满足多实例共享需求)
- MCP confirmation 协议实现(黑名单方案足够,元数据方案为后续升级路径)
- Tool 基类 `requires_confirmation` 元数据重构
### Deferred to Follow-Up Work
- M1 feishu/slack `body.decode("utf-8")` 异常捕获 → 独立 PR
- M4 mcp/server.py 异常消息脱敏 + REST/JSON-RPC 代码去重 → 独立 PR
- M6 app.py shutdown `asyncio.gather` 超时 → 独立 PR
- H4 升级路径Tool 基类 `requires_confirmation` 元数据 → 下个迭代
## Risks & Dependencies
- **风险 1**H1/H2 Redis 迁移后Redis 不可用会导致 webhook 端点完全失效fail-closed。缓解lifespan 启动时 ping Redis不可用时 log error 但不阻止启动(单进程降级模式仍可用)。
- **风险 2**U4 backpressure `INCR/DECR` 非原子crash 可能导致计数不归零。缓解:`EXPIRE 30s` 安全网 + 定期清理。
- **依赖 1**U4 依赖 U3 完成的 Redis 注入基础设施。
- **依赖 2**:所有单元测试使用 `fakeredis``unittest.mock.AsyncMock` mock Redis不依赖真实 Redis 实例。
## System-Wide Impact
- **运维**:多 worker 部署gunicorn -w 4现在可正确共享渠道状态无需 sticky session
- **安全**MCP 端点不再暴露危险工具KB 缓存 fail-closed 防止数据泄漏
- **性能**Redis 网络往返替代内存字典,单请求延迟增加 ~1ms可接受backpressure 跨 worker 生效
- **兼容性**:单进程开发模式(无 Redis降级到内存存储行为不变
## Open Questions
无 — 所有技术决策已在 KTD1-KTD3 中明确。

View File

@ -7,7 +7,8 @@ KTD8 关键决策:
server 启动钩子实现本模块提供 ``assert_production_master_key`` 辅助 server 启动钩子实现本模块提供 ``assert_production_master_key`` 辅助
- Master key 轮换采用双密钥窗口策略key_id 字段标记 - Master key 轮换采用双密钥窗口策略key_id 字段标记
当前为内存存储实现PG 迁移预留接口``_store: dict`` 未来替换为 ORM session U3Redis 后端支持多 worker 共享``redis`` 参数为 None 时降级到内存字典
仅开发模式生产环境必须注入 Redis 客户端
""" """
from __future__ import annotations from __future__ import annotations
@ -16,6 +17,7 @@ import base64
import logging import logging
import os import os
import redis.asyncio as aioredis
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -74,21 +76,41 @@ class SecretsStore:
使用 AES-256-GCM 加密HKDF with per-row salt 派生 per-row 密钥 使用 AES-256-GCM 加密HKDF with per-row salt 派生 per-row 密钥
Master key 从环境变量 ``AGENTKIT_MASTER_KEY`` 读取开发 fallback Master key 从环境变量 ``AGENTKIT_MASTER_KEY`` 读取开发 fallback
生产环境应通过 KMS 提供 master key 并显式传入 生产环境应通过 KMS 提供 master key 并显式传入
U3``redis`` 参数注入 Redis 客户端后凭证存储在 Redis worker 共享
``redis=None`` 时降级到内存字典仅开发模式
""" """
def __init__(self, master_key: bytes | None = None, *, key_source: str = "env"): # Redis key 前缀 — 遵循 codebase 约定agentkit:<subsystem>:
_REDIS_PREFIX = "agentkit:secrets:"
def __init__(
self,
master_key: bytes | None = None,
*,
key_source: str = "env",
redis: aioredis.Redis | None = None,
):
"""初始化 secrets store。 """初始化 secrets store。
Args: Args:
master_key: 显式传入的 master key若为 None 则从环境变量加载 master_key: 显式传入的 master key若为 None 则从环境变量加载
key_source: master key 来源标记传给 ``assert_production_master_key`` key_source: master key 来源标记传给 ``assert_production_master_key``
生产环境应设为 "kms" 生产环境应设为 "kms"
redis: Redis 客户端``redis.asyncio.Redis``注入后凭证存储在 Redis
worker 共享None 时降级到内存字典仅开发模式
""" """
self._master_key = master_key or self._load_master_key() self._master_key = master_key or self._load_master_key()
# 生产 guard若 key_source="env" 且在生产模式下,构造即失败。 # 生产 guard若 key_source="env" 且在生产模式下,构造即失败。
assert_production_master_key(self._master_key, source=key_source) assert_production_master_key(self._master_key, source=key_source)
# 内存存储PG 迁移预留接口:替换为 ORM session 即可) self._redis = redis
# 内存降级存储redis=None 时使用)
self._store: dict[str, SecretEntry] = {} self._store: dict[str, SecretEntry] = {}
if redis is None:
logger.warning(
"SecretsStore 运行于内存降级模式redis=None"
"多 worker 部署下状态不共享,仅限开发使用。"
)
def _load_master_key(self) -> bytes: def _load_master_key(self) -> bytes:
"""从环境变量加载 master key开发 fallback""" """从环境变量加载 master key开发 fallback"""
@ -148,18 +170,30 @@ class SecretsStore:
"""存储加密凭证。覆盖同名 key。""" """存储加密凭证。覆盖同名 key。"""
entry = self.encrypt(value) entry = self.encrypt(value)
entry.key = key entry.key = key
self._store[key] = entry if self._redis is not None:
await self._redis.set(self._REDIS_PREFIX + key, entry.model_dump_json())
else:
self._store[key] = entry
return entry return entry
async def get_secret(self, key: str) -> str | None: async def get_secret(self, key: str) -> str | None:
"""读取并解密凭证。key 不存在返回 None。""" """读取并解密凭证。key 不存在返回 None。"""
entry = self._store.get(key) if self._redis is not None:
if entry is None: raw = await self._redis.get(self._REDIS_PREFIX + key)
return None if raw is None:
return None
entry = SecretEntry.model_validate_json(raw)
else:
entry = self._store.get(key)
if entry is None:
return None
return self.decrypt(entry) return self.decrypt(entry)
async def delete_secret(self, key: str) -> bool: async def delete_secret(self, key: str) -> bool:
"""删除凭证。返回是否删除成功。""" """删除凭证。返回是否删除成功。"""
if self._redis is not None:
deleted = await self._redis.delete(self._REDIS_PREFIX + key)
return deleted > 0
if key in self._store: if key in self._store:
del self._store[key] del self._store[key]
return True return True
@ -167,6 +201,12 @@ class SecretsStore:
async def list_keys(self, prefix: str | None = None) -> list[str]: async def list_keys(self, prefix: str | None = None) -> list[str]:
"""列出凭证键。可选前缀过滤。""" """列出凭证键。可选前缀过滤。"""
if self._redis is not None:
pattern = self._REDIS_PREFIX + (prefix or "") + "*"
keys: list[str] = []
async for k in self._redis.scan_iter(match=pattern):
keys.append(k[len(self._REDIS_PREFIX) :])
return keys
if prefix: if prefix:
return [k for k in self._store if k.startswith(prefix)] return [k for k in self._store if k.startswith(prefix)]
return list(self._store.keys()) return list(self._store.keys())

View File

@ -134,7 +134,11 @@ class LLMGateway:
from agentkit.llm.cache import LitellmCacheManager from agentkit.llm.cache import LitellmCacheManager
# 解析 KB caching_disabled安全要求 c # 解析 KB caching_disabled安全要求 c
kb_caching_disabled = False # 非 RAG 请求kb_id=None→ 默认启用缓存(无 KB 数据需保护)。
# RAG 请求kb_id!=None→ fail-closed默认禁用缓存仅在 settings
# 明确返回 caching_disabled=False 时启用。防止 DB 异常时 fail-open
# 导致禁用缓存的 KB 数据泄漏到缓存。
kb_caching_disabled = kb_id is not None
if kb_id is not None: if kb_id is not None:
try: try:
from agentkit.rag_platform.settings import get_settings_store from agentkit.rag_platform.settings import get_settings_store
@ -142,8 +146,10 @@ class LLMGateway:
settings = await get_settings_store().get_settings(kb_id) settings = await get_settings_store().get_settings(kb_id)
if settings is not None: if settings is not None:
kb_caching_disabled = settings.caching_disabled kb_caching_disabled = settings.caching_disabled
# settings 为 NoneKB 不存在)→ 保持 Truefail-closed
except Exception as e: except Exception as e:
logger.warning(f"Failed to read KB cache settings for kb_id={kb_id}: {e}") logger.warning(f"Failed to read KB cache settings for kb_id={kb_id}: {e}")
# 读取异常 → 保持 Truefail-closed禁用缓存
if self._cache_manager.should_cache(kb_caching_disabled, user_id): if self._cache_manager.should_cache(kb_caching_disabled, user_id):
cache_key = self._cache_manager.build_cache_key( cache_key = self._cache_manager.build_cache_key(

View File

@ -25,6 +25,30 @@ logger = logging.getLogger(__name__)
# 认证由主 app 的 AuthMiddleware 处理,这里只做权限校验。 # 认证由主 app 的 AuthMiddleware 处理,这里只做权限校验。
_mcp_member_auth = require_permission(Permission.CHAT) _mcp_member_auth = require_permission(Permission.CHAT)
# ponytail: 危险工具黑名单 — 这些工具依赖 chat confirmation 流程WebSocket
# 通过 MCP 暴露会绕过用户确认机制。黑名单需手动维护,新增危险工具需同步更新。
# 天花板:若工具名动态生成或别名注册,黑名单可能漏判。
# 升级路径Tool 基类增加 requires_confirmation 元数据字段,按属性自动过滤。
# U5c: 与 publisher.py 的 _DANGEROUS_TOOL_NAMES 保持一致(单一真相源)。
_MCP_BLOCKED_TOOLS: frozenset[str] = frozenset(
{
"terminal", # 终端执行(与 ShellTool 协同)— 危险命令需 confirmation
"shell", # ShellTool — 危险命令需 confirmation
"file_write", # 文件写入
"file_read", # 文件读取(可能泄露敏感配置)
"file_delete", # 文件删除
}
)
def _serialize_tool(tool: Tool) -> dict[str, Any]:
"""将 Tool 序列化为 MCP 协议响应字典。"""
return {
"name": tool.name,
"description": tool.description,
"inputSchema": tool.input_schema or {},
}
def create_mcp_router( def create_mcp_router(
tool_registry: Any = None, tool_registry: Any = None,
@ -52,7 +76,7 @@ def create_mcp_router(
router = APIRouter(tags=["mcp"]) router = APIRouter(tags=["mcp"])
def _all_tools() -> list[Tool]: def _all_tools() -> list[Tool]:
"""合并 ToolRegistry 与已发布工具""" """合并 ToolRegistry 与已发布工具,过滤危险工具"""
tools: list[Tool] = [] tools: list[Tool] = []
if tool_registry is not None: if tool_registry is not None:
tools.extend(tool_registry.list_tools()) tools.extend(tool_registry.list_tools())
@ -61,10 +85,14 @@ def create_mcp_router(
tools.extend(published_tools_getter()) tools.extend(published_tools_getter())
except Exception: except Exception:
logger.exception("published_tools_getter failed") logger.exception("published_tools_getter failed")
return tools # 过滤危险工具 — 防止绕过 chat confirmation 流程
return [t for t in tools if t.name not in _MCP_BLOCKED_TOOLS]
def _find_tool(name: str) -> Tool | None: def _find_tool(name: str) -> Tool | None:
"""按名查找工具(先 registry 后已发布)。""" """按名查找工具(先 registry 后已发布),危险工具返回 None。"""
if name in _MCP_BLOCKED_TOOLS:
logger.warning(f"MCP tool '{name}' blocked — requires chat confirmation flow")
return None
if tool_registry is not None: if tool_registry is not None:
try: try:
return tool_registry.get(name) return tool_registry.get(name)
@ -80,16 +108,7 @@ def create_mcp_router(
async def list_tools(_user: dict = Depends(_mcp_member_auth)) -> dict[str, Any]: async def list_tools(_user: dict = Depends(_mcp_member_auth)) -> dict[str, Any]:
"""列出所有可用的 MCP 工具。""" """列出所有可用的 MCP 工具。"""
tools = _all_tools() tools = _all_tools()
return { return {"tools": [_serialize_tool(t) for t in tools]}
"tools": [
{
"name": t.name,
"description": t.description,
"inputSchema": t.input_schema or {},
}
for t in tools
]
}
@router.post("/tools/call") @router.post("/tools/call")
async def call_tool( async def call_tool(
@ -152,16 +171,7 @@ def create_mcp_router(
} }
elif method == "tools/list": elif method == "tools/list":
tools = _all_tools() tools = _all_tools()
result = { result = {"tools": [_serialize_tool(t) for t in tools]}
"tools": [
{
"name": t.name,
"description": t.description,
"inputSchema": t.input_schema or {},
}
for t in tools
]
}
elif method == "tools/call": elif method == "tools/call":
tool_name = params.get("name", "") tool_name = params.get("name", "")
arguments = params.get("arguments", {}) arguments = params.get("arguments", {})
@ -171,6 +181,13 @@ def create_mcp_router(
"isError": True, "isError": True,
"content": [{"type": "text", "text": "Tool not found"}], "content": [{"type": "text", "text": "Tool not found"}],
} }
elif tool_name in _MCP_BLOCKED_TOOLS:
# 显式黑名单检查 — 与 legacy JSON-RPC 一致,提供清晰审计反馈
logger.warning(f"MCP tool '{tool_name}' blocked — requires chat confirmation flow")
result = {
"isError": True,
"content": [{"type": "text", "text": f"Tool '{tool_name}' is blocked via MCP"}],
}
else: else:
tool = _find_tool(tool_name) tool = _find_tool(tool_name)
if tool is None: if tool is None:
@ -240,16 +257,9 @@ class MCPServer:
if self._tool_registry is None: if self._tool_registry is None:
return {"tools": []} return {"tools": []}
tools = self._tool_registry.list_tools() tools = self._tool_registry.list_tools()
return { # 过滤危险工具(与 create_mcp_router 一致)
"tools": [ tools = [t for t in tools if t.name not in _MCP_BLOCKED_TOOLS]
{ return {"tools": [_serialize_tool(t) for t in tools]}
"name": t.name,
"description": t.description,
"inputSchema": t.input_schema or {},
}
for t in tools
]
}
@app.post("/tools/call") @app.post("/tools/call")
async def call_tool(request: dict): async def call_tool(request: dict):
@ -259,6 +269,9 @@ class MCPServer:
if not tool_name or self._tool_registry is None: if not tool_name or self._tool_registry is None:
return {"error": "Tool not specified or registry not configured"} return {"error": "Tool not specified or registry not configured"}
if tool_name in _MCP_BLOCKED_TOOLS:
return {"error": f"Tool '{tool_name}' is blocked via MCP"}
try: try:
tool = self._tool_registry.get(tool_name) tool = self._tool_registry.get(tool_name)
result = await tool.safe_execute(**arguments) result = await tool.safe_execute(**arguments)
@ -302,12 +315,7 @@ class MCPServer:
tools = self._tool_registry.list_tools() tools = self._tool_registry.list_tools()
result = { result = {
"tools": [ "tools": [
{ _serialize_tool(t) for t in tools if t.name not in _MCP_BLOCKED_TOOLS
"name": t.name,
"description": t.description,
"inputSchema": t.input_schema or {},
}
for t in tools
] ]
} }
elif method == "tools/call": elif method == "tools/call":
@ -319,6 +327,13 @@ class MCPServer:
"isError": True, "isError": True,
"content": [{"type": "text", "text": "Tool not found"}], "content": [{"type": "text", "text": "Tool not found"}],
} }
elif tool_name in _MCP_BLOCKED_TOOLS:
result = {
"isError": True,
"content": [
{"type": "text", "text": f"Tool '{tool_name}' is blocked via MCP"}
],
}
else: else:
try: try:
tool = self._tool_registry.get(tool_name) tool = self._tool_registry.get(tool_name)

View File

@ -975,9 +975,19 @@ def create_app(
redis_url = server_config.memory["working"].get( redis_url = server_config.memory["working"].get(
"redis_url", "redis://localhost:6379" "redis_url", "redis://localhost:6379"
) )
redis_client = aioredis.from_url(redis_url, decode_responses=True) # U5c: socket_timeout 防止单点 Redis 故障时网关请求挂死。
redis_client = aioredis.from_url(
redis_url,
decode_responses=True,
socket_timeout=5.0,
socket_connect_timeout=5.0,
)
working = WorkingMemory(redis=redis_client) working = WorkingMemory(redis=redis_client)
app.state.working_redis_client = redis_client app.state.working_redis_client = redis_client
# U3注入 Redis 到 channels 模块SecretsStore 多 worker 共享)
from agentkit.server.routes.channels import _set_redis_client
_set_redis_client(redis_client)
if server_config.memory.get("semantic", {}).get("enabled"): if server_config.memory.get("semantic", {}).get("enabled"):
sem_conf = server_config.memory["semantic"] sem_conf = server_config.memory["semantic"]

View File

@ -23,6 +23,7 @@ import re
import time import time
from collections import OrderedDict from collections import OrderedDict
from typing import Any from typing import Any
from uuid import uuid4
from fastapi import APIRouter, Depends, HTTPException, Request from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import Response from fastapi.responses import Response
@ -52,17 +53,29 @@ _CHANNEL_ID_RE = re.compile(r"^[a-z0-9][a-z0-9_-]{0,63}$")
# 凭证字段名校验字母数字下划线1-64 字符 # 凭证字段名校验字母数字下划线1-64 字符
_SECRET_NAME_RE = re.compile(r"^[a-zA-Z0-9_]{1,64}$") _SECRET_NAME_RE = re.compile(r"^[a-zA-Z0-9_]{1,64}$")
# ponytail: 模块级单例 store。当前为内存实现PG 迁移后改为请求级 session # ponytail: 模块级单例 store。U3 后支持 Redis 后端(多 worker 共享)
# 天花板:多进程部署下状态不共享;升级路径:注入 app.state.secrets_store # _redis_client 由 app.py lifespan 通过 _set_redis_client() 注入。
# 由 lifespan 绑定到 PG session factory # redis=None 时 SecretsStore 降级到内存字典(仅开发模式)
_secrets_store: SecretsStore | None = None _secrets_store: SecretsStore | None = None
_redis_client: Any = None
def _set_redis_client(redis: Any) -> None:
"""注入 Redis 客户端(由 app.py lifespan 调用)。
设置后下次 _get_secrets_store() 会用该 redis 客户端构造 SecretsStore
"""
global _redis_client
_redis_client = redis
# 重置 store 以便下次 _get_secrets_store 用新 redis 重建
_reset_secrets_store()
def _get_secrets_store() -> SecretsStore: def _get_secrets_store() -> SecretsStore:
"""获取全局 secrets store 单例(懒加载)。""" """获取全局 secrets store 单例(懒加载,注入 redis)。"""
global _secrets_store global _secrets_store
if _secrets_store is None: if _secrets_store is None:
_secrets_store = SecretsStore() _secrets_store = SecretsStore(redis=_redis_client)
return _secrets_store return _secrets_store
@ -164,6 +177,71 @@ def _reset_webhook_state() -> None:
_pending_webhook_tasks.clear() _pending_webhook_tasks.clear()
# ---------------------------------------------------------------------------
# U4: Redis 后端 — nonce dedup / rate limit / backpressure
#
# 多 worker 部署下内存状态不共享。Redis 后端确保 nonce dedup、限流、
# backpressure 跨 worker 一致。redis=None 时降级到上方内存实现。
# ponytail: 天花板 — _channels 渠道配置 dict 仍为进程内存储,多 worker 下
# 配置变更不可见。升级路径_channels 迁移到 Redis Hash + read-through cache。
# ---------------------------------------------------------------------------
# Redis key 前缀 — 遵循 codebase 约定agentkit:<subsystem>:
_REDIS_NS = "agentkit:channels:webhook:"
_NONCE_KEY = _REDIS_NS + "nonce:{nonce}"
_RATELIMIT_KEY = _REDIS_NS + "ratelimit:{ip}"
_BACKPRESSURE_KEY = _REDIS_NS + "concurrent"
_BACKPRESSURE_TTL = 30 # 秒 — crash 安全网,防止计数不归零
def _get_redis_from_request(request: Request) -> Any:
"""从 request.app.state 提取 working_redis_client可能为 None"""
return getattr(request.app.state, "working_redis_client", None)
async def _check_rate_limit_redis(redis: Any, client_ip: str) -> bool:
"""Redis ZSET 滑动窗口限流。返回 True 放行False 超限。"""
key = _RATELIMIT_KEY.format(ip=client_ip)
now = time.time()
# U5c: member 加 uuid 后缀避免同时间戳碰撞3 评审员一致标记)。
pipe = redis.pipeline()
pipe.zremrangebyscore(key, 0, now - _RATE_LIMIT_WINDOW)
pipe.zadd(key, {f"{now}:{uuid4().hex}": now})
pipe.zcard(key)
pipe.expire(key, int(_RATE_LIMIT_WINDOW))
_, _, count, _ = await pipe.execute()
return count <= _RATE_LIMIT_MAX
async def _check_nonce_dedup_redis(redis: Any, nonce: str) -> bool:
"""Redis SET NX EX nonce 去重。返回 True 新 nonceFalse 重复。"""
key = _NONCE_KEY.format(nonce=nonce)
# SET key 1 EX ttl NX — 原子操作,返回 True 表示新增成功
is_new = await redis.set(key, "1", ex=int(_NONCE_TTL), nx=True)
return bool(is_new)
async def _acquire_backpressure_slot(redis: Any) -> bool:
"""Redis INCR 共享并发计数。返回 True 获得槽位False 超限。
调用方在任务完成后必须调用 _release_backpressure_slot(redis)
pipeline 合并 INCR + EXPIRE 为单次往返
"""
pipe = redis.pipeline()
pipe.incr(_BACKPRESSURE_KEY)
pipe.expire(_BACKPRESSURE_KEY, _BACKPRESSURE_TTL)
current, _ = await pipe.execute()
if current > _WEBHOOK_MAX_CONCURRENT * 2:
await redis.decr(_BACKPRESSURE_KEY)
return False
return True
async def _release_backpressure_slot(redis: Any) -> None:
"""释放并发槽位DECR 计数器)。"""
await redis.decr(_BACKPRESSURE_KEY)
def _validate_channel_id(channel_id: str) -> str: def _validate_channel_id(channel_id: str) -> str:
"""校验渠道 ID非法时抛 400。""" """校验渠道 ID非法时抛 400。"""
if not _CHANNEL_ID_RE.match(channel_id): if not _CHANNEL_ID_RE.match(channel_id):
@ -413,9 +491,7 @@ async def _build_adapter(channel_id: str) -> MessageAdapter:
store.get_secret(f"{channel_id}:token"), store.get_secret(f"{channel_id}:token"),
) )
if not all([app_key, app_secret, robot_code]): if not all([app_key, app_secret, robot_code]):
raise HTTPException( raise HTTPException(status_code=500, detail=f"渠道 '{channel_id}' 缺少 dingtalk 凭证")
status_code=500, detail=f"渠道 '{channel_id}' 缺少 dingtalk 凭证"
)
adapter = DingTalkMessageAdapter( adapter = DingTalkMessageAdapter(
app_key=app_key, app_key=app_key,
app_secret=app_secret, app_secret=app_secret,
@ -434,9 +510,7 @@ async def _build_adapter(channel_id: str) -> MessageAdapter:
store.get_secret(f"{channel_id}:agent_id"), store.get_secret(f"{channel_id}:agent_id"),
) )
if not all([corp_id, corp_secret, token, encoding_aes_key, agent_id_raw]): if not all([corp_id, corp_secret, token, encoding_aes_key, agent_id_raw]):
raise HTTPException( raise HTTPException(status_code=500, detail=f"渠道 '{channel_id}' 缺少 wecom 凭证")
status_code=500, detail=f"渠道 '{channel_id}' 缺少 wecom 凭证"
)
try: try:
agent_id = int(agent_id_raw) agent_id = int(agent_id_raw)
except (TypeError, ValueError) as exc: except (TypeError, ValueError) as exc:
@ -461,9 +535,7 @@ async def _build_adapter(channel_id: str) -> MessageAdapter:
store.get_secret(f"{channel_id}:verification_token"), store.get_secret(f"{channel_id}:verification_token"),
) )
if not bot_token or not signing_secret: if not bot_token or not signing_secret:
raise HTTPException( raise HTTPException(status_code=500, detail=f"渠道 '{channel_id}' 缺少 slack 凭证")
status_code=500, detail=f"渠道 '{channel_id}' 缺少 slack 凭证"
)
adapter = SlackMessageAdapter( adapter = SlackMessageAdapter(
bot_token=bot_token, bot_token=bot_token,
signing_secret=signing_secret, signing_secret=signing_secret,
@ -495,9 +567,7 @@ async def close_all_adapters() -> None:
_adapter_cache.clear() _adapter_cache.clear()
async def _process_inbound_message( async def _process_inbound_message(app_state: Any, adapter: MessageAdapter, message: Any) -> None:
app_state: Any, adapter: MessageAdapter, message: Any
) -> None:
"""后台处理入站消息 — 调用 chat 链路并通过适配器回复。 """后台处理入站消息 — 调用 chat 链路并通过适配器回复。
整个流程 try/except 包裹任何异常仅记录日志不向上抛出 整个流程 try/except 包裹任何异常仅记录日志不向上抛出
@ -520,9 +590,7 @@ async def _process_inbound_message(
request_preprocessor = getattr(app_state, "request_preprocessor", None) request_preprocessor = getattr(app_state, "request_preprocessor", None)
llm_gateway = getattr(app_state, "llm_gateway", None) llm_gateway = getattr(app_state, "llm_gateway", None)
if request_preprocessor is None or llm_gateway is None: if request_preprocessor is None or llm_gateway is None:
logger.warning( logger.warning("app.state 缺少 request_preprocessor 或 llm_gateway — 跳过消息处理")
"app.state 缺少 request_preprocessor 或 llm_gateway — 跳过消息处理"
)
return return
# 路由预处理 — IM 场景使用默认 agent无需技能注册表 # 路由预处理 — IM 场景使用默认 agent无需技能注册表
@ -587,12 +655,22 @@ async def channel_webhook(channel_id: str, request: Request) -> Any:
5. URL verification 飞书/Slack 返回 challenge企微返回 XML 5. URL verification 飞书/Slack 返回 challenge企微返回 XML
6. 解析消息 后台异步处理 立即返回 200 6. 解析消息 后台异步处理 立即返回 200
U4nonce dedup限流backpressure 优先使用 Redis worker 共享
redis=None 时降级到进程内内存实现
企微通过 query 参数传递 ``msg_signature``/``timestamp``/``nonce`` 企微通过 query 参数传递 ``msg_signature``/``timestamp``/``nonce``
合并到 headers dict 供适配器读取 合并到 headers dict 供适配器读取
""" """
redis = _get_redis_from_request(request)
client_ip = _get_client_ip(request) client_ip = _get_client_ip(request)
if not _check_rate_limit(client_ip):
raise HTTPException(status_code=429, detail="请求过于频繁,请稍后重试") # 1. 限流 — Redis ZSET 滑动窗口(多 worker 共享),降级到内存
if redis is not None:
if not await _check_rate_limit_redis(redis, client_ip):
raise HTTPException(status_code=429, detail="请求过于频繁,请稍后重试")
else:
if not _check_rate_limit(client_ip):
raise HTTPException(status_code=429, detail="请求过于频繁,请稍后重试")
body = await request.body() body = await request.body()
@ -607,10 +685,15 @@ async def channel_webhook(channel_id: str, request: Request) -> Any:
if not await adapter.verify_signature(headers_dict, body): if not await adapter.verify_signature(headers_dict, body):
raise HTTPException(status_code=401, detail="签名校验失败") raise HTTPException(status_code=401, detail="签名校验失败")
# Nonce dedup可选 — 若头不存在则跳过去重;仅飞书携带该头) # 4. Nonce dedup — Redis SET NX EX多 worker 共享),降级到内存
nonce = request.headers.get("x-lark-request-nonce") nonce = request.headers.get("x-lark-request-nonce")
if nonce and not _check_nonce_dedup(nonce): if nonce:
return {"code": 0, "msg": "duplicate"} if redis is not None:
if not await _check_nonce_dedup_redis(redis, nonce):
return {"code": 0, "msg": "duplicate"}
else:
if not _check_nonce_dedup(nonce):
return {"code": 0, "msg": "duplicate"}
try: try:
message = await adapter.receive_message(headers_dict, body) message = await adapter.receive_message(headers_dict, body)
@ -624,13 +707,27 @@ async def channel_webhook(channel_id: str, request: Request) -> Any:
logger.warning("receive_message 解析失败 channel=%s: %s", channel_id, exc) logger.warning("receive_message 解析失败 channel=%s: %s", channel_id, exc)
return {"code": 0, "msg": "invalid_payload"} return {"code": 0, "msg": "invalid_payload"}
# 异步处理 — 不阻塞 webhook 响应(平台要求快速返回 200 # 6. Backpressure + 异步分发 — Redis 共享计数器(多 worker降级到内存 set
# 持有 task 引用防止 GC 回收正在运行的后台任务 if redis is not None:
# 有界化:超过 2x 并发上限时拒绝新任务(防突发流量下 set 无界增长) if not await _acquire_backpressure_slot(redis):
if len(_pending_webhook_tasks) >= _WEBHOOK_MAX_CONCURRENT * 2: logger.warning("webhook Redis backpressure 超限,拒绝新任务")
logger.warning("webhook 后台任务积压 %d,拒绝新任务", len(_pending_webhook_tasks)) raise HTTPException(status_code=429, detail="服务器繁忙,请稍后重试")
raise HTTPException(status_code=429, detail="服务器繁忙,请稍后重试")
task = asyncio.create_task(_process_inbound_message(request.app.state, adapter, message)) async def _run_with_release() -> None:
"""执行消息处理并在完成后释放 Redis 并发槽位。"""
try:
await _process_inbound_message(request.app.state, adapter, message)
finally:
await _release_backpressure_slot(redis)
task = asyncio.create_task(_run_with_release())
else:
# 降级模式:进程内 set 长度检查 + GC 引用持有
if len(_pending_webhook_tasks) >= _WEBHOOK_MAX_CONCURRENT * 2:
logger.warning("webhook 后台任务积压 %d,拒绝新任务", len(_pending_webhook_tasks))
raise HTTPException(status_code=429, detail="服务器繁忙,请稍后重试")
task = asyncio.create_task(_process_inbound_message(request.app.state, adapter, message))
_pending_webhook_tasks.add(task) _pending_webhook_tasks.add(task)
task.add_done_callback(_pending_webhook_tasks.discard) task.add_done_callback(_pending_webhook_tasks.discard)

View File

@ -0,0 +1,245 @@
"""U3 — SecretsStore Redis 后端单元测试。
覆盖场景
- Redis 后端 CRUDset/get/delete/list_keys
- Redis=None 降级到内存字典
- 加密-存储-读取-解密往返Redis 后端
- Redis 连接异常 fail-closedget_secret 返回 None
- 多实例共享语义两个 store 共用同一 Redis mock
不依赖真实 Redis 实例 使用 AsyncMock 模拟 redis.asyncio.Redis
"""
from __future__ import annotations
import base64
from unittest.mock import AsyncMock, MagicMock
import pytest
from agentkit.channels.secrets import KEY_SIZE, SecretsStore
# ── 辅助函数 ──────────────────────────────────────────────
def _make_master_key() -> bytes:
"""确定性测试 master key32 字节)。"""
return b"\x01" * KEY_SIZE
class _FakeRedis:
"""极简 Redis mock — 仅模拟 SecretsStore 用到的方法。
内部用 dict 存储支持 set/get/delete/scan_iter
scan_iter 返回 async generator匹配真实 Redis 接口
"""
def __init__(self):
self._data: dict[str, str] = {}
async def set(self, key: str, value: str, **kwargs) -> bool:
self._data[key] = value
return True
async def get(self, key: str) -> str | None:
return self._data.get(key)
async def delete(self, key: str) -> int:
if key in self._data:
del self._data[key]
return 1
return 0
async def scan_iter(self, match: str = "*"):
import fnmatch
for k in self._data:
if fnmatch.fnmatch(k, match):
yield k
def _make_failing_redis() -> AsyncMock:
"""构造一个所有方法都抛 ConnectionError 的 Redis mock。"""
redis = AsyncMock()
redis.get = AsyncMock(side_effect=ConnectionError("redis down"))
redis.set = AsyncMock(side_effect=ConnectionError("redis down"))
redis.delete = AsyncMock(side_effect=ConnectionError("redis down"))
redis.scan_iter = MagicMock(side_effect=ConnectionError("redis down"))
return redis
@pytest.fixture(autouse=True)
def _clean_env(monkeypatch):
"""确保测试不在生产模式下运行。"""
monkeypatch.delenv("AGENTKIT_ENV", raising=False)
monkeypatch.delenv("AGENTKIT_MASTER_KEY", raising=False)
# ── Redis 后端 CRUD ──────────────────────────────────────
class TestRedisBackendCrud:
"""Redis 后端的 set/get/delete/list_keys 操作。"""
async def test_set_and_get_secret_with_redis(self):
"""Redis 后端:写入后读取返回明文。"""
redis = _FakeRedis()
store = SecretsStore(master_key=_make_master_key(), redis=redis)
await store.set_secret("feishu:app_id", "cli_xxx")
assert await store.get_secret("feishu:app_id") == "cli_xxx"
async def test_get_nonexistent_returns_none_with_redis(self):
"""Redis 后端:读取不存在的 key 返回 None。"""
redis = _FakeRedis()
store = SecretsStore(master_key=_make_master_key(), redis=redis)
assert await store.get_secret("missing") is None
async def test_set_overwrites_existing_with_redis(self):
"""Redis 后端:同名 key 写入覆盖旧值。"""
redis = _FakeRedis()
store = SecretsStore(master_key=_make_master_key(), redis=redis)
await store.set_secret("k", "old")
await store.set_secret("k", "new")
assert await store.get_secret("k") == "new"
async def test_delete_secret_with_redis(self):
"""Redis 后端:删除凭证后不可读取。"""
redis = _FakeRedis()
store = SecretsStore(master_key=_make_master_key(), redis=redis)
await store.set_secret("k", "v")
assert await store.delete_secret("k") is True
assert await store.get_secret("k") is None
async def test_delete_nonexistent_returns_false_with_redis(self):
"""Redis 后端:删除不存在的 key 返回 False。"""
redis = _FakeRedis()
store = SecretsStore(master_key=_make_master_key(), redis=redis)
assert await store.delete_secret("missing") is False
async def test_list_keys_with_redis(self):
"""Redis 后端:列出全部凭证键。"""
redis = _FakeRedis()
store = SecretsStore(master_key=_make_master_key(), redis=redis)
await store.set_secret("feishu:a", "1")
await store.set_secret("feishu:b", "2")
await store.set_secret("dingtalk:c", "3")
keys = await store.list_keys()
assert set(keys) == {"feishu:a", "feishu:b", "dingtalk:c"}
async def test_list_keys_with_prefix_with_redis(self):
"""Redis 后端:按前缀过滤凭证键。"""
redis = _FakeRedis()
store = SecretsStore(master_key=_make_master_key(), redis=redis)
await store.set_secret("feishu:a", "1")
await store.set_secret("feishu:b", "2")
await store.set_secret("dingtalk:c", "3")
keys = await store.list_keys(prefix="feishu:")
assert set(keys) == {"feishu:a", "feishu:b"}
# ── Redis=None 降级模式 ─────────────────────────────────
class TestFallbackMode:
"""redis=None 时降级到内存字典。"""
async def test_fallback_set_and_get(self):
"""降级模式:写入后读取返回明文。"""
store = SecretsStore(master_key=_make_master_key(), redis=None)
await store.set_secret("k", "v")
assert await store.get_secret("k") == "v"
async def test_fallback_delete(self):
"""降级模式:删除凭证后不可读取。"""
store = SecretsStore(master_key=_make_master_key(), redis=None)
await store.set_secret("k", "v")
assert await store.delete_secret("k") is True
assert await store.get_secret("k") is None
async def test_fallback_list_keys_with_prefix(self):
"""降级模式:按前缀过滤。"""
store = SecretsStore(master_key=_make_master_key(), redis=None)
await store.set_secret("feishu:a", "1")
await store.set_secret("dingtalk:b", "2")
keys = await store.list_keys(prefix="feishu:")
assert keys == ["feishu:a"]
# ── 加密往返 ─────────────────────────────────────────────
class TestRedisEncryptionRoundtrip:
"""Redis 后端的加密-存储-读取-解密往返。"""
async def test_stored_value_is_encrypted_in_redis(self):
"""Redis 中存储的 value 字段为密文,非明文。"""
redis = _FakeRedis()
store = SecretsStore(master_key=_make_master_key(), redis=redis)
plaintext = "plaintext-token"
await store.set_secret("k", plaintext)
# 直接检查 Redis 中的原始值
raw = redis._data["agentkit:secrets:k"]
assert plaintext not in raw
# base64 解码后的密文也不含明文
import json
entry = json.loads(raw)
ciphertext_bytes = base64.b64decode(entry["value"])
assert plaintext.encode() not in ciphertext_bytes
async def test_roundtrip_preserves_plaintext(self):
"""加密-存储-读取-解密往返保持明文一致。"""
redis = _FakeRedis()
store = SecretsStore(master_key=_make_master_key(), redis=redis)
plaintext = "feishu-app-secret-12345"
await store.set_secret("key1", plaintext)
assert await store.get_secret("key1") == plaintext
# ── 多实例共享语义 ───────────────────────────────────────
class TestMultiInstanceSharing:
"""两个 store 共用同一 Redis → 多 worker 共享语义。"""
async def test_two_stores_share_state_via_redis(self):
"""worker A set_secret → worker B get_secret 能读到(共享 Redis"""
redis = _FakeRedis()
store_a = SecretsStore(master_key=_make_master_key(), redis=redis)
store_b = SecretsStore(master_key=_make_master_key(), redis=redis)
await store_a.set_secret("shared:key", "value-from-a")
# worker B 通过同一 Redis 能读到 worker A 写入的凭证
assert await store_b.get_secret("shared:key") == "value-from-a"
async def test_delete_from_one_store_visible_to_other(self):
"""worker A delete → worker B get 返回 None。"""
redis = _FakeRedis()
store_a = SecretsStore(master_key=_make_master_key(), redis=redis)
store_b = SecretsStore(master_key=_make_master_key(), redis=redis)
await store_a.set_secret("k", "v")
assert await store_b.delete_secret("k") is True
assert await store_a.get_secret("k") is None
# ── Redis 连接异常 ───────────────────────────────────────
class TestRedisConnectionFailure:
"""Redis 连接失败时的 fail-closed 行为。"""
async def test_get_secret_raises_on_redis_failure(self):
"""Redis 异常时 get_secret 抛异常fail-closed不静默返回 None"""
redis = _make_failing_redis()
store = SecretsStore(master_key=_make_master_key(), redis=redis)
with pytest.raises(ConnectionError):
await store.get_secret("k")
async def test_set_secret_raises_on_redis_failure(self):
"""Redis 异常时 set_secret 抛异常(写入失败不静默降级)。"""
redis = _make_failing_redis()
store = SecretsStore(master_key=_make_master_key(), redis=redis)
with pytest.raises(ConnectionError):
await store.set_secret("k", "v")

View File

@ -0,0 +1,340 @@
"""U4 — Channels webhook Redis 状态迁移单元测试。
覆盖场景
- nonce dedup首次 True重复 FalseTTL 过期 可再次使用
- rate limit窗口内 100 请求通过 101 False窗口滚动后重置
- backpressure并发 < 2x 上限 True>= 2x 上限 False释放后恢复
- 降级模式redis=None 内存实现仍工作
不依赖真实 Redis 使用 _FakeRedis 模拟 redis.asyncio.Redis 的子集
"""
from __future__ import annotations
import fnmatch
import time
import pytest
from agentkit.server.routes.channels import (
_BACKPRESSURE_KEY,
_RATE_LIMIT_MAX,
_WEBHOOK_MAX_CONCURRENT,
_acquire_backpressure_slot,
_check_nonce_dedup,
_check_nonce_dedup_redis,
_check_rate_limit,
_check_rate_limit_redis,
_release_backpressure_slot,
_reset_webhook_state,
)
# ── _FakeRedis — 支持 SecretsStore + webhook 用到的操作 ───
class _FakeRedis:
"""极简 Redis mock支持 string/zset/incr/pipeline/expire/scan_iter。
内部数据结构
- _strings: dict[str, str] string 类型set/get/delete
- _zsets: dict[str, dict[str, float]] sorted setzadd/zcard/zremrangebyscore
- _expires: dict[str, float] key -> 过期 monotonic 时间戳
"""
def __init__(self):
self._strings: dict[str, str] = {}
self._zsets: dict[str, dict[str, float]] = {}
self._expires: dict[str, float] = {}
self._incr_values: dict[str, int] = {}
def _is_expired(self, key: str) -> bool:
exp = self._expires.get(key)
return exp is not None and time.monotonic() >= exp
def _cleanup_if_expired(self, key: str) -> None:
if self._is_expired(key):
self._strings.pop(key, None)
self._zsets.pop(key, None)
self._incr_values.pop(key, None)
self._expires.pop(key, None)
# ── string 操作 ───────────────────────────────────
async def set(self, key: str, value: str, *, ex: int | None = None, nx: bool = False) -> bool:
self._cleanup_if_expired(key)
if nx and key in self._strings:
return False
self._strings[key] = value
if ex is not None:
self._expires[key] = time.monotonic() + ex
return True
async def get(self, key: str) -> str | None:
self._cleanup_if_expired(key)
return self._strings.get(key)
async def delete(self, key: str) -> int:
existed = key in self._strings or key in self._zsets
self._strings.pop(key, None)
self._zsets.pop(key, None)
self._expires.pop(key, None)
return 1 if existed else 0
# ── zset 操作 ─────────────────────────────────────
async def zadd(self, key: str, mapping: dict[str, float]) -> int:
self._cleanup_if_expired(key)
zset = self._zsets.setdefault(key, {})
added = 0
for member, score in mapping.items():
if member not in zset:
added += 1
zset[member] = score
return added
async def zcard(self, key: str) -> int:
self._cleanup_if_expired(key)
return len(self._zsets.get(key, {}))
async def zremrangebyscore(self, key: str, min_score: float, max_score: float) -> int:
self._cleanup_if_expired(key)
zset = self._zsets.get(key, {})
removed = 0
for member in list(zset):
if min_score <= zset[member] <= max_score:
del zset[member]
removed += 1
return removed
# ── incr/decr ─────────────────────────────────────
async def incr(self, key: str) -> int:
self._cleanup_if_expired(key)
self._incr_values[key] = self._incr_values.get(key, 0) + 1
return self._incr_values[key]
async def decr(self, key: str) -> int:
self._cleanup_if_expired(key)
self._incr_values[key] = self._incr_values.get(key, 0) - 1
return self._incr_values[key]
# ── expire ────────────────────────────────────────
async def expire(self, key: str, seconds: int) -> bool:
if key in self._strings or key in self._zsets or key in self._incr_values:
self._expires[key] = time.monotonic() + seconds
return True
return False
# ── pipeline ──────────────────────────────────────
def pipeline(self) -> "_FakePipeline":
return _FakePipeline(self)
# ── scan_iter ─────────────────────────────────────
async def scan_iter(self, match: str = "*"):
for k in list(self._strings):
self._cleanup_if_expired(k)
if k in self._strings and fnmatch.fnmatch(k, match):
yield k
class _FakePipeline:
"""模拟 redis pipeline — 收集命令execute() 返回结果列表。"""
def __init__(self, redis: _FakeRedis):
self._redis = redis
self._commands: list[tuple[str, tuple, dict]] = []
def zremrangebyscore(self, key, min_score, max_score):
self._commands.append(("zremrangebyscore", (key, min_score, max_score), {}))
return self
def zadd(self, key, mapping):
self._commands.append(("zadd", (key, mapping), {}))
return self
def zcard(self, key):
self._commands.append(("zcard", (key,), {}))
return self
def expire(self, key, seconds):
self._commands.append(("expire", (key, seconds), {}))
return self
def incr(self, key):
self._commands.append(("incr", (key,), {}))
return self
async def execute(self) -> list:
results = []
for cmd, args, kwargs in self._commands:
method = getattr(self._redis, cmd)
result = await method(*args, **kwargs)
results.append(result)
return results
# ── Fixtures ─────────────────────────────────────────────
@pytest.fixture(autouse=True)
def _reset_state():
"""每个测试前重置内存状态。"""
_reset_webhook_state()
yield
_reset_webhook_state()
# ── nonce dedup (Redis) ──────────────────────────────────
class TestNonceDedupRedis:
"""Redis 后端 nonce 去重。"""
async def test_first_nonce_returns_true(self):
"""首次 nonce → True新 nonce应处理"""
redis = _FakeRedis()
assert await _check_nonce_dedup_redis(redis, "nonce-001") is True
async def test_duplicate_nonce_returns_false(self):
"""重复 nonce → False跳过处理"""
redis = _FakeRedis()
assert await _check_nonce_dedup_redis(redis, "nonce-001") is True
assert await _check_nonce_dedup_redis(redis, "nonce-001") is False
async def test_different_nonces_both_return_true(self):
"""不同 nonce 各自首次 → True。"""
redis = _FakeRedis()
assert await _check_nonce_dedup_redis(redis, "nonce-a") is True
assert await _check_nonce_dedup_redis(redis, "nonce-b") is True
async def test_nonce_ttl_expiry_allows_reuse(self):
"""TTL 过期后相同 nonce 可再次使用。"""
redis = _FakeRedis()
# 首次写入
assert await _check_nonce_dedup_redis(redis, "nonce-exp") is True
# 模拟 TTL 过期 — 手动清除过期标记
redis._expires.clear()
redis._strings.clear()
# 过期后相同 nonce 可再次使用
assert await _check_nonce_dedup_redis(redis, "nonce-exp") is True
# ── nonce dedup (内存降级) ───────────────────────────────
class TestNonceDedupFallback:
"""redis=None 时内存 nonce 去重仍工作。"""
def test_fallback_first_nonce_returns_true(self):
"""内存模式:首次 nonce → True。"""
assert _check_nonce_dedup("nonce-fb-1") is True
def test_fallback_duplicate_returns_false(self):
"""内存模式:重复 nonce → False。"""
assert _check_nonce_dedup("nonce-fb-2") is True
assert _check_nonce_dedup("nonce-fb-2") is False
# ── rate limit (Redis) ───────────────────────────────────
class TestRateLimitRedis:
"""Redis 后端滑动窗口限流。"""
async def test_under_limit_returns_true(self):
"""窗口内未超限 → True。"""
redis = _FakeRedis()
for _ in range(_RATE_LIMIT_MAX):
assert await _check_rate_limit_redis(redis, "1.2.3.4") is True
async def test_over_limit_returns_false(self):
"""超过窗口上限 → False。"""
redis = _FakeRedis()
for _ in range(_RATE_LIMIT_MAX):
await _check_rate_limit_redis(redis, "1.2.3.4")
# 第 _RATE_LIMIT_MAX + 1 次应被拒绝
assert await _check_rate_limit_redis(redis, "1.2.3.4") is False
async def test_different_ips_independent(self):
"""不同 IP 的限流独立。"""
redis = _FakeRedis()
for _ in range(_RATE_LIMIT_MAX):
await _check_rate_limit_redis(redis, "1.1.1.1")
# IP 1.1.1.1 已满,但 2.2.2.2 仍可通过
assert await _check_rate_limit_redis(redis, "2.2.2.2") is True
async def test_window_reset_after_expiry(self):
"""窗口过期后计数重置。"""
redis = _FakeRedis()
for _ in range(_RATE_LIMIT_MAX):
await _check_rate_limit_redis(redis, "3.3.3.3")
assert await _check_rate_limit_redis(redis, "3.3.3.3") is False
# 模拟窗口过期 — 清除 zset 数据
redis._zsets.clear()
redis._expires.clear()
# 过期后可再次通过
assert await _check_rate_limit_redis(redis, "3.3.3.3") is True
# ── rate limit (内存降级) ────────────────────────────────
class TestRateLimitFallback:
"""redis=None 时内存限流仍工作。"""
def test_fallback_under_limit(self):
"""内存模式:未超限 → True。"""
for _ in range(_RATE_LIMIT_MAX):
assert _check_rate_limit("4.4.4.4") is True
def test_fallback_over_limit(self):
"""内存模式:超限 → False。"""
for _ in range(_RATE_LIMIT_MAX):
_check_rate_limit("5.5.5.5")
assert _check_rate_limit("5.5.5.5") is False
# ── backpressure (Redis) ─────────────────────────────────
class TestBackpressureRedis:
"""Redis 后端共享并发计数器。"""
async def test_under_limit_returns_true(self):
"""并发 < 2x 上限 → True。"""
redis = _FakeRedis()
for _ in range(_WEBHOOK_MAX_CONCURRENT * 2):
assert await _acquire_backpressure_slot(redis) is True
async def test_over_limit_returns_false(self):
"""并发 >= 2x 上限 → False。"""
redis = _FakeRedis()
for _ in range(_WEBHOOK_MAX_CONCURRENT * 2):
await _acquire_backpressure_slot(redis)
# 超过 2x 上限应拒绝
assert await _acquire_backpressure_slot(redis) is False
async def test_release_restores_slot(self):
"""释放后槽位恢复,可再次获取。"""
redis = _FakeRedis()
# 获取到上限
for _ in range(_WEBHOOK_MAX_CONCURRENT * 2):
await _acquire_backpressure_slot(redis)
# 超限
assert await _acquire_backpressure_slot(redis) is False
# 释放一个
await _release_backpressure_slot(redis)
# 可再次获取
assert await _acquire_backpressure_slot(redis) is True
async def test_release_decrements_counter(self):
"""release 后计数器递减。"""
redis = _FakeRedis()
await _acquire_backpressure_slot(redis)
assert redis._incr_values[_BACKPRESSURE_KEY] == 1
await _release_backpressure_slot(redis)
assert redis._incr_values[_BACKPRESSURE_KEY] == 0

View File

@ -0,0 +1,182 @@
"""U1 — Gateway KB cache fail-closed 行为测试。
验证安全要求 R1KB settings 读取失败时必须 fail-closed禁用缓存
不得 fail-open默认启用缓存
覆盖场景
1. settings 正常读取 caching_disabled=False 缓存启用
2. settings 正常读取 caching_disabled=True 缓存禁用
3. get_settings_store() 抛异常 fail-closed缓存禁用
4. settings 返回 NoneKB 不存在 fail-closed缓存禁用
5. kb_id=None RAG 请求 不查 settings缓存正常启用
"""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
from agentkit.llm.gateway import LLMGateway
from agentkit.llm.protocol import LLMRequest, LLMResponse, TokenUsage
def _make_response() -> LLMResponse:
"""构造最小 LLMResponse。"""
return LLMResponse(
content="ok",
model="test-model",
usage=TokenUsage(prompt_tokens=5, completion_tokens=3),
)
def _make_gateway_with_cache() -> LLMGateway:
"""构造带 mock 缓存管理器的 LLMGateway避免 litellm 依赖)。
gateway.chat() 调用 LitellmCacheManager.cache_params_for_hit/no_cache类静态方法
should_cache build_cache_key 通过实例调用 mock 这两个即可
"""
gateway = LLMGateway() # 不启用真实 cachelitellm 可能未安装)
mock_manager = MagicMock()
# should_cache: kb_caching_disabled=True 或 user_id=None 时返回 False
mock_manager.should_cache = lambda kb_disabled, uid: not kb_disabled and uid is not None
mock_manager.build_cache_key = MagicMock(return_value="mock_cache_key")
mock_manager.record_cache_result = MagicMock()
gateway._cache_manager = mock_manager
return gateway
def _register_mock_provider(gateway: LLMGateway) -> MagicMock:
"""注册 mock provider返回带 cache 参数的 capture。
模型格式 "test/test-model" _resolve_model "/" 分割为 provider="test" + model="test-model"
"""
provider = MagicMock()
provider.chat = AsyncMock(return_value=_make_response())
gateway.register_provider("test", provider)
return provider
_MODEL = "test/test-model"
def _get_cache_arg(provider: MagicMock) -> dict:
"""从 provider.chat 调用中提取 cache 参数。"""
call_args = provider.chat.call_args
# provider.chat(req) — req 是第一个位置参数
req: LLMRequest = call_args.args[0]
return req._cache or {}
class TestGatewayCacheFailClosed:
"""U1 — KB settings 读取异常时 fail-closed。"""
async def test_settings_caching_false_enables_cache(self):
"""settings 正常读取 caching_disabled=False → cache_key 传入 provider启用缓存"""
gateway = _make_gateway_with_cache()
provider = _register_mock_provider(gateway)
mock_settings = SimpleNamespace(caching_disabled=False)
with patch("agentkit.rag_platform.settings.get_settings_store") as mock_get_store:
mock_store = MagicMock()
mock_store.get_settings = AsyncMock(return_value=mock_settings)
mock_get_store.return_value = mock_store
await gateway.chat(
messages=[{"role": "user", "content": "hi"}],
model=_MODEL,
user_id="u1",
kb_id="kb1",
kb_acl_hash="acl1",
)
cache_arg = _get_cache_arg(provider)
assert "cache_key" in cache_arg, f"Expected cache_key (cache enabled), got {cache_arg}"
async def test_settings_caching_true_disables_cache(self):
"""settings 正常读取 caching_disabled=True → no-cache 传入 provider禁用缓存"""
gateway = _make_gateway_with_cache()
provider = _register_mock_provider(gateway)
mock_settings = SimpleNamespace(caching_disabled=True)
with patch("agentkit.rag_platform.settings.get_settings_store") as mock_get_store:
mock_store = MagicMock()
mock_store.get_settings = AsyncMock(return_value=mock_settings)
mock_get_store.return_value = mock_store
await gateway.chat(
messages=[{"role": "user", "content": "hi"}],
model=_MODEL,
user_id="u1",
kb_id="kb1",
)
cache_arg = _get_cache_arg(provider)
assert cache_arg.get("no-cache") is True, f"Expected no-cache=True, got {cache_arg}"
async def test_settings_exception_fail_closed(self):
"""get_settings_store() 抛异常 → fail-closedno-cache"""
gateway = _make_gateway_with_cache()
provider = _register_mock_provider(gateway)
with patch("agentkit.rag_platform.settings.get_settings_store") as mock_get_store:
mock_store = MagicMock()
mock_store.get_settings = AsyncMock(side_effect=RuntimeError("DB down"))
mock_get_store.return_value = mock_store
await gateway.chat(
messages=[{"role": "user", "content": "hi"}],
model=_MODEL,
user_id="u1",
kb_id="kb1",
)
cache_arg = _get_cache_arg(provider)
assert cache_arg.get("no-cache") is True, (
f"fail-closed: 读取异常应禁用缓存,但 got {cache_arg}"
)
async def test_settings_none_fail_closed(self):
"""settings 返回 NoneKB 不存在)→ fail-closedno-cache"""
gateway = _make_gateway_with_cache()
provider = _register_mock_provider(gateway)
with patch("agentkit.rag_platform.settings.get_settings_store") as mock_get_store:
mock_store = MagicMock()
mock_store.get_settings = AsyncMock(return_value=None)
mock_get_store.return_value = mock_store
await gateway.chat(
messages=[{"role": "user", "content": "hi"}],
model=_MODEL,
user_id="u1",
kb_id="kb_nonexistent",
)
cache_arg = _get_cache_arg(provider)
assert cache_arg.get("no-cache") is True, (
f"Expected no-cache for None settings, got {cache_arg}"
)
async def test_no_kb_id_skips_settings_lookup(self):
"""kb_id=None非 RAG 请求)→ 不查 settings缓存正常启用。"""
gateway = _make_gateway_with_cache()
provider = _register_mock_provider(gateway)
with patch("agentkit.rag_platform.settings.get_settings_store") as mock_get_store:
mock_store = MagicMock()
mock_store.get_settings = AsyncMock()
mock_get_store.return_value = mock_store
await gateway.chat(
messages=[{"role": "user", "content": "hi"}],
model=_MODEL,
user_id="u1",
# kb_id 不传
)
# 不应查询 settings
mock_store.get_settings.assert_not_called()
# 缓存应启用cache_key 传入)
cache_arg = _get_cache_arg(provider)
assert "cache_key" in cache_arg, f"Expected cache_key (no kb_id), got {cache_arg}"

View File

@ -0,0 +1,269 @@
"""U2 — MCP 危险工具黑名单过滤单元测试。
覆盖 6 个暴露路径
- create_mcp_router() REST: GET /tools/list, POST /tools/call
- create_mcp_router() JSON-RPC: tools/list, tools/call
- legacy MCPServer REST: GET /tools/list, POST /tools/call
- legacy MCPServer JSON-RPC: tools/list, tools/call
黑名单工具terminal/shell/file_write/file_read/file_delete依赖 chat
confirmation 流程WebSocket通过 MCP 暴露会绕过用户确认
"""
from __future__ import annotations
import warnings
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import httpx
from fastapi import FastAPI
from agentkit.mcp.server import MCPServer, create_mcp_router
from agentkit.server.auth.middleware import AuthMiddleware
# 与 test_server_auth.py 共用的测试凭据 — 仅限单元测试。
JWT_SECRET = "u13-test-jwt-secret-xxxxxxxxxxxxx"
API_KEY = "u13-test-api-key-yyy"
# ── 辅助函数 ──────────────────────────────────────────────
def _make_mock_tool(
name: str,
description: str = "",
result: str = "ok",
) -> MagicMock:
"""构造一个 mock 工具,模拟 Tool 接口。"""
tool = MagicMock()
tool.name = name
tool.description = description
tool.input_schema = {"type": "object", "properties": {}}
tool.safe_execute = AsyncMock(return_value=result)
return tool
def _make_mock_registry(tools: list) -> MagicMock:
"""构造一个 mock ToolRegistry支持 list_tools() 和 get(name)。"""
registry = MagicMock()
registry.list_tools.return_value = tools
def _get(name: str):
for t in tools:
if t.name == name:
return t
raise KeyError(name)
registry.get = _get
return registry
def _make_registry_with_safe_and_dangerous() -> MagicMock:
"""构造含 1 安全 + 5 危险工具的 registry。
危险工具名与 _MCP_BLOCKED_TOOLS 完全对应验证黑名单全覆盖
"""
return _make_mock_registry(
[
_make_mock_tool("echo", "safe tool", "echo: hi"),
_make_mock_tool("shell", "shell tool", "should-not-reach"),
_make_mock_tool("file_write", "write tool", "should-not-reach"),
_make_mock_tool("file_read", "read tool", "should-not-reach"),
_make_mock_tool("file_delete", "delete tool", "should-not-reach"),
_make_mock_tool("terminal", "terminal tool", "should-not-reach"),
]
)
def _make_app(tool_registry: Any = None) -> FastAPI:
"""构造测试用 FastAPI app挂载 MCP router + AuthMiddleware。"""
app = FastAPI()
app.state.tool_registry = tool_registry
app.add_middleware(AuthMiddleware, jwt_secret=JWT_SECRET, api_key=API_KEY)
mcp_router = create_mcp_router(tool_registry=tool_registry)
app.include_router(mcp_router, prefix="/api/v1/mcp")
return app
def _make_legacy_app(tool_registry: Any = None) -> FastAPI:
"""构造 legacy MCPServer app无认证过滤 DeprecationWarning"""
server = MCPServer(tool_registry=tool_registry)
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
return server._create_app()
_BLOCKED = {"terminal", "shell", "file_write", "file_read", "file_delete"}
_HEADERS = {"X-API-Key": API_KEY}
# ── create_mcp_router() REST 端点 ─────────────────────────
class TestCreateMcpRouterRestFilter:
"""create_mcp_router() 的 REST 端点黑名单过滤。"""
async def test_rest_tools_list_excludes_blocked(self):
"""GET /tools/list 不返回黑名单工具。"""
app = _make_app(tool_registry=_make_registry_with_safe_and_dangerous())
transport = httpx.ASGITransport(app=app)
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
resp = await client.get("/api/v1/mcp/tools/list", headers=_HEADERS)
assert resp.status_code == 200
names = {t["name"] for t in resp.json()["tools"]}
assert names == {"echo"}
assert not (names & _BLOCKED)
async def test_rest_tools_call_safe_tool_succeeds(self):
"""POST /tools/call 安全工具 → 200 + 结果。"""
app = _make_app(tool_registry=_make_registry_with_safe_and_dangerous())
transport = httpx.ASGITransport(app=app)
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
resp = await client.post(
"/api/v1/mcp/tools/call",
json={"name": "echo", "arguments": {}},
headers=_HEADERS,
)
assert resp.status_code == 200
assert "echo: hi" in resp.json()["content"][0]["text"]
async def test_rest_tools_call_shell_returns_404(self):
"""POST /tools/call shell → 404黑名单"""
app = _make_app(tool_registry=_make_registry_with_safe_and_dangerous())
transport = httpx.ASGITransport(app=app)
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
resp = await client.post(
"/api/v1/mcp/tools/call",
json={"name": "shell", "arguments": {"cmd": "rm -rf /"}},
headers=_HEADERS,
)
assert resp.status_code == 404
async def test_rest_tools_call_file_delete_returns_404(self):
"""POST /tools/call file_delete → 404黑名单"""
app = _make_app(tool_registry=_make_registry_with_safe_and_dangerous())
transport = httpx.ASGITransport(app=app)
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
resp = await client.post(
"/api/v1/mcp/tools/call",
json={"name": "file_delete", "arguments": {"path": "/etc/passwd"}},
headers=_HEADERS,
)
assert resp.status_code == 404
# ── create_mcp_router() JSON-RPC 端点 ─────────────────────
class TestCreateMcpRouterJsonRpcFilter:
"""create_mcp_router() 的 JSON-RPC 端点黑名单过滤。"""
async def test_jsonrpc_tools_list_excludes_blocked(self):
"""JSON-RPC tools/list 不包含黑名单工具。"""
app = _make_app(tool_registry=_make_registry_with_safe_and_dangerous())
transport = httpx.ASGITransport(app=app)
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
resp = await client.post(
"/api/v1/mcp/",
json={"jsonrpc": "2.0", "method": "tools/list", "id": 1},
headers=_HEADERS,
)
assert resp.status_code == 200
body = resp.json()
names = {t["name"] for t in body["result"]["tools"]}
assert names == {"echo"}
assert not (names & _BLOCKED)
async def test_jsonrpc_tools_call_shell_returns_iserror(self):
"""JSON-RPC tools/call shell → isError=True。"""
app = _make_app(tool_registry=_make_registry_with_safe_and_dangerous())
transport = httpx.ASGITransport(app=app)
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
resp = await client.post(
"/api/v1/mcp/",
json={
"jsonrpc": "2.0",
"method": "tools/call",
"params": {"name": "shell", "arguments": {"cmd": "rm -rf /"}},
"id": 2,
},
headers=_HEADERS,
)
assert resp.status_code == 200
result = resp.json()["result"]
assert result.get("isError") is True
assert "blocked" in result["content"][0]["text"].lower()
# ── legacy MCPServer REST 端点 ────────────────────────────
class TestLegacyMcpServerRestFilter:
"""legacy MCPServer独立 app的 REST 端点黑名单过滤。"""
async def test_legacy_rest_tools_list_excludes_blocked(self):
"""legacy GET /tools/list 不返回黑名单工具。"""
app = _make_legacy_app(tool_registry=_make_registry_with_safe_and_dangerous())
transport = httpx.ASGITransport(app=app)
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
resp = await client.get("/tools/list")
assert resp.status_code == 200
names = {t["name"] for t in resp.json()["tools"]}
assert names == {"echo"}
assert not (names & _BLOCKED)
async def test_legacy_rest_tools_call_shell_returns_error(self):
"""legacy POST /tools/call shell → 返回 errorblocked"""
app = _make_legacy_app(tool_registry=_make_registry_with_safe_and_dangerous())
transport = httpx.ASGITransport(app=app)
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
resp = await client.post(
"/tools/call",
json={"name": "shell", "arguments": {"cmd": "rm -rf /"}},
)
assert resp.status_code == 200
body = resp.json()
assert "error" in body
assert "blocked" in body["error"].lower()
# ── legacy MCPServer JSON-RPC 端点 ────────────────────────
class TestLegacyMcpServerJsonRpcFilter:
"""legacy MCPServer 的 JSON-RPC 端点黑名单过滤。"""
async def test_legacy_jsonrpc_tools_list_excludes_blocked(self):
"""legacy JSON-RPC tools/list 不包含黑名单工具。"""
app = _make_legacy_app(tool_registry=_make_registry_with_safe_and_dangerous())
transport = httpx.ASGITransport(app=app)
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
resp = await client.post(
"/",
json={"jsonrpc": "2.0", "method": "tools/list", "id": 1},
)
assert resp.status_code == 200
body = resp.json()
names = {t["name"] for t in body["result"]["tools"]}
assert names == {"echo"}
assert not (names & _BLOCKED)
async def test_legacy_jsonrpc_tools_call_shell_returns_iserror(self):
"""legacy JSON-RPC tools/call shell → isError=True。"""
app = _make_legacy_app(tool_registry=_make_registry_with_safe_and_dangerous())
transport = httpx.ASGITransport(app=app)
async with httpx.AsyncClient(transport=transport, base_url="http://test") as client:
resp = await client.post(
"/",
json={
"jsonrpc": "2.0",
"method": "tools/call",
"params": {"name": "shell", "arguments": {"cmd": "rm -rf /"}},
"id": 2,
},
)
assert resp.status_code == 200
result = resp.json()["result"]
assert result.get("isError") is True
assert "blocked" in result["content"][0]["text"].lower()