geo/backend/app/api/brands.py

199 lines
6.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Brands API endpoints."""
import json
import logging
import uuid
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.api.deps import get_current_user
from app.api.competitors import router as competitors_router
from app.api.scoring import router as scoring_router
from app.database import get_db
from app.models.user import User
from app.models.brand import Brand
from app.models.query import Query as QueryModel
from app.schemas.brand import BrandCreate, BrandUpdate, BrandResponse, BrandListResponse
from app.services.cache import get_cache_service, TTL_BRANDS
logger = logging.getLogger(__name__)
router = APIRouter()
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
if isinstance(value, uuid.UUID):
return value
return uuid.UUID(str(value))
# Include competitors router under brands
router.include_router(competitors_router)
# Include scoring router under brands (/{brand_id}/score/, /{brand_id}/score/history/)
router.include_router(scoring_router)
@router.get("/", response_model=BrandListResponse)
async def get_brands(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Get all brands for the current user."""
cache = get_cache_service()
cache_key = f"brands:{current_user.id}"
# 先读缓存
cached = await cache.get_json(cache_key)
if cached is not None:
return cached
# 修复 N+1一次性加载 competitors 和 suggestions
stmt = (
select(Brand)
.where(Brand.user_id == _to_uuid(current_user.id))
.options(
selectinload(Brand.competitors),
selectinload(Brand.suggestions),
)
)
result = await db.execute(stmt)
items = list(result.scalars().all())
total = len(items)
response_data = {"items": [BrandResponse.model_validate(b).model_dump(mode="json") for b in items], "total": total}
# 写入缓存TTL: 5 分钟)
await cache.set_json(cache_key, response_data, expire=TTL_BRANDS)
return {"items": items, "total": total}
@router.post("/", response_model=BrandResponse, status_code=status.HTTP_201_CREATED)
async def create_brand(
brand_data: BrandCreate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Create a new brand."""
brand = Brand(
user_id=_to_uuid(current_user.id),
name=brand_data.name,
aliases=brand_data.aliases,
website=brand_data.website,
industry=brand_data.industry,
platforms=brand_data.platforms,
frequency=brand_data.frequency,
)
db.add(brand)
await db.commit()
await db.refresh(brand)
# 失效该用户的品牌列表缓存
cache = get_cache_service()
await cache.delete(f"brands:{current_user.id}")
return brand
@router.get("/{brand_id}/", response_model=BrandResponse)
async def get_brand(
brand_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Get a specific brand by ID."""
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == _to_uuid(current_user.id))
result = await db.execute(stmt)
brand = result.scalar_one_or_none()
if not brand:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="品牌不存在",
)
return brand
@router.put("/{brand_id}/", response_model=BrandResponse)
async def update_brand(
brand_id: uuid.UUID,
brand_data: BrandUpdate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Update a brand."""
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == _to_uuid(current_user.id))
result = await db.execute(stmt)
brand = result.scalar_one_or_none()
if not brand:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="品牌不存在",
)
# Update only provided fields
update_data = brand_data.model_dump(exclude_unset=True)
# 检测品牌名称变更,同步更新关联 Query 的 target_brand 和 brand_aliases
old_name = brand.name
new_name = update_data.get("name")
if new_name and new_name != old_name:
queries_stmt = select(QueryModel).where(
QueryModel.user_id == current_user.id,
QueryModel.target_brand == old_name,
)
queries_result = await db.execute(queries_stmt)
related_queries = queries_result.scalars().all()
for query in related_queries:
query.target_brand = new_name
# 如果 brand_aliases 中包含旧名称,也同步更新
if query.brand_aliases and old_name in query.brand_aliases:
query.brand_aliases = [new_name if a == old_name else a for a in query.brand_aliases]
if related_queries:
logger.info(f"Brand renamed from '{old_name}' to '{new_name}', synced {len(related_queries)} queries")
for field, value in update_data.items():
setattr(brand, field, value)
await db.commit()
await db.refresh(brand)
# 失效该用户的品牌列表缓存
cache = get_cache_service()
await cache.delete(f"brands:{current_user.id}")
return brand
@router.delete("/{brand_id}/", status_code=status.HTTP_204_NO_CONTENT)
async def delete_brand(
brand_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Delete a brand."""
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == _to_uuid(current_user.id))
result = await db.execute(stmt)
brand = result.scalar_one_or_none()
if not brand:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="品牌不存在",
)
await db.delete(brand)
await db.commit()
# 失效该用户的品牌列表缓存
cache = get_cache_service()
await cache.delete(f"brands:{current_user.id}")
return None