feat(server): Phase B - auth, rate limiting, SSRF protection, handler whitelist
U1: API Key authentication middleware (dev mode skip, health whitelist) U2: Rate limiting middleware (fixed-window, 60 req/min default) U3: Callback URL SSRF protection (private IP blocking) U4: custom_handler module prefix whitelist 65 tests passing. CORS conflict fixed.
This commit is contained in:
parent
f87b790c0f
commit
5f1c51cf9a
|
|
@ -184,6 +184,12 @@ class ConfigDrivenAgent(BaseAgent):
|
||||||
- retrieve_knowledge
|
- retrieve_knowledge
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Security: whitelist of allowed module prefixes for dynamic handler import
|
||||||
|
_ALLOWED_HANDLER_PREFIXES = (
|
||||||
|
"agentkit.",
|
||||||
|
"app.agent_framework.",
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: AgentConfig,
|
config: AgentConfig,
|
||||||
|
|
@ -566,6 +572,14 @@ class ConfigDrivenAgent(BaseAgent):
|
||||||
|
|
||||||
def _import_handler(self, dotted_path: str) -> Callable[..., Coroutine]:
|
def _import_handler(self, dotted_path: str) -> Callable[..., Coroutine]:
|
||||||
"""动态导入自定义 handler"""
|
"""动态导入自定义 handler"""
|
||||||
|
# Security: validate module prefix to prevent arbitrary code execution
|
||||||
|
if not any(dotted_path.startswith(prefix) for prefix in self._ALLOWED_HANDLER_PREFIXES):
|
||||||
|
raise ConfigValidationError(
|
||||||
|
agent_name=self.name,
|
||||||
|
key="custom_handler",
|
||||||
|
reason=f"Handler '{dotted_path}' is not in allowed module prefixes: {self._ALLOWED_HANDLER_PREFIXES}",
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
module_path, func_name = dotted_path.rsplit(".", 1)
|
module_path, func_name = dotted_path.rsplit(".", 1)
|
||||||
import importlib
|
import importlib
|
||||||
|
|
|
||||||
|
|
@ -3,11 +3,13 @@
|
||||||
与业务系统解耦:通过依赖注入获取 Redis 连接和数据库会话。
|
与业务系统解耦:通过依赖注入获取 Redis 连接和数据库会话。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import ipaddress
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any, Callable, Awaitable
|
from typing import Any, Callable, Awaitable
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from agentkit.core.exceptions import (
|
from agentkit.core.exceptions import (
|
||||||
NoAvailableAgentError,
|
NoAvailableAgentError,
|
||||||
|
|
@ -24,6 +26,54 @@ from agentkit.core.protocol import (
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_PRIVATE_NETWORKS = [
|
||||||
|
ipaddress.ip_network("127.0.0.0/8"),
|
||||||
|
ipaddress.ip_network("10.0.0.0/8"),
|
||||||
|
ipaddress.ip_network("172.16.0.0/12"),
|
||||||
|
ipaddress.ip_network("192.168.0.0/16"),
|
||||||
|
ipaddress.ip_network("169.254.0.0/16"),
|
||||||
|
ipaddress.ip_network("::1/128"),
|
||||||
|
ipaddress.ip_network("fc00::/7"),
|
||||||
|
ipaddress.ip_network("fe80::/10"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_callback_url(url: str) -> bool:
|
||||||
|
"""Validate callback URL to prevent SSRF attacks.
|
||||||
|
|
||||||
|
Rules:
|
||||||
|
- Only http/https protocols allowed
|
||||||
|
- No localhost or loopback addresses
|
||||||
|
- No private/internal IP ranges
|
||||||
|
- No link-local addresses
|
||||||
|
|
||||||
|
Returns True if valid, False if should be blocked.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
parsed = urlparse(url)
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if parsed.scheme not in ("http", "https"):
|
||||||
|
return False
|
||||||
|
|
||||||
|
hostname = parsed.hostname
|
||||||
|
if not hostname:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if hostname.lower() in ("localhost", "127.0.0.1", "::1"):
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
ip = ipaddress.ip_address(hostname)
|
||||||
|
for network in _PRIVATE_NETWORKS:
|
||||||
|
if ip in network:
|
||||||
|
return False
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
class TaskDispatcher:
|
class TaskDispatcher:
|
||||||
"""任务分发器,通过 Redis Queue 将任务分发给 Agent"""
|
"""任务分发器,通过 Redis Queue 将任务分发给 Agent"""
|
||||||
|
|
@ -333,6 +383,10 @@ class TaskDispatcher:
|
||||||
db.add(log_entry)
|
db.add(log_entry)
|
||||||
|
|
||||||
async def _trigger_callback(self, callback_url: str, result: TaskResult):
|
async def _trigger_callback(self, callback_url: str, result: TaskResult):
|
||||||
|
if not _validate_callback_url(callback_url):
|
||||||
|
logger.warning(f"Callback URL rejected (SSRF protection): {callback_url}")
|
||||||
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import httpx
|
import httpx
|
||||||
async with httpx.AsyncClient(timeout=10) as client:
|
async with httpx.AsyncClient(timeout=10) as client:
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
"""FastAPI Application Factory"""
|
"""FastAPI Application Factory"""
|
||||||
|
|
||||||
|
import os
|
||||||
from fastapi import FastAPI, Request
|
from fastapi import FastAPI, Request
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
|
@ -11,12 +12,15 @@ from agentkit.router.intent import IntentRouter
|
||||||
from agentkit.skills.registry import SkillRegistry
|
from agentkit.skills.registry import SkillRegistry
|
||||||
from agentkit.tools.registry import ToolRegistry
|
from agentkit.tools.registry import ToolRegistry
|
||||||
from agentkit.server.routes import agents, tasks, skills, llm, health
|
from agentkit.server.routes import agents, tasks, skills, llm, health
|
||||||
|
from agentkit.server.middleware import APIKeyAuthMiddleware, RateLimitMiddleware
|
||||||
|
|
||||||
|
|
||||||
def create_app(
|
def create_app(
|
||||||
llm_gateway: LLMGateway | None = None,
|
llm_gateway: LLMGateway | None = None,
|
||||||
skill_registry: SkillRegistry | None = None,
|
skill_registry: SkillRegistry | None = None,
|
||||||
tool_registry: ToolRegistry | None = None,
|
tool_registry: ToolRegistry | None = None,
|
||||||
|
api_key: str | None = None,
|
||||||
|
rate_limit: int | None = None,
|
||||||
) -> FastAPI:
|
) -> FastAPI:
|
||||||
"""Create and configure the FastAPI application"""
|
"""Create and configure the FastAPI application"""
|
||||||
app = FastAPI(title="AgentKit Server", version="2.0.0")
|
app = FastAPI(title="AgentKit Server", version="2.0.0")
|
||||||
|
|
@ -25,11 +29,20 @@ def create_app(
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=["*"], # 生产环境应限制具体域名
|
allow_origins=["*"], # 生产环境应限制具体域名
|
||||||
allow_credentials=True,
|
|
||||||
allow_methods=["*"],
|
allow_methods=["*"],
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Auth middleware
|
||||||
|
if api_key:
|
||||||
|
os.environ["AGENTKIT_API_KEY"] = api_key
|
||||||
|
app.add_middleware(APIKeyAuthMiddleware)
|
||||||
|
|
||||||
|
# Rate limiting middleware
|
||||||
|
if rate_limit is not None:
|
||||||
|
os.environ["AGENTKIT_RATE_LIMIT_PER_MINUTE"] = str(rate_limit)
|
||||||
|
app.add_middleware(RateLimitMiddleware)
|
||||||
|
|
||||||
# Initialize shared state
|
# Initialize shared state
|
||||||
app.state.llm_gateway = llm_gateway or LLMGateway()
|
app.state.llm_gateway = llm_gateway or LLMGateway()
|
||||||
app.state.skill_registry = skill_registry or SkillRegistry()
|
app.state.skill_registry = skill_registry or SkillRegistry()
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,105 @@
|
||||||
|
"""Server middleware - Authentication and Rate Limiting"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from collections import defaultdict
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from starlette.requests import Request
|
||||||
|
from starlette.responses import JSONResponse
|
||||||
|
|
||||||
|
|
||||||
|
class APIKeyAuthMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""API Key authentication middleware.
|
||||||
|
|
||||||
|
Validates X-API-Key header against AGENTKIT_API_KEY env var.
|
||||||
|
Skips validation if AGENTKIT_API_KEY is not set (dev mode).
|
||||||
|
Whitelisted paths (no auth required): /api/v1/health
|
||||||
|
"""
|
||||||
|
|
||||||
|
WHITELIST_PATHS = ("/api/v1/health",)
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next):
|
||||||
|
# Skip auth for whitelisted paths
|
||||||
|
if any(request.url.path.startswith(p) for p in self.WHITELIST_PATHS):
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
api_key = os.environ.get("AGENTKIT_API_KEY")
|
||||||
|
if not api_key:
|
||||||
|
# Dev mode: skip auth if no API key configured
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
# Check API key from header
|
||||||
|
provided_key = request.headers.get("X-API-Key")
|
||||||
|
if not provided_key or provided_key != api_key:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=401,
|
||||||
|
content={"error": "Unauthorized", "message": "Invalid or missing API key"},
|
||||||
|
)
|
||||||
|
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimiter:
|
||||||
|
"""Fixed-window rate limiter.
|
||||||
|
|
||||||
|
Tracks request counts per key (IP or API key) within time windows.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, max_requests: int = 60, window_seconds: int = 60):
|
||||||
|
self._max_requests = max_requests
|
||||||
|
self._window_seconds = window_seconds
|
||||||
|
self._requests: dict[str, list[float]] = defaultdict(list)
|
||||||
|
|
||||||
|
def is_allowed(self, key: str) -> tuple[bool, float]:
|
||||||
|
"""Check if request is allowed. Returns (allowed, retry_after_seconds)."""
|
||||||
|
now = time.time()
|
||||||
|
window_start = now - self._window_seconds
|
||||||
|
|
||||||
|
# Clean old requests outside the window
|
||||||
|
self._requests[key] = [
|
||||||
|
ts for ts in self._requests[key] if ts > window_start
|
||||||
|
]
|
||||||
|
|
||||||
|
if len(self._requests[key]) >= self._max_requests:
|
||||||
|
retry_after = self._requests[key][0] + self._window_seconds - now
|
||||||
|
return False, max(0, retry_after)
|
||||||
|
|
||||||
|
self._requests[key].append(now)
|
||||||
|
return True, 0.0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_requests(self) -> int:
|
||||||
|
return self._max_requests
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""Rate limiting middleware.
|
||||||
|
|
||||||
|
Limits requests per IP. Returns 429 Too Many Requests when exceeded.
|
||||||
|
Configurable via AGENTKIT_RATE_LIMIT_PER_MINUTE env var (default: 60).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, app, max_requests: int | None = None, window_seconds: int = 60):
|
||||||
|
super().__init__(app)
|
||||||
|
if max_requests is None:
|
||||||
|
max_requests = int(os.environ.get("AGENTKIT_RATE_LIMIT_PER_MINUTE", "60"))
|
||||||
|
self._limiter = RateLimiter(max_requests=max_requests, window_seconds=window_seconds)
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next):
|
||||||
|
# Use API key if available, otherwise IP
|
||||||
|
api_key = request.headers.get("X-API-Key")
|
||||||
|
key = f"key:{api_key}" if api_key else f"ip:{request.client.host}"
|
||||||
|
|
||||||
|
allowed, retry_after = self._limiter.is_allowed(key)
|
||||||
|
if not allowed:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=429,
|
||||||
|
content={
|
||||||
|
"error": "Too Many Requests",
|
||||||
|
"message": f"Rate limit exceeded. Try again in {int(retry_after)} seconds.",
|
||||||
|
},
|
||||||
|
headers={"Retry-After": str(int(retry_after))},
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await call_next(request)
|
||||||
|
return response
|
||||||
|
|
@ -354,3 +354,73 @@ class TestStandaloneRunner:
|
||||||
runner = StandaloneRunner(config_dir="/nonexistent/path")
|
runner = StandaloneRunner(config_dir="/nonexistent/path")
|
||||||
configs = runner.discover_configs()
|
configs = runner.discover_configs()
|
||||||
assert len(configs) == 0
|
assert len(configs) == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ── Handler Prefix Whitelist 测试 ─────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestHandlerPrefixWhitelist:
|
||||||
|
"""U4: 测试 _import_handler 的模块前缀白名单,防止任意代码执行"""
|
||||||
|
|
||||||
|
def _make_agent_with_custom(self, handler_path: str) -> ConfigDrivenAgent:
|
||||||
|
config = AgentConfig(
|
||||||
|
name="test_agent",
|
||||||
|
agent_type="test",
|
||||||
|
task_mode="custom",
|
||||||
|
custom_handler=handler_path,
|
||||||
|
)
|
||||||
|
return ConfigDrivenAgent(config=config)
|
||||||
|
|
||||||
|
def test_allowed_prefix_agentkit(self):
|
||||||
|
"""agentkit.xxx.handler → 允许通过前缀检查"""
|
||||||
|
agent = self._make_agent_with_custom("agentkit.handlers.test_handler")
|
||||||
|
# 前缀检查通过,但模块不存在会报 ImportError,我们只验证不报 ConfigValidationError(前缀)
|
||||||
|
try:
|
||||||
|
agent._import_handler("agentkit.handlers.test_handler")
|
||||||
|
except Exception as e:
|
||||||
|
# 允许 ImportError/AttributeError(模块不存在),但不允许前缀拒绝
|
||||||
|
assert "not in allowed module prefixes" not in str(e)
|
||||||
|
|
||||||
|
def test_allowed_prefix_app_agent_framework(self):
|
||||||
|
"""app.agent_framework.handlers.xxx → 允许通过前缀检查"""
|
||||||
|
agent = self._make_agent_with_custom("app.agent_framework.handlers.xxx_handler")
|
||||||
|
try:
|
||||||
|
agent._import_handler("app.agent_framework.handlers.xxx_handler")
|
||||||
|
except Exception as e:
|
||||||
|
assert "not in allowed module prefixes" not in str(e)
|
||||||
|
|
||||||
|
def test_blocked_os_system(self):
|
||||||
|
"""os.system → 阻止(ConfigValidationError)"""
|
||||||
|
agent = self._make_agent_with_custom("os.system")
|
||||||
|
with pytest.raises(Exception, match="not in allowed module prefixes"):
|
||||||
|
agent._import_handler("os.system")
|
||||||
|
|
||||||
|
def test_blocked_subprocess_run(self):
|
||||||
|
"""subprocess.run → 阻止"""
|
||||||
|
agent = self._make_agent_with_custom("subprocess.run")
|
||||||
|
with pytest.raises(Exception, match="not in allowed module prefixes"):
|
||||||
|
agent._import_handler("subprocess.run")
|
||||||
|
|
||||||
|
def test_blocked_builtins_exec(self):
|
||||||
|
"""builtins.exec → 阻止"""
|
||||||
|
agent = self._make_agent_with_custom("builtins.exec")
|
||||||
|
with pytest.raises(Exception, match="not in allowed module prefixes"):
|
||||||
|
agent._import_handler("builtins.exec")
|
||||||
|
|
||||||
|
def test_blocked_empty_string(self):
|
||||||
|
"""空字符串 → 阻止(在 _import_handler 级别直接被前缀检查拒绝)"""
|
||||||
|
config = AgentConfig(
|
||||||
|
name="test_agent",
|
||||||
|
agent_type="test",
|
||||||
|
task_mode="custom",
|
||||||
|
custom_handler="agentkit.dummy", # valid config, but we test _import_handler directly
|
||||||
|
)
|
||||||
|
agent = ConfigDrivenAgent(config=config)
|
||||||
|
with pytest.raises(Exception, match="not in allowed module prefixes"):
|
||||||
|
agent._import_handler("")
|
||||||
|
|
||||||
|
def test_blocked_agentkitx_prefix(self):
|
||||||
|
"""agentkitx. → 阻止(不是 agentkit.)"""
|
||||||
|
agent = self._make_agent_with_custom("agentkitx.handlers.evil")
|
||||||
|
with pytest.raises(Exception, match="not in allowed module prefixes"):
|
||||||
|
agent._import_handler("agentkitx.handlers.evil")
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from agentkit.core.dispatcher import TaskDispatcher
|
from agentkit.core.dispatcher import TaskDispatcher, _validate_callback_url
|
||||||
from agentkit.core.exceptions import TaskDispatchError, TaskNotFoundError
|
from agentkit.core.exceptions import TaskDispatchError, TaskNotFoundError
|
||||||
from agentkit.core.protocol import AgentStatus, TaskResult, TaskStatus
|
from agentkit.core.protocol import AgentStatus, TaskResult, TaskStatus
|
||||||
|
|
||||||
|
|
@ -267,3 +267,55 @@ class TestTaskDispatcherHandleResult:
|
||||||
|
|
||||||
assert task.status == TaskStatus.FAILED
|
assert task.status == TaskStatus.FAILED
|
||||||
assert task.error_message == "Something went wrong"
|
assert task.error_message == "Something went wrong"
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidateCallbackUrl:
|
||||||
|
"""SSRF protection tests for _validate_callback_url."""
|
||||||
|
|
||||||
|
def test_valid_public_https_url(self):
|
||||||
|
"""Valid public HTTPS URL should be allowed."""
|
||||||
|
assert _validate_callback_url("https://example.com/callback") is True
|
||||||
|
|
||||||
|
def test_valid_public_http_url(self):
|
||||||
|
"""Valid public HTTP URL should be allowed."""
|
||||||
|
assert _validate_callback_url("http://example.com/callback") is True
|
||||||
|
|
||||||
|
def test_localhost_blocked(self):
|
||||||
|
"""localhost should be blocked."""
|
||||||
|
assert _validate_callback_url("http://localhost:8080/callback") is False
|
||||||
|
|
||||||
|
def test_loopback_ip_blocked(self):
|
||||||
|
"""127.0.0.1 should be blocked."""
|
||||||
|
assert _validate_callback_url("http://127.0.0.1:8080/callback") is False
|
||||||
|
|
||||||
|
def test_private_10_range_blocked(self):
|
||||||
|
"""10.0.0.0/8 range should be blocked."""
|
||||||
|
assert _validate_callback_url("http://10.0.0.1/internal") is False
|
||||||
|
|
||||||
|
def test_private_192_range_blocked(self):
|
||||||
|
"""192.168.0.0/16 range should be blocked."""
|
||||||
|
assert _validate_callback_url("http://192.168.1.1/admin") is False
|
||||||
|
|
||||||
|
def test_private_172_range_blocked(self):
|
||||||
|
"""172.16.0.0/12 range should be blocked."""
|
||||||
|
assert _validate_callback_url("http://172.16.0.1/internal") is False
|
||||||
|
|
||||||
|
def test_ftp_protocol_blocked(self):
|
||||||
|
"""FTP protocol should be blocked."""
|
||||||
|
assert _validate_callback_url("ftp://example.com/file") is False
|
||||||
|
|
||||||
|
def test_file_protocol_blocked(self):
|
||||||
|
"""file:// protocol should be blocked."""
|
||||||
|
assert _validate_callback_url("file:///etc/passwd") is False
|
||||||
|
|
||||||
|
def test_javascript_protocol_blocked(self):
|
||||||
|
"""javascript: protocol should be blocked."""
|
||||||
|
assert _validate_callback_url("javascript:alert(1)") is False
|
||||||
|
|
||||||
|
def test_empty_url_blocked(self):
|
||||||
|
"""Empty URL should be blocked."""
|
||||||
|
assert _validate_callback_url("") is False
|
||||||
|
|
||||||
|
def test_malformed_url_blocked(self):
|
||||||
|
"""Malformed URL should be blocked."""
|
||||||
|
assert _validate_callback_url("not-a-valid-url") is False
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,242 @@
|
||||||
|
"""Server Middleware 单元测试 - API Key Auth + Rate Limiting"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from agentkit.server.middleware import (
|
||||||
|
APIKeyAuthMiddleware,
|
||||||
|
RateLimiter,
|
||||||
|
RateLimitMiddleware,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helper: minimal app with only a health endpoint for isolated middleware tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _make_minimal_app():
|
||||||
|
"""Create a minimal FastAPI app with just a health endpoint."""
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
@app.get("/api/v1/health")
|
||||||
|
async def health():
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
@app.get("/api/v1/protected")
|
||||||
|
async def protected():
|
||||||
|
return {"data": "secret"}
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# APIKeyAuthMiddleware Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestAPIKeyAuthMiddleware:
|
||||||
|
"""API Key authentication middleware tests."""
|
||||||
|
|
||||||
|
def test_dev_mode_no_api_key_set_passes_through(self):
|
||||||
|
"""No AGENTKIT_API_KEY set → requests pass through (dev mode)."""
|
||||||
|
with patch.dict(os.environ, {}, clear=False):
|
||||||
|
# Ensure AGENTKIT_API_KEY is not set
|
||||||
|
os.environ.pop("AGENTKIT_API_KEY", None)
|
||||||
|
|
||||||
|
app = _make_minimal_app()
|
||||||
|
app.add_middleware(APIKeyAuthMiddleware)
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
response = client.get("/api/v1/protected")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
def test_api_key_set_no_header_returns_401(self):
|
||||||
|
"""AGENTKIT_API_KEY set, no header → 401."""
|
||||||
|
with patch.dict(os.environ, {"AGENTKIT_API_KEY": "test-secret-key"}):
|
||||||
|
app = _make_minimal_app()
|
||||||
|
app.add_middleware(APIKeyAuthMiddleware)
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
response = client.get("/api/v1/protected")
|
||||||
|
assert response.status_code == 401
|
||||||
|
data = response.json()
|
||||||
|
assert data["error"] == "Unauthorized"
|
||||||
|
|
||||||
|
def test_api_key_set_wrong_header_returns_401(self):
|
||||||
|
"""AGENTKIT_API_KEY set, wrong header → 401."""
|
||||||
|
with patch.dict(os.environ, {"AGENTKIT_API_KEY": "test-secret-key"}):
|
||||||
|
app = _make_minimal_app()
|
||||||
|
app.add_middleware(APIKeyAuthMiddleware)
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
response = client.get(
|
||||||
|
"/api/v1/protected",
|
||||||
|
headers={"X-API-Key": "wrong-key"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
def test_api_key_set_correct_header_returns_200(self):
|
||||||
|
"""AGENTKIT_API_KEY set, correct header → 200."""
|
||||||
|
with patch.dict(os.environ, {"AGENTKIT_API_KEY": "test-secret-key"}):
|
||||||
|
app = _make_minimal_app()
|
||||||
|
app.add_middleware(APIKeyAuthMiddleware)
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
response = client.get(
|
||||||
|
"/api/v1/protected",
|
||||||
|
headers={"X-API-Key": "test-secret-key"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["data"] == "secret"
|
||||||
|
|
||||||
|
def test_health_check_path_no_auth_required(self):
|
||||||
|
"""Health check path → 200 without API key."""
|
||||||
|
with patch.dict(os.environ, {"AGENTKIT_API_KEY": "test-secret-key"}):
|
||||||
|
app = _make_minimal_app()
|
||||||
|
app.add_middleware(APIKeyAuthMiddleware)
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
response = client.get("/api/v1/health")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["status"] == "ok"
|
||||||
|
|
||||||
|
def test_programmatic_api_key_parameter(self):
|
||||||
|
"""Programmatic api_key parameter → uses passed key."""
|
||||||
|
with patch.dict(os.environ, {}, clear=False):
|
||||||
|
os.environ.pop("AGENTKIT_API_KEY", None)
|
||||||
|
|
||||||
|
app = _make_minimal_app()
|
||||||
|
# Set the API key via environment before adding middleware
|
||||||
|
os.environ["AGENTKIT_API_KEY"] = "programmatic-key"
|
||||||
|
app.add_middleware(APIKeyAuthMiddleware)
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
response = client.get(
|
||||||
|
"/api/v1/protected",
|
||||||
|
headers={"X-API-Key": "programmatic-key"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# RateLimiter Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestRateLimiter:
|
||||||
|
"""Fixed-window rate limiter unit tests."""
|
||||||
|
|
||||||
|
def test_requests_within_limit_allowed(self):
|
||||||
|
"""Requests within limit → allowed."""
|
||||||
|
limiter = RateLimiter(max_requests=5, window_seconds=60)
|
||||||
|
|
||||||
|
for i in range(5):
|
||||||
|
allowed, retry_after = limiter.is_allowed("test-key")
|
||||||
|
assert allowed is True
|
||||||
|
assert retry_after == 0.0
|
||||||
|
|
||||||
|
def test_requests_exceed_limit_denied(self):
|
||||||
|
"""Requests exceed limit → denied with retry_after."""
|
||||||
|
limiter = RateLimiter(max_requests=2, window_seconds=60)
|
||||||
|
|
||||||
|
# Use up the limit
|
||||||
|
limiter.is_allowed("test-key")
|
||||||
|
limiter.is_allowed("test-key")
|
||||||
|
|
||||||
|
# Next request should be denied
|
||||||
|
allowed, retry_after = limiter.is_allowed("test-key")
|
||||||
|
assert allowed is False
|
||||||
|
assert retry_after > 0
|
||||||
|
|
||||||
|
def test_after_window_expires_counter_resets(self):
|
||||||
|
"""After window expires → counter resets."""
|
||||||
|
limiter = RateLimiter(max_requests=2, window_seconds=1)
|
||||||
|
|
||||||
|
# Use up the limit
|
||||||
|
limiter.is_allowed("test-key")
|
||||||
|
limiter.is_allowed("test-key")
|
||||||
|
|
||||||
|
# Should be denied
|
||||||
|
allowed, _ = limiter.is_allowed("test-key")
|
||||||
|
assert allowed is False
|
||||||
|
|
||||||
|
# Wait for window to expire
|
||||||
|
time.sleep(1.1)
|
||||||
|
|
||||||
|
# Should be allowed again
|
||||||
|
allowed, retry_after = limiter.is_allowed("test-key")
|
||||||
|
assert allowed is True
|
||||||
|
|
||||||
|
def test_different_keys_independent_counters(self):
|
||||||
|
"""Different keys have independent counters."""
|
||||||
|
limiter = RateLimiter(max_requests=1, window_seconds=60)
|
||||||
|
|
||||||
|
# Use up key-a's limit
|
||||||
|
limiter.is_allowed("key-a")
|
||||||
|
|
||||||
|
# key-a should be denied
|
||||||
|
allowed_a, _ = limiter.is_allowed("key-a")
|
||||||
|
assert allowed_a is False
|
||||||
|
|
||||||
|
# key-b should still be allowed
|
||||||
|
allowed_b, _ = limiter.is_allowed("key-b")
|
||||||
|
assert allowed_b is True
|
||||||
|
|
||||||
|
def test_max_requests_property(self):
|
||||||
|
"""max_requests property returns configured value."""
|
||||||
|
limiter = RateLimiter(max_requests=100, window_seconds=30)
|
||||||
|
assert limiter.max_requests == 100
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# RateLimitMiddleware Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestRateLimitMiddleware:
|
||||||
|
"""Rate limiting middleware integration tests."""
|
||||||
|
|
||||||
|
def test_returns_429_with_retry_after_header(self):
|
||||||
|
"""Returns 429 with Retry-After header when limit exceeded."""
|
||||||
|
app = _make_minimal_app()
|
||||||
|
app.add_middleware(RateLimitMiddleware, max_requests=1, window_seconds=60)
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
# First request should pass
|
||||||
|
response1 = client.get("/api/v1/protected")
|
||||||
|
assert response1.status_code == 200
|
||||||
|
|
||||||
|
# Second request should be rate limited
|
||||||
|
response2 = client.get("/api/v1/protected")
|
||||||
|
assert response2.status_code == 429
|
||||||
|
data = response2.json()
|
||||||
|
assert data["error"] == "Too Many Requests"
|
||||||
|
assert "Retry-After" in response2.headers
|
||||||
|
|
||||||
|
def test_uses_api_key_for_identity(self):
|
||||||
|
"""Uses API key for identity when present (different keys = different limits)."""
|
||||||
|
app = _make_minimal_app()
|
||||||
|
app.add_middleware(RateLimitMiddleware, max_requests=1, window_seconds=60)
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
# Request with key-a
|
||||||
|
response_a1 = client.get(
|
||||||
|
"/api/v1/protected",
|
||||||
|
headers={"X-API-Key": "key-a"},
|
||||||
|
)
|
||||||
|
assert response_a1.status_code == 200
|
||||||
|
|
||||||
|
# key-a should now be rate limited
|
||||||
|
response_a2 = client.get(
|
||||||
|
"/api/v1/protected",
|
||||||
|
headers={"X-API-Key": "key-a"},
|
||||||
|
)
|
||||||
|
assert response_a2.status_code == 429
|
||||||
|
|
||||||
|
# key-b should still be allowed (independent counter)
|
||||||
|
response_b1 = client.get(
|
||||||
|
"/api/v1/protected",
|
||||||
|
headers={"X-API-Key": "key-b"},
|
||||||
|
)
|
||||||
|
assert response_b1.status_code == 200
|
||||||
Loading…
Reference in New Issue