"""Server middleware - Authentication and Rate Limiting""" import os import time from collections import defaultdict from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import JSONResponse class APIKeyAuthMiddleware(BaseHTTPMiddleware): """API Key authentication middleware. Validates X-API-Key header against AGENTKIT_API_KEY env var. Skips validation if AGENTKIT_API_KEY is not set (dev mode). Whitelisted paths (no auth required): /api/v1/health """ WHITELIST_PATHS = ("/api/v1/health",) async def dispatch(self, request: Request, call_next): # Skip auth for whitelisted paths if any(request.url.path.startswith(p) for p in self.WHITELIST_PATHS): return await call_next(request) api_key = os.environ.get("AGENTKIT_API_KEY") if not api_key: # Dev mode: skip auth if no API key configured return await call_next(request) # Check API key from header provided_key = request.headers.get("X-API-Key") if not provided_key or provided_key != api_key: return JSONResponse( status_code=401, content={"error": "Unauthorized", "message": "Invalid or missing API key"}, ) return await call_next(request) class RateLimiter: """Fixed-window rate limiter. Tracks request counts per key (IP or API key) within time windows. """ def __init__(self, max_requests: int = 60, window_seconds: int = 60): self._max_requests = max_requests self._window_seconds = window_seconds self._requests: dict[str, list[float]] = defaultdict(list) def is_allowed(self, key: str) -> tuple[bool, float]: """Check if request is allowed. Returns (allowed, retry_after_seconds).""" now = time.time() window_start = now - self._window_seconds # Clean old requests outside the window self._requests[key] = [ ts for ts in self._requests[key] if ts > window_start ] if len(self._requests[key]) >= self._max_requests: retry_after = self._requests[key][0] + self._window_seconds - now return False, max(0, retry_after) self._requests[key].append(now) return True, 0.0 @property def max_requests(self) -> int: return self._max_requests class RateLimitMiddleware(BaseHTTPMiddleware): """Rate limiting middleware. Limits requests per IP. Returns 429 Too Many Requests when exceeded. Configurable via AGENTKIT_RATE_LIMIT_PER_MINUTE env var (default: 60). """ def __init__(self, app, max_requests: int | None = None, window_seconds: int = 60): super().__init__(app) if max_requests is None: max_requests = int(os.environ.get("AGENTKIT_RATE_LIMIT_PER_MINUTE", "60")) self._limiter = RateLimiter(max_requests=max_requests, window_seconds=window_seconds) async def dispatch(self, request: Request, call_next): # Use API key if available, otherwise IP api_key = request.headers.get("X-API-Key") key = f"key:{api_key}" if api_key else f"ip:{request.client.host}" allowed, retry_after = self._limiter.is_allowed(key) if not allowed: return JSONResponse( status_code=429, content={ "error": "Too Many Requests", "message": f"Rate limit exceeded. Try again in {int(retry_after)} seconds.", }, headers={"Retry-After": str(int(retry_after))}, ) response = await call_next(request) return response