From 5572387c01594809ac68f2efef1a93b19ffb44ee Mon Sep 17 00:00:00 2001 From: chiguyong Date: Thu, 25 Jun 2026 20:13:37 +0800 Subject: [PATCH] =?UTF-8?q?feat(channels):=20U10=20=E2=80=94=20message=20a?= =?UTF-8?q?dapter=20ABC=20+=20AES-256-GCM=20secrets=20store=20+=20channel?= =?UTF-8?q?=20CRUD=20routes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 2 + src/agentkit/channels/__init__.py | 27 ++ src/agentkit/channels/base.py | 79 ++++++ src/agentkit/channels/secrets.py | 172 ++++++++++++ src/agentkit/server/app.py | 2 + src/agentkit/server/routes/__init__.py | 2 + src/agentkit/server/routes/channels.py | 253 +++++++++++++++++ tests/unit/channels/__init__.py | 0 tests/unit/channels/test_base.py | 368 +++++++++++++++++++++++++ tests/unit/channels/test_secrets.py | 245 ++++++++++++++++ 10 files changed, 1150 insertions(+) create mode 100644 src/agentkit/channels/__init__.py create mode 100644 src/agentkit/channels/base.py create mode 100644 src/agentkit/channels/secrets.py create mode 100644 src/agentkit/server/routes/channels.py create mode 100644 tests/unit/channels/__init__.py create mode 100644 tests/unit/channels/test_base.py create mode 100644 tests/unit/channels/test_secrets.py diff --git a/pyproject.toml b/pyproject.toml index 9f90092..5244566 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,8 @@ dependencies = [ "pyjwt>=2.8", "bcrypt>=4.0", "aiosqlite>=0.20", + # 加密 secrets store (U10 — 多端消息适配器 AES-256-GCM) + "cryptography>=42.0", # Calendar & schedule (RRULE expansion) "python-dateutil>=2.9", # Calendar ICS import/export (U8) diff --git a/src/agentkit/channels/__init__.py b/src/agentkit/channels/__init__.py new file mode 100644 index 0000000..a1a86af --- /dev/null +++ b/src/agentkit/channels/__init__.py @@ -0,0 +1,27 @@ +"""多端消息适配器子系统 (U10)。 + +提供渠道适配器 ABC 与加密 secrets store 基础设施: +- :class:`MessageAdapter` — 所有平台适配器的抽象基类 +- :class:`SecretsStore` — AES-256-GCM 加密凭证存储 (KTD8) +""" + +from __future__ import annotations + +from agentkit.channels.base import ( + ChannelType, + IncomingMessage, + MessageAdapter, + MessageDirection, + OutgoingMessage, +) +from agentkit.channels.secrets import SecretEntry, SecretsStore + +__all__ = [ + "ChannelType", + "IncomingMessage", + "MessageAdapter", + "MessageDirection", + "OutgoingMessage", + "SecretEntry", + "SecretsStore", +] diff --git a/src/agentkit/channels/base.py b/src/agentkit/channels/base.py new file mode 100644 index 0000000..ebda830 --- /dev/null +++ b/src/agentkit/channels/base.py @@ -0,0 +1,79 @@ +"""消息适配器 ABC — 所有渠道适配器的基类。 + +与 KBAdapter 的 authenticate()/close() 生命周期方法对齐: +verify_signature() 对应 authenticate()。 +""" + +from __future__ import annotations + +import abc +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + + +class ChannelType(str, Enum): + """支持的消息平台渠道。""" + + FEISHU = "feishu" + DINGTALK = "dingtalk" + WECOM = "wecom" + SLACK = "slack" + + +class MessageDirection(str, Enum): + """消息流向。""" + + INBOUND = "inbound" + OUTBOUND = "outbound" + + +@dataclass +class IncomingMessage: + """标准化入站消息 — 所有渠道适配器将平台特定格式转换为此结构。""" + + channel: ChannelType + platform_message_id: str + user_id: str # 平台用户 ID + chat_id: str # 群组/会话 ID + content: str # 消息文本 + raw_event: dict[str, Any] = field(default_factory=dict) # 原始事件 + timestamp: str = "" + + +@dataclass +class OutgoingMessage: + """标准化出站消息。""" + + channel: ChannelType + chat_id: str + content: str + reply_to_message_id: str | None = None + + +class MessageAdapter(abc.ABC): + """消息适配器 ABC。 + + 生命周期: + __init__ → verify_signature() → receive_message() → send_message() → close() + + 子类必须实现全部抽象方法。verify_signature 失败时调用方应拒绝处理 + (webhook 端点 fail-closed:Redis 不可用或签名校验失败均返回 503/401, + 不可跳过 nonce dedup 直接处理消息)。 + """ + + @abc.abstractmethod + async def verify_signature(self, headers: dict[str, str], body: bytes) -> bool: + """验证平台签名/token。返回 True 表示请求可信。""" + + @abc.abstractmethod + async def receive_message(self, headers: dict[str, str], body: bytes) -> IncomingMessage: + """从 webhook 请求中解析标准化消息。""" + + @abc.abstractmethod + async def send_message(self, message: OutgoingMessage) -> bool: + """向平台发送消息。返回 True 表示发送成功。""" + + @abc.abstractmethod + async def close(self) -> None: + """释放资源(HTTP 客户端、连接池等)。""" diff --git a/src/agentkit/channels/secrets.py b/src/agentkit/channels/secrets.py new file mode 100644 index 0000000..6213e14 --- /dev/null +++ b/src/agentkit/channels/secrets.py @@ -0,0 +1,172 @@ +"""加密 DB 列 secrets store — AES-256-GCM (KTD8)。 + +KTD8 关键决策: +- 每行使用随机 96-bit nonce + HKDF with per-row salt 派生 per-row 密钥。 +- 生产环境 master key 必须来自云 KMS;环境变量 AGENTKIT_MASTER_KEY 仅作开发 fallback。 +- 应用启动 guard:生产模式下若仅检测到环境变量 key 应拒绝启动 + (由 server 启动钩子实现,本模块提供 ``assert_production_master_key`` 辅助)。 +- Master key 轮换采用双密钥窗口策略(key_id 字段标记)。 + +当前为内存存储实现,PG 迁移预留接口(``_store: dict`` → 未来替换为 ORM session)。 +""" + +from __future__ import annotations + +import base64 +import logging +import os + +from pydantic import BaseModel, ConfigDict + +logger = logging.getLogger(__name__) + +# AES-256-GCM 参数 +NONCE_SIZE = 12 # 96-bit nonce (GCM 推荐长度) +KEY_SIZE = 32 # 256-bit key +SALT_SIZE = 16 # per-row salt +HKDF_INFO = b"agentkit-channels-secrets-v1" + +# 生产环境标记:当此环境变量存在时启用生产 guard。 +_PRODUCTION_ENV_VAR = "AGENTKIT_ENV" +_PRODUCTION_VALUE = "production" + + +class SecretEntry(BaseModel): + """凭证条目 — 对应 DB 一行加密列。""" + + model_config = ConfigDict() + + key: str # 唯一键(如 "feishu:app_id:xxx") + value: str # base64 编码的密文 + nonce: str # base64 编码的 96-bit nonce + salt: str # base64 编码的 per-row salt + key_id: str = "default" # master key ID(用于双密钥轮换窗口) + created_at: str = "" + updated_at: str = "" + + +def assert_production_master_key(master_key: bytes | None, *, source: str = "env") -> None: + """生产环境启动 guard。 + + 若运行于生产模式(``AGENTKIT_ENV=production``)且 master key 来源于 + 环境变量(source="env")或为空,则拒绝启动。 + + Args: + master_key: 已加载的 master key 字节,可能为 None。 + source: master key 来源标记。"env" 表示环境变量 fallback, + "kms" 表示云 KMS。生产模式仅接受 "kms"。 + + Raises: + RuntimeError: 生产模式下 master key 来源不可信。 + """ + if os.environ.get(_PRODUCTION_ENV_VAR) == _PRODUCTION_VALUE: + if source != "kms" or not master_key: + raise RuntimeError( + "生产环境 master key 必须来自云 KMS(source=kms);" + "环境变量 AGENTKIT_MASTER_KEY 仅作开发 fallback。" + "请通过 KMS 注入 master key 并以 source='kms' 调用本 guard。" + ) + + +class SecretsStore: + """加密凭证存储。 + + 使用 AES-256-GCM 加密,HKDF with per-row salt 派生 per-row 密钥。 + Master key 从环境变量 ``AGENTKIT_MASTER_KEY`` 读取(开发 fallback); + 生产环境应通过 KMS 提供 master key 并显式传入。 + """ + + def __init__(self, master_key: bytes | None = None, *, key_source: str = "env"): + """初始化 secrets store。 + + Args: + master_key: 显式传入的 master key。若为 None 则从环境变量加载。 + key_source: master key 来源标记,传给 ``assert_production_master_key``。 + 生产环境应设为 "kms"。 + """ + self._master_key = master_key or self._load_master_key() + # 生产 guard:若 key_source="env" 且在生产模式下,构造即失败。 + assert_production_master_key(self._master_key, source=key_source) + # 内存存储(PG 迁移预留接口:替换为 ORM session 即可) + self._store: dict[str, SecretEntry] = {} + + def _load_master_key(self) -> bytes: + """从环境变量加载 master key(开发 fallback)。""" + key_b64 = os.environ.get("AGENTKIT_MASTER_KEY") + if key_b64: + return base64.b64decode(key_b64) + # ponytail: 开发模式生成临时 key。生产环境必须通过 KMS 提供 + # (由 assert_production_master_key 在 __init__ 中强制)。 + logger.warning("AGENTKIT_MASTER_KEY 未设置 — 使用临时 key(仅限开发)") + return os.urandom(KEY_SIZE) + + def _derive_key(self, salt: bytes) -> bytes: + """HKDF 从 master key + per-row salt 派生 per-row 密钥。""" + from cryptography.hazmat.primitives import hashes + from cryptography.hazmat.primitives.kdf.hkdf import HKDF + + hkdf = HKDF( + algorithm=hashes.SHA256(), + length=KEY_SIZE, + salt=salt, + info=HKDF_INFO, + ) + return hkdf.derive(self._master_key) + + def encrypt(self, plaintext: str) -> SecretEntry: + """加密凭证,返回 SecretEntry(key 字段留空,由调用方设置)。""" + from cryptography.hazmat.primitives.ciphers.aead import AESGCM + + salt = os.urandom(SALT_SIZE) + key = self._derive_key(salt) + nonce = os.urandom(NONCE_SIZE) + + aesgcm = AESGCM(key) + ciphertext = aesgcm.encrypt(nonce, plaintext.encode(), None) + + return SecretEntry( + key="", + value=base64.b64encode(ciphertext).decode(), + nonce=base64.b64encode(nonce).decode(), + salt=base64.b64encode(salt).decode(), + ) + + def decrypt(self, entry: SecretEntry) -> str: + """解密凭证。密钥不匹配或密文损坏时抛出 cryptography 异常。""" + from cryptography.hazmat.primitives.ciphers.aead import AESGCM + + salt = base64.b64decode(entry.salt) + key = self._derive_key(salt) + nonce = base64.b64decode(entry.nonce) + ciphertext = base64.b64decode(entry.value) + + aesgcm = AESGCM(key) + plaintext = aesgcm.decrypt(nonce, ciphertext, None) + return plaintext.decode() + + async def set_secret(self, key: str, value: str) -> SecretEntry: + """存储加密凭证。覆盖同名 key。""" + entry = self.encrypt(value) + entry.key = key + self._store[key] = entry + return entry + + async def get_secret(self, key: str) -> str | None: + """读取并解密凭证。key 不存在返回 None。""" + entry = self._store.get(key) + if entry is None: + return None + return self.decrypt(entry) + + async def delete_secret(self, key: str) -> bool: + """删除凭证。返回是否删除成功。""" + if key in self._store: + del self._store[key] + return True + return False + + async def list_keys(self, prefix: str | None = None) -> list[str]: + """列出凭证键。可选前缀过滤。""" + if prefix: + return [k for k in self._store if k.startswith(prefix)] + return list(self._store.keys()) diff --git a/src/agentkit/server/app.py b/src/agentkit/server/app.py index 9ee145a..a288f5b 100644 --- a/src/agentkit/server/app.py +++ b/src/agentkit/server/app.py @@ -52,6 +52,7 @@ from agentkit.server.routes import ( admin as admin_routes_module, calendar as calendar_routes, bitable as bitable_routes, + channels as channels_routes, ) from agentkit.server.auth.jwt_utils import get_jwt_secret from agentkit.server.auth.middleware import AuthMiddleware @@ -1082,6 +1083,7 @@ def create_app( app.include_router(documents.router, prefix="/api/v1") app.include_router(calendar_routes.router, prefix="/api/v1") app.include_router(bitable_routes.router, prefix="/api/v1") + app.include_router(channels_routes.router, prefix="/api/v1") # Serve GUI when in GUI mode gui_mode = os.environ.get("AGENTKIT_GUI_MODE") diff --git a/src/agentkit/server/routes/__init__.py b/src/agentkit/server/routes/__init__.py index 03a96dd..d0786e7 100644 --- a/src/agentkit/server/routes/__init__.py +++ b/src/agentkit/server/routes/__init__.py @@ -17,6 +17,7 @@ from agentkit.server.routes import ( workflows, terminal, experts, + channels, ) __all__ = [ @@ -36,4 +37,5 @@ __all__ = [ "workflows", "terminal", "experts", + "channels", ] diff --git a/src/agentkit/server/routes/channels.py b/src/agentkit/server/routes/channels.py new file mode 100644 index 0000000..472c663 --- /dev/null +++ b/src/agentkit/server/routes/channels.py @@ -0,0 +1,253 @@ +"""渠道管理端点 — 消息渠道配置的 CRUD。 + +端点: +- ``GET /channels`` — 列出已配置渠道 +- ``POST /channels`` — 注册新渠道(凭证加密存储) +- ``GET /channels/{id}`` — 获取渠道配置(不返回凭证明文) +- ``PUT /channels/{id}`` — 更新渠道配置 +- ``DELETE /channels/{id}`` — 删除渠道 + +凭证通过 :class:`SecretsStore` 加密存储;响应中绝不返回凭证明文, +仅返回 secret 字段名列表以便前端管理。 + +Webhook 入站端点(接收平台消息)的 fail-closed 行为:Redis 不可用或 +签名校验失败时返回 503,不可跳过 nonce dedup 直接处理消息。具体 +webhook 端点由各平台适配器子模块实现,本模块仅提供配置管理。 +""" + +from __future__ import annotations + +import logging +import re +from typing import Any + +from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel, Field + +from agentkit.channels.base import ChannelType +from agentkit.channels.secrets import SecretsStore +from agentkit.server.auth.dependencies import require_permission +from agentkit.server.auth.permissions import Permission + +logger = logging.getLogger(__name__) + +router = APIRouter(tags=["channels"]) + +# 渠道 ID 校验:小写字母、数字、下划线、连字符,1-64 字符 +_CHANNEL_ID_RE = re.compile(r"^[a-z0-9][a-z0-9_-]{0,63}$") +# 凭证字段名校验:字母数字下划线,1-64 字符 +_SECRET_NAME_RE = re.compile(r"^[a-zA-Z0-9_]{1,64}$") + +# ponytail: 模块级单例 store。当前为内存实现;PG 迁移后改为请求级 session。 +# 天花板:多进程部署下状态不共享;升级路径:注入 app.state.secrets_store, +# 由 lifespan 绑定到 PG session factory。 +_secrets_store: SecretsStore | None = None + + +def _get_secrets_store() -> SecretsStore: + """获取全局 secrets store 单例(懒加载)。""" + global _secrets_store + if _secrets_store is None: + _secrets_store = SecretsStore() + return _secrets_store + + +def _reset_secrets_store() -> None: + """重置全局 store(仅供测试使用)。""" + global _secrets_store + _secrets_store = None + + +# --------------------------------------------------------------------------- +# 内存渠道配置存储(PG 迁移预留接口) +# --------------------------------------------------------------------------- + +_channels: dict[str, dict[str, Any]] = {} + + +def _validate_channel_id(channel_id: str) -> str: + """校验渠道 ID,非法时抛 400。""" + if not _CHANNEL_ID_RE.match(channel_id): + raise HTTPException( + status_code=400, + detail=( + f"非法渠道 ID '{channel_id}':仅允许小写字母、数字、下划线、连字符(1-64 字符)" + ), + ) + return channel_id + + +def _validate_secret_names(names: list[str]) -> list[str]: + """校验凭证字段名列表。""" + for name in names: + if not _SECRET_NAME_RE.match(name): + raise HTTPException( + status_code=400, + detail=f"非法凭证字段名 '{name}':仅允许字母、数字、下划线(1-64 字符)", + ) + return names + + +# --------------------------------------------------------------------------- +# 请求 / 响应模型 +# --------------------------------------------------------------------------- + + +class ChannelCreateRequest(BaseModel): + """创建渠道请求。secrets 为明文键值对,服务端加密后存储。""" + + channel_id: str = Field(..., description="渠道唯一 ID") + channel_type: ChannelType + name: str = Field(..., description="渠道显示名称") + config: dict[str, Any] = Field(default_factory=dict, description="非敏感配置") + secrets: dict[str, str] = Field( + default_factory=dict, description="凭证键值对(明文,服务端加密存储)" + ) + + +class ChannelUpdateRequest(BaseModel): + """更新渠道请求。所有字段可选。secrets 提供则覆盖。""" + + name: str | None = None + config: dict[str, Any] | None = None + secrets: dict[str, str] | None = None + + +class ChannelInfo(BaseModel): + """渠道信息响应 — 不含凭证明文。""" + + channel_id: str + channel_type: ChannelType + name: str + config: dict[str, Any] + secret_keys: list[str] = Field(default_factory=list, description="已存储的凭证字段名") + + +# --------------------------------------------------------------------------- +# 端点 +# --------------------------------------------------------------------------- + + +@router.get("/channels") +async def list_channels( + _user: Any = Depends(require_permission(Permission.SYSTEM_CONFIG)), +) -> dict[str, Any]: + """列出所有已配置渠道。""" + items = [ + ChannelInfo( + channel_id=cid, + channel_type=ChannelType(cfg["channel_type"]), + name=cfg["name"], + config=cfg.get("config", {}), + secret_keys=cfg.get("secret_keys", []), + ).model_dump() + for cid, cfg in _channels.items() + ] + return {"channels": items, "total": len(items)} + + +@router.post("/channels", status_code=201) +async def create_channel( + payload: ChannelCreateRequest, + _user: Any = Depends(require_permission(Permission.SYSTEM_CONFIG)), +) -> ChannelInfo: + """注册新渠道。凭证经 SecretsStore 加密存储。""" + _validate_channel_id(payload.channel_id) + _validate_secret_names(list(payload.secrets.keys())) + + if payload.channel_id in _channels: + raise HTTPException(status_code=409, detail=f"渠道 '{payload.channel_id}' 已存在") + + store = _get_secrets_store() + secret_keys: list[str] = [] + for name, value in payload.secrets.items(): + secret_key = f"{payload.channel_id}:{name}" + await store.set_secret(secret_key, value) + secret_keys.append(name) + + _channels[payload.channel_id] = { + "channel_type": payload.channel_type.value, + "name": payload.name, + "config": payload.config, + "secret_keys": secret_keys, + } + + return ChannelInfo( + channel_id=payload.channel_id, + channel_type=payload.channel_type, + name=payload.name, + config=payload.config, + secret_keys=secret_keys, + ) + + +@router.get("/channels/{channel_id}") +async def get_channel( + channel_id: str, + _user: Any = Depends(require_permission(Permission.SYSTEM_CONFIG)), +) -> ChannelInfo: + """获取单个渠道配置(不返回凭证明文)。""" + cfg = _channels.get(channel_id) + if cfg is None: + raise HTTPException(status_code=404, detail=f"渠道 '{channel_id}' 不存在") + return ChannelInfo( + channel_id=channel_id, + channel_type=ChannelType(cfg["channel_type"]), + name=cfg["name"], + config=cfg.get("config", {}), + secret_keys=cfg.get("secret_keys", []), + ) + + +@router.put("/channels/{channel_id}") +async def update_channel( + channel_id: str, + payload: ChannelUpdateRequest, + _user: Any = Depends(require_permission(Permission.SYSTEM_CONFIG)), +) -> ChannelInfo: + """更新渠道配置。secrets 提供则覆盖对应凭证。""" + cfg = _channels.get(channel_id) + if cfg is None: + raise HTTPException(status_code=404, detail=f"渠道 '{channel_id}' 不存在") + + if payload.secrets is not None: + _validate_secret_names(list(payload.secrets.keys())) + store = _get_secrets_store() + # 覆盖:先移除旧凭证字段再写入新值 + existing = set(cfg.get("secret_keys", [])) + new_names = set(payload.secrets.keys()) + for name in existing - new_names: + await store.delete_secret(f"{channel_id}:{name}") + for name, value in payload.secrets.items(): + await store.set_secret(f"{channel_id}:{name}", value) + cfg["secret_keys"] = list(new_names) + + if payload.name is not None: + cfg["name"] = payload.name + if payload.config is not None: + cfg["config"] = payload.config + + return ChannelInfo( + channel_id=channel_id, + channel_type=ChannelType(cfg["channel_type"]), + name=cfg["name"], + config=cfg.get("config", {}), + secret_keys=cfg.get("secret_keys", []), + ) + + +@router.delete("/channels/{channel_id}") +async def delete_channel( + channel_id: str, + _user: Any = Depends(require_permission(Permission.SYSTEM_CONFIG)), +) -> dict[str, Any]: + """删除渠道及其全部凭证。""" + cfg = _channels.pop(channel_id, None) + if cfg is None: + raise HTTPException(status_code=404, detail=f"渠道 '{channel_id}' 不存在") + + store = _get_secrets_store() + for name in cfg.get("secret_keys", []): + await store.delete_secret(f"{channel_id}:{name}") + + return {"deleted": channel_id} diff --git a/tests/unit/channels/__init__.py b/tests/unit/channels/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/channels/test_base.py b/tests/unit/channels/test_base.py new file mode 100644 index 0000000..69196a0 --- /dev/null +++ b/tests/unit/channels/test_base.py @@ -0,0 +1,368 @@ +"""MessageAdapter ABC 与渠道管理端点测试 (U10)。 + +覆盖场景: +- MessageAdapter ABC 不能直接实例化 +- 具体子类实现协议方法后可正常工作 +- ChannelType / MessageDirection 枚举 +- IncomingMessage / OutgoingMessage 数据类 +- 渠道管理端点 CRUD 工作(GET/POST/GET/PUT/DELETE) +""" + +from __future__ import annotations + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from agentkit.channels import ( + ChannelType, + IncomingMessage, + MessageAdapter, + MessageDirection, + OutgoingMessage, +) +from agentkit.channels.secrets import KEY_SIZE, SecretsStore +from agentkit.server.routes import channels as channels_routes + + +# --------------------------------------------------------------------------- +# ABC 协议测试 +# --------------------------------------------------------------------------- + + +class TestMessageAdapterAbc: + """MessageAdapter 抽象基类协议。""" + + def test_abc_cannot_be_instantiated_directly(self): + """ABC 不能直接实例化(缺少抽象方法实现)。""" + with pytest.raises(TypeError): + MessageAdapter() # type: ignore[abstract] + + def test_subclass_missing_method_cannot_instantiate(self): + """子类未实现全部抽象方法时仍不能实例化。""" + + class PartialAdapter(MessageAdapter): + async def verify_signature(self, headers, body): + return True + + with pytest.raises(TypeError): + PartialAdapter() # type: ignore[abstract] + + def test_concrete_subclass_works(self): + """完整实现的子类可正常实例化并调用方法。""" + adapter = _StubAdapter() + assert isinstance(adapter, MessageAdapter) + + async def test_concrete_subclass_lifecycle(self): + """具体子类的完整生命周期调用。""" + adapter = _StubAdapter() + headers = {"X-Signature": "valid"} + body = b'{"msg":"hi"}' + + # verify_signature + assert await adapter.verify_signature(headers, body) is True + # receive_message + msg = await adapter.receive_message(headers, body) + assert isinstance(msg, IncomingMessage) + assert msg.channel == ChannelType.FEISHU + assert msg.content == "hi" + # send_message + out = OutgoingMessage(channel=ChannelType.FEISHU, chat_id="c1", content="ok") + assert await adapter.send_message(out) is True + # close + await adapter.close() + assert adapter.closed is True + + async def test_verify_signature_failure(self): + """verify_signature 返回 False 时调用方应拒绝。""" + adapter = _StubAdapter() + # _StubAdapter 对 "bad" 签名返回 False + assert await adapter.verify_signature({"X-Signature": "bad"}, b"") is False + + +# --------------------------------------------------------------------------- +# 枚举与数据类测试 +# --------------------------------------------------------------------------- + + +class TestEnumsAndDataclasses: + """枚举值与数据类字段。""" + + def test_channel_type_values(self): + """ChannelType 包含飞书、钉钉、企微、Slack。""" + assert ChannelType.FEISHU.value == "feishu" + assert ChannelType.DINGTALK.value == "dingtalk" + assert ChannelType.WECOM.value == "wecom" + assert ChannelType.SLACK.value == "slack" + + def test_message_direction_values(self): + """MessageDirection 包含 inbound/outbound。""" + assert MessageDirection.INBOUND.value == "inbound" + assert MessageDirection.OUTBOUND.value == "outbound" + + def test_incoming_message_defaults(self): + """IncomingMessage 默认 raw_event 为空 dict,timestamp 为空。""" + msg = IncomingMessage( + channel=ChannelType.FEISHU, + platform_message_id="m1", + user_id="u1", + chat_id="c1", + content="hello", + ) + assert msg.raw_event == {} + assert msg.timestamp == "" + # 默认值独立性(每次实例化应得到独立 dict) + msg2 = IncomingMessage( + channel=ChannelType.FEISHU, + platform_message_id="m2", + user_id="u2", + chat_id="c2", + content="hi", + ) + msg.raw_event["k"] = "v" + assert "k" not in msg2.raw_event + + def test_outgoing_message_optional_reply_to(self): + """OutgoingMessage 的 reply_to_message_id 默认 None。""" + msg = OutgoingMessage(channel=ChannelType.SLACK, chat_id="c1", content="reply") + assert msg.reply_to_message_id is None + + +# --------------------------------------------------------------------------- +# 渠道管理端点 CRUD 测试 +# --------------------------------------------------------------------------- + + +@pytest.fixture +def app(monkeypatch): + """构造仅挂载 channels 路由的最小 FastAPI 应用。 + + 使用确定性 master key 并清理模块级状态以保证测试隔离。 + 注入伪造的 admin 用户中间件,使 SYSTEM_CONFIG 权限校验通过。 + """ + monkeypatch.delenv("AGENTKIT_ENV", raising=False) + monkeypatch.delenv("AGENTKIT_MASTER_KEY", raising=False) + + # 注入确定性 store,避免依赖环境变量 + channels_routes._secrets_store = SecretsStore(master_key=b"\x01" * KEY_SIZE) + channels_routes._channels.clear() + + application = FastAPI() + application.include_router(channels_routes.router, prefix="/api/v1") + + # 伪造 admin 用户,使 require_permission(SYSTEM_CONFIG) 通过 + # (SYSTEM_CONFIG 为高风险权限,dev mode 下会 401,故需显式注入用户) + @application.middleware("http") + async def _fake_admin_auth(request, call_next): + request.state.current_user = { + "user_id": "u1", + "username": "admin", + "role": "admin", + } + return await call_next(request) + + return application + + +@pytest.fixture +def client(app): + return TestClient(app) + + +class TestChannelRoutesCrud: + """渠道管理端点 CRUD。""" + + def test_list_channels_empty(self, client): + """空状态下列出渠道返回空列表。""" + resp = client.get("/api/v1/channels") + assert resp.status_code == 200 + data = resp.json() + assert data["total"] == 0 + assert data["channels"] == [] + + def test_create_and_get_channel(self, client): + """创建渠道后可读取配置(不含凭证明文)。""" + resp = client.post( + "/api/v1/channels", + json={ + "channel_id": "feishu-prod", + "channel_type": "feishu", + "name": "飞书生产", + "config": {"webhook_path": "/hook/feishu"}, + "secrets": {"app_id": "cli_xxx", "app_secret": "topsecret"}, + }, + ) + assert resp.status_code == 201, resp.text + created = resp.json() + assert created["channel_id"] == "feishu-prod" + assert created["channel_type"] == "feishu" + assert created["name"] == "飞书生产" + # secret_keys 返回字段名,不含明文 + assert set(created["secret_keys"]) == {"app_id", "app_secret"} + assert "topsecret" not in resp.text + + # GET 单个渠道 + resp = client.get("/api/v1/channels/feishu-prod") + assert resp.status_code == 200 + got = resp.json() + assert got["channel_id"] == "feishu-prod" + assert set(got["secret_keys"]) == {"app_id", "app_secret"} + assert "topsecret" not in resp.text + + def test_create_duplicate_returns_409(self, client): + """重复创建同 ID 渠道返回 409。""" + payload = { + "channel_id": "dingtalk-1", + "channel_type": "dingtalk", + "name": "钉钉", + } + client.post("/api/v1/channels", json=payload) + resp = client.post("/api/v1/channels", json=payload) + assert resp.status_code == 409 + + def test_create_invalid_channel_id_returns_400(self, client): + """非法渠道 ID 返回 400。""" + resp = client.post( + "/api/v1/channels", + json={ + "channel_id": "Invalid ID!", + "channel_type": "feishu", + "name": "x", + }, + ) + assert resp.status_code == 400 + + def test_get_nonexistent_returns_404(self, client): + """获取不存在渠道返回 404。""" + resp = client.get("/api/v1/channels/missing") + assert resp.status_code == 404 + + def test_update_channel(self, client): + """更新渠道配置与凭证。""" + client.post( + "/api/v1/channels", + json={ + "channel_id": "wecom-1", + "channel_type": "wecom", + "name": "企微", + "secrets": {"token": "old-token"}, + }, + ) + resp = client.put( + "/api/v1/channels/wecom-1", + json={ + "name": "企微更新", + "config": {"enabled": True}, + "secrets": {"token": "new-token", "extra": "v"}, + }, + ) + assert resp.status_code == 200, resp.text + updated = resp.json() + assert updated["name"] == "企微更新" + assert updated["config"] == {"enabled": True} + assert set(updated["secret_keys"]) == {"token", "extra"} + + # 凭证已在 secrets store 中加密更新(明文不可见) + assert "new-token" not in resp.text + + def test_delete_channel(self, client): + """删除渠道后不可再获取。""" + client.post( + "/api/v1/channels", + json={ + "channel_id": "slack-1", + "channel_type": "slack", + "name": "Slack", + }, + ) + resp = client.delete("/api/v1/channels/slack-1") + assert resp.status_code == 200 + assert resp.json() == {"deleted": "slack-1"} + + # 再获取应 404 + assert client.get("/api/v1/channels/slack-1").status_code == 404 + + def test_delete_nonexistent_returns_404(self, client): + """删除不存在渠道返回 404。""" + resp = client.delete("/api/v1/channels/missing") + assert resp.status_code == 404 + + def test_list_after_create(self, client): + """创建后列出渠道包含新建项。""" + client.post( + "/api/v1/channels", + json={ + "channel_id": "feishu-a", + "channel_type": "feishu", + "name": "A", + }, + ) + client.post( + "/api/v1/channels", + json={ + "channel_id": "dingtalk-b", + "channel_type": "dingtalk", + "name": "B", + }, + ) + resp = client.get("/api/v1/channels") + assert resp.status_code == 200 + data = resp.json() + assert data["total"] == 2 + ids = {c["channel_id"] for c in data["channels"]} + assert ids == {"feishu-a", "dingtalk-b"} + + def test_secrets_encrypted_in_store(self, client): + """创建渠道后 secrets store 内部存储为密文。""" + client.post( + "/api/v1/channels", + json={ + "channel_id": "feishu-enc", + "channel_type": "feishu", + "name": "enc", + "secrets": {"app_secret": "plaintext-value"}, + }, + ) + store = channels_routes._get_secrets_store() + entry = store._store["feishu-enc:app_secret"] + # 内部存储的 value 不含明文 + assert "plaintext-value" not in entry.value + # 但可正确解密(store 为纯内存对象,可在新事件循环中调用) + import asyncio + + decrypted = asyncio.run(store.get_secret("feishu-enc:app_secret")) + assert decrypted == "plaintext-value" + + +# --------------------------------------------------------------------------- +# 辅助:具体子类桩 +# --------------------------------------------------------------------------- + + +class _StubAdapter(MessageAdapter): + """用于测试的桩适配器。""" + + def __init__(self): + self.closed = False + + async def verify_signature(self, headers: dict[str, str], body: bytes) -> bool: + return headers.get("X-Signature") == "valid" + + async def receive_message(self, headers: dict[str, str], body: bytes) -> IncomingMessage: + import json + + data = json.loads(body) + return IncomingMessage( + channel=ChannelType.FEISHU, + platform_message_id="m1", + user_id="u1", + chat_id="c1", + content=data.get("msg", ""), + raw_event=data, + ) + + async def send_message(self, message: OutgoingMessage) -> bool: + return True + + async def close(self) -> None: + self.closed = True diff --git a/tests/unit/channels/test_secrets.py b/tests/unit/channels/test_secrets.py new file mode 100644 index 0000000..3111e86 --- /dev/null +++ b/tests/unit/channels/test_secrets.py @@ -0,0 +1,245 @@ +"""SecretsStore 加密凭证存储测试 (U10 / KTD8)。 + +覆盖场景: +- secrets 写入后加密存储(非明文) +- secrets 读取时解密 +- 不同 master key / 不同次加密产生不同密文 +- 错误 master key 解密失败 +- set/get/delete/list CRUD 操作 +- 生产环境启动 guard +""" + +from __future__ import annotations + +import base64 +import os + +import pytest +from cryptography.exceptions import InvalidTag + +from agentkit.channels.secrets import ( + KEY_SIZE, + SecretsStore, + assert_production_master_key, +) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def master_key() -> bytes: + """确定性的测试 master key(32 字节)。""" + return b"\x01" * KEY_SIZE + + +@pytest.fixture +def store(master_key: bytes) -> SecretsStore: + """使用确定性 master key 的 SecretsStore。""" + return SecretsStore(master_key=master_key) + + +@pytest.fixture(autouse=True) +def _clean_production_env(monkeypatch): + """确保测试不在生产模式下运行(移除 AGENTKIT_ENV=production)。""" + monkeypatch.delenv("AGENTKIT_ENV", raising=False) + monkeypatch.delenv("AGENTKIT_MASTER_KEY", raising=False) + + +# --------------------------------------------------------------------------- +# 加密 / 解密往返 +# --------------------------------------------------------------------------- + + +class TestEncryptDecrypt: + """加密与解密核心行为。""" + + def test_roundtrip_returns_plaintext(self, store: SecretsStore): + """加密后解密应返回原始明文。""" + plaintext = "feishu-app-secret-12345" + entry = store.encrypt(plaintext) + assert store.decrypt(entry) == plaintext + + def test_encrypted_value_is_not_plaintext(self, store: SecretsStore): + """写入后密文不应包含明文(非明文存储)。""" + plaintext = "super-secret-token" + entry = store.encrypt(plaintext) + # 密文 base64 不应包含明文 + assert plaintext not in entry.value + assert plaintext not in entry.nonce + assert plaintext not in entry.salt + # 解码后的密文也不应包含明文字节 + ciphertext_bytes = base64.b64decode(entry.value) + assert plaintext.encode() not in ciphertext_bytes + + def test_each_encrypt_produces_different_ciphertext(self, store: SecretsStore): + """同一明文每次加密应产生不同密文(随机 nonce+salt)。""" + plaintext = "same-secret" + entry1 = store.encrypt(plaintext) + entry2 = store.encrypt(plaintext) + assert entry1.value != entry2.value + assert entry1.nonce != entry2.nonce + assert entry1.salt != entry2.salt + # 两者都能正确解密 + assert store.decrypt(entry1) == plaintext + assert store.decrypt(entry2) == plaintext + + def test_different_master_keys_produce_different_ciphertext(self): + """不同 master key 派生不同 per-row 密钥。""" + plaintext = "shared-secret" + store_a = SecretsStore(master_key=b"\x01" * KEY_SIZE) + store_b = SecretsStore(master_key=b"\x02" * KEY_SIZE) + entry_a = store_a.encrypt(plaintext) + # store_b 无法解密 store_a 加密的条目 + with pytest.raises(InvalidTag): + store_b.decrypt(entry_a) + # store_a 自身可解密 + assert store_a.decrypt(entry_a) == plaintext + + def test_decrypt_with_wrong_key_fails(self, store: SecretsStore): + """错误 master key 解密应抛出 InvalidTag。""" + entry = store.encrypt("secret-data") + wrong_store = SecretsStore(master_key=b"\xab" * KEY_SIZE) + with pytest.raises(InvalidTag): + wrong_store.decrypt(entry) + + def test_entry_has_correct_metadata(self, store: SecretsStore): + """加密条目应包含 nonce、salt、key_id 等元数据。""" + entry = store.encrypt("data") + assert entry.key == "" + assert entry.nonce + assert entry.salt + assert entry.key_id == "default" + # nonce 解码后为 12 字节 + assert len(base64.b64decode(entry.nonce)) == 12 + # salt 解码后为 16 字节 + assert len(base64.b64decode(entry.salt)) == 16 + + +# --------------------------------------------------------------------------- +# CRUD 异步操作 +# --------------------------------------------------------------------------- + + +class TestSecretsCrud: + """set/get/delete/list 异步操作。""" + + async def test_set_and_get_secret(self, store: SecretsStore): + """写入后读取应返回明文。""" + await store.set_secret("feishu:app_id", "cli_xxx") + assert await store.get_secret("feishu:app_id") == "cli_xxx" + + async def test_get_nonexistent_secret_returns_none(self, store: SecretsStore): + """读取不存在的 key 返回 None。""" + assert await store.get_secret("missing") is None + + async def test_set_overwrites_existing(self, store: SecretsStore): + """同名 key 写入应覆盖旧值。""" + await store.set_secret("k", "old") + await store.set_secret("k", "new") + assert await store.get_secret("k") == "new" + + async def test_delete_secret(self, store: SecretsStore): + """删除凭证后不可再读取。""" + await store.set_secret("k", "v") + assert await store.delete_secret("k") is True + assert await store.get_secret("k") is None + + async def test_delete_nonexistent_returns_false(self, store: SecretsStore): + """删除不存在的 key 返回 False。""" + assert await store.delete_secret("missing") is False + + async def test_list_keys(self, store: SecretsStore): + """列出全部凭证键。""" + await store.set_secret("feishu:a", "1") + await store.set_secret("feishu:b", "2") + await store.set_secret("dingtalk:c", "3") + keys = await store.list_keys() + assert set(keys) == {"feishu:a", "feishu:b", "dingtalk:c"} + + async def test_list_keys_with_prefix(self, store: SecretsStore): + """按前缀过滤凭证键。""" + await store.set_secret("feishu:a", "1") + await store.set_secret("feishu:b", "2") + await store.set_secret("dingtalk:c", "3") + keys = await store.list_keys(prefix="feishu:") + assert set(keys) == {"feishu:a", "feishu:b"} + + async def test_stored_value_is_encrypted_not_plaintext(self, store: SecretsStore): + """写入后内部存储的 value 字段应为密文,非明文。""" + plaintext = "plaintext-token" + await store.set_secret("k", plaintext) + # 直接检查内存中的 entry.value,确认非明文 + entry = store._store["k"] + assert plaintext not in entry.value + assert plaintext not in base64.b64decode(entry.value).decode(errors="ignore") + + +# --------------------------------------------------------------------------- +# 生产环境 guard +# --------------------------------------------------------------------------- + + +class TestProductionGuard: + """生产环境 master key 启动 guard。""" + + def test_dev_mode_allows_env_source(self, master_key: bytes, monkeypatch): + """开发模式下 env 来源 master key 允许构造。""" + monkeypatch.delenv("AGENTKIT_ENV", raising=False) + # 不应抛异常 + store = SecretsStore(master_key=master_key, key_source="env") + assert store is not None + + def test_production_mode_rejects_env_source(self, master_key: bytes, monkeypatch): + """生产模式下 env 来源 master key 应拒绝启动。""" + monkeypatch.setenv("AGENTKIT_ENV", "production") + with pytest.raises(RuntimeError, match="云 KMS"): + SecretsStore(master_key=master_key, key_source="env") + + def test_production_mode_rejects_missing_key(self, monkeypatch): + """生产模式下无 master key 应拒绝启动。""" + monkeypatch.setenv("AGENTKIT_ENV", "production") + with pytest.raises(RuntimeError, match="云 KMS"): + SecretsStore(master_key=None, key_source="env") + + def test_production_mode_allows_kms_source(self, master_key: bytes, monkeypatch): + """生产模式下 KMS 来源 master key 允许构造。""" + monkeypatch.setenv("AGENTKIT_ENV", "production") + store = SecretsStore(master_key=master_key, key_source="kms") + assert store is not None + + def test_assert_production_master_key_helper(self, master_key: bytes, monkeypatch): + """直接测试 guard 辅助函数。""" + monkeypatch.setenv("AGENTKIT_ENV", "production") + # env 来源拒绝 + with pytest.raises(RuntimeError): + assert_production_master_key(master_key, source="env") + # kms 来源通过 + assert_production_master_key(master_key, source="kms") + # 空密钥拒绝 + with pytest.raises(RuntimeError): + assert_production_master_key(None, source="kms") + + +# --------------------------------------------------------------------------- +# 环境变量加载 fallback +# --------------------------------------------------------------------------- + + +class TestEnvLoading: + """开发环境从环境变量加载 master key。""" + + def test_loads_master_key_from_env(self, monkeypatch): + """AGENTKIT_MASTER_KEY 环境变量应被加载。""" + key = os.urandom(KEY_SIZE) + monkeypatch.setenv("AGENTKIT_MASTER_KEY", base64.b64encode(key).decode()) + store = SecretsStore() # 不显式传入 master_key + assert store._master_key == key + + def test_ephemeral_key_when_no_env(self, monkeypatch): + """无环境变量时生成临时 key(开发 fallback)。""" + monkeypatch.delenv("AGENTKIT_MASTER_KEY", raising=False) + store = SecretsStore() + assert len(store._master_key) == KEY_SIZE