fischer-agentkit/src/agentkit/server/middleware.py

106 lines
3.7 KiB
Python

"""Server middleware - Authentication and Rate Limiting"""
import os
import time
from collections import defaultdict
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse
class APIKeyAuthMiddleware(BaseHTTPMiddleware):
"""API Key authentication middleware.
Validates X-API-Key header against AGENTKIT_API_KEY env var.
Skips validation if AGENTKIT_API_KEY is not set (dev mode).
Whitelisted paths (no auth required): /api/v1/health
"""
WHITELIST_PATHS = ("/api/v1/health",)
async def dispatch(self, request: Request, call_next):
# Skip auth for whitelisted paths
if any(request.url.path.startswith(p) for p in self.WHITELIST_PATHS):
return await call_next(request)
api_key = os.environ.get("AGENTKIT_API_KEY")
if not api_key:
# Dev mode: skip auth if no API key configured
return await call_next(request)
# Check API key from header
provided_key = request.headers.get("X-API-Key")
if not provided_key or provided_key != api_key:
return JSONResponse(
status_code=401,
content={"error": "Unauthorized", "message": "Invalid or missing API key"},
)
return await call_next(request)
class RateLimiter:
"""Fixed-window rate limiter.
Tracks request counts per key (IP or API key) within time windows.
"""
def __init__(self, max_requests: int = 60, window_seconds: int = 60):
self._max_requests = max_requests
self._window_seconds = window_seconds
self._requests: dict[str, list[float]] = defaultdict(list)
def is_allowed(self, key: str) -> tuple[bool, float]:
"""Check if request is allowed. Returns (allowed, retry_after_seconds)."""
now = time.time()
window_start = now - self._window_seconds
# Clean old requests outside the window
self._requests[key] = [
ts for ts in self._requests[key] if ts > window_start
]
if len(self._requests[key]) >= self._max_requests:
retry_after = self._requests[key][0] + self._window_seconds - now
return False, max(0, retry_after)
self._requests[key].append(now)
return True, 0.0
@property
def max_requests(self) -> int:
return self._max_requests
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Rate limiting middleware.
Limits requests per IP. Returns 429 Too Many Requests when exceeded.
Configurable via AGENTKIT_RATE_LIMIT_PER_MINUTE env var (default: 60).
"""
def __init__(self, app, max_requests: int | None = None, window_seconds: int = 60):
super().__init__(app)
if max_requests is None:
max_requests = int(os.environ.get("AGENTKIT_RATE_LIMIT_PER_MINUTE", "60"))
self._limiter = RateLimiter(max_requests=max_requests, window_seconds=window_seconds)
async def dispatch(self, request: Request, call_next):
# Use API key if available, otherwise IP
api_key = request.headers.get("X-API-Key")
key = f"key:{api_key}" if api_key else f"ip:{request.client.host}"
allowed, retry_after = self._limiter.is_allowed(key)
if not allowed:
return JSONResponse(
status_code=429,
content={
"error": "Too Many Requests",
"message": f"Rate limit exceeded. Try again in {int(retry_after)} seconds.",
},
headers={"Retry-After": str(int(retry_after))},
)
response = await call_next(request)
return response