452 lines
14 KiB
Python
452 lines
14 KiB
Python
"""认证API测试"""
|
||
import uuid
|
||
|
||
import pytest
|
||
import pytest_asyncio
|
||
from httpx import AsyncClient, ASGITransport
|
||
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession, create_async_engine
|
||
from sqlalchemy.pool import StaticPool
|
||
|
||
from app.database import Base
|
||
from app.main import app
|
||
from app.models.user import User
|
||
from app.api.deps import get_current_user, get_db
|
||
from app.services.auth import hash_password, create_access_token
|
||
|
||
|
||
# ==================== Fixtures ====================
|
||
|
||
@pytest_asyncio.fixture
|
||
async def async_engine():
|
||
"""创建测试用SQLite异步引擎"""
|
||
engine = create_async_engine(
|
||
"sqlite+aiosqlite:///:memory:",
|
||
connect_args={"check_same_thread": False},
|
||
poolclass=StaticPool,
|
||
)
|
||
async with engine.begin() as conn:
|
||
await conn.run_sync(Base.metadata.create_all)
|
||
yield engine
|
||
await engine.dispose()
|
||
|
||
|
||
@pytest_asyncio.fixture
|
||
async def async_session(async_engine):
|
||
"""创建测试用异步数据库会话"""
|
||
async_session_maker = async_sessionmaker(
|
||
async_engine,
|
||
class_=AsyncSession,
|
||
expire_on_commit=False,
|
||
autoflush=False,
|
||
autocommit=False,
|
||
)
|
||
async with async_session_maker() as session:
|
||
yield session
|
||
|
||
|
||
@pytest_asyncio.fixture
|
||
async def test_user(async_session):
|
||
"""创建测试用户"""
|
||
user = User(
|
||
id=uuid.uuid4(),
|
||
email="test@example.com",
|
||
password_hash=hash_password("Test@123456"),
|
||
name="Test User",
|
||
plan="free",
|
||
max_queries=5,
|
||
is_active=True,
|
||
email_verified=True,
|
||
)
|
||
async_session.add(user)
|
||
await async_session.commit()
|
||
await async_session.refresh(user)
|
||
return user
|
||
|
||
|
||
@pytest_asyncio.fixture
|
||
async def async_client(async_session, test_user):
|
||
"""创建异步HTTP客户端用于API测试"""
|
||
|
||
async def override_get_db():
|
||
yield async_session
|
||
|
||
async def override_get_current_user():
|
||
return test_user
|
||
|
||
app.dependency_overrides[get_db] = override_get_db
|
||
app.dependency_overrides[get_current_user] = override_get_current_user
|
||
|
||
transport = ASGITransport(app=app)
|
||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||
yield client
|
||
|
||
app.dependency_overrides.clear()
|
||
|
||
|
||
@pytest.fixture
|
||
def auth_headers(test_user):
|
||
"""创建认证请求头"""
|
||
token = create_access_token(data={"sub": str(test_user.id)})
|
||
return {"Authorization": f"Bearer {token}"}
|
||
|
||
|
||
# ==================== 测试类 ====================
|
||
|
||
class TestAuthAPI:
|
||
"""认证API测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_register_success(self, async_session):
|
||
"""测试用户注册成功"""
|
||
async def override_get_db():
|
||
yield async_session
|
||
|
||
app.dependency_overrides[get_db] = override_get_db
|
||
|
||
transport = ASGITransport(app=app)
|
||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||
response = await client.post(
|
||
"/api/v1/auth/register",
|
||
json={
|
||
"email": f"test_{uuid.uuid4()}@example.com",
|
||
"name": "Test User",
|
||
"password": "Test@123456"
|
||
}
|
||
)
|
||
|
||
assert response.status_code == 201
|
||
data = response.json()
|
||
assert "id" in data
|
||
assert data["email"] is not None
|
||
assert data["name"] == "Test User"
|
||
|
||
app.dependency_overrides.clear()
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_register_duplicate_email(self, async_session):
|
||
"""测试重复邮箱注册失败"""
|
||
email = f"test_{uuid.uuid4()}@example.com"
|
||
|
||
async def override_get_db():
|
||
yield async_session
|
||
|
||
app.dependency_overrides[get_db] = override_get_db
|
||
|
||
transport = ASGITransport(app=app)
|
||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||
# 第一次注册
|
||
response1 = await client.post(
|
||
"/api/v1/auth/register",
|
||
json={
|
||
"email": email,
|
||
"name": "Test User 1",
|
||
"password": "Test@123456"
|
||
}
|
||
)
|
||
assert response1.status_code == 201
|
||
|
||
# 第二次使用相同邮箱注册
|
||
response2 = await client.post(
|
||
"/api/v1/auth/register",
|
||
json={
|
||
"email": email,
|
||
"name": "Test User 2",
|
||
"password": "Test@123456"
|
||
}
|
||
)
|
||
assert response2.status_code == 400
|
||
|
||
app.dependency_overrides.clear()
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_login_success(self, async_session):
|
||
"""测试用户登录成功"""
|
||
email = f"test_{uuid.uuid4()}@example.com"
|
||
password = "Test@123456"
|
||
|
||
# 先创建用户
|
||
user = User(
|
||
id=uuid.uuid4(),
|
||
email=email,
|
||
password_hash=hash_password(password),
|
||
name="Test User",
|
||
plan="free",
|
||
max_queries=5,
|
||
is_active=True,
|
||
email_verified=True,
|
||
)
|
||
async_session.add(user)
|
||
await async_session.commit()
|
||
|
||
async def override_get_db():
|
||
yield async_session
|
||
|
||
app.dependency_overrides[get_db] = override_get_db
|
||
|
||
transport = ASGITransport(app=app)
|
||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||
response = await client.post(
|
||
"/api/v1/auth/login",
|
||
json={
|
||
"email": email,
|
||
"password": password
|
||
}
|
||
)
|
||
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert "access_token" in data
|
||
assert "refresh_token" in data
|
||
assert data["token_type"] == "bearer"
|
||
assert "user" in data
|
||
|
||
app.dependency_overrides.clear()
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_login_wrong_password(self, async_session):
|
||
"""测试错误密码登录"""
|
||
email = f"test_{uuid.uuid4()}@example.com"
|
||
|
||
# 创建用户
|
||
user = User(
|
||
id=uuid.uuid4(),
|
||
email=email,
|
||
password_hash=hash_password("Correct@123456"),
|
||
name="Test User",
|
||
plan="free",
|
||
max_queries=5,
|
||
is_active=True,
|
||
email_verified=True,
|
||
)
|
||
async_session.add(user)
|
||
await async_session.commit()
|
||
|
||
async def override_get_db():
|
||
yield async_session
|
||
|
||
app.dependency_overrides[get_db] = override_get_db
|
||
|
||
transport = ASGITransport(app=app)
|
||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||
response = await client.post(
|
||
"/api/v1/auth/login",
|
||
json={
|
||
"email": email,
|
||
"password": "WrongPassword"
|
||
}
|
||
)
|
||
|
||
assert response.status_code == 401
|
||
|
||
app.dependency_overrides.clear()
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_login_nonexistent_user(self, async_session):
|
||
"""测试不存在的用户登录"""
|
||
async def override_get_db():
|
||
yield async_session
|
||
|
||
app.dependency_overrides[get_db] = override_get_db
|
||
|
||
transport = ASGITransport(app=app)
|
||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||
response = await client.post(
|
||
"/api/v1/auth/login",
|
||
json={
|
||
"email": "nonexistent@example.com",
|
||
"password": "Test@123456"
|
||
}
|
||
)
|
||
|
||
# 统一错误消息,不区分用户不存在和密码错误
|
||
# 可能返回401(认证失败)或429(速率限制),都是有效响应
|
||
assert response.status_code in [401, 429]
|
||
|
||
app.dependency_overrides.clear()
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_current_user(self, async_client, test_user):
|
||
"""测试获取当前用户信息"""
|
||
response = await async_client.get("/api/v1/auth/me")
|
||
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert data["id"] == str(test_user.id)
|
||
assert data["email"] == test_user.email
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_change_password_success(self, async_client, async_session, test_user):
|
||
"""测试修改密码成功"""
|
||
async def override_get_db():
|
||
yield async_session
|
||
|
||
app.dependency_overrides[get_db] = override_get_db
|
||
|
||
transport = ASGITransport(app=app)
|
||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||
# 使用token获取用户
|
||
token = create_access_token(data={"sub": str(test_user.id)})
|
||
headers = {"Authorization": f"Bearer {token}"}
|
||
|
||
response = await client.put(
|
||
"/api/v1/auth/change-password",
|
||
headers=headers,
|
||
json={
|
||
"old_password": "Test@123456",
|
||
"new_password": "NewPass@123456"
|
||
}
|
||
)
|
||
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert data["message"] == "密码修改成功"
|
||
|
||
app.dependency_overrides.clear()
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_change_password_wrong_old(self, async_client, test_user):
|
||
"""测试旧密码错误"""
|
||
token = create_access_token(data={"sub": str(test_user.id)})
|
||
headers = {"Authorization": f"Bearer {token}"}
|
||
|
||
response = await async_client.put(
|
||
"/api/v1/auth/change-password",
|
||
headers=headers,
|
||
json={
|
||
"old_password": "WrongOldPass",
|
||
"new_password": "NewPass@123456"
|
||
}
|
||
)
|
||
|
||
assert response.status_code == 400
|
||
assert "旧密码错误" in response.json()["detail"]
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_update_profile(self, async_client, test_user):
|
||
"""测试更新用户资料"""
|
||
token = create_access_token(data={"sub": str(test_user.id)})
|
||
headers = {"Authorization": f"Bearer {token}"}
|
||
|
||
response = await async_client.put(
|
||
"/api/v1/auth/profile",
|
||
headers=headers,
|
||
json={
|
||
"name": "Updated Name",
|
||
"avatar_url": "https://example.com/avatar.png"
|
||
}
|
||
)
|
||
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert data["name"] == "Updated Name"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_refresh_token(self, async_session, test_user):
|
||
"""测试刷新令牌"""
|
||
from app.services.auth import create_refresh_token
|
||
|
||
async def override_get_db():
|
||
yield async_session
|
||
|
||
app.dependency_overrides[get_db] = override_get_db
|
||
|
||
# 创建refresh token
|
||
refresh_token = create_refresh_token(data={"sub": str(test_user.id)})
|
||
|
||
transport = ASGITransport(app=app)
|
||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||
response = await client.post(
|
||
"/api/v1/auth/refresh",
|
||
json={
|
||
"refresh_token": refresh_token
|
||
}
|
||
)
|
||
|
||
assert response.status_code == 200
|
||
data = response.json()
|
||
assert "access_token" in data
|
||
assert "refresh_token" in data
|
||
|
||
app.dependency_overrides.clear()
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_forgot_password(self, async_session):
|
||
"""测试忘记密码请求"""
|
||
async def override_get_db():
|
||
yield async_session
|
||
|
||
app.dependency_overrides[get_db] = override_get_db
|
||
|
||
transport = ASGITransport(app=app)
|
||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||
response = await client.post(
|
||
"/api/v1/auth/forgot-password",
|
||
json={
|
||
"email": "test@example.com"
|
||
}
|
||
)
|
||
|
||
# 无论邮箱是否存在都返回成功,防止用户枚举
|
||
assert response.status_code == 200
|
||
assert "message" in response.json()
|
||
|
||
app.dependency_overrides.clear()
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_resend_verification(self, async_session):
|
||
"""测试重新发送验证码"""
|
||
async def override_get_db():
|
||
yield async_session
|
||
|
||
app.dependency_overrides[get_db] = override_get_db
|
||
|
||
transport = ASGITransport(app=app)
|
||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||
response = await client.post(
|
||
"/api/v1/auth/resend-verification",
|
||
json={
|
||
"email": "test@example.com"
|
||
}
|
||
)
|
||
|
||
assert response.status_code == 200
|
||
assert "message" in response.json()
|
||
|
||
app.dependency_overrides.clear()
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_unauthorized_access(self, async_session):
|
||
"""测试未授权访问受保护端点"""
|
||
async def override_get_db():
|
||
yield async_session
|
||
|
||
app.dependency_overrides[get_db] = override_get_db
|
||
|
||
transport = ASGITransport(app=app)
|
||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||
# 不带token访问受保护端点
|
||
response = await client.get("/api/v1/auth/me")
|
||
|
||
assert response.status_code == 401
|
||
|
||
app.dependency_overrides.clear()
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_invalid_token(self, async_session):
|
||
"""测试无效令牌访问"""
|
||
async def override_get_db():
|
||
yield async_session
|
||
|
||
app.dependency_overrides[get_db] = override_get_db
|
||
|
||
transport = ASGITransport(app=app)
|
||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||
response = await client.get(
|
||
"/api/v1/auth/me",
|
||
headers={"Authorization": "Bearer invalid_token"}
|
||
)
|
||
|
||
assert response.status_code == 401
|
||
|
||
app.dependency_overrides.clear()
|