275 lines
7.9 KiB
Python
275 lines
7.9 KiB
Python
import logging
|
|
import uuid
|
|
from datetime import datetime, timezone
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
|
from pydantic import BaseModel
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.api.deps import get_current_user
|
|
from app.database import get_db
|
|
from app.models.payment_order import PaymentOrder as PaymentOrderModel
|
|
from app.models.user import User
|
|
from app.services.payment import get_payment_gateway
|
|
from app.services.subscription import PLANS, subscribe
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(prefix="/api/v1/payments", tags=["支付"])
|
|
|
|
|
|
class CreateOrderRequest(BaseModel):
|
|
plan: str
|
|
payment_provider: str = "wechat"
|
|
|
|
|
|
class CreateOrderResponse(BaseModel):
|
|
order_id: str
|
|
pay_url: str
|
|
amount: float
|
|
currency: str = "CNY"
|
|
status: str = "pending"
|
|
|
|
|
|
class OrderStatusResponse(BaseModel):
|
|
order_id: str
|
|
status: str
|
|
plan: str
|
|
amount: float
|
|
payment_provider: str
|
|
payment_id: str | None = None
|
|
created_at: str | None = None
|
|
paid_at: str | None = None
|
|
|
|
|
|
class RefundRequest(BaseModel):
|
|
reason: str = ""
|
|
|
|
|
|
@router.post("/orders", response_model=CreateOrderResponse, status_code=status.HTTP_201_CREATED)
|
|
async def create_payment_order(
|
|
request: CreateOrderRequest,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
plan_data = PLANS.get(request.plan)
|
|
if plan_data is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
detail=f"无效的套餐: {request.plan}",
|
|
)
|
|
|
|
if plan_data["price"] == 0:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="免费套餐无需支付",
|
|
)
|
|
|
|
order_id = uuid.uuid4()
|
|
amount = plan_data["price"]
|
|
|
|
gateway = get_payment_gateway(request.payment_provider)
|
|
payment_order = await gateway.create_order(
|
|
order_id=str(order_id),
|
|
amount=amount,
|
|
description=f"GEO平台{plan_data['name']}订阅",
|
|
user_id=current_user.id,
|
|
plan=request.plan,
|
|
)
|
|
|
|
db_order = PaymentOrderModel(
|
|
id=order_id,
|
|
user_id=current_user.id,
|
|
plan=request.plan,
|
|
amount=amount,
|
|
payment_provider=request.payment_provider,
|
|
status="pending",
|
|
pay_url=payment_order.pay_url,
|
|
)
|
|
db.add(db_order)
|
|
await db.commit()
|
|
|
|
return CreateOrderResponse(
|
|
order_id=str(order_id),
|
|
pay_url=payment_order.pay_url,
|
|
amount=amount,
|
|
status="pending",
|
|
)
|
|
|
|
|
|
@router.post("/callback/wechat")
|
|
async def wechat_pay_callback(request: Request):
|
|
body = await request.form()
|
|
request_data = dict(body)
|
|
|
|
gateway = get_payment_gateway("wechat")
|
|
callback = await gateway.verify_callback(request_data)
|
|
|
|
return await _handle_payment_callback(request_data, callback, "wechat")
|
|
|
|
|
|
@router.post("/callback/alipay")
|
|
async def alipay_callback(request: Request):
|
|
body = await request.form()
|
|
request_data = dict(body)
|
|
|
|
gateway = get_payment_gateway("alipay")
|
|
callback = await gateway.verify_callback(request_data)
|
|
|
|
return await _handle_payment_callback(request_data, callback, "alipay")
|
|
|
|
|
|
async def _handle_payment_callback(request_data: dict, callback, provider: str):
|
|
from app.database import AsyncSessionLocal
|
|
|
|
async with AsyncSessionLocal() as db:
|
|
try:
|
|
result = await _process_callback(db, callback, provider)
|
|
await db.commit()
|
|
return result
|
|
except Exception as e:
|
|
logger.error(f"[PaymentCallback] 处理回调异常: {e}", exc_info=True)
|
|
await db.rollback()
|
|
if provider == "wechat":
|
|
return _wechat_fail_response()
|
|
return "fail"
|
|
|
|
|
|
async def _process_callback(db: AsyncSession, callback, provider: str):
|
|
stmt = select(PaymentOrderModel).where(
|
|
PaymentOrderModel.id == uuid.UUID(callback.order_id)
|
|
)
|
|
result = await db.execute(stmt)
|
|
order = result.scalar_one_or_none()
|
|
|
|
if order is None:
|
|
logger.warning(f"[PaymentCallback] 订单不存在: order_id={callback.order_id}")
|
|
if provider == "wechat":
|
|
return _wechat_fail_response()
|
|
return "fail"
|
|
|
|
if callback.status == "success":
|
|
order.status = "paid"
|
|
order.payment_id = callback.payment_id
|
|
order.callback_data = callback.raw_data
|
|
order.paid_at = datetime.now(timezone.utc)
|
|
|
|
await subscribe(db, order.user_id, order.plan)
|
|
|
|
logger.info(
|
|
f"[PaymentCallback] 支付成功: order_id={callback.order_id}, "
|
|
f"plan={order.plan}, provider={provider}"
|
|
)
|
|
else:
|
|
order.status = "failed"
|
|
order.callback_data = callback.raw_data
|
|
|
|
if provider == "wechat":
|
|
return _wechat_success_response()
|
|
return "success"
|
|
|
|
|
|
def _wechat_success_response():
|
|
from fastapi.responses import Response
|
|
return Response(content="<xml><return_code><![CDATA[SUCCESS]]></return_code></xml>", media_type="application/xml")
|
|
|
|
|
|
def _wechat_fail_response():
|
|
from fastapi.responses import Response
|
|
return Response(content="<xml><return_code><![CDATA[FAIL]]></return_code></xml>", media_type="application/xml")
|
|
|
|
|
|
@router.get("/orders/{order_id}", response_model=OrderStatusResponse)
|
|
async def query_order_status(
|
|
order_id: str,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
try:
|
|
oid = uuid.UUID(order_id)
|
|
except ValueError:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="无效的订单ID",
|
|
)
|
|
|
|
stmt = select(PaymentOrderModel).where(PaymentOrderModel.id == oid)
|
|
result = await db.execute(stmt)
|
|
order = result.scalar_one_or_none()
|
|
|
|
if order is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="订单不存在",
|
|
)
|
|
|
|
if order.user_id != current_user.id:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="无权查看此订单",
|
|
)
|
|
|
|
return OrderStatusResponse(
|
|
order_id=str(order.id),
|
|
status=order.status,
|
|
plan=order.plan,
|
|
amount=order.amount,
|
|
payment_provider=order.payment_provider,
|
|
payment_id=order.payment_id,
|
|
created_at=order.created_at.isoformat() if order.created_at else None,
|
|
paid_at=order.paid_at.isoformat() if order.paid_at else None,
|
|
)
|
|
|
|
|
|
@router.post("/refund/{order_id}")
|
|
async def refund_order(
|
|
order_id: str,
|
|
body: RefundRequest = RefundRequest(),
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
user_plan = getattr(current_user, "plan", "free") or "free"
|
|
if user_plan != "enterprise":
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="仅企业管理员可执行退款操作",
|
|
)
|
|
|
|
try:
|
|
oid = uuid.UUID(order_id)
|
|
except ValueError:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="无效的订单ID",
|
|
)
|
|
|
|
stmt = select(PaymentOrderModel).where(PaymentOrderModel.id == oid)
|
|
result = await db.execute(stmt)
|
|
order = result.scalar_one_or_none()
|
|
|
|
if order is None:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="订单不存在",
|
|
)
|
|
|
|
if order.status != "paid":
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="仅已支付订单可退款",
|
|
)
|
|
|
|
gateway = get_payment_gateway(order.payment_provider)
|
|
success = await gateway.refund(order_id, order.amount, body.reason)
|
|
|
|
if success:
|
|
order.status = "refunded"
|
|
await db.commit()
|
|
return {"message": "退款成功", "order_id": order_id}
|
|
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="退款失败",
|
|
)
|