feat(server): Phase B - auth, rate limiting, SSRF protection, handler whitelist

U1: API Key authentication middleware (dev mode skip, health whitelist)
U2: Rate limiting middleware (fixed-window, 60 req/min default)
U3: Callback URL SSRF protection (private IP blocking)
U4: custom_handler module prefix whitelist

65 tests passing. CORS conflict fixed.
This commit is contained in:
chiguyong 2026-06-05 23:37:36 +08:00
parent f87b790c0f
commit 5f1c51cf9a
7 changed files with 552 additions and 2 deletions

View File

@ -184,6 +184,12 @@ class ConfigDrivenAgent(BaseAgent):
- retrieve_knowledge - retrieve_knowledge
""" """
# Security: whitelist of allowed module prefixes for dynamic handler import
_ALLOWED_HANDLER_PREFIXES = (
"agentkit.",
"app.agent_framework.",
)
def __init__( def __init__(
self, self,
config: AgentConfig, config: AgentConfig,
@ -566,6 +572,14 @@ class ConfigDrivenAgent(BaseAgent):
def _import_handler(self, dotted_path: str) -> Callable[..., Coroutine]: def _import_handler(self, dotted_path: str) -> Callable[..., Coroutine]:
"""动态导入自定义 handler""" """动态导入自定义 handler"""
# Security: validate module prefix to prevent arbitrary code execution
if not any(dotted_path.startswith(prefix) for prefix in self._ALLOWED_HANDLER_PREFIXES):
raise ConfigValidationError(
agent_name=self.name,
key="custom_handler",
reason=f"Handler '{dotted_path}' is not in allowed module prefixes: {self._ALLOWED_HANDLER_PREFIXES}",
)
try: try:
module_path, func_name = dotted_path.rsplit(".", 1) module_path, func_name = dotted_path.rsplit(".", 1)
import importlib import importlib

View File

@ -3,11 +3,13 @@
与业务系统解耦通过依赖注入获取 Redis 连接和数据库会话 与业务系统解耦通过依赖注入获取 Redis 连接和数据库会话
""" """
import ipaddress
import json import json
import logging import logging
import uuid import uuid
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Any, Callable, Awaitable from typing import Any, Callable, Awaitable
from urllib.parse import urlparse
from agentkit.core.exceptions import ( from agentkit.core.exceptions import (
NoAvailableAgentError, NoAvailableAgentError,
@ -24,6 +26,54 @@ from agentkit.core.protocol import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_PRIVATE_NETWORKS = [
ipaddress.ip_network("127.0.0.0/8"),
ipaddress.ip_network("10.0.0.0/8"),
ipaddress.ip_network("172.16.0.0/12"),
ipaddress.ip_network("192.168.0.0/16"),
ipaddress.ip_network("169.254.0.0/16"),
ipaddress.ip_network("::1/128"),
ipaddress.ip_network("fc00::/7"),
ipaddress.ip_network("fe80::/10"),
]
def _validate_callback_url(url: str) -> bool:
"""Validate callback URL to prevent SSRF attacks.
Rules:
- Only http/https protocols allowed
- No localhost or loopback addresses
- No private/internal IP ranges
- No link-local addresses
Returns True if valid, False if should be blocked.
"""
try:
parsed = urlparse(url)
except Exception:
return False
if parsed.scheme not in ("http", "https"):
return False
hostname = parsed.hostname
if not hostname:
return False
if hostname.lower() in ("localhost", "127.0.0.1", "::1"):
return False
try:
ip = ipaddress.ip_address(hostname)
for network in _PRIVATE_NETWORKS:
if ip in network:
return False
except ValueError:
pass
return True
class TaskDispatcher: class TaskDispatcher:
"""任务分发器,通过 Redis Queue 将任务分发给 Agent""" """任务分发器,通过 Redis Queue 将任务分发给 Agent"""
@ -333,6 +383,10 @@ class TaskDispatcher:
db.add(log_entry) db.add(log_entry)
async def _trigger_callback(self, callback_url: str, result: TaskResult): async def _trigger_callback(self, callback_url: str, result: TaskResult):
if not _validate_callback_url(callback_url):
logger.warning(f"Callback URL rejected (SSRF protection): {callback_url}")
return
try: try:
import httpx import httpx
async with httpx.AsyncClient(timeout=10) as client: async with httpx.AsyncClient(timeout=10) as client:

View File

@ -1,5 +1,6 @@
"""FastAPI Application Factory""" """FastAPI Application Factory"""
import os
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
@ -11,12 +12,15 @@ from agentkit.router.intent import IntentRouter
from agentkit.skills.registry import SkillRegistry from agentkit.skills.registry import SkillRegistry
from agentkit.tools.registry import ToolRegistry from agentkit.tools.registry import ToolRegistry
from agentkit.server.routes import agents, tasks, skills, llm, health from agentkit.server.routes import agents, tasks, skills, llm, health
from agentkit.server.middleware import APIKeyAuthMiddleware, RateLimitMiddleware
def create_app( def create_app(
llm_gateway: LLMGateway | None = None, llm_gateway: LLMGateway | None = None,
skill_registry: SkillRegistry | None = None, skill_registry: SkillRegistry | None = None,
tool_registry: ToolRegistry | None = None, tool_registry: ToolRegistry | None = None,
api_key: str | None = None,
rate_limit: int | None = None,
) -> FastAPI: ) -> FastAPI:
"""Create and configure the FastAPI application""" """Create and configure the FastAPI application"""
app = FastAPI(title="AgentKit Server", version="2.0.0") app = FastAPI(title="AgentKit Server", version="2.0.0")
@ -25,11 +29,20 @@ def create_app(
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=["*"], # 生产环境应限制具体域名 allow_origins=["*"], # 生产环境应限制具体域名
allow_credentials=True,
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],
) )
# Auth middleware
if api_key:
os.environ["AGENTKIT_API_KEY"] = api_key
app.add_middleware(APIKeyAuthMiddleware)
# Rate limiting middleware
if rate_limit is not None:
os.environ["AGENTKIT_RATE_LIMIT_PER_MINUTE"] = str(rate_limit)
app.add_middleware(RateLimitMiddleware)
# Initialize shared state # Initialize shared state
app.state.llm_gateway = llm_gateway or LLMGateway() app.state.llm_gateway = llm_gateway or LLMGateway()
app.state.skill_registry = skill_registry or SkillRegistry() app.state.skill_registry = skill_registry or SkillRegistry()

View File

@ -0,0 +1,105 @@
"""Server middleware - Authentication and Rate Limiting"""
import os
import time
from collections import defaultdict
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse
class APIKeyAuthMiddleware(BaseHTTPMiddleware):
"""API Key authentication middleware.
Validates X-API-Key header against AGENTKIT_API_KEY env var.
Skips validation if AGENTKIT_API_KEY is not set (dev mode).
Whitelisted paths (no auth required): /api/v1/health
"""
WHITELIST_PATHS = ("/api/v1/health",)
async def dispatch(self, request: Request, call_next):
# Skip auth for whitelisted paths
if any(request.url.path.startswith(p) for p in self.WHITELIST_PATHS):
return await call_next(request)
api_key = os.environ.get("AGENTKIT_API_KEY")
if not api_key:
# Dev mode: skip auth if no API key configured
return await call_next(request)
# Check API key from header
provided_key = request.headers.get("X-API-Key")
if not provided_key or provided_key != api_key:
return JSONResponse(
status_code=401,
content={"error": "Unauthorized", "message": "Invalid or missing API key"},
)
return await call_next(request)
class RateLimiter:
"""Fixed-window rate limiter.
Tracks request counts per key (IP or API key) within time windows.
"""
def __init__(self, max_requests: int = 60, window_seconds: int = 60):
self._max_requests = max_requests
self._window_seconds = window_seconds
self._requests: dict[str, list[float]] = defaultdict(list)
def is_allowed(self, key: str) -> tuple[bool, float]:
"""Check if request is allowed. Returns (allowed, retry_after_seconds)."""
now = time.time()
window_start = now - self._window_seconds
# Clean old requests outside the window
self._requests[key] = [
ts for ts in self._requests[key] if ts > window_start
]
if len(self._requests[key]) >= self._max_requests:
retry_after = self._requests[key][0] + self._window_seconds - now
return False, max(0, retry_after)
self._requests[key].append(now)
return True, 0.0
@property
def max_requests(self) -> int:
return self._max_requests
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Rate limiting middleware.
Limits requests per IP. Returns 429 Too Many Requests when exceeded.
Configurable via AGENTKIT_RATE_LIMIT_PER_MINUTE env var (default: 60).
"""
def __init__(self, app, max_requests: int | None = None, window_seconds: int = 60):
super().__init__(app)
if max_requests is None:
max_requests = int(os.environ.get("AGENTKIT_RATE_LIMIT_PER_MINUTE", "60"))
self._limiter = RateLimiter(max_requests=max_requests, window_seconds=window_seconds)
async def dispatch(self, request: Request, call_next):
# Use API key if available, otherwise IP
api_key = request.headers.get("X-API-Key")
key = f"key:{api_key}" if api_key else f"ip:{request.client.host}"
allowed, retry_after = self._limiter.is_allowed(key)
if not allowed:
return JSONResponse(
status_code=429,
content={
"error": "Too Many Requests",
"message": f"Rate limit exceeded. Try again in {int(retry_after)} seconds.",
},
headers={"Retry-After": str(int(retry_after))},
)
response = await call_next(request)
return response

View File

@ -354,3 +354,73 @@ class TestStandaloneRunner:
runner = StandaloneRunner(config_dir="/nonexistent/path") runner = StandaloneRunner(config_dir="/nonexistent/path")
configs = runner.discover_configs() configs = runner.discover_configs()
assert len(configs) == 0 assert len(configs) == 0
# ── Handler Prefix Whitelist 测试 ─────────────────────────
class TestHandlerPrefixWhitelist:
"""U4: 测试 _import_handler 的模块前缀白名单,防止任意代码执行"""
def _make_agent_with_custom(self, handler_path: str) -> ConfigDrivenAgent:
config = AgentConfig(
name="test_agent",
agent_type="test",
task_mode="custom",
custom_handler=handler_path,
)
return ConfigDrivenAgent(config=config)
def test_allowed_prefix_agentkit(self):
"""agentkit.xxx.handler → 允许通过前缀检查"""
agent = self._make_agent_with_custom("agentkit.handlers.test_handler")
# 前缀检查通过,但模块不存在会报 ImportError我们只验证不报 ConfigValidationError(前缀)
try:
agent._import_handler("agentkit.handlers.test_handler")
except Exception as e:
# 允许 ImportError/AttributeError模块不存在但不允许前缀拒绝
assert "not in allowed module prefixes" not in str(e)
def test_allowed_prefix_app_agent_framework(self):
"""app.agent_framework.handlers.xxx → 允许通过前缀检查"""
agent = self._make_agent_with_custom("app.agent_framework.handlers.xxx_handler")
try:
agent._import_handler("app.agent_framework.handlers.xxx_handler")
except Exception as e:
assert "not in allowed module prefixes" not in str(e)
def test_blocked_os_system(self):
"""os.system → 阻止ConfigValidationError"""
agent = self._make_agent_with_custom("os.system")
with pytest.raises(Exception, match="not in allowed module prefixes"):
agent._import_handler("os.system")
def test_blocked_subprocess_run(self):
"""subprocess.run → 阻止"""
agent = self._make_agent_with_custom("subprocess.run")
with pytest.raises(Exception, match="not in allowed module prefixes"):
agent._import_handler("subprocess.run")
def test_blocked_builtins_exec(self):
"""builtins.exec → 阻止"""
agent = self._make_agent_with_custom("builtins.exec")
with pytest.raises(Exception, match="not in allowed module prefixes"):
agent._import_handler("builtins.exec")
def test_blocked_empty_string(self):
"""空字符串 → 阻止(在 _import_handler 级别直接被前缀检查拒绝)"""
config = AgentConfig(
name="test_agent",
agent_type="test",
task_mode="custom",
custom_handler="agentkit.dummy", # valid config, but we test _import_handler directly
)
agent = ConfigDrivenAgent(config=config)
with pytest.raises(Exception, match="not in allowed module prefixes"):
agent._import_handler("")
def test_blocked_agentkitx_prefix(self):
"""agentkitx. → 阻止(不是 agentkit."""
agent = self._make_agent_with_custom("agentkitx.handlers.evil")
with pytest.raises(Exception, match="not in allowed module prefixes"):
agent._import_handler("agentkitx.handlers.evil")

View File

@ -6,7 +6,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from agentkit.core.dispatcher import TaskDispatcher from agentkit.core.dispatcher import TaskDispatcher, _validate_callback_url
from agentkit.core.exceptions import TaskDispatchError, TaskNotFoundError from agentkit.core.exceptions import TaskDispatchError, TaskNotFoundError
from agentkit.core.protocol import AgentStatus, TaskResult, TaskStatus from agentkit.core.protocol import AgentStatus, TaskResult, TaskStatus
@ -267,3 +267,55 @@ class TestTaskDispatcherHandleResult:
assert task.status == TaskStatus.FAILED assert task.status == TaskStatus.FAILED
assert task.error_message == "Something went wrong" assert task.error_message == "Something went wrong"
class TestValidateCallbackUrl:
"""SSRF protection tests for _validate_callback_url."""
def test_valid_public_https_url(self):
"""Valid public HTTPS URL should be allowed."""
assert _validate_callback_url("https://example.com/callback") is True
def test_valid_public_http_url(self):
"""Valid public HTTP URL should be allowed."""
assert _validate_callback_url("http://example.com/callback") is True
def test_localhost_blocked(self):
"""localhost should be blocked."""
assert _validate_callback_url("http://localhost:8080/callback") is False
def test_loopback_ip_blocked(self):
"""127.0.0.1 should be blocked."""
assert _validate_callback_url("http://127.0.0.1:8080/callback") is False
def test_private_10_range_blocked(self):
"""10.0.0.0/8 range should be blocked."""
assert _validate_callback_url("http://10.0.0.1/internal") is False
def test_private_192_range_blocked(self):
"""192.168.0.0/16 range should be blocked."""
assert _validate_callback_url("http://192.168.1.1/admin") is False
def test_private_172_range_blocked(self):
"""172.16.0.0/12 range should be blocked."""
assert _validate_callback_url("http://172.16.0.1/internal") is False
def test_ftp_protocol_blocked(self):
"""FTP protocol should be blocked."""
assert _validate_callback_url("ftp://example.com/file") is False
def test_file_protocol_blocked(self):
"""file:// protocol should be blocked."""
assert _validate_callback_url("file:///etc/passwd") is False
def test_javascript_protocol_blocked(self):
"""javascript: protocol should be blocked."""
assert _validate_callback_url("javascript:alert(1)") is False
def test_empty_url_blocked(self):
"""Empty URL should be blocked."""
assert _validate_callback_url("") is False
def test_malformed_url_blocked(self):
"""Malformed URL should be blocked."""
assert _validate_callback_url("not-a-valid-url") is False

View File

@ -0,0 +1,242 @@
"""Server Middleware 单元测试 - API Key Auth + Rate Limiting"""
import os
import time
import pytest
from unittest.mock import patch
from fastapi import FastAPI
from fastapi.testclient import TestClient
from agentkit.server.middleware import (
APIKeyAuthMiddleware,
RateLimiter,
RateLimitMiddleware,
)
# ---------------------------------------------------------------------------
# Helper: minimal app with only a health endpoint for isolated middleware tests
# ---------------------------------------------------------------------------
def _make_minimal_app():
"""Create a minimal FastAPI app with just a health endpoint."""
app = FastAPI()
@app.get("/api/v1/health")
async def health():
return {"status": "ok"}
@app.get("/api/v1/protected")
async def protected():
return {"data": "secret"}
return app
# ---------------------------------------------------------------------------
# APIKeyAuthMiddleware Tests
# ---------------------------------------------------------------------------
class TestAPIKeyAuthMiddleware:
"""API Key authentication middleware tests."""
def test_dev_mode_no_api_key_set_passes_through(self):
"""No AGENTKIT_API_KEY set → requests pass through (dev mode)."""
with patch.dict(os.environ, {}, clear=False):
# Ensure AGENTKIT_API_KEY is not set
os.environ.pop("AGENTKIT_API_KEY", None)
app = _make_minimal_app()
app.add_middleware(APIKeyAuthMiddleware)
client = TestClient(app)
response = client.get("/api/v1/protected")
assert response.status_code == 200
def test_api_key_set_no_header_returns_401(self):
"""AGENTKIT_API_KEY set, no header → 401."""
with patch.dict(os.environ, {"AGENTKIT_API_KEY": "test-secret-key"}):
app = _make_minimal_app()
app.add_middleware(APIKeyAuthMiddleware)
client = TestClient(app)
response = client.get("/api/v1/protected")
assert response.status_code == 401
data = response.json()
assert data["error"] == "Unauthorized"
def test_api_key_set_wrong_header_returns_401(self):
"""AGENTKIT_API_KEY set, wrong header → 401."""
with patch.dict(os.environ, {"AGENTKIT_API_KEY": "test-secret-key"}):
app = _make_minimal_app()
app.add_middleware(APIKeyAuthMiddleware)
client = TestClient(app)
response = client.get(
"/api/v1/protected",
headers={"X-API-Key": "wrong-key"},
)
assert response.status_code == 401
def test_api_key_set_correct_header_returns_200(self):
"""AGENTKIT_API_KEY set, correct header → 200."""
with patch.dict(os.environ, {"AGENTKIT_API_KEY": "test-secret-key"}):
app = _make_minimal_app()
app.add_middleware(APIKeyAuthMiddleware)
client = TestClient(app)
response = client.get(
"/api/v1/protected",
headers={"X-API-Key": "test-secret-key"},
)
assert response.status_code == 200
assert response.json()["data"] == "secret"
def test_health_check_path_no_auth_required(self):
"""Health check path → 200 without API key."""
with patch.dict(os.environ, {"AGENTKIT_API_KEY": "test-secret-key"}):
app = _make_minimal_app()
app.add_middleware(APIKeyAuthMiddleware)
client = TestClient(app)
response = client.get("/api/v1/health")
assert response.status_code == 200
assert response.json()["status"] == "ok"
def test_programmatic_api_key_parameter(self):
"""Programmatic api_key parameter → uses passed key."""
with patch.dict(os.environ, {}, clear=False):
os.environ.pop("AGENTKIT_API_KEY", None)
app = _make_minimal_app()
# Set the API key via environment before adding middleware
os.environ["AGENTKIT_API_KEY"] = "programmatic-key"
app.add_middleware(APIKeyAuthMiddleware)
client = TestClient(app)
response = client.get(
"/api/v1/protected",
headers={"X-API-Key": "programmatic-key"},
)
assert response.status_code == 200
# ---------------------------------------------------------------------------
# RateLimiter Tests
# ---------------------------------------------------------------------------
class TestRateLimiter:
"""Fixed-window rate limiter unit tests."""
def test_requests_within_limit_allowed(self):
"""Requests within limit → allowed."""
limiter = RateLimiter(max_requests=5, window_seconds=60)
for i in range(5):
allowed, retry_after = limiter.is_allowed("test-key")
assert allowed is True
assert retry_after == 0.0
def test_requests_exceed_limit_denied(self):
"""Requests exceed limit → denied with retry_after."""
limiter = RateLimiter(max_requests=2, window_seconds=60)
# Use up the limit
limiter.is_allowed("test-key")
limiter.is_allowed("test-key")
# Next request should be denied
allowed, retry_after = limiter.is_allowed("test-key")
assert allowed is False
assert retry_after > 0
def test_after_window_expires_counter_resets(self):
"""After window expires → counter resets."""
limiter = RateLimiter(max_requests=2, window_seconds=1)
# Use up the limit
limiter.is_allowed("test-key")
limiter.is_allowed("test-key")
# Should be denied
allowed, _ = limiter.is_allowed("test-key")
assert allowed is False
# Wait for window to expire
time.sleep(1.1)
# Should be allowed again
allowed, retry_after = limiter.is_allowed("test-key")
assert allowed is True
def test_different_keys_independent_counters(self):
"""Different keys have independent counters."""
limiter = RateLimiter(max_requests=1, window_seconds=60)
# Use up key-a's limit
limiter.is_allowed("key-a")
# key-a should be denied
allowed_a, _ = limiter.is_allowed("key-a")
assert allowed_a is False
# key-b should still be allowed
allowed_b, _ = limiter.is_allowed("key-b")
assert allowed_b is True
def test_max_requests_property(self):
"""max_requests property returns configured value."""
limiter = RateLimiter(max_requests=100, window_seconds=30)
assert limiter.max_requests == 100
# ---------------------------------------------------------------------------
# RateLimitMiddleware Tests
# ---------------------------------------------------------------------------
class TestRateLimitMiddleware:
"""Rate limiting middleware integration tests."""
def test_returns_429_with_retry_after_header(self):
"""Returns 429 with Retry-After header when limit exceeded."""
app = _make_minimal_app()
app.add_middleware(RateLimitMiddleware, max_requests=1, window_seconds=60)
client = TestClient(app)
# First request should pass
response1 = client.get("/api/v1/protected")
assert response1.status_code == 200
# Second request should be rate limited
response2 = client.get("/api/v1/protected")
assert response2.status_code == 429
data = response2.json()
assert data["error"] == "Too Many Requests"
assert "Retry-After" in response2.headers
def test_uses_api_key_for_identity(self):
"""Uses API key for identity when present (different keys = different limits)."""
app = _make_minimal_app()
app.add_middleware(RateLimitMiddleware, max_requests=1, window_seconds=60)
client = TestClient(app)
# Request with key-a
response_a1 = client.get(
"/api/v1/protected",
headers={"X-API-Key": "key-a"},
)
assert response_a1.status_code == 200
# key-a should now be rate limited
response_a2 = client.get(
"/api/v1/protected",
headers={"X-API-Key": "key-a"},
)
assert response_a2.status_code == 429
# key-b should still be allowed (independent counter)
response_b1 = client.get(
"/api/v1/protected",
headers={"X-API-Key": "key-b"},
)
assert response_b1.status_code == 200