feat(channels): U12 — DingTalk/WeCom/Slack adapters + multi-channel webhook dispatch
This commit is contained in:
parent
4b58e8f661
commit
8998f94c42
|
|
@ -0,0 +1,281 @@
|
|||
"""钉钉 IM 适配器 (U12)。
|
||||
|
||||
实现 :class:`MessageAdapter` 协议,对接钉钉企业内机器人 outgoing/webhook 回调。
|
||||
|
||||
关键设计决策:
|
||||
- 签名校验 fail-closed:``Sign`` + ``Timestamp`` 头缺失时仅当 token 校验通过才放行。
|
||||
- 时间戳窗口 3600s(钉钉官方窗口 1 小时)。
|
||||
- ``accessToken`` 简单 TTL 缓存(7200s);钉钉 token 有效期 2 小时。
|
||||
- httpx 客户端懒构造,避免未使用的适配器持有连接池。
|
||||
- 钉钉无 URL verification challenge 流程 — 合法签名请求直接 200。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from agentkit.channels.base import (
|
||||
ChannelType,
|
||||
IncomingMessage,
|
||||
MessageAdapter,
|
||||
OutgoingMessage,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 签名时间戳允许的最大偏移(秒)— 钉钉官方窗口 1 小时
|
||||
_SIGNATURE_MAX_AGE_SECONDS = 3600
|
||||
# accessToken 缓存 TTL(秒)— 钉钉 token 有效期 2 小时
|
||||
_TOKEN_CACHE_TTL = 7200.0
|
||||
|
||||
# 钉钉 API 端点
|
||||
_ACCESS_TOKEN_URL = "https://api.dingtalk.com/v1.0/oauth2/accessToken"
|
||||
_SEND_MESSAGE_URL = "https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend"
|
||||
|
||||
# 钉钉 @ 机器人前缀剥离 — 内容形如 "@robotName 实际消息"
|
||||
# ponytail: 简单正则剥离首个 @token。天花板:无法区分 @ 机器人与 @ 用户;
|
||||
# 升级路径:从 config 注入 robot_name 精确匹配。
|
||||
_MENTION_PREFIX_RE = re.compile(r"^@\S+\s*")
|
||||
|
||||
|
||||
class DingTalkMessageAdapter(MessageAdapter):
|
||||
"""钉钉 IM 适配器。
|
||||
|
||||
生命周期:
|
||||
``__init__`` → :meth:`verify_signature` → :meth:`receive_message`
|
||||
→ :meth:`send_message` → :meth:`close`
|
||||
|
||||
Args:
|
||||
app_key: 钉钉应用 App Key。
|
||||
app_secret: 钉钉应用 App Secret(同时作为签名密钥)。
|
||||
robot_code: 机器人 robotCode(发送消息时使用)。
|
||||
token: 可选 token 校验值(部分机器人通过 ``Token`` 头校验)。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app_key: str,
|
||||
app_secret: str,
|
||||
robot_code: str,
|
||||
token: str | None = None,
|
||||
) -> None:
|
||||
self.app_key = app_key
|
||||
self.app_secret = app_secret
|
||||
self.robot_code = robot_code
|
||||
self.token = token
|
||||
# 懒加载 httpx 客户端
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
# ponytail: 简单 TTL 缓存 (token, expiry)。天花板:单实例内存;
|
||||
# 升级路径:Redis 缓存共享给多实例。
|
||||
self._token_cache: tuple[str, float] | None = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# httpx 客户端懒加载
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _get_client(self) -> httpx.AsyncClient:
|
||||
if self._client is None:
|
||||
self._client = httpx.AsyncClient(timeout=10.0)
|
||||
return self._client
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 签名验证
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def verify_signature(self, headers: dict[str, str], body: bytes) -> bool:
|
||||
"""验证钉钉 webhook 签名。
|
||||
|
||||
- 若配置了 ``token``:校验 ``Token`` 头,不匹配返回 False。
|
||||
- 若存在 ``Sign`` + ``Timestamp`` 头:校验 HMAC-SHA256 签名与时间戳新鲜度。
|
||||
- 两者皆无:仅当 token 已校验通过才放行(token-only 模式)。
|
||||
|
||||
Args:
|
||||
headers: HTTP 请求头(键大小写不敏感查找)。
|
||||
body: 原始请求体字节(钉钉签名不依赖 body)。
|
||||
|
||||
Returns:
|
||||
True 表示签名校验通过。
|
||||
"""
|
||||
# Token 校验(若配置)
|
||||
if self.token is not None:
|
||||
token_header = _header_get(headers, "Token")
|
||||
if token_header is None or not hmac.compare_digest(token_header, self.token):
|
||||
return False
|
||||
|
||||
sign = _header_get(headers, "Sign")
|
||||
timestamp_str = _header_get(headers, "Timestamp")
|
||||
|
||||
if sign is None and timestamp_str is None:
|
||||
# 无签名头:仅当 token 已校验通过才放行
|
||||
return self.token is not None
|
||||
|
||||
if sign is None or timestamp_str is None:
|
||||
return False
|
||||
|
||||
# 时间戳新鲜度(毫秒 → 秒)
|
||||
try:
|
||||
ts_ms = int(timestamp_str)
|
||||
except ValueError:
|
||||
return False
|
||||
ts_sec = ts_ms / 1000.0
|
||||
now = time.time()
|
||||
if abs(now - ts_sec) > _SIGNATURE_MAX_AGE_SECONDS:
|
||||
logger.warning("钉钉 webhook 时间戳超出 %ds 窗口 — 拒绝", _SIGNATURE_MAX_AGE_SECONDS)
|
||||
return False
|
||||
|
||||
# 计算签名:base64(hmac-sha256(key=app_secret, msg="{timestamp}\n{app_secret}"))
|
||||
string_to_sign = f"{timestamp_str}\n{self.app_secret}"
|
||||
expected = base64.b64encode(
|
||||
hmac.new(
|
||||
self.app_secret.encode("utf-8"),
|
||||
string_to_sign.encode("utf-8"),
|
||||
hashlib.sha256,
|
||||
).digest()
|
||||
).decode("utf-8")
|
||||
return hmac.compare_digest(sign, expected)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 消息接收 / 解析
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def receive_message(self, headers: dict[str, str], body: bytes) -> IncomingMessage:
|
||||
"""解析钉钉 webhook 事件为标准化 :class:`IncomingMessage`。
|
||||
|
||||
钉钉无 URL verification challenge 流程 — 合法请求直接解析为消息。
|
||||
|
||||
Raises:
|
||||
ValueError: 事件 body 不是合法 JSON。
|
||||
"""
|
||||
try:
|
||||
data: dict[str, Any] = json.loads(body)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise ValueError(f"钉钉事件 body 不是合法 JSON: {exc}") from exc
|
||||
|
||||
chat_id = data.get("conversationId", "")
|
||||
user_id = data.get("senderStaffId") or data.get("senderId") or data.get("staffId") or ""
|
||||
msg_id = data.get("msgId", "")
|
||||
session_expired = data.get("sessionWebhookExpiredTime", "")
|
||||
timestamp = str(session_expired) if session_expired else ""
|
||||
|
||||
content = self._extract_content(data)
|
||||
|
||||
return IncomingMessage(
|
||||
channel=ChannelType.DINGTALK,
|
||||
platform_message_id=msg_id,
|
||||
user_id=user_id,
|
||||
chat_id=chat_id,
|
||||
content=content,
|
||||
raw_event=data,
|
||||
timestamp=timestamp,
|
||||
)
|
||||
|
||||
def _extract_content(self, data: dict[str, Any]) -> str:
|
||||
"""从钉钉事件提取文本内容。
|
||||
|
||||
- text 类型:解析 ``text.content``,剥离 @ 机器人前缀。
|
||||
- 其他类型:返回 ``[unsupported message type: {type}]``。
|
||||
"""
|
||||
msgtype = data.get("msgtype", "")
|
||||
if msgtype == "text":
|
||||
text_obj = data.get("text", {})
|
||||
content = text_obj.get("content", "") if isinstance(text_obj, dict) else ""
|
||||
# 剥离 @ 机器人前缀
|
||||
return _MENTION_PREFIX_RE.sub("", content, count=1).strip()
|
||||
return f"[unsupported message type: {msgtype}]"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 消息发送
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def send_message(self, message: OutgoingMessage) -> bool:
|
||||
"""向钉钉发送文本消息(oToMessages/batchSend 单聊)。
|
||||
|
||||
Returns:
|
||||
True 表示 HTTP 200。
|
||||
"""
|
||||
try:
|
||||
token = await self._get_access_token()
|
||||
if not token:
|
||||
return False
|
||||
|
||||
client = self._get_client()
|
||||
payload = {
|
||||
"robotCode": self.robot_code,
|
||||
"conversationId": message.chat_id,
|
||||
"msgKey": "sampleText",
|
||||
"msgParam": json.dumps({"content": message.content}),
|
||||
}
|
||||
resp = await client.post(
|
||||
_SEND_MESSAGE_URL,
|
||||
json=payload,
|
||||
headers={"x-acs-dingtalk-access-token": token},
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
logger.error("钉钉 send_message HTTP %d: %s", resp.status_code, resp.text[:200])
|
||||
return False
|
||||
return True
|
||||
except httpx.HTTPError as exc:
|
||||
logger.error("钉钉 send_message 网络错误: %s", exc)
|
||||
return False
|
||||
|
||||
async def _get_access_token(self) -> str | None:
|
||||
"""获取并缓存钉钉 ``accessToken``。"""
|
||||
# 命中缓存
|
||||
if self._token_cache is not None:
|
||||
token, expiry = self._token_cache
|
||||
if time.monotonic() < expiry:
|
||||
return token
|
||||
|
||||
try:
|
||||
client = self._get_client()
|
||||
resp = await client.post(
|
||||
_ACCESS_TOKEN_URL,
|
||||
json={"appKey": self.app_key, "appSecret": self.app_secret},
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
logger.error("钉钉 accessToken HTTP %d: %s", resp.status_code, resp.text[:200])
|
||||
return None
|
||||
data = resp.json()
|
||||
token = data.get("accessToken", "")
|
||||
if not token:
|
||||
return None
|
||||
self._token_cache = (token, time.monotonic() + _TOKEN_CACHE_TTL)
|
||||
return token
|
||||
except httpx.HTTPError as exc:
|
||||
logger.error("钉钉 accessToken 网络错误: %s", exc)
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 资源释放
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭 httpx 客户端(如已创建)。"""
|
||||
if self._client is not None:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 辅助函数
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _header_get(headers: dict[str, str], name: str) -> str | None:
|
||||
"""大小写不敏感的 header 查找。"""
|
||||
if name in headers:
|
||||
return headers[name]
|
||||
lower = name.lower()
|
||||
for k, v in headers.items():
|
||||
if k.lower() == lower:
|
||||
return v
|
||||
return None
|
||||
|
|
@ -0,0 +1,287 @@
|
|||
"""Slack IM 适配器 (U12)。
|
||||
|
||||
实现 :class:`MessageAdapter` 协议,对接 Slack Events API 与 Slash Commands。
|
||||
|
||||
关键设计决策:
|
||||
- 签名:``v0:{timestamp}:{body}`` 的 HMAC-SHA256,与 ``X-Slack-Signature`` 比对。
|
||||
- 时间戳窗口 300s(Slack 官方推荐 5 分钟)。
|
||||
- URL 验证:``url_verification`` 事件抛 :class:`URLVerificationChallenge`。
|
||||
- Events API 与 Slash Commands 双流程:JSON body → Events API;form-encoded → Slash。
|
||||
- ``<@U12345>`` 提及标记剥离。
|
||||
- httpx 客户端懒构造。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from typing import Any
|
||||
from urllib.parse import parse_qs
|
||||
from uuid import uuid4
|
||||
|
||||
import httpx
|
||||
|
||||
from agentkit.channels.base import (
|
||||
ChannelType,
|
||||
IncomingMessage,
|
||||
MessageAdapter,
|
||||
OutgoingMessage,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 签名时间戳允许的最大偏移(秒)— Slack 官方推荐 5 分钟
|
||||
_SIGNATURE_MAX_AGE_SECONDS = 300
|
||||
|
||||
# Slack API 端点
|
||||
_SEND_MESSAGE_URL = "https://slack.com/api/chat.postMessage"
|
||||
|
||||
# Slack <@U12345> / <@U12345|name> 提及标记剥离
|
||||
_MENTION_RE = re.compile(r"<@[^>]+>\s*")
|
||||
|
||||
|
||||
class URLVerificationChallenge(Exception):
|
||||
"""Slack URL 验证事件 — webhook 端点需返回 ``{"challenge": ...}`` 响应。
|
||||
|
||||
Slack 在配置 Events API 时发送 ``url_verification`` 事件,要求服务端
|
||||
原样返回 ``challenge`` 字段以验证 URL 可达。
|
||||
"""
|
||||
|
||||
def __init__(self, challenge: str) -> None:
|
||||
super().__init__(f"Slack URL verification challenge: {challenge}")
|
||||
self.challenge = challenge
|
||||
|
||||
|
||||
class SignatureVerificationError(Exception):
|
||||
"""``verification_token`` 校验失败 — 拒绝处理。"""
|
||||
|
||||
|
||||
class SlackMessageAdapter(MessageAdapter):
|
||||
"""Slack IM 适配器。
|
||||
|
||||
生命周期:
|
||||
``__init__`` → :meth:`verify_signature` → :meth:`receive_message`
|
||||
→ :meth:`send_message` → :meth:`close`
|
||||
|
||||
Args:
|
||||
bot_token: Slack Bot User OAuth Token(``xoxb-`` 开头)。
|
||||
signing_secret: Slack 应用 Signing Secret(签名校验)。
|
||||
verification_token: 可选旧版 Verification Token(URL 验证时校验)。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bot_token: str,
|
||||
signing_secret: str,
|
||||
verification_token: str | None = None,
|
||||
) -> None:
|
||||
self.bot_token = bot_token
|
||||
self.signing_secret = signing_secret
|
||||
self.verification_token = verification_token
|
||||
# 懒加载 httpx 客户端
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# httpx 客户端懒加载
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _get_client(self) -> httpx.AsyncClient:
|
||||
if self._client is None:
|
||||
self._client = httpx.AsyncClient(timeout=10.0)
|
||||
return self._client
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 签名验证
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def verify_signature(self, headers: dict[str, str], body: bytes) -> bool:
|
||||
"""验证 Slack webhook 签名。
|
||||
|
||||
签名 = ``v0=`` + hmac-sha256(key=signing_secret, msg="v0:{timestamp}:{body}"),
|
||||
与 ``X-Slack-Signature`` 头比对。时间戳超过 5 分钟视为重放攻击。
|
||||
|
||||
Args:
|
||||
headers: HTTP 请求头(键大小写不敏感查找)。
|
||||
body: 原始请求体字节。
|
||||
|
||||
Returns:
|
||||
True 表示签名校验通过。
|
||||
"""
|
||||
signature = _header_get(headers, "X-Slack-Signature")
|
||||
timestamp_str = _header_get(headers, "X-Slack-Request-Timestamp")
|
||||
if not signature or not timestamp_str:
|
||||
return False
|
||||
|
||||
# 时间戳重放保护
|
||||
try:
|
||||
ts = int(timestamp_str)
|
||||
except ValueError:
|
||||
return False
|
||||
now = time.time()
|
||||
if abs(now - ts) > _SIGNATURE_MAX_AGE_SECONDS:
|
||||
logger.warning("Slack webhook 时间戳超出 %ds 窗口 — 拒绝", _SIGNATURE_MAX_AGE_SECONDS)
|
||||
return False
|
||||
|
||||
# 计算签名:v0={timestamp}:{body}
|
||||
base = f"v0:{timestamp_str}:{body.decode('utf-8')}"
|
||||
expected = (
|
||||
"v0="
|
||||
+ hmac.new(
|
||||
self.signing_secret.encode("utf-8"),
|
||||
base.encode("utf-8"),
|
||||
hashlib.sha256,
|
||||
).hexdigest()
|
||||
)
|
||||
return hmac.compare_digest(signature, expected)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 消息接收 / 解析
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def receive_message(self, headers: dict[str, str], body: bytes) -> IncomingMessage:
|
||||
"""解析 Slack webhook 事件为标准化 :class:`IncomingMessage`。
|
||||
|
||||
- JSON body:Events API 或 URL 验证。
|
||||
- form-encoded body:Slash Command。
|
||||
|
||||
Raises:
|
||||
URLVerificationChallenge: URL 验证事件。
|
||||
SignatureVerificationError: ``verification_token`` 不匹配。
|
||||
"""
|
||||
# 优先尝试 JSON 解析(Events API / URL 验证)
|
||||
try:
|
||||
data: dict[str, Any] = json.loads(body)
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
data = {}
|
||||
|
||||
if data:
|
||||
return self._parse_event(data)
|
||||
|
||||
# 非 JSON — 视为 Slash Command(form-encoded)
|
||||
return self._parse_slash_command(body)
|
||||
|
||||
def _parse_event(self, data: dict[str, Any]) -> IncomingMessage:
|
||||
"""解析 Events API 事件。"""
|
||||
# URL 验证流程
|
||||
if data.get("type") == "url_verification":
|
||||
challenge = data.get("challenge", "")
|
||||
token = data.get("token", "")
|
||||
if self.verification_token is not None and not hmac.compare_digest(
|
||||
token, self.verification_token
|
||||
):
|
||||
raise SignatureVerificationError("verification_token 不匹配")
|
||||
raise URLVerificationChallenge(challenge)
|
||||
|
||||
event = data.get("event", {})
|
||||
event_id = data.get("event_id", "")
|
||||
event_type = event.get("type", "")
|
||||
user = str(event.get("user", ""))
|
||||
channel = event.get("channel", "")
|
||||
text = event.get("text", "")
|
||||
ts = event.get("ts", "")
|
||||
|
||||
# 仅支持 message / app_mention 事件
|
||||
if event_type not in ("message", "app_mention"):
|
||||
return IncomingMessage(
|
||||
channel=ChannelType.SLACK,
|
||||
platform_message_id=event_id,
|
||||
user_id=user,
|
||||
chat_id=channel,
|
||||
content=f"[unsupported event type: {event_type}]",
|
||||
raw_event=data,
|
||||
timestamp=ts,
|
||||
)
|
||||
|
||||
# 剥离 <@U12345> 提及标记
|
||||
text = _MENTION_RE.sub("", text).strip()
|
||||
|
||||
return IncomingMessage(
|
||||
channel=ChannelType.SLACK,
|
||||
platform_message_id=event_id,
|
||||
user_id=user,
|
||||
chat_id=channel,
|
||||
content=text,
|
||||
raw_event=data,
|
||||
timestamp=ts,
|
||||
)
|
||||
|
||||
def _parse_slash_command(self, body: bytes) -> IncomingMessage:
|
||||
"""解析 Slash Command(form-encoded body)。"""
|
||||
params = parse_qs(body.decode("utf-8"))
|
||||
text = params.get("text", [""])[0]
|
||||
user_id = params.get("user_id", [""])[0]
|
||||
channel_id = params.get("channel_id", [""])[0]
|
||||
command = params.get("command", [""])[0]
|
||||
|
||||
return IncomingMessage(
|
||||
channel=ChannelType.SLACK,
|
||||
platform_message_id=f"slash-{uuid4()}",
|
||||
user_id=user_id,
|
||||
chat_id=channel_id,
|
||||
content=text,
|
||||
raw_event={"command": command, "form": {k: v[0] for k, v in params.items()}},
|
||||
timestamp=str(time.time()),
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 消息发送
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def send_message(self, message: OutgoingMessage) -> bool:
|
||||
"""向 Slack 发送文本消息(chat.postMessage)。
|
||||
|
||||
Returns:
|
||||
True 表示 HTTP 200 且响应 ``ok == true``。
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
resp = await client.post(
|
||||
_SEND_MESSAGE_URL,
|
||||
headers={"Authorization": f"Bearer {self.bot_token}"},
|
||||
json={"channel": message.chat_id, "text": message.content},
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
logger.error("Slack send_message HTTP %d: %s", resp.status_code, resp.text[:200])
|
||||
return False
|
||||
data = resp.json()
|
||||
if not data.get("ok"):
|
||||
logger.error(
|
||||
"Slack send_message 业务失败 ok=%s error=%s",
|
||||
data.get("ok"),
|
||||
data.get("error", "")[:200],
|
||||
)
|
||||
return False
|
||||
return True
|
||||
except httpx.HTTPError as exc:
|
||||
logger.error("Slack send_message 网络错误: %s", exc)
|
||||
return False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 资源释放
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭 httpx 客户端(如已创建)。"""
|
||||
if self._client is not None:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 辅助函数
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _header_get(headers: dict[str, str], name: str) -> str | None:
|
||||
"""大小写不敏感的 header 查找。"""
|
||||
if name in headers:
|
||||
return headers[name]
|
||||
lower = name.lower()
|
||||
for k, v in headers.items():
|
||||
if k.lower() == lower:
|
||||
return v
|
||||
return None
|
||||
|
|
@ -0,0 +1,390 @@
|
|||
"""企业微信 IM 适配器 (U12)。
|
||||
|
||||
实现 :class:`MessageAdapter` 协议,对接企业微信回调 webhook。
|
||||
|
||||
关键设计决策:
|
||||
- EncodingAESKey(43 字符 base64)→ 32 字节 AES-256-CBC 密钥。
|
||||
- 签名:sha1(sorted([token, timestamp, nonce, encrypt])) 与 ``msg_signature`` 比对。
|
||||
- 加密协议:AES-256-CBC,明文 = random(16) + msg_len(4, 大端) + msg + app_id。
|
||||
app_id 校验需匹配 ``corp_id``。
|
||||
- ``access_token`` 简单 TTL 缓存(7200s)。
|
||||
- XML 解析使用 stdlib ``xml.etree.ElementTree``。
|
||||
- httpx 客户端懒构造。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
import httpx
|
||||
|
||||
from agentkit.channels.base import (
|
||||
ChannelType,
|
||||
IncomingMessage,
|
||||
MessageAdapter,
|
||||
OutgoingMessage,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# access_token 缓存 TTL(秒)— 企微 token 有效期 2 小时
|
||||
_TOKEN_CACHE_TTL = 7200.0
|
||||
|
||||
# 企微 API 端点
|
||||
_GET_TOKEN_URL = "https://qyapi.weixin.qq.com/cgi-bin/gettoken"
|
||||
_SEND_MESSAGE_URL = "https://qyapi.weixin.qq.com/cgi-bin/message/send"
|
||||
|
||||
|
||||
class WeComURLVerification(Exception):
|
||||
"""企微 URL 验证事件 — webhook 端点需返回 XML 响应。
|
||||
|
||||
企微在配置回调时发送含 ``EchoStr`` 的加密事件,服务端需解密、
|
||||
重新加密后包装成 XML 响应返回。
|
||||
"""
|
||||
|
||||
def __init__(self, response_xml: str) -> None:
|
||||
super().__init__("WeCom URL verification")
|
||||
self.response_xml = response_xml
|
||||
|
||||
|
||||
class WeComError(Exception):
|
||||
"""企微消息处理错误(解密失败、app_id 不匹配等)。"""
|
||||
|
||||
|
||||
class WeComMessageAdapter(MessageAdapter):
|
||||
"""企业微信 IM 适配器。
|
||||
|
||||
生命周期:
|
||||
``__init__`` → :meth:`verify_signature` → :meth:`receive_message`
|
||||
→ :meth:`send_message` → :meth:`close`
|
||||
|
||||
Args:
|
||||
corp_id: 企业 ID(解密时校验 app_id)。
|
||||
agent_id: 应用 AgentID(发送消息时使用)。
|
||||
corp_secret: 应用凭证密钥(获取 access_token)。
|
||||
token: 回调配置的 Token(签名校验)。
|
||||
encoding_aes_key: 43 字符 EncodingAESKey(AES 密钥来源)。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
corp_id: str,
|
||||
agent_id: int,
|
||||
corp_secret: str,
|
||||
token: str,
|
||||
encoding_aes_key: str,
|
||||
) -> None:
|
||||
self.corp_id = corp_id
|
||||
self.agent_id = agent_id
|
||||
self.corp_secret = corp_secret
|
||||
self.token = token
|
||||
self.encoding_aes_key = encoding_aes_key
|
||||
# 懒加载 httpx 客户端
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
# ponytail: 简单 TTL 缓存。天花板:单实例内存;升级路径:Redis 共享。
|
||||
self._token_cache: tuple[str, float] | None = None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# httpx 客户端懒加载
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _get_client(self) -> httpx.AsyncClient:
|
||||
if self._client is None:
|
||||
self._client = httpx.AsyncClient(timeout=10.0)
|
||||
return self._client
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# AES 密钥 / 加解密
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _decode_aes_key(self) -> bytes:
|
||||
"""EncodingAESKey(43 字符)→ 32 字节 AES 密钥。"""
|
||||
return base64.b64decode(self.encoding_aes_key + "=")
|
||||
|
||||
def _encrypt(self, plaintext_str: str) -> str:
|
||||
"""AES-256-CBC 加密 — 返回 base64(IV + 密文)。
|
||||
|
||||
明文结构:random(16) + msg_len(4, 大端) + msg + corp_id,PKCS7 填充。
|
||||
"""
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
from cryptography.hazmat.primitives.padding import PKCS7
|
||||
|
||||
key = self._decode_aes_key()
|
||||
iv = os.urandom(16)
|
||||
random_prefix = os.urandom(16)
|
||||
msg_bytes = plaintext_str.encode("utf-8")
|
||||
msg_len = len(msg_bytes).to_bytes(4, "big")
|
||||
app_id = self.corp_id.encode("utf-8")
|
||||
plaintext = random_prefix + msg_len + msg_bytes + app_id
|
||||
|
||||
padder = PKCS7(algorithms.AES.block_size).padder()
|
||||
padded = padder.update(plaintext) + padder.finalize()
|
||||
|
||||
cipher = Cipher(algorithms.AES(key), modes.CBC(iv))
|
||||
encryptor = cipher.encryptor()
|
||||
ciphertext = encryptor.update(padded) + encryptor.finalize()
|
||||
|
||||
return base64.b64encode(iv + ciphertext).decode("utf-8")
|
||||
|
||||
def _decrypt(self, encrypt_b64: str) -> str:
|
||||
"""AES-256-CBC 解密 — 校验 app_id 后返回 msg 明文。
|
||||
|
||||
Raises:
|
||||
WeComError: app_id 不匹配 corp_id,或密文损坏。
|
||||
"""
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
from cryptography.hazmat.primitives.padding import PKCS7
|
||||
|
||||
key = self._decode_aes_key()
|
||||
ciphertext = base64.b64decode(encrypt_b64)
|
||||
if len(ciphertext) < 17: # IV(16) + 至少 1 字节密文
|
||||
raise WeComError("企微密文长度不足")
|
||||
|
||||
iv = ciphertext[:16]
|
||||
encrypted = ciphertext[16:]
|
||||
|
||||
cipher = Cipher(algorithms.AES(key), modes.CBC(iv))
|
||||
decryptor = cipher.decryptor()
|
||||
padded = decryptor.update(encrypted) + decryptor.finalize()
|
||||
|
||||
unpadder = PKCS7(algorithms.AES.block_size).unpadder()
|
||||
plaintext = unpadder.update(padded) + unpadder.finalize()
|
||||
|
||||
# plaintext = random(16) + msg_len(4, 大端) + msg + app_id
|
||||
if len(plaintext) < 20:
|
||||
raise WeComError("企微解密后明文长度不足")
|
||||
msg_len = int.from_bytes(plaintext[16:20], "big")
|
||||
msg = plaintext[20 : 20 + msg_len]
|
||||
app_id = plaintext[20 + msg_len :].decode("utf-8", errors="replace")
|
||||
|
||||
if app_id != self.corp_id:
|
||||
raise WeComError(f"app_id 不匹配: {app_id} != {self.corp_id}")
|
||||
|
||||
return msg.decode("utf-8")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# XML 解析辅助
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _extract_encrypt(self, body: bytes) -> str | None:
|
||||
"""从外层 XML body 提取 ``Encrypt`` 字段。"""
|
||||
try:
|
||||
root = ET.fromstring(body.decode("utf-8"))
|
||||
except (ET.ParseError, UnicodeDecodeError):
|
||||
return None
|
||||
encrypt_elem = root.find("Encrypt")
|
||||
if encrypt_elem is None or encrypt_elem.text is None:
|
||||
return None
|
||||
return encrypt_elem.text
|
||||
|
||||
def _parse_xml(self, xml_str: str) -> dict[str, str]:
|
||||
"""解析 XML 为 ``{tag: text}`` 字典(无命名空间)。"""
|
||||
root = ET.fromstring(xml_str)
|
||||
result: dict[str, str] = {}
|
||||
for child in root:
|
||||
result[child.tag] = child.text or ""
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 签名验证
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def verify_signature(self, headers: dict[str, str], body: bytes) -> bool:
|
||||
"""验证企微 webhook 签名。
|
||||
|
||||
签名 = sha1(sorted([token, timestamp, nonce, encrypt]).join("")),
|
||||
与 query 参数 ``msg_signature`` 比对。timestamp/nonce/msg_signature
|
||||
由 webhook 端点从 query 参数合并到 headers dict 中。
|
||||
|
||||
Args:
|
||||
headers: HTTP 请求头(含已合并的 query 参数)。
|
||||
body: 原始请求体字节(XML)。
|
||||
|
||||
Returns:
|
||||
True 表示签名校验通过。
|
||||
"""
|
||||
msg_signature = _header_get(headers, "msg_signature")
|
||||
timestamp = _header_get(headers, "timestamp")
|
||||
nonce = _header_get(headers, "nonce")
|
||||
if not msg_signature or not timestamp or not nonce:
|
||||
return False
|
||||
|
||||
encrypt = self._extract_encrypt(body)
|
||||
if not encrypt:
|
||||
return False
|
||||
|
||||
# sha1(sorted([token, timestamp, nonce, encrypt]).join(""))
|
||||
parts = sorted([self.token, timestamp, nonce, encrypt])
|
||||
expected = hashlib.sha1("".join(parts).encode("utf-8")).hexdigest()
|
||||
return hmac.compare_digest(msg_signature, expected)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 消息接收 / 解析
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def receive_message(self, headers: dict[str, str], body: bytes) -> IncomingMessage:
|
||||
"""解析企微 webhook 事件为标准化 :class:`IncomingMessage`。
|
||||
|
||||
- URL 验证流程(解密后含 ``EchoStr``):抛 :class:`WeComURLVerification`。
|
||||
- 普通消息:解密后解析内部 XML,提取 FromUserName/MsgId/Content。
|
||||
|
||||
Raises:
|
||||
WeComURLVerification: URL 验证事件,携带 XML 响应。
|
||||
WeComError: 缺少 Encrypt 字段或解密失败。
|
||||
"""
|
||||
encrypt = self._extract_encrypt(body)
|
||||
if not encrypt:
|
||||
raise WeComError("企微 XML body 缺少 Encrypt 字段")
|
||||
|
||||
plaintext = self._decrypt(encrypt)
|
||||
inner = self._parse_xml(plaintext)
|
||||
|
||||
# URL 验证流程 — 内部 XML 包含 EchoStr
|
||||
if "EchoStr" in inner:
|
||||
timestamp = _header_get(headers, "timestamp") or ""
|
||||
nonce = _header_get(headers, "nonce") or ""
|
||||
encrypted_echo = self._encrypt(inner["EchoStr"])
|
||||
# 计算响应签名
|
||||
sig_parts = sorted([self.token, timestamp, nonce, encrypted_echo])
|
||||
sig = hashlib.sha1("".join(sig_parts).encode("utf-8")).hexdigest()
|
||||
response_xml = (
|
||||
f"<xml><Encrypt><![CDATA[{encrypted_echo}]]></Encrypt>"
|
||||
f"<MsgSignature><![CDATA[{sig}]]></MsgSignature>"
|
||||
f"<TimeStamp>{timestamp}</TimeStamp>"
|
||||
f"<Nonce><![CDATA[{nonce}]]></Nonce></xml>"
|
||||
)
|
||||
raise WeComURLVerification(response_xml)
|
||||
|
||||
# 普通消息
|
||||
from_user = inner.get("FromUserName", "")
|
||||
msg_id = inner.get("MsgId", "")
|
||||
content = inner.get("Content", "")
|
||||
create_time = inner.get("CreateTime", "")
|
||||
|
||||
# 群聊消息内容以 ":" 开头 — 剥离
|
||||
if content.startswith(":"):
|
||||
content = content[1:].strip()
|
||||
|
||||
# chat_id: 使用 corp_id + from_user 复合键
|
||||
# ponytail: 简单复合键。天花板:群聊场景需独立 ChatId;
|
||||
# 升级路径:从内部 XML 的 ChatId 字段提取(群聊事件携带)。
|
||||
chat_id = f"{self.corp_id}:{from_user}"
|
||||
|
||||
return IncomingMessage(
|
||||
channel=ChannelType.WECOM,
|
||||
platform_message_id=msg_id,
|
||||
user_id=from_user,
|
||||
chat_id=chat_id,
|
||||
content=content,
|
||||
raw_event=inner,
|
||||
timestamp=create_time,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 消息发送
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def send_message(self, message: OutgoingMessage) -> bool:
|
||||
"""向企微发送文本消息。
|
||||
|
||||
Returns:
|
||||
True 表示 HTTP 200 且 ``errcode == 0``。
|
||||
"""
|
||||
try:
|
||||
token = await self._get_access_token()
|
||||
if not token:
|
||||
return False
|
||||
|
||||
client = self._get_client()
|
||||
payload = {
|
||||
"touser": message.chat_id,
|
||||
"msgtype": "text",
|
||||
"agentid": self.agent_id,
|
||||
"text": {"content": message.content},
|
||||
}
|
||||
resp = await client.post(
|
||||
_SEND_MESSAGE_URL,
|
||||
params={"access_token": token},
|
||||
json=payload,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
logger.error("企微 send_message HTTP %d: %s", resp.status_code, resp.text[:200])
|
||||
return False
|
||||
data = resp.json()
|
||||
if data.get("errcode") != 0:
|
||||
logger.error(
|
||||
"企微 send_message 业务失败 errcode=%s errmsg=%s",
|
||||
data.get("errcode"),
|
||||
data.get("errmsg", "")[:200],
|
||||
)
|
||||
return False
|
||||
return True
|
||||
except httpx.HTTPError as exc:
|
||||
logger.error("企微 send_message 网络错误: %s", exc)
|
||||
return False
|
||||
|
||||
async def _get_access_token(self) -> str | None:
|
||||
"""获取并缓存企微 ``access_token``。"""
|
||||
# 命中缓存
|
||||
if self._token_cache is not None:
|
||||
token, expiry = self._token_cache
|
||||
if time.monotonic() < expiry:
|
||||
return token
|
||||
|
||||
try:
|
||||
client = self._get_client()
|
||||
resp = await client.get(
|
||||
_GET_TOKEN_URL,
|
||||
params={"corpid": self.corp_id, "corpsecret": self.corp_secret},
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
logger.error("企微 access_token HTTP %d: %s", resp.status_code, resp.text[:200])
|
||||
return None
|
||||
data = resp.json()
|
||||
if data.get("errcode") != 0:
|
||||
logger.error(
|
||||
"企微 access_token 业务失败 errcode=%s errmsg=%s",
|
||||
data.get("errcode"),
|
||||
data.get("errmsg", "")[:200],
|
||||
)
|
||||
return None
|
||||
token = data.get("access_token", "")
|
||||
if not token:
|
||||
return None
|
||||
self._token_cache = (token, time.monotonic() + _TOKEN_CACHE_TTL)
|
||||
return token
|
||||
except httpx.HTTPError as exc:
|
||||
logger.error("企微 access_token 网络错误: %s", exc)
|
||||
return None
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 资源释放
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def close(self) -> None:
|
||||
"""关闭 httpx 客户端(如已创建)。"""
|
||||
if self._client is not None:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 辅助函数
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _header_get(headers: dict[str, str], name: str) -> str | None:
|
||||
"""大小写不敏感的 header 查找。"""
|
||||
if name in headers:
|
||||
return headers[name]
|
||||
lower = name.lower()
|
||||
for k, v in headers.items():
|
||||
if k.lower() == lower:
|
||||
return v
|
||||
return None
|
||||
|
|
@ -24,11 +24,18 @@ import time
|
|||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import Response
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from agentkit.channels.base import ChannelType, MessageAdapter, OutgoingMessage
|
||||
from agentkit.channels.dingtalk import DingTalkMessageAdapter
|
||||
from agentkit.channels.feishu import FeishuMessageAdapter, URLVerificationChallenge
|
||||
from agentkit.channels.secrets import SecretsStore
|
||||
from agentkit.channels.slack import (
|
||||
SlackMessageAdapter,
|
||||
URLVerificationChallenge as SlackURLVerificationChallenge,
|
||||
)
|
||||
from agentkit.channels.wecom import WeComMessageAdapter, WeComURLVerification
|
||||
from agentkit.server.auth.dependencies import require_permission
|
||||
from agentkit.server.auth.permissions import Permission
|
||||
|
||||
|
|
@ -315,16 +322,19 @@ async def delete_channel(
|
|||
async def _build_adapter(channel_id: str) -> MessageAdapter:
|
||||
"""根据渠道配置与 secrets 构造适配器实例。
|
||||
|
||||
支持飞书 / 钉钉 / 企微 / Slack 四种渠道类型,按 ``channel_type`` 分发。
|
||||
|
||||
Raises:
|
||||
HTTPException: 渠道不存在(404)、渠道类型非飞书(400)、缺少必要凭证(500)。
|
||||
HTTPException: 渠道不存在(404)、渠道类型不支持(400)、缺少必要凭证(500)。
|
||||
"""
|
||||
cfg = _channels.get(channel_id)
|
||||
if cfg is None:
|
||||
raise HTTPException(status_code=404, detail=f"渠道 '{channel_id}' 不存在")
|
||||
if cfg["channel_type"] != ChannelType.FEISHU.value:
|
||||
raise HTTPException(status_code=400, detail=f"渠道 '{channel_id}' 不是飞书渠道")
|
||||
|
||||
store = _get_secrets_store()
|
||||
channel_type = cfg["channel_type"]
|
||||
|
||||
if channel_type == ChannelType.FEISHU.value:
|
||||
app_id = await store.get_secret(f"{channel_id}:app_id")
|
||||
app_secret = await store.get_secret(f"{channel_id}:app_secret")
|
||||
encrypt_key = await store.get_secret(f"{channel_id}:encrypt_key")
|
||||
|
|
@ -340,14 +350,72 @@ async def _build_adapter(channel_id: str) -> MessageAdapter:
|
|||
verification_token=verification_token,
|
||||
)
|
||||
|
||||
if channel_type == ChannelType.DINGTALK.value:
|
||||
app_key = await store.get_secret(f"{channel_id}:app_key")
|
||||
app_secret = await store.get_secret(f"{channel_id}:app_secret")
|
||||
robot_code = await store.get_secret(f"{channel_id}:robot_code")
|
||||
token = await store.get_secret(f"{channel_id}:token")
|
||||
if not all([app_key, app_secret, robot_code]):
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"渠道 '{channel_id}' 缺少 dingtalk 凭证"
|
||||
)
|
||||
return DingTalkMessageAdapter(
|
||||
app_key=app_key,
|
||||
app_secret=app_secret,
|
||||
robot_code=robot_code,
|
||||
token=token,
|
||||
)
|
||||
|
||||
if channel_type == ChannelType.WECOM.value:
|
||||
corp_id = await store.get_secret(f"{channel_id}:corp_id")
|
||||
corp_secret = await store.get_secret(f"{channel_id}:corp_secret")
|
||||
token = await store.get_secret(f"{channel_id}:token")
|
||||
encoding_aes_key = await store.get_secret(f"{channel_id}:encoding_aes_key")
|
||||
agent_id_raw = await store.get_secret(f"{channel_id}:agent_id")
|
||||
if not all([corp_id, corp_secret, token, encoding_aes_key, agent_id_raw]):
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"渠道 '{channel_id}' 缺少 wecom 凭证"
|
||||
)
|
||||
try:
|
||||
agent_id = int(agent_id_raw)
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"渠道 '{channel_id}' agent_id 不是合法整数",
|
||||
) from exc
|
||||
return WeComMessageAdapter(
|
||||
corp_id=corp_id,
|
||||
agent_id=agent_id,
|
||||
corp_secret=corp_secret,
|
||||
token=token,
|
||||
encoding_aes_key=encoding_aes_key,
|
||||
)
|
||||
|
||||
if channel_type == ChannelType.SLACK.value:
|
||||
bot_token = await store.get_secret(f"{channel_id}:bot_token")
|
||||
signing_secret = await store.get_secret(f"{channel_id}:signing_secret")
|
||||
verification_token = await store.get_secret(f"{channel_id}:verification_token")
|
||||
if not bot_token or not signing_secret:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"渠道 '{channel_id}' 缺少 slack 凭证"
|
||||
)
|
||||
return SlackMessageAdapter(
|
||||
bot_token=bot_token,
|
||||
signing_secret=signing_secret,
|
||||
verification_token=verification_token,
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=400, detail=f"不支持的渠道类型: {channel_type}")
|
||||
|
||||
|
||||
async def _process_inbound_message(
|
||||
app_state: Any, adapter: FeishuMessageAdapter, message: Any
|
||||
app_state: Any, adapter: MessageAdapter, message: Any
|
||||
) -> None:
|
||||
"""后台处理入站消息 — 调用 chat 链路并通过适配器回复。
|
||||
|
||||
整个流程 try/except 包裹,任何异常仅记录日志,不向上抛出
|
||||
(webhook 必须保持响应能力)。``adapter.close()`` 在 finally 中调用。
|
||||
适配器类型不限 — 出站消息的 ``channel`` 取自入站消息以匹配平台。
|
||||
"""
|
||||
try:
|
||||
request_preprocessor = getattr(app_state, "request_preprocessor", None)
|
||||
|
|
@ -395,13 +463,13 @@ async def _process_inbound_message(
|
|||
return
|
||||
|
||||
outgoing = OutgoingMessage(
|
||||
channel=ChannelType.FEISHU,
|
||||
channel=message.channel,
|
||||
chat_id=message.chat_id,
|
||||
content=final_content,
|
||||
)
|
||||
await adapter.send_message(outgoing)
|
||||
except Exception as exc: # noqa: BLE001 — webhook 必须保持响应能力
|
||||
logger.exception("处理飞书入站消息失败: %s", exc)
|
||||
logger.exception("处理入站消息失败: %s", exc)
|
||||
finally:
|
||||
try:
|
||||
await adapter.close()
|
||||
|
|
@ -421,16 +489,19 @@ def _get_client_ip(request: Request) -> str:
|
|||
|
||||
|
||||
@router.post("/channels/{channel_id}/webhook")
|
||||
async def channel_webhook(channel_id: str, request: Request) -> dict[str, Any]:
|
||||
"""飞书 webhook 端点 — 接收平台事件。
|
||||
async def channel_webhook(channel_id: str, request: Request) -> Any:
|
||||
"""渠道 webhook 端点 — 接收平台事件(飞书/钉钉/企微/Slack)。
|
||||
|
||||
安全流程(按顺序):
|
||||
1. Per-IP 限流(100 req/min)— 超限 429
|
||||
2. 读取原始 body(未解析)
|
||||
3. 签名验证 — 失败 401
|
||||
4. Nonce dedup — 重复返回 200(飞书要求 3s 内响应)
|
||||
5. URL verification challenge — 返回 ``{"challenge": ...}``
|
||||
5. URL verification — 飞书/Slack 返回 challenge;企微返回 XML
|
||||
6. 解析消息 → 后台异步处理 → 立即返回 200
|
||||
|
||||
企微通过 query 参数传递 ``msg_signature``/``timestamp``/``nonce``,
|
||||
合并到 headers dict 供适配器读取。
|
||||
"""
|
||||
client_ip = _get_client_ip(request)
|
||||
if not _check_rate_limit(client_ip):
|
||||
|
|
@ -441,21 +512,29 @@ async def channel_webhook(channel_id: str, request: Request) -> dict[str, Any]:
|
|||
adapter = await _build_adapter(channel_id)
|
||||
|
||||
headers_dict = dict(request.headers)
|
||||
# 企微等平台通过 query 参数传递签名信息 — 合并到 headers dict 供适配器读取
|
||||
for key, value in request.query_params.items():
|
||||
if key not in headers_dict:
|
||||
headers_dict[key] = value
|
||||
|
||||
if not await adapter.verify_signature(headers_dict, body):
|
||||
raise HTTPException(status_code=401, detail="签名校验失败")
|
||||
|
||||
# Nonce dedup(可选 — 若头不存在则跳过去重)
|
||||
# Nonce dedup(可选 — 若头不存在则跳过去重;仅飞书携带该头)
|
||||
nonce = request.headers.get("x-lark-request-nonce")
|
||||
if nonce and not _check_nonce_dedup(nonce):
|
||||
return {"code": 0, "msg": "duplicate"}
|
||||
|
||||
try:
|
||||
message = await adapter.receive_message(headers_dict, body)
|
||||
except URLVerificationChallenge as e:
|
||||
# URL 验证流程 — 飞书配置 webhook 时发送
|
||||
except (URLVerificationChallenge, SlackURLVerificationChallenge) as e:
|
||||
# URL 验证流程 — 飞书 / Slack 配置 webhook 时发送
|
||||
return {"challenge": e.challenge}
|
||||
except WeComURLVerification as e:
|
||||
# 企微 URL 验证 — 返回 XML 响应
|
||||
return Response(content=e.response_xml, media_type="application/xml")
|
||||
|
||||
# 异步处理 — 不阻塞 webhook 响应(飞书要求 3s 内返回 200)
|
||||
# 异步处理 — 不阻塞 webhook 响应(平台要求快速返回 200)
|
||||
asyncio.create_task(_process_inbound_message(request.app.state, adapter, message))
|
||||
|
||||
return {"code": 0}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,304 @@
|
|||
"""钉钉 IM 适配器单元测试 (U12)。
|
||||
|
||||
覆盖场景:
|
||||
- 签名校验(有效/无效/过期/缺签名头)
|
||||
- Token 校验(匹配/不匹配/未配置)
|
||||
- 文本消息解析(含 @ 提及剥离)
|
||||
- 不支持消息类型
|
||||
- send_message 成功/失败
|
||||
- accessToken 缓存
|
||||
- senderStaffId/senderId/staffId 回退
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from agentkit.channels.base import ChannelType, IncomingMessage, OutgoingMessage
|
||||
from agentkit.channels.dingtalk import DingTalkMessageAdapter
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 辅助函数
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _sign(app_secret: str, timestamp_ms: int) -> str:
|
||||
"""构造钉钉 Sign 头值:base64(hmac-sha256(key=app_secret, msg="{ts}\n{secret}"))。"""
|
||||
string_to_sign = f"{timestamp_ms}\n{app_secret}"
|
||||
digest = hmac.new(
|
||||
app_secret.encode("utf-8"),
|
||||
string_to_sign.encode("utf-8"),
|
||||
hashlib.sha256,
|
||||
).digest()
|
||||
return base64.b64encode(digest).decode("utf-8")
|
||||
|
||||
|
||||
def _make_body(
|
||||
*,
|
||||
text: str = "hello",
|
||||
msgtype: str = "text",
|
||||
conversation_id: str = "cid1",
|
||||
sender_staff_id: str | None = "u1",
|
||||
sender_id: str | None = None,
|
||||
staff_id: str | None = None,
|
||||
msg_id: str = "m1",
|
||||
session_expired: int = 1700000000000,
|
||||
) -> dict[str, Any]:
|
||||
"""构造钉钉 webhook 事件 JSON 体。"""
|
||||
body: dict[str, Any] = {
|
||||
"conversationId": conversation_id,
|
||||
"msgId": msg_id,
|
||||
"msgtype": msgtype,
|
||||
"sessionWebhookExpiredTime": session_expired,
|
||||
}
|
||||
if sender_staff_id is not None:
|
||||
body["senderStaffId"] = sender_staff_id
|
||||
if sender_id is not None:
|
||||
body["senderId"] = sender_id
|
||||
if staff_id is not None:
|
||||
body["staffId"] = staff_id
|
||||
if msgtype == "text":
|
||||
body["text"] = {"content": text}
|
||||
else:
|
||||
body["richText"] = {"content": text}
|
||||
return body
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 签名校验
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSignatureVerification:
|
||||
"""钉钉签名校验。"""
|
||||
|
||||
async def test_valid_signature(self):
|
||||
"""正确 Sign + Timestamp 返回 True。"""
|
||||
app_secret = "secret123"
|
||||
adapter = DingTalkMessageAdapter(app_key="k", app_secret=app_secret, robot_code="r")
|
||||
ts_ms = int(time.time() * 1000)
|
||||
headers = {"Sign": _sign(app_secret, ts_ms), "Timestamp": str(ts_ms)}
|
||||
assert await adapter.verify_signature(headers, b"{}") is True
|
||||
|
||||
async def test_invalid_signature(self):
|
||||
"""篡改 Sign 返回 False。"""
|
||||
app_secret = "secret123"
|
||||
adapter = DingTalkMessageAdapter(app_key="k", app_secret=app_secret, robot_code="r")
|
||||
ts_ms = int(time.time() * 1000)
|
||||
headers = {"Sign": "tampered_signature", "Timestamp": str(ts_ms)}
|
||||
assert await adapter.verify_signature(headers, b"{}") is False
|
||||
|
||||
async def test_expired_timestamp(self):
|
||||
"""时间戳超过 1 小时返回 False。"""
|
||||
app_secret = "secret123"
|
||||
adapter = DingTalkMessageAdapter(app_key="k", app_secret=app_secret, robot_code="r")
|
||||
ts_ms = int((time.time() - 3700) * 1000)
|
||||
headers = {"Sign": _sign(app_secret, ts_ms), "Timestamp": str(ts_ms)}
|
||||
assert await adapter.verify_signature(headers, b"{}") is False
|
||||
|
||||
async def test_missing_signature_headers(self):
|
||||
"""缺 Sign + Timestamp 头且未配置 token 返回 False。"""
|
||||
adapter = DingTalkMessageAdapter(app_key="k", app_secret="s", robot_code="r")
|
||||
assert await adapter.verify_signature({}, b"{}") is False
|
||||
|
||||
|
||||
class TestTokenVerification:
|
||||
"""Token 校验。"""
|
||||
|
||||
async def test_token_mismatch(self):
|
||||
"""配置 token 后 Token 头不匹配返回 False。"""
|
||||
adapter = DingTalkMessageAdapter(
|
||||
app_key="k", app_secret="s", robot_code="r", token="abc"
|
||||
)
|
||||
headers = {"Token": "wrong"}
|
||||
assert await adapter.verify_signature(headers, b"{}") is False
|
||||
|
||||
async def test_token_match_without_sign(self):
|
||||
"""配置 token 且 Token 头匹配(无 Sign 头)返回 True。"""
|
||||
adapter = DingTalkMessageAdapter(
|
||||
app_key="k", app_secret="s", robot_code="r", token="abc"
|
||||
)
|
||||
headers = {"Token": "abc"}
|
||||
assert await adapter.verify_signature(headers, b"{}") is True
|
||||
|
||||
async def test_token_none_skips_token_check(self):
|
||||
"""token=None 时无需 Token 头,仅凭签名放行。"""
|
||||
app_secret = "secret123"
|
||||
adapter = DingTalkMessageAdapter(
|
||||
app_key="k", app_secret=app_secret, robot_code="r", token=None
|
||||
)
|
||||
ts_ms = int(time.time() * 1000)
|
||||
headers = {"Sign": _sign(app_secret, ts_ms), "Timestamp": str(ts_ms)}
|
||||
assert await adapter.verify_signature(headers, b"{}") is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 消息解析
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMessageParsing:
|
||||
"""文本消息解析。"""
|
||||
|
||||
async def test_text_message_parsing(self):
|
||||
"""文本事件解析为 IncomingMessage。"""
|
||||
adapter = DingTalkMessageAdapter(app_key="k", app_secret="s", robot_code="r")
|
||||
body = json.dumps(_make_body(text="hello world")).encode("utf-8")
|
||||
msg = await adapter.receive_message({}, body)
|
||||
assert isinstance(msg, IncomingMessage)
|
||||
assert msg.channel == ChannelType.DINGTALK
|
||||
assert msg.content == "hello world"
|
||||
assert msg.chat_id == "cid1"
|
||||
assert msg.user_id == "u1"
|
||||
assert msg.platform_message_id == "m1"
|
||||
assert msg.timestamp == "1700000000000"
|
||||
|
||||
async def test_mention_stripping(self):
|
||||
"""@ 机器人前缀被剥离。"""
|
||||
adapter = DingTalkMessageAdapter(app_key="k", app_secret="s", robot_code="r")
|
||||
body = json.dumps(_make_body(text="@robotName hello")).encode("utf-8")
|
||||
msg = await adapter.receive_message({}, body)
|
||||
assert msg.content == "hello"
|
||||
|
||||
async def test_unsupported_message_type(self):
|
||||
"""非 text 类型返回 unsupported 占位内容。"""
|
||||
adapter = DingTalkMessageAdapter(app_key="k", app_secret="s", robot_code="r")
|
||||
body = json.dumps(_make_body(msgtype="image")).encode("utf-8")
|
||||
msg = await adapter.receive_message({}, body)
|
||||
assert msg.content.startswith("[unsupported message type: image]")
|
||||
|
||||
async def test_senderid_fallback(self):
|
||||
"""缺少 senderStaffId 时回退到 senderId。"""
|
||||
adapter = DingTalkMessageAdapter(app_key="k", app_secret="s", robot_code="r")
|
||||
body = json.dumps(
|
||||
_make_body(sender_staff_id=None, sender_id="fallback_user")
|
||||
).encode("utf-8")
|
||||
msg = await adapter.receive_message({}, body)
|
||||
assert msg.user_id == "fallback_user"
|
||||
|
||||
async def test_staffid_fallback(self):
|
||||
"""缺少 senderStaffId 与 senderId 时回退到 staffId。"""
|
||||
adapter = DingTalkMessageAdapter(app_key="k", app_secret="s", robot_code="r")
|
||||
body = json.dumps(
|
||||
_make_body(sender_staff_id=None, sender_id=None, staff_id="staff_user")
|
||||
).encode("utf-8")
|
||||
msg = await adapter.receive_message({}, body)
|
||||
assert msg.user_id == "staff_user"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# send_message
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSendMessage:
|
||||
"""send_message 行为。"""
|
||||
|
||||
async def test_send_message_success(self):
|
||||
"""HTTP 200 返回 True,且 send 调用携带正确的 access token 头。"""
|
||||
adapter = DingTalkMessageAdapter(app_key="k", app_secret="s", robot_code="r")
|
||||
mock_token = MagicMock()
|
||||
mock_token.status_code = 200
|
||||
mock_token.json.return_value = {"accessToken": "tok_123"}
|
||||
|
||||
mock_send = MagicMock()
|
||||
mock_send.status_code = 200
|
||||
mock_send.json.return_value = {"processQueryKey": "x"}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(side_effect=[mock_token, mock_send])
|
||||
adapter._client = mock_client
|
||||
|
||||
out = OutgoingMessage(channel=ChannelType.DINGTALK, chat_id="c1", content="hi")
|
||||
assert await adapter.send_message(out) is True
|
||||
|
||||
send_call = mock_client.post.call_args_list[1]
|
||||
assert "robot/oToMessages/batchSend" in send_call.args[0]
|
||||
assert send_call.kwargs["headers"]["x-acs-dingtalk-access-token"] == "tok_123"
|
||||
assert send_call.kwargs["json"]["robotCode"] == "r"
|
||||
|
||||
async def test_send_message_failure(self):
|
||||
"""send 返回非 200 返回 False。"""
|
||||
adapter = DingTalkMessageAdapter(app_key="k", app_secret="s", robot_code="r")
|
||||
mock_token = MagicMock()
|
||||
mock_token.status_code = 200
|
||||
mock_token.json.return_value = {"accessToken": "tok_x"}
|
||||
|
||||
mock_send = MagicMock()
|
||||
mock_send.status_code = 400
|
||||
mock_send.text = "invalid request"
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(side_effect=[mock_token, mock_send])
|
||||
adapter._client = mock_client
|
||||
|
||||
out = OutgoingMessage(channel=ChannelType.DINGTALK, chat_id="c1", content="hi")
|
||||
assert await adapter.send_message(out) is False
|
||||
|
||||
async def test_send_message_token_fetch_failure(self):
|
||||
"""获取 accessToken 失败返回 False。"""
|
||||
adapter = DingTalkMessageAdapter(app_key="k", app_secret="s", robot_code="r")
|
||||
mock_token = MagicMock()
|
||||
mock_token.status_code = 401
|
||||
mock_token.text = "invalid"
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(return_value=mock_token)
|
||||
adapter._client = mock_client
|
||||
|
||||
out = OutgoingMessage(channel=ChannelType.DINGTALK, chat_id="c1", content="hi")
|
||||
assert await adapter.send_message(out) is False
|
||||
|
||||
async def test_access_token_caching(self):
|
||||
"""同 TTL 内的两次 send_message 只拉取一次 accessToken。"""
|
||||
adapter = DingTalkMessageAdapter(app_key="k", app_secret="s", robot_code="r")
|
||||
mock_token = MagicMock()
|
||||
mock_token.status_code = 200
|
||||
mock_token.json.return_value = {"accessToken": "cached_tok"}
|
||||
|
||||
mock_send = MagicMock()
|
||||
mock_send.status_code = 200
|
||||
mock_send.json.return_value = {"processQueryKey": "x"}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
# 第一次:token + send;第二次:仅 send(token 走缓存)
|
||||
mock_client.post = AsyncMock(side_effect=[mock_token, mock_send, mock_send])
|
||||
adapter._client = mock_client
|
||||
|
||||
out = OutgoingMessage(channel=ChannelType.DINGTALK, chat_id="c1", content="hi")
|
||||
await adapter.send_message(out)
|
||||
await adapter.send_message(out)
|
||||
|
||||
# 仅 1 次 token 调用 + 2 次 send 调用 = 3 次(未缓存会是 4 次)
|
||||
assert mock_client.post.call_count == 3
|
||||
first_url = mock_client.post.call_args_list[0].args[0]
|
||||
assert "accessToken" in first_url
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 资源释放
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestClose:
|
||||
"""资源释放。"""
|
||||
|
||||
async def test_close_no_client_is_noop(self):
|
||||
"""未创建 httpx 客户端时 close 不抛异常。"""
|
||||
adapter = DingTalkMessageAdapter(app_key="k", app_secret="s", robot_code="r")
|
||||
await adapter.close()
|
||||
|
||||
async def test_close_resets_client(self):
|
||||
"""close 后客户端引用清空。"""
|
||||
adapter = DingTalkMessageAdapter(app_key="k", app_secret="s", robot_code="r")
|
||||
adapter._get_client()
|
||||
assert adapter._client is not None
|
||||
await adapter.close()
|
||||
assert adapter._client is None
|
||||
|
|
@ -574,9 +574,9 @@ class TestWebhookErrors:
|
|||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
def test_non_feishu_channel_returns_400(self, webhook_client):
|
||||
"""POST 到非飞书渠道 webhook 返回 400。"""
|
||||
# 注册一个钉钉渠道
|
||||
def test_dingtalk_channel_without_secrets_returns_500(self, webhook_client):
|
||||
"""钉钉渠道未配置凭证时 webhook 返回 500(U12 后钉钉受支持)。"""
|
||||
# 注册一个钉钉渠道(不配置凭证)
|
||||
webhook_client.post(
|
||||
"/api/v1/channels",
|
||||
json={
|
||||
|
|
@ -590,7 +590,7 @@ class TestWebhookErrors:
|
|||
content=b"{}",
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
assert resp.status_code == 500
|
||||
|
||||
|
||||
class TestWebhookImmediateResponse:
|
||||
|
|
|
|||
|
|
@ -0,0 +1,320 @@
|
|||
"""Slack IM 适配器单元测试 (U12)。
|
||||
|
||||
覆盖场景:
|
||||
- 签名校验(有效/无效/过期)
|
||||
- URL 验证(challenge / token 不匹配)
|
||||
- Events API 消息解析
|
||||
- Slash Command 解析
|
||||
- <@U12345> 提及剥离
|
||||
- 不支持事件类型
|
||||
- send_message 成功/失败
|
||||
- bot_token 构造器存储
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import time
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.channels.base import ChannelType, IncomingMessage, OutgoingMessage
|
||||
from agentkit.channels.slack import (
|
||||
SlackMessageAdapter,
|
||||
SignatureVerificationError,
|
||||
URLVerificationChallenge,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 辅助函数
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _sign(signing_secret: str, timestamp: int, body: bytes) -> str:
|
||||
"""构造 Slack X-Slack-Signature 头值:v0= + hmac-sha256(signing_secret, "v0:{ts}:{body}")。"""
|
||||
base = f"v0:{timestamp}:{body.decode('utf-8')}"
|
||||
digest = hmac.new(
|
||||
signing_secret.encode("utf-8"),
|
||||
base.encode("utf-8"),
|
||||
hashlib.sha256,
|
||||
).hexdigest()
|
||||
return f"v0={digest}"
|
||||
|
||||
|
||||
def _make_event_body(
|
||||
*,
|
||||
text: str = "hello",
|
||||
event_type: str = "message",
|
||||
event_id: str = "evt_001",
|
||||
user: str = "U123",
|
||||
channel: str = "C456",
|
||||
ts: str = "1700000000.000123",
|
||||
) -> dict[str, Any]:
|
||||
"""构造 Slack Events API 事件 JSON 体。"""
|
||||
return {
|
||||
"event_id": event_id,
|
||||
"type": "event_callback",
|
||||
"event": {
|
||||
"type": event_type,
|
||||
"user": user,
|
||||
"channel": channel,
|
||||
"text": text,
|
||||
"ts": ts,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 签名校验
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSignatureVerification:
|
||||
"""Slack 签名校验。"""
|
||||
|
||||
async def test_valid_signature(self):
|
||||
"""正确 v0 签名返回 True。"""
|
||||
signing_secret = "sh_secret"
|
||||
adapter = SlackMessageAdapter(bot_token="xoxb-x", signing_secret=signing_secret)
|
||||
body = json.dumps(_make_event_body()).encode("utf-8")
|
||||
ts = int(time.time())
|
||||
headers = {
|
||||
"X-Slack-Signature": _sign(signing_secret, ts, body),
|
||||
"X-Slack-Request-Timestamp": str(ts),
|
||||
}
|
||||
assert await adapter.verify_signature(headers, body) is True
|
||||
|
||||
async def test_invalid_signature(self):
|
||||
"""篡改签名返回 False。"""
|
||||
signing_secret = "sh_secret"
|
||||
adapter = SlackMessageAdapter(bot_token="xoxb-x", signing_secret=signing_secret)
|
||||
body = json.dumps(_make_event_body()).encode("utf-8")
|
||||
ts = int(time.time())
|
||||
headers = {"X-Slack-Signature": "v0=tampered", "X-Slack-Request-Timestamp": str(ts)}
|
||||
assert await adapter.verify_signature(headers, body) is False
|
||||
|
||||
async def test_expired_timestamp(self):
|
||||
"""时间戳超过 5 分钟返回 False。"""
|
||||
signing_secret = "sh_secret"
|
||||
adapter = SlackMessageAdapter(bot_token="xoxb-x", signing_secret=signing_secret)
|
||||
body = json.dumps(_make_event_body()).encode("utf-8")
|
||||
old_ts = int(time.time()) - 600
|
||||
headers = {
|
||||
"X-Slack-Signature": _sign(signing_secret, old_ts, body),
|
||||
"X-Slack-Request-Timestamp": str(old_ts),
|
||||
}
|
||||
assert await adapter.verify_signature(headers, body) is False
|
||||
|
||||
async def test_missing_signature_headers(self):
|
||||
"""缺签名头返回 False。"""
|
||||
adapter = SlackMessageAdapter(bot_token="xoxb-x", signing_secret="s")
|
||||
assert await adapter.verify_signature({}, b"{}") is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# URL 验证
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestURLVerification:
|
||||
"""Slack URL 验证流程。"""
|
||||
|
||||
async def test_url_verification_raises_challenge(self):
|
||||
"""url_verification 事件抛 URLVerificationChallenge。"""
|
||||
adapter = SlackMessageAdapter(bot_token="xoxb-x", signing_secret="s")
|
||||
body = json.dumps(
|
||||
{"type": "url_verification", "challenge": "verify_abc", "token": "t"}
|
||||
).encode("utf-8")
|
||||
with pytest.raises(URLVerificationChallenge) as exc_info:
|
||||
await adapter.receive_message({}, body)
|
||||
assert exc_info.value.challenge == "verify_abc"
|
||||
|
||||
async def test_url_verification_token_mismatch_raises(self):
|
||||
"""verification_token 不匹配抛 SignatureVerificationError。"""
|
||||
adapter = SlackMessageAdapter(
|
||||
bot_token="xoxb-x",
|
||||
signing_secret="s",
|
||||
verification_token="right_token",
|
||||
)
|
||||
body = json.dumps(
|
||||
{"type": "url_verification", "challenge": "abc", "token": "wrong_token"}
|
||||
).encode("utf-8")
|
||||
with pytest.raises(SignatureVerificationError):
|
||||
await adapter.receive_message({}, body)
|
||||
|
||||
async def test_url_verification_token_match_passes(self):
|
||||
"""verification_token 匹配时正常抛出 challenge。"""
|
||||
adapter = SlackMessageAdapter(
|
||||
bot_token="xoxb-x",
|
||||
signing_secret="s",
|
||||
verification_token="right_token",
|
||||
)
|
||||
body = json.dumps(
|
||||
{"type": "url_verification", "challenge": "ok123", "token": "right_token"}
|
||||
).encode("utf-8")
|
||||
with pytest.raises(URLVerificationChallenge) as exc_info:
|
||||
await adapter.receive_message({}, body)
|
||||
assert exc_info.value.challenge == "ok123"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 消息解析
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestEventsAPIParsing:
|
||||
"""Events API 消息解析。"""
|
||||
|
||||
async def test_event_message_parsing(self):
|
||||
"""Events API 消息解析为 IncomingMessage。"""
|
||||
adapter = SlackMessageAdapter(bot_token="xoxb-x", signing_secret="s")
|
||||
body = json.dumps(_make_event_body(text="hello slack")).encode("utf-8")
|
||||
msg = await adapter.receive_message({}, body)
|
||||
assert isinstance(msg, IncomingMessage)
|
||||
assert msg.channel == ChannelType.SLACK
|
||||
assert msg.content == "hello slack"
|
||||
assert msg.platform_message_id == "evt_001"
|
||||
assert msg.user_id == "U123"
|
||||
assert msg.chat_id == "C456"
|
||||
assert msg.timestamp == "1700000000.000123"
|
||||
|
||||
async def test_mention_stripping(self):
|
||||
"""<@U12345> 提及标记被剥离。"""
|
||||
adapter = SlackMessageAdapter(bot_token="xoxb-x", signing_secret="s")
|
||||
body = json.dumps(_make_event_body(text="<@U12345> hello there")).encode("utf-8")
|
||||
msg = await adapter.receive_message({}, body)
|
||||
assert msg.content == "hello there"
|
||||
|
||||
async def test_mention_with_name_stripped(self):
|
||||
"""<@U12345|name> 形式的提及标记也被剥离。"""
|
||||
adapter = SlackMessageAdapter(bot_token="xoxb-x", signing_secret="s")
|
||||
body = json.dumps(_make_event_body(text="<@U999|bob> hi")).encode("utf-8")
|
||||
msg = await adapter.receive_message({}, body)
|
||||
assert msg.content == "hi"
|
||||
|
||||
async def test_unsupported_event_type(self):
|
||||
"""非 message/app_mention 事件返回 unsupported 占位内容。"""
|
||||
adapter = SlackMessageAdapter(bot_token="xoxb-x", signing_secret="s")
|
||||
body = json.dumps(_make_event_body(event_type="reaction_added")).encode("utf-8")
|
||||
msg = await adapter.receive_message({}, body)
|
||||
assert msg.content.startswith("[unsupported event type: reaction_added]")
|
||||
|
||||
|
||||
class TestSlashCommandParsing:
|
||||
"""Slash Command 解析。"""
|
||||
|
||||
async def test_slash_command_parsing(self):
|
||||
"""form-encoded slash command 解析为 IncomingMessage。"""
|
||||
adapter = SlackMessageAdapter(bot_token="xoxb-x", signing_secret="s")
|
||||
form = urlencode(
|
||||
{
|
||||
"text": "echo me",
|
||||
"user_id": "U_slash",
|
||||
"channel_id": "C_slash",
|
||||
"command": "/echo",
|
||||
"response_url": "https://hooks.slack.com/x",
|
||||
}
|
||||
)
|
||||
body = form.encode("utf-8")
|
||||
msg = await adapter.receive_message(
|
||||
{"Content-Type": "application/x-www-form-urlencoded"}, body
|
||||
)
|
||||
assert msg.channel == ChannelType.SLACK
|
||||
assert msg.content == "echo me"
|
||||
assert msg.user_id == "U_slash"
|
||||
assert msg.chat_id == "C_slash"
|
||||
assert msg.platform_message_id.startswith("slash-")
|
||||
assert msg.raw_event["command"] == "/echo"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# send_message
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSendMessage:
|
||||
"""send_message 行为。"""
|
||||
|
||||
async def test_send_message_success(self):
|
||||
"""HTTP 200 + ok=true 返回 True。"""
|
||||
adapter = SlackMessageAdapter(bot_token="xoxb-123", signing_secret="s")
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = {"ok": True, "channel": "C1", "ts": "1.0"}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||||
adapter._client = mock_client
|
||||
|
||||
out = OutgoingMessage(channel=ChannelType.SLACK, chat_id="C1", content="hi")
|
||||
assert await adapter.send_message(out) is True
|
||||
|
||||
call = mock_client.post.call_args
|
||||
assert "chat.postMessage" in call.args[0]
|
||||
assert call.kwargs["headers"]["Authorization"] == "Bearer xoxb-123"
|
||||
assert call.kwargs["json"]["channel"] == "C1"
|
||||
|
||||
async def test_send_message_failure(self):
|
||||
"""HTTP 200 但 ok=false 返回 False。"""
|
||||
adapter = SlackMessageAdapter(bot_token="xoxb-x", signing_secret="s")
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = {"ok": False, "error": "channel_not_found"}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||||
adapter._client = mock_client
|
||||
|
||||
out = OutgoingMessage(channel=ChannelType.SLACK, chat_id="C1", content="hi")
|
||||
assert await adapter.send_message(out) is False
|
||||
|
||||
async def test_send_message_http_error(self):
|
||||
"""非 200 HTTP 状态返回 False。"""
|
||||
adapter = SlackMessageAdapter(bot_token="xoxb-x", signing_secret="s")
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 500
|
||||
mock_resp.text = "server error"
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post = AsyncMock(return_value=mock_resp)
|
||||
adapter._client = mock_client
|
||||
|
||||
out = OutgoingMessage(channel=ChannelType.SLACK, chat_id="C1", content="hi")
|
||||
assert await adapter.send_message(out) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 构造器 / 资源释放
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConstructor:
|
||||
"""构造器与资源释放。"""
|
||||
|
||||
def test_bot_token_stored(self):
|
||||
"""bot_token 由构造器存储。"""
|
||||
adapter = SlackMessageAdapter(
|
||||
bot_token="xoxb-secret-token", signing_secret="s", verification_token="vt"
|
||||
)
|
||||
assert adapter.bot_token == "xoxb-secret-token"
|
||||
assert adapter.signing_secret == "s"
|
||||
assert adapter.verification_token == "vt"
|
||||
|
||||
async def test_close_no_client_is_noop(self):
|
||||
"""未创建 httpx 客户端时 close 不抛异常。"""
|
||||
adapter = SlackMessageAdapter(bot_token="x", signing_secret="s")
|
||||
await adapter.close()
|
||||
|
||||
async def test_close_resets_client(self):
|
||||
"""close 后客户端引用清空。"""
|
||||
adapter = SlackMessageAdapter(bot_token="x", signing_secret="s")
|
||||
adapter._get_client()
|
||||
assert adapter._client is not None
|
||||
await adapter.close()
|
||||
assert adapter._client is None
|
||||
|
|
@ -0,0 +1,323 @@
|
|||
"""企业微信 IM 适配器单元测试 (U12)。
|
||||
|
||||
覆盖场景:
|
||||
- EncodingAESKey 解码(43 字符 → 32 字节)
|
||||
- 签名校验(有效/无效)
|
||||
- AES-256-CBC 加解密往返
|
||||
- URL 验证流程(XML 响应)
|
||||
- 消息解析(FromUserName/MsgId/Content)
|
||||
- send_message 成功/失败
|
||||
- access_token 缓存
|
||||
- app_id 不匹配 → 抛异常
|
||||
- chat_id 复合键
|
||||
- 群聊消息 ":" 前缀剥离
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from xml.etree import ElementTree as ET
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.channels.base import ChannelType, IncomingMessage, OutgoingMessage
|
||||
from agentkit.channels.wecom import WeComError, WeComMessageAdapter, WeComURLVerification
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 辅助函数
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_aes_key() -> str:
|
||||
"""生成合法的 43 字符 EncodingAESKey(解码为 32 字节)。"""
|
||||
return base64.b64encode(b"\x11" * 32).decode().rstrip("=")
|
||||
|
||||
|
||||
def _make_adapter(
|
||||
*,
|
||||
corp_id: str = "corp_test",
|
||||
agent_id: int = 1000001,
|
||||
corp_secret: str = "secret_test",
|
||||
token: str = "token_test",
|
||||
encoding_aes_key: str | None = None,
|
||||
) -> WeComMessageAdapter:
|
||||
"""构造测试用企微适配器。"""
|
||||
return WeComMessageAdapter(
|
||||
corp_id=corp_id,
|
||||
agent_id=agent_id,
|
||||
corp_secret=corp_secret,
|
||||
token=token,
|
||||
encoding_aes_key=encoding_aes_key or _make_aes_key(),
|
||||
)
|
||||
|
||||
|
||||
def _build_inner_xml(fields: dict[str, str]) -> str:
|
||||
"""由字段字典构造内部 XML 字符串。"""
|
||||
parts = ["<xml>"]
|
||||
for tag, value in fields.items():
|
||||
parts.append(f"<{tag}>{value}</{tag}>")
|
||||
parts.append("</xml>")
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
def _build_body(adapter: WeComMessageAdapter, inner_xml: str) -> bytes:
|
||||
"""加密内部 XML 并包装为外层 webhook body。"""
|
||||
encrypt = adapter._encrypt(inner_xml)
|
||||
return (
|
||||
f"<xml><ToUserName>{adapter.corp_id}</ToUserName>"
|
||||
f"<Encrypt>{encrypt}</Encrypt></xml>"
|
||||
).encode("utf-8")
|
||||
|
||||
|
||||
def _signature(token: str, timestamp: str, nonce: str, encrypt: str) -> str:
|
||||
"""计算企微签名:sha1(sorted([token, ts, nonce, encrypt]).join(""))。"""
|
||||
parts = sorted([token, timestamp, nonce, encrypt])
|
||||
return hashlib.sha1("".join(parts).encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AES 密钥 / 加解密
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAESKey:
|
||||
"""EncodingAESKey 解码。"""
|
||||
|
||||
def test_decode_aes_key_to_32_bytes(self):
|
||||
"""43 字符 EncodingAESKey 解码为 32 字节。"""
|
||||
adapter = _make_adapter(encoding_aes_key=_make_aes_key())
|
||||
key = adapter._decode_aes_key()
|
||||
assert len(key) == 32
|
||||
|
||||
def test_encrypt_decrypt_roundtrip(self):
|
||||
"""加密后解密能恢复原文。"""
|
||||
adapter = _make_adapter()
|
||||
encrypted = adapter._encrypt("hello 企微")
|
||||
assert adapter._decrypt(encrypted) == "hello 企微"
|
||||
|
||||
|
||||
class TestSignatureVerification:
|
||||
"""签名校验。"""
|
||||
|
||||
async def test_valid_signature(self):
|
||||
"""正确 msg_signature 返回 True。"""
|
||||
adapter = _make_adapter()
|
||||
inner_xml = _build_inner_xml({"MsgId": "m1", "Content": "hi"})
|
||||
body = _build_body(adapter, inner_xml)
|
||||
encrypt = adapter._extract_encrypt(body) or ""
|
||||
ts, nonce = "1609459200", "n1"
|
||||
sig = _signature(adapter.token, ts, nonce, encrypt)
|
||||
headers = {"msg_signature": sig, "timestamp": ts, "nonce": nonce}
|
||||
assert await adapter.verify_signature(headers, body) is True
|
||||
|
||||
async def test_invalid_signature(self):
|
||||
"""篡改 msg_signature 返回 False。"""
|
||||
adapter = _make_adapter()
|
||||
inner_xml = _build_inner_xml({"MsgId": "m1", "Content": "hi"})
|
||||
body = _build_body(adapter, inner_xml)
|
||||
ts, nonce = "1609459200", "n1"
|
||||
headers = {"msg_signature": "tampered", "timestamp": ts, "nonce": nonce}
|
||||
assert await adapter.verify_signature(headers, body) is False
|
||||
|
||||
async def test_missing_query_params(self):
|
||||
"""缺少 msg_signature/timestamp/nonce 返回 False。"""
|
||||
adapter = _make_adapter()
|
||||
body = b"<xml><Encrypt>x</Encrypt></xml>"
|
||||
assert await adapter.verify_signature({}, body) is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# URL 验证流程
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestURLVerification:
|
||||
"""企微 URL 验证流程。"""
|
||||
|
||||
async def test_url_verification_raises_with_xml(self):
|
||||
"""含 EchoStr 的加密事件抛 WeComURLVerification,响应 XML 可解密回 echo。"""
|
||||
adapter = _make_adapter()
|
||||
inner_xml = _build_inner_xml({"EchoStr": "echo123"})
|
||||
body = _build_body(adapter, inner_xml)
|
||||
ts, nonce = "1609459200", "n1"
|
||||
headers = {"timestamp": ts, "nonce": nonce}
|
||||
|
||||
with pytest.raises(WeComURLVerification) as exc_info:
|
||||
await adapter.receive_message(headers, body)
|
||||
|
||||
# 验证响应 XML 可解密回 echo 值
|
||||
response_xml = exc_info.value.response_xml
|
||||
root = ET.fromstring(response_xml)
|
||||
encrypted_echo = root.find("Encrypt").text # type: ignore[union-attr]
|
||||
assert adapter._decrypt(encrypted_echo) == "echo123"
|
||||
|
||||
async def test_url_verification_missing_encrypt_raises(self):
|
||||
"""body 缺少 Encrypt 字段抛 WeComError。"""
|
||||
adapter = _make_adapter()
|
||||
body = b"<xml><ToUserName>x</ToUserName></xml>"
|
||||
with pytest.raises(WeComError):
|
||||
await adapter.receive_message({}, body)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 消息解析
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMessageParsing:
|
||||
"""消息解析。"""
|
||||
|
||||
async def test_receive_message_fields(self):
|
||||
"""解析普通消息提取 FromUserName/MsgId/Content。"""
|
||||
adapter = _make_adapter()
|
||||
inner_xml = _build_inner_xml(
|
||||
{
|
||||
"FromUserName": "user1",
|
||||
"MsgId": "msg_001",
|
||||
"Content": "hello wecom",
|
||||
"CreateTime": "1609459200",
|
||||
"MsgType": "text",
|
||||
}
|
||||
)
|
||||
body = _build_body(adapter, inner_xml)
|
||||
msg = await adapter.receive_message({}, body)
|
||||
assert isinstance(msg, IncomingMessage)
|
||||
assert msg.channel == ChannelType.WECOM
|
||||
assert msg.user_id == "user1"
|
||||
assert msg.platform_message_id == "msg_001"
|
||||
assert msg.content == "hello wecom"
|
||||
assert msg.timestamp == "1609459200"
|
||||
|
||||
async def test_chat_id_composition(self):
|
||||
"""chat_id 由 corp_id + from_user 复合而成。"""
|
||||
adapter = _make_adapter(corp_id="my_corp")
|
||||
inner_xml = _build_inner_xml({"FromUserName": "userA", "MsgId": "m1", "Content": "hi"})
|
||||
body = _build_body(adapter, inner_xml)
|
||||
msg = await adapter.receive_message({}, body)
|
||||
assert msg.chat_id == "my_corp:userA"
|
||||
|
||||
async def test_chat_room_mention_stripping(self):
|
||||
"""群聊消息内容以 ":" 开头时被剥离。"""
|
||||
adapter = _make_adapter()
|
||||
inner_xml = _build_inner_xml(
|
||||
{"FromUserName": "u1", "MsgId": "m1", "Content": ":hello room"}
|
||||
)
|
||||
body = _build_body(adapter, inner_xml)
|
||||
msg = await adapter.receive_message({}, body)
|
||||
assert msg.content == "hello room"
|
||||
|
||||
|
||||
class TestAppIDMismatch:
|
||||
"""app_id 校验。"""
|
||||
|
||||
async def test_app_id_mismatch_raises(self):
|
||||
"""解密后 app_id 不匹配 corp_id 抛 WeComError。"""
|
||||
adapter = _make_adapter(corp_id="corp_a")
|
||||
# 用不同 corp_id 的适配器加密,AES 密钥相同
|
||||
other = _make_adapter(corp_id="corp_b")
|
||||
inner_xml = _build_inner_xml({"FromUserName": "u1", "MsgId": "m1", "Content": "hi"})
|
||||
body = _build_body(other, inner_xml)
|
||||
|
||||
with pytest.raises(WeComError):
|
||||
await adapter.receive_message({}, body)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# send_message
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSendMessage:
|
||||
"""send_message 行为。"""
|
||||
|
||||
async def test_send_message_success(self):
|
||||
"""HTTP 200 + errcode=0 返回 True。"""
|
||||
adapter = _make_adapter()
|
||||
mock_token = MagicMock()
|
||||
mock_token.status_code = 200
|
||||
mock_token.json.return_value = {"errcode": 0, "access_token": "tok_123"}
|
||||
|
||||
mock_send = MagicMock()
|
||||
mock_send.status_code = 200
|
||||
mock_send.json.return_value = {"errcode": 0, "errmsg": "ok"}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_token)
|
||||
mock_client.post = AsyncMock(return_value=mock_send)
|
||||
adapter._client = mock_client
|
||||
|
||||
out = OutgoingMessage(channel=ChannelType.WECOM, chat_id="c1", content="hi")
|
||||
assert await adapter.send_message(out) is True
|
||||
|
||||
send_call = mock_client.post.call_args
|
||||
assert "message/send" in send_call.args[0]
|
||||
assert send_call.kwargs["params"]["access_token"] == "tok_123"
|
||||
assert send_call.kwargs["json"]["agentid"] == 1000001
|
||||
|
||||
async def test_send_message_business_failure(self):
|
||||
"""HTTP 200 但 errcode != 0 返回 False。"""
|
||||
adapter = _make_adapter()
|
||||
mock_token = MagicMock()
|
||||
mock_token.status_code = 200
|
||||
mock_token.json.return_value = {"errcode": 0, "access_token": "tok_x"}
|
||||
|
||||
mock_send = MagicMock()
|
||||
mock_send.status_code = 200
|
||||
mock_send.json.return_value = {"errcode": 40014, "errmsg": "invalid token"}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_token)
|
||||
mock_client.post = AsyncMock(return_value=mock_send)
|
||||
adapter._client = mock_client
|
||||
|
||||
out = OutgoingMessage(channel=ChannelType.WECOM, chat_id="c1", content="hi")
|
||||
assert await adapter.send_message(out) is False
|
||||
|
||||
async def test_access_token_caching(self):
|
||||
"""同 TTL 内两次 send_message 只拉取一次 access_token。"""
|
||||
adapter = _make_adapter()
|
||||
mock_token = MagicMock()
|
||||
mock_token.status_code = 200
|
||||
mock_token.json.return_value = {"errcode": 0, "access_token": "cached_tok"}
|
||||
|
||||
mock_send = MagicMock()
|
||||
mock_send.status_code = 200
|
||||
mock_send.json.return_value = {"errcode": 0}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_token)
|
||||
mock_client.post = AsyncMock(return_value=mock_send)
|
||||
adapter._client = mock_client
|
||||
|
||||
out = OutgoingMessage(channel=ChannelType.WECOM, chat_id="c1", content="hi")
|
||||
await adapter.send_message(out)
|
||||
await adapter.send_message(out)
|
||||
|
||||
# 仅 1 次 token(GET)+ 2 次 send(POST)
|
||||
assert mock_client.get.call_count == 1
|
||||
assert mock_client.post.call_count == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 资源释放
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestClose:
|
||||
"""资源释放。"""
|
||||
|
||||
async def test_close_no_client_is_noop(self):
|
||||
"""未创建 httpx 客户端时 close 不抛异常。"""
|
||||
adapter = _make_adapter()
|
||||
await adapter.close()
|
||||
|
||||
async def test_close_resets_client(self):
|
||||
"""close 后客户端引用清空。"""
|
||||
adapter = _make_adapter()
|
||||
adapter._get_client()
|
||||
assert adapter._client is not None
|
||||
await adapter.close()
|
||||
assert adapter._client is None
|
||||
Loading…
Reference in New Issue