feat(channels): U10 — message adapter ABC + AES-256-GCM secrets store + channel CRUD routes
This commit is contained in:
parent
af96cb49bd
commit
5572387c01
|
|
@ -24,6 +24,8 @@ dependencies = [
|
|||
"pyjwt>=2.8",
|
||||
"bcrypt>=4.0",
|
||||
"aiosqlite>=0.20",
|
||||
# 加密 secrets store (U10 — 多端消息适配器 AES-256-GCM)
|
||||
"cryptography>=42.0",
|
||||
# Calendar & schedule (RRULE expansion)
|
||||
"python-dateutil>=2.9",
|
||||
# Calendar ICS import/export (U8)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,27 @@
|
|||
"""多端消息适配器子系统 (U10)。
|
||||
|
||||
提供渠道适配器 ABC 与加密 secrets store 基础设施:
|
||||
- :class:`MessageAdapter` — 所有平台适配器的抽象基类
|
||||
- :class:`SecretsStore` — AES-256-GCM 加密凭证存储 (KTD8)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from agentkit.channels.base import (
|
||||
ChannelType,
|
||||
IncomingMessage,
|
||||
MessageAdapter,
|
||||
MessageDirection,
|
||||
OutgoingMessage,
|
||||
)
|
||||
from agentkit.channels.secrets import SecretEntry, SecretsStore
|
||||
|
||||
__all__ = [
|
||||
"ChannelType",
|
||||
"IncomingMessage",
|
||||
"MessageAdapter",
|
||||
"MessageDirection",
|
||||
"OutgoingMessage",
|
||||
"SecretEntry",
|
||||
"SecretsStore",
|
||||
]
|
||||
|
|
@ -0,0 +1,79 @@
|
|||
"""消息适配器 ABC — 所有渠道适配器的基类。
|
||||
|
||||
与 KBAdapter 的 authenticate()/close() 生命周期方法对齐:
|
||||
verify_signature() 对应 authenticate()。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ChannelType(str, Enum):
|
||||
"""支持的消息平台渠道。"""
|
||||
|
||||
FEISHU = "feishu"
|
||||
DINGTALK = "dingtalk"
|
||||
WECOM = "wecom"
|
||||
SLACK = "slack"
|
||||
|
||||
|
||||
class MessageDirection(str, Enum):
|
||||
"""消息流向。"""
|
||||
|
||||
INBOUND = "inbound"
|
||||
OUTBOUND = "outbound"
|
||||
|
||||
|
||||
@dataclass
|
||||
class IncomingMessage:
|
||||
"""标准化入站消息 — 所有渠道适配器将平台特定格式转换为此结构。"""
|
||||
|
||||
channel: ChannelType
|
||||
platform_message_id: str
|
||||
user_id: str # 平台用户 ID
|
||||
chat_id: str # 群组/会话 ID
|
||||
content: str # 消息文本
|
||||
raw_event: dict[str, Any] = field(default_factory=dict) # 原始事件
|
||||
timestamp: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutgoingMessage:
|
||||
"""标准化出站消息。"""
|
||||
|
||||
channel: ChannelType
|
||||
chat_id: str
|
||||
content: str
|
||||
reply_to_message_id: str | None = None
|
||||
|
||||
|
||||
class MessageAdapter(abc.ABC):
|
||||
"""消息适配器 ABC。
|
||||
|
||||
生命周期:
|
||||
__init__ → verify_signature() → receive_message() → send_message() → close()
|
||||
|
||||
子类必须实现全部抽象方法。verify_signature 失败时调用方应拒绝处理
|
||||
(webhook 端点 fail-closed:Redis 不可用或签名校验失败均返回 503/401,
|
||||
不可跳过 nonce dedup 直接处理消息)。
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def verify_signature(self, headers: dict[str, str], body: bytes) -> bool:
|
||||
"""验证平台签名/token。返回 True 表示请求可信。"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def receive_message(self, headers: dict[str, str], body: bytes) -> IncomingMessage:
|
||||
"""从 webhook 请求中解析标准化消息。"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def send_message(self, message: OutgoingMessage) -> bool:
|
||||
"""向平台发送消息。返回 True 表示发送成功。"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def close(self) -> None:
|
||||
"""释放资源(HTTP 客户端、连接池等)。"""
|
||||
|
|
@ -0,0 +1,172 @@
|
|||
"""加密 DB 列 secrets store — AES-256-GCM (KTD8)。
|
||||
|
||||
KTD8 关键决策:
|
||||
- 每行使用随机 96-bit nonce + HKDF with per-row salt 派生 per-row 密钥。
|
||||
- 生产环境 master key 必须来自云 KMS;环境变量 AGENTKIT_MASTER_KEY 仅作开发 fallback。
|
||||
- 应用启动 guard:生产模式下若仅检测到环境变量 key 应拒绝启动
|
||||
(由 server 启动钩子实现,本模块提供 ``assert_production_master_key`` 辅助)。
|
||||
- Master key 轮换采用双密钥窗口策略(key_id 字段标记)。
|
||||
|
||||
当前为内存存储实现,PG 迁移预留接口(``_store: dict`` → 未来替换为 ORM session)。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# AES-256-GCM 参数
|
||||
NONCE_SIZE = 12 # 96-bit nonce (GCM 推荐长度)
|
||||
KEY_SIZE = 32 # 256-bit key
|
||||
SALT_SIZE = 16 # per-row salt
|
||||
HKDF_INFO = b"agentkit-channels-secrets-v1"
|
||||
|
||||
# 生产环境标记:当此环境变量存在时启用生产 guard。
|
||||
_PRODUCTION_ENV_VAR = "AGENTKIT_ENV"
|
||||
_PRODUCTION_VALUE = "production"
|
||||
|
||||
|
||||
class SecretEntry(BaseModel):
|
||||
"""凭证条目 — 对应 DB 一行加密列。"""
|
||||
|
||||
model_config = ConfigDict()
|
||||
|
||||
key: str # 唯一键(如 "feishu:app_id:xxx")
|
||||
value: str # base64 编码的密文
|
||||
nonce: str # base64 编码的 96-bit nonce
|
||||
salt: str # base64 编码的 per-row salt
|
||||
key_id: str = "default" # master key ID(用于双密钥轮换窗口)
|
||||
created_at: str = ""
|
||||
updated_at: str = ""
|
||||
|
||||
|
||||
def assert_production_master_key(master_key: bytes | None, *, source: str = "env") -> None:
|
||||
"""生产环境启动 guard。
|
||||
|
||||
若运行于生产模式(``AGENTKIT_ENV=production``)且 master key 来源于
|
||||
环境变量(source="env")或为空,则拒绝启动。
|
||||
|
||||
Args:
|
||||
master_key: 已加载的 master key 字节,可能为 None。
|
||||
source: master key 来源标记。"env" 表示环境变量 fallback,
|
||||
"kms" 表示云 KMS。生产模式仅接受 "kms"。
|
||||
|
||||
Raises:
|
||||
RuntimeError: 生产模式下 master key 来源不可信。
|
||||
"""
|
||||
if os.environ.get(_PRODUCTION_ENV_VAR) == _PRODUCTION_VALUE:
|
||||
if source != "kms" or not master_key:
|
||||
raise RuntimeError(
|
||||
"生产环境 master key 必须来自云 KMS(source=kms);"
|
||||
"环境变量 AGENTKIT_MASTER_KEY 仅作开发 fallback。"
|
||||
"请通过 KMS 注入 master key 并以 source='kms' 调用本 guard。"
|
||||
)
|
||||
|
||||
|
||||
class SecretsStore:
|
||||
"""加密凭证存储。
|
||||
|
||||
使用 AES-256-GCM 加密,HKDF with per-row salt 派生 per-row 密钥。
|
||||
Master key 从环境变量 ``AGENTKIT_MASTER_KEY`` 读取(开发 fallback);
|
||||
生产环境应通过 KMS 提供 master key 并显式传入。
|
||||
"""
|
||||
|
||||
def __init__(self, master_key: bytes | None = None, *, key_source: str = "env"):
|
||||
"""初始化 secrets store。
|
||||
|
||||
Args:
|
||||
master_key: 显式传入的 master key。若为 None 则从环境变量加载。
|
||||
key_source: master key 来源标记,传给 ``assert_production_master_key``。
|
||||
生产环境应设为 "kms"。
|
||||
"""
|
||||
self._master_key = master_key or self._load_master_key()
|
||||
# 生产 guard:若 key_source="env" 且在生产模式下,构造即失败。
|
||||
assert_production_master_key(self._master_key, source=key_source)
|
||||
# 内存存储(PG 迁移预留接口:替换为 ORM session 即可)
|
||||
self._store: dict[str, SecretEntry] = {}
|
||||
|
||||
def _load_master_key(self) -> bytes:
|
||||
"""从环境变量加载 master key(开发 fallback)。"""
|
||||
key_b64 = os.environ.get("AGENTKIT_MASTER_KEY")
|
||||
if key_b64:
|
||||
return base64.b64decode(key_b64)
|
||||
# ponytail: 开发模式生成临时 key。生产环境必须通过 KMS 提供
|
||||
# (由 assert_production_master_key 在 __init__ 中强制)。
|
||||
logger.warning("AGENTKIT_MASTER_KEY 未设置 — 使用临时 key(仅限开发)")
|
||||
return os.urandom(KEY_SIZE)
|
||||
|
||||
def _derive_key(self, salt: bytes) -> bytes:
|
||||
"""HKDF 从 master key + per-row salt 派生 per-row 密钥。"""
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
|
||||
|
||||
hkdf = HKDF(
|
||||
algorithm=hashes.SHA256(),
|
||||
length=KEY_SIZE,
|
||||
salt=salt,
|
||||
info=HKDF_INFO,
|
||||
)
|
||||
return hkdf.derive(self._master_key)
|
||||
|
||||
def encrypt(self, plaintext: str) -> SecretEntry:
|
||||
"""加密凭证,返回 SecretEntry(key 字段留空,由调用方设置)。"""
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
|
||||
salt = os.urandom(SALT_SIZE)
|
||||
key = self._derive_key(salt)
|
||||
nonce = os.urandom(NONCE_SIZE)
|
||||
|
||||
aesgcm = AESGCM(key)
|
||||
ciphertext = aesgcm.encrypt(nonce, plaintext.encode(), None)
|
||||
|
||||
return SecretEntry(
|
||||
key="",
|
||||
value=base64.b64encode(ciphertext).decode(),
|
||||
nonce=base64.b64encode(nonce).decode(),
|
||||
salt=base64.b64encode(salt).decode(),
|
||||
)
|
||||
|
||||
def decrypt(self, entry: SecretEntry) -> str:
|
||||
"""解密凭证。密钥不匹配或密文损坏时抛出 cryptography 异常。"""
|
||||
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
|
||||
|
||||
salt = base64.b64decode(entry.salt)
|
||||
key = self._derive_key(salt)
|
||||
nonce = base64.b64decode(entry.nonce)
|
||||
ciphertext = base64.b64decode(entry.value)
|
||||
|
||||
aesgcm = AESGCM(key)
|
||||
plaintext = aesgcm.decrypt(nonce, ciphertext, None)
|
||||
return plaintext.decode()
|
||||
|
||||
async def set_secret(self, key: str, value: str) -> SecretEntry:
|
||||
"""存储加密凭证。覆盖同名 key。"""
|
||||
entry = self.encrypt(value)
|
||||
entry.key = key
|
||||
self._store[key] = entry
|
||||
return entry
|
||||
|
||||
async def get_secret(self, key: str) -> str | None:
|
||||
"""读取并解密凭证。key 不存在返回 None。"""
|
||||
entry = self._store.get(key)
|
||||
if entry is None:
|
||||
return None
|
||||
return self.decrypt(entry)
|
||||
|
||||
async def delete_secret(self, key: str) -> bool:
|
||||
"""删除凭证。返回是否删除成功。"""
|
||||
if key in self._store:
|
||||
del self._store[key]
|
||||
return True
|
||||
return False
|
||||
|
||||
async def list_keys(self, prefix: str | None = None) -> list[str]:
|
||||
"""列出凭证键。可选前缀过滤。"""
|
||||
if prefix:
|
||||
return [k for k in self._store if k.startswith(prefix)]
|
||||
return list(self._store.keys())
|
||||
|
|
@ -52,6 +52,7 @@ from agentkit.server.routes import (
|
|||
admin as admin_routes_module,
|
||||
calendar as calendar_routes,
|
||||
bitable as bitable_routes,
|
||||
channels as channels_routes,
|
||||
)
|
||||
from agentkit.server.auth.jwt_utils import get_jwt_secret
|
||||
from agentkit.server.auth.middleware import AuthMiddleware
|
||||
|
|
@ -1082,6 +1083,7 @@ def create_app(
|
|||
app.include_router(documents.router, prefix="/api/v1")
|
||||
app.include_router(calendar_routes.router, prefix="/api/v1")
|
||||
app.include_router(bitable_routes.router, prefix="/api/v1")
|
||||
app.include_router(channels_routes.router, prefix="/api/v1")
|
||||
|
||||
# Serve GUI when in GUI mode
|
||||
gui_mode = os.environ.get("AGENTKIT_GUI_MODE")
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from agentkit.server.routes import (
|
|||
workflows,
|
||||
terminal,
|
||||
experts,
|
||||
channels,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -36,4 +37,5 @@ __all__ = [
|
|||
"workflows",
|
||||
"terminal",
|
||||
"experts",
|
||||
"channels",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,253 @@
|
|||
"""渠道管理端点 — 消息渠道配置的 CRUD。
|
||||
|
||||
端点:
|
||||
- ``GET /channels`` — 列出已配置渠道
|
||||
- ``POST /channels`` — 注册新渠道(凭证加密存储)
|
||||
- ``GET /channels/{id}`` — 获取渠道配置(不返回凭证明文)
|
||||
- ``PUT /channels/{id}`` — 更新渠道配置
|
||||
- ``DELETE /channels/{id}`` — 删除渠道
|
||||
|
||||
凭证通过 :class:`SecretsStore` 加密存储;响应中绝不返回凭证明文,
|
||||
仅返回 secret 字段名列表以便前端管理。
|
||||
|
||||
Webhook 入站端点(接收平台消息)的 fail-closed 行为:Redis 不可用或
|
||||
签名校验失败时返回 503,不可跳过 nonce dedup 直接处理消息。具体
|
||||
webhook 端点由各平台适配器子模块实现,本模块仅提供配置管理。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from agentkit.channels.base import ChannelType
|
||||
from agentkit.channels.secrets import SecretsStore
|
||||
from agentkit.server.auth.dependencies import require_permission
|
||||
from agentkit.server.auth.permissions import Permission
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["channels"])
|
||||
|
||||
# 渠道 ID 校验:小写字母、数字、下划线、连字符,1-64 字符
|
||||
_CHANNEL_ID_RE = re.compile(r"^[a-z0-9][a-z0-9_-]{0,63}$")
|
||||
# 凭证字段名校验:字母数字下划线,1-64 字符
|
||||
_SECRET_NAME_RE = re.compile(r"^[a-zA-Z0-9_]{1,64}$")
|
||||
|
||||
# ponytail: 模块级单例 store。当前为内存实现;PG 迁移后改为请求级 session。
|
||||
# 天花板:多进程部署下状态不共享;升级路径:注入 app.state.secrets_store,
|
||||
# 由 lifespan 绑定到 PG session factory。
|
||||
_secrets_store: SecretsStore | None = None
|
||||
|
||||
|
||||
def _get_secrets_store() -> SecretsStore:
|
||||
"""获取全局 secrets store 单例(懒加载)。"""
|
||||
global _secrets_store
|
||||
if _secrets_store is None:
|
||||
_secrets_store = SecretsStore()
|
||||
return _secrets_store
|
||||
|
||||
|
||||
def _reset_secrets_store() -> None:
|
||||
"""重置全局 store(仅供测试使用)。"""
|
||||
global _secrets_store
|
||||
_secrets_store = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 内存渠道配置存储(PG 迁移预留接口)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_channels: dict[str, dict[str, Any]] = {}
|
||||
|
||||
|
||||
def _validate_channel_id(channel_id: str) -> str:
|
||||
"""校验渠道 ID,非法时抛 400。"""
|
||||
if not _CHANNEL_ID_RE.match(channel_id):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"非法渠道 ID '{channel_id}':仅允许小写字母、数字、下划线、连字符(1-64 字符)"
|
||||
),
|
||||
)
|
||||
return channel_id
|
||||
|
||||
|
||||
def _validate_secret_names(names: list[str]) -> list[str]:
|
||||
"""校验凭证字段名列表。"""
|
||||
for name in names:
|
||||
if not _SECRET_NAME_RE.match(name):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"非法凭证字段名 '{name}':仅允许字母、数字、下划线(1-64 字符)",
|
||||
)
|
||||
return names
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 请求 / 响应模型
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ChannelCreateRequest(BaseModel):
|
||||
"""创建渠道请求。secrets 为明文键值对,服务端加密后存储。"""
|
||||
|
||||
channel_id: str = Field(..., description="渠道唯一 ID")
|
||||
channel_type: ChannelType
|
||||
name: str = Field(..., description="渠道显示名称")
|
||||
config: dict[str, Any] = Field(default_factory=dict, description="非敏感配置")
|
||||
secrets: dict[str, str] = Field(
|
||||
default_factory=dict, description="凭证键值对(明文,服务端加密存储)"
|
||||
)
|
||||
|
||||
|
||||
class ChannelUpdateRequest(BaseModel):
|
||||
"""更新渠道请求。所有字段可选。secrets 提供则覆盖。"""
|
||||
|
||||
name: str | None = None
|
||||
config: dict[str, Any] | None = None
|
||||
secrets: dict[str, str] | None = None
|
||||
|
||||
|
||||
class ChannelInfo(BaseModel):
|
||||
"""渠道信息响应 — 不含凭证明文。"""
|
||||
|
||||
channel_id: str
|
||||
channel_type: ChannelType
|
||||
name: str
|
||||
config: dict[str, Any]
|
||||
secret_keys: list[str] = Field(default_factory=list, description="已存储的凭证字段名")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 端点
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/channels")
|
||||
async def list_channels(
|
||||
_user: Any = Depends(require_permission(Permission.SYSTEM_CONFIG)),
|
||||
) -> dict[str, Any]:
|
||||
"""列出所有已配置渠道。"""
|
||||
items = [
|
||||
ChannelInfo(
|
||||
channel_id=cid,
|
||||
channel_type=ChannelType(cfg["channel_type"]),
|
||||
name=cfg["name"],
|
||||
config=cfg.get("config", {}),
|
||||
secret_keys=cfg.get("secret_keys", []),
|
||||
).model_dump()
|
||||
for cid, cfg in _channels.items()
|
||||
]
|
||||
return {"channels": items, "total": len(items)}
|
||||
|
||||
|
||||
@router.post("/channels", status_code=201)
|
||||
async def create_channel(
|
||||
payload: ChannelCreateRequest,
|
||||
_user: Any = Depends(require_permission(Permission.SYSTEM_CONFIG)),
|
||||
) -> ChannelInfo:
|
||||
"""注册新渠道。凭证经 SecretsStore 加密存储。"""
|
||||
_validate_channel_id(payload.channel_id)
|
||||
_validate_secret_names(list(payload.secrets.keys()))
|
||||
|
||||
if payload.channel_id in _channels:
|
||||
raise HTTPException(status_code=409, detail=f"渠道 '{payload.channel_id}' 已存在")
|
||||
|
||||
store = _get_secrets_store()
|
||||
secret_keys: list[str] = []
|
||||
for name, value in payload.secrets.items():
|
||||
secret_key = f"{payload.channel_id}:{name}"
|
||||
await store.set_secret(secret_key, value)
|
||||
secret_keys.append(name)
|
||||
|
||||
_channels[payload.channel_id] = {
|
||||
"channel_type": payload.channel_type.value,
|
||||
"name": payload.name,
|
||||
"config": payload.config,
|
||||
"secret_keys": secret_keys,
|
||||
}
|
||||
|
||||
return ChannelInfo(
|
||||
channel_id=payload.channel_id,
|
||||
channel_type=payload.channel_type,
|
||||
name=payload.name,
|
||||
config=payload.config,
|
||||
secret_keys=secret_keys,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/channels/{channel_id}")
|
||||
async def get_channel(
|
||||
channel_id: str,
|
||||
_user: Any = Depends(require_permission(Permission.SYSTEM_CONFIG)),
|
||||
) -> ChannelInfo:
|
||||
"""获取单个渠道配置(不返回凭证明文)。"""
|
||||
cfg = _channels.get(channel_id)
|
||||
if cfg is None:
|
||||
raise HTTPException(status_code=404, detail=f"渠道 '{channel_id}' 不存在")
|
||||
return ChannelInfo(
|
||||
channel_id=channel_id,
|
||||
channel_type=ChannelType(cfg["channel_type"]),
|
||||
name=cfg["name"],
|
||||
config=cfg.get("config", {}),
|
||||
secret_keys=cfg.get("secret_keys", []),
|
||||
)
|
||||
|
||||
|
||||
@router.put("/channels/{channel_id}")
|
||||
async def update_channel(
|
||||
channel_id: str,
|
||||
payload: ChannelUpdateRequest,
|
||||
_user: Any = Depends(require_permission(Permission.SYSTEM_CONFIG)),
|
||||
) -> ChannelInfo:
|
||||
"""更新渠道配置。secrets 提供则覆盖对应凭证。"""
|
||||
cfg = _channels.get(channel_id)
|
||||
if cfg is None:
|
||||
raise HTTPException(status_code=404, detail=f"渠道 '{channel_id}' 不存在")
|
||||
|
||||
if payload.secrets is not None:
|
||||
_validate_secret_names(list(payload.secrets.keys()))
|
||||
store = _get_secrets_store()
|
||||
# 覆盖:先移除旧凭证字段再写入新值
|
||||
existing = set(cfg.get("secret_keys", []))
|
||||
new_names = set(payload.secrets.keys())
|
||||
for name in existing - new_names:
|
||||
await store.delete_secret(f"{channel_id}:{name}")
|
||||
for name, value in payload.secrets.items():
|
||||
await store.set_secret(f"{channel_id}:{name}", value)
|
||||
cfg["secret_keys"] = list(new_names)
|
||||
|
||||
if payload.name is not None:
|
||||
cfg["name"] = payload.name
|
||||
if payload.config is not None:
|
||||
cfg["config"] = payload.config
|
||||
|
||||
return ChannelInfo(
|
||||
channel_id=channel_id,
|
||||
channel_type=ChannelType(cfg["channel_type"]),
|
||||
name=cfg["name"],
|
||||
config=cfg.get("config", {}),
|
||||
secret_keys=cfg.get("secret_keys", []),
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/channels/{channel_id}")
|
||||
async def delete_channel(
|
||||
channel_id: str,
|
||||
_user: Any = Depends(require_permission(Permission.SYSTEM_CONFIG)),
|
||||
) -> dict[str, Any]:
|
||||
"""删除渠道及其全部凭证。"""
|
||||
cfg = _channels.pop(channel_id, None)
|
||||
if cfg is None:
|
||||
raise HTTPException(status_code=404, detail=f"渠道 '{channel_id}' 不存在")
|
||||
|
||||
store = _get_secrets_store()
|
||||
for name in cfg.get("secret_keys", []):
|
||||
await store.delete_secret(f"{channel_id}:{name}")
|
||||
|
||||
return {"deleted": channel_id}
|
||||
|
|
@ -0,0 +1,368 @@
|
|||
"""MessageAdapter ABC 与渠道管理端点测试 (U10)。
|
||||
|
||||
覆盖场景:
|
||||
- MessageAdapter ABC 不能直接实例化
|
||||
- 具体子类实现协议方法后可正常工作
|
||||
- ChannelType / MessageDirection 枚举
|
||||
- IncomingMessage / OutgoingMessage 数据类
|
||||
- 渠道管理端点 CRUD 工作(GET/POST/GET/PUT/DELETE)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from agentkit.channels import (
|
||||
ChannelType,
|
||||
IncomingMessage,
|
||||
MessageAdapter,
|
||||
MessageDirection,
|
||||
OutgoingMessage,
|
||||
)
|
||||
from agentkit.channels.secrets import KEY_SIZE, SecretsStore
|
||||
from agentkit.server.routes import channels as channels_routes
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ABC 协议测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMessageAdapterAbc:
|
||||
"""MessageAdapter 抽象基类协议。"""
|
||||
|
||||
def test_abc_cannot_be_instantiated_directly(self):
|
||||
"""ABC 不能直接实例化(缺少抽象方法实现)。"""
|
||||
with pytest.raises(TypeError):
|
||||
MessageAdapter() # type: ignore[abstract]
|
||||
|
||||
def test_subclass_missing_method_cannot_instantiate(self):
|
||||
"""子类未实现全部抽象方法时仍不能实例化。"""
|
||||
|
||||
class PartialAdapter(MessageAdapter):
|
||||
async def verify_signature(self, headers, body):
|
||||
return True
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
PartialAdapter() # type: ignore[abstract]
|
||||
|
||||
def test_concrete_subclass_works(self):
|
||||
"""完整实现的子类可正常实例化并调用方法。"""
|
||||
adapter = _StubAdapter()
|
||||
assert isinstance(adapter, MessageAdapter)
|
||||
|
||||
async def test_concrete_subclass_lifecycle(self):
|
||||
"""具体子类的完整生命周期调用。"""
|
||||
adapter = _StubAdapter()
|
||||
headers = {"X-Signature": "valid"}
|
||||
body = b'{"msg":"hi"}'
|
||||
|
||||
# verify_signature
|
||||
assert await adapter.verify_signature(headers, body) is True
|
||||
# receive_message
|
||||
msg = await adapter.receive_message(headers, body)
|
||||
assert isinstance(msg, IncomingMessage)
|
||||
assert msg.channel == ChannelType.FEISHU
|
||||
assert msg.content == "hi"
|
||||
# send_message
|
||||
out = OutgoingMessage(channel=ChannelType.FEISHU, chat_id="c1", content="ok")
|
||||
assert await adapter.send_message(out) is True
|
||||
# close
|
||||
await adapter.close()
|
||||
assert adapter.closed is True
|
||||
|
||||
async def test_verify_signature_failure(self):
|
||||
"""verify_signature 返回 False 时调用方应拒绝。"""
|
||||
adapter = _StubAdapter()
|
||||
# _StubAdapter 对 "bad" 签名返回 False
|
||||
assert await adapter.verify_signature({"X-Signature": "bad"}, b"") is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 枚举与数据类测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEnumsAndDataclasses:
|
||||
"""枚举值与数据类字段。"""
|
||||
|
||||
def test_channel_type_values(self):
|
||||
"""ChannelType 包含飞书、钉钉、企微、Slack。"""
|
||||
assert ChannelType.FEISHU.value == "feishu"
|
||||
assert ChannelType.DINGTALK.value == "dingtalk"
|
||||
assert ChannelType.WECOM.value == "wecom"
|
||||
assert ChannelType.SLACK.value == "slack"
|
||||
|
||||
def test_message_direction_values(self):
|
||||
"""MessageDirection 包含 inbound/outbound。"""
|
||||
assert MessageDirection.INBOUND.value == "inbound"
|
||||
assert MessageDirection.OUTBOUND.value == "outbound"
|
||||
|
||||
def test_incoming_message_defaults(self):
|
||||
"""IncomingMessage 默认 raw_event 为空 dict,timestamp 为空。"""
|
||||
msg = IncomingMessage(
|
||||
channel=ChannelType.FEISHU,
|
||||
platform_message_id="m1",
|
||||
user_id="u1",
|
||||
chat_id="c1",
|
||||
content="hello",
|
||||
)
|
||||
assert msg.raw_event == {}
|
||||
assert msg.timestamp == ""
|
||||
# 默认值独立性(每次实例化应得到独立 dict)
|
||||
msg2 = IncomingMessage(
|
||||
channel=ChannelType.FEISHU,
|
||||
platform_message_id="m2",
|
||||
user_id="u2",
|
||||
chat_id="c2",
|
||||
content="hi",
|
||||
)
|
||||
msg.raw_event["k"] = "v"
|
||||
assert "k" not in msg2.raw_event
|
||||
|
||||
def test_outgoing_message_optional_reply_to(self):
|
||||
"""OutgoingMessage 的 reply_to_message_id 默认 None。"""
|
||||
msg = OutgoingMessage(channel=ChannelType.SLACK, chat_id="c1", content="reply")
|
||||
assert msg.reply_to_message_id is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 渠道管理端点 CRUD 测试
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app(monkeypatch):
|
||||
"""构造仅挂载 channels 路由的最小 FastAPI 应用。
|
||||
|
||||
使用确定性 master key 并清理模块级状态以保证测试隔离。
|
||||
注入伪造的 admin 用户中间件,使 SYSTEM_CONFIG 权限校验通过。
|
||||
"""
|
||||
monkeypatch.delenv("AGENTKIT_ENV", raising=False)
|
||||
monkeypatch.delenv("AGENTKIT_MASTER_KEY", raising=False)
|
||||
|
||||
# 注入确定性 store,避免依赖环境变量
|
||||
channels_routes._secrets_store = SecretsStore(master_key=b"\x01" * KEY_SIZE)
|
||||
channels_routes._channels.clear()
|
||||
|
||||
application = FastAPI()
|
||||
application.include_router(channels_routes.router, prefix="/api/v1")
|
||||
|
||||
# 伪造 admin 用户,使 require_permission(SYSTEM_CONFIG) 通过
|
||||
# (SYSTEM_CONFIG 为高风险权限,dev mode 下会 401,故需显式注入用户)
|
||||
@application.middleware("http")
|
||||
async def _fake_admin_auth(request, call_next):
|
||||
request.state.current_user = {
|
||||
"user_id": "u1",
|
||||
"username": "admin",
|
||||
"role": "admin",
|
||||
}
|
||||
return await call_next(request)
|
||||
|
||||
return application
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
class TestChannelRoutesCrud:
|
||||
"""渠道管理端点 CRUD。"""
|
||||
|
||||
def test_list_channels_empty(self, client):
|
||||
"""空状态下列出渠道返回空列表。"""
|
||||
resp = client.get("/api/v1/channels")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["total"] == 0
|
||||
assert data["channels"] == []
|
||||
|
||||
def test_create_and_get_channel(self, client):
|
||||
"""创建渠道后可读取配置(不含凭证明文)。"""
|
||||
resp = client.post(
|
||||
"/api/v1/channels",
|
||||
json={
|
||||
"channel_id": "feishu-prod",
|
||||
"channel_type": "feishu",
|
||||
"name": "飞书生产",
|
||||
"config": {"webhook_path": "/hook/feishu"},
|
||||
"secrets": {"app_id": "cli_xxx", "app_secret": "topsecret"},
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 201, resp.text
|
||||
created = resp.json()
|
||||
assert created["channel_id"] == "feishu-prod"
|
||||
assert created["channel_type"] == "feishu"
|
||||
assert created["name"] == "飞书生产"
|
||||
# secret_keys 返回字段名,不含明文
|
||||
assert set(created["secret_keys"]) == {"app_id", "app_secret"}
|
||||
assert "topsecret" not in resp.text
|
||||
|
||||
# GET 单个渠道
|
||||
resp = client.get("/api/v1/channels/feishu-prod")
|
||||
assert resp.status_code == 200
|
||||
got = resp.json()
|
||||
assert got["channel_id"] == "feishu-prod"
|
||||
assert set(got["secret_keys"]) == {"app_id", "app_secret"}
|
||||
assert "topsecret" not in resp.text
|
||||
|
||||
def test_create_duplicate_returns_409(self, client):
|
||||
"""重复创建同 ID 渠道返回 409。"""
|
||||
payload = {
|
||||
"channel_id": "dingtalk-1",
|
||||
"channel_type": "dingtalk",
|
||||
"name": "钉钉",
|
||||
}
|
||||
client.post("/api/v1/channels", json=payload)
|
||||
resp = client.post("/api/v1/channels", json=payload)
|
||||
assert resp.status_code == 409
|
||||
|
||||
def test_create_invalid_channel_id_returns_400(self, client):
|
||||
"""非法渠道 ID 返回 400。"""
|
||||
resp = client.post(
|
||||
"/api/v1/channels",
|
||||
json={
|
||||
"channel_id": "Invalid ID!",
|
||||
"channel_type": "feishu",
|
||||
"name": "x",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
def test_get_nonexistent_returns_404(self, client):
|
||||
"""获取不存在渠道返回 404。"""
|
||||
resp = client.get("/api/v1/channels/missing")
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_update_channel(self, client):
|
||||
"""更新渠道配置与凭证。"""
|
||||
client.post(
|
||||
"/api/v1/channels",
|
||||
json={
|
||||
"channel_id": "wecom-1",
|
||||
"channel_type": "wecom",
|
||||
"name": "企微",
|
||||
"secrets": {"token": "old-token"},
|
||||
},
|
||||
)
|
||||
resp = client.put(
|
||||
"/api/v1/channels/wecom-1",
|
||||
json={
|
||||
"name": "企微更新",
|
||||
"config": {"enabled": True},
|
||||
"secrets": {"token": "new-token", "extra": "v"},
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200, resp.text
|
||||
updated = resp.json()
|
||||
assert updated["name"] == "企微更新"
|
||||
assert updated["config"] == {"enabled": True}
|
||||
assert set(updated["secret_keys"]) == {"token", "extra"}
|
||||
|
||||
# 凭证已在 secrets store 中加密更新(明文不可见)
|
||||
assert "new-token" not in resp.text
|
||||
|
||||
def test_delete_channel(self, client):
|
||||
"""删除渠道后不可再获取。"""
|
||||
client.post(
|
||||
"/api/v1/channels",
|
||||
json={
|
||||
"channel_id": "slack-1",
|
||||
"channel_type": "slack",
|
||||
"name": "Slack",
|
||||
},
|
||||
)
|
||||
resp = client.delete("/api/v1/channels/slack-1")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == {"deleted": "slack-1"}
|
||||
|
||||
# 再获取应 404
|
||||
assert client.get("/api/v1/channels/slack-1").status_code == 404
|
||||
|
||||
def test_delete_nonexistent_returns_404(self, client):
|
||||
"""删除不存在渠道返回 404。"""
|
||||
resp = client.delete("/api/v1/channels/missing")
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_list_after_create(self, client):
|
||||
"""创建后列出渠道包含新建项。"""
|
||||
client.post(
|
||||
"/api/v1/channels",
|
||||
json={
|
||||
"channel_id": "feishu-a",
|
||||
"channel_type": "feishu",
|
||||
"name": "A",
|
||||
},
|
||||
)
|
||||
client.post(
|
||||
"/api/v1/channels",
|
||||
json={
|
||||
"channel_id": "dingtalk-b",
|
||||
"channel_type": "dingtalk",
|
||||
"name": "B",
|
||||
},
|
||||
)
|
||||
resp = client.get("/api/v1/channels")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["total"] == 2
|
||||
ids = {c["channel_id"] for c in data["channels"]}
|
||||
assert ids == {"feishu-a", "dingtalk-b"}
|
||||
|
||||
def test_secrets_encrypted_in_store(self, client):
|
||||
"""创建渠道后 secrets store 内部存储为密文。"""
|
||||
client.post(
|
||||
"/api/v1/channels",
|
||||
json={
|
||||
"channel_id": "feishu-enc",
|
||||
"channel_type": "feishu",
|
||||
"name": "enc",
|
||||
"secrets": {"app_secret": "plaintext-value"},
|
||||
},
|
||||
)
|
||||
store = channels_routes._get_secrets_store()
|
||||
entry = store._store["feishu-enc:app_secret"]
|
||||
# 内部存储的 value 不含明文
|
||||
assert "plaintext-value" not in entry.value
|
||||
# 但可正确解密(store 为纯内存对象,可在新事件循环中调用)
|
||||
import asyncio
|
||||
|
||||
decrypted = asyncio.run(store.get_secret("feishu-enc:app_secret"))
|
||||
assert decrypted == "plaintext-value"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 辅助:具体子类桩
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _StubAdapter(MessageAdapter):
|
||||
"""用于测试的桩适配器。"""
|
||||
|
||||
def __init__(self):
|
||||
self.closed = False
|
||||
|
||||
async def verify_signature(self, headers: dict[str, str], body: bytes) -> bool:
|
||||
return headers.get("X-Signature") == "valid"
|
||||
|
||||
async def receive_message(self, headers: dict[str, str], body: bytes) -> IncomingMessage:
|
||||
import json
|
||||
|
||||
data = json.loads(body)
|
||||
return IncomingMessage(
|
||||
channel=ChannelType.FEISHU,
|
||||
platform_message_id="m1",
|
||||
user_id="u1",
|
||||
chat_id="c1",
|
||||
content=data.get("msg", ""),
|
||||
raw_event=data,
|
||||
)
|
||||
|
||||
async def send_message(self, message: OutgoingMessage) -> bool:
|
||||
return True
|
||||
|
||||
async def close(self) -> None:
|
||||
self.closed = True
|
||||
|
|
@ -0,0 +1,245 @@
|
|||
"""SecretsStore 加密凭证存储测试 (U10 / KTD8)。
|
||||
|
||||
覆盖场景:
|
||||
- secrets 写入后加密存储(非明文)
|
||||
- secrets 读取时解密
|
||||
- 不同 master key / 不同次加密产生不同密文
|
||||
- 错误 master key 解密失败
|
||||
- set/get/delete/list CRUD 操作
|
||||
- 生产环境启动 guard
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from cryptography.exceptions import InvalidTag
|
||||
|
||||
from agentkit.channels.secrets import (
|
||||
KEY_SIZE,
|
||||
SecretsStore,
|
||||
assert_production_master_key,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def master_key() -> bytes:
|
||||
"""确定性的测试 master key(32 字节)。"""
|
||||
return b"\x01" * KEY_SIZE
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def store(master_key: bytes) -> SecretsStore:
|
||||
"""使用确定性 master key 的 SecretsStore。"""
|
||||
return SecretsStore(master_key=master_key)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clean_production_env(monkeypatch):
|
||||
"""确保测试不在生产模式下运行(移除 AGENTKIT_ENV=production)。"""
|
||||
monkeypatch.delenv("AGENTKIT_ENV", raising=False)
|
||||
monkeypatch.delenv("AGENTKIT_MASTER_KEY", raising=False)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 加密 / 解密往返
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEncryptDecrypt:
|
||||
"""加密与解密核心行为。"""
|
||||
|
||||
def test_roundtrip_returns_plaintext(self, store: SecretsStore):
|
||||
"""加密后解密应返回原始明文。"""
|
||||
plaintext = "feishu-app-secret-12345"
|
||||
entry = store.encrypt(plaintext)
|
||||
assert store.decrypt(entry) == plaintext
|
||||
|
||||
def test_encrypted_value_is_not_plaintext(self, store: SecretsStore):
|
||||
"""写入后密文不应包含明文(非明文存储)。"""
|
||||
plaintext = "super-secret-token"
|
||||
entry = store.encrypt(plaintext)
|
||||
# 密文 base64 不应包含明文
|
||||
assert plaintext not in entry.value
|
||||
assert plaintext not in entry.nonce
|
||||
assert plaintext not in entry.salt
|
||||
# 解码后的密文也不应包含明文字节
|
||||
ciphertext_bytes = base64.b64decode(entry.value)
|
||||
assert plaintext.encode() not in ciphertext_bytes
|
||||
|
||||
def test_each_encrypt_produces_different_ciphertext(self, store: SecretsStore):
|
||||
"""同一明文每次加密应产生不同密文(随机 nonce+salt)。"""
|
||||
plaintext = "same-secret"
|
||||
entry1 = store.encrypt(plaintext)
|
||||
entry2 = store.encrypt(plaintext)
|
||||
assert entry1.value != entry2.value
|
||||
assert entry1.nonce != entry2.nonce
|
||||
assert entry1.salt != entry2.salt
|
||||
# 两者都能正确解密
|
||||
assert store.decrypt(entry1) == plaintext
|
||||
assert store.decrypt(entry2) == plaintext
|
||||
|
||||
def test_different_master_keys_produce_different_ciphertext(self):
|
||||
"""不同 master key 派生不同 per-row 密钥。"""
|
||||
plaintext = "shared-secret"
|
||||
store_a = SecretsStore(master_key=b"\x01" * KEY_SIZE)
|
||||
store_b = SecretsStore(master_key=b"\x02" * KEY_SIZE)
|
||||
entry_a = store_a.encrypt(plaintext)
|
||||
# store_b 无法解密 store_a 加密的条目
|
||||
with pytest.raises(InvalidTag):
|
||||
store_b.decrypt(entry_a)
|
||||
# store_a 自身可解密
|
||||
assert store_a.decrypt(entry_a) == plaintext
|
||||
|
||||
def test_decrypt_with_wrong_key_fails(self, store: SecretsStore):
|
||||
"""错误 master key 解密应抛出 InvalidTag。"""
|
||||
entry = store.encrypt("secret-data")
|
||||
wrong_store = SecretsStore(master_key=b"\xab" * KEY_SIZE)
|
||||
with pytest.raises(InvalidTag):
|
||||
wrong_store.decrypt(entry)
|
||||
|
||||
def test_entry_has_correct_metadata(self, store: SecretsStore):
|
||||
"""加密条目应包含 nonce、salt、key_id 等元数据。"""
|
||||
entry = store.encrypt("data")
|
||||
assert entry.key == ""
|
||||
assert entry.nonce
|
||||
assert entry.salt
|
||||
assert entry.key_id == "default"
|
||||
# nonce 解码后为 12 字节
|
||||
assert len(base64.b64decode(entry.nonce)) == 12
|
||||
# salt 解码后为 16 字节
|
||||
assert len(base64.b64decode(entry.salt)) == 16
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CRUD 异步操作
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSecretsCrud:
|
||||
"""set/get/delete/list 异步操作。"""
|
||||
|
||||
async def test_set_and_get_secret(self, store: SecretsStore):
|
||||
"""写入后读取应返回明文。"""
|
||||
await store.set_secret("feishu:app_id", "cli_xxx")
|
||||
assert await store.get_secret("feishu:app_id") == "cli_xxx"
|
||||
|
||||
async def test_get_nonexistent_secret_returns_none(self, store: SecretsStore):
|
||||
"""读取不存在的 key 返回 None。"""
|
||||
assert await store.get_secret("missing") is None
|
||||
|
||||
async def test_set_overwrites_existing(self, store: SecretsStore):
|
||||
"""同名 key 写入应覆盖旧值。"""
|
||||
await store.set_secret("k", "old")
|
||||
await store.set_secret("k", "new")
|
||||
assert await store.get_secret("k") == "new"
|
||||
|
||||
async def test_delete_secret(self, store: SecretsStore):
|
||||
"""删除凭证后不可再读取。"""
|
||||
await store.set_secret("k", "v")
|
||||
assert await store.delete_secret("k") is True
|
||||
assert await store.get_secret("k") is None
|
||||
|
||||
async def test_delete_nonexistent_returns_false(self, store: SecretsStore):
|
||||
"""删除不存在的 key 返回 False。"""
|
||||
assert await store.delete_secret("missing") is False
|
||||
|
||||
async def test_list_keys(self, store: SecretsStore):
|
||||
"""列出全部凭证键。"""
|
||||
await store.set_secret("feishu:a", "1")
|
||||
await store.set_secret("feishu:b", "2")
|
||||
await store.set_secret("dingtalk:c", "3")
|
||||
keys = await store.list_keys()
|
||||
assert set(keys) == {"feishu:a", "feishu:b", "dingtalk:c"}
|
||||
|
||||
async def test_list_keys_with_prefix(self, store: SecretsStore):
|
||||
"""按前缀过滤凭证键。"""
|
||||
await store.set_secret("feishu:a", "1")
|
||||
await store.set_secret("feishu:b", "2")
|
||||
await store.set_secret("dingtalk:c", "3")
|
||||
keys = await store.list_keys(prefix="feishu:")
|
||||
assert set(keys) == {"feishu:a", "feishu:b"}
|
||||
|
||||
async def test_stored_value_is_encrypted_not_plaintext(self, store: SecretsStore):
|
||||
"""写入后内部存储的 value 字段应为密文,非明文。"""
|
||||
plaintext = "plaintext-token"
|
||||
await store.set_secret("k", plaintext)
|
||||
# 直接检查内存中的 entry.value,确认非明文
|
||||
entry = store._store["k"]
|
||||
assert plaintext not in entry.value
|
||||
assert plaintext not in base64.b64decode(entry.value).decode(errors="ignore")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 生产环境 guard
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestProductionGuard:
|
||||
"""生产环境 master key 启动 guard。"""
|
||||
|
||||
def test_dev_mode_allows_env_source(self, master_key: bytes, monkeypatch):
|
||||
"""开发模式下 env 来源 master key 允许构造。"""
|
||||
monkeypatch.delenv("AGENTKIT_ENV", raising=False)
|
||||
# 不应抛异常
|
||||
store = SecretsStore(master_key=master_key, key_source="env")
|
||||
assert store is not None
|
||||
|
||||
def test_production_mode_rejects_env_source(self, master_key: bytes, monkeypatch):
|
||||
"""生产模式下 env 来源 master key 应拒绝启动。"""
|
||||
monkeypatch.setenv("AGENTKIT_ENV", "production")
|
||||
with pytest.raises(RuntimeError, match="云 KMS"):
|
||||
SecretsStore(master_key=master_key, key_source="env")
|
||||
|
||||
def test_production_mode_rejects_missing_key(self, monkeypatch):
|
||||
"""生产模式下无 master key 应拒绝启动。"""
|
||||
monkeypatch.setenv("AGENTKIT_ENV", "production")
|
||||
with pytest.raises(RuntimeError, match="云 KMS"):
|
||||
SecretsStore(master_key=None, key_source="env")
|
||||
|
||||
def test_production_mode_allows_kms_source(self, master_key: bytes, monkeypatch):
|
||||
"""生产模式下 KMS 来源 master key 允许构造。"""
|
||||
monkeypatch.setenv("AGENTKIT_ENV", "production")
|
||||
store = SecretsStore(master_key=master_key, key_source="kms")
|
||||
assert store is not None
|
||||
|
||||
def test_assert_production_master_key_helper(self, master_key: bytes, monkeypatch):
|
||||
"""直接测试 guard 辅助函数。"""
|
||||
monkeypatch.setenv("AGENTKIT_ENV", "production")
|
||||
# env 来源拒绝
|
||||
with pytest.raises(RuntimeError):
|
||||
assert_production_master_key(master_key, source="env")
|
||||
# kms 来源通过
|
||||
assert_production_master_key(master_key, source="kms")
|
||||
# 空密钥拒绝
|
||||
with pytest.raises(RuntimeError):
|
||||
assert_production_master_key(None, source="kms")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 环境变量加载 fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEnvLoading:
|
||||
"""开发环境从环境变量加载 master key。"""
|
||||
|
||||
def test_loads_master_key_from_env(self, monkeypatch):
|
||||
"""AGENTKIT_MASTER_KEY 环境变量应被加载。"""
|
||||
key = os.urandom(KEY_SIZE)
|
||||
monkeypatch.setenv("AGENTKIT_MASTER_KEY", base64.b64encode(key).decode())
|
||||
store = SecretsStore() # 不显式传入 master_key
|
||||
assert store._master_key == key
|
||||
|
||||
def test_ephemeral_key_when_no_env(self, monkeypatch):
|
||||
"""无环境变量时生成临时 key(开发 fallback)。"""
|
||||
monkeypatch.delenv("AGENTKIT_MASTER_KEY", raising=False)
|
||||
store = SecretsStore()
|
||||
assert len(store._master_key) == KEY_SIZE
|
||||
Loading…
Reference in New Issue