From 5f1c51cf9a9a52aa78f6a43286dbfdf5f423d25f Mon Sep 17 00:00:00 2001 From: chiguyong Date: Fri, 5 Jun 2026 23:37:36 +0800 Subject: [PATCH] feat(server): Phase B - auth, rate limiting, SSRF protection, handler whitelist U1: API Key authentication middleware (dev mode skip, health whitelist) U2: Rate limiting middleware (fixed-window, 60 req/min default) U3: Callback URL SSRF protection (private IP blocking) U4: custom_handler module prefix whitelist 65 tests passing. CORS conflict fixed. --- src/agentkit/core/config_driven.py | 14 ++ src/agentkit/core/dispatcher.py | 54 ++++++ src/agentkit/server/app.py | 15 +- src/agentkit/server/middleware.py | 105 ++++++++++++ tests/unit/test_config_driven.py | 70 ++++++++ tests/unit/test_dispatcher.py | 54 +++++- tests/unit/test_server_middleware.py | 242 +++++++++++++++++++++++++++ 7 files changed, 552 insertions(+), 2 deletions(-) create mode 100644 src/agentkit/server/middleware.py create mode 100644 tests/unit/test_server_middleware.py diff --git a/src/agentkit/core/config_driven.py b/src/agentkit/core/config_driven.py index 4727030..7de51d6 100644 --- a/src/agentkit/core/config_driven.py +++ b/src/agentkit/core/config_driven.py @@ -184,6 +184,12 @@ class ConfigDrivenAgent(BaseAgent): - retrieve_knowledge """ + # Security: whitelist of allowed module prefixes for dynamic handler import + _ALLOWED_HANDLER_PREFIXES = ( + "agentkit.", + "app.agent_framework.", + ) + def __init__( self, config: AgentConfig, @@ -566,6 +572,14 @@ class ConfigDrivenAgent(BaseAgent): def _import_handler(self, dotted_path: str) -> Callable[..., Coroutine]: """动态导入自定义 handler""" + # Security: validate module prefix to prevent arbitrary code execution + if not any(dotted_path.startswith(prefix) for prefix in self._ALLOWED_HANDLER_PREFIXES): + raise ConfigValidationError( + agent_name=self.name, + key="custom_handler", + reason=f"Handler '{dotted_path}' is not in allowed module prefixes: {self._ALLOWED_HANDLER_PREFIXES}", + ) + try: module_path, func_name = dotted_path.rsplit(".", 1) import importlib diff --git a/src/agentkit/core/dispatcher.py b/src/agentkit/core/dispatcher.py index f96a5d0..5463343 100644 --- a/src/agentkit/core/dispatcher.py +++ b/src/agentkit/core/dispatcher.py @@ -3,11 +3,13 @@ 与业务系统解耦:通过依赖注入获取 Redis 连接和数据库会话。 """ +import ipaddress import json import logging import uuid from datetime import datetime, timezone from typing import Any, Callable, Awaitable +from urllib.parse import urlparse from agentkit.core.exceptions import ( NoAvailableAgentError, @@ -24,6 +26,54 @@ from agentkit.core.protocol import ( logger = logging.getLogger(__name__) +_PRIVATE_NETWORKS = [ + ipaddress.ip_network("127.0.0.0/8"), + ipaddress.ip_network("10.0.0.0/8"), + ipaddress.ip_network("172.16.0.0/12"), + ipaddress.ip_network("192.168.0.0/16"), + ipaddress.ip_network("169.254.0.0/16"), + ipaddress.ip_network("::1/128"), + ipaddress.ip_network("fc00::/7"), + ipaddress.ip_network("fe80::/10"), +] + + +def _validate_callback_url(url: str) -> bool: + """Validate callback URL to prevent SSRF attacks. + + Rules: + - Only http/https protocols allowed + - No localhost or loopback addresses + - No private/internal IP ranges + - No link-local addresses + + Returns True if valid, False if should be blocked. + """ + try: + parsed = urlparse(url) + except Exception: + return False + + if parsed.scheme not in ("http", "https"): + return False + + hostname = parsed.hostname + if not hostname: + return False + + if hostname.lower() in ("localhost", "127.0.0.1", "::1"): + return False + + try: + ip = ipaddress.ip_address(hostname) + for network in _PRIVATE_NETWORKS: + if ip in network: + return False + except ValueError: + pass + + return True + class TaskDispatcher: """任务分发器,通过 Redis Queue 将任务分发给 Agent""" @@ -333,6 +383,10 @@ class TaskDispatcher: db.add(log_entry) async def _trigger_callback(self, callback_url: str, result: TaskResult): + if not _validate_callback_url(callback_url): + logger.warning(f"Callback URL rejected (SSRF protection): {callback_url}") + return + try: import httpx async with httpx.AsyncClient(timeout=10) as client: diff --git a/src/agentkit/server/app.py b/src/agentkit/server/app.py index 2d7df86..3e08ee3 100644 --- a/src/agentkit/server/app.py +++ b/src/agentkit/server/app.py @@ -1,5 +1,6 @@ """FastAPI Application Factory""" +import os from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware @@ -11,12 +12,15 @@ from agentkit.router.intent import IntentRouter from agentkit.skills.registry import SkillRegistry from agentkit.tools.registry import ToolRegistry from agentkit.server.routes import agents, tasks, skills, llm, health +from agentkit.server.middleware import APIKeyAuthMiddleware, RateLimitMiddleware def create_app( llm_gateway: LLMGateway | None = None, skill_registry: SkillRegistry | None = None, tool_registry: ToolRegistry | None = None, + api_key: str | None = None, + rate_limit: int | None = None, ) -> FastAPI: """Create and configure the FastAPI application""" app = FastAPI(title="AgentKit Server", version="2.0.0") @@ -25,11 +29,20 @@ def create_app( app.add_middleware( CORSMiddleware, allow_origins=["*"], # 生产环境应限制具体域名 - allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) + # Auth middleware + if api_key: + os.environ["AGENTKIT_API_KEY"] = api_key + app.add_middleware(APIKeyAuthMiddleware) + + # Rate limiting middleware + if rate_limit is not None: + os.environ["AGENTKIT_RATE_LIMIT_PER_MINUTE"] = str(rate_limit) + app.add_middleware(RateLimitMiddleware) + # Initialize shared state app.state.llm_gateway = llm_gateway or LLMGateway() app.state.skill_registry = skill_registry or SkillRegistry() diff --git a/src/agentkit/server/middleware.py b/src/agentkit/server/middleware.py new file mode 100644 index 0000000..2497d37 --- /dev/null +++ b/src/agentkit/server/middleware.py @@ -0,0 +1,105 @@ +"""Server middleware - Authentication and Rate Limiting""" + +import os +import time +from collections import defaultdict +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import JSONResponse + + +class APIKeyAuthMiddleware(BaseHTTPMiddleware): + """API Key authentication middleware. + + Validates X-API-Key header against AGENTKIT_API_KEY env var. + Skips validation if AGENTKIT_API_KEY is not set (dev mode). + Whitelisted paths (no auth required): /api/v1/health + """ + + WHITELIST_PATHS = ("/api/v1/health",) + + async def dispatch(self, request: Request, call_next): + # Skip auth for whitelisted paths + if any(request.url.path.startswith(p) for p in self.WHITELIST_PATHS): + return await call_next(request) + + api_key = os.environ.get("AGENTKIT_API_KEY") + if not api_key: + # Dev mode: skip auth if no API key configured + return await call_next(request) + + # Check API key from header + provided_key = request.headers.get("X-API-Key") + if not provided_key or provided_key != api_key: + return JSONResponse( + status_code=401, + content={"error": "Unauthorized", "message": "Invalid or missing API key"}, + ) + + return await call_next(request) + + +class RateLimiter: + """Fixed-window rate limiter. + + Tracks request counts per key (IP or API key) within time windows. + """ + + def __init__(self, max_requests: int = 60, window_seconds: int = 60): + self._max_requests = max_requests + self._window_seconds = window_seconds + self._requests: dict[str, list[float]] = defaultdict(list) + + def is_allowed(self, key: str) -> tuple[bool, float]: + """Check if request is allowed. Returns (allowed, retry_after_seconds).""" + now = time.time() + window_start = now - self._window_seconds + + # Clean old requests outside the window + self._requests[key] = [ + ts for ts in self._requests[key] if ts > window_start + ] + + if len(self._requests[key]) >= self._max_requests: + retry_after = self._requests[key][0] + self._window_seconds - now + return False, max(0, retry_after) + + self._requests[key].append(now) + return True, 0.0 + + @property + def max_requests(self) -> int: + return self._max_requests + + +class RateLimitMiddleware(BaseHTTPMiddleware): + """Rate limiting middleware. + + Limits requests per IP. Returns 429 Too Many Requests when exceeded. + Configurable via AGENTKIT_RATE_LIMIT_PER_MINUTE env var (default: 60). + """ + + def __init__(self, app, max_requests: int | None = None, window_seconds: int = 60): + super().__init__(app) + if max_requests is None: + max_requests = int(os.environ.get("AGENTKIT_RATE_LIMIT_PER_MINUTE", "60")) + self._limiter = RateLimiter(max_requests=max_requests, window_seconds=window_seconds) + + async def dispatch(self, request: Request, call_next): + # Use API key if available, otherwise IP + api_key = request.headers.get("X-API-Key") + key = f"key:{api_key}" if api_key else f"ip:{request.client.host}" + + allowed, retry_after = self._limiter.is_allowed(key) + if not allowed: + return JSONResponse( + status_code=429, + content={ + "error": "Too Many Requests", + "message": f"Rate limit exceeded. Try again in {int(retry_after)} seconds.", + }, + headers={"Retry-After": str(int(retry_after))}, + ) + + response = await call_next(request) + return response diff --git a/tests/unit/test_config_driven.py b/tests/unit/test_config_driven.py index 13b958f..1ba5f4b 100644 --- a/tests/unit/test_config_driven.py +++ b/tests/unit/test_config_driven.py @@ -354,3 +354,73 @@ class TestStandaloneRunner: runner = StandaloneRunner(config_dir="/nonexistent/path") configs = runner.discover_configs() assert len(configs) == 0 + + +# ── Handler Prefix Whitelist 测试 ───────────────────────── + + +class TestHandlerPrefixWhitelist: + """U4: 测试 _import_handler 的模块前缀白名单,防止任意代码执行""" + + def _make_agent_with_custom(self, handler_path: str) -> ConfigDrivenAgent: + config = AgentConfig( + name="test_agent", + agent_type="test", + task_mode="custom", + custom_handler=handler_path, + ) + return ConfigDrivenAgent(config=config) + + def test_allowed_prefix_agentkit(self): + """agentkit.xxx.handler → 允许通过前缀检查""" + agent = self._make_agent_with_custom("agentkit.handlers.test_handler") + # 前缀检查通过,但模块不存在会报 ImportError,我们只验证不报 ConfigValidationError(前缀) + try: + agent._import_handler("agentkit.handlers.test_handler") + except Exception as e: + # 允许 ImportError/AttributeError(模块不存在),但不允许前缀拒绝 + assert "not in allowed module prefixes" not in str(e) + + def test_allowed_prefix_app_agent_framework(self): + """app.agent_framework.handlers.xxx → 允许通过前缀检查""" + agent = self._make_agent_with_custom("app.agent_framework.handlers.xxx_handler") + try: + agent._import_handler("app.agent_framework.handlers.xxx_handler") + except Exception as e: + assert "not in allowed module prefixes" not in str(e) + + def test_blocked_os_system(self): + """os.system → 阻止(ConfigValidationError)""" + agent = self._make_agent_with_custom("os.system") + with pytest.raises(Exception, match="not in allowed module prefixes"): + agent._import_handler("os.system") + + def test_blocked_subprocess_run(self): + """subprocess.run → 阻止""" + agent = self._make_agent_with_custom("subprocess.run") + with pytest.raises(Exception, match="not in allowed module prefixes"): + agent._import_handler("subprocess.run") + + def test_blocked_builtins_exec(self): + """builtins.exec → 阻止""" + agent = self._make_agent_with_custom("builtins.exec") + with pytest.raises(Exception, match="not in allowed module prefixes"): + agent._import_handler("builtins.exec") + + def test_blocked_empty_string(self): + """空字符串 → 阻止(在 _import_handler 级别直接被前缀检查拒绝)""" + config = AgentConfig( + name="test_agent", + agent_type="test", + task_mode="custom", + custom_handler="agentkit.dummy", # valid config, but we test _import_handler directly + ) + agent = ConfigDrivenAgent(config=config) + with pytest.raises(Exception, match="not in allowed module prefixes"): + agent._import_handler("") + + def test_blocked_agentkitx_prefix(self): + """agentkitx. → 阻止(不是 agentkit.)""" + agent = self._make_agent_with_custom("agentkitx.handlers.evil") + with pytest.raises(Exception, match="not in allowed module prefixes"): + agent._import_handler("agentkitx.handlers.evil") diff --git a/tests/unit/test_dispatcher.py b/tests/unit/test_dispatcher.py index 9ee06be..0f03888 100644 --- a/tests/unit/test_dispatcher.py +++ b/tests/unit/test_dispatcher.py @@ -6,7 +6,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest -from agentkit.core.dispatcher import TaskDispatcher +from agentkit.core.dispatcher import TaskDispatcher, _validate_callback_url from agentkit.core.exceptions import TaskDispatchError, TaskNotFoundError from agentkit.core.protocol import AgentStatus, TaskResult, TaskStatus @@ -267,3 +267,55 @@ class TestTaskDispatcherHandleResult: assert task.status == TaskStatus.FAILED assert task.error_message == "Something went wrong" + + +class TestValidateCallbackUrl: + """SSRF protection tests for _validate_callback_url.""" + + def test_valid_public_https_url(self): + """Valid public HTTPS URL should be allowed.""" + assert _validate_callback_url("https://example.com/callback") is True + + def test_valid_public_http_url(self): + """Valid public HTTP URL should be allowed.""" + assert _validate_callback_url("http://example.com/callback") is True + + def test_localhost_blocked(self): + """localhost should be blocked.""" + assert _validate_callback_url("http://localhost:8080/callback") is False + + def test_loopback_ip_blocked(self): + """127.0.0.1 should be blocked.""" + assert _validate_callback_url("http://127.0.0.1:8080/callback") is False + + def test_private_10_range_blocked(self): + """10.0.0.0/8 range should be blocked.""" + assert _validate_callback_url("http://10.0.0.1/internal") is False + + def test_private_192_range_blocked(self): + """192.168.0.0/16 range should be blocked.""" + assert _validate_callback_url("http://192.168.1.1/admin") is False + + def test_private_172_range_blocked(self): + """172.16.0.0/12 range should be blocked.""" + assert _validate_callback_url("http://172.16.0.1/internal") is False + + def test_ftp_protocol_blocked(self): + """FTP protocol should be blocked.""" + assert _validate_callback_url("ftp://example.com/file") is False + + def test_file_protocol_blocked(self): + """file:// protocol should be blocked.""" + assert _validate_callback_url("file:///etc/passwd") is False + + def test_javascript_protocol_blocked(self): + """javascript: protocol should be blocked.""" + assert _validate_callback_url("javascript:alert(1)") is False + + def test_empty_url_blocked(self): + """Empty URL should be blocked.""" + assert _validate_callback_url("") is False + + def test_malformed_url_blocked(self): + """Malformed URL should be blocked.""" + assert _validate_callback_url("not-a-valid-url") is False diff --git a/tests/unit/test_server_middleware.py b/tests/unit/test_server_middleware.py new file mode 100644 index 0000000..d4f7b25 --- /dev/null +++ b/tests/unit/test_server_middleware.py @@ -0,0 +1,242 @@ +"""Server Middleware 单元测试 - API Key Auth + Rate Limiting""" + +import os +import time +import pytest +from unittest.mock import patch +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from agentkit.server.middleware import ( + APIKeyAuthMiddleware, + RateLimiter, + RateLimitMiddleware, +) + + +# --------------------------------------------------------------------------- +# Helper: minimal app with only a health endpoint for isolated middleware tests +# --------------------------------------------------------------------------- + +def _make_minimal_app(): + """Create a minimal FastAPI app with just a health endpoint.""" + app = FastAPI() + + @app.get("/api/v1/health") + async def health(): + return {"status": "ok"} + + @app.get("/api/v1/protected") + async def protected(): + return {"data": "secret"} + + return app + + +# --------------------------------------------------------------------------- +# APIKeyAuthMiddleware Tests +# --------------------------------------------------------------------------- + +class TestAPIKeyAuthMiddleware: + """API Key authentication middleware tests.""" + + def test_dev_mode_no_api_key_set_passes_through(self): + """No AGENTKIT_API_KEY set → requests pass through (dev mode).""" + with patch.dict(os.environ, {}, clear=False): + # Ensure AGENTKIT_API_KEY is not set + os.environ.pop("AGENTKIT_API_KEY", None) + + app = _make_minimal_app() + app.add_middleware(APIKeyAuthMiddleware) + client = TestClient(app) + + response = client.get("/api/v1/protected") + assert response.status_code == 200 + + def test_api_key_set_no_header_returns_401(self): + """AGENTKIT_API_KEY set, no header → 401.""" + with patch.dict(os.environ, {"AGENTKIT_API_KEY": "test-secret-key"}): + app = _make_minimal_app() + app.add_middleware(APIKeyAuthMiddleware) + client = TestClient(app) + + response = client.get("/api/v1/protected") + assert response.status_code == 401 + data = response.json() + assert data["error"] == "Unauthorized" + + def test_api_key_set_wrong_header_returns_401(self): + """AGENTKIT_API_KEY set, wrong header → 401.""" + with patch.dict(os.environ, {"AGENTKIT_API_KEY": "test-secret-key"}): + app = _make_minimal_app() + app.add_middleware(APIKeyAuthMiddleware) + client = TestClient(app) + + response = client.get( + "/api/v1/protected", + headers={"X-API-Key": "wrong-key"}, + ) + assert response.status_code == 401 + + def test_api_key_set_correct_header_returns_200(self): + """AGENTKIT_API_KEY set, correct header → 200.""" + with patch.dict(os.environ, {"AGENTKIT_API_KEY": "test-secret-key"}): + app = _make_minimal_app() + app.add_middleware(APIKeyAuthMiddleware) + client = TestClient(app) + + response = client.get( + "/api/v1/protected", + headers={"X-API-Key": "test-secret-key"}, + ) + assert response.status_code == 200 + assert response.json()["data"] == "secret" + + def test_health_check_path_no_auth_required(self): + """Health check path → 200 without API key.""" + with patch.dict(os.environ, {"AGENTKIT_API_KEY": "test-secret-key"}): + app = _make_minimal_app() + app.add_middleware(APIKeyAuthMiddleware) + client = TestClient(app) + + response = client.get("/api/v1/health") + assert response.status_code == 200 + assert response.json()["status"] == "ok" + + def test_programmatic_api_key_parameter(self): + """Programmatic api_key parameter → uses passed key.""" + with patch.dict(os.environ, {}, clear=False): + os.environ.pop("AGENTKIT_API_KEY", None) + + app = _make_minimal_app() + # Set the API key via environment before adding middleware + os.environ["AGENTKIT_API_KEY"] = "programmatic-key" + app.add_middleware(APIKeyAuthMiddleware) + client = TestClient(app) + + response = client.get( + "/api/v1/protected", + headers={"X-API-Key": "programmatic-key"}, + ) + assert response.status_code == 200 + + +# --------------------------------------------------------------------------- +# RateLimiter Tests +# --------------------------------------------------------------------------- + +class TestRateLimiter: + """Fixed-window rate limiter unit tests.""" + + def test_requests_within_limit_allowed(self): + """Requests within limit → allowed.""" + limiter = RateLimiter(max_requests=5, window_seconds=60) + + for i in range(5): + allowed, retry_after = limiter.is_allowed("test-key") + assert allowed is True + assert retry_after == 0.0 + + def test_requests_exceed_limit_denied(self): + """Requests exceed limit → denied with retry_after.""" + limiter = RateLimiter(max_requests=2, window_seconds=60) + + # Use up the limit + limiter.is_allowed("test-key") + limiter.is_allowed("test-key") + + # Next request should be denied + allowed, retry_after = limiter.is_allowed("test-key") + assert allowed is False + assert retry_after > 0 + + def test_after_window_expires_counter_resets(self): + """After window expires → counter resets.""" + limiter = RateLimiter(max_requests=2, window_seconds=1) + + # Use up the limit + limiter.is_allowed("test-key") + limiter.is_allowed("test-key") + + # Should be denied + allowed, _ = limiter.is_allowed("test-key") + assert allowed is False + + # Wait for window to expire + time.sleep(1.1) + + # Should be allowed again + allowed, retry_after = limiter.is_allowed("test-key") + assert allowed is True + + def test_different_keys_independent_counters(self): + """Different keys have independent counters.""" + limiter = RateLimiter(max_requests=1, window_seconds=60) + + # Use up key-a's limit + limiter.is_allowed("key-a") + + # key-a should be denied + allowed_a, _ = limiter.is_allowed("key-a") + assert allowed_a is False + + # key-b should still be allowed + allowed_b, _ = limiter.is_allowed("key-b") + assert allowed_b is True + + def test_max_requests_property(self): + """max_requests property returns configured value.""" + limiter = RateLimiter(max_requests=100, window_seconds=30) + assert limiter.max_requests == 100 + + +# --------------------------------------------------------------------------- +# RateLimitMiddleware Tests +# --------------------------------------------------------------------------- + +class TestRateLimitMiddleware: + """Rate limiting middleware integration tests.""" + + def test_returns_429_with_retry_after_header(self): + """Returns 429 with Retry-After header when limit exceeded.""" + app = _make_minimal_app() + app.add_middleware(RateLimitMiddleware, max_requests=1, window_seconds=60) + client = TestClient(app) + + # First request should pass + response1 = client.get("/api/v1/protected") + assert response1.status_code == 200 + + # Second request should be rate limited + response2 = client.get("/api/v1/protected") + assert response2.status_code == 429 + data = response2.json() + assert data["error"] == "Too Many Requests" + assert "Retry-After" in response2.headers + + def test_uses_api_key_for_identity(self): + """Uses API key for identity when present (different keys = different limits).""" + app = _make_minimal_app() + app.add_middleware(RateLimitMiddleware, max_requests=1, window_seconds=60) + client = TestClient(app) + + # Request with key-a + response_a1 = client.get( + "/api/v1/protected", + headers={"X-API-Key": "key-a"}, + ) + assert response_a1.status_code == 200 + + # key-a should now be rate limited + response_a2 = client.get( + "/api/v1/protected", + headers={"X-API-Key": "key-a"}, + ) + assert response_a2.status_code == 429 + + # key-b should still be allowed (independent counter) + response_b1 = client.get( + "/api/v1/protected", + headers={"X-API-Key": "key-b"}, + ) + assert response_b1.status_code == 200