106 lines
3.7 KiB
Python
106 lines
3.7 KiB
Python
"""Server middleware - Authentication and Rate Limiting"""
|
|
|
|
import os
|
|
import time
|
|
from collections import defaultdict
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from starlette.requests import Request
|
|
from starlette.responses import JSONResponse
|
|
|
|
|
|
class APIKeyAuthMiddleware(BaseHTTPMiddleware):
|
|
"""API Key authentication middleware.
|
|
|
|
Validates X-API-Key header against AGENTKIT_API_KEY env var.
|
|
Skips validation if AGENTKIT_API_KEY is not set (dev mode).
|
|
Whitelisted paths (no auth required): /api/v1/health
|
|
"""
|
|
|
|
WHITELIST_PATHS = ("/api/v1/health",)
|
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
# Skip auth for whitelisted paths
|
|
if any(request.url.path.startswith(p) for p in self.WHITELIST_PATHS):
|
|
return await call_next(request)
|
|
|
|
api_key = os.environ.get("AGENTKIT_API_KEY")
|
|
if not api_key:
|
|
# Dev mode: skip auth if no API key configured
|
|
return await call_next(request)
|
|
|
|
# Check API key from header
|
|
provided_key = request.headers.get("X-API-Key")
|
|
if not provided_key or provided_key != api_key:
|
|
return JSONResponse(
|
|
status_code=401,
|
|
content={"error": "Unauthorized", "message": "Invalid or missing API key"},
|
|
)
|
|
|
|
return await call_next(request)
|
|
|
|
|
|
class RateLimiter:
|
|
"""Fixed-window rate limiter.
|
|
|
|
Tracks request counts per key (IP or API key) within time windows.
|
|
"""
|
|
|
|
def __init__(self, max_requests: int = 60, window_seconds: int = 60):
|
|
self._max_requests = max_requests
|
|
self._window_seconds = window_seconds
|
|
self._requests: dict[str, list[float]] = defaultdict(list)
|
|
|
|
def is_allowed(self, key: str) -> tuple[bool, float]:
|
|
"""Check if request is allowed. Returns (allowed, retry_after_seconds)."""
|
|
now = time.time()
|
|
window_start = now - self._window_seconds
|
|
|
|
# Clean old requests outside the window
|
|
self._requests[key] = [
|
|
ts for ts in self._requests[key] if ts > window_start
|
|
]
|
|
|
|
if len(self._requests[key]) >= self._max_requests:
|
|
retry_after = self._requests[key][0] + self._window_seconds - now
|
|
return False, max(0, retry_after)
|
|
|
|
self._requests[key].append(now)
|
|
return True, 0.0
|
|
|
|
@property
|
|
def max_requests(self) -> int:
|
|
return self._max_requests
|
|
|
|
|
|
class RateLimitMiddleware(BaseHTTPMiddleware):
|
|
"""Rate limiting middleware.
|
|
|
|
Limits requests per IP. Returns 429 Too Many Requests when exceeded.
|
|
Configurable via AGENTKIT_RATE_LIMIT_PER_MINUTE env var (default: 60).
|
|
"""
|
|
|
|
def __init__(self, app, max_requests: int | None = None, window_seconds: int = 60):
|
|
super().__init__(app)
|
|
if max_requests is None:
|
|
max_requests = int(os.environ.get("AGENTKIT_RATE_LIMIT_PER_MINUTE", "60"))
|
|
self._limiter = RateLimiter(max_requests=max_requests, window_seconds=window_seconds)
|
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
# Use API key if available, otherwise IP
|
|
api_key = request.headers.get("X-API-Key")
|
|
key = f"key:{api_key}" if api_key else f"ip:{request.client.host}"
|
|
|
|
allowed, retry_after = self._limiter.is_allowed(key)
|
|
if not allowed:
|
|
return JSONResponse(
|
|
status_code=429,
|
|
content={
|
|
"error": "Too Many Requests",
|
|
"message": f"Rate limit exceeded. Try again in {int(retry_after)} seconds.",
|
|
},
|
|
headers={"Retry-After": str(int(retry_after))},
|
|
)
|
|
|
|
response = await call_next(request)
|
|
return response
|