167 lines
4.9 KiB
Python
167 lines
4.9 KiB
Python
"""Brands API endpoints."""
|
||
import json
|
||
import uuid
|
||
from typing import Annotated
|
||
|
||
from fastapi import APIRouter, Depends, HTTPException, status
|
||
from sqlalchemy import select, func
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from sqlalchemy.orm import selectinload
|
||
|
||
from app.api.deps import get_current_user
|
||
from app.api.competitors import router as competitors_router
|
||
from app.api.scoring import router as scoring_router
|
||
from app.database import get_db
|
||
from app.models.user import User
|
||
from app.models.brand import Brand
|
||
from app.schemas.brand import BrandCreate, BrandUpdate, BrandResponse, BrandListResponse
|
||
from app.services.cache import get_cache_service, TTL_BRANDS
|
||
|
||
router = APIRouter()
|
||
|
||
# Include competitors router under brands
|
||
router.include_router(competitors_router)
|
||
|
||
# Include scoring router under brands (/{brand_id}/score/, /{brand_id}/score/history/)
|
||
router.include_router(scoring_router)
|
||
|
||
|
||
@router.get("/", response_model=BrandListResponse)
|
||
async def get_brands(
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""Get all brands for the current user."""
|
||
cache = get_cache_service()
|
||
cache_key = f"brands:{current_user.id}"
|
||
|
||
# 先读缓存
|
||
cached = await cache.get_json(cache_key)
|
||
if cached is not None:
|
||
return cached
|
||
|
||
# 修复 N+1:一次性加载 competitors 和 suggestions
|
||
stmt = (
|
||
select(Brand)
|
||
.where(Brand.user_id == current_user.id)
|
||
.options(
|
||
selectinload(Brand.competitors),
|
||
selectinload(Brand.suggestions),
|
||
)
|
||
)
|
||
result = await db.execute(stmt)
|
||
items = list(result.scalars().all())
|
||
total = len(items)
|
||
|
||
response_data = {"items": [BrandResponse.model_validate(b).model_dump(mode="json") for b in items], "total": total}
|
||
|
||
# 写入缓存(TTL: 5 分钟)
|
||
await cache.set_json(cache_key, response_data, expire=TTL_BRANDS)
|
||
|
||
return {"items": items, "total": total}
|
||
|
||
|
||
@router.post("/", response_model=BrandResponse, status_code=status.HTTP_201_CREATED)
|
||
async def create_brand(
|
||
brand_data: BrandCreate,
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""Create a new brand."""
|
||
brand = Brand(
|
||
user_id=current_user.id,
|
||
name=brand_data.name,
|
||
aliases=brand_data.aliases,
|
||
website=brand_data.website,
|
||
industry=brand_data.industry,
|
||
platforms=brand_data.platforms,
|
||
frequency=brand_data.frequency,
|
||
)
|
||
db.add(brand)
|
||
await db.commit()
|
||
await db.refresh(brand)
|
||
|
||
# 失效该用户的品牌列表缓存
|
||
cache = get_cache_service()
|
||
await cache.delete(f"brands:{current_user.id}")
|
||
|
||
return brand
|
||
|
||
|
||
@router.get("/{brand_id}/", response_model=BrandResponse)
|
||
async def get_brand(
|
||
brand_id: uuid.UUID,
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""Get a specific brand by ID."""
|
||
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id)
|
||
result = await db.execute(stmt)
|
||
brand = result.scalar_one_or_none()
|
||
|
||
if not brand:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="品牌不存在",
|
||
)
|
||
return brand
|
||
|
||
|
||
@router.put("/{brand_id}/", response_model=BrandResponse)
|
||
async def update_brand(
|
||
brand_id: uuid.UUID,
|
||
brand_data: BrandUpdate,
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""Update a brand."""
|
||
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id)
|
||
result = await db.execute(stmt)
|
||
brand = result.scalar_one_or_none()
|
||
|
||
if not brand:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="品牌不存在",
|
||
)
|
||
|
||
# Update only provided fields
|
||
update_data = brand_data.model_dump(exclude_unset=True)
|
||
for field, value in update_data.items():
|
||
setattr(brand, field, value)
|
||
|
||
await db.commit()
|
||
await db.refresh(brand)
|
||
|
||
# 失效该用户的品牌列表缓存
|
||
cache = get_cache_service()
|
||
await cache.delete(f"brands:{current_user.id}")
|
||
|
||
return brand
|
||
|
||
|
||
@router.delete("/{brand_id}/", status_code=status.HTTP_204_NO_CONTENT)
|
||
async def delete_brand(
|
||
brand_id: uuid.UUID,
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""Delete a brand."""
|
||
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id)
|
||
result = await db.execute(stmt)
|
||
brand = result.scalar_one_or_none()
|
||
|
||
if not brand:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="品牌不存在",
|
||
)
|
||
|
||
await db.delete(brand)
|
||
await db.commit()
|
||
|
||
# 失效该用户的品牌列表缓存
|
||
cache = get_cache_service()
|
||
await cache.delete(f"brands:{current_user.id}")
|
||
|
||
return None |