geo/backend/app/services/subscription.py

155 lines
4.1 KiB
Python

import logging
import uuid
from datetime import date, timedelta
from typing import Optional
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.subscription import Subscription
from app.models.user import User
from app.schemas.subscription import PlanDetail, PlanFeature, SubscriptionResponse
logger = logging.getLogger(__name__)
_PLAN_FEATURES = [
("基础查询监控", ["free", "starter", "pro", "enterprise"]),
("CSV导出", ["free", "starter", "pro", "enterprise"]),
("PDF报告", ["starter", "pro", "enterprise"]),
("定时查询", ["pro", "enterprise"]),
("竞品分析", ["pro", "enterprise"]),
("API访问", ["enterprise"]),
("专属支持", ["enterprise"]),
]
PLANS = {
"free": {
"name": "免费版",
"price": 0,
"max_queries": 5,
},
"starter": {
"name": "入门版",
"price": 99,
"max_queries": 20,
},
"pro": {
"name": "专业版",
"price": 299,
"max_queries": 100,
},
"enterprise": {
"name": "企业版",
"price": 999,
"max_queries": 500,
},
}
def _build_features(plan_id: str) -> list[PlanFeature]:
return [
PlanFeature(name=name, included=plan_id in allowed)
for name, allowed in _PLAN_FEATURES
]
def get_plans() -> list[PlanDetail]:
return [
PlanDetail(
id=plan_id,
name=data["name"],
price=data["price"],
max_queries=data["max_queries"],
features=_build_features(plan_id),
)
for plan_id, data in PLANS.items()
]
async def get_current_subscription(
db: AsyncSession, user_id: uuid.UUID
) -> Optional[SubscriptionResponse]:
stmt = (
select(Subscription)
.where(Subscription.user_id == user_id)
.order_by(Subscription.created_at.desc())
.limit(1)
)
result = await db.execute(stmt)
sub = result.scalar_one_or_none()
if sub is None:
return None
return SubscriptionResponse.model_validate(sub)
async def subscribe(
db: AsyncSession, user_id: uuid.UUID, plan: str
) -> SubscriptionResponse:
plan_data = PLANS.get(plan)
if plan_data is None:
raise ValueError(f"Invalid plan: {plan}")
today = date.today()
end_date = today + timedelta(days=30)
subscription = Subscription(
user_id=user_id,
plan=plan,
status="active",
start_date=today,
end_date=end_date,
amount=plan_data["price"],
payment_method="模拟支付",
)
db.add(subscription)
user_stmt = select(User).where(User.id == user_id)
user_result = await db.execute(user_stmt)
user = user_result.scalar_one()
user.plan = plan
user.max_queries = plan_data["max_queries"]
await db.commit()
await db.refresh(subscription)
logger.info(f"[模拟支付] 用户{user_id} 订阅{plan},金额{plan_data['price']}")
return SubscriptionResponse.model_validate(subscription)
async def cancel_subscription(db: AsyncSession, user_id: uuid.UUID) -> dict:
stmt = (
select(Subscription)
.where(Subscription.user_id == user_id, Subscription.status == "active")
.order_by(Subscription.created_at.desc())
.limit(1)
)
result = await db.execute(stmt)
sub = result.scalar_one_or_none()
if sub is not None:
sub.status = "cancelled"
user_stmt = select(User).where(User.id == user_id)
user_result = await db.execute(user_stmt)
user = user_result.scalar_one()
user.plan = "free"
user.max_queries = PLANS["free"]["max_queries"]
await db.commit()
return {"message": "订阅已取消,已降级到免费版"}
async def get_subscription_history(
db: AsyncSession, user_id: uuid.UUID
) -> list[SubscriptionResponse]:
stmt = (
select(Subscription)
.where(Subscription.user_id == user_id)
.order_by(Subscription.created_at.desc())
)
result = await db.execute(stmt)
subs = result.scalars().all()
return [SubscriptionResponse.model_validate(sub) for sub in subs]