155 lines
4.1 KiB
Python
155 lines
4.1 KiB
Python
import logging
|
|
import uuid
|
|
from datetime import date, timedelta
|
|
from typing import Optional
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.models.subscription import Subscription
|
|
from app.models.user import User
|
|
from app.schemas.subscription import PlanDetail, PlanFeature, SubscriptionResponse
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_PLAN_FEATURES = [
|
|
("基础查询监控", ["free", "starter", "pro", "enterprise"]),
|
|
("CSV导出", ["free", "starter", "pro", "enterprise"]),
|
|
("PDF报告", ["starter", "pro", "enterprise"]),
|
|
("定时查询", ["pro", "enterprise"]),
|
|
("竞品分析", ["pro", "enterprise"]),
|
|
("API访问", ["enterprise"]),
|
|
("专属支持", ["enterprise"]),
|
|
]
|
|
|
|
PLANS = {
|
|
"free": {
|
|
"name": "免费版",
|
|
"price": 0,
|
|
"max_queries": 5,
|
|
},
|
|
"starter": {
|
|
"name": "入门版",
|
|
"price": 99,
|
|
"max_queries": 20,
|
|
},
|
|
"pro": {
|
|
"name": "专业版",
|
|
"price": 299,
|
|
"max_queries": 100,
|
|
},
|
|
"enterprise": {
|
|
"name": "企业版",
|
|
"price": 999,
|
|
"max_queries": 500,
|
|
},
|
|
}
|
|
|
|
|
|
def _build_features(plan_id: str) -> list[PlanFeature]:
|
|
return [
|
|
PlanFeature(name=name, included=plan_id in allowed)
|
|
for name, allowed in _PLAN_FEATURES
|
|
]
|
|
|
|
|
|
def get_plans() -> list[PlanDetail]:
|
|
return [
|
|
PlanDetail(
|
|
id=plan_id,
|
|
name=data["name"],
|
|
price=data["price"],
|
|
max_queries=data["max_queries"],
|
|
features=_build_features(plan_id),
|
|
)
|
|
for plan_id, data in PLANS.items()
|
|
]
|
|
|
|
|
|
async def get_current_subscription(
|
|
db: AsyncSession, user_id: uuid.UUID
|
|
) -> Optional[SubscriptionResponse]:
|
|
stmt = (
|
|
select(Subscription)
|
|
.where(Subscription.user_id == user_id)
|
|
.order_by(Subscription.created_at.desc())
|
|
.limit(1)
|
|
)
|
|
result = await db.execute(stmt)
|
|
sub = result.scalar_one_or_none()
|
|
if sub is None:
|
|
return None
|
|
return SubscriptionResponse.model_validate(sub)
|
|
|
|
|
|
async def subscribe(
|
|
db: AsyncSession, user_id: uuid.UUID, plan: str
|
|
) -> SubscriptionResponse:
|
|
plan_data = PLANS.get(plan)
|
|
if plan_data is None:
|
|
raise ValueError(f"Invalid plan: {plan}")
|
|
|
|
today = date.today()
|
|
end_date = today + timedelta(days=30)
|
|
|
|
subscription = Subscription(
|
|
user_id=user_id,
|
|
plan=plan,
|
|
status="active",
|
|
start_date=today,
|
|
end_date=end_date,
|
|
amount=plan_data["price"],
|
|
payment_method="模拟支付",
|
|
)
|
|
db.add(subscription)
|
|
|
|
user_stmt = select(User).where(User.id == user_id)
|
|
user_result = await db.execute(user_stmt)
|
|
user = user_result.scalar_one()
|
|
user.plan = plan
|
|
user.max_queries = plan_data["max_queries"]
|
|
|
|
await db.commit()
|
|
await db.refresh(subscription)
|
|
|
|
logger.info(f"[模拟支付] 用户{user_id} 订阅{plan},金额{plan_data['price']}元")
|
|
|
|
return SubscriptionResponse.model_validate(subscription)
|
|
|
|
|
|
async def cancel_subscription(db: AsyncSession, user_id: uuid.UUID) -> dict:
|
|
stmt = (
|
|
select(Subscription)
|
|
.where(Subscription.user_id == user_id, Subscription.status == "active")
|
|
.order_by(Subscription.created_at.desc())
|
|
.limit(1)
|
|
)
|
|
result = await db.execute(stmt)
|
|
sub = result.scalar_one_or_none()
|
|
|
|
if sub is not None:
|
|
sub.status = "cancelled"
|
|
|
|
user_stmt = select(User).where(User.id == user_id)
|
|
user_result = await db.execute(user_stmt)
|
|
user = user_result.scalar_one()
|
|
user.plan = "free"
|
|
user.max_queries = PLANS["free"]["max_queries"]
|
|
|
|
await db.commit()
|
|
|
|
return {"message": "订阅已取消,已降级到免费版"}
|
|
|
|
|
|
async def get_subscription_history(
|
|
db: AsyncSession, user_id: uuid.UUID
|
|
) -> list[SubscriptionResponse]:
|
|
stmt = (
|
|
select(Subscription)
|
|
.where(Subscription.user_id == user_id)
|
|
.order_by(Subscription.created_at.desc())
|
|
)
|
|
result = await db.execute(stmt)
|
|
subs = result.scalars().all()
|
|
return [SubscriptionResponse.model_validate(sub) for sub in subs]
|