geo/backend/app/services/subscription.py

182 lines
5.2 KiB
Python

import logging
import uuid
from datetime import date, timedelta
from typing import Optional
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.subscription import Subscription
from app.models.user import User
from app.schemas.subscription import PlanDetail, PlanFeature, SubscriptionResponse
logger = logging.getLogger(__name__)
_PLAN_FEATURES = [
("基础查询监控", ["free", "starter", "pro", "enterprise"]),
("CSV导出", ["free", "starter", "pro", "enterprise"]),
("7天数据保留", ["free"]),
("30天数据保留", ["starter", "pro", "enterprise"]),
("1个品牌监控", ["free", "starter"]),
("3个品牌监控", ["pro"]),
("无限品牌监控", ["enterprise"]),
("PDF报告", ["starter", "pro", "enterprise"]),
("基础竞品对比", ["starter"]),
("完整竞品对比+雷达图", ["pro", "enterprise"]),
("5条告警/月", ["starter"]),
("无限告警", ["pro", "enterprise"]),
("规则优化建议", ["starter", "pro", "enterprise"]),
("AI个性化建议(DeepSeek)", ["pro", "enterprise"]),
("情感分析", ["pro", "enterprise"]),
("定时查询", ["pro", "enterprise"]),
("API访问", ["enterprise"]),
("白标报告", ["enterprise"]),
("专属客户成功经理", ["enterprise"]),
]
PLANS = {
"free": {
"name": "免费版",
"price": 0,
"max_queries": 3,
"max_brands": 1,
"max_alerts_per_month": 0,
"data_retention_days": 7,
},
"starter": {
"name": "入门版",
"price": 199,
"max_queries": 15,
"max_brands": 1,
"max_alerts_per_month": 5,
"data_retention_days": 30,
},
"pro": {
"name": "专业版",
"price": 599,
"max_queries": 50,
"max_brands": 3,
"max_alerts_per_month": -1, # 无限
"data_retention_days": 30,
},
"enterprise": {
"name": "企业版",
"price": 1999,
"max_queries": 200,
"max_brands": -1, # 无限
"max_alerts_per_month": -1,
"data_retention_days": 30,
},
}
def _build_features(plan_id: str) -> list[PlanFeature]:
return [
PlanFeature(name=name, included=plan_id in allowed)
for name, allowed in _PLAN_FEATURES
]
def get_plans() -> list[PlanDetail]:
return [
PlanDetail(
id=plan_id,
name=data["name"],
price=data["price"],
max_queries=data["max_queries"],
max_brands=data["max_brands"],
max_alerts_per_month=data["max_alerts_per_month"],
data_retention_days=data["data_retention_days"],
features=_build_features(plan_id),
)
for plan_id, data in PLANS.items()
]
async def get_current_subscription(
db: AsyncSession, user_id: uuid.UUID
) -> Optional[SubscriptionResponse]:
stmt = (
select(Subscription)
.where(Subscription.user_id == user_id)
.order_by(Subscription.created_at.desc())
.limit(1)
)
result = await db.execute(stmt)
sub = result.scalar_one_or_none()
if sub is None:
return None
return SubscriptionResponse.model_validate(sub)
async def subscribe(
db: AsyncSession, user_id: uuid.UUID, plan: str
) -> SubscriptionResponse:
plan_data = PLANS.get(plan)
if plan_data is None:
raise ValueError(f"Invalid plan: {plan}")
today = date.today()
end_date = today + timedelta(days=30)
subscription = Subscription(
user_id=user_id,
plan=plan,
status="active",
start_date=today,
end_date=end_date,
amount=plan_data["price"],
payment_method="模拟支付",
)
db.add(subscription)
user_stmt = select(User).where(User.id == user_id)
user_result = await db.execute(user_stmt)
user = user_result.scalar_one()
user.plan = plan
user.max_queries = plan_data["max_queries"]
await db.commit()
await db.refresh(subscription)
logger.info(f"[模拟支付] 用户{user_id} 订阅{plan},金额{plan_data['price']}")
return SubscriptionResponse.model_validate(subscription)
async def cancel_subscription(db: AsyncSession, user_id: uuid.UUID) -> dict:
stmt = (
select(Subscription)
.where(Subscription.user_id == user_id, Subscription.status == "active")
.order_by(Subscription.created_at.desc())
.limit(1)
)
result = await db.execute(stmt)
sub = result.scalar_one_or_none()
if sub is not None:
sub.status = "cancelled"
user_stmt = select(User).where(User.id == user_id)
user_result = await db.execute(user_stmt)
user = user_result.scalar_one()
user.plan = "free"
user.max_queries = PLANS["free"]["max_queries"]
await db.commit()
return {"message": "订阅已取消,已降级到免费版"}
async def get_subscription_history(
db: AsyncSession, user_id: uuid.UUID
) -> list[SubscriptionResponse]:
stmt = (
select(Subscription)
.where(Subscription.user_id == user_id)
.order_by(Subscription.created_at.desc())
)
result = await db.execute(stmt)
subs = result.scalars().all()
return [SubscriptionResponse.model_validate(sub) for sub in subs]