feat(server): Phase B - auth, rate limiting, SSRF protection, handler whitelist
U1: API Key authentication middleware (dev mode skip, health whitelist) U2: Rate limiting middleware (fixed-window, 60 req/min default) U3: Callback URL SSRF protection (private IP blocking) U4: custom_handler module prefix whitelist 65 tests passing. CORS conflict fixed.
This commit is contained in:
parent
f87b790c0f
commit
5f1c51cf9a
|
|
@ -184,6 +184,12 @@ class ConfigDrivenAgent(BaseAgent):
|
|||
- retrieve_knowledge
|
||||
"""
|
||||
|
||||
# Security: whitelist of allowed module prefixes for dynamic handler import
|
||||
_ALLOWED_HANDLER_PREFIXES = (
|
||||
"agentkit.",
|
||||
"app.agent_framework.",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: AgentConfig,
|
||||
|
|
@ -566,6 +572,14 @@ class ConfigDrivenAgent(BaseAgent):
|
|||
|
||||
def _import_handler(self, dotted_path: str) -> Callable[..., Coroutine]:
|
||||
"""动态导入自定义 handler"""
|
||||
# Security: validate module prefix to prevent arbitrary code execution
|
||||
if not any(dotted_path.startswith(prefix) for prefix in self._ALLOWED_HANDLER_PREFIXES):
|
||||
raise ConfigValidationError(
|
||||
agent_name=self.name,
|
||||
key="custom_handler",
|
||||
reason=f"Handler '{dotted_path}' is not in allowed module prefixes: {self._ALLOWED_HANDLER_PREFIXES}",
|
||||
)
|
||||
|
||||
try:
|
||||
module_path, func_name = dotted_path.rsplit(".", 1)
|
||||
import importlib
|
||||
|
|
|
|||
|
|
@ -3,11 +3,13 @@
|
|||
与业务系统解耦:通过依赖注入获取 Redis 连接和数据库会话。
|
||||
"""
|
||||
|
||||
import ipaddress
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Callable, Awaitable
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from agentkit.core.exceptions import (
|
||||
NoAvailableAgentError,
|
||||
|
|
@ -24,6 +26,54 @@ from agentkit.core.protocol import (
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_PRIVATE_NETWORKS = [
|
||||
ipaddress.ip_network("127.0.0.0/8"),
|
||||
ipaddress.ip_network("10.0.0.0/8"),
|
||||
ipaddress.ip_network("172.16.0.0/12"),
|
||||
ipaddress.ip_network("192.168.0.0/16"),
|
||||
ipaddress.ip_network("169.254.0.0/16"),
|
||||
ipaddress.ip_network("::1/128"),
|
||||
ipaddress.ip_network("fc00::/7"),
|
||||
ipaddress.ip_network("fe80::/10"),
|
||||
]
|
||||
|
||||
|
||||
def _validate_callback_url(url: str) -> bool:
|
||||
"""Validate callback URL to prevent SSRF attacks.
|
||||
|
||||
Rules:
|
||||
- Only http/https protocols allowed
|
||||
- No localhost or loopback addresses
|
||||
- No private/internal IP ranges
|
||||
- No link-local addresses
|
||||
|
||||
Returns True if valid, False if should be blocked.
|
||||
"""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
return False
|
||||
|
||||
hostname = parsed.hostname
|
||||
if not hostname:
|
||||
return False
|
||||
|
||||
if hostname.lower() in ("localhost", "127.0.0.1", "::1"):
|
||||
return False
|
||||
|
||||
try:
|
||||
ip = ipaddress.ip_address(hostname)
|
||||
for network in _PRIVATE_NETWORKS:
|
||||
if ip in network:
|
||||
return False
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class TaskDispatcher:
|
||||
"""任务分发器,通过 Redis Queue 将任务分发给 Agent"""
|
||||
|
|
@ -333,6 +383,10 @@ class TaskDispatcher:
|
|||
db.add(log_entry)
|
||||
|
||||
async def _trigger_callback(self, callback_url: str, result: TaskResult):
|
||||
if not _validate_callback_url(callback_url):
|
||||
logger.warning(f"Callback URL rejected (SSRF protection): {callback_url}")
|
||||
return
|
||||
|
||||
try:
|
||||
import httpx
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
"""FastAPI Application Factory"""
|
||||
|
||||
import os
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
|
|
@ -11,12 +12,15 @@ from agentkit.router.intent import IntentRouter
|
|||
from agentkit.skills.registry import SkillRegistry
|
||||
from agentkit.tools.registry import ToolRegistry
|
||||
from agentkit.server.routes import agents, tasks, skills, llm, health
|
||||
from agentkit.server.middleware import APIKeyAuthMiddleware, RateLimitMiddleware
|
||||
|
||||
|
||||
def create_app(
|
||||
llm_gateway: LLMGateway | None = None,
|
||||
skill_registry: SkillRegistry | None = None,
|
||||
tool_registry: ToolRegistry | None = None,
|
||||
api_key: str | None = None,
|
||||
rate_limit: int | None = None,
|
||||
) -> FastAPI:
|
||||
"""Create and configure the FastAPI application"""
|
||||
app = FastAPI(title="AgentKit Server", version="2.0.0")
|
||||
|
|
@ -25,11 +29,20 @@ def create_app(
|
|||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # 生产环境应限制具体域名
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Auth middleware
|
||||
if api_key:
|
||||
os.environ["AGENTKIT_API_KEY"] = api_key
|
||||
app.add_middleware(APIKeyAuthMiddleware)
|
||||
|
||||
# Rate limiting middleware
|
||||
if rate_limit is not None:
|
||||
os.environ["AGENTKIT_RATE_LIMIT_PER_MINUTE"] = str(rate_limit)
|
||||
app.add_middleware(RateLimitMiddleware)
|
||||
|
||||
# Initialize shared state
|
||||
app.state.llm_gateway = llm_gateway or LLMGateway()
|
||||
app.state.skill_registry = skill_registry or SkillRegistry()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,105 @@
|
|||
"""Server middleware - Authentication and Rate Limiting"""
|
||||
|
||||
import os
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
|
||||
class APIKeyAuthMiddleware(BaseHTTPMiddleware):
|
||||
"""API Key authentication middleware.
|
||||
|
||||
Validates X-API-Key header against AGENTKIT_API_KEY env var.
|
||||
Skips validation if AGENTKIT_API_KEY is not set (dev mode).
|
||||
Whitelisted paths (no auth required): /api/v1/health
|
||||
"""
|
||||
|
||||
WHITELIST_PATHS = ("/api/v1/health",)
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
# Skip auth for whitelisted paths
|
||||
if any(request.url.path.startswith(p) for p in self.WHITELIST_PATHS):
|
||||
return await call_next(request)
|
||||
|
||||
api_key = os.environ.get("AGENTKIT_API_KEY")
|
||||
if not api_key:
|
||||
# Dev mode: skip auth if no API key configured
|
||||
return await call_next(request)
|
||||
|
||||
# Check API key from header
|
||||
provided_key = request.headers.get("X-API-Key")
|
||||
if not provided_key or provided_key != api_key:
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"error": "Unauthorized", "message": "Invalid or missing API key"},
|
||||
)
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Fixed-window rate limiter.
|
||||
|
||||
Tracks request counts per key (IP or API key) within time windows.
|
||||
"""
|
||||
|
||||
def __init__(self, max_requests: int = 60, window_seconds: int = 60):
|
||||
self._max_requests = max_requests
|
||||
self._window_seconds = window_seconds
|
||||
self._requests: dict[str, list[float]] = defaultdict(list)
|
||||
|
||||
def is_allowed(self, key: str) -> tuple[bool, float]:
|
||||
"""Check if request is allowed. Returns (allowed, retry_after_seconds)."""
|
||||
now = time.time()
|
||||
window_start = now - self._window_seconds
|
||||
|
||||
# Clean old requests outside the window
|
||||
self._requests[key] = [
|
||||
ts for ts in self._requests[key] if ts > window_start
|
||||
]
|
||||
|
||||
if len(self._requests[key]) >= self._max_requests:
|
||||
retry_after = self._requests[key][0] + self._window_seconds - now
|
||||
return False, max(0, retry_after)
|
||||
|
||||
self._requests[key].append(now)
|
||||
return True, 0.0
|
||||
|
||||
@property
|
||||
def max_requests(self) -> int:
|
||||
return self._max_requests
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""Rate limiting middleware.
|
||||
|
||||
Limits requests per IP. Returns 429 Too Many Requests when exceeded.
|
||||
Configurable via AGENTKIT_RATE_LIMIT_PER_MINUTE env var (default: 60).
|
||||
"""
|
||||
|
||||
def __init__(self, app, max_requests: int | None = None, window_seconds: int = 60):
|
||||
super().__init__(app)
|
||||
if max_requests is None:
|
||||
max_requests = int(os.environ.get("AGENTKIT_RATE_LIMIT_PER_MINUTE", "60"))
|
||||
self._limiter = RateLimiter(max_requests=max_requests, window_seconds=window_seconds)
|
||||
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
# Use API key if available, otherwise IP
|
||||
api_key = request.headers.get("X-API-Key")
|
||||
key = f"key:{api_key}" if api_key else f"ip:{request.client.host}"
|
||||
|
||||
allowed, retry_after = self._limiter.is_allowed(key)
|
||||
if not allowed:
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
"error": "Too Many Requests",
|
||||
"message": f"Rate limit exceeded. Try again in {int(retry_after)} seconds.",
|
||||
},
|
||||
headers={"Retry-After": str(int(retry_after))},
|
||||
)
|
||||
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
|
@ -354,3 +354,73 @@ class TestStandaloneRunner:
|
|||
runner = StandaloneRunner(config_dir="/nonexistent/path")
|
||||
configs = runner.discover_configs()
|
||||
assert len(configs) == 0
|
||||
|
||||
|
||||
# ── Handler Prefix Whitelist 测试 ─────────────────────────
|
||||
|
||||
|
||||
class TestHandlerPrefixWhitelist:
|
||||
"""U4: 测试 _import_handler 的模块前缀白名单,防止任意代码执行"""
|
||||
|
||||
def _make_agent_with_custom(self, handler_path: str) -> ConfigDrivenAgent:
|
||||
config = AgentConfig(
|
||||
name="test_agent",
|
||||
agent_type="test",
|
||||
task_mode="custom",
|
||||
custom_handler=handler_path,
|
||||
)
|
||||
return ConfigDrivenAgent(config=config)
|
||||
|
||||
def test_allowed_prefix_agentkit(self):
|
||||
"""agentkit.xxx.handler → 允许通过前缀检查"""
|
||||
agent = self._make_agent_with_custom("agentkit.handlers.test_handler")
|
||||
# 前缀检查通过,但模块不存在会报 ImportError,我们只验证不报 ConfigValidationError(前缀)
|
||||
try:
|
||||
agent._import_handler("agentkit.handlers.test_handler")
|
||||
except Exception as e:
|
||||
# 允许 ImportError/AttributeError(模块不存在),但不允许前缀拒绝
|
||||
assert "not in allowed module prefixes" not in str(e)
|
||||
|
||||
def test_allowed_prefix_app_agent_framework(self):
|
||||
"""app.agent_framework.handlers.xxx → 允许通过前缀检查"""
|
||||
agent = self._make_agent_with_custom("app.agent_framework.handlers.xxx_handler")
|
||||
try:
|
||||
agent._import_handler("app.agent_framework.handlers.xxx_handler")
|
||||
except Exception as e:
|
||||
assert "not in allowed module prefixes" not in str(e)
|
||||
|
||||
def test_blocked_os_system(self):
|
||||
"""os.system → 阻止(ConfigValidationError)"""
|
||||
agent = self._make_agent_with_custom("os.system")
|
||||
with pytest.raises(Exception, match="not in allowed module prefixes"):
|
||||
agent._import_handler("os.system")
|
||||
|
||||
def test_blocked_subprocess_run(self):
|
||||
"""subprocess.run → 阻止"""
|
||||
agent = self._make_agent_with_custom("subprocess.run")
|
||||
with pytest.raises(Exception, match="not in allowed module prefixes"):
|
||||
agent._import_handler("subprocess.run")
|
||||
|
||||
def test_blocked_builtins_exec(self):
|
||||
"""builtins.exec → 阻止"""
|
||||
agent = self._make_agent_with_custom("builtins.exec")
|
||||
with pytest.raises(Exception, match="not in allowed module prefixes"):
|
||||
agent._import_handler("builtins.exec")
|
||||
|
||||
def test_blocked_empty_string(self):
|
||||
"""空字符串 → 阻止(在 _import_handler 级别直接被前缀检查拒绝)"""
|
||||
config = AgentConfig(
|
||||
name="test_agent",
|
||||
agent_type="test",
|
||||
task_mode="custom",
|
||||
custom_handler="agentkit.dummy", # valid config, but we test _import_handler directly
|
||||
)
|
||||
agent = ConfigDrivenAgent(config=config)
|
||||
with pytest.raises(Exception, match="not in allowed module prefixes"):
|
||||
agent._import_handler("")
|
||||
|
||||
def test_blocked_agentkitx_prefix(self):
|
||||
"""agentkitx. → 阻止(不是 agentkit.)"""
|
||||
agent = self._make_agent_with_custom("agentkitx.handlers.evil")
|
||||
with pytest.raises(Exception, match="not in allowed module prefixes"):
|
||||
agent._import_handler("agentkitx.handlers.evil")
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
|
||||
import pytest
|
||||
|
||||
from agentkit.core.dispatcher import TaskDispatcher
|
||||
from agentkit.core.dispatcher import TaskDispatcher, _validate_callback_url
|
||||
from agentkit.core.exceptions import TaskDispatchError, TaskNotFoundError
|
||||
from agentkit.core.protocol import AgentStatus, TaskResult, TaskStatus
|
||||
|
||||
|
|
@ -267,3 +267,55 @@ class TestTaskDispatcherHandleResult:
|
|||
|
||||
assert task.status == TaskStatus.FAILED
|
||||
assert task.error_message == "Something went wrong"
|
||||
|
||||
|
||||
class TestValidateCallbackUrl:
|
||||
"""SSRF protection tests for _validate_callback_url."""
|
||||
|
||||
def test_valid_public_https_url(self):
|
||||
"""Valid public HTTPS URL should be allowed."""
|
||||
assert _validate_callback_url("https://example.com/callback") is True
|
||||
|
||||
def test_valid_public_http_url(self):
|
||||
"""Valid public HTTP URL should be allowed."""
|
||||
assert _validate_callback_url("http://example.com/callback") is True
|
||||
|
||||
def test_localhost_blocked(self):
|
||||
"""localhost should be blocked."""
|
||||
assert _validate_callback_url("http://localhost:8080/callback") is False
|
||||
|
||||
def test_loopback_ip_blocked(self):
|
||||
"""127.0.0.1 should be blocked."""
|
||||
assert _validate_callback_url("http://127.0.0.1:8080/callback") is False
|
||||
|
||||
def test_private_10_range_blocked(self):
|
||||
"""10.0.0.0/8 range should be blocked."""
|
||||
assert _validate_callback_url("http://10.0.0.1/internal") is False
|
||||
|
||||
def test_private_192_range_blocked(self):
|
||||
"""192.168.0.0/16 range should be blocked."""
|
||||
assert _validate_callback_url("http://192.168.1.1/admin") is False
|
||||
|
||||
def test_private_172_range_blocked(self):
|
||||
"""172.16.0.0/12 range should be blocked."""
|
||||
assert _validate_callback_url("http://172.16.0.1/internal") is False
|
||||
|
||||
def test_ftp_protocol_blocked(self):
|
||||
"""FTP protocol should be blocked."""
|
||||
assert _validate_callback_url("ftp://example.com/file") is False
|
||||
|
||||
def test_file_protocol_blocked(self):
|
||||
"""file:// protocol should be blocked."""
|
||||
assert _validate_callback_url("file:///etc/passwd") is False
|
||||
|
||||
def test_javascript_protocol_blocked(self):
|
||||
"""javascript: protocol should be blocked."""
|
||||
assert _validate_callback_url("javascript:alert(1)") is False
|
||||
|
||||
def test_empty_url_blocked(self):
|
||||
"""Empty URL should be blocked."""
|
||||
assert _validate_callback_url("") is False
|
||||
|
||||
def test_malformed_url_blocked(self):
|
||||
"""Malformed URL should be blocked."""
|
||||
assert _validate_callback_url("not-a-valid-url") is False
|
||||
|
|
|
|||
|
|
@ -0,0 +1,242 @@
|
|||
"""Server Middleware 单元测试 - API Key Auth + Rate Limiting"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from agentkit.server.middleware import (
|
||||
APIKeyAuthMiddleware,
|
||||
RateLimiter,
|
||||
RateLimitMiddleware,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper: minimal app with only a health endpoint for isolated middleware tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_minimal_app():
|
||||
"""Create a minimal FastAPI app with just a health endpoint."""
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/api/v1/health")
|
||||
async def health():
|
||||
return {"status": "ok"}
|
||||
|
||||
@app.get("/api/v1/protected")
|
||||
async def protected():
|
||||
return {"data": "secret"}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# APIKeyAuthMiddleware Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestAPIKeyAuthMiddleware:
|
||||
"""API Key authentication middleware tests."""
|
||||
|
||||
def test_dev_mode_no_api_key_set_passes_through(self):
|
||||
"""No AGENTKIT_API_KEY set → requests pass through (dev mode)."""
|
||||
with patch.dict(os.environ, {}, clear=False):
|
||||
# Ensure AGENTKIT_API_KEY is not set
|
||||
os.environ.pop("AGENTKIT_API_KEY", None)
|
||||
|
||||
app = _make_minimal_app()
|
||||
app.add_middleware(APIKeyAuthMiddleware)
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/api/v1/protected")
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_api_key_set_no_header_returns_401(self):
|
||||
"""AGENTKIT_API_KEY set, no header → 401."""
|
||||
with patch.dict(os.environ, {"AGENTKIT_API_KEY": "test-secret-key"}):
|
||||
app = _make_minimal_app()
|
||||
app.add_middleware(APIKeyAuthMiddleware)
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/api/v1/protected")
|
||||
assert response.status_code == 401
|
||||
data = response.json()
|
||||
assert data["error"] == "Unauthorized"
|
||||
|
||||
def test_api_key_set_wrong_header_returns_401(self):
|
||||
"""AGENTKIT_API_KEY set, wrong header → 401."""
|
||||
with patch.dict(os.environ, {"AGENTKIT_API_KEY": "test-secret-key"}):
|
||||
app = _make_minimal_app()
|
||||
app.add_middleware(APIKeyAuthMiddleware)
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get(
|
||||
"/api/v1/protected",
|
||||
headers={"X-API-Key": "wrong-key"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_api_key_set_correct_header_returns_200(self):
|
||||
"""AGENTKIT_API_KEY set, correct header → 200."""
|
||||
with patch.dict(os.environ, {"AGENTKIT_API_KEY": "test-secret-key"}):
|
||||
app = _make_minimal_app()
|
||||
app.add_middleware(APIKeyAuthMiddleware)
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get(
|
||||
"/api/v1/protected",
|
||||
headers={"X-API-Key": "test-secret-key"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["data"] == "secret"
|
||||
|
||||
def test_health_check_path_no_auth_required(self):
|
||||
"""Health check path → 200 without API key."""
|
||||
with patch.dict(os.environ, {"AGENTKIT_API_KEY": "test-secret-key"}):
|
||||
app = _make_minimal_app()
|
||||
app.add_middleware(APIKeyAuthMiddleware)
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/api/v1/health")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["status"] == "ok"
|
||||
|
||||
def test_programmatic_api_key_parameter(self):
|
||||
"""Programmatic api_key parameter → uses passed key."""
|
||||
with patch.dict(os.environ, {}, clear=False):
|
||||
os.environ.pop("AGENTKIT_API_KEY", None)
|
||||
|
||||
app = _make_minimal_app()
|
||||
# Set the API key via environment before adding middleware
|
||||
os.environ["AGENTKIT_API_KEY"] = "programmatic-key"
|
||||
app.add_middleware(APIKeyAuthMiddleware)
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get(
|
||||
"/api/v1/protected",
|
||||
headers={"X-API-Key": "programmatic-key"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RateLimiter Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRateLimiter:
|
||||
"""Fixed-window rate limiter unit tests."""
|
||||
|
||||
def test_requests_within_limit_allowed(self):
|
||||
"""Requests within limit → allowed."""
|
||||
limiter = RateLimiter(max_requests=5, window_seconds=60)
|
||||
|
||||
for i in range(5):
|
||||
allowed, retry_after = limiter.is_allowed("test-key")
|
||||
assert allowed is True
|
||||
assert retry_after == 0.0
|
||||
|
||||
def test_requests_exceed_limit_denied(self):
|
||||
"""Requests exceed limit → denied with retry_after."""
|
||||
limiter = RateLimiter(max_requests=2, window_seconds=60)
|
||||
|
||||
# Use up the limit
|
||||
limiter.is_allowed("test-key")
|
||||
limiter.is_allowed("test-key")
|
||||
|
||||
# Next request should be denied
|
||||
allowed, retry_after = limiter.is_allowed("test-key")
|
||||
assert allowed is False
|
||||
assert retry_after > 0
|
||||
|
||||
def test_after_window_expires_counter_resets(self):
|
||||
"""After window expires → counter resets."""
|
||||
limiter = RateLimiter(max_requests=2, window_seconds=1)
|
||||
|
||||
# Use up the limit
|
||||
limiter.is_allowed("test-key")
|
||||
limiter.is_allowed("test-key")
|
||||
|
||||
# Should be denied
|
||||
allowed, _ = limiter.is_allowed("test-key")
|
||||
assert allowed is False
|
||||
|
||||
# Wait for window to expire
|
||||
time.sleep(1.1)
|
||||
|
||||
# Should be allowed again
|
||||
allowed, retry_after = limiter.is_allowed("test-key")
|
||||
assert allowed is True
|
||||
|
||||
def test_different_keys_independent_counters(self):
|
||||
"""Different keys have independent counters."""
|
||||
limiter = RateLimiter(max_requests=1, window_seconds=60)
|
||||
|
||||
# Use up key-a's limit
|
||||
limiter.is_allowed("key-a")
|
||||
|
||||
# key-a should be denied
|
||||
allowed_a, _ = limiter.is_allowed("key-a")
|
||||
assert allowed_a is False
|
||||
|
||||
# key-b should still be allowed
|
||||
allowed_b, _ = limiter.is_allowed("key-b")
|
||||
assert allowed_b is True
|
||||
|
||||
def test_max_requests_property(self):
|
||||
"""max_requests property returns configured value."""
|
||||
limiter = RateLimiter(max_requests=100, window_seconds=30)
|
||||
assert limiter.max_requests == 100
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RateLimitMiddleware Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRateLimitMiddleware:
|
||||
"""Rate limiting middleware integration tests."""
|
||||
|
||||
def test_returns_429_with_retry_after_header(self):
|
||||
"""Returns 429 with Retry-After header when limit exceeded."""
|
||||
app = _make_minimal_app()
|
||||
app.add_middleware(RateLimitMiddleware, max_requests=1, window_seconds=60)
|
||||
client = TestClient(app)
|
||||
|
||||
# First request should pass
|
||||
response1 = client.get("/api/v1/protected")
|
||||
assert response1.status_code == 200
|
||||
|
||||
# Second request should be rate limited
|
||||
response2 = client.get("/api/v1/protected")
|
||||
assert response2.status_code == 429
|
||||
data = response2.json()
|
||||
assert data["error"] == "Too Many Requests"
|
||||
assert "Retry-After" in response2.headers
|
||||
|
||||
def test_uses_api_key_for_identity(self):
|
||||
"""Uses API key for identity when present (different keys = different limits)."""
|
||||
app = _make_minimal_app()
|
||||
app.add_middleware(RateLimitMiddleware, max_requests=1, window_seconds=60)
|
||||
client = TestClient(app)
|
||||
|
||||
# Request with key-a
|
||||
response_a1 = client.get(
|
||||
"/api/v1/protected",
|
||||
headers={"X-API-Key": "key-a"},
|
||||
)
|
||||
assert response_a1.status_code == 200
|
||||
|
||||
# key-a should now be rate limited
|
||||
response_a2 = client.get(
|
||||
"/api/v1/protected",
|
||||
headers={"X-API-Key": "key-a"},
|
||||
)
|
||||
assert response_a2.status_code == 429
|
||||
|
||||
# key-b should still be allowed (independent counter)
|
||||
response_b1 = client.get(
|
||||
"/api/v1/protected",
|
||||
headers={"X-API-Key": "key-b"},
|
||||
)
|
||||
assert response_b1.status_code == 200
|
||||
Loading…
Reference in New Issue