"""Brands API endpoints.""" import json import uuid from typing import Annotated from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy import select, func from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from app.api.deps import get_current_user from app.api.competitors import router as competitors_router from app.api.scoring import router as scoring_router from app.database import get_db from app.models.user import User from app.models.brand import Brand from app.schemas.brand import BrandCreate, BrandUpdate, BrandResponse, BrandListResponse from app.services.cache import get_cache_service, TTL_BRANDS router = APIRouter() # Include competitors router under brands router.include_router(competitors_router) # Include scoring router under brands (/{brand_id}/score/, /{brand_id}/score/history/) router.include_router(scoring_router) @router.get("/", response_model=BrandListResponse) async def get_brands( current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """Get all brands for the current user.""" cache = get_cache_service() cache_key = f"brands:{current_user.id}" # 先读缓存 cached = await cache.get_json(cache_key) if cached is not None: return cached # 修复 N+1:一次性加载 competitors 和 suggestions stmt = ( select(Brand) .where(Brand.user_id == current_user.id) .options( selectinload(Brand.competitors), selectinload(Brand.suggestions), ) ) result = await db.execute(stmt) items = list(result.scalars().all()) total = len(items) response_data = {"items": [BrandResponse.model_validate(b).model_dump(mode="json") for b in items], "total": total} # 写入缓存(TTL: 5 分钟) await cache.set_json(cache_key, response_data, expire=TTL_BRANDS) return {"items": items, "total": total} @router.post("/", response_model=BrandResponse, status_code=status.HTTP_201_CREATED) async def create_brand( brand_data: BrandCreate, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """Create a new brand.""" brand = Brand( user_id=current_user.id, name=brand_data.name, aliases=brand_data.aliases, website=brand_data.website, industry=brand_data.industry, platforms=brand_data.platforms, frequency=brand_data.frequency, ) db.add(brand) await db.commit() await db.refresh(brand) # 失效该用户的品牌列表缓存 cache = get_cache_service() await cache.delete(f"brands:{current_user.id}") return brand @router.get("/{brand_id}/", response_model=BrandResponse) async def get_brand( brand_id: uuid.UUID, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """Get a specific brand by ID.""" stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id) result = await db.execute(stmt) brand = result.scalar_one_or_none() if not brand: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="品牌不存在", ) return brand @router.put("/{brand_id}/", response_model=BrandResponse) async def update_brand( brand_id: uuid.UUID, brand_data: BrandUpdate, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """Update a brand.""" stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id) result = await db.execute(stmt) brand = result.scalar_one_or_none() if not brand: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="品牌不存在", ) # Update only provided fields update_data = brand_data.model_dump(exclude_unset=True) for field, value in update_data.items(): setattr(brand, field, value) await db.commit() await db.refresh(brand) # 失效该用户的品牌列表缓存 cache = get_cache_service() await cache.delete(f"brands:{current_user.id}") return brand @router.delete("/{brand_id}/", status_code=status.HTTP_204_NO_CONTENT) async def delete_brand( brand_id: uuid.UUID, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """Delete a brand.""" stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id) result = await db.execute(stmt) brand = result.scalar_one_or_none() if not brand: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="品牌不存在", ) await db.delete(brand) await db.commit() # 失效该用户的品牌列表缓存 cache = get_cache_service() await cache.delete(f"brands:{current_user.id}") return None