fischer-agentkit/tests/unit/channels/test_webhook_redis_state.py

341 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""U4 — Channels webhook Redis 状态迁移单元测试。
覆盖场景:
- nonce dedup首次 → True重复 → FalseTTL 过期 → 可再次使用
- rate limit窗口内 100 请求通过,第 101 → False窗口滚动后重置
- backpressure并发 < 2x 上限 → True>= 2x 上限 → False释放后恢复
- 降级模式redis=None → 内存实现仍工作
不依赖真实 Redis — 使用 _FakeRedis 模拟 redis.asyncio.Redis 的子集。
"""
from __future__ import annotations
import fnmatch
import time
import pytest
from agentkit.server.routes.channels import (
_BACKPRESSURE_KEY,
_RATE_LIMIT_MAX,
_WEBHOOK_MAX_CONCURRENT,
_acquire_backpressure_slot,
_check_nonce_dedup,
_check_nonce_dedup_redis,
_check_rate_limit,
_check_rate_limit_redis,
_release_backpressure_slot,
_reset_webhook_state,
)
# ── _FakeRedis — 支持 SecretsStore + webhook 用到的操作 ───
class _FakeRedis:
"""极简 Redis mock支持 string/zset/incr/pipeline/expire/scan_iter。
内部数据结构:
- _strings: dict[str, str] — string 类型set/get/delete
- _zsets: dict[str, dict[str, float]] — sorted setzadd/zcard/zremrangebyscore
- _expires: dict[str, float] — key -> 过期 monotonic 时间戳
"""
def __init__(self):
self._strings: dict[str, str] = {}
self._zsets: dict[str, dict[str, float]] = {}
self._expires: dict[str, float] = {}
self._incr_values: dict[str, int] = {}
def _is_expired(self, key: str) -> bool:
exp = self._expires.get(key)
return exp is not None and time.monotonic() >= exp
def _cleanup_if_expired(self, key: str) -> None:
if self._is_expired(key):
self._strings.pop(key, None)
self._zsets.pop(key, None)
self._incr_values.pop(key, None)
self._expires.pop(key, None)
# ── string 操作 ───────────────────────────────────
async def set(self, key: str, value: str, *, ex: int | None = None, nx: bool = False) -> bool:
self._cleanup_if_expired(key)
if nx and key in self._strings:
return False
self._strings[key] = value
if ex is not None:
self._expires[key] = time.monotonic() + ex
return True
async def get(self, key: str) -> str | None:
self._cleanup_if_expired(key)
return self._strings.get(key)
async def delete(self, key: str) -> int:
existed = key in self._strings or key in self._zsets
self._strings.pop(key, None)
self._zsets.pop(key, None)
self._expires.pop(key, None)
return 1 if existed else 0
# ── zset 操作 ─────────────────────────────────────
async def zadd(self, key: str, mapping: dict[str, float]) -> int:
self._cleanup_if_expired(key)
zset = self._zsets.setdefault(key, {})
added = 0
for member, score in mapping.items():
if member not in zset:
added += 1
zset[member] = score
return added
async def zcard(self, key: str) -> int:
self._cleanup_if_expired(key)
return len(self._zsets.get(key, {}))
async def zremrangebyscore(self, key: str, min_score: float, max_score: float) -> int:
self._cleanup_if_expired(key)
zset = self._zsets.get(key, {})
removed = 0
for member in list(zset):
if min_score <= zset[member] <= max_score:
del zset[member]
removed += 1
return removed
# ── incr/decr ─────────────────────────────────────
async def incr(self, key: str) -> int:
self._cleanup_if_expired(key)
self._incr_values[key] = self._incr_values.get(key, 0) + 1
return self._incr_values[key]
async def decr(self, key: str) -> int:
self._cleanup_if_expired(key)
self._incr_values[key] = self._incr_values.get(key, 0) - 1
return self._incr_values[key]
# ── expire ────────────────────────────────────────
async def expire(self, key: str, seconds: int) -> bool:
if key in self._strings or key in self._zsets or key in self._incr_values:
self._expires[key] = time.monotonic() + seconds
return True
return False
# ── pipeline ──────────────────────────────────────
def pipeline(self) -> "_FakePipeline":
return _FakePipeline(self)
# ── scan_iter ─────────────────────────────────────
async def scan_iter(self, match: str = "*"):
for k in list(self._strings):
self._cleanup_if_expired(k)
if k in self._strings and fnmatch.fnmatch(k, match):
yield k
class _FakePipeline:
"""模拟 redis pipeline — 收集命令execute() 返回结果列表。"""
def __init__(self, redis: _FakeRedis):
self._redis = redis
self._commands: list[tuple[str, tuple, dict]] = []
def zremrangebyscore(self, key, min_score, max_score):
self._commands.append(("zremrangebyscore", (key, min_score, max_score), {}))
return self
def zadd(self, key, mapping):
self._commands.append(("zadd", (key, mapping), {}))
return self
def zcard(self, key):
self._commands.append(("zcard", (key,), {}))
return self
def expire(self, key, seconds):
self._commands.append(("expire", (key, seconds), {}))
return self
def incr(self, key):
self._commands.append(("incr", (key,), {}))
return self
async def execute(self) -> list:
results = []
for cmd, args, kwargs in self._commands:
method = getattr(self._redis, cmd)
result = await method(*args, **kwargs)
results.append(result)
return results
# ── Fixtures ─────────────────────────────────────────────
@pytest.fixture(autouse=True)
def _reset_state():
"""每个测试前重置内存状态。"""
_reset_webhook_state()
yield
_reset_webhook_state()
# ── nonce dedup (Redis) ──────────────────────────────────
class TestNonceDedupRedis:
"""Redis 后端 nonce 去重。"""
async def test_first_nonce_returns_true(self):
"""首次 nonce → True新 nonce应处理"""
redis = _FakeRedis()
assert await _check_nonce_dedup_redis(redis, "nonce-001") is True
async def test_duplicate_nonce_returns_false(self):
"""重复 nonce → False跳过处理"""
redis = _FakeRedis()
assert await _check_nonce_dedup_redis(redis, "nonce-001") is True
assert await _check_nonce_dedup_redis(redis, "nonce-001") is False
async def test_different_nonces_both_return_true(self):
"""不同 nonce 各自首次 → True。"""
redis = _FakeRedis()
assert await _check_nonce_dedup_redis(redis, "nonce-a") is True
assert await _check_nonce_dedup_redis(redis, "nonce-b") is True
async def test_nonce_ttl_expiry_allows_reuse(self):
"""TTL 过期后相同 nonce 可再次使用。"""
redis = _FakeRedis()
# 首次写入
assert await _check_nonce_dedup_redis(redis, "nonce-exp") is True
# 模拟 TTL 过期 — 手动清除过期标记
redis._expires.clear()
redis._strings.clear()
# 过期后相同 nonce 可再次使用
assert await _check_nonce_dedup_redis(redis, "nonce-exp") is True
# ── nonce dedup (内存降级) ───────────────────────────────
class TestNonceDedupFallback:
"""redis=None 时内存 nonce 去重仍工作。"""
def test_fallback_first_nonce_returns_true(self):
"""内存模式:首次 nonce → True。"""
assert _check_nonce_dedup("nonce-fb-1") is True
def test_fallback_duplicate_returns_false(self):
"""内存模式:重复 nonce → False。"""
assert _check_nonce_dedup("nonce-fb-2") is True
assert _check_nonce_dedup("nonce-fb-2") is False
# ── rate limit (Redis) ───────────────────────────────────
class TestRateLimitRedis:
"""Redis 后端滑动窗口限流。"""
async def test_under_limit_returns_true(self):
"""窗口内未超限 → True。"""
redis = _FakeRedis()
for _ in range(_RATE_LIMIT_MAX):
assert await _check_rate_limit_redis(redis, "1.2.3.4") is True
async def test_over_limit_returns_false(self):
"""超过窗口上限 → False。"""
redis = _FakeRedis()
for _ in range(_RATE_LIMIT_MAX):
await _check_rate_limit_redis(redis, "1.2.3.4")
# 第 _RATE_LIMIT_MAX + 1 次应被拒绝
assert await _check_rate_limit_redis(redis, "1.2.3.4") is False
async def test_different_ips_independent(self):
"""不同 IP 的限流独立。"""
redis = _FakeRedis()
for _ in range(_RATE_LIMIT_MAX):
await _check_rate_limit_redis(redis, "1.1.1.1")
# IP 1.1.1.1 已满,但 2.2.2.2 仍可通过
assert await _check_rate_limit_redis(redis, "2.2.2.2") is True
async def test_window_reset_after_expiry(self):
"""窗口过期后计数重置。"""
redis = _FakeRedis()
for _ in range(_RATE_LIMIT_MAX):
await _check_rate_limit_redis(redis, "3.3.3.3")
assert await _check_rate_limit_redis(redis, "3.3.3.3") is False
# 模拟窗口过期 — 清除 zset 数据
redis._zsets.clear()
redis._expires.clear()
# 过期后可再次通过
assert await _check_rate_limit_redis(redis, "3.3.3.3") is True
# ── rate limit (内存降级) ────────────────────────────────
class TestRateLimitFallback:
"""redis=None 时内存限流仍工作。"""
def test_fallback_under_limit(self):
"""内存模式:未超限 → True。"""
for _ in range(_RATE_LIMIT_MAX):
assert _check_rate_limit("4.4.4.4") is True
def test_fallback_over_limit(self):
"""内存模式:超限 → False。"""
for _ in range(_RATE_LIMIT_MAX):
_check_rate_limit("5.5.5.5")
assert _check_rate_limit("5.5.5.5") is False
# ── backpressure (Redis) ─────────────────────────────────
class TestBackpressureRedis:
"""Redis 后端共享并发计数器。"""
async def test_under_limit_returns_true(self):
"""并发 < 2x 上限 → True。"""
redis = _FakeRedis()
for _ in range(_WEBHOOK_MAX_CONCURRENT * 2):
assert await _acquire_backpressure_slot(redis) is True
async def test_over_limit_returns_false(self):
"""并发 >= 2x 上限 → False。"""
redis = _FakeRedis()
for _ in range(_WEBHOOK_MAX_CONCURRENT * 2):
await _acquire_backpressure_slot(redis)
# 超过 2x 上限应拒绝
assert await _acquire_backpressure_slot(redis) is False
async def test_release_restores_slot(self):
"""释放后槽位恢复,可再次获取。"""
redis = _FakeRedis()
# 获取到上限
for _ in range(_WEBHOOK_MAX_CONCURRENT * 2):
await _acquire_backpressure_slot(redis)
# 超限
assert await _acquire_backpressure_slot(redis) is False
# 释放一个
await _release_backpressure_slot(redis)
# 可再次获取
assert await _acquire_backpressure_slot(redis) is True
async def test_release_decrements_counter(self):
"""release 后计数器递减。"""
redis = _FakeRedis()
await _acquire_backpressure_slot(redis)
assert redis._incr_values[_BACKPRESSURE_KEY] == 1
await _release_backpressure_slot(redis)
assert redis._incr_values[_BACKPRESSURE_KEY] == 0