fischer-agentkit/tests/unit/channels/test_wecom.py

338 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""企业微信 IM 适配器单元测试 (U12)。
覆盖场景:
- EncodingAESKey 解码43 字符 → 32 字节)
- 签名校验(有效/无效)
- AES-256-CBC 加解密往返
- URL 验证流程XML 响应)
- 消息解析FromUserName/MsgId/Content
- send_message 成功/失败
- access_token 缓存
- app_id 不匹配 → 抛异常
- chat_id 复合键
- 群聊消息 ":" 前缀剥离
"""
from __future__ import annotations
import base64
import hashlib
import time
from unittest.mock import AsyncMock, MagicMock
from xml.etree import ElementTree as ET
import pytest
from agentkit.channels.base import ChannelType, IncomingMessage, OutgoingMessage
from agentkit.channels.wecom import WeComError, WeComMessageAdapter, WeComURLVerification
# ---------------------------------------------------------------------------
# 辅助函数
# ---------------------------------------------------------------------------
def _make_aes_key() -> str:
"""生成合法的 43 字符 EncodingAESKey解码为 32 字节)。"""
return base64.b64encode(b"\x11" * 32).decode().rstrip("=")
def _make_adapter(
*,
corp_id: str = "corp_test",
agent_id: int = 1000001,
corp_secret: str = "secret_test",
token: str = "token_test",
encoding_aes_key: str | None = None,
) -> WeComMessageAdapter:
"""构造测试用企微适配器。"""
return WeComMessageAdapter(
corp_id=corp_id,
agent_id=agent_id,
corp_secret=corp_secret,
token=token,
encoding_aes_key=encoding_aes_key or _make_aes_key(),
)
def _build_inner_xml(fields: dict[str, str]) -> str:
"""由字段字典构造内部 XML 字符串。"""
parts = ["<xml>"]
for tag, value in fields.items():
parts.append(f"<{tag}>{value}</{tag}>")
parts.append("</xml>")
return "".join(parts)
def _build_body(adapter: WeComMessageAdapter, inner_xml: str) -> bytes:
"""加密内部 XML 并包装为外层 webhook body。"""
encrypt = adapter._encrypt(inner_xml)
return (
f"<xml><ToUserName>{adapter.corp_id}</ToUserName>"
f"<Encrypt>{encrypt}</Encrypt></xml>"
).encode("utf-8")
def _signature(token: str, timestamp: str, nonce: str, encrypt: str) -> str:
"""计算企微签名sha1(sorted([token, ts, nonce, encrypt]).join(""))。"""
parts = sorted([token, timestamp, nonce, encrypt])
return hashlib.sha1("".join(parts).encode("utf-8")).hexdigest()
# ---------------------------------------------------------------------------
# AES 密钥 / 加解密
# ---------------------------------------------------------------------------
class TestAESKey:
"""EncodingAESKey 解码。"""
def test_decode_aes_key_to_32_bytes(self):
"""43 字符 EncodingAESKey 解码为 32 字节。"""
adapter = _make_adapter(encoding_aes_key=_make_aes_key())
key = adapter._decode_aes_key()
assert len(key) == 32
def test_encrypt_decrypt_roundtrip(self):
"""加密后解密能恢复原文。"""
adapter = _make_adapter()
encrypted = adapter._encrypt("hello 企微")
assert adapter._decrypt(encrypted) == "hello 企微"
class TestSignatureVerification:
"""签名校验。"""
async def test_valid_signature(self):
"""正确 msg_signature 返回 True。"""
adapter = _make_adapter()
inner_xml = _build_inner_xml({"MsgId": "m1", "Content": "hi"})
body = _build_body(adapter, inner_xml)
encrypt = adapter._extract_encrypt(body) or ""
ts, nonce = str(int(time.time())), "n1"
sig = _signature(adapter.token, ts, nonce, encrypt)
headers = {"msg_signature": sig, "timestamp": ts, "nonce": nonce}
assert await adapter.verify_signature(headers, body) is True
async def test_invalid_signature(self):
"""篡改 msg_signature 返回 False。"""
adapter = _make_adapter()
inner_xml = _build_inner_xml({"MsgId": "m1", "Content": "hi"})
body = _build_body(adapter, inner_xml)
ts, nonce = str(int(time.time())), "n1"
headers = {"msg_signature": "tampered", "timestamp": ts, "nonce": nonce}
assert await adapter.verify_signature(headers, body) is False
async def test_expired_timestamp_rejected(self):
"""时间戳超出 5 分钟窗口返回 False防重放"""
adapter = _make_adapter()
inner_xml = _build_inner_xml({"MsgId": "m1", "Content": "hi"})
body = _build_body(adapter, inner_xml)
encrypt = adapter._extract_encrypt(body) or ""
# 10 分钟前的时间戳 — 超出 300s 窗口
old_ts = str(int(time.time()) - 600)
nonce = "n1"
sig = _signature(adapter.token, old_ts, nonce, encrypt)
headers = {"msg_signature": sig, "timestamp": old_ts, "nonce": nonce}
assert await adapter.verify_signature(headers, body) is False
async def test_missing_query_params(self):
"""缺少 msg_signature/timestamp/nonce 返回 False。"""
adapter = _make_adapter()
body = b"<xml><Encrypt>x</Encrypt></xml>"
assert await adapter.verify_signature({}, body) is False
# ---------------------------------------------------------------------------
# URL 验证流程
# ---------------------------------------------------------------------------
class TestURLVerification:
"""企微 URL 验证流程。"""
async def test_url_verification_raises_with_xml(self):
"""含 EchoStr 的加密事件抛 WeComURLVerification响应 XML 可解密回 echo。"""
adapter = _make_adapter()
inner_xml = _build_inner_xml({"EchoStr": "echo123"})
body = _build_body(adapter, inner_xml)
ts, nonce = "1609459200", "n1"
headers = {"timestamp": ts, "nonce": nonce}
with pytest.raises(WeComURLVerification) as exc_info:
await adapter.receive_message(headers, body)
# 验证响应 XML 可解密回 echo 值
response_xml = exc_info.value.response_xml
root = ET.fromstring(response_xml)
encrypted_echo = root.find("Encrypt").text # type: ignore[union-attr]
assert adapter._decrypt(encrypted_echo) == "echo123"
async def test_url_verification_missing_encrypt_raises(self):
"""body 缺少 Encrypt 字段抛 WeComError。"""
adapter = _make_adapter()
body = b"<xml><ToUserName>x</ToUserName></xml>"
with pytest.raises(WeComError):
await adapter.receive_message({}, body)
# ---------------------------------------------------------------------------
# 消息解析
# ---------------------------------------------------------------------------
class TestMessageParsing:
"""消息解析。"""
async def test_receive_message_fields(self):
"""解析普通消息提取 FromUserName/MsgId/Content。"""
adapter = _make_adapter()
inner_xml = _build_inner_xml(
{
"FromUserName": "user1",
"MsgId": "msg_001",
"Content": "hello wecom",
"CreateTime": "1609459200",
"MsgType": "text",
}
)
body = _build_body(adapter, inner_xml)
msg = await adapter.receive_message({}, body)
assert isinstance(msg, IncomingMessage)
assert msg.channel == ChannelType.WECOM
assert msg.user_id == "user1"
assert msg.platform_message_id == "msg_001"
assert msg.content == "hello wecom"
assert msg.timestamp == "1609459200"
async def test_chat_id_composition(self):
"""chat_id 由 corp_id + from_user 复合而成。"""
adapter = _make_adapter(corp_id="my_corp")
inner_xml = _build_inner_xml({"FromUserName": "userA", "MsgId": "m1", "Content": "hi"})
body = _build_body(adapter, inner_xml)
msg = await adapter.receive_message({}, body)
assert msg.chat_id == "my_corp:userA"
async def test_chat_room_mention_stripping(self):
"""群聊消息内容以 ":" 开头时被剥离。"""
adapter = _make_adapter()
inner_xml = _build_inner_xml(
{"FromUserName": "u1", "MsgId": "m1", "Content": ":hello room"}
)
body = _build_body(adapter, inner_xml)
msg = await adapter.receive_message({}, body)
assert msg.content == "hello room"
class TestAppIDMismatch:
"""app_id 校验。"""
async def test_app_id_mismatch_raises(self):
"""解密后 app_id 不匹配 corp_id 抛 WeComError。"""
adapter = _make_adapter(corp_id="corp_a")
# 用不同 corp_id 的适配器加密AES 密钥相同
other = _make_adapter(corp_id="corp_b")
inner_xml = _build_inner_xml({"FromUserName": "u1", "MsgId": "m1", "Content": "hi"})
body = _build_body(other, inner_xml)
with pytest.raises(WeComError):
await adapter.receive_message({}, body)
# ---------------------------------------------------------------------------
# send_message
# ---------------------------------------------------------------------------
class TestSendMessage:
"""send_message 行为。"""
async def test_send_message_success(self):
"""HTTP 200 + errcode=0 返回 True。"""
adapter = _make_adapter()
mock_token = MagicMock()
mock_token.status_code = 200
mock_token.json.return_value = {"errcode": 0, "access_token": "tok_123"}
mock_send = MagicMock()
mock_send.status_code = 200
mock_send.json.return_value = {"errcode": 0, "errmsg": "ok"}
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_token)
mock_client.post = AsyncMock(return_value=mock_send)
adapter._client = mock_client
out = OutgoingMessage(channel=ChannelType.WECOM, chat_id="c1", content="hi")
assert await adapter.send_message(out) is True
send_call = mock_client.post.call_args
assert "message/send" in send_call.args[0]
assert send_call.kwargs["params"]["access_token"] == "tok_123"
assert send_call.kwargs["json"]["agentid"] == 1000001
async def test_send_message_business_failure(self):
"""HTTP 200 但 errcode != 0 返回 False。"""
adapter = _make_adapter()
mock_token = MagicMock()
mock_token.status_code = 200
mock_token.json.return_value = {"errcode": 0, "access_token": "tok_x"}
mock_send = MagicMock()
mock_send.status_code = 200
mock_send.json.return_value = {"errcode": 40014, "errmsg": "invalid token"}
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_token)
mock_client.post = AsyncMock(return_value=mock_send)
adapter._client = mock_client
out = OutgoingMessage(channel=ChannelType.WECOM, chat_id="c1", content="hi")
assert await adapter.send_message(out) is False
async def test_access_token_caching(self):
"""同 TTL 内两次 send_message 只拉取一次 access_token。"""
adapter = _make_adapter()
mock_token = MagicMock()
mock_token.status_code = 200
mock_token.json.return_value = {"errcode": 0, "access_token": "cached_tok"}
mock_send = MagicMock()
mock_send.status_code = 200
mock_send.json.return_value = {"errcode": 0}
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_token)
mock_client.post = AsyncMock(return_value=mock_send)
adapter._client = mock_client
out = OutgoingMessage(channel=ChannelType.WECOM, chat_id="c1", content="hi")
await adapter.send_message(out)
await adapter.send_message(out)
# 仅 1 次 tokenGET+ 2 次 sendPOST
assert mock_client.get.call_count == 1
assert mock_client.post.call_count == 2
# ---------------------------------------------------------------------------
# 资源释放
# ---------------------------------------------------------------------------
class TestClose:
"""资源释放。"""
async def test_close_no_client_is_noop(self):
"""未创建 httpx 客户端时 close 不抛异常。"""
adapter = _make_adapter()
await adapter.close()
async def test_close_resets_client(self):
"""close 后客户端引用清空。"""
adapter = _make_adapter()
adapter._get_client()
assert adapter._client is not None
await adapter.close()
assert adapter._client is None