243 lines
8.4 KiB
Python
243 lines
8.4 KiB
Python
"""Server Middleware 单元测试 - API Key Auth + Rate Limiting"""
|
|
|
|
import os
|
|
import time
|
|
import pytest
|
|
from unittest.mock import patch
|
|
from fastapi import FastAPI
|
|
from fastapi.testclient import TestClient
|
|
|
|
from agentkit.server.middleware import (
|
|
APIKeyAuthMiddleware,
|
|
RateLimiter,
|
|
RateLimitMiddleware,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helper: minimal app with only a health endpoint for isolated middleware tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _make_minimal_app():
|
|
"""Create a minimal FastAPI app with just a health endpoint."""
|
|
app = FastAPI()
|
|
|
|
@app.get("/api/v1/health")
|
|
async def health():
|
|
return {"status": "ok"}
|
|
|
|
@app.get("/api/v1/protected")
|
|
async def protected():
|
|
return {"data": "secret"}
|
|
|
|
return app
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# APIKeyAuthMiddleware Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestAPIKeyAuthMiddleware:
|
|
"""API Key authentication middleware tests."""
|
|
|
|
def test_dev_mode_no_api_key_set_passes_through(self):
|
|
"""No AGENTKIT_API_KEY set → requests pass through (dev mode)."""
|
|
with patch.dict(os.environ, {}, clear=False):
|
|
# Ensure AGENTKIT_API_KEY is not set
|
|
os.environ.pop("AGENTKIT_API_KEY", None)
|
|
|
|
app = _make_minimal_app()
|
|
app.add_middleware(APIKeyAuthMiddleware)
|
|
client = TestClient(app)
|
|
|
|
response = client.get("/api/v1/protected")
|
|
assert response.status_code == 200
|
|
|
|
def test_api_key_set_no_header_returns_401(self):
|
|
"""AGENTKIT_API_KEY set, no header → 401."""
|
|
with patch.dict(os.environ, {"AGENTKIT_API_KEY": "test-secret-key"}):
|
|
app = _make_minimal_app()
|
|
app.add_middleware(APIKeyAuthMiddleware)
|
|
client = TestClient(app)
|
|
|
|
response = client.get("/api/v1/protected")
|
|
assert response.status_code == 401
|
|
data = response.json()
|
|
assert data["error"] == "Unauthorized"
|
|
|
|
def test_api_key_set_wrong_header_returns_401(self):
|
|
"""AGENTKIT_API_KEY set, wrong header → 401."""
|
|
with patch.dict(os.environ, {"AGENTKIT_API_KEY": "test-secret-key"}):
|
|
app = _make_minimal_app()
|
|
app.add_middleware(APIKeyAuthMiddleware)
|
|
client = TestClient(app)
|
|
|
|
response = client.get(
|
|
"/api/v1/protected",
|
|
headers={"X-API-Key": "wrong-key"},
|
|
)
|
|
assert response.status_code == 401
|
|
|
|
def test_api_key_set_correct_header_returns_200(self):
|
|
"""AGENTKIT_API_KEY set, correct header → 200."""
|
|
with patch.dict(os.environ, {"AGENTKIT_API_KEY": "test-secret-key"}):
|
|
app = _make_minimal_app()
|
|
app.add_middleware(APIKeyAuthMiddleware)
|
|
client = TestClient(app)
|
|
|
|
response = client.get(
|
|
"/api/v1/protected",
|
|
headers={"X-API-Key": "test-secret-key"},
|
|
)
|
|
assert response.status_code == 200
|
|
assert response.json()["data"] == "secret"
|
|
|
|
def test_health_check_path_no_auth_required(self):
|
|
"""Health check path → 200 without API key."""
|
|
with patch.dict(os.environ, {"AGENTKIT_API_KEY": "test-secret-key"}):
|
|
app = _make_minimal_app()
|
|
app.add_middleware(APIKeyAuthMiddleware)
|
|
client = TestClient(app)
|
|
|
|
response = client.get("/api/v1/health")
|
|
assert response.status_code == 200
|
|
assert response.json()["status"] == "ok"
|
|
|
|
def test_programmatic_api_key_parameter(self):
|
|
"""Programmatic api_key parameter → uses passed key."""
|
|
with patch.dict(os.environ, {}, clear=False):
|
|
os.environ.pop("AGENTKIT_API_KEY", None)
|
|
|
|
app = _make_minimal_app()
|
|
# Set the API key via environment before adding middleware
|
|
os.environ["AGENTKIT_API_KEY"] = "programmatic-key"
|
|
app.add_middleware(APIKeyAuthMiddleware)
|
|
client = TestClient(app)
|
|
|
|
response = client.get(
|
|
"/api/v1/protected",
|
|
headers={"X-API-Key": "programmatic-key"},
|
|
)
|
|
assert response.status_code == 200
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# RateLimiter Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestRateLimiter:
|
|
"""Fixed-window rate limiter unit tests."""
|
|
|
|
def test_requests_within_limit_allowed(self):
|
|
"""Requests within limit → allowed."""
|
|
limiter = RateLimiter(max_requests=5, window_seconds=60)
|
|
|
|
for i in range(5):
|
|
allowed, retry_after = limiter.is_allowed("test-key")
|
|
assert allowed is True
|
|
assert retry_after == 0.0
|
|
|
|
def test_requests_exceed_limit_denied(self):
|
|
"""Requests exceed limit → denied with retry_after."""
|
|
limiter = RateLimiter(max_requests=2, window_seconds=60)
|
|
|
|
# Use up the limit
|
|
limiter.is_allowed("test-key")
|
|
limiter.is_allowed("test-key")
|
|
|
|
# Next request should be denied
|
|
allowed, retry_after = limiter.is_allowed("test-key")
|
|
assert allowed is False
|
|
assert retry_after > 0
|
|
|
|
def test_after_window_expires_counter_resets(self):
|
|
"""After window expires → counter resets."""
|
|
limiter = RateLimiter(max_requests=2, window_seconds=1)
|
|
|
|
# Use up the limit
|
|
limiter.is_allowed("test-key")
|
|
limiter.is_allowed("test-key")
|
|
|
|
# Should be denied
|
|
allowed, _ = limiter.is_allowed("test-key")
|
|
assert allowed is False
|
|
|
|
# Wait for window to expire
|
|
time.sleep(1.1)
|
|
|
|
# Should be allowed again
|
|
allowed, retry_after = limiter.is_allowed("test-key")
|
|
assert allowed is True
|
|
|
|
def test_different_keys_independent_counters(self):
|
|
"""Different keys have independent counters."""
|
|
limiter = RateLimiter(max_requests=1, window_seconds=60)
|
|
|
|
# Use up key-a's limit
|
|
limiter.is_allowed("key-a")
|
|
|
|
# key-a should be denied
|
|
allowed_a, _ = limiter.is_allowed("key-a")
|
|
assert allowed_a is False
|
|
|
|
# key-b should still be allowed
|
|
allowed_b, _ = limiter.is_allowed("key-b")
|
|
assert allowed_b is True
|
|
|
|
def test_max_requests_property(self):
|
|
"""max_requests property returns configured value."""
|
|
limiter = RateLimiter(max_requests=100, window_seconds=30)
|
|
assert limiter.max_requests == 100
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# RateLimitMiddleware Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestRateLimitMiddleware:
|
|
"""Rate limiting middleware integration tests."""
|
|
|
|
def test_returns_429_with_retry_after_header(self):
|
|
"""Returns 429 with Retry-After header when limit exceeded."""
|
|
app = _make_minimal_app()
|
|
app.add_middleware(RateLimitMiddleware, max_requests=1, window_seconds=60)
|
|
client = TestClient(app)
|
|
|
|
# First request should pass
|
|
response1 = client.get("/api/v1/protected")
|
|
assert response1.status_code == 200
|
|
|
|
# Second request should be rate limited
|
|
response2 = client.get("/api/v1/protected")
|
|
assert response2.status_code == 429
|
|
data = response2.json()
|
|
assert data["error"] == "Too Many Requests"
|
|
assert "Retry-After" in response2.headers
|
|
|
|
def test_uses_api_key_for_identity(self):
|
|
"""Uses API key for identity when present (different keys = different limits)."""
|
|
app = _make_minimal_app()
|
|
app.add_middleware(RateLimitMiddleware, max_requests=1, window_seconds=60)
|
|
client = TestClient(app)
|
|
|
|
# Request with key-a
|
|
response_a1 = client.get(
|
|
"/api/v1/protected",
|
|
headers={"X-API-Key": "key-a"},
|
|
)
|
|
assert response_a1.status_code == 200
|
|
|
|
# key-a should now be rate limited
|
|
response_a2 = client.get(
|
|
"/api/v1/protected",
|
|
headers={"X-API-Key": "key-a"},
|
|
)
|
|
assert response_a2.status_code == 429
|
|
|
|
# key-b should still be allowed (independent counter)
|
|
response_b1 = client.get(
|
|
"/api/v1/protected",
|
|
headers={"X-API-Key": "key-b"},
|
|
)
|
|
assert response_b1.status_code == 200
|