diff --git a/docs/plans/2026-06-26-006-fix-p0-security-multi-instance-hardening-plan.md b/docs/plans/2026-06-26-006-fix-p0-security-multi-instance-hardening-plan.md new file mode 100644 index 0000000..910c784 --- /dev/null +++ b/docs/plans/2026-06-26-006-fix-p0-security-multi-instance-hardening-plan.md @@ -0,0 +1,255 @@ +--- +title: "fix: P0 安全与多实例一致性加固" +date: 2026-06-26 +type: fix +status: planned +origin: 代码走查报告(portal-platform-evolution 合并后审查) +--- + +# fix: P0 安全与多实例一致性加固 + +## Summary + +修复代码走查发现的 4 项高优先级问题,分两组:(1) 安全 fail-open 默认值 — Gateway KB cache fail-closed + MCP dangerous-tool 黑名单过滤;(2) 多实例部署一致性 — SecretsStore 与 Channels 全局状态迁移到 Redis,消除多 worker 下的安全与状态共享缺口。 + +## Problem Frame + +`portal-platform-evolution` 合并引入了多端渠道(U10-U12)、MCP 协议(U13/U16)、LiteLLM 缓存(U17)等特性。代码走查发现两类系统性风险: + +1. **安全 fail-open**:Gateway 在 KB settings 读取异常时默认启用缓存(可能泄漏禁用缓存的 KB 数据);MCP 端点暴露所有工具(含 ShellTool 等危险工具),绕过 chat 流程的 confirmation 机制。 +2. **多实例不一致**:SecretsStore 和 Channels 路由使用模块级内存字典存储凭证、nonce、限流、backpressure 状态,多 worker 部署下失效,与 README 宣称的 K8s 多实例部署目标冲突。 + +## Requirements + +- **R1**:Gateway KB cache 在 settings 读取失败时必须 fail-closed(禁用缓存),不得 fail-open +- **R2**:MCP `/tools/call` 和 `/tools/list`(含 JSON-RPC 端点)必须过滤危险工具,禁止绕过 chat confirmation 流程 +- **R3**:SecretsStore 加密凭证必须存储在 Redis,多 worker 共享,保留 AES-256-GCM 加密层 +- **R4**:Channels webhook nonce dedup、rate limit、backpressure、渠道配置必须存储在 Redis,多 worker 一致 +- **R5**:所有改动不破坏现有单进程行为(Redis 不可用时降级到内存,需显式 log warning) +- **R6**:每个修复单元包含可运行的验证测试 + +## Key Technical Decisions + +### KTD1: H4 危险工具过滤采用黑名单 + +**决策**:使用 `_MCP_BLOCKED_TOOLS: frozenset[str]` 黑名单,而非在 Tool 基类增加 `requires_confirmation` 元数据字段。 + +**理由**:黑名单改动最小(ponytail 原则),仅影响 mcp/server.py 一个文件。元数据方案需修改 Tool ABC + 所有工具子类 + ToolRegistry,改动面过大。黑名单天花板已在代码注释标注(`ponytail: 黑名单需手动维护,新增危险工具需同步更新`)。 + +### KTD2: H1/H2 Redis 连接复用 app.state.working_redis_client + +**决策**:复用 `app.py` lifespan 已创建的 `app.state.working_redis_client`(`aioredis.from_url(redis_url, decode_responses=True)`),通过 FastAPI `Depends` 或 `request.app.state` 注入到 channels 路由和 SecretsStore。 + +**理由**:避免新建 Redis 连接池(资源浪费),与 WorkingMemory、TaskManager 等现有子系统共享同一连接。`decode_responses=True` 与 RedisBus 模式一致。 + +### KTD3: 降级策略 — Redis 不可用时 fail-closed + +**决策**:Redis 连接失败时,webhook 端点返回 503(而非降级到内存),secrets 读取返回 None(而非空字典)。 + +**理由**:安全优先。降级到内存会导致多 worker 间状态不一致,nonce dedup 失效可能引发重放攻击。fail-closed 符合 AGENTS.md 安全约定。 + +## Implementation Units + +### U1. Gateway KB cache fail-closed + +**Goal**: KB settings 读取异常时禁用缓存,而非默认启用。 + +**Requirements**: R1 + +**Dependencies**: 无 + +**Files**: +- `src/agentkit/llm/gateway.py`(修改第 137-146 行) +- `tests/unit/llm/test_gateway_cache_failclosed.py`(新建) + +**Approach**: 将 `kb_caching_disabled` 默认值从 `False` 改为 `True`(fail-closed),仅在 settings 成功读取且 `caching_disabled=False` 时才启用缓存。读取异常或 settings 为 None 时保持 `True`。 + +**Patterns to follow**: 现有 `should_cache` 已有 `per_user_namespace + user_id=None` 防护逻辑,本单元补齐 KB settings 读取层的防护。 + +**Test scenarios**: +- **Happy path**: settings 正常读取 `caching_disabled=False` → 缓存启用,`should_cache` 返回 True +- **Happy path**: settings 正常读取 `caching_disabled=True` → 缓存禁用,`should_cache` 返回 False +- **Error path**: `get_settings_store().get_settings()` 抛异常 → `kb_caching_disabled=True`,`should_cache` 返回 False(fail-closed) +- **Edge case**: settings 返回 None(KB 不存在)→ `kb_caching_disabled=True` +- **Edge case**: `kb_id=None`(非 RAG 请求)→ 不查 settings,缓存正常启用 + +**Verification**: `pytest tests/unit/llm/test_gateway_cache_failclosed.py -v` 全绿;mock settings store 抛异常时断言 `should_cache` 返回 False。 + +--- + +### U2. MCP dangerous-tool 黑名单过滤 + +**Goal**: 阻止危险工具通过 MCP 端点执行,绕过 chat confirmation 流程。 + +**Requirements**: R2 + +**Dependencies**: 无 + +**Files**: +- `src/agentkit/mcp/server.py`(修改 `_all_tools`、`_find_tool`、REST `call_tool`、JSON-RPC `tools/call`) +- `tests/unit/mcp/test_server_dangerous_tools.py`(新建) + +**Approach**: 在 `server.py` 顶部定义 `_MCP_BLOCKED_TOOLS` 黑名单(`shell`、`file_write`、`file_delete`、`terminal_execute`)。`_all_tools()` 和 `_find_tool()` 过滤黑名单工具。`list_tools` 和 JSON-RPC `tools/list` 不返回黑名单工具。`call_tool` 对黑名单工具返回 404。 + +**Technical design**(directional guidance): + +```python +# ponytail: 黑名单需手动维护,新增危险工具需同步更新。 +# 天花板:若工具名动态生成或别名注册,黑名单可能漏判。 +# 升级路径:Tool 基类增加 requires_confirmation 元数据字段。 +_MCP_BLOCKED_TOOLS: frozenset[str] = frozenset({ + "shell", "file_write", "file_delete", "terminal_execute", +}) +``` + +**Patterns to follow**: `shell.py` 已有 `_DANGEROUS_BINARIES` frozenset 模式。 + +**Test scenarios**: +- **Happy path**: `GET /tools/list` 不返回黑名单工具 +- **Happy path**: `POST /tools/call` with `{"name": "safe_tool"}` 正常执行 +- **Error path**: `POST /tools/call` with `{"name": "shell"}` → 404 +- **Error path**: `POST /tools/call` with `{"name": "file_delete"}` → 404 +- **Integration**: JSON-RPC `tools/call` with `{"method": "tools/call", "params": {"name": "shell"}}` → `isError: True` +- **Edge case**: JSON-RPC `tools/list` 不包含黑名单工具 + +**Verification**: `pytest tests/unit/mcp/test_server_dangerous_tools.py -v` 全绿;黑名单工具在 list 和 call 端点均不可见。 + +--- + +### U3. SecretsStore Redis 迁移 + +**Goal**: 加密凭证存储迁移到 Redis,多 worker 共享,保留 AES-256-GCM 加密层。 + +**Requirements**: R3, R5 + +**Dependencies**: 无(可与 U4 并行,共享 Redis 注入基础设施) + +**Files**: +- `src/agentkit/channels/secrets.py`(修改 `SecretsStore.__init__`、`set_secret`、`get_secret`、`delete_secret`、`list_keys`) +- `src/agentkit/server/routes/channels.py`(修改 `_get_secrets_store` 注入 Redis) +- `src/agentkit/server/app.py`(lifespan 中将 `working_redis_client` 传入 SecretsStore) +- `tests/unit/channels/test_secrets_redis.py`(新建) + +**Approach**: `SecretsStore.__init__` 新增 `redis` 参数(`aioredis.Redis | None`)。`_store` 字典替换为 Redis 调用:`set_secret` → `redis.set(f"secrets:{key}", entry.model_dump_json())`;`get_secret` → `redis.get` + `SecretEntry.model_validate_json`;`delete_secret` → `redis.delete`;`list_keys` → `redis.scan_iter`。加密层(`encrypt`/`decrypt`/`_derive_key`)不变。Redis 为 None 时降级到内存(仅开发模式,log warning)。 + +**Patterns to follow**: `RedisBus._get_redis()` 懒初始化模式;`SharedWorkspace` 的 `redis_client` 注入模式。 + +**Test scenarios**: +- **Happy path**: `set_secret("k", "v")` 后 `get_secret("k")` 返回 `"v"`(Redis mock) +- **Happy path**: `delete_secret("k")` 后 `get_secret("k")` 返回 None +- **Happy path**: `list_keys(prefix="feishu:")` 返回匹配前缀的 key 列表 +- **Edge case**: Redis 为 None(降级模式)→ 内存字典行为,log warning +- **Error path**: Redis 连接失败 → `get_secret` 抛异常或返回 None(fail-closed) +- **Integration**: 加密-存储-读取-解密往返:明文 → encrypt → Redis SET → Redis GET → decrypt → 明文一致 + +**Verification**: `pytest tests/unit/channels/test_secrets_redis.py -v` 全绿;多 worker 场景下 worker A `set_secret` 后 worker B `get_secret` 能读到(需 Redis faker 或集成测试)。 + +--- + +### U4. Channels 全局状态 Redis 迁移 + +**Goal**: nonce dedup、rate limit、backpressure、渠道配置迁移到 Redis,多 worker 一致。 + +**Requirements**: R4, R5, R6 + +**Dependencies**: U3(共享 Redis 注入基础设施) + +**Files**: +- `src/agentkit/server/routes/channels.py`(修改 `_channels`、`_rate_limits`、`_seen_nonces`、`_pending_webhook_tasks` 相关逻辑) +- `tests/unit/channels/test_webhook_redis_state.py`(新建) + +**Approach**: + +1. **nonce dedup** → Redis `SET NX EX`:`await redis.set(f"nonce:{nonce}", "1", ex=300, nx=True)` 返回 None 表示重复。 +2. **rate limit** → Redis ZSET 滑动窗口:`ZREMRANGEBYSCORE` 清理过期 + `ZADD` + `ZCARD` 计数。 +3. **backpressure** → Redis `INCR`/`DECR` 共享计数器:`INCR webhook:concurrent`,超限 `DECR` 并返回 429,执行完毕 `DECR`。`EXPIRE 30s` 防 crash 计数不归零。 +4. **渠道配置** → Redis Hash:`HSET channel:{id}` 存配置 JSON,`HGETALL` 读取。`_adapter_cache` 保留进程内缓存(只读快照,配置变更时 invalidate)。 +5. **Redis 注入**:webhook handler 通过 `request.app.state.working_redis_client` 获取连接。 + +**Technical design**(directional guidance): + +```python +# nonce dedup — 原子操作,TTL 自动过期 +is_new = await redis.set(f"nonce:{nonce}", "1", ex=int(_NONCE_TTL), nx=True) +if not is_new: + return Response(status_code=200) # 重复事件,飞书要求 3s 内响应 + +# rate limit — ZSET 滑动窗口 +key = f"ratelimit:{client_ip}" +now = time.time() +pipe = redis.pipeline() +pipe.zremrangebyscore(key, 0, now - _RATE_LIMIT_WINDOW) +pipe.zadd(key, {str(now): now}) +pipe.zcard(key) +pipe.expire(key, int(_RATE_LIMIT_WINDOW)) +_, _, count, _ = await pipe.execute() +if count > _RATE_LIMIT_MAX: + return Response(status_code=429) + +# backpressure — 共享计数器 +current = await redis.incr("webhook:concurrent") +if current > _WEBHOOK_MAX_CONCURRENT * 2: + await redis.decr("webhook:concurrent") + return Response(status_code=429) +try: + # ... process webhook ... +finally: + await redis.decr("webhook:concurrent") +``` + +**Patterns to follow**: Redis pipeline 模式(`usage_store.py` 已有);`_adapter_cache` 保留进程内缓存(`_invalidate_adapter_cache` 逻辑不变)。 + +**Test scenarios**: +- **Happy path**: 首次 nonce → `SET NX` 返回 True,请求正常处理 +- **Happy path**: 重复 nonce → `SET NX` 返回 None,返回 200(不重复处理) +- **Happy path**: rate limit 窗口内 100 请求通过,第 101 请求返回 429 +- **Happy path**: backpressure 并发 < 2x 上限时请求通过 +- **Error path**: backpressure 并发 >= 2x 上限时返回 429 +- **Edge case**: nonce TTL 过期后相同 nonce 可再次使用(mock `time.sleep` 或 Redis TTL) +- **Edge case**: rate limit 窗口滚动后计数重置 +- **Integration**: 渠道配置写入 Redis 后,新 worker 的 `_adapter_cache` miss 时从 Redis 读取并缓存 + +**Verification**: `pytest tests/unit/channels/test_webhook_redis_state.py -v` 全绿;`ruff check src/ && ruff format src/` 通过;现有 `tests/unit/channels/test_wecom.py` 等不回归。 + +--- + +## Scope Boundaries + +### In Scope + +- 4 项高优先级问题修复(H3/H4/H1/H2) +- Redis 连接复用现有 `app.state.working_redis_client` +- 单元测试覆盖每个修复单元 + +### Out of Scope + +- 中优先级问题(M1-M8):feishu/slack decode 异常、dingtalk 时间戳窗口、mcp/client 性能、app.py shutdown 超时等 +- 低优先级问题(L1-L3):X-Forwarded-For 信任、生产 guard 触发条件、Any 类型 +- PostgreSQL 持久化迁移(Redis 已满足多实例共享需求) +- MCP confirmation 协议实现(黑名单方案足够,元数据方案为后续升级路径) +- Tool 基类 `requires_confirmation` 元数据重构 + +### Deferred to Follow-Up Work + +- M1 feishu/slack `body.decode("utf-8")` 异常捕获 → 独立 PR +- M4 mcp/server.py 异常消息脱敏 + REST/JSON-RPC 代码去重 → 独立 PR +- M6 app.py shutdown `asyncio.gather` 超时 → 独立 PR +- H4 升级路径:Tool 基类 `requires_confirmation` 元数据 → 下个迭代 + +## Risks & Dependencies + +- **风险 1**:H1/H2 Redis 迁移后,Redis 不可用会导致 webhook 端点完全失效(fail-closed)。缓解:lifespan 启动时 ping Redis,不可用时 log error 但不阻止启动(单进程降级模式仍可用)。 +- **风险 2**:U4 backpressure `INCR/DECR` 非原子,crash 可能导致计数不归零。缓解:`EXPIRE 30s` 安全网 + 定期清理。 +- **依赖 1**:U4 依赖 U3 完成的 Redis 注入基础设施。 +- **依赖 2**:所有单元测试使用 `fakeredis` 或 `unittest.mock.AsyncMock` mock Redis,不依赖真实 Redis 实例。 + +## System-Wide Impact + +- **运维**:多 worker 部署(gunicorn -w 4)现在可正确共享渠道状态,无需 sticky session +- **安全**:MCP 端点不再暴露危险工具;KB 缓存 fail-closed 防止数据泄漏 +- **性能**:Redis 网络往返替代内存字典,单请求延迟增加 ~1ms(可接受);backpressure 跨 worker 生效 +- **兼容性**:单进程开发模式(无 Redis)降级到内存存储,行为不变 + +## Open Questions + +无 — 所有技术决策已在 KTD1-KTD3 中明确。 diff --git a/src/agentkit/channels/secrets.py b/src/agentkit/channels/secrets.py index 6213e14..a2e4f3a 100644 --- a/src/agentkit/channels/secrets.py +++ b/src/agentkit/channels/secrets.py @@ -7,7 +7,8 @@ KTD8 关键决策: (由 server 启动钩子实现,本模块提供 ``assert_production_master_key`` 辅助)。 - Master key 轮换采用双密钥窗口策略(key_id 字段标记)。 -当前为内存存储实现,PG 迁移预留接口(``_store: dict`` → 未来替换为 ORM session)。 +U3:Redis 后端支持多 worker 共享。``redis`` 参数为 None 时降级到内存字典 +(仅开发模式),生产环境必须注入 Redis 客户端。 """ from __future__ import annotations @@ -16,6 +17,7 @@ import base64 import logging import os +import redis.asyncio as aioredis from pydantic import BaseModel, ConfigDict logger = logging.getLogger(__name__) @@ -74,21 +76,41 @@ class SecretsStore: 使用 AES-256-GCM 加密,HKDF with per-row salt 派生 per-row 密钥。 Master key 从环境变量 ``AGENTKIT_MASTER_KEY`` 读取(开发 fallback); 生产环境应通过 KMS 提供 master key 并显式传入。 + + U3:``redis`` 参数注入 Redis 客户端后,凭证存储在 Redis(多 worker 共享)。 + ``redis=None`` 时降级到内存字典(仅开发模式)。 """ - def __init__(self, master_key: bytes | None = None, *, key_source: str = "env"): + # Redis key 前缀 — 遵循 codebase 约定(agentkit::) + _REDIS_PREFIX = "agentkit:secrets:" + + def __init__( + self, + master_key: bytes | None = None, + *, + key_source: str = "env", + redis: aioredis.Redis | None = None, + ): """初始化 secrets store。 Args: master_key: 显式传入的 master key。若为 None 则从环境变量加载。 key_source: master key 来源标记,传给 ``assert_production_master_key``。 生产环境应设为 "kms"。 + redis: Redis 客户端(``redis.asyncio.Redis``)。注入后凭证存储在 Redis, + 多 worker 共享。None 时降级到内存字典(仅开发模式)。 """ self._master_key = master_key or self._load_master_key() # 生产 guard:若 key_source="env" 且在生产模式下,构造即失败。 assert_production_master_key(self._master_key, source=key_source) - # 内存存储(PG 迁移预留接口:替换为 ORM session 即可) + self._redis = redis + # 内存降级存储(redis=None 时使用) self._store: dict[str, SecretEntry] = {} + if redis is None: + logger.warning( + "SecretsStore 运行于内存降级模式(redis=None)— " + "多 worker 部署下状态不共享,仅限开发使用。" + ) def _load_master_key(self) -> bytes: """从环境变量加载 master key(开发 fallback)。""" @@ -148,18 +170,30 @@ class SecretsStore: """存储加密凭证。覆盖同名 key。""" entry = self.encrypt(value) entry.key = key - self._store[key] = entry + if self._redis is not None: + await self._redis.set(self._REDIS_PREFIX + key, entry.model_dump_json()) + else: + self._store[key] = entry return entry async def get_secret(self, key: str) -> str | None: """读取并解密凭证。key 不存在返回 None。""" - entry = self._store.get(key) - if entry is None: - return None + if self._redis is not None: + raw = await self._redis.get(self._REDIS_PREFIX + key) + if raw is None: + return None + entry = SecretEntry.model_validate_json(raw) + else: + entry = self._store.get(key) + if entry is None: + return None return self.decrypt(entry) async def delete_secret(self, key: str) -> bool: """删除凭证。返回是否删除成功。""" + if self._redis is not None: + deleted = await self._redis.delete(self._REDIS_PREFIX + key) + return deleted > 0 if key in self._store: del self._store[key] return True @@ -167,6 +201,12 @@ class SecretsStore: async def list_keys(self, prefix: str | None = None) -> list[str]: """列出凭证键。可选前缀过滤。""" + if self._redis is not None: + pattern = self._REDIS_PREFIX + (prefix or "") + "*" + keys: list[str] = [] + async for k in self._redis.scan_iter(match=pattern): + keys.append(k[len(self._REDIS_PREFIX) :]) + return keys if prefix: return [k for k in self._store if k.startswith(prefix)] return list(self._store.keys()) diff --git a/src/agentkit/llm/gateway.py b/src/agentkit/llm/gateway.py index f0d62e7..337d395 100644 --- a/src/agentkit/llm/gateway.py +++ b/src/agentkit/llm/gateway.py @@ -134,7 +134,11 @@ class LLMGateway: from agentkit.llm.cache import LitellmCacheManager # 解析 KB caching_disabled(安全要求 c) - kb_caching_disabled = False + # 非 RAG 请求(kb_id=None)→ 默认启用缓存(无 KB 数据需保护)。 + # RAG 请求(kb_id!=None)→ fail-closed:默认禁用缓存,仅在 settings + # 明确返回 caching_disabled=False 时启用。防止 DB 异常时 fail-open + # 导致禁用缓存的 KB 数据泄漏到缓存。 + kb_caching_disabled = kb_id is not None if kb_id is not None: try: from agentkit.rag_platform.settings import get_settings_store @@ -142,8 +146,10 @@ class LLMGateway: settings = await get_settings_store().get_settings(kb_id) if settings is not None: kb_caching_disabled = settings.caching_disabled + # settings 为 None(KB 不存在)→ 保持 True(fail-closed) except Exception as e: logger.warning(f"Failed to read KB cache settings for kb_id={kb_id}: {e}") + # 读取异常 → 保持 True(fail-closed,禁用缓存) if self._cache_manager.should_cache(kb_caching_disabled, user_id): cache_key = self._cache_manager.build_cache_key( diff --git a/src/agentkit/mcp/server.py b/src/agentkit/mcp/server.py index 4195d5e..e8253ca 100644 --- a/src/agentkit/mcp/server.py +++ b/src/agentkit/mcp/server.py @@ -25,6 +25,30 @@ logger = logging.getLogger(__name__) # 认证由主 app 的 AuthMiddleware 处理,这里只做权限校验。 _mcp_member_auth = require_permission(Permission.CHAT) +# ponytail: 危险工具黑名单 — 这些工具依赖 chat confirmation 流程(WebSocket), +# 通过 MCP 暴露会绕过用户确认机制。黑名单需手动维护,新增危险工具需同步更新。 +# 天花板:若工具名动态生成或别名注册,黑名单可能漏判。 +# 升级路径:Tool 基类增加 requires_confirmation 元数据字段,按属性自动过滤。 +# U5c: 与 publisher.py 的 _DANGEROUS_TOOL_NAMES 保持一致(单一真相源)。 +_MCP_BLOCKED_TOOLS: frozenset[str] = frozenset( + { + "terminal", # 终端执行(与 ShellTool 协同)— 危险命令需 confirmation + "shell", # ShellTool — 危险命令需 confirmation + "file_write", # 文件写入 + "file_read", # 文件读取(可能泄露敏感配置) + "file_delete", # 文件删除 + } +) + + +def _serialize_tool(tool: Tool) -> dict[str, Any]: + """将 Tool 序列化为 MCP 协议响应字典。""" + return { + "name": tool.name, + "description": tool.description, + "inputSchema": tool.input_schema or {}, + } + def create_mcp_router( tool_registry: Any = None, @@ -52,7 +76,7 @@ def create_mcp_router( router = APIRouter(tags=["mcp"]) def _all_tools() -> list[Tool]: - """合并 ToolRegistry 与已发布工具。""" + """合并 ToolRegistry 与已发布工具,过滤危险工具。""" tools: list[Tool] = [] if tool_registry is not None: tools.extend(tool_registry.list_tools()) @@ -61,10 +85,14 @@ def create_mcp_router( tools.extend(published_tools_getter()) except Exception: logger.exception("published_tools_getter failed") - return tools + # 过滤危险工具 — 防止绕过 chat confirmation 流程 + return [t for t in tools if t.name not in _MCP_BLOCKED_TOOLS] def _find_tool(name: str) -> Tool | None: - """按名查找工具(先 registry 后已发布)。""" + """按名查找工具(先 registry 后已发布),危险工具返回 None。""" + if name in _MCP_BLOCKED_TOOLS: + logger.warning(f"MCP tool '{name}' blocked — requires chat confirmation flow") + return None if tool_registry is not None: try: return tool_registry.get(name) @@ -80,16 +108,7 @@ def create_mcp_router( async def list_tools(_user: dict = Depends(_mcp_member_auth)) -> dict[str, Any]: """列出所有可用的 MCP 工具。""" tools = _all_tools() - return { - "tools": [ - { - "name": t.name, - "description": t.description, - "inputSchema": t.input_schema or {}, - } - for t in tools - ] - } + return {"tools": [_serialize_tool(t) for t in tools]} @router.post("/tools/call") async def call_tool( @@ -152,16 +171,7 @@ def create_mcp_router( } elif method == "tools/list": tools = _all_tools() - result = { - "tools": [ - { - "name": t.name, - "description": t.description, - "inputSchema": t.input_schema or {}, - } - for t in tools - ] - } + result = {"tools": [_serialize_tool(t) for t in tools]} elif method == "tools/call": tool_name = params.get("name", "") arguments = params.get("arguments", {}) @@ -171,6 +181,13 @@ def create_mcp_router( "isError": True, "content": [{"type": "text", "text": "Tool not found"}], } + elif tool_name in _MCP_BLOCKED_TOOLS: + # 显式黑名单检查 — 与 legacy JSON-RPC 一致,提供清晰审计反馈 + logger.warning(f"MCP tool '{tool_name}' blocked — requires chat confirmation flow") + result = { + "isError": True, + "content": [{"type": "text", "text": f"Tool '{tool_name}' is blocked via MCP"}], + } else: tool = _find_tool(tool_name) if tool is None: @@ -240,16 +257,9 @@ class MCPServer: if self._tool_registry is None: return {"tools": []} tools = self._tool_registry.list_tools() - return { - "tools": [ - { - "name": t.name, - "description": t.description, - "inputSchema": t.input_schema or {}, - } - for t in tools - ] - } + # 过滤危险工具(与 create_mcp_router 一致) + tools = [t for t in tools if t.name not in _MCP_BLOCKED_TOOLS] + return {"tools": [_serialize_tool(t) for t in tools]} @app.post("/tools/call") async def call_tool(request: dict): @@ -259,6 +269,9 @@ class MCPServer: if not tool_name or self._tool_registry is None: return {"error": "Tool not specified or registry not configured"} + if tool_name in _MCP_BLOCKED_TOOLS: + return {"error": f"Tool '{tool_name}' is blocked via MCP"} + try: tool = self._tool_registry.get(tool_name) result = await tool.safe_execute(**arguments) @@ -302,12 +315,7 @@ class MCPServer: tools = self._tool_registry.list_tools() result = { "tools": [ - { - "name": t.name, - "description": t.description, - "inputSchema": t.input_schema or {}, - } - for t in tools + _serialize_tool(t) for t in tools if t.name not in _MCP_BLOCKED_TOOLS ] } elif method == "tools/call": @@ -319,6 +327,13 @@ class MCPServer: "isError": True, "content": [{"type": "text", "text": "Tool not found"}], } + elif tool_name in _MCP_BLOCKED_TOOLS: + result = { + "isError": True, + "content": [ + {"type": "text", "text": f"Tool '{tool_name}' is blocked via MCP"} + ], + } else: try: tool = self._tool_registry.get(tool_name) diff --git a/src/agentkit/server/app.py b/src/agentkit/server/app.py index d01bede..d34e0d3 100644 --- a/src/agentkit/server/app.py +++ b/src/agentkit/server/app.py @@ -975,9 +975,19 @@ def create_app( redis_url = server_config.memory["working"].get( "redis_url", "redis://localhost:6379" ) - redis_client = aioredis.from_url(redis_url, decode_responses=True) + # U5c: socket_timeout 防止单点 Redis 故障时网关请求挂死。 + redis_client = aioredis.from_url( + redis_url, + decode_responses=True, + socket_timeout=5.0, + socket_connect_timeout=5.0, + ) working = WorkingMemory(redis=redis_client) app.state.working_redis_client = redis_client + # U3:注入 Redis 到 channels 模块(SecretsStore 多 worker 共享) + from agentkit.server.routes.channels import _set_redis_client + + _set_redis_client(redis_client) if server_config.memory.get("semantic", {}).get("enabled"): sem_conf = server_config.memory["semantic"] diff --git a/src/agentkit/server/routes/channels.py b/src/agentkit/server/routes/channels.py index 22bf8e1..8b6033f 100644 --- a/src/agentkit/server/routes/channels.py +++ b/src/agentkit/server/routes/channels.py @@ -23,6 +23,7 @@ import re import time from collections import OrderedDict from typing import Any +from uuid import uuid4 from fastapi import APIRouter, Depends, HTTPException, Request from fastapi.responses import Response @@ -52,17 +53,29 @@ _CHANNEL_ID_RE = re.compile(r"^[a-z0-9][a-z0-9_-]{0,63}$") # 凭证字段名校验:字母数字下划线,1-64 字符 _SECRET_NAME_RE = re.compile(r"^[a-zA-Z0-9_]{1,64}$") -# ponytail: 模块级单例 store。当前为内存实现;PG 迁移后改为请求级 session。 -# 天花板:多进程部署下状态不共享;升级路径:注入 app.state.secrets_store, -# 由 lifespan 绑定到 PG session factory。 +# ponytail: 模块级单例 store。U3 后支持 Redis 后端(多 worker 共享)。 +# _redis_client 由 app.py lifespan 通过 _set_redis_client() 注入。 +# redis=None 时 SecretsStore 降级到内存字典(仅开发模式)。 _secrets_store: SecretsStore | None = None +_redis_client: Any = None + + +def _set_redis_client(redis: Any) -> None: + """注入 Redis 客户端(由 app.py lifespan 调用)。 + + 设置后下次 _get_secrets_store() 会用该 redis 客户端构造 SecretsStore。 + """ + global _redis_client + _redis_client = redis + # 重置 store 以便下次 _get_secrets_store 用新 redis 重建 + _reset_secrets_store() def _get_secrets_store() -> SecretsStore: - """获取全局 secrets store 单例(懒加载)。""" + """获取全局 secrets store 单例(懒加载,注入 redis)。""" global _secrets_store if _secrets_store is None: - _secrets_store = SecretsStore() + _secrets_store = SecretsStore(redis=_redis_client) return _secrets_store @@ -164,6 +177,71 @@ def _reset_webhook_state() -> None: _pending_webhook_tasks.clear() +# --------------------------------------------------------------------------- +# U4: Redis 后端 — nonce dedup / rate limit / backpressure +# +# 多 worker 部署下,内存状态不共享。Redis 后端确保 nonce dedup、限流、 +# backpressure 跨 worker 一致。redis=None 时降级到上方内存实现。 +# ponytail: 天花板 — _channels 渠道配置 dict 仍为进程内存储,多 worker 下 +# 配置变更不可见。升级路径:_channels 迁移到 Redis Hash + read-through cache。 +# --------------------------------------------------------------------------- + +# Redis key 前缀 — 遵循 codebase 约定(agentkit::) +_REDIS_NS = "agentkit:channels:webhook:" +_NONCE_KEY = _REDIS_NS + "nonce:{nonce}" +_RATELIMIT_KEY = _REDIS_NS + "ratelimit:{ip}" +_BACKPRESSURE_KEY = _REDIS_NS + "concurrent" +_BACKPRESSURE_TTL = 30 # 秒 — crash 安全网,防止计数不归零 + + +def _get_redis_from_request(request: Request) -> Any: + """从 request.app.state 提取 working_redis_client(可能为 None)。""" + return getattr(request.app.state, "working_redis_client", None) + + +async def _check_rate_limit_redis(redis: Any, client_ip: str) -> bool: + """Redis ZSET 滑动窗口限流。返回 True 放行,False 超限。""" + key = _RATELIMIT_KEY.format(ip=client_ip) + now = time.time() + # U5c: member 加 uuid 后缀避免同时间戳碰撞(3 评审员一致标记)。 + pipe = redis.pipeline() + pipe.zremrangebyscore(key, 0, now - _RATE_LIMIT_WINDOW) + pipe.zadd(key, {f"{now}:{uuid4().hex}": now}) + pipe.zcard(key) + pipe.expire(key, int(_RATE_LIMIT_WINDOW)) + _, _, count, _ = await pipe.execute() + return count <= _RATE_LIMIT_MAX + + +async def _check_nonce_dedup_redis(redis: Any, nonce: str) -> bool: + """Redis SET NX EX nonce 去重。返回 True 新 nonce,False 重复。""" + key = _NONCE_KEY.format(nonce=nonce) + # SET key 1 EX ttl NX — 原子操作,返回 True 表示新增成功 + is_new = await redis.set(key, "1", ex=int(_NONCE_TTL), nx=True) + return bool(is_new) + + +async def _acquire_backpressure_slot(redis: Any) -> bool: + """Redis INCR 共享并发计数。返回 True 获得槽位,False 超限。 + + 调用方在任务完成后必须调用 _release_backpressure_slot(redis)。 + pipeline 合并 INCR + EXPIRE 为单次往返。 + """ + pipe = redis.pipeline() + pipe.incr(_BACKPRESSURE_KEY) + pipe.expire(_BACKPRESSURE_KEY, _BACKPRESSURE_TTL) + current, _ = await pipe.execute() + if current > _WEBHOOK_MAX_CONCURRENT * 2: + await redis.decr(_BACKPRESSURE_KEY) + return False + return True + + +async def _release_backpressure_slot(redis: Any) -> None: + """释放并发槽位(DECR 计数器)。""" + await redis.decr(_BACKPRESSURE_KEY) + + def _validate_channel_id(channel_id: str) -> str: """校验渠道 ID,非法时抛 400。""" if not _CHANNEL_ID_RE.match(channel_id): @@ -413,9 +491,7 @@ async def _build_adapter(channel_id: str) -> MessageAdapter: store.get_secret(f"{channel_id}:token"), ) if not all([app_key, app_secret, robot_code]): - raise HTTPException( - status_code=500, detail=f"渠道 '{channel_id}' 缺少 dingtalk 凭证" - ) + raise HTTPException(status_code=500, detail=f"渠道 '{channel_id}' 缺少 dingtalk 凭证") adapter = DingTalkMessageAdapter( app_key=app_key, app_secret=app_secret, @@ -434,9 +510,7 @@ async def _build_adapter(channel_id: str) -> MessageAdapter: store.get_secret(f"{channel_id}:agent_id"), ) if not all([corp_id, corp_secret, token, encoding_aes_key, agent_id_raw]): - raise HTTPException( - status_code=500, detail=f"渠道 '{channel_id}' 缺少 wecom 凭证" - ) + raise HTTPException(status_code=500, detail=f"渠道 '{channel_id}' 缺少 wecom 凭证") try: agent_id = int(agent_id_raw) except (TypeError, ValueError) as exc: @@ -461,9 +535,7 @@ async def _build_adapter(channel_id: str) -> MessageAdapter: store.get_secret(f"{channel_id}:verification_token"), ) if not bot_token or not signing_secret: - raise HTTPException( - status_code=500, detail=f"渠道 '{channel_id}' 缺少 slack 凭证" - ) + raise HTTPException(status_code=500, detail=f"渠道 '{channel_id}' 缺少 slack 凭证") adapter = SlackMessageAdapter( bot_token=bot_token, signing_secret=signing_secret, @@ -495,9 +567,7 @@ async def close_all_adapters() -> None: _adapter_cache.clear() -async def _process_inbound_message( - app_state: Any, adapter: MessageAdapter, message: Any -) -> None: +async def _process_inbound_message(app_state: Any, adapter: MessageAdapter, message: Any) -> None: """后台处理入站消息 — 调用 chat 链路并通过适配器回复。 整个流程 try/except 包裹,任何异常仅记录日志,不向上抛出 @@ -520,9 +590,7 @@ async def _process_inbound_message( request_preprocessor = getattr(app_state, "request_preprocessor", None) llm_gateway = getattr(app_state, "llm_gateway", None) if request_preprocessor is None or llm_gateway is None: - logger.warning( - "app.state 缺少 request_preprocessor 或 llm_gateway — 跳过消息处理" - ) + logger.warning("app.state 缺少 request_preprocessor 或 llm_gateway — 跳过消息处理") return # 路由预处理 — IM 场景使用默认 agent,无需技能注册表 @@ -587,12 +655,22 @@ async def channel_webhook(channel_id: str, request: Request) -> Any: 5. URL verification — 飞书/Slack 返回 challenge;企微返回 XML 6. 解析消息 → 后台异步处理 → 立即返回 200 + U4:nonce dedup、限流、backpressure 优先使用 Redis(多 worker 共享), + redis=None 时降级到进程内内存实现。 + 企微通过 query 参数传递 ``msg_signature``/``timestamp``/``nonce``, 合并到 headers dict 供适配器读取。 """ + redis = _get_redis_from_request(request) client_ip = _get_client_ip(request) - if not _check_rate_limit(client_ip): - raise HTTPException(status_code=429, detail="请求过于频繁,请稍后重试") + + # 1. 限流 — Redis ZSET 滑动窗口(多 worker 共享),降级到内存 + if redis is not None: + if not await _check_rate_limit_redis(redis, client_ip): + raise HTTPException(status_code=429, detail="请求过于频繁,请稍后重试") + else: + if not _check_rate_limit(client_ip): + raise HTTPException(status_code=429, detail="请求过于频繁,请稍后重试") body = await request.body() @@ -607,10 +685,15 @@ async def channel_webhook(channel_id: str, request: Request) -> Any: if not await adapter.verify_signature(headers_dict, body): raise HTTPException(status_code=401, detail="签名校验失败") - # Nonce dedup(可选 — 若头不存在则跳过去重;仅飞书携带该头) + # 4. Nonce dedup — Redis SET NX EX(多 worker 共享),降级到内存 nonce = request.headers.get("x-lark-request-nonce") - if nonce and not _check_nonce_dedup(nonce): - return {"code": 0, "msg": "duplicate"} + if nonce: + if redis is not None: + if not await _check_nonce_dedup_redis(redis, nonce): + return {"code": 0, "msg": "duplicate"} + else: + if not _check_nonce_dedup(nonce): + return {"code": 0, "msg": "duplicate"} try: message = await adapter.receive_message(headers_dict, body) @@ -624,13 +707,27 @@ async def channel_webhook(channel_id: str, request: Request) -> Any: logger.warning("receive_message 解析失败 channel=%s: %s", channel_id, exc) return {"code": 0, "msg": "invalid_payload"} - # 异步处理 — 不阻塞 webhook 响应(平台要求快速返回 200) - # 持有 task 引用防止 GC 回收正在运行的后台任务 - # 有界化:超过 2x 并发上限时拒绝新任务(防突发流量下 set 无界增长) - if len(_pending_webhook_tasks) >= _WEBHOOK_MAX_CONCURRENT * 2: - logger.warning("webhook 后台任务积压 %d,拒绝新任务", len(_pending_webhook_tasks)) - raise HTTPException(status_code=429, detail="服务器繁忙,请稍后重试") - task = asyncio.create_task(_process_inbound_message(request.app.state, adapter, message)) + # 6. Backpressure + 异步分发 — Redis 共享计数器(多 worker),降级到内存 set + if redis is not None: + if not await _acquire_backpressure_slot(redis): + logger.warning("webhook Redis backpressure 超限,拒绝新任务") + raise HTTPException(status_code=429, detail="服务器繁忙,请稍后重试") + + async def _run_with_release() -> None: + """执行消息处理并在完成后释放 Redis 并发槽位。""" + try: + await _process_inbound_message(request.app.state, adapter, message) + finally: + await _release_backpressure_slot(redis) + + task = asyncio.create_task(_run_with_release()) + else: + # 降级模式:进程内 set 长度检查 + GC 引用持有 + if len(_pending_webhook_tasks) >= _WEBHOOK_MAX_CONCURRENT * 2: + logger.warning("webhook 后台任务积压 %d,拒绝新任务", len(_pending_webhook_tasks)) + raise HTTPException(status_code=429, detail="服务器繁忙,请稍后重试") + task = asyncio.create_task(_process_inbound_message(request.app.state, adapter, message)) + _pending_webhook_tasks.add(task) task.add_done_callback(_pending_webhook_tasks.discard) diff --git a/tests/unit/channels/test_secrets_redis.py b/tests/unit/channels/test_secrets_redis.py new file mode 100644 index 0000000..6180680 --- /dev/null +++ b/tests/unit/channels/test_secrets_redis.py @@ -0,0 +1,245 @@ +"""U3 — SecretsStore Redis 后端单元测试。 + +覆盖场景: +- Redis 后端 CRUD(set/get/delete/list_keys) +- Redis=None 降级到内存字典 +- 加密-存储-读取-解密往返(Redis 后端) +- Redis 连接异常 fail-closed(get_secret 返回 None) +- 多实例共享语义(两个 store 共用同一 Redis mock) + +不依赖真实 Redis 实例 — 使用 AsyncMock 模拟 redis.asyncio.Redis。 +""" + +from __future__ import annotations + +import base64 +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from agentkit.channels.secrets import KEY_SIZE, SecretsStore + + +# ── 辅助函数 ────────────────────────────────────────────── + + +def _make_master_key() -> bytes: + """确定性测试 master key(32 字节)。""" + return b"\x01" * KEY_SIZE + + +class _FakeRedis: + """极简 Redis mock — 仅模拟 SecretsStore 用到的方法。 + + 内部用 dict 存储,支持 set/get/delete/scan_iter。 + scan_iter 返回 async generator(匹配真实 Redis 接口)。 + """ + + def __init__(self): + self._data: dict[str, str] = {} + + async def set(self, key: str, value: str, **kwargs) -> bool: + self._data[key] = value + return True + + async def get(self, key: str) -> str | None: + return self._data.get(key) + + async def delete(self, key: str) -> int: + if key in self._data: + del self._data[key] + return 1 + return 0 + + async def scan_iter(self, match: str = "*"): + import fnmatch + + for k in self._data: + if fnmatch.fnmatch(k, match): + yield k + + +def _make_failing_redis() -> AsyncMock: + """构造一个所有方法都抛 ConnectionError 的 Redis mock。""" + redis = AsyncMock() + redis.get = AsyncMock(side_effect=ConnectionError("redis down")) + redis.set = AsyncMock(side_effect=ConnectionError("redis down")) + redis.delete = AsyncMock(side_effect=ConnectionError("redis down")) + redis.scan_iter = MagicMock(side_effect=ConnectionError("redis down")) + return redis + + +@pytest.fixture(autouse=True) +def _clean_env(monkeypatch): + """确保测试不在生产模式下运行。""" + monkeypatch.delenv("AGENTKIT_ENV", raising=False) + monkeypatch.delenv("AGENTKIT_MASTER_KEY", raising=False) + + +# ── Redis 后端 CRUD ────────────────────────────────────── + + +class TestRedisBackendCrud: + """Redis 后端的 set/get/delete/list_keys 操作。""" + + async def test_set_and_get_secret_with_redis(self): + """Redis 后端:写入后读取返回明文。""" + redis = _FakeRedis() + store = SecretsStore(master_key=_make_master_key(), redis=redis) + await store.set_secret("feishu:app_id", "cli_xxx") + assert await store.get_secret("feishu:app_id") == "cli_xxx" + + async def test_get_nonexistent_returns_none_with_redis(self): + """Redis 后端:读取不存在的 key 返回 None。""" + redis = _FakeRedis() + store = SecretsStore(master_key=_make_master_key(), redis=redis) + assert await store.get_secret("missing") is None + + async def test_set_overwrites_existing_with_redis(self): + """Redis 后端:同名 key 写入覆盖旧值。""" + redis = _FakeRedis() + store = SecretsStore(master_key=_make_master_key(), redis=redis) + await store.set_secret("k", "old") + await store.set_secret("k", "new") + assert await store.get_secret("k") == "new" + + async def test_delete_secret_with_redis(self): + """Redis 后端:删除凭证后不可读取。""" + redis = _FakeRedis() + store = SecretsStore(master_key=_make_master_key(), redis=redis) + await store.set_secret("k", "v") + assert await store.delete_secret("k") is True + assert await store.get_secret("k") is None + + async def test_delete_nonexistent_returns_false_with_redis(self): + """Redis 后端:删除不存在的 key 返回 False。""" + redis = _FakeRedis() + store = SecretsStore(master_key=_make_master_key(), redis=redis) + assert await store.delete_secret("missing") is False + + async def test_list_keys_with_redis(self): + """Redis 后端:列出全部凭证键。""" + redis = _FakeRedis() + store = SecretsStore(master_key=_make_master_key(), redis=redis) + await store.set_secret("feishu:a", "1") + await store.set_secret("feishu:b", "2") + await store.set_secret("dingtalk:c", "3") + keys = await store.list_keys() + assert set(keys) == {"feishu:a", "feishu:b", "dingtalk:c"} + + async def test_list_keys_with_prefix_with_redis(self): + """Redis 后端:按前缀过滤凭证键。""" + redis = _FakeRedis() + store = SecretsStore(master_key=_make_master_key(), redis=redis) + await store.set_secret("feishu:a", "1") + await store.set_secret("feishu:b", "2") + await store.set_secret("dingtalk:c", "3") + keys = await store.list_keys(prefix="feishu:") + assert set(keys) == {"feishu:a", "feishu:b"} + + +# ── Redis=None 降级模式 ───────────────────────────────── + + +class TestFallbackMode: + """redis=None 时降级到内存字典。""" + + async def test_fallback_set_and_get(self): + """降级模式:写入后读取返回明文。""" + store = SecretsStore(master_key=_make_master_key(), redis=None) + await store.set_secret("k", "v") + assert await store.get_secret("k") == "v" + + async def test_fallback_delete(self): + """降级模式:删除凭证后不可读取。""" + store = SecretsStore(master_key=_make_master_key(), redis=None) + await store.set_secret("k", "v") + assert await store.delete_secret("k") is True + assert await store.get_secret("k") is None + + async def test_fallback_list_keys_with_prefix(self): + """降级模式:按前缀过滤。""" + store = SecretsStore(master_key=_make_master_key(), redis=None) + await store.set_secret("feishu:a", "1") + await store.set_secret("dingtalk:b", "2") + keys = await store.list_keys(prefix="feishu:") + assert keys == ["feishu:a"] + + +# ── 加密往返 ───────────────────────────────────────────── + + +class TestRedisEncryptionRoundtrip: + """Redis 后端的加密-存储-读取-解密往返。""" + + async def test_stored_value_is_encrypted_in_redis(self): + """Redis 中存储的 value 字段为密文,非明文。""" + redis = _FakeRedis() + store = SecretsStore(master_key=_make_master_key(), redis=redis) + plaintext = "plaintext-token" + await store.set_secret("k", plaintext) + # 直接检查 Redis 中的原始值 + raw = redis._data["agentkit:secrets:k"] + assert plaintext not in raw + # base64 解码后的密文也不含明文 + import json + + entry = json.loads(raw) + ciphertext_bytes = base64.b64decode(entry["value"]) + assert plaintext.encode() not in ciphertext_bytes + + async def test_roundtrip_preserves_plaintext(self): + """加密-存储-读取-解密往返保持明文一致。""" + redis = _FakeRedis() + store = SecretsStore(master_key=_make_master_key(), redis=redis) + plaintext = "feishu-app-secret-12345" + await store.set_secret("key1", plaintext) + assert await store.get_secret("key1") == plaintext + + +# ── 多实例共享语义 ─────────────────────────────────────── + + +class TestMultiInstanceSharing: + """两个 store 共用同一 Redis → 多 worker 共享语义。""" + + async def test_two_stores_share_state_via_redis(self): + """worker A set_secret → worker B get_secret 能读到(共享 Redis)。""" + redis = _FakeRedis() + store_a = SecretsStore(master_key=_make_master_key(), redis=redis) + store_b = SecretsStore(master_key=_make_master_key(), redis=redis) + + await store_a.set_secret("shared:key", "value-from-a") + # worker B 通过同一 Redis 能读到 worker A 写入的凭证 + assert await store_b.get_secret("shared:key") == "value-from-a" + + async def test_delete_from_one_store_visible_to_other(self): + """worker A delete → worker B get 返回 None。""" + redis = _FakeRedis() + store_a = SecretsStore(master_key=_make_master_key(), redis=redis) + store_b = SecretsStore(master_key=_make_master_key(), redis=redis) + + await store_a.set_secret("k", "v") + assert await store_b.delete_secret("k") is True + assert await store_a.get_secret("k") is None + + +# ── Redis 连接异常 ─────────────────────────────────────── + + +class TestRedisConnectionFailure: + """Redis 连接失败时的 fail-closed 行为。""" + + async def test_get_secret_raises_on_redis_failure(self): + """Redis 异常时 get_secret 抛异常(fail-closed,不静默返回 None)。""" + redis = _make_failing_redis() + store = SecretsStore(master_key=_make_master_key(), redis=redis) + with pytest.raises(ConnectionError): + await store.get_secret("k") + + async def test_set_secret_raises_on_redis_failure(self): + """Redis 异常时 set_secret 抛异常(写入失败不静默降级)。""" + redis = _make_failing_redis() + store = SecretsStore(master_key=_make_master_key(), redis=redis) + with pytest.raises(ConnectionError): + await store.set_secret("k", "v") diff --git a/tests/unit/channels/test_webhook_redis_state.py b/tests/unit/channels/test_webhook_redis_state.py new file mode 100644 index 0000000..911a811 --- /dev/null +++ b/tests/unit/channels/test_webhook_redis_state.py @@ -0,0 +1,340 @@ +"""U4 — Channels webhook Redis 状态迁移单元测试。 + +覆盖场景: +- nonce dedup:首次 → True,重复 → False,TTL 过期 → 可再次使用 +- rate limit:窗口内 100 请求通过,第 101 → False,窗口滚动后重置 +- backpressure:并发 < 2x 上限 → True,>= 2x 上限 → False,释放后恢复 +- 降级模式:redis=None → 内存实现仍工作 + +不依赖真实 Redis — 使用 _FakeRedis 模拟 redis.asyncio.Redis 的子集。 +""" + +from __future__ import annotations + +import fnmatch +import time + +import pytest + +from agentkit.server.routes.channels import ( + _BACKPRESSURE_KEY, + _RATE_LIMIT_MAX, + _WEBHOOK_MAX_CONCURRENT, + _acquire_backpressure_slot, + _check_nonce_dedup, + _check_nonce_dedup_redis, + _check_rate_limit, + _check_rate_limit_redis, + _release_backpressure_slot, + _reset_webhook_state, +) + + +# ── _FakeRedis — 支持 SecretsStore + webhook 用到的操作 ─── + + +class _FakeRedis: + """极简 Redis mock,支持 string/zset/incr/pipeline/expire/scan_iter。 + + 内部数据结构: + - _strings: dict[str, str] — string 类型(set/get/delete) + - _zsets: dict[str, dict[str, float]] — sorted set(zadd/zcard/zremrangebyscore) + - _expires: dict[str, float] — key -> 过期 monotonic 时间戳 + """ + + def __init__(self): + self._strings: dict[str, str] = {} + self._zsets: dict[str, dict[str, float]] = {} + self._expires: dict[str, float] = {} + self._incr_values: dict[str, int] = {} + + def _is_expired(self, key: str) -> bool: + exp = self._expires.get(key) + return exp is not None and time.monotonic() >= exp + + def _cleanup_if_expired(self, key: str) -> None: + if self._is_expired(key): + self._strings.pop(key, None) + self._zsets.pop(key, None) + self._incr_values.pop(key, None) + self._expires.pop(key, None) + + # ── string 操作 ─────────────────────────────────── + + async def set(self, key: str, value: str, *, ex: int | None = None, nx: bool = False) -> bool: + self._cleanup_if_expired(key) + if nx and key in self._strings: + return False + self._strings[key] = value + if ex is not None: + self._expires[key] = time.monotonic() + ex + return True + + async def get(self, key: str) -> str | None: + self._cleanup_if_expired(key) + return self._strings.get(key) + + async def delete(self, key: str) -> int: + existed = key in self._strings or key in self._zsets + self._strings.pop(key, None) + self._zsets.pop(key, None) + self._expires.pop(key, None) + return 1 if existed else 0 + + # ── zset 操作 ───────────────────────────────────── + + async def zadd(self, key: str, mapping: dict[str, float]) -> int: + self._cleanup_if_expired(key) + zset = self._zsets.setdefault(key, {}) + added = 0 + for member, score in mapping.items(): + if member not in zset: + added += 1 + zset[member] = score + return added + + async def zcard(self, key: str) -> int: + self._cleanup_if_expired(key) + return len(self._zsets.get(key, {})) + + async def zremrangebyscore(self, key: str, min_score: float, max_score: float) -> int: + self._cleanup_if_expired(key) + zset = self._zsets.get(key, {}) + removed = 0 + for member in list(zset): + if min_score <= zset[member] <= max_score: + del zset[member] + removed += 1 + return removed + + # ── incr/decr ───────────────────────────────────── + + async def incr(self, key: str) -> int: + self._cleanup_if_expired(key) + self._incr_values[key] = self._incr_values.get(key, 0) + 1 + return self._incr_values[key] + + async def decr(self, key: str) -> int: + self._cleanup_if_expired(key) + self._incr_values[key] = self._incr_values.get(key, 0) - 1 + return self._incr_values[key] + + # ── expire ──────────────────────────────────────── + + async def expire(self, key: str, seconds: int) -> bool: + if key in self._strings or key in self._zsets or key in self._incr_values: + self._expires[key] = time.monotonic() + seconds + return True + return False + + # ── pipeline ────────────────────────────────────── + + def pipeline(self) -> "_FakePipeline": + return _FakePipeline(self) + + # ── scan_iter ───────────────────────────────────── + + async def scan_iter(self, match: str = "*"): + for k in list(self._strings): + self._cleanup_if_expired(k) + if k in self._strings and fnmatch.fnmatch(k, match): + yield k + + +class _FakePipeline: + """模拟 redis pipeline — 收集命令,execute() 返回结果列表。""" + + def __init__(self, redis: _FakeRedis): + self._redis = redis + self._commands: list[tuple[str, tuple, dict]] = [] + + def zremrangebyscore(self, key, min_score, max_score): + self._commands.append(("zremrangebyscore", (key, min_score, max_score), {})) + return self + + def zadd(self, key, mapping): + self._commands.append(("zadd", (key, mapping), {})) + return self + + def zcard(self, key): + self._commands.append(("zcard", (key,), {})) + return self + + def expire(self, key, seconds): + self._commands.append(("expire", (key, seconds), {})) + return self + + def incr(self, key): + self._commands.append(("incr", (key,), {})) + return self + + async def execute(self) -> list: + results = [] + for cmd, args, kwargs in self._commands: + method = getattr(self._redis, cmd) + result = await method(*args, **kwargs) + results.append(result) + return results + + +# ── Fixtures ───────────────────────────────────────────── + + +@pytest.fixture(autouse=True) +def _reset_state(): + """每个测试前重置内存状态。""" + _reset_webhook_state() + yield + _reset_webhook_state() + + +# ── nonce dedup (Redis) ────────────────────────────────── + + +class TestNonceDedupRedis: + """Redis 后端 nonce 去重。""" + + async def test_first_nonce_returns_true(self): + """首次 nonce → True(新 nonce,应处理)。""" + redis = _FakeRedis() + assert await _check_nonce_dedup_redis(redis, "nonce-001") is True + + async def test_duplicate_nonce_returns_false(self): + """重复 nonce → False(跳过处理)。""" + redis = _FakeRedis() + assert await _check_nonce_dedup_redis(redis, "nonce-001") is True + assert await _check_nonce_dedup_redis(redis, "nonce-001") is False + + async def test_different_nonces_both_return_true(self): + """不同 nonce 各自首次 → True。""" + redis = _FakeRedis() + assert await _check_nonce_dedup_redis(redis, "nonce-a") is True + assert await _check_nonce_dedup_redis(redis, "nonce-b") is True + + async def test_nonce_ttl_expiry_allows_reuse(self): + """TTL 过期后相同 nonce 可再次使用。""" + redis = _FakeRedis() + # 首次写入 + assert await _check_nonce_dedup_redis(redis, "nonce-exp") is True + # 模拟 TTL 过期 — 手动清除过期标记 + redis._expires.clear() + redis._strings.clear() + # 过期后相同 nonce 可再次使用 + assert await _check_nonce_dedup_redis(redis, "nonce-exp") is True + + +# ── nonce dedup (内存降级) ─────────────────────────────── + + +class TestNonceDedupFallback: + """redis=None 时内存 nonce 去重仍工作。""" + + def test_fallback_first_nonce_returns_true(self): + """内存模式:首次 nonce → True。""" + assert _check_nonce_dedup("nonce-fb-1") is True + + def test_fallback_duplicate_returns_false(self): + """内存模式:重复 nonce → False。""" + assert _check_nonce_dedup("nonce-fb-2") is True + assert _check_nonce_dedup("nonce-fb-2") is False + + +# ── rate limit (Redis) ─────────────────────────────────── + + +class TestRateLimitRedis: + """Redis 后端滑动窗口限流。""" + + async def test_under_limit_returns_true(self): + """窗口内未超限 → True。""" + redis = _FakeRedis() + for _ in range(_RATE_LIMIT_MAX): + assert await _check_rate_limit_redis(redis, "1.2.3.4") is True + + async def test_over_limit_returns_false(self): + """超过窗口上限 → False。""" + redis = _FakeRedis() + for _ in range(_RATE_LIMIT_MAX): + await _check_rate_limit_redis(redis, "1.2.3.4") + # 第 _RATE_LIMIT_MAX + 1 次应被拒绝 + assert await _check_rate_limit_redis(redis, "1.2.3.4") is False + + async def test_different_ips_independent(self): + """不同 IP 的限流独立。""" + redis = _FakeRedis() + for _ in range(_RATE_LIMIT_MAX): + await _check_rate_limit_redis(redis, "1.1.1.1") + # IP 1.1.1.1 已满,但 2.2.2.2 仍可通过 + assert await _check_rate_limit_redis(redis, "2.2.2.2") is True + + async def test_window_reset_after_expiry(self): + """窗口过期后计数重置。""" + redis = _FakeRedis() + for _ in range(_RATE_LIMIT_MAX): + await _check_rate_limit_redis(redis, "3.3.3.3") + assert await _check_rate_limit_redis(redis, "3.3.3.3") is False + # 模拟窗口过期 — 清除 zset 数据 + redis._zsets.clear() + redis._expires.clear() + # 过期后可再次通过 + assert await _check_rate_limit_redis(redis, "3.3.3.3") is True + + +# ── rate limit (内存降级) ──────────────────────────────── + + +class TestRateLimitFallback: + """redis=None 时内存限流仍工作。""" + + def test_fallback_under_limit(self): + """内存模式:未超限 → True。""" + for _ in range(_RATE_LIMIT_MAX): + assert _check_rate_limit("4.4.4.4") is True + + def test_fallback_over_limit(self): + """内存模式:超限 → False。""" + for _ in range(_RATE_LIMIT_MAX): + _check_rate_limit("5.5.5.5") + assert _check_rate_limit("5.5.5.5") is False + + +# ── backpressure (Redis) ───────────────────────────────── + + +class TestBackpressureRedis: + """Redis 后端共享并发计数器。""" + + async def test_under_limit_returns_true(self): + """并发 < 2x 上限 → True。""" + redis = _FakeRedis() + for _ in range(_WEBHOOK_MAX_CONCURRENT * 2): + assert await _acquire_backpressure_slot(redis) is True + + async def test_over_limit_returns_false(self): + """并发 >= 2x 上限 → False。""" + redis = _FakeRedis() + for _ in range(_WEBHOOK_MAX_CONCURRENT * 2): + await _acquire_backpressure_slot(redis) + # 超过 2x 上限应拒绝 + assert await _acquire_backpressure_slot(redis) is False + + async def test_release_restores_slot(self): + """释放后槽位恢复,可再次获取。""" + redis = _FakeRedis() + # 获取到上限 + for _ in range(_WEBHOOK_MAX_CONCURRENT * 2): + await _acquire_backpressure_slot(redis) + # 超限 + assert await _acquire_backpressure_slot(redis) is False + # 释放一个 + await _release_backpressure_slot(redis) + # 可再次获取 + assert await _acquire_backpressure_slot(redis) is True + + async def test_release_decrements_counter(self): + """release 后计数器递减。""" + redis = _FakeRedis() + await _acquire_backpressure_slot(redis) + assert redis._incr_values[_BACKPRESSURE_KEY] == 1 + await _release_backpressure_slot(redis) + assert redis._incr_values[_BACKPRESSURE_KEY] == 0 diff --git a/tests/unit/llm/test_gateway_cache_failclosed.py b/tests/unit/llm/test_gateway_cache_failclosed.py new file mode 100644 index 0000000..374e4be --- /dev/null +++ b/tests/unit/llm/test_gateway_cache_failclosed.py @@ -0,0 +1,182 @@ +"""U1 — Gateway KB cache fail-closed 行为测试。 + +验证安全要求 R1:KB settings 读取失败时必须 fail-closed(禁用缓存), +不得 fail-open(默认启用缓存)。 + +覆盖场景: +1. settings 正常读取 caching_disabled=False → 缓存启用 +2. settings 正常读取 caching_disabled=True → 缓存禁用 +3. get_settings_store() 抛异常 → fail-closed,缓存禁用 +4. settings 返回 None(KB 不存在)→ fail-closed,缓存禁用 +5. kb_id=None(非 RAG 请求)→ 不查 settings,缓存正常启用 +""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +from agentkit.llm.gateway import LLMGateway +from agentkit.llm.protocol import LLMRequest, LLMResponse, TokenUsage + + +def _make_response() -> LLMResponse: + """构造最小 LLMResponse。""" + return LLMResponse( + content="ok", + model="test-model", + usage=TokenUsage(prompt_tokens=5, completion_tokens=3), + ) + + +def _make_gateway_with_cache() -> LLMGateway: + """构造带 mock 缓存管理器的 LLMGateway(避免 litellm 依赖)。 + + gateway.chat() 调用 LitellmCacheManager.cache_params_for_hit/no_cache(类静态方法), + 仅 should_cache 和 build_cache_key 通过实例调用 — mock 这两个即可。 + """ + gateway = LLMGateway() # 不启用真实 cache(litellm 可能未安装) + mock_manager = MagicMock() + # should_cache: kb_caching_disabled=True 或 user_id=None 时返回 False + mock_manager.should_cache = lambda kb_disabled, uid: not kb_disabled and uid is not None + mock_manager.build_cache_key = MagicMock(return_value="mock_cache_key") + mock_manager.record_cache_result = MagicMock() + gateway._cache_manager = mock_manager + return gateway + + +def _register_mock_provider(gateway: LLMGateway) -> MagicMock: + """注册 mock provider,返回带 cache 参数的 capture。 + + 模型格式 "test/test-model" → _resolve_model 按 "/" 分割为 provider="test" + model="test-model"。 + """ + provider = MagicMock() + provider.chat = AsyncMock(return_value=_make_response()) + gateway.register_provider("test", provider) + return provider + + +_MODEL = "test/test-model" + + +def _get_cache_arg(provider: MagicMock) -> dict: + """从 provider.chat 调用中提取 cache 参数。""" + call_args = provider.chat.call_args + # provider.chat(req) — req 是第一个位置参数 + req: LLMRequest = call_args.args[0] + return req._cache or {} + + +class TestGatewayCacheFailClosed: + """U1 — KB settings 读取异常时 fail-closed。""" + + async def test_settings_caching_false_enables_cache(self): + """settings 正常读取 caching_disabled=False → cache_key 传入 provider(启用缓存)。""" + gateway = _make_gateway_with_cache() + provider = _register_mock_provider(gateway) + + mock_settings = SimpleNamespace(caching_disabled=False) + with patch("agentkit.rag_platform.settings.get_settings_store") as mock_get_store: + mock_store = MagicMock() + mock_store.get_settings = AsyncMock(return_value=mock_settings) + mock_get_store.return_value = mock_store + + await gateway.chat( + messages=[{"role": "user", "content": "hi"}], + model=_MODEL, + user_id="u1", + kb_id="kb1", + kb_acl_hash="acl1", + ) + + cache_arg = _get_cache_arg(provider) + assert "cache_key" in cache_arg, f"Expected cache_key (cache enabled), got {cache_arg}" + + async def test_settings_caching_true_disables_cache(self): + """settings 正常读取 caching_disabled=True → no-cache 传入 provider(禁用缓存)。""" + gateway = _make_gateway_with_cache() + provider = _register_mock_provider(gateway) + + mock_settings = SimpleNamespace(caching_disabled=True) + with patch("agentkit.rag_platform.settings.get_settings_store") as mock_get_store: + mock_store = MagicMock() + mock_store.get_settings = AsyncMock(return_value=mock_settings) + mock_get_store.return_value = mock_store + + await gateway.chat( + messages=[{"role": "user", "content": "hi"}], + model=_MODEL, + user_id="u1", + kb_id="kb1", + ) + + cache_arg = _get_cache_arg(provider) + assert cache_arg.get("no-cache") is True, f"Expected no-cache=True, got {cache_arg}" + + async def test_settings_exception_fail_closed(self): + """get_settings_store() 抛异常 → fail-closed(no-cache)。""" + gateway = _make_gateway_with_cache() + provider = _register_mock_provider(gateway) + + with patch("agentkit.rag_platform.settings.get_settings_store") as mock_get_store: + mock_store = MagicMock() + mock_store.get_settings = AsyncMock(side_effect=RuntimeError("DB down")) + mock_get_store.return_value = mock_store + + await gateway.chat( + messages=[{"role": "user", "content": "hi"}], + model=_MODEL, + user_id="u1", + kb_id="kb1", + ) + + cache_arg = _get_cache_arg(provider) + assert cache_arg.get("no-cache") is True, ( + f"fail-closed: 读取异常应禁用缓存,但 got {cache_arg}" + ) + + async def test_settings_none_fail_closed(self): + """settings 返回 None(KB 不存在)→ fail-closed(no-cache)。""" + gateway = _make_gateway_with_cache() + provider = _register_mock_provider(gateway) + + with patch("agentkit.rag_platform.settings.get_settings_store") as mock_get_store: + mock_store = MagicMock() + mock_store.get_settings = AsyncMock(return_value=None) + mock_get_store.return_value = mock_store + + await gateway.chat( + messages=[{"role": "user", "content": "hi"}], + model=_MODEL, + user_id="u1", + kb_id="kb_nonexistent", + ) + + cache_arg = _get_cache_arg(provider) + assert cache_arg.get("no-cache") is True, ( + f"Expected no-cache for None settings, got {cache_arg}" + ) + + async def test_no_kb_id_skips_settings_lookup(self): + """kb_id=None(非 RAG 请求)→ 不查 settings,缓存正常启用。""" + gateway = _make_gateway_with_cache() + provider = _register_mock_provider(gateway) + + with patch("agentkit.rag_platform.settings.get_settings_store") as mock_get_store: + mock_store = MagicMock() + mock_store.get_settings = AsyncMock() + mock_get_store.return_value = mock_store + + await gateway.chat( + messages=[{"role": "user", "content": "hi"}], + model=_MODEL, + user_id="u1", + # kb_id 不传 + ) + + # 不应查询 settings + mock_store.get_settings.assert_not_called() + + # 缓存应启用(cache_key 传入) + cache_arg = _get_cache_arg(provider) + assert "cache_key" in cache_arg, f"Expected cache_key (no kb_id), got {cache_arg}" diff --git a/tests/unit/mcp/test_server_dangerous_tools.py b/tests/unit/mcp/test_server_dangerous_tools.py new file mode 100644 index 0000000..6bece3e --- /dev/null +++ b/tests/unit/mcp/test_server_dangerous_tools.py @@ -0,0 +1,269 @@ +"""U2 — MCP 危险工具黑名单过滤单元测试。 + +覆盖 6 个暴露路径: +- create_mcp_router() REST: GET /tools/list, POST /tools/call +- create_mcp_router() JSON-RPC: tools/list, tools/call +- legacy MCPServer REST: GET /tools/list, POST /tools/call +- legacy MCPServer JSON-RPC: tools/list, tools/call + +黑名单工具(terminal/shell/file_write/file_read/file_delete)依赖 chat +confirmation 流程(WebSocket),通过 MCP 暴露会绕过用户确认。 +""" + +from __future__ import annotations + +import warnings +from typing import Any +from unittest.mock import AsyncMock, MagicMock + +import httpx +from fastapi import FastAPI + +from agentkit.mcp.server import MCPServer, create_mcp_router +from agentkit.server.auth.middleware import AuthMiddleware + +# 与 test_server_auth.py 共用的测试凭据 — 仅限单元测试。 +JWT_SECRET = "u13-test-jwt-secret-xxxxxxxxxxxxx" +API_KEY = "u13-test-api-key-yyy" + + +# ── 辅助函数 ────────────────────────────────────────────── + + +def _make_mock_tool( + name: str, + description: str = "", + result: str = "ok", +) -> MagicMock: + """构造一个 mock 工具,模拟 Tool 接口。""" + tool = MagicMock() + tool.name = name + tool.description = description + tool.input_schema = {"type": "object", "properties": {}} + tool.safe_execute = AsyncMock(return_value=result) + return tool + + +def _make_mock_registry(tools: list) -> MagicMock: + """构造一个 mock ToolRegistry,支持 list_tools() 和 get(name)。""" + registry = MagicMock() + registry.list_tools.return_value = tools + + def _get(name: str): + for t in tools: + if t.name == name: + return t + raise KeyError(name) + + registry.get = _get + return registry + + +def _make_registry_with_safe_and_dangerous() -> MagicMock: + """构造含 1 安全 + 5 危险工具的 registry。 + + 危险工具名与 _MCP_BLOCKED_TOOLS 完全对应,验证黑名单全覆盖。 + """ + return _make_mock_registry( + [ + _make_mock_tool("echo", "safe tool", "echo: hi"), + _make_mock_tool("shell", "shell tool", "should-not-reach"), + _make_mock_tool("file_write", "write tool", "should-not-reach"), + _make_mock_tool("file_read", "read tool", "should-not-reach"), + _make_mock_tool("file_delete", "delete tool", "should-not-reach"), + _make_mock_tool("terminal", "terminal tool", "should-not-reach"), + ] + ) + + +def _make_app(tool_registry: Any = None) -> FastAPI: + """构造测试用 FastAPI app:挂载 MCP router + AuthMiddleware。""" + app = FastAPI() + app.state.tool_registry = tool_registry + app.add_middleware(AuthMiddleware, jwt_secret=JWT_SECRET, api_key=API_KEY) + mcp_router = create_mcp_router(tool_registry=tool_registry) + app.include_router(mcp_router, prefix="/api/v1/mcp") + return app + + +def _make_legacy_app(tool_registry: Any = None) -> FastAPI: + """构造 legacy MCPServer app(无认证,过滤 DeprecationWarning)。""" + server = MCPServer(tool_registry=tool_registry) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + return server._create_app() + + +_BLOCKED = {"terminal", "shell", "file_write", "file_read", "file_delete"} +_HEADERS = {"X-API-Key": API_KEY} + + +# ── create_mcp_router() REST 端点 ───────────────────────── + + +class TestCreateMcpRouterRestFilter: + """create_mcp_router() 的 REST 端点黑名单过滤。""" + + async def test_rest_tools_list_excludes_blocked(self): + """GET /tools/list 不返回黑名单工具。""" + app = _make_app(tool_registry=_make_registry_with_safe_and_dangerous()) + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.get("/api/v1/mcp/tools/list", headers=_HEADERS) + assert resp.status_code == 200 + names = {t["name"] for t in resp.json()["tools"]} + assert names == {"echo"} + assert not (names & _BLOCKED) + + async def test_rest_tools_call_safe_tool_succeeds(self): + """POST /tools/call 安全工具 → 200 + 结果。""" + app = _make_app(tool_registry=_make_registry_with_safe_and_dangerous()) + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.post( + "/api/v1/mcp/tools/call", + json={"name": "echo", "arguments": {}}, + headers=_HEADERS, + ) + assert resp.status_code == 200 + assert "echo: hi" in resp.json()["content"][0]["text"] + + async def test_rest_tools_call_shell_returns_404(self): + """POST /tools/call shell → 404(黑名单)。""" + app = _make_app(tool_registry=_make_registry_with_safe_and_dangerous()) + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.post( + "/api/v1/mcp/tools/call", + json={"name": "shell", "arguments": {"cmd": "rm -rf /"}}, + headers=_HEADERS, + ) + assert resp.status_code == 404 + + async def test_rest_tools_call_file_delete_returns_404(self): + """POST /tools/call file_delete → 404(黑名单)。""" + app = _make_app(tool_registry=_make_registry_with_safe_and_dangerous()) + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.post( + "/api/v1/mcp/tools/call", + json={"name": "file_delete", "arguments": {"path": "/etc/passwd"}}, + headers=_HEADERS, + ) + assert resp.status_code == 404 + + +# ── create_mcp_router() JSON-RPC 端点 ───────────────────── + + +class TestCreateMcpRouterJsonRpcFilter: + """create_mcp_router() 的 JSON-RPC 端点黑名单过滤。""" + + async def test_jsonrpc_tools_list_excludes_blocked(self): + """JSON-RPC tools/list 不包含黑名单工具。""" + app = _make_app(tool_registry=_make_registry_with_safe_and_dangerous()) + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.post( + "/api/v1/mcp/", + json={"jsonrpc": "2.0", "method": "tools/list", "id": 1}, + headers=_HEADERS, + ) + assert resp.status_code == 200 + body = resp.json() + names = {t["name"] for t in body["result"]["tools"]} + assert names == {"echo"} + assert not (names & _BLOCKED) + + async def test_jsonrpc_tools_call_shell_returns_iserror(self): + """JSON-RPC tools/call shell → isError=True。""" + app = _make_app(tool_registry=_make_registry_with_safe_and_dangerous()) + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.post( + "/api/v1/mcp/", + json={ + "jsonrpc": "2.0", + "method": "tools/call", + "params": {"name": "shell", "arguments": {"cmd": "rm -rf /"}}, + "id": 2, + }, + headers=_HEADERS, + ) + assert resp.status_code == 200 + result = resp.json()["result"] + assert result.get("isError") is True + assert "blocked" in result["content"][0]["text"].lower() + + +# ── legacy MCPServer REST 端点 ──────────────────────────── + + +class TestLegacyMcpServerRestFilter: + """legacy MCPServer(独立 app)的 REST 端点黑名单过滤。""" + + async def test_legacy_rest_tools_list_excludes_blocked(self): + """legacy GET /tools/list 不返回黑名单工具。""" + app = _make_legacy_app(tool_registry=_make_registry_with_safe_and_dangerous()) + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.get("/tools/list") + assert resp.status_code == 200 + names = {t["name"] for t in resp.json()["tools"]} + assert names == {"echo"} + assert not (names & _BLOCKED) + + async def test_legacy_rest_tools_call_shell_returns_error(self): + """legacy POST /tools/call shell → 返回 error(blocked)。""" + app = _make_legacy_app(tool_registry=_make_registry_with_safe_and_dangerous()) + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.post( + "/tools/call", + json={"name": "shell", "arguments": {"cmd": "rm -rf /"}}, + ) + assert resp.status_code == 200 + body = resp.json() + assert "error" in body + assert "blocked" in body["error"].lower() + + +# ── legacy MCPServer JSON-RPC 端点 ──────────────────────── + + +class TestLegacyMcpServerJsonRpcFilter: + """legacy MCPServer 的 JSON-RPC 端点黑名单过滤。""" + + async def test_legacy_jsonrpc_tools_list_excludes_blocked(self): + """legacy JSON-RPC tools/list 不包含黑名单工具。""" + app = _make_legacy_app(tool_registry=_make_registry_with_safe_and_dangerous()) + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.post( + "/", + json={"jsonrpc": "2.0", "method": "tools/list", "id": 1}, + ) + assert resp.status_code == 200 + body = resp.json() + names = {t["name"] for t in body["result"]["tools"]} + assert names == {"echo"} + assert not (names & _BLOCKED) + + async def test_legacy_jsonrpc_tools_call_shell_returns_iserror(self): + """legacy JSON-RPC tools/call shell → isError=True。""" + app = _make_legacy_app(tool_registry=_make_registry_with_safe_and_dangerous()) + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.post( + "/", + json={ + "jsonrpc": "2.0", + "method": "tools/call", + "params": {"name": "shell", "arguments": {"cmd": "rm -rf /"}}, + "id": 2, + }, + ) + assert resp.status_code == 200 + result = resp.json()["result"] + assert result.get("isError") is True + assert "blocked" in result["content"][0]["text"].lower()