"""认证API测试""" import uuid import pytest import pytest_asyncio from httpx import AsyncClient, ASGITransport from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession, create_async_engine from sqlalchemy.pool import StaticPool from app.database import Base from app.main import app from app.models.user import User from app.api.deps import get_current_user, get_db from app.services.auth import hash_password, create_access_token # ==================== Fixtures ==================== @pytest_asyncio.fixture async def async_engine(): """创建测试用SQLite异步引擎""" engine = create_async_engine( "sqlite+aiosqlite:///:memory:", connect_args={"check_same_thread": False}, poolclass=StaticPool, ) async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) yield engine await engine.dispose() @pytest_asyncio.fixture async def async_session(async_engine): """创建测试用异步数据库会话""" async_session_maker = async_sessionmaker( async_engine, class_=AsyncSession, expire_on_commit=False, autoflush=False, autocommit=False, ) async with async_session_maker() as session: yield session @pytest_asyncio.fixture async def test_user(async_session): """创建测试用户""" user = User( id=uuid.uuid4(), email="test@example.com", password_hash=hash_password("Test@123456"), name="Test User", plan="free", max_queries=5, is_active=True, email_verified=True, ) async_session.add(user) await async_session.commit() await async_session.refresh(user) return user @pytest_asyncio.fixture async def async_client(async_session, test_user): """创建异步HTTP客户端用于API测试""" async def override_get_db(): yield async_session async def override_get_current_user(): return test_user app.dependency_overrides[get_db] = override_get_db app.dependency_overrides[get_current_user] = override_get_current_user transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: yield client app.dependency_overrides.clear() @pytest.fixture def auth_headers(test_user): """创建认证请求头""" token = create_access_token(data={"sub": str(test_user.id)}) return {"Authorization": f"Bearer {token}"} # ==================== 测试类 ==================== class TestAuthAPI: """认证API测试""" @pytest.mark.asyncio async def test_register_success(self, async_session): """测试用户注册成功""" async def override_get_db(): yield async_session app.dependency_overrides[get_db] = override_get_db transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: response = await client.post( "/api/v1/auth/register", json={ "email": f"test_{uuid.uuid4()}@example.com", "name": "Test User", "password": "Test@123456" } ) assert response.status_code == 201 data = response.json() assert "id" in data assert data["email"] is not None assert data["name"] == "Test User" app.dependency_overrides.clear() @pytest.mark.asyncio async def test_register_duplicate_email(self, async_session): """测试重复邮箱注册失败""" email = f"test_{uuid.uuid4()}@example.com" async def override_get_db(): yield async_session app.dependency_overrides[get_db] = override_get_db transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: # 第一次注册 response1 = await client.post( "/api/v1/auth/register", json={ "email": email, "name": "Test User 1", "password": "Test@123456" } ) assert response1.status_code == 201 # 第二次使用相同邮箱注册 response2 = await client.post( "/api/v1/auth/register", json={ "email": email, "name": "Test User 2", "password": "Test@123456" } ) assert response2.status_code == 400 app.dependency_overrides.clear() @pytest.mark.asyncio async def test_login_success(self, async_session): """测试用户登录成功""" email = f"test_{uuid.uuid4()}@example.com" password = "Test@123456" # 先创建用户 user = User( id=uuid.uuid4(), email=email, password_hash=hash_password(password), name="Test User", plan="free", max_queries=5, is_active=True, email_verified=True, ) async_session.add(user) await async_session.commit() async def override_get_db(): yield async_session app.dependency_overrides[get_db] = override_get_db transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: response = await client.post( "/api/v1/auth/login", json={ "email": email, "password": password } ) assert response.status_code == 200 data = response.json() assert "access_token" in data assert "refresh_token" in data assert data["token_type"] == "bearer" assert "user" in data app.dependency_overrides.clear() @pytest.mark.asyncio async def test_login_wrong_password(self, async_session): """测试错误密码登录""" email = f"test_{uuid.uuid4()}@example.com" # 创建用户 user = User( id=uuid.uuid4(), email=email, password_hash=hash_password("Correct@123456"), name="Test User", plan="free", max_queries=5, is_active=True, email_verified=True, ) async_session.add(user) await async_session.commit() async def override_get_db(): yield async_session app.dependency_overrides[get_db] = override_get_db transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: response = await client.post( "/api/v1/auth/login", json={ "email": email, "password": "WrongPassword" } ) assert response.status_code == 401 app.dependency_overrides.clear() @pytest.mark.asyncio async def test_login_nonexistent_user(self, async_session): """测试不存在的用户登录""" async def override_get_db(): yield async_session app.dependency_overrides[get_db] = override_get_db transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: response = await client.post( "/api/v1/auth/login", json={ "email": "nonexistent@example.com", "password": "Test@123456" } ) # 统一错误消息,不区分用户不存在和密码错误 # 可能返回401(认证失败)或429(速率限制),都是有效响应 assert response.status_code in [401, 429] app.dependency_overrides.clear() @pytest.mark.asyncio async def test_get_current_user(self, async_client, test_user): """测试获取当前用户信息""" response = await async_client.get("/api/v1/auth/me") assert response.status_code == 200 data = response.json() assert data["id"] == str(test_user.id) assert data["email"] == test_user.email @pytest.mark.asyncio async def test_change_password_success(self, async_client, async_session, test_user): """测试修改密码成功""" async def override_get_db(): yield async_session app.dependency_overrides[get_db] = override_get_db transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: # 使用token获取用户 token = create_access_token(data={"sub": str(test_user.id)}) headers = {"Authorization": f"Bearer {token}"} response = await client.put( "/api/v1/auth/change-password", headers=headers, json={ "old_password": "Test@123456", "new_password": "NewPass@123456" } ) assert response.status_code == 200 data = response.json() assert data["message"] == "密码修改成功" app.dependency_overrides.clear() @pytest.mark.asyncio async def test_change_password_wrong_old(self, async_client, test_user): """测试旧密码错误""" token = create_access_token(data={"sub": str(test_user.id)}) headers = {"Authorization": f"Bearer {token}"} response = await async_client.put( "/api/v1/auth/change-password", headers=headers, json={ "old_password": "WrongOldPass", "new_password": "NewPass@123456" } ) assert response.status_code == 400 assert "旧密码错误" in response.json()["detail"] @pytest.mark.asyncio async def test_update_profile(self, async_client, test_user): """测试更新用户资料""" token = create_access_token(data={"sub": str(test_user.id)}) headers = {"Authorization": f"Bearer {token}"} response = await async_client.put( "/api/v1/auth/profile", headers=headers, json={ "name": "Updated Name", "avatar_url": "https://example.com/avatar.png" } ) assert response.status_code == 200 data = response.json() assert data["name"] == "Updated Name" @pytest.mark.asyncio async def test_refresh_token(self, async_session, test_user): """测试刷新令牌""" from app.services.auth import create_refresh_token async def override_get_db(): yield async_session app.dependency_overrides[get_db] = override_get_db # 创建refresh token refresh_token = create_refresh_token(data={"sub": str(test_user.id)}) transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: response = await client.post( "/api/v1/auth/refresh", json={ "refresh_token": refresh_token } ) assert response.status_code == 200 data = response.json() assert "access_token" in data assert "refresh_token" in data app.dependency_overrides.clear() @pytest.mark.asyncio async def test_forgot_password(self, async_session): """测试忘记密码请求""" async def override_get_db(): yield async_session app.dependency_overrides[get_db] = override_get_db transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: response = await client.post( "/api/v1/auth/forgot-password", json={ "email": "test@example.com" } ) # 无论邮箱是否存在都返回成功,防止用户枚举 assert response.status_code == 200 assert "message" in response.json() app.dependency_overrides.clear() @pytest.mark.asyncio async def test_resend_verification(self, async_session): """测试重新发送验证码""" async def override_get_db(): yield async_session app.dependency_overrides[get_db] = override_get_db transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: response = await client.post( "/api/v1/auth/resend-verification", json={ "email": "test@example.com" } ) assert response.status_code == 200 assert "message" in response.json() app.dependency_overrides.clear() @pytest.mark.asyncio async def test_unauthorized_access(self, async_session): """测试未授权访问受保护端点""" async def override_get_db(): yield async_session app.dependency_overrides[get_db] = override_get_db transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: # 不带token访问受保护端点 response = await client.get("/api/v1/auth/me") assert response.status_code == 401 app.dependency_overrides.clear() @pytest.mark.asyncio async def test_invalid_token(self, async_session): """测试无效令牌访问""" async def override_get_db(): yield async_session app.dependency_overrides[get_db] = override_get_db transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: response = await client.get( "/api/v1/auth/me", headers={"Authorization": "Bearer invalid_token"} ) assert response.status_code == 401 app.dependency_overrides.clear()