import logging
import uuid
from datetime import datetime, timezone
from fastapi import APIRouter, Depends, HTTPException, Request, status
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.database import get_db
from app.models.payment_order import PaymentOrder as PaymentOrderModel
from app.models.user import User
from app.services.payment import get_payment_gateway
from app.services.subscription import PLANS, subscribe
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/payments", tags=["支付"])
class CreateOrderRequest(BaseModel):
plan: str
payment_provider: str = "wechat"
class CreateOrderResponse(BaseModel):
order_id: str
pay_url: str
amount: float
currency: str = "CNY"
status: str = "pending"
class OrderStatusResponse(BaseModel):
order_id: str
status: str
plan: str
amount: float
payment_provider: str
payment_id: str | None = None
created_at: str | None = None
paid_at: str | None = None
class RefundRequest(BaseModel):
reason: str = ""
@router.post("/orders", response_model=CreateOrderResponse, status_code=status.HTTP_201_CREATED)
async def create_payment_order(
request: CreateOrderRequest,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
plan_data = PLANS.get(request.plan)
if plan_data is None:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"无效的套餐: {request.plan}",
)
if plan_data["price"] == 0:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="免费套餐无需支付",
)
order_id = uuid.uuid4()
amount = plan_data["price"]
gateway = get_payment_gateway(request.payment_provider)
payment_order = await gateway.create_order(
order_id=str(order_id),
amount=amount,
description=f"GEO平台{plan_data['name']}订阅",
user_id=current_user.id,
plan=request.plan,
)
db_order = PaymentOrderModel(
id=order_id,
user_id=current_user.id,
plan=request.plan,
amount=amount,
payment_provider=request.payment_provider,
status="pending",
pay_url=payment_order.pay_url,
)
db.add(db_order)
await db.commit()
return CreateOrderResponse(
order_id=str(order_id),
pay_url=payment_order.pay_url,
amount=amount,
status="pending",
)
@router.post("/callback/wechat")
async def wechat_pay_callback(request: Request):
body = await request.form()
request_data = dict(body)
gateway = get_payment_gateway("wechat")
callback = await gateway.verify_callback(request_data)
return await _handle_payment_callback(request_data, callback, "wechat")
@router.post("/callback/alipay")
async def alipay_callback(request: Request):
body = await request.form()
request_data = dict(body)
gateway = get_payment_gateway("alipay")
callback = await gateway.verify_callback(request_data)
return await _handle_payment_callback(request_data, callback, "alipay")
async def _handle_payment_callback(request_data: dict, callback, provider: str):
from app.database import AsyncSessionLocal
async with AsyncSessionLocal() as db:
try:
result = await _process_callback(db, callback, provider)
await db.commit()
return result
except Exception as e:
logger.error(f"[PaymentCallback] 处理回调异常: {e}", exc_info=True)
await db.rollback()
if provider == "wechat":
return _wechat_fail_response()
return "fail"
async def _process_callback(db: AsyncSession, callback, provider: str):
stmt = select(PaymentOrderModel).where(
PaymentOrderModel.id == uuid.UUID(callback.order_id)
)
result = await db.execute(stmt)
order = result.scalar_one_or_none()
if order is None:
logger.warning(f"[PaymentCallback] 订单不存在: order_id={callback.order_id}")
if provider == "wechat":
return _wechat_fail_response()
return "fail"
if callback.status == "success":
order.status = "paid"
order.payment_id = callback.payment_id
order.callback_data = callback.raw_data
order.paid_at = datetime.now(timezone.utc)
await subscribe(db, order.user_id, order.plan)
logger.info(
f"[PaymentCallback] 支付成功: order_id={callback.order_id}, "
f"plan={order.plan}, provider={provider}"
)
else:
order.status = "failed"
order.callback_data = callback.raw_data
if provider == "wechat":
return _wechat_success_response()
return "success"
def _wechat_success_response():
from fastapi.responses import Response
return Response(content="", media_type="application/xml")
def _wechat_fail_response():
from fastapi.responses import Response
return Response(content="", media_type="application/xml")
@router.get("/orders/{order_id}", response_model=OrderStatusResponse)
async def query_order_status(
order_id: str,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
try:
oid = uuid.UUID(order_id)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="无效的订单ID",
)
stmt = select(PaymentOrderModel).where(PaymentOrderModel.id == oid)
result = await db.execute(stmt)
order = result.scalar_one_or_none()
if order is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="订单不存在",
)
if order.user_id != current_user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="无权查看此订单",
)
return OrderStatusResponse(
order_id=str(order.id),
status=order.status,
plan=order.plan,
amount=order.amount,
payment_provider=order.payment_provider,
payment_id=order.payment_id,
created_at=order.created_at.isoformat() if order.created_at else None,
paid_at=order.paid_at.isoformat() if order.paid_at else None,
)
@router.post("/refund/{order_id}")
async def refund_order(
order_id: str,
body: RefundRequest = RefundRequest(),
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
user_plan = getattr(current_user, "plan", "free") or "free"
if user_plan != "enterprise":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="仅企业管理员可执行退款操作",
)
try:
oid = uuid.UUID(order_id)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="无效的订单ID",
)
stmt = select(PaymentOrderModel).where(PaymentOrderModel.id == oid)
result = await db.execute(stmt)
order = result.scalar_one_or_none()
if order is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="订单不存在",
)
if order.status != "paid":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="仅已支付订单可退款",
)
gateway = get_payment_gateway(order.payment_provider)
success = await gateway.refund(order_id, order.amount, body.reason)
if success:
order.status = "refunded"
await db.commit()
return {"message": "退款成功", "order_id": order_id}
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="退款失败",
)