382 lines
14 KiB
Python
382 lines
14 KiB
Python
"""Performance tests: concurrent access, response time, and rate limiting."""
|
|
import asyncio
|
|
import time
|
|
import uuid
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
from httpx import AsyncClient, ASGITransport
|
|
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
|
|
from sqlalchemy.pool import StaticPool
|
|
|
|
from app.database import Base
|
|
from app.main import app
|
|
from app.models.user import User
|
|
from app.models.query import Query
|
|
from app.models.brand import Brand
|
|
from app.models.competitor import Competitor
|
|
from app.models.suggestion import Suggestion
|
|
from app.api.deps import get_current_user, get_db
|
|
from app.services.auth import create_access_token, hash_password
|
|
from tests.fixtures.auth import _to_uuid
|
|
|
|
# Only the tables needed for performance tests (avoids JSONB/SQLite incompatibility)
|
|
_TEST_TABLES = (
|
|
User.__table__,
|
|
Query.__table__,
|
|
Brand.__table__,
|
|
Competitor.__table__,
|
|
Suggestion.__table__,
|
|
)
|
|
|
|
|
|
# ─────────────────────── Fixtures ───────────────────────
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def async_engine():
|
|
"""Create async engine for testing with SQLite.
|
|
|
|
Only creates the specific tables needed by performance tests,
|
|
avoiding PostgreSQL-only types (JSONB) that fail on SQLite.
|
|
"""
|
|
engine = create_async_engine(
|
|
"sqlite+aiosqlite:///:memory:",
|
|
connect_args={"check_same_thread": False},
|
|
poolclass=StaticPool,
|
|
)
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(
|
|
lambda sync_conn: Base.metadata.create_all(
|
|
sync_conn, tables=[t for t in _TEST_TABLES]
|
|
)
|
|
)
|
|
yield engine
|
|
await engine.dispose()
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def async_session(async_engine):
|
|
"""Create async session for testing."""
|
|
async_session_maker = async_sessionmaker(
|
|
async_engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False,
|
|
autoflush=False,
|
|
autocommit=False,
|
|
)
|
|
async with async_session_maker() as session:
|
|
yield session
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def test_user(async_session):
|
|
"""Create a test user with properly hashed password."""
|
|
user = User(
|
|
id=str(uuid.uuid4()),
|
|
email="perf_test@example.com",
|
|
password=hash_password("PerfTest123!"),
|
|
firstName="Performance Test User",
|
|
plan="free",
|
|
max_queries=50,
|
|
isActive=True,
|
|
emailVerified=True,
|
|
)
|
|
async_session.add(user)
|
|
await async_session.commit()
|
|
await async_session.refresh(user)
|
|
return user
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def async_client(async_session, test_user):
|
|
"""Create async client for API testing with dependency overrides."""
|
|
session = async_session
|
|
|
|
async def override_get_db():
|
|
yield session
|
|
|
|
async def override_get_current_user():
|
|
return test_user
|
|
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
app.dependency_overrides[get_current_user] = override_get_current_user
|
|
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
yield client
|
|
|
|
app.dependency_overrides.clear()
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def client_no_override(async_session):
|
|
"""Create async client WITHOUT overriding get_current_user (for real auth flow)."""
|
|
session = async_session
|
|
|
|
async def override_get_db():
|
|
yield session
|
|
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
yield client
|
|
|
|
app.dependency_overrides.clear()
|
|
|
|
|
|
@pytest.fixture
|
|
def auth_headers(test_user):
|
|
"""Create authentication headers."""
|
|
token = create_access_token(data={"sub": str(test_user.id)})
|
|
return {"Authorization": f"Bearer {token}"}
|
|
|
|
|
|
# ═══════════════════════════════════════════════════════════
|
|
# API Response Time Tests
|
|
# ═══════════════════════════════════════════════════════════
|
|
|
|
|
|
class TestAPIPerformance:
|
|
"""Test API response time and concurrency behavior."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_health_check_fast(self, async_client):
|
|
"""Health check endpoint should respond quickly (< 100ms)."""
|
|
start = time.time()
|
|
response = await async_client.get("/health")
|
|
elapsed = time.time() - start
|
|
assert response.status_code == 200
|
|
assert elapsed < 0.1, f"Health check took {elapsed:.3f}s, expected < 0.1s"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_brand_list_performance(self, async_client, async_session, test_user, auth_headers):
|
|
"""Brand list API should respond within 500ms."""
|
|
# Create several brands for a more realistic test
|
|
for i in range(10):
|
|
brand = Brand(
|
|
user_id=_to_uuid(test_user.id),
|
|
name=f"Brand {i}",
|
|
platforms=["wenxin"],
|
|
status="active",
|
|
)
|
|
async_session.add(brand)
|
|
await async_session.commit()
|
|
|
|
start = time.time()
|
|
response = await async_client.get("/api/v1/brands/", headers=auth_headers)
|
|
elapsed = time.time() - start
|
|
|
|
assert response.status_code == 200
|
|
assert elapsed < 0.5, f"Brand list took {elapsed:.3f}s, expected < 0.5s"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_query_list_performance(self, async_client, async_session, test_user, auth_headers):
|
|
"""Query list API should respond within 500ms."""
|
|
# Create several queries for a more realistic test
|
|
for i in range(10):
|
|
query = Query(
|
|
user_id=_to_uuid(test_user.id),
|
|
keyword=f"query keyword {i}",
|
|
target_brand=f"Brand {i}",
|
|
platforms=["wenxin"],
|
|
status="active",
|
|
)
|
|
async_session.add(query)
|
|
await async_session.commit()
|
|
|
|
start = time.time()
|
|
response = await async_client.get("/api/v1/queries/", headers=auth_headers)
|
|
elapsed = time.time() - start
|
|
|
|
assert response.status_code == 200
|
|
assert elapsed < 0.5, f"Query list took {elapsed:.3f}s, expected < 0.5s"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_me_endpoint_performance(self, async_client, auth_headers):
|
|
"""Current user endpoint should respond within 200ms."""
|
|
start = time.time()
|
|
response = await async_client.get("/api/v1/auth/me", headers=auth_headers)
|
|
elapsed = time.time() - start
|
|
|
|
assert response.status_code == 200
|
|
assert elapsed < 0.2, f"/auth/me took {elapsed:.3f}s, expected < 0.2s"
|
|
|
|
|
|
# ═══════════════════════════════════════════════════════════
|
|
# Concurrency Tests
|
|
# ═══════════════════════════════════════════════════════════
|
|
|
|
|
|
class TestConcurrency:
|
|
"""Test concurrent access behavior."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_concurrent_login_no_crash(self, client_no_override, test_user):
|
|
"""50 concurrent login requests should not cause system crash.
|
|
|
|
Note: Rate limiting will kick in after 5 attempts from same IP,
|
|
so most requests will get 429. The key point is no 500 errors.
|
|
"""
|
|
tasks = []
|
|
for _ in range(50):
|
|
task = client_no_override.post(
|
|
"/api/v1/auth/login",
|
|
json={
|
|
"email": "perf_test@example.com",
|
|
"password": "PerfTest123!",
|
|
},
|
|
)
|
|
tasks.append(task)
|
|
|
|
responses = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
for i, resp in enumerate(responses):
|
|
if isinstance(resp, Exception):
|
|
# Network/transport errors are acceptable under heavy load
|
|
continue
|
|
# Should NOT get 500 — rate limiting (429) or auth errors (401/422) are fine
|
|
assert resp.status_code in (200, 401, 422, 429), (
|
|
f"Concurrent login request {i} returned {resp.status_code}, "
|
|
f"expected 200/401/422/429"
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_concurrent_brand_reads(self, async_client, async_session, test_user, auth_headers):
|
|
"""Concurrent brand list reads should all succeed."""
|
|
# Pre-create data
|
|
for i in range(5):
|
|
brand = Brand(
|
|
user_id=_to_uuid(test_user.id),
|
|
name=f"Concurrent Brand {i}",
|
|
platforms=["wenxin"],
|
|
status="active",
|
|
)
|
|
async_session.add(brand)
|
|
await async_session.commit()
|
|
|
|
# 20 concurrent read requests
|
|
tasks = [
|
|
async_client.get("/api/v1/brands/", headers=auth_headers)
|
|
for _ in range(20)
|
|
]
|
|
responses = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
success_count = 0
|
|
for resp in responses:
|
|
if isinstance(resp, Exception):
|
|
continue
|
|
if resp.status_code == 200:
|
|
success_count += 1
|
|
|
|
# At least some should succeed
|
|
assert success_count > 0, "No concurrent brand reads succeeded"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_concurrent_query_reads(self, async_client, async_session, test_user, auth_headers):
|
|
"""Concurrent query list reads should all succeed."""
|
|
for i in range(5):
|
|
query = Query(
|
|
user_id=_to_uuid(test_user.id),
|
|
keyword=f"concurrent query {i}",
|
|
target_brand=f"Brand {i}",
|
|
platforms=["wenxin"],
|
|
status="active",
|
|
)
|
|
async_session.add(query)
|
|
await async_session.commit()
|
|
|
|
tasks = [
|
|
async_client.get("/api/v1/queries/", headers=auth_headers)
|
|
for _ in range(20)
|
|
]
|
|
responses = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
success_count = 0
|
|
for resp in responses:
|
|
if isinstance(resp, Exception):
|
|
continue
|
|
if resp.status_code == 200:
|
|
success_count += 1
|
|
|
|
assert success_count > 0, "No concurrent query reads succeeded"
|
|
|
|
|
|
# ═══════════════════════════════════════════════════════════
|
|
# Rate Limiting Tests
|
|
# ═══════════════════════════════════════════════════════════
|
|
|
|
|
|
class TestRateLimiting:
|
|
"""Test rate limiting enforcement."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_login_rate_limit_enforcement(self, client_no_override):
|
|
"""Login endpoint should enforce rate limiting (5 req/min/IP).
|
|
|
|
After 5 rapid login attempts, subsequent requests should get 429.
|
|
"""
|
|
# Note: RateLimitMiddleware state is shared across the app instance.
|
|
# Since other tests may have already sent requests, we need to
|
|
# send enough to trigger the limit. The auth_strict rule allows
|
|
# 5 requests per 60 seconds per IP.
|
|
responses = []
|
|
for _ in range(8):
|
|
response = await client_no_override.post(
|
|
"/api/v1/auth/login",
|
|
json={
|
|
"email": "ratelimit@example.com",
|
|
"password": "somepassword123",
|
|
},
|
|
)
|
|
responses.append(response)
|
|
|
|
# At least one of the later requests should be rate-limited (429)
|
|
status_codes = [r.status_code for r in responses]
|
|
# After 5 requests, additional ones should get 429
|
|
rate_limited = [code for code in status_codes if code == 429]
|
|
assert len(rate_limited) > 0, (
|
|
f"No rate limiting detected. Status codes: {status_codes}. "
|
|
f"Expected at least one 429 after 8 rapid login attempts."
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_global_rate_limit_high_threshold(self, async_client, auth_headers):
|
|
"""Global rate limit (100 req/min) should allow normal usage patterns.
|
|
|
|
Sending 10 rapid requests should all succeed (well within limit).
|
|
"""
|
|
responses = []
|
|
for _ in range(10):
|
|
response = await async_client.get("/health")
|
|
responses.append(response)
|
|
|
|
# All should succeed — well under global limit
|
|
success_count = sum(1 for r in responses if r.status_code == 200)
|
|
assert success_count == 10, (
|
|
f"Expected all 10 health checks to succeed, got {success_count}/10"
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rate_limit_429_response_format(self, client_no_override):
|
|
"""Rate-limited responses should have proper 429 format."""
|
|
# Exhaust login rate limit
|
|
for _ in range(8):
|
|
await client_no_override.post(
|
|
"/api/v1/auth/login",
|
|
json={"email": "rl@example.com", "password": "password123"},
|
|
)
|
|
|
|
# This one should be rate-limited
|
|
response = await client_no_override.post(
|
|
"/api/v1/auth/login",
|
|
json={"email": "rl@example.com", "password": "password123"},
|
|
)
|
|
|
|
if response.status_code == 429:
|
|
data = response.json()
|
|
# Should have a detail message
|
|
assert "detail" in data or "message" in data, (
|
|
"429 response should include a detail or message field"
|
|
) |