"""U4 — Channels webhook Redis 状态迁移单元测试。 覆盖场景: - nonce dedup:首次 → True,重复 → False,TTL 过期 → 可再次使用 - rate limit:窗口内 100 请求通过,第 101 → False,窗口滚动后重置 - backpressure:并发 < 2x 上限 → True,>= 2x 上限 → False,释放后恢复 - 降级模式:redis=None → 内存实现仍工作 不依赖真实 Redis — 使用 _FakeRedis 模拟 redis.asyncio.Redis 的子集。 """ from __future__ import annotations import fnmatch import time import pytest from agentkit.server.routes.channels import ( _BACKPRESSURE_KEY, _RATE_LIMIT_MAX, _WEBHOOK_MAX_CONCURRENT, _acquire_backpressure_slot, _check_nonce_dedup, _check_nonce_dedup_redis, _check_rate_limit, _check_rate_limit_redis, _release_backpressure_slot, _reset_webhook_state, ) # ── _FakeRedis — 支持 SecretsStore + webhook 用到的操作 ─── class _FakeRedis: """极简 Redis mock,支持 string/zset/incr/pipeline/expire/scan_iter。 内部数据结构: - _strings: dict[str, str] — string 类型(set/get/delete) - _zsets: dict[str, dict[str, float]] — sorted set(zadd/zcard/zremrangebyscore) - _expires: dict[str, float] — key -> 过期 monotonic 时间戳 """ def __init__(self): self._strings: dict[str, str] = {} self._zsets: dict[str, dict[str, float]] = {} self._expires: dict[str, float] = {} self._incr_values: dict[str, int] = {} def _is_expired(self, key: str) -> bool: exp = self._expires.get(key) return exp is not None and time.monotonic() >= exp def _cleanup_if_expired(self, key: str) -> None: if self._is_expired(key): self._strings.pop(key, None) self._zsets.pop(key, None) self._incr_values.pop(key, None) self._expires.pop(key, None) # ── string 操作 ─────────────────────────────────── async def set(self, key: str, value: str, *, ex: int | None = None, nx: bool = False) -> bool: self._cleanup_if_expired(key) if nx and key in self._strings: return False self._strings[key] = value if ex is not None: self._expires[key] = time.monotonic() + ex return True async def get(self, key: str) -> str | None: self._cleanup_if_expired(key) return self._strings.get(key) async def delete(self, key: str) -> int: existed = key in self._strings or key in self._zsets self._strings.pop(key, None) self._zsets.pop(key, None) self._expires.pop(key, None) return 1 if existed else 0 # ── zset 操作 ───────────────────────────────────── async def zadd(self, key: str, mapping: dict[str, float]) -> int: self._cleanup_if_expired(key) zset = self._zsets.setdefault(key, {}) added = 0 for member, score in mapping.items(): if member not in zset: added += 1 zset[member] = score return added async def zcard(self, key: str) -> int: self._cleanup_if_expired(key) return len(self._zsets.get(key, {})) async def zremrangebyscore(self, key: str, min_score: float, max_score: float) -> int: self._cleanup_if_expired(key) zset = self._zsets.get(key, {}) removed = 0 for member in list(zset): if min_score <= zset[member] <= max_score: del zset[member] removed += 1 return removed # ── incr/decr ───────────────────────────────────── async def incr(self, key: str) -> int: self._cleanup_if_expired(key) self._incr_values[key] = self._incr_values.get(key, 0) + 1 return self._incr_values[key] async def decr(self, key: str) -> int: self._cleanup_if_expired(key) self._incr_values[key] = self._incr_values.get(key, 0) - 1 return self._incr_values[key] # ── expire ──────────────────────────────────────── async def expire(self, key: str, seconds: int) -> bool: if key in self._strings or key in self._zsets or key in self._incr_values: self._expires[key] = time.monotonic() + seconds return True return False # ── pipeline ────────────────────────────────────── def pipeline(self) -> "_FakePipeline": return _FakePipeline(self) # ── scan_iter ───────────────────────────────────── async def scan_iter(self, match: str = "*"): for k in list(self._strings): self._cleanup_if_expired(k) if k in self._strings and fnmatch.fnmatch(k, match): yield k class _FakePipeline: """模拟 redis pipeline — 收集命令,execute() 返回结果列表。""" def __init__(self, redis: _FakeRedis): self._redis = redis self._commands: list[tuple[str, tuple, dict]] = [] def zremrangebyscore(self, key, min_score, max_score): self._commands.append(("zremrangebyscore", (key, min_score, max_score), {})) return self def zadd(self, key, mapping): self._commands.append(("zadd", (key, mapping), {})) return self def zcard(self, key): self._commands.append(("zcard", (key,), {})) return self def expire(self, key, seconds): self._commands.append(("expire", (key, seconds), {})) return self def incr(self, key): self._commands.append(("incr", (key,), {})) return self async def execute(self) -> list: results = [] for cmd, args, kwargs in self._commands: method = getattr(self._redis, cmd) result = await method(*args, **kwargs) results.append(result) return results # ── Fixtures ───────────────────────────────────────────── @pytest.fixture(autouse=True) def _reset_state(): """每个测试前重置内存状态。""" _reset_webhook_state() yield _reset_webhook_state() # ── nonce dedup (Redis) ────────────────────────────────── class TestNonceDedupRedis: """Redis 后端 nonce 去重。""" async def test_first_nonce_returns_true(self): """首次 nonce → True(新 nonce,应处理)。""" redis = _FakeRedis() assert await _check_nonce_dedup_redis(redis, "nonce-001") is True async def test_duplicate_nonce_returns_false(self): """重复 nonce → False(跳过处理)。""" redis = _FakeRedis() assert await _check_nonce_dedup_redis(redis, "nonce-001") is True assert await _check_nonce_dedup_redis(redis, "nonce-001") is False async def test_different_nonces_both_return_true(self): """不同 nonce 各自首次 → True。""" redis = _FakeRedis() assert await _check_nonce_dedup_redis(redis, "nonce-a") is True assert await _check_nonce_dedup_redis(redis, "nonce-b") is True async def test_nonce_ttl_expiry_allows_reuse(self): """TTL 过期后相同 nonce 可再次使用。""" redis = _FakeRedis() # 首次写入 assert await _check_nonce_dedup_redis(redis, "nonce-exp") is True # 模拟 TTL 过期 — 手动清除过期标记 redis._expires.clear() redis._strings.clear() # 过期后相同 nonce 可再次使用 assert await _check_nonce_dedup_redis(redis, "nonce-exp") is True # ── nonce dedup (内存降级) ─────────────────────────────── class TestNonceDedupFallback: """redis=None 时内存 nonce 去重仍工作。""" def test_fallback_first_nonce_returns_true(self): """内存模式:首次 nonce → True。""" assert _check_nonce_dedup("nonce-fb-1") is True def test_fallback_duplicate_returns_false(self): """内存模式:重复 nonce → False。""" assert _check_nonce_dedup("nonce-fb-2") is True assert _check_nonce_dedup("nonce-fb-2") is False # ── rate limit (Redis) ─────────────────────────────────── class TestRateLimitRedis: """Redis 后端滑动窗口限流。""" async def test_under_limit_returns_true(self): """窗口内未超限 → True。""" redis = _FakeRedis() for _ in range(_RATE_LIMIT_MAX): assert await _check_rate_limit_redis(redis, "1.2.3.4") is True async def test_over_limit_returns_false(self): """超过窗口上限 → False。""" redis = _FakeRedis() for _ in range(_RATE_LIMIT_MAX): await _check_rate_limit_redis(redis, "1.2.3.4") # 第 _RATE_LIMIT_MAX + 1 次应被拒绝 assert await _check_rate_limit_redis(redis, "1.2.3.4") is False async def test_different_ips_independent(self): """不同 IP 的限流独立。""" redis = _FakeRedis() for _ in range(_RATE_LIMIT_MAX): await _check_rate_limit_redis(redis, "1.1.1.1") # IP 1.1.1.1 已满,但 2.2.2.2 仍可通过 assert await _check_rate_limit_redis(redis, "2.2.2.2") is True async def test_window_reset_after_expiry(self): """窗口过期后计数重置。""" redis = _FakeRedis() for _ in range(_RATE_LIMIT_MAX): await _check_rate_limit_redis(redis, "3.3.3.3") assert await _check_rate_limit_redis(redis, "3.3.3.3") is False # 模拟窗口过期 — 清除 zset 数据 redis._zsets.clear() redis._expires.clear() # 过期后可再次通过 assert await _check_rate_limit_redis(redis, "3.3.3.3") is True # ── rate limit (内存降级) ──────────────────────────────── class TestRateLimitFallback: """redis=None 时内存限流仍工作。""" def test_fallback_under_limit(self): """内存模式:未超限 → True。""" for _ in range(_RATE_LIMIT_MAX): assert _check_rate_limit("4.4.4.4") is True def test_fallback_over_limit(self): """内存模式:超限 → False。""" for _ in range(_RATE_LIMIT_MAX): _check_rate_limit("5.5.5.5") assert _check_rate_limit("5.5.5.5") is False # ── backpressure (Redis) ───────────────────────────────── class TestBackpressureRedis: """Redis 后端共享并发计数器。""" async def test_under_limit_returns_true(self): """并发 < 2x 上限 → True。""" redis = _FakeRedis() for _ in range(_WEBHOOK_MAX_CONCURRENT * 2): assert await _acquire_backpressure_slot(redis) is True async def test_over_limit_returns_false(self): """并发 >= 2x 上限 → False。""" redis = _FakeRedis() for _ in range(_WEBHOOK_MAX_CONCURRENT * 2): await _acquire_backpressure_slot(redis) # 超过 2x 上限应拒绝 assert await _acquire_backpressure_slot(redis) is False async def test_release_restores_slot(self): """释放后槽位恢复,可再次获取。""" redis = _FakeRedis() # 获取到上限 for _ in range(_WEBHOOK_MAX_CONCURRENT * 2): await _acquire_backpressure_slot(redis) # 超限 assert await _acquire_backpressure_slot(redis) is False # 释放一个 await _release_backpressure_slot(redis) # 可再次获取 assert await _acquire_backpressure_slot(redis) is True async def test_release_decrements_counter(self): """release 后计数器递减。""" redis = _FakeRedis() await _acquire_backpressure_slot(redis) assert redis._incr_values[_BACKPRESSURE_KEY] == 1 await _release_backpressure_slot(redis) assert redis._incr_values[_BACKPRESSURE_KEY] == 0