from fastapi import Depends, HTTPException from sqlalchemy.ext.asyncio import AsyncSession from app.api.deps import get_current_user from app.database import get_db from app.models.user import User from app.services.subscription import PLANS class SubscriptionEnforcement: @staticmethod def require_plan(*allowed_plans: str): async def _check(current_user: User = Depends(get_current_user)): user_plan = getattr(current_user, "plan", "free") or "free" if user_plan not in allowed_plans: raise HTTPException( status_code=403, detail={ "message": f"此功能需要 {allowed_plans[0]} 及以上套餐", "required_plan": allowed_plans[0], "current_plan": user_plan, "upgrade_url": "/api/v1/subscriptions/plans", }, ) return current_user return _check @staticmethod def check_quota(resource: str): async def _check( current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): user_plan = getattr(current_user, "plan", "free") or "free" plan_config = PLANS.get(user_plan, PLANS["free"]) if resource == "queries": limit = plan_config.get("max_queries", 3) current_usage = getattr(current_user, "max_queries", limit) or limit remaining = max(0, limit - current_usage) elif resource == "brands": limit = plan_config.get("max_brands", 1) remaining = limit if limit == -1 else max(0, limit) elif resource == "alerts": limit = plan_config.get("max_alerts_per_month", 0) remaining = limit if limit == -1 else max(0, limit) else: remaining = 0 return { "user_id": current_user.id, "plan": user_plan, "resource": resource, "remaining": remaining, "unlimited": remaining == -1, } return _check