341 lines
12 KiB
Python
341 lines
12 KiB
Python
"""U4 — Channels webhook Redis 状态迁移单元测试。
|
||
|
||
覆盖场景:
|
||
- nonce dedup:首次 → True,重复 → False,TTL 过期 → 可再次使用
|
||
- rate limit:窗口内 100 请求通过,第 101 → False,窗口滚动后重置
|
||
- backpressure:并发 < 2x 上限 → True,>= 2x 上限 → False,释放后恢复
|
||
- 降级模式:redis=None → 内存实现仍工作
|
||
|
||
不依赖真实 Redis — 使用 _FakeRedis 模拟 redis.asyncio.Redis 的子集。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import fnmatch
|
||
import time
|
||
|
||
import pytest
|
||
|
||
from agentkit.server.routes.channels import (
|
||
_BACKPRESSURE_KEY,
|
||
_RATE_LIMIT_MAX,
|
||
_WEBHOOK_MAX_CONCURRENT,
|
||
_acquire_backpressure_slot,
|
||
_check_nonce_dedup,
|
||
_check_nonce_dedup_redis,
|
||
_check_rate_limit,
|
||
_check_rate_limit_redis,
|
||
_release_backpressure_slot,
|
||
_reset_webhook_state,
|
||
)
|
||
|
||
|
||
# ── _FakeRedis — 支持 SecretsStore + webhook 用到的操作 ───
|
||
|
||
|
||
class _FakeRedis:
|
||
"""极简 Redis mock,支持 string/zset/incr/pipeline/expire/scan_iter。
|
||
|
||
内部数据结构:
|
||
- _strings: dict[str, str] — string 类型(set/get/delete)
|
||
- _zsets: dict[str, dict[str, float]] — sorted set(zadd/zcard/zremrangebyscore)
|
||
- _expires: dict[str, float] — key -> 过期 monotonic 时间戳
|
||
"""
|
||
|
||
def __init__(self):
|
||
self._strings: dict[str, str] = {}
|
||
self._zsets: dict[str, dict[str, float]] = {}
|
||
self._expires: dict[str, float] = {}
|
||
self._incr_values: dict[str, int] = {}
|
||
|
||
def _is_expired(self, key: str) -> bool:
|
||
exp = self._expires.get(key)
|
||
return exp is not None and time.monotonic() >= exp
|
||
|
||
def _cleanup_if_expired(self, key: str) -> None:
|
||
if self._is_expired(key):
|
||
self._strings.pop(key, None)
|
||
self._zsets.pop(key, None)
|
||
self._incr_values.pop(key, None)
|
||
self._expires.pop(key, None)
|
||
|
||
# ── string 操作 ───────────────────────────────────
|
||
|
||
async def set(self, key: str, value: str, *, ex: int | None = None, nx: bool = False) -> bool:
|
||
self._cleanup_if_expired(key)
|
||
if nx and key in self._strings:
|
||
return False
|
||
self._strings[key] = value
|
||
if ex is not None:
|
||
self._expires[key] = time.monotonic() + ex
|
||
return True
|
||
|
||
async def get(self, key: str) -> str | None:
|
||
self._cleanup_if_expired(key)
|
||
return self._strings.get(key)
|
||
|
||
async def delete(self, key: str) -> int:
|
||
existed = key in self._strings or key in self._zsets
|
||
self._strings.pop(key, None)
|
||
self._zsets.pop(key, None)
|
||
self._expires.pop(key, None)
|
||
return 1 if existed else 0
|
||
|
||
# ── zset 操作 ─────────────────────────────────────
|
||
|
||
async def zadd(self, key: str, mapping: dict[str, float]) -> int:
|
||
self._cleanup_if_expired(key)
|
||
zset = self._zsets.setdefault(key, {})
|
||
added = 0
|
||
for member, score in mapping.items():
|
||
if member not in zset:
|
||
added += 1
|
||
zset[member] = score
|
||
return added
|
||
|
||
async def zcard(self, key: str) -> int:
|
||
self._cleanup_if_expired(key)
|
||
return len(self._zsets.get(key, {}))
|
||
|
||
async def zremrangebyscore(self, key: str, min_score: float, max_score: float) -> int:
|
||
self._cleanup_if_expired(key)
|
||
zset = self._zsets.get(key, {})
|
||
removed = 0
|
||
for member in list(zset):
|
||
if min_score <= zset[member] <= max_score:
|
||
del zset[member]
|
||
removed += 1
|
||
return removed
|
||
|
||
# ── incr/decr ─────────────────────────────────────
|
||
|
||
async def incr(self, key: str) -> int:
|
||
self._cleanup_if_expired(key)
|
||
self._incr_values[key] = self._incr_values.get(key, 0) + 1
|
||
return self._incr_values[key]
|
||
|
||
async def decr(self, key: str) -> int:
|
||
self._cleanup_if_expired(key)
|
||
self._incr_values[key] = self._incr_values.get(key, 0) - 1
|
||
return self._incr_values[key]
|
||
|
||
# ── expire ────────────────────────────────────────
|
||
|
||
async def expire(self, key: str, seconds: int) -> bool:
|
||
if key in self._strings or key in self._zsets or key in self._incr_values:
|
||
self._expires[key] = time.monotonic() + seconds
|
||
return True
|
||
return False
|
||
|
||
# ── pipeline ──────────────────────────────────────
|
||
|
||
def pipeline(self) -> "_FakePipeline":
|
||
return _FakePipeline(self)
|
||
|
||
# ── scan_iter ─────────────────────────────────────
|
||
|
||
async def scan_iter(self, match: str = "*"):
|
||
for k in list(self._strings):
|
||
self._cleanup_if_expired(k)
|
||
if k in self._strings and fnmatch.fnmatch(k, match):
|
||
yield k
|
||
|
||
|
||
class _FakePipeline:
|
||
"""模拟 redis pipeline — 收集命令,execute() 返回结果列表。"""
|
||
|
||
def __init__(self, redis: _FakeRedis):
|
||
self._redis = redis
|
||
self._commands: list[tuple[str, tuple, dict]] = []
|
||
|
||
def zremrangebyscore(self, key, min_score, max_score):
|
||
self._commands.append(("zremrangebyscore", (key, min_score, max_score), {}))
|
||
return self
|
||
|
||
def zadd(self, key, mapping):
|
||
self._commands.append(("zadd", (key, mapping), {}))
|
||
return self
|
||
|
||
def zcard(self, key):
|
||
self._commands.append(("zcard", (key,), {}))
|
||
return self
|
||
|
||
def expire(self, key, seconds):
|
||
self._commands.append(("expire", (key, seconds), {}))
|
||
return self
|
||
|
||
def incr(self, key):
|
||
self._commands.append(("incr", (key,), {}))
|
||
return self
|
||
|
||
async def execute(self) -> list:
|
||
results = []
|
||
for cmd, args, kwargs in self._commands:
|
||
method = getattr(self._redis, cmd)
|
||
result = await method(*args, **kwargs)
|
||
results.append(result)
|
||
return results
|
||
|
||
|
||
# ── Fixtures ─────────────────────────────────────────────
|
||
|
||
|
||
@pytest.fixture(autouse=True)
|
||
def _reset_state():
|
||
"""每个测试前重置内存状态。"""
|
||
_reset_webhook_state()
|
||
yield
|
||
_reset_webhook_state()
|
||
|
||
|
||
# ── nonce dedup (Redis) ──────────────────────────────────
|
||
|
||
|
||
class TestNonceDedupRedis:
|
||
"""Redis 后端 nonce 去重。"""
|
||
|
||
async def test_first_nonce_returns_true(self):
|
||
"""首次 nonce → True(新 nonce,应处理)。"""
|
||
redis = _FakeRedis()
|
||
assert await _check_nonce_dedup_redis(redis, "nonce-001") is True
|
||
|
||
async def test_duplicate_nonce_returns_false(self):
|
||
"""重复 nonce → False(跳过处理)。"""
|
||
redis = _FakeRedis()
|
||
assert await _check_nonce_dedup_redis(redis, "nonce-001") is True
|
||
assert await _check_nonce_dedup_redis(redis, "nonce-001") is False
|
||
|
||
async def test_different_nonces_both_return_true(self):
|
||
"""不同 nonce 各自首次 → True。"""
|
||
redis = _FakeRedis()
|
||
assert await _check_nonce_dedup_redis(redis, "nonce-a") is True
|
||
assert await _check_nonce_dedup_redis(redis, "nonce-b") is True
|
||
|
||
async def test_nonce_ttl_expiry_allows_reuse(self):
|
||
"""TTL 过期后相同 nonce 可再次使用。"""
|
||
redis = _FakeRedis()
|
||
# 首次写入
|
||
assert await _check_nonce_dedup_redis(redis, "nonce-exp") is True
|
||
# 模拟 TTL 过期 — 手动清除过期标记
|
||
redis._expires.clear()
|
||
redis._strings.clear()
|
||
# 过期后相同 nonce 可再次使用
|
||
assert await _check_nonce_dedup_redis(redis, "nonce-exp") is True
|
||
|
||
|
||
# ── nonce dedup (内存降级) ───────────────────────────────
|
||
|
||
|
||
class TestNonceDedupFallback:
|
||
"""redis=None 时内存 nonce 去重仍工作。"""
|
||
|
||
def test_fallback_first_nonce_returns_true(self):
|
||
"""内存模式:首次 nonce → True。"""
|
||
assert _check_nonce_dedup("nonce-fb-1") is True
|
||
|
||
def test_fallback_duplicate_returns_false(self):
|
||
"""内存模式:重复 nonce → False。"""
|
||
assert _check_nonce_dedup("nonce-fb-2") is True
|
||
assert _check_nonce_dedup("nonce-fb-2") is False
|
||
|
||
|
||
# ── rate limit (Redis) ───────────────────────────────────
|
||
|
||
|
||
class TestRateLimitRedis:
|
||
"""Redis 后端滑动窗口限流。"""
|
||
|
||
async def test_under_limit_returns_true(self):
|
||
"""窗口内未超限 → True。"""
|
||
redis = _FakeRedis()
|
||
for _ in range(_RATE_LIMIT_MAX):
|
||
assert await _check_rate_limit_redis(redis, "1.2.3.4") is True
|
||
|
||
async def test_over_limit_returns_false(self):
|
||
"""超过窗口上限 → False。"""
|
||
redis = _FakeRedis()
|
||
for _ in range(_RATE_LIMIT_MAX):
|
||
await _check_rate_limit_redis(redis, "1.2.3.4")
|
||
# 第 _RATE_LIMIT_MAX + 1 次应被拒绝
|
||
assert await _check_rate_limit_redis(redis, "1.2.3.4") is False
|
||
|
||
async def test_different_ips_independent(self):
|
||
"""不同 IP 的限流独立。"""
|
||
redis = _FakeRedis()
|
||
for _ in range(_RATE_LIMIT_MAX):
|
||
await _check_rate_limit_redis(redis, "1.1.1.1")
|
||
# IP 1.1.1.1 已满,但 2.2.2.2 仍可通过
|
||
assert await _check_rate_limit_redis(redis, "2.2.2.2") is True
|
||
|
||
async def test_window_reset_after_expiry(self):
|
||
"""窗口过期后计数重置。"""
|
||
redis = _FakeRedis()
|
||
for _ in range(_RATE_LIMIT_MAX):
|
||
await _check_rate_limit_redis(redis, "3.3.3.3")
|
||
assert await _check_rate_limit_redis(redis, "3.3.3.3") is False
|
||
# 模拟窗口过期 — 清除 zset 数据
|
||
redis._zsets.clear()
|
||
redis._expires.clear()
|
||
# 过期后可再次通过
|
||
assert await _check_rate_limit_redis(redis, "3.3.3.3") is True
|
||
|
||
|
||
# ── rate limit (内存降级) ────────────────────────────────
|
||
|
||
|
||
class TestRateLimitFallback:
|
||
"""redis=None 时内存限流仍工作。"""
|
||
|
||
def test_fallback_under_limit(self):
|
||
"""内存模式:未超限 → True。"""
|
||
for _ in range(_RATE_LIMIT_MAX):
|
||
assert _check_rate_limit("4.4.4.4") is True
|
||
|
||
def test_fallback_over_limit(self):
|
||
"""内存模式:超限 → False。"""
|
||
for _ in range(_RATE_LIMIT_MAX):
|
||
_check_rate_limit("5.5.5.5")
|
||
assert _check_rate_limit("5.5.5.5") is False
|
||
|
||
|
||
# ── backpressure (Redis) ─────────────────────────────────
|
||
|
||
|
||
class TestBackpressureRedis:
|
||
"""Redis 后端共享并发计数器。"""
|
||
|
||
async def test_under_limit_returns_true(self):
|
||
"""并发 < 2x 上限 → True。"""
|
||
redis = _FakeRedis()
|
||
for _ in range(_WEBHOOK_MAX_CONCURRENT * 2):
|
||
assert await _acquire_backpressure_slot(redis) is True
|
||
|
||
async def test_over_limit_returns_false(self):
|
||
"""并发 >= 2x 上限 → False。"""
|
||
redis = _FakeRedis()
|
||
for _ in range(_WEBHOOK_MAX_CONCURRENT * 2):
|
||
await _acquire_backpressure_slot(redis)
|
||
# 超过 2x 上限应拒绝
|
||
assert await _acquire_backpressure_slot(redis) is False
|
||
|
||
async def test_release_restores_slot(self):
|
||
"""释放后槽位恢复,可再次获取。"""
|
||
redis = _FakeRedis()
|
||
# 获取到上限
|
||
for _ in range(_WEBHOOK_MAX_CONCURRENT * 2):
|
||
await _acquire_backpressure_slot(redis)
|
||
# 超限
|
||
assert await _acquire_backpressure_slot(redis) is False
|
||
# 释放一个
|
||
await _release_backpressure_slot(redis)
|
||
# 可再次获取
|
||
assert await _acquire_backpressure_slot(redis) is True
|
||
|
||
async def test_release_decrements_counter(self):
|
||
"""release 后计数器递减。"""
|
||
redis = _FakeRedis()
|
||
await _acquire_backpressure_slot(redis)
|
||
assert redis._incr_values[_BACKPRESSURE_KEY] == 1
|
||
await _release_backpressure_slot(redis)
|
||
assert redis._incr_values[_BACKPRESSURE_KEY] == 0
|