geo/backend/app/services/auth.py

176 lines
4.8 KiB
Python

import logging
import random
import uuid
from datetime import datetime, timedelta
import bcrypt
from jose import jwt, JWTError
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.models.user import User
from app.models.organization import Organization, OrgMember
from app.schemas.auth import UserRegister, UpdateProfileRequest
logger = logging.getLogger(__name__)
def hash_password(password: str) -> str:
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
def verify_password(plain_password: str, hashed_password: str) -> bool:
return bcrypt.checkpw(plain_password.encode(), hashed_password.encode())
def create_access_token(data: dict) -> str:
to_encode = data.copy()
expire = datetime.utcnow() + timedelta(hours=1)
to_encode.update({"exp": expire, "type": "access"})
encoded_jwt = jwt.encode(to_encode, settings.JWT_SECRET, algorithm="HS256")
return encoded_jwt
def create_refresh_token(data: dict) -> str:
to_encode = data.copy()
expire = datetime.utcnow() + timedelta(days=7)
to_encode.update({"exp": expire, "type": "refresh"})
encoded_jwt = jwt.encode(to_encode, settings.JWT_SECRET, algorithm="HS256")
return encoded_jwt
def verify_refresh_token(token: str) -> dict:
try:
payload = jwt.decode(token, settings.JWT_SECRET, algorithms=["HS256"])
except JWTError:
raise ValueError("刷新令牌无效")
if payload.get("type") != "refresh":
raise ValueError("令牌类型错误")
return payload
def verify_token(token: str) -> dict:
try:
payload = jwt.decode(token, settings.JWT_SECRET, algorithms=["HS256"])
except JWTError:
raise ValueError("访问令牌无效")
if payload.get("type") not in ("access", None):
raise ValueError("令牌类型错误")
return payload
async def register_user(db: AsyncSession, user_data: UserRegister) -> User:
stmt = select(User).where(User.email == user_data.email)
result = await db.execute(stmt)
existing_user = result.scalar_one_or_none()
if existing_user:
raise ValueError("邮箱已被注册")
user = User(
id=str(uuid.uuid4()),
email=user_data.email,
password=hash_password(user_data.password),
username=user_data.name,
plan="free",
max_queries=5,
)
db.add(user)
await db.flush()
# Auto-create personal organization
org = Organization(
id=uuid.UUID(user.id),
name=f"{user_data.name}的个人空间",
slug=f"user-{user.id[:8]}",
plan="free",
)
db.add(org)
await db.flush()
org_member = OrgMember(
organization_id=org.id,
user_id=user.id,
role="owner",
)
db.add(org_member)
# Link user to organization
user.organization_id = org.id
await db.commit()
await db.refresh(user)
return user
async def authenticate_user(
db: AsyncSession, email: str, password: str
) -> User | None:
stmt = select(User).where(User.email == email)
result = await db.execute(stmt)
user = result.scalar_one_or_none()
if not user:
return None
if not verify_password(password, user.password):
return None
return user
async def send_verification_code(db: AsyncSession, email: str) -> None:
logger.info(f"[模拟邮件] 邮箱验证码发送到 {email}")
async def verify_email(db: AsyncSession, email: str, code: str) -> bool:
stmt = select(User).where(User.email == email)
result = await db.execute(stmt)
user = result.scalar_one_or_none()
if not user:
return False
user.emailVerified = True
await db.commit()
return True
async def send_reset_link(db: AsyncSession, email: str) -> None:
logger.info(f"[模拟邮件] 密码重置链接发送到 {email}")
async def reset_password(db: AsyncSession, token: str, new_password: str) -> bool:
return False
async def change_password(db: AsyncSession, user_id, old_password: str, new_password: str) -> bool:
stmt = select(User).where(User.id == user_id)
result = await db.execute(stmt)
user = result.scalar_one_or_none()
if not user:
return False
if not verify_password(old_password, user.password):
return False
user.password = hash_password(new_password)
await db.commit()
return True
async def update_profile(db: AsyncSession, user_id, data: UpdateProfileRequest) -> User | None:
stmt = select(User).where(User.id == user_id)
result = await db.execute(stmt)
user = result.scalar_one_or_none()
if not user:
return None
if data.name is not None:
user.username = data.name
if data.avatar_url is not None:
user.avatar = data.avatar_url
await db.commit()
await db.refresh(user)
return user