193 lines
5.9 KiB
Python
193 lines
5.9 KiB
Python
"""Brands API endpoints."""
|
||
import json
|
||
import logging
|
||
import uuid
|
||
from typing import Annotated
|
||
|
||
from fastapi import APIRouter, Depends, HTTPException, status
|
||
from sqlalchemy import select, func
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
from sqlalchemy.orm import selectinload
|
||
|
||
from app.api.deps import get_current_user
|
||
from app.api.competitors import router as competitors_router
|
||
from app.api.scoring import router as scoring_router
|
||
from app.database import get_db
|
||
from app.models.user import User
|
||
from app.models.brand import Brand
|
||
from app.models.query import Query as QueryModel
|
||
from app.schemas.brand import BrandCreate, BrandUpdate, BrandResponse, BrandListResponse
|
||
from app.services.cache import get_cache_service, TTL_BRANDS
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
router = APIRouter()
|
||
|
||
# Include competitors router under brands
|
||
router.include_router(competitors_router)
|
||
|
||
# Include scoring router under brands (/{brand_id}/score/, /{brand_id}/score/history/)
|
||
router.include_router(scoring_router)
|
||
|
||
|
||
@router.get("/", response_model=BrandListResponse)
|
||
async def get_brands(
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""Get all brands for the current user."""
|
||
cache = get_cache_service()
|
||
cache_key = f"brands:{current_user.id}"
|
||
|
||
# 先读缓存
|
||
cached = await cache.get_json(cache_key)
|
||
if cached is not None:
|
||
return cached
|
||
|
||
# 修复 N+1:一次性加载 competitors 和 suggestions
|
||
stmt = (
|
||
select(Brand)
|
||
.where(Brand.user_id == current_user.id)
|
||
.options(
|
||
selectinload(Brand.competitors),
|
||
selectinload(Brand.suggestions),
|
||
)
|
||
)
|
||
result = await db.execute(stmt)
|
||
items = list(result.scalars().all())
|
||
total = len(items)
|
||
|
||
response_data = {"items": [BrandResponse.model_validate(b).model_dump(mode="json") for b in items], "total": total}
|
||
|
||
# 写入缓存(TTL: 5 分钟)
|
||
await cache.set_json(cache_key, response_data, expire=TTL_BRANDS)
|
||
|
||
return {"items": items, "total": total}
|
||
|
||
|
||
@router.post("/", response_model=BrandResponse, status_code=status.HTTP_201_CREATED)
|
||
async def create_brand(
|
||
brand_data: BrandCreate,
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""Create a new brand."""
|
||
brand = Brand(
|
||
user_id=current_user.id,
|
||
name=brand_data.name,
|
||
aliases=brand_data.aliases,
|
||
website=brand_data.website,
|
||
industry=brand_data.industry,
|
||
platforms=brand_data.platforms,
|
||
frequency=brand_data.frequency,
|
||
)
|
||
db.add(brand)
|
||
await db.commit()
|
||
await db.refresh(brand)
|
||
|
||
# 失效该用户的品牌列表缓存
|
||
cache = get_cache_service()
|
||
await cache.delete(f"brands:{current_user.id}")
|
||
|
||
return brand
|
||
|
||
|
||
@router.get("/{brand_id}/", response_model=BrandResponse)
|
||
async def get_brand(
|
||
brand_id: uuid.UUID,
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""Get a specific brand by ID."""
|
||
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id)
|
||
result = await db.execute(stmt)
|
||
brand = result.scalar_one_or_none()
|
||
|
||
if not brand:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="品牌不存在",
|
||
)
|
||
return brand
|
||
|
||
|
||
@router.put("/{brand_id}/", response_model=BrandResponse)
|
||
async def update_brand(
|
||
brand_id: uuid.UUID,
|
||
brand_data: BrandUpdate,
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""Update a brand."""
|
||
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id)
|
||
result = await db.execute(stmt)
|
||
brand = result.scalar_one_or_none()
|
||
|
||
if not brand:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="品牌不存在",
|
||
)
|
||
|
||
# Update only provided fields
|
||
update_data = brand_data.model_dump(exclude_unset=True)
|
||
|
||
# 检测品牌名称变更,同步更新关联 Query 的 target_brand 和 brand_aliases
|
||
old_name = brand.name
|
||
new_name = update_data.get("name")
|
||
|
||
if new_name and new_name != old_name:
|
||
queries_stmt = select(QueryModel).where(
|
||
QueryModel.user_id == current_user.id,
|
||
QueryModel.target_brand == old_name,
|
||
)
|
||
queries_result = await db.execute(queries_stmt)
|
||
related_queries = queries_result.scalars().all()
|
||
|
||
for query in related_queries:
|
||
query.target_brand = new_name
|
||
# 如果 brand_aliases 中包含旧名称,也同步更新
|
||
if query.brand_aliases and old_name in query.brand_aliases:
|
||
query.brand_aliases = [new_name if a == old_name else a for a in query.brand_aliases]
|
||
|
||
if related_queries:
|
||
logger.info(f"Brand renamed from '{old_name}' to '{new_name}', synced {len(related_queries)} queries")
|
||
|
||
for field, value in update_data.items():
|
||
setattr(brand, field, value)
|
||
|
||
await db.commit()
|
||
await db.refresh(brand)
|
||
|
||
# 失效该用户的品牌列表缓存
|
||
cache = get_cache_service()
|
||
await cache.delete(f"brands:{current_user.id}")
|
||
|
||
return brand
|
||
|
||
|
||
@router.delete("/{brand_id}/", status_code=status.HTTP_204_NO_CONTENT)
|
||
async def delete_brand(
|
||
brand_id: uuid.UUID,
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""Delete a brand."""
|
||
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id)
|
||
result = await db.execute(stmt)
|
||
brand = result.scalar_one_or_none()
|
||
|
||
if not brand:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="品牌不存在",
|
||
)
|
||
|
||
await db.delete(brand)
|
||
await db.commit()
|
||
|
||
# 失效该用户的品牌列表缓存
|
||
cache = get_cache_service()
|
||
await cache.delete(f"brands:{current_user.id}")
|
||
|
||
return None |