import logging import uuid from datetime import date, timedelta from typing import Optional from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.models.subscription import Subscription from app.models.user import User from app.schemas.subscription import PlanDetail, PlanFeature, SubscriptionResponse logger = logging.getLogger(__name__) _PLAN_FEATURES = [ ("基础查询监控", ["free", "starter", "pro", "enterprise"]), ("CSV导出", ["free", "starter", "pro", "enterprise"]), ("7天数据保留", ["free"]), ("30天数据保留", ["starter", "pro", "enterprise"]), ("1个品牌监控", ["free", "starter"]), ("3个品牌监控", ["pro"]), ("无限品牌监控", ["enterprise"]), ("PDF报告", ["starter", "pro", "enterprise"]), ("基础竞品对比", ["starter"]), ("完整竞品对比+雷达图", ["pro", "enterprise"]), ("5条告警/月", ["starter"]), ("无限告警", ["pro", "enterprise"]), ("规则优化建议", ["starter", "pro", "enterprise"]), ("AI个性化建议(DeepSeek)", ["pro", "enterprise"]), ("情感分析", ["pro", "enterprise"]), ("定时查询", ["pro", "enterprise"]), ("API访问", ["enterprise"]), ("白标报告", ["enterprise"]), ("专属客户成功经理", ["enterprise"]), ] PLANS = { "free": { "name": "免费版", "price": 0, "max_queries": 3, "max_brands": 1, "max_alerts_per_month": 0, "data_retention_days": 7, }, "starter": { "name": "入门版", "price": 199, "max_queries": 15, "max_brands": 1, "max_alerts_per_month": 5, "data_retention_days": 30, }, "pro": { "name": "专业版", "price": 599, "max_queries": 50, "max_brands": 3, "max_alerts_per_month": -1, # 无限 "data_retention_days": 30, }, "enterprise": { "name": "企业版", "price": 1999, "max_queries": 200, "max_brands": -1, # 无限 "max_alerts_per_month": -1, "data_retention_days": 30, }, } def _build_features(plan_id: str) -> list[PlanFeature]: return [ PlanFeature(name=name, included=plan_id in allowed) for name, allowed in _PLAN_FEATURES ] def get_plans() -> list[PlanDetail]: return [ PlanDetail( id=plan_id, name=data["name"], price=data["price"], max_queries=data["max_queries"], max_brands=data["max_brands"], max_alerts_per_month=data["max_alerts_per_month"], data_retention_days=data["data_retention_days"], features=_build_features(plan_id), ) for plan_id, data in PLANS.items() ] async def get_current_subscription( db: AsyncSession, user_id: uuid.UUID ) -> Optional[SubscriptionResponse]: stmt = ( select(Subscription) .where(Subscription.user_id == user_id) .order_by(Subscription.created_at.desc()) .limit(1) ) result = await db.execute(stmt) sub = result.scalar_one_or_none() if sub is None: return None return SubscriptionResponse.model_validate(sub) async def subscribe( db: AsyncSession, user_id: uuid.UUID, plan: str ) -> SubscriptionResponse: plan_data = PLANS.get(plan) if plan_data is None: raise ValueError(f"Invalid plan: {plan}") today = date.today() end_date = today + timedelta(days=30) subscription = Subscription( user_id=user_id, plan=plan, status="active", start_date=today, end_date=end_date, amount=plan_data["price"], payment_method="模拟支付", ) db.add(subscription) user_stmt = select(User).where(User.id == user_id) user_result = await db.execute(user_stmt) user = user_result.scalar_one() user.plan = plan user.max_queries = plan_data["max_queries"] await db.commit() await db.refresh(subscription) logger.info(f"[模拟支付] 用户{user_id} 订阅{plan},金额{plan_data['price']}元") return SubscriptionResponse.model_validate(subscription) async def cancel_subscription(db: AsyncSession, user_id: uuid.UUID) -> dict: stmt = ( select(Subscription) .where(Subscription.user_id == user_id, Subscription.status == "active") .order_by(Subscription.created_at.desc()) .limit(1) ) result = await db.execute(stmt) sub = result.scalar_one_or_none() if sub is not None: sub.status = "cancelled" user_stmt = select(User).where(User.id == user_id) user_result = await db.execute(user_stmt) user = user_result.scalar_one() user.plan = "free" user.max_queries = PLANS["free"]["max_queries"] await db.commit() return {"message": "订阅已取消,已降级到免费版"} async def get_subscription_history( db: AsyncSession, user_id: uuid.UUID ) -> list[SubscriptionResponse]: stmt = ( select(Subscription) .where(Subscription.user_id == user_id) .order_by(Subscription.created_at.desc()) ) result = await db.execute(stmt) subs = result.scalars().all() return [SubscriptionResponse.model_validate(sub) for sub in subs]