geo/backend/app/services/auth.py

175 lines
5.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
import random
import uuid
from datetime import datetime, timedelta
from jose import jwt, JWTError
from passlib.context import CryptContext
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.models.user import User
from app.schemas.auth import UserRegister, UpdateProfileRequest
logger = logging.getLogger(__name__)
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
def hash_password(password: str) -> str:
return pwd_context.hash(password)
def verify_password(plain_password: str, hashed_password: str) -> bool:
return pwd_context.verify(plain_password, hashed_password)
def create_access_token(data: dict) -> str:
to_encode = data.copy()
expire = datetime.utcnow() + timedelta(hours=settings.JWT_EXPIRE_HOURS)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, settings.JWT_SECRET, algorithm="HS256")
return encoded_jwt
def verify_token(token: str) -> dict:
payload = jwt.decode(token, settings.JWT_SECRET, algorithms=["HS256"])
return payload
async def register_user(db: AsyncSession, user_data: UserRegister) -> User:
stmt = select(User).where(User.email == user_data.email)
result = await db.execute(stmt)
existing_user = result.scalar_one_or_none()
if existing_user:
raise ValueError("邮箱已被注册")
user = User(
email=user_data.email,
password_hash=hash_password(user_data.password),
name=user_data.name,
)
db.add(user)
await db.commit()
await db.refresh(user)
return user
async def authenticate_user(
db: AsyncSession, email: str, password: str
) -> User | None:
stmt = select(User).where(User.email == email)
result = await db.execute(stmt)
user = result.scalar_one_or_none()
if not user:
return None
if not verify_password(password, user.password_hash):
return None
return user
async def send_verification_code(db: AsyncSession, email: str) -> None:
"""生成6位随机验证码存到user记录日志输出模拟邮件"""
stmt = select(User).where(User.email == email)
result = await db.execute(stmt)
user = result.scalar_one_or_none()
if not user:
return
code = f"{random.randint(100000, 999999)}"
user.verification_code = code
user.verification_code_expires = datetime.utcnow() + timedelta(minutes=10)
await db.commit()
logger.info(f"[模拟邮件] 邮箱验证码发送到 {email}: {code}")
async def verify_email(db: AsyncSession, email: str, code: str) -> bool:
"""验证码校验成功则设置email_verified=True"""
stmt = select(User).where(User.email == email)
result = await db.execute(stmt)
user = result.scalar_one_or_none()
if not user:
return False
if user.verification_code != code:
return False
if user.verification_code_expires is None or user.verification_code_expires < datetime.utcnow():
return False
user.email_verified = True
user.verification_code = None
user.verification_code_expires = None
await db.commit()
return True
async def send_reset_link(db: AsyncSession, email: str) -> None:
"""生成UUID token存到user记录日志输出重置链接"""
stmt = select(User).where(User.email == email)
result = await db.execute(stmt)
user = result.scalar_one_or_none()
if not user:
return
token = str(uuid.uuid4())
user.reset_token = token
user.reset_token_expires = datetime.utcnow() + timedelta(hours=1)
await db.commit()
logger.info(f"[模拟邮件] 密码重置链接: http://localhost:3000/reset-password?token={token}")
async def reset_password(db: AsyncSession, token: str, new_password: str) -> bool:
"""token验证+密码更新"""
stmt = select(User).where(User.reset_token == token)
result = await db.execute(stmt)
user = result.scalar_one_or_none()
if not user:
return False
if user.reset_token_expires is None or user.reset_token_expires < datetime.utcnow():
return False
user.password_hash = hash_password(new_password)
user.reset_token = None
user.reset_token_expires = None
await db.commit()
return True
async def change_password(db: AsyncSession, user_id: uuid.UUID, old_password: str, new_password: str) -> bool:
"""旧密码验证后更新"""
stmt = select(User).where(User.id == user_id)
result = await db.execute(stmt)
user = result.scalar_one_or_none()
if not user:
return False
if not verify_password(old_password, user.password_hash):
return False
user.password_hash = hash_password(new_password)
await db.commit()
return True
async def update_profile(db: AsyncSession, user_id: uuid.UUID, data: UpdateProfileRequest) -> User | None:
"""更新用户资料name, avatar_url"""
stmt = select(User).where(User.id == user_id)
result = await db.execute(stmt)
user = result.scalar_one_or_none()
if not user:
return None
if data.name is not None:
user.name = data.name
if data.avatar_url is not None:
user.avatar_url = data.avatar_url
await db.commit()
await db.refresh(user)
return user