58 lines
2.2 KiB
Python
58 lines
2.2 KiB
Python
from fastapi import Depends, HTTPException
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.api.deps import get_current_user
|
|
from app.database import get_db
|
|
from app.models.user import User
|
|
from app.services.subscription import PLANS
|
|
|
|
|
|
class SubscriptionEnforcement:
|
|
@staticmethod
|
|
def require_plan(*allowed_plans: str):
|
|
async def _check(current_user: User = Depends(get_current_user)):
|
|
user_plan = getattr(current_user, "plan", "free") or "free"
|
|
if user_plan not in allowed_plans:
|
|
raise HTTPException(
|
|
status_code=403,
|
|
detail={
|
|
"message": f"此功能需要 {allowed_plans[0]} 及以上套餐",
|
|
"required_plan": allowed_plans[0],
|
|
"current_plan": user_plan,
|
|
"upgrade_url": "/api/v1/subscriptions/plans",
|
|
},
|
|
)
|
|
return current_user
|
|
return _check
|
|
|
|
@staticmethod
|
|
def check_quota(resource: str):
|
|
async def _check(
|
|
current_user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
user_plan = getattr(current_user, "plan", "free") or "free"
|
|
plan_config = PLANS.get(user_plan, PLANS["free"])
|
|
|
|
if resource == "queries":
|
|
limit = plan_config.get("max_queries", 3)
|
|
current_usage = getattr(current_user, "max_queries", limit) or limit
|
|
remaining = max(0, limit - current_usage)
|
|
elif resource == "brands":
|
|
limit = plan_config.get("max_brands", 1)
|
|
remaining = limit if limit == -1 else max(0, limit)
|
|
elif resource == "alerts":
|
|
limit = plan_config.get("max_alerts_per_month", 0)
|
|
remaining = limit if limit == -1 else max(0, limit)
|
|
else:
|
|
remaining = 0
|
|
|
|
return {
|
|
"user_id": current_user.id,
|
|
"plan": user_plan,
|
|
"resource": resource,
|
|
"remaining": remaining,
|
|
"unlimited": remaining == -1,
|
|
}
|
|
return _check
|