175 lines
5.2 KiB
Python
175 lines
5.2 KiB
Python
import logging
|
||
import random
|
||
import uuid
|
||
from datetime import datetime, timedelta
|
||
|
||
from jose import jwt, JWTError
|
||
from passlib.context import CryptContext
|
||
from sqlalchemy import select
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.config import settings
|
||
from app.models.user import User
|
||
from app.schemas.auth import UserRegister, UpdateProfileRequest
|
||
|
||
logger = logging.getLogger(__name__)
|
||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||
|
||
|
||
def hash_password(password: str) -> str:
|
||
return pwd_context.hash(password)
|
||
|
||
|
||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||
return pwd_context.verify(plain_password, hashed_password)
|
||
|
||
|
||
def create_access_token(data: dict) -> str:
|
||
to_encode = data.copy()
|
||
expire = datetime.utcnow() + timedelta(hours=settings.JWT_EXPIRE_HOURS)
|
||
to_encode.update({"exp": expire})
|
||
encoded_jwt = jwt.encode(to_encode, settings.JWT_SECRET, algorithm="HS256")
|
||
return encoded_jwt
|
||
|
||
|
||
def verify_token(token: str) -> dict:
|
||
payload = jwt.decode(token, settings.JWT_SECRET, algorithms=["HS256"])
|
||
return payload
|
||
|
||
|
||
async def register_user(db: AsyncSession, user_data: UserRegister) -> User:
|
||
stmt = select(User).where(User.email == user_data.email)
|
||
result = await db.execute(stmt)
|
||
existing_user = result.scalar_one_or_none()
|
||
if existing_user:
|
||
raise ValueError("邮箱已被注册")
|
||
|
||
user = User(
|
||
email=user_data.email,
|
||
password_hash=hash_password(user_data.password),
|
||
name=user_data.name,
|
||
)
|
||
db.add(user)
|
||
await db.commit()
|
||
await db.refresh(user)
|
||
return user
|
||
|
||
|
||
async def authenticate_user(
|
||
db: AsyncSession, email: str, password: str
|
||
) -> User | None:
|
||
stmt = select(User).where(User.email == email)
|
||
result = await db.execute(stmt)
|
||
user = result.scalar_one_or_none()
|
||
|
||
if not user:
|
||
return None
|
||
|
||
if not verify_password(password, user.password_hash):
|
||
return None
|
||
|
||
return user
|
||
|
||
|
||
async def send_verification_code(db: AsyncSession, email: str) -> None:
|
||
"""生成6位随机验证码,存到user记录,日志输出(模拟邮件)"""
|
||
stmt = select(User).where(User.email == email)
|
||
result = await db.execute(stmt)
|
||
user = result.scalar_one_or_none()
|
||
if not user:
|
||
return
|
||
|
||
code = f"{random.randint(100000, 999999)}"
|
||
user.verification_code = code
|
||
user.verification_code_expires = datetime.utcnow() + timedelta(minutes=10)
|
||
await db.commit()
|
||
logger.info(f"[模拟邮件] 邮箱验证码发送到 {email}: {code}")
|
||
|
||
|
||
async def verify_email(db: AsyncSession, email: str, code: str) -> bool:
|
||
"""验证码校验,成功则设置email_verified=True"""
|
||
stmt = select(User).where(User.email == email)
|
||
result = await db.execute(stmt)
|
||
user = result.scalar_one_or_none()
|
||
if not user:
|
||
return False
|
||
|
||
if user.verification_code != code:
|
||
return False
|
||
|
||
if user.verification_code_expires is None or user.verification_code_expires < datetime.utcnow():
|
||
return False
|
||
|
||
user.email_verified = True
|
||
user.verification_code = None
|
||
user.verification_code_expires = None
|
||
await db.commit()
|
||
return True
|
||
|
||
|
||
async def send_reset_link(db: AsyncSession, email: str) -> None:
|
||
"""生成UUID token,存到user记录,日志输出重置链接"""
|
||
stmt = select(User).where(User.email == email)
|
||
result = await db.execute(stmt)
|
||
user = result.scalar_one_or_none()
|
||
if not user:
|
||
return
|
||
|
||
token = str(uuid.uuid4())
|
||
user.reset_token = token
|
||
user.reset_token_expires = datetime.utcnow() + timedelta(hours=1)
|
||
await db.commit()
|
||
logger.info(f"[模拟邮件] 密码重置链接: http://localhost:3000/reset-password?token={token}")
|
||
|
||
|
||
async def reset_password(db: AsyncSession, token: str, new_password: str) -> bool:
|
||
"""token验证+密码更新"""
|
||
stmt = select(User).where(User.reset_token == token)
|
||
result = await db.execute(stmt)
|
||
user = result.scalar_one_or_none()
|
||
if not user:
|
||
return False
|
||
|
||
if user.reset_token_expires is None or user.reset_token_expires < datetime.utcnow():
|
||
return False
|
||
|
||
user.password_hash = hash_password(new_password)
|
||
user.reset_token = None
|
||
user.reset_token_expires = None
|
||
await db.commit()
|
||
return True
|
||
|
||
|
||
async def change_password(db: AsyncSession, user_id: uuid.UUID, old_password: str, new_password: str) -> bool:
|
||
"""旧密码验证后更新"""
|
||
stmt = select(User).where(User.id == user_id)
|
||
result = await db.execute(stmt)
|
||
user = result.scalar_one_or_none()
|
||
if not user:
|
||
return False
|
||
|
||
if not verify_password(old_password, user.password_hash):
|
||
return False
|
||
|
||
user.password_hash = hash_password(new_password)
|
||
await db.commit()
|
||
return True
|
||
|
||
|
||
async def update_profile(db: AsyncSession, user_id: uuid.UUID, data: UpdateProfileRequest) -> User | None:
|
||
"""更新用户资料(name, avatar_url)"""
|
||
stmt = select(User).where(User.id == user_id)
|
||
result = await db.execute(stmt)
|
||
user = result.scalar_one_or_none()
|
||
if not user:
|
||
return None
|
||
|
||
if data.name is not None:
|
||
user.name = data.name
|
||
if data.avatar_url is not None:
|
||
user.avatar_url = data.avatar_url
|
||
|
||
await db.commit()
|
||
await db.refresh(user)
|
||
return user
|