import logging import uuid from datetime import date, timedelta from typing import Optional from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.models.subscription import Subscription from app.models.user import User from app.schemas.subscription import PlanDetail, PlanFeature, SubscriptionResponse logger = logging.getLogger(__name__) _PLAN_FEATURES = [ ("基础查询监控", ["free", "starter", "pro", "enterprise"]), ("CSV导出", ["free", "starter", "pro", "enterprise"]), ("PDF报告", ["starter", "pro", "enterprise"]), ("定时查询", ["pro", "enterprise"]), ("竞品分析", ["pro", "enterprise"]), ("API访问", ["enterprise"]), ("专属支持", ["enterprise"]), ] PLANS = { "free": { "name": "免费版", "price": 0, "max_queries": 5, }, "starter": { "name": "入门版", "price": 99, "max_queries": 20, }, "pro": { "name": "专业版", "price": 299, "max_queries": 100, }, "enterprise": { "name": "企业版", "price": 999, "max_queries": 500, }, } def _build_features(plan_id: str) -> list[PlanFeature]: return [ PlanFeature(name=name, included=plan_id in allowed) for name, allowed in _PLAN_FEATURES ] def get_plans() -> list[PlanDetail]: return [ PlanDetail( id=plan_id, name=data["name"], price=data["price"], max_queries=data["max_queries"], features=_build_features(plan_id), ) for plan_id, data in PLANS.items() ] async def get_current_subscription( db: AsyncSession, user_id: uuid.UUID ) -> Optional[SubscriptionResponse]: stmt = ( select(Subscription) .where(Subscription.user_id == user_id) .order_by(Subscription.created_at.desc()) .limit(1) ) result = await db.execute(stmt) sub = result.scalar_one_or_none() if sub is None: return None return SubscriptionResponse.model_validate(sub) async def subscribe( db: AsyncSession, user_id: uuid.UUID, plan: str ) -> SubscriptionResponse: plan_data = PLANS.get(plan) if plan_data is None: raise ValueError(f"Invalid plan: {plan}") today = date.today() end_date = today + timedelta(days=30) subscription = Subscription( user_id=user_id, plan=plan, status="active", start_date=today, end_date=end_date, amount=plan_data["price"], payment_method="模拟支付", ) db.add(subscription) user_stmt = select(User).where(User.id == user_id) user_result = await db.execute(user_stmt) user = user_result.scalar_one() user.plan = plan user.max_queries = plan_data["max_queries"] await db.commit() await db.refresh(subscription) logger.info(f"[模拟支付] 用户{user_id} 订阅{plan},金额{plan_data['price']}元") return SubscriptionResponse.model_validate(subscription) async def cancel_subscription(db: AsyncSession, user_id: uuid.UUID) -> dict: stmt = ( select(Subscription) .where(Subscription.user_id == user_id, Subscription.status == "active") .order_by(Subscription.created_at.desc()) .limit(1) ) result = await db.execute(stmt) sub = result.scalar_one_or_none() if sub is not None: sub.status = "cancelled" user_stmt = select(User).where(User.id == user_id) user_result = await db.execute(user_stmt) user = user_result.scalar_one() user.plan = "free" user.max_queries = PLANS["free"]["max_queries"] await db.commit() return {"message": "订阅已取消,已降级到免费版"} async def get_subscription_history( db: AsyncSession, user_id: uuid.UUID ) -> list[SubscriptionResponse]: stmt = ( select(Subscription) .where(Subscription.user_id == user_id) .order_by(Subscription.created_at.desc()) ) result = await db.execute(stmt) subs = result.scalars().all() return [SubscriptionResponse.model_validate(sub) for sub in subs]