geo/backend/app/services/auth.py

201 lines
6.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
import random
import uuid
from datetime import datetime, timedelta
import bcrypt
from jose import jwt, JWTError
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.models.user import User
from app.schemas.auth import UserRegister, UpdateProfileRequest
logger = logging.getLogger(__name__)
def hash_password(password: str) -> str:
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
def verify_password(plain_password: str, hashed_password: str) -> bool:
return bcrypt.checkpw(plain_password.encode(), hashed_password.encode())
def create_access_token(data: dict) -> str:
to_encode = data.copy()
# access token 有效期固定为 1 小时(替代原来的 JWT_EXPIRE_HOURS=24h
expire = datetime.utcnow() + timedelta(hours=1)
to_encode.update({"exp": expire, "type": "access"})
encoded_jwt = jwt.encode(to_encode, settings.JWT_SECRET, algorithm="HS256")
return encoded_jwt
def create_refresh_token(data: dict) -> str:
"""7 天有效期的刷新令牌,使用 type: 'refresh' 区分"""
to_encode = data.copy()
expire = datetime.utcnow() + timedelta(days=7)
to_encode.update({"exp": expire, "type": "refresh"})
encoded_jwt = jwt.encode(to_encode, settings.JWT_SECRET, algorithm="HS256")
return encoded_jwt
def verify_refresh_token(token: str) -> dict:
"""验证 refresh token返回 payload如果无效或类型不匹配则抛出异常"""
try:
payload = jwt.decode(token, settings.JWT_SECRET, algorithms=["HS256"])
except JWTError:
raise ValueError("刷新令牌无效")
if payload.get("type") != "refresh":
raise ValueError("令牌类型错误")
return payload
def verify_token(token: str) -> dict:
"""验证 access token返回 payload"""
try:
payload = jwt.decode(token, settings.JWT_SECRET, algorithms=["HS256"])
except JWTError:
raise ValueError("访问令牌无效")
if payload.get("type") not in ("access", None): # None 兼容旧 token
raise ValueError("令牌类型错误")
return payload
async def register_user(db: AsyncSession, user_data: UserRegister) -> User:
stmt = select(User).where(User.email == user_data.email)
result = await db.execute(stmt)
existing_user = result.scalar_one_or_none()
if existing_user:
raise ValueError("邮箱已被注册")
user = User(
email=user_data.email,
password_hash=hash_password(user_data.password),
name=user_data.name,
)
db.add(user)
await db.commit()
await db.refresh(user)
return user
async def authenticate_user(
db: AsyncSession, email: str, password: str
) -> User | None:
stmt = select(User).where(User.email == email)
result = await db.execute(stmt)
user = result.scalar_one_or_none()
if not user:
return None
if not verify_password(password, user.password_hash):
return None
return user
async def send_verification_code(db: AsyncSession, email: str) -> None:
"""生成6位随机验证码存到user记录日志输出模拟邮件"""
stmt = select(User).where(User.email == email)
result = await db.execute(stmt)
user = result.scalar_one_or_none()
if not user:
return
code = f"{random.randint(100000, 999999)}"
user.verification_code = code
user.verification_code_expires = datetime.utcnow() + timedelta(minutes=10)
await db.commit()
logger.info(f"[模拟邮件] 邮箱验证码发送到 {email}: {code}")
async def verify_email(db: AsyncSession, email: str, code: str) -> bool:
"""验证码校验成功则设置email_verified=True"""
stmt = select(User).where(User.email == email)
result = await db.execute(stmt)
user = result.scalar_one_or_none()
if not user:
return False
if user.verification_code != code:
return False
if user.verification_code_expires is None or user.verification_code_expires < datetime.utcnow():
return False
user.email_verified = True
user.verification_code = None
user.verification_code_expires = None
await db.commit()
return True
async def send_reset_link(db: AsyncSession, email: str) -> None:
"""生成UUID token存到user记录日志输出重置链接"""
stmt = select(User).where(User.email == email)
result = await db.execute(stmt)
user = result.scalar_one_or_none()
if not user:
return
token = str(uuid.uuid4())
user.reset_token = token
user.reset_token_expires = datetime.utcnow() + timedelta(hours=1)
await db.commit()
logger.info(f"[模拟邮件] 密码重置链接: http://localhost:3000/reset-password?token={token}")
async def reset_password(db: AsyncSession, token: str, new_password: str) -> bool:
"""token验证+密码更新"""
stmt = select(User).where(User.reset_token == token)
result = await db.execute(stmt)
user = result.scalar_one_or_none()
if not user:
return False
if user.reset_token_expires is None or user.reset_token_expires < datetime.utcnow():
return False
user.password_hash = hash_password(new_password)
user.reset_token = None
user.reset_token_expires = None
await db.commit()
return True
async def change_password(db: AsyncSession, user_id: uuid.UUID, old_password: str, new_password: str) -> bool:
"""旧密码验证后更新"""
stmt = select(User).where(User.id == user_id)
result = await db.execute(stmt)
user = result.scalar_one_or_none()
if not user:
return False
if not verify_password(old_password, user.password_hash):
return False
user.password_hash = hash_password(new_password)
await db.commit()
return True
async def update_profile(db: AsyncSession, user_id: uuid.UUID, data: UpdateProfileRequest) -> User | None:
"""更新用户资料name, avatar_url"""
stmt = select(User).where(User.id == user_id)
result = await db.execute(stmt)
user = result.scalar_one_or_none()
if not user:
return None
if data.name is not None:
user.name = data.name
if data.avatar_url is not None:
user.avatar_url = data.avatar_url
await db.commit()
await db.refresh(user)
return user