338 lines
12 KiB
Python
338 lines
12 KiB
Python
"""企业微信 IM 适配器单元测试 (U12)。
|
||
|
||
覆盖场景:
|
||
- EncodingAESKey 解码(43 字符 → 32 字节)
|
||
- 签名校验(有效/无效)
|
||
- AES-256-CBC 加解密往返
|
||
- URL 验证流程(XML 响应)
|
||
- 消息解析(FromUserName/MsgId/Content)
|
||
- send_message 成功/失败
|
||
- access_token 缓存
|
||
- app_id 不匹配 → 抛异常
|
||
- chat_id 复合键
|
||
- 群聊消息 ":" 前缀剥离
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import base64
|
||
import hashlib
|
||
import time
|
||
from unittest.mock import AsyncMock, MagicMock
|
||
from xml.etree import ElementTree as ET
|
||
|
||
import pytest
|
||
|
||
from agentkit.channels.base import ChannelType, IncomingMessage, OutgoingMessage
|
||
from agentkit.channels.wecom import WeComError, WeComMessageAdapter, WeComURLVerification
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 辅助函数
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def _make_aes_key() -> str:
|
||
"""生成合法的 43 字符 EncodingAESKey(解码为 32 字节)。"""
|
||
return base64.b64encode(b"\x11" * 32).decode().rstrip("=")
|
||
|
||
|
||
def _make_adapter(
|
||
*,
|
||
corp_id: str = "corp_test",
|
||
agent_id: int = 1000001,
|
||
corp_secret: str = "secret_test",
|
||
token: str = "token_test",
|
||
encoding_aes_key: str | None = None,
|
||
) -> WeComMessageAdapter:
|
||
"""构造测试用企微适配器。"""
|
||
return WeComMessageAdapter(
|
||
corp_id=corp_id,
|
||
agent_id=agent_id,
|
||
corp_secret=corp_secret,
|
||
token=token,
|
||
encoding_aes_key=encoding_aes_key or _make_aes_key(),
|
||
)
|
||
|
||
|
||
def _build_inner_xml(fields: dict[str, str]) -> str:
|
||
"""由字段字典构造内部 XML 字符串。"""
|
||
parts = ["<xml>"]
|
||
for tag, value in fields.items():
|
||
parts.append(f"<{tag}>{value}</{tag}>")
|
||
parts.append("</xml>")
|
||
return "".join(parts)
|
||
|
||
|
||
def _build_body(adapter: WeComMessageAdapter, inner_xml: str) -> bytes:
|
||
"""加密内部 XML 并包装为外层 webhook body。"""
|
||
encrypt = adapter._encrypt(inner_xml)
|
||
return (
|
||
f"<xml><ToUserName>{adapter.corp_id}</ToUserName>"
|
||
f"<Encrypt>{encrypt}</Encrypt></xml>"
|
||
).encode("utf-8")
|
||
|
||
|
||
def _signature(token: str, timestamp: str, nonce: str, encrypt: str) -> str:
|
||
"""计算企微签名:sha1(sorted([token, ts, nonce, encrypt]).join(""))。"""
|
||
parts = sorted([token, timestamp, nonce, encrypt])
|
||
return hashlib.sha1("".join(parts).encode("utf-8")).hexdigest()
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# AES 密钥 / 加解密
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestAESKey:
|
||
"""EncodingAESKey 解码。"""
|
||
|
||
def test_decode_aes_key_to_32_bytes(self):
|
||
"""43 字符 EncodingAESKey 解码为 32 字节。"""
|
||
adapter = _make_adapter(encoding_aes_key=_make_aes_key())
|
||
key = adapter._decode_aes_key()
|
||
assert len(key) == 32
|
||
|
||
def test_encrypt_decrypt_roundtrip(self):
|
||
"""加密后解密能恢复原文。"""
|
||
adapter = _make_adapter()
|
||
encrypted = adapter._encrypt("hello 企微")
|
||
assert adapter._decrypt(encrypted) == "hello 企微"
|
||
|
||
|
||
class TestSignatureVerification:
|
||
"""签名校验。"""
|
||
|
||
async def test_valid_signature(self):
|
||
"""正确 msg_signature 返回 True。"""
|
||
adapter = _make_adapter()
|
||
inner_xml = _build_inner_xml({"MsgId": "m1", "Content": "hi"})
|
||
body = _build_body(adapter, inner_xml)
|
||
encrypt = adapter._extract_encrypt(body) or ""
|
||
ts, nonce = str(int(time.time())), "n1"
|
||
sig = _signature(adapter.token, ts, nonce, encrypt)
|
||
headers = {"msg_signature": sig, "timestamp": ts, "nonce": nonce}
|
||
assert await adapter.verify_signature(headers, body) is True
|
||
|
||
async def test_invalid_signature(self):
|
||
"""篡改 msg_signature 返回 False。"""
|
||
adapter = _make_adapter()
|
||
inner_xml = _build_inner_xml({"MsgId": "m1", "Content": "hi"})
|
||
body = _build_body(adapter, inner_xml)
|
||
ts, nonce = str(int(time.time())), "n1"
|
||
headers = {"msg_signature": "tampered", "timestamp": ts, "nonce": nonce}
|
||
assert await adapter.verify_signature(headers, body) is False
|
||
|
||
async def test_expired_timestamp_rejected(self):
|
||
"""时间戳超出 5 分钟窗口返回 False(防重放)。"""
|
||
adapter = _make_adapter()
|
||
inner_xml = _build_inner_xml({"MsgId": "m1", "Content": "hi"})
|
||
body = _build_body(adapter, inner_xml)
|
||
encrypt = adapter._extract_encrypt(body) or ""
|
||
# 10 分钟前的时间戳 — 超出 300s 窗口
|
||
old_ts = str(int(time.time()) - 600)
|
||
nonce = "n1"
|
||
sig = _signature(adapter.token, old_ts, nonce, encrypt)
|
||
headers = {"msg_signature": sig, "timestamp": old_ts, "nonce": nonce}
|
||
assert await adapter.verify_signature(headers, body) is False
|
||
|
||
async def test_missing_query_params(self):
|
||
"""缺少 msg_signature/timestamp/nonce 返回 False。"""
|
||
adapter = _make_adapter()
|
||
body = b"<xml><Encrypt>x</Encrypt></xml>"
|
||
assert await adapter.verify_signature({}, body) is False
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# URL 验证流程
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestURLVerification:
|
||
"""企微 URL 验证流程。"""
|
||
|
||
async def test_url_verification_raises_with_xml(self):
|
||
"""含 EchoStr 的加密事件抛 WeComURLVerification,响应 XML 可解密回 echo。"""
|
||
adapter = _make_adapter()
|
||
inner_xml = _build_inner_xml({"EchoStr": "echo123"})
|
||
body = _build_body(adapter, inner_xml)
|
||
ts, nonce = "1609459200", "n1"
|
||
headers = {"timestamp": ts, "nonce": nonce}
|
||
|
||
with pytest.raises(WeComURLVerification) as exc_info:
|
||
await adapter.receive_message(headers, body)
|
||
|
||
# 验证响应 XML 可解密回 echo 值
|
||
response_xml = exc_info.value.response_xml
|
||
root = ET.fromstring(response_xml)
|
||
encrypted_echo = root.find("Encrypt").text # type: ignore[union-attr]
|
||
assert adapter._decrypt(encrypted_echo) == "echo123"
|
||
|
||
async def test_url_verification_missing_encrypt_raises(self):
|
||
"""body 缺少 Encrypt 字段抛 WeComError。"""
|
||
adapter = _make_adapter()
|
||
body = b"<xml><ToUserName>x</ToUserName></xml>"
|
||
with pytest.raises(WeComError):
|
||
await adapter.receive_message({}, body)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 消息解析
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestMessageParsing:
|
||
"""消息解析。"""
|
||
|
||
async def test_receive_message_fields(self):
|
||
"""解析普通消息提取 FromUserName/MsgId/Content。"""
|
||
adapter = _make_adapter()
|
||
inner_xml = _build_inner_xml(
|
||
{
|
||
"FromUserName": "user1",
|
||
"MsgId": "msg_001",
|
||
"Content": "hello wecom",
|
||
"CreateTime": "1609459200",
|
||
"MsgType": "text",
|
||
}
|
||
)
|
||
body = _build_body(adapter, inner_xml)
|
||
msg = await adapter.receive_message({}, body)
|
||
assert isinstance(msg, IncomingMessage)
|
||
assert msg.channel == ChannelType.WECOM
|
||
assert msg.user_id == "user1"
|
||
assert msg.platform_message_id == "msg_001"
|
||
assert msg.content == "hello wecom"
|
||
assert msg.timestamp == "1609459200"
|
||
|
||
async def test_chat_id_composition(self):
|
||
"""chat_id 由 corp_id + from_user 复合而成。"""
|
||
adapter = _make_adapter(corp_id="my_corp")
|
||
inner_xml = _build_inner_xml({"FromUserName": "userA", "MsgId": "m1", "Content": "hi"})
|
||
body = _build_body(adapter, inner_xml)
|
||
msg = await adapter.receive_message({}, body)
|
||
assert msg.chat_id == "my_corp:userA"
|
||
|
||
async def test_chat_room_mention_stripping(self):
|
||
"""群聊消息内容以 ":" 开头时被剥离。"""
|
||
adapter = _make_adapter()
|
||
inner_xml = _build_inner_xml(
|
||
{"FromUserName": "u1", "MsgId": "m1", "Content": ":hello room"}
|
||
)
|
||
body = _build_body(adapter, inner_xml)
|
||
msg = await adapter.receive_message({}, body)
|
||
assert msg.content == "hello room"
|
||
|
||
|
||
class TestAppIDMismatch:
|
||
"""app_id 校验。"""
|
||
|
||
async def test_app_id_mismatch_raises(self):
|
||
"""解密后 app_id 不匹配 corp_id 抛 WeComError。"""
|
||
adapter = _make_adapter(corp_id="corp_a")
|
||
# 用不同 corp_id 的适配器加密,AES 密钥相同
|
||
other = _make_adapter(corp_id="corp_b")
|
||
inner_xml = _build_inner_xml({"FromUserName": "u1", "MsgId": "m1", "Content": "hi"})
|
||
body = _build_body(other, inner_xml)
|
||
|
||
with pytest.raises(WeComError):
|
||
await adapter.receive_message({}, body)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# send_message
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestSendMessage:
|
||
"""send_message 行为。"""
|
||
|
||
async def test_send_message_success(self):
|
||
"""HTTP 200 + errcode=0 返回 True。"""
|
||
adapter = _make_adapter()
|
||
mock_token = MagicMock()
|
||
mock_token.status_code = 200
|
||
mock_token.json.return_value = {"errcode": 0, "access_token": "tok_123"}
|
||
|
||
mock_send = MagicMock()
|
||
mock_send.status_code = 200
|
||
mock_send.json.return_value = {"errcode": 0, "errmsg": "ok"}
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.get = AsyncMock(return_value=mock_token)
|
||
mock_client.post = AsyncMock(return_value=mock_send)
|
||
adapter._client = mock_client
|
||
|
||
out = OutgoingMessage(channel=ChannelType.WECOM, chat_id="c1", content="hi")
|
||
assert await adapter.send_message(out) is True
|
||
|
||
send_call = mock_client.post.call_args
|
||
assert "message/send" in send_call.args[0]
|
||
assert send_call.kwargs["params"]["access_token"] == "tok_123"
|
||
assert send_call.kwargs["json"]["agentid"] == 1000001
|
||
|
||
async def test_send_message_business_failure(self):
|
||
"""HTTP 200 但 errcode != 0 返回 False。"""
|
||
adapter = _make_adapter()
|
||
mock_token = MagicMock()
|
||
mock_token.status_code = 200
|
||
mock_token.json.return_value = {"errcode": 0, "access_token": "tok_x"}
|
||
|
||
mock_send = MagicMock()
|
||
mock_send.status_code = 200
|
||
mock_send.json.return_value = {"errcode": 40014, "errmsg": "invalid token"}
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.get = AsyncMock(return_value=mock_token)
|
||
mock_client.post = AsyncMock(return_value=mock_send)
|
||
adapter._client = mock_client
|
||
|
||
out = OutgoingMessage(channel=ChannelType.WECOM, chat_id="c1", content="hi")
|
||
assert await adapter.send_message(out) is False
|
||
|
||
async def test_access_token_caching(self):
|
||
"""同 TTL 内两次 send_message 只拉取一次 access_token。"""
|
||
adapter = _make_adapter()
|
||
mock_token = MagicMock()
|
||
mock_token.status_code = 200
|
||
mock_token.json.return_value = {"errcode": 0, "access_token": "cached_tok"}
|
||
|
||
mock_send = MagicMock()
|
||
mock_send.status_code = 200
|
||
mock_send.json.return_value = {"errcode": 0}
|
||
|
||
mock_client = AsyncMock()
|
||
mock_client.get = AsyncMock(return_value=mock_token)
|
||
mock_client.post = AsyncMock(return_value=mock_send)
|
||
adapter._client = mock_client
|
||
|
||
out = OutgoingMessage(channel=ChannelType.WECOM, chat_id="c1", content="hi")
|
||
await adapter.send_message(out)
|
||
await adapter.send_message(out)
|
||
|
||
# 仅 1 次 token(GET)+ 2 次 send(POST)
|
||
assert mock_client.get.call_count == 1
|
||
assert mock_client.post.call_count == 2
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 资源释放
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestClose:
|
||
"""资源释放。"""
|
||
|
||
async def test_close_no_client_is_noop(self):
|
||
"""未创建 httpx 客户端时 close 不抛异常。"""
|
||
adapter = _make_adapter()
|
||
await adapter.close()
|
||
|
||
async def test_close_resets_client(self):
|
||
"""close 后客户端引用清空。"""
|
||
adapter = _make_adapter()
|
||
adapter._get_client()
|
||
assert adapter._client is not None
|
||
await adapter.close()
|
||
assert adapter._client is None
|