import logging import random import uuid from datetime import datetime, timedelta import bcrypt from jose import jwt, JWTError from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.config import settings from app.models.user import User from app.schemas.auth import UserRegister, UpdateProfileRequest logger = logging.getLogger(__name__) def hash_password(password: str) -> str: return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode() def verify_password(plain_password: str, hashed_password: str) -> bool: return bcrypt.checkpw(plain_password.encode(), hashed_password.encode()) def create_access_token(data: dict) -> str: to_encode = data.copy() # access token 有效期固定为 1 小时(替代原来的 JWT_EXPIRE_HOURS=24h) expire = datetime.utcnow() + timedelta(hours=1) to_encode.update({"exp": expire, "type": "access"}) encoded_jwt = jwt.encode(to_encode, settings.JWT_SECRET, algorithm="HS256") return encoded_jwt def create_refresh_token(data: dict) -> str: """7 天有效期的刷新令牌,使用 type: 'refresh' 区分""" to_encode = data.copy() expire = datetime.utcnow() + timedelta(days=7) to_encode.update({"exp": expire, "type": "refresh"}) encoded_jwt = jwt.encode(to_encode, settings.JWT_SECRET, algorithm="HS256") return encoded_jwt def verify_refresh_token(token: str) -> dict: """验证 refresh token,返回 payload;如果无效或类型不匹配则抛出异常""" try: payload = jwt.decode(token, settings.JWT_SECRET, algorithms=["HS256"]) except JWTError: raise ValueError("刷新令牌无效") if payload.get("type") != "refresh": raise ValueError("令牌类型错误") return payload def verify_token(token: str) -> dict: """验证 access token,返回 payload""" try: payload = jwt.decode(token, settings.JWT_SECRET, algorithms=["HS256"]) except JWTError: raise ValueError("访问令牌无效") if payload.get("type") not in ("access", None): # None 兼容旧 token raise ValueError("令牌类型错误") return payload async def register_user(db: AsyncSession, user_data: UserRegister) -> User: stmt = select(User).where(User.email == user_data.email) result = await db.execute(stmt) existing_user = result.scalar_one_or_none() if existing_user: raise ValueError("邮箱已被注册") user = User( email=user_data.email, password_hash=hash_password(user_data.password), name=user_data.name, ) db.add(user) await db.commit() await db.refresh(user) return user async def authenticate_user( db: AsyncSession, email: str, password: str ) -> User | None: stmt = select(User).where(User.email == email) result = await db.execute(stmt) user = result.scalar_one_or_none() if not user: return None if not verify_password(password, user.password_hash): return None return user async def send_verification_code(db: AsyncSession, email: str) -> None: """生成6位随机验证码,存到user记录,日志输出(模拟邮件)""" stmt = select(User).where(User.email == email) result = await db.execute(stmt) user = result.scalar_one_or_none() if not user: return code = f"{random.randint(100000, 999999)}" user.verification_code = code user.verification_code_expires = datetime.utcnow() + timedelta(minutes=10) await db.commit() logger.info(f"[模拟邮件] 邮箱验证码发送到 {email}: {code}") async def verify_email(db: AsyncSession, email: str, code: str) -> bool: """验证码校验,成功则设置email_verified=True""" stmt = select(User).where(User.email == email) result = await db.execute(stmt) user = result.scalar_one_or_none() if not user: return False if user.verification_code != code: return False if user.verification_code_expires is None or user.verification_code_expires < datetime.utcnow(): return False user.email_verified = True user.verification_code = None user.verification_code_expires = None await db.commit() return True async def send_reset_link(db: AsyncSession, email: str) -> None: """生成UUID token,存到user记录,日志输出重置链接""" stmt = select(User).where(User.email == email) result = await db.execute(stmt) user = result.scalar_one_or_none() if not user: return token = str(uuid.uuid4()) user.reset_token = token user.reset_token_expires = datetime.utcnow() + timedelta(hours=1) await db.commit() logger.info(f"[模拟邮件] 密码重置链接: http://localhost:3000/reset-password?token={token}") async def reset_password(db: AsyncSession, token: str, new_password: str) -> bool: """token验证+密码更新""" stmt = select(User).where(User.reset_token == token) result = await db.execute(stmt) user = result.scalar_one_or_none() if not user: return False if user.reset_token_expires is None or user.reset_token_expires < datetime.utcnow(): return False user.password_hash = hash_password(new_password) user.reset_token = None user.reset_token_expires = None await db.commit() return True async def change_password(db: AsyncSession, user_id: uuid.UUID, old_password: str, new_password: str) -> bool: """旧密码验证后更新""" stmt = select(User).where(User.id == user_id) result = await db.execute(stmt) user = result.scalar_one_or_none() if not user: return False if not verify_password(old_password, user.password_hash): return False user.password_hash = hash_password(new_password) await db.commit() return True async def update_profile(db: AsyncSession, user_id: uuid.UUID, data: UpdateProfileRequest) -> User | None: """更新用户资料(name, avatar_url)""" stmt = select(User).where(User.id == user_id) result = await db.execute(stmt) user = result.scalar_one_or_none() if not user: return None if data.name is not None: user.name = data.name if data.avatar_url is not None: user.avatar_url = data.avatar_url await db.commit() await db.refresh(user) return user