geo/backend/app/middleware/subscription_enforcement.py

58 lines
2.2 KiB
Python

from fastapi import Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.database import get_db
from app.models.user import User
from app.services.subscription import PLANS
class SubscriptionEnforcement:
@staticmethod
def require_plan(*allowed_plans: str):
async def _check(current_user: User = Depends(get_current_user)):
user_plan = getattr(current_user, "plan", "free") or "free"
if user_plan not in allowed_plans:
raise HTTPException(
status_code=403,
detail={
"message": f"此功能需要 {allowed_plans[0]} 及以上套餐",
"required_plan": allowed_plans[0],
"current_plan": user_plan,
"upgrade_url": "/api/v1/subscriptions/plans",
},
)
return current_user
return _check
@staticmethod
def check_quota(resource: str):
async def _check(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
user_plan = getattr(current_user, "plan", "free") or "free"
plan_config = PLANS.get(user_plan, PLANS["free"])
if resource == "queries":
limit = plan_config.get("max_queries", 3)
current_usage = getattr(current_user, "max_queries", limit) or limit
remaining = max(0, limit - current_usage)
elif resource == "brands":
limit = plan_config.get("max_brands", 1)
remaining = limit if limit == -1 else max(0, limit)
elif resource == "alerts":
limit = plan_config.get("max_alerts_per_month", 0)
remaining = limit if limit == -1 else max(0, limit)
else:
remaining = 0
return {
"user_id": current_user.id,
"plan": user_plan,
"resource": resource,
"remaining": remaining,
"unlimited": remaining == -1,
}
return _check