170 lines
5.8 KiB
Python
170 lines
5.8 KiB
Python
from fastapi import APIRouter, Depends, HTTPException, status
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.api.deps import get_current_user
|
||
from app.database import get_db
|
||
from app.models.user import User
|
||
from app.schemas.auth import (
|
||
AccessTokenResponse,
|
||
ChangePasswordRequest,
|
||
ForgotPasswordRequest,
|
||
RefreshTokenRequest,
|
||
ResetPasswordRequest,
|
||
TokenResponse,
|
||
UpdateProfileRequest,
|
||
UserLogin,
|
||
UserRegister,
|
||
UserResponse,
|
||
VerifyEmailRequest,
|
||
)
|
||
from app.services.auth import (
|
||
authenticate_user,
|
||
change_password as change_password_service,
|
||
create_access_token,
|
||
create_refresh_token,
|
||
register_user,
|
||
reset_password as reset_password_service,
|
||
send_reset_link,
|
||
send_verification_code,
|
||
update_profile as update_profile_service,
|
||
verify_email as verify_email_service,
|
||
verify_refresh_token,
|
||
)
|
||
from app.services.cache import get_cache_service, TTL_USER_PROFILE
|
||
|
||
router = APIRouter()
|
||
|
||
|
||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||
async def register(user_data: UserRegister, db: AsyncSession = Depends(get_db)):
|
||
try:
|
||
user = await register_user(db, user_data)
|
||
except ValueError:
|
||
# 不泄露具体原因,防止用户枚举
|
||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="注册失败,请检查输入信息是否已被使用")
|
||
return user
|
||
|
||
|
||
@router.post("/login", response_model=TokenResponse)
|
||
async def login(user_data: UserLogin, db: AsyncSession = Depends(get_db)):
|
||
user = await authenticate_user(db, user_data.email, user_data.password)
|
||
if not user:
|
||
# 统一错误消息,防止用户枚举(不区分“用户不存在” vs “密码错误”)
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="邮箱或密码错误",
|
||
headers={"WWW-Authenticate": "Bearer"},
|
||
)
|
||
|
||
access_token = create_access_token(data={"sub": str(user.id)})
|
||
refresh_token = create_refresh_token(data={"sub": str(user.id)})
|
||
return {
|
||
"access_token": access_token,
|
||
"token_type": "bearer",
|
||
"refresh_token": refresh_token,
|
||
"user": user,
|
||
}
|
||
|
||
|
||
@router.post("/refresh", response_model=AccessTokenResponse)
|
||
async def refresh_token(req: RefreshTokenRequest):
|
||
"""
|
||
刷新接口:使用 refresh_token 获取新的 access_token + refresh_token(滑动过期)
|
||
"""
|
||
try:
|
||
payload = verify_refresh_token(req.refresh_token)
|
||
except ValueError:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="刷新令牌无效或已过期",
|
||
headers={"WWW-Authenticate": "Bearer"},
|
||
)
|
||
|
||
user_id = payload.get("sub")
|
||
if not user_id:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
detail="刷新令牌无效或已过期",
|
||
headers={"WWW-Authenticate": "Bearer"},
|
||
)
|
||
|
||
new_access_token = create_access_token(data={"sub": user_id})
|
||
new_refresh_token = create_refresh_token(data={"sub": user_id}) # 滑动过期
|
||
return {
|
||
"access_token": new_access_token,
|
||
"token_type": "bearer",
|
||
"refresh_token": new_refresh_token,
|
||
}
|
||
|
||
|
||
@router.get("/me", response_model=UserResponse)
|
||
async def read_current_user(
|
||
current_user: User = Depends(get_current_user),
|
||
):
|
||
cache = get_cache_service()
|
||
cache_key = f"user:profile:{current_user.id}"
|
||
cached = await cache.get_json(cache_key)
|
||
if cached is not None:
|
||
return cached
|
||
|
||
user_data = UserResponse.model_validate(current_user).model_dump(mode="json")
|
||
await cache.set_json(cache_key, user_data, expire=TTL_USER_PROFILE)
|
||
return current_user
|
||
|
||
|
||
@router.post("/forgot-password")
|
||
async def forgot_password(req: ForgotPasswordRequest, db: AsyncSession = Depends(get_db)):
|
||
await send_reset_link(db, req.email)
|
||
return {"message": "如果该邮箱已注册,重置链接已发送"}
|
||
|
||
|
||
@router.post("/reset-password")
|
||
async def reset_password(req: ResetPasswordRequest, db: AsyncSession = Depends(get_db)):
|
||
success = await reset_password_service(db, req.token, req.new_password)
|
||
if not success:
|
||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="无效的令牌或令牌已过期")
|
||
return {"message": "密码重置成功"}
|
||
|
||
|
||
@router.post("/verify-email")
|
||
async def verify_email(req: VerifyEmailRequest, db: AsyncSession = Depends(get_db)):
|
||
success = await verify_email_service(db, req.email, req.code)
|
||
if not success:
|
||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="验证码无效或已过期")
|
||
return {"message": "邮箱验证成功"}
|
||
|
||
|
||
@router.post("/resend-verification")
|
||
async def resend_verification(req: ForgotPasswordRequest, db: AsyncSession = Depends(get_db)):
|
||
await send_verification_code(db, req.email)
|
||
return {"message": "验证码已重新发送"}
|
||
|
||
|
||
@router.put("/change-password")
|
||
async def change_password(
|
||
req: ChangePasswordRequest,
|
||
user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
success = await change_password_service(db, user.id, req.old_password, req.new_password)
|
||
if not success:
|
||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="旧密码错误")
|
||
return {"message": "密码修改成功"}
|
||
|
||
|
||
@router.put("/profile", response_model=UserResponse)
|
||
async def update_profile(
|
||
req: UpdateProfileRequest,
|
||
user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
updated_user = await update_profile_service(db, user.id, req)
|
||
if not updated_user:
|
||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="用户不存在")
|
||
|
||
# 失效用户配置缓存
|
||
cache = get_cache_service()
|
||
await cache.delete(f"user:profile:{user.id}")
|
||
|
||
return updated_user
|