geo/backend/tests/test_api/test_auth_api.py

452 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""认证API测试"""
import uuid
import pytest
import pytest_asyncio
from httpx import AsyncClient, ASGITransport
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession, create_async_engine
from sqlalchemy.pool import StaticPool
from app.database import Base
from app.main import app
from app.models.user import User
from app.api.deps import get_current_user, get_db
from app.services.auth import hash_password, create_access_token
# ==================== Fixtures ====================
@pytest_asyncio.fixture
async def async_engine():
"""创建测试用SQLite异步引擎"""
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest_asyncio.fixture
async def async_session(async_engine):
"""创建测试用异步数据库会话"""
async_session_maker = async_sessionmaker(
async_engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False,
)
async with async_session_maker() as session:
yield session
@pytest_asyncio.fixture
async def test_user(async_session):
"""创建测试用户"""
user = User(
id=uuid.uuid4(),
email="test@example.com",
password_hash=hash_password("Test@123456"),
name="Test User",
plan="free",
max_queries=5,
is_active=True,
email_verified=True,
)
async_session.add(user)
await async_session.commit()
await async_session.refresh(user)
return user
@pytest_asyncio.fixture
async def async_client(async_session, test_user):
"""创建异步HTTP客户端用于API测试"""
async def override_get_db():
yield async_session
async def override_get_current_user():
return test_user
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_get_current_user
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
app.dependency_overrides.clear()
@pytest.fixture
def auth_headers(test_user):
"""创建认证请求头"""
token = create_access_token(data={"sub": str(test_user.id)})
return {"Authorization": f"Bearer {token}"}
# ==================== 测试类 ====================
class TestAuthAPI:
"""认证API测试"""
@pytest.mark.asyncio
async def test_register_success(self, async_session):
"""测试用户注册成功"""
async def override_get_db():
yield async_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.post(
"/api/v1/auth/register",
json={
"email": f"test_{uuid.uuid4()}@example.com",
"name": "Test User",
"password": "Test@123456"
}
)
assert response.status_code == 201
data = response.json()
assert "id" in data
assert data["email"] is not None
assert data["name"] == "Test User"
app.dependency_overrides.clear()
@pytest.mark.asyncio
async def test_register_duplicate_email(self, async_session):
"""测试重复邮箱注册失败"""
email = f"test_{uuid.uuid4()}@example.com"
async def override_get_db():
yield async_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
# 第一次注册
response1 = await client.post(
"/api/v1/auth/register",
json={
"email": email,
"name": "Test User 1",
"password": "Test@123456"
}
)
assert response1.status_code == 201
# 第二次使用相同邮箱注册
response2 = await client.post(
"/api/v1/auth/register",
json={
"email": email,
"name": "Test User 2",
"password": "Test@123456"
}
)
assert response2.status_code == 400
app.dependency_overrides.clear()
@pytest.mark.asyncio
async def test_login_success(self, async_session):
"""测试用户登录成功"""
email = f"test_{uuid.uuid4()}@example.com"
password = "Test@123456"
# 先创建用户
user = User(
id=uuid.uuid4(),
email=email,
password_hash=hash_password(password),
name="Test User",
plan="free",
max_queries=5,
is_active=True,
email_verified=True,
)
async_session.add(user)
await async_session.commit()
async def override_get_db():
yield async_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.post(
"/api/v1/auth/login",
json={
"email": email,
"password": password
}
)
assert response.status_code == 200
data = response.json()
assert "access_token" in data
assert "refresh_token" in data
assert data["token_type"] == "bearer"
assert "user" in data
app.dependency_overrides.clear()
@pytest.mark.asyncio
async def test_login_wrong_password(self, async_session):
"""测试错误密码登录"""
email = f"test_{uuid.uuid4()}@example.com"
# 创建用户
user = User(
id=uuid.uuid4(),
email=email,
password_hash=hash_password("Correct@123456"),
name="Test User",
plan="free",
max_queries=5,
is_active=True,
email_verified=True,
)
async_session.add(user)
await async_session.commit()
async def override_get_db():
yield async_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.post(
"/api/v1/auth/login",
json={
"email": email,
"password": "WrongPassword"
}
)
assert response.status_code == 401
app.dependency_overrides.clear()
@pytest.mark.asyncio
async def test_login_nonexistent_user(self, async_session):
"""测试不存在的用户登录"""
async def override_get_db():
yield async_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.post(
"/api/v1/auth/login",
json={
"email": "nonexistent@example.com",
"password": "Test@123456"
}
)
# 统一错误消息,不区分用户不存在和密码错误
# 可能返回401认证失败或429速率限制都是有效响应
assert response.status_code in [401, 429]
app.dependency_overrides.clear()
@pytest.mark.asyncio
async def test_get_current_user(self, async_client, test_user):
"""测试获取当前用户信息"""
response = await async_client.get("/api/v1/auth/me")
assert response.status_code == 200
data = response.json()
assert data["id"] == str(test_user.id)
assert data["email"] == test_user.email
@pytest.mark.asyncio
async def test_change_password_success(self, async_client, async_session, test_user):
"""测试修改密码成功"""
async def override_get_db():
yield async_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
# 使用token获取用户
token = create_access_token(data={"sub": str(test_user.id)})
headers = {"Authorization": f"Bearer {token}"}
response = await client.put(
"/api/v1/auth/change-password",
headers=headers,
json={
"old_password": "Test@123456",
"new_password": "NewPass@123456"
}
)
assert response.status_code == 200
data = response.json()
assert data["message"] == "密码修改成功"
app.dependency_overrides.clear()
@pytest.mark.asyncio
async def test_change_password_wrong_old(self, async_client, test_user):
"""测试旧密码错误"""
token = create_access_token(data={"sub": str(test_user.id)})
headers = {"Authorization": f"Bearer {token}"}
response = await async_client.put(
"/api/v1/auth/change-password",
headers=headers,
json={
"old_password": "WrongOldPass",
"new_password": "NewPass@123456"
}
)
assert response.status_code == 400
assert "旧密码错误" in response.json()["detail"]
@pytest.mark.asyncio
async def test_update_profile(self, async_client, test_user):
"""测试更新用户资料"""
token = create_access_token(data={"sub": str(test_user.id)})
headers = {"Authorization": f"Bearer {token}"}
response = await async_client.put(
"/api/v1/auth/profile",
headers=headers,
json={
"name": "Updated Name",
"avatar_url": "https://example.com/avatar.png"
}
)
assert response.status_code == 200
data = response.json()
assert data["name"] == "Updated Name"
@pytest.mark.asyncio
async def test_refresh_token(self, async_session, test_user):
"""测试刷新令牌"""
from app.services.auth import create_refresh_token
async def override_get_db():
yield async_session
app.dependency_overrides[get_db] = override_get_db
# 创建refresh token
refresh_token = create_refresh_token(data={"sub": str(test_user.id)})
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.post(
"/api/v1/auth/refresh",
json={
"refresh_token": refresh_token
}
)
assert response.status_code == 200
data = response.json()
assert "access_token" in data
assert "refresh_token" in data
app.dependency_overrides.clear()
@pytest.mark.asyncio
async def test_forgot_password(self, async_session):
"""测试忘记密码请求"""
async def override_get_db():
yield async_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.post(
"/api/v1/auth/forgot-password",
json={
"email": "test@example.com"
}
)
# 无论邮箱是否存在都返回成功,防止用户枚举
assert response.status_code == 200
assert "message" in response.json()
app.dependency_overrides.clear()
@pytest.mark.asyncio
async def test_resend_verification(self, async_session):
"""测试重新发送验证码"""
async def override_get_db():
yield async_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.post(
"/api/v1/auth/resend-verification",
json={
"email": "test@example.com"
}
)
assert response.status_code == 200
assert "message" in response.json()
app.dependency_overrides.clear()
@pytest.mark.asyncio
async def test_unauthorized_access(self, async_session):
"""测试未授权访问受保护端点"""
async def override_get_db():
yield async_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
# 不带token访问受保护端点
response = await client.get("/api/v1/auth/me")
assert response.status_code == 401
app.dependency_overrides.clear()
@pytest.mark.asyncio
async def test_invalid_token(self, async_session):
"""测试无效令牌访问"""
async def override_get_db():
yield async_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.get(
"/api/v1/auth/me",
headers={"Authorization": "Bearer invalid_token"}
)
assert response.status_code == 401
app.dependency_overrides.clear()