geo/backend/app/api/competitors.py

419 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Competitors API endpoints."""
import asyncio
import json
import logging
import uuid
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.config import settings
from app.database import get_db
from app.models.user import User
from app.models.brand import Brand
from app.models.competitor import Competitor
from app.schemas.competitor import (
CompetitorCreate,
CompetitorResponse,
CompetitorListResponse,
CompetitorRecommendationItem,
CompetitorRecommendationResponse,
)
from app.utils.json_extractor import extract_json
logger = logging.getLogger(__name__)
router = APIRouter()
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
if isinstance(value, uuid.UUID):
return value
return uuid.UUID(str(value))
async def get_brand_if_owned(
brand_id: uuid.UUID,
current_user: User,
db: AsyncSession,
) -> Brand:
"""Helper to get brand if owned by current user."""
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == _to_uuid(current_user.id))
result = await db.execute(stmt)
brand = result.scalar_one_or_none()
if not brand:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="品牌不存在",
)
return brand
@router.get("/{brand_id}/competitors/", response_model=CompetitorListResponse)
async def get_competitors(
brand_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Get all competitors for a brand."""
# Verify brand ownership
await get_brand_if_owned(brand_id, current_user, db)
stmt = select(Competitor).where(Competitor.brand_id == brand_id)
result = await db.execute(stmt)
competitors = result.scalars().all()
return {"items": competitors, "total": len(competitors)}
@router.post("/{brand_id}/competitors/", response_model=CompetitorResponse, status_code=status.HTTP_201_CREATED)
async def create_competitor(
brand_id: uuid.UUID,
competitor_data: CompetitorCreate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Create a new competitor for a brand."""
# Verify brand ownership
await get_brand_if_owned(brand_id, current_user, db)
# Check competitor limit (max 5)
count_stmt = select(func.count()).select_from(Competitor).where(
Competitor.brand_id == brand_id
)
count_result = await db.execute(count_stmt)
current_count = count_result.scalar_one()
if current_count >= 5:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="竞品数量已达上限最多5个",
)
# Check for duplicate name
dup_stmt = select(Competitor).where(
Competitor.brand_id == brand_id,
Competitor.name == competitor_data.name,
)
dup_result = await db.execute(dup_stmt)
if dup_result.scalar_one_or_none():
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"竞品 '{competitor_data.name}' 已存在",
)
competitor = Competitor(
brand_id=brand_id,
name=competitor_data.name,
aliases=competitor_data.aliases,
)
db.add(competitor)
await db.commit()
await db.refresh(competitor)
return competitor
@router.delete("/{brand_id}/competitors/{competitor_id}/", status_code=status.HTTP_204_NO_CONTENT)
async def delete_competitor(
brand_id: uuid.UUID,
competitor_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""Delete a competitor."""
# Verify brand ownership
await get_brand_if_owned(brand_id, current_user, db)
stmt = select(Competitor).where(
Competitor.id == competitor_id,
Competitor.brand_id == brand_id,
)
result = await db.execute(stmt)
competitor = result.scalar_one_or_none()
if not competitor:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="竞品不存在",
)
await db.delete(competitor)
await db.commit()
return None
@router.get("/{brand_id}/competitors/recommendations/", response_model=CompetitorRecommendationResponse)
async def get_competitor_recommendations(
brand_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""
基于品牌行业和名称使用LLM推荐竞品。
当ENABLE_LLM=True时调用DeepSeek API进行智能推荐
否则返回基于行业规则的默认推荐。
"""
brand = await get_brand_if_owned(brand_id, current_user, db)
# Get existing competitor names to exclude
existing_stmt = select(Competitor.name).where(Competitor.brand_id == brand_id)
existing_result = await db.execute(existing_stmt)
existing_names = [row[0] for row in existing_result.all()]
if settings.ENABLE_LLM and settings.DEEPSEEK_API_KEY:
try:
recommendations = await _get_llm_recommendations(
brand_name=brand.name,
industry=brand.industry,
existing_names=existing_names,
)
except Exception as e:
logger.warning(f"LLM推荐竞品失败回退到规则推荐: {e}")
recommendations = _get_rule_based_recommendations(
brand_name=brand.name,
industry=brand.industry,
existing_names=existing_names,
)
else:
recommendations = _get_rule_based_recommendations(
brand_name=brand.name,
industry=brand.industry,
existing_names=existing_names,
)
return CompetitorRecommendationResponse(
brand_name=brand.name,
industry=brand.industry,
recommendations=recommendations,
)
# ============================================================
# 竞品推荐辅助函数
# ============================================================
# 行业竞品知识库
INDUSTRY_COMPETITORS = {
"technology": [
{"name": "华为", "reason": "国内科技巨头,覆盖云计算、终端设备等多个领域"},
{"name": "阿里巴巴", "reason": "电商与云计算领域的领导者"},
{"name": "腾讯", "reason": "社交与游戏领域的巨头,云服务快速增长"},
{"name": "字节跳动", "reason": "内容平台与AI领域的新兴力量"},
{"name": "百度", "reason": "AI与搜索领域的先行者"},
{"name": "小米", "reason": "智能硬件与IoT生态的领先者"},
],
"finance": [
{"name": "蚂蚁集团", "reason": "数字金融科技领域的领导者"},
{"name": "京东数科", "reason": "数字科技与金融服务的创新者"},
{"name": "平安科技", "reason": "金融科技综合解决方案提供商"},
{"name": "陆金所", "reason": "线上财富管理平台的先行者"},
],
"retail": [
{"name": "京东", "reason": "自营电商与物流体系的领导者"},
{"name": "拼多多", "reason": "社交电商模式的创新者"},
{"name": "美团", "reason": "本地生活服务与即时零售的领先者"},
{"name": "盒马鲜生", "reason": "新零售模式的代表企业"},
],
"healthcare": [
{"name": "平安好医生", "reason": "互联网医疗健康平台的领先者"},
{"name": "微医", "reason": "数字医疗与健康管理的先行者"},
{"name": "丁香园", "reason": "专业医疗知识与服务的平台"},
{"name": "京东健康", "reason": "医药电商与互联网医疗的综合平台"},
],
"education": [
{"name": "好未来", "reason": "K12教育领域的领先品牌"},
{"name": "新东方", "reason": "语言培训与教育服务的老牌企业"},
{"name": "网易有道", "reason": "智能学习硬件与在线教育平台"},
{"name": "作业帮", "reason": "K12在线教育工具的领先者"},
],
"automotive": [
{"name": "比亚迪", "reason": "新能源汽车领域的领导者"},
{"name": "蔚来", "reason": "高端智能电动汽车的代表"},
{"name": "小鹏汽车", "reason": "智能驾驶技术的创新者"},
{"name": "理想汽车", "reason": "增程式电动SUV的先行者"},
],
"entertainment": [
{"name": "哔哩哔哩", "reason": "年轻人文化与视频社区的领先者"},
{"name": "快手", "reason": "短视频与直播平台的代表"},
{"name": "爱奇艺", "reason": "在线视频与内容制作的领先者"},
{"name": "芒果TV", "reason": "综艺内容与流媒体平台的创新者"},
],
"food": [
{"name": "美团外卖", "reason": "外卖与本地生活服务的领导者"},
{"name": "饿了么", "reason": "即时配送与餐饮服务平台"},
{"name": "瑞幸咖啡", "reason": "新零售咖啡连锁的代表"},
{"name": "海底捞", "reason": "餐饮服务体验的标杆企业"},
],
}
def _get_rule_based_recommendations(
brand_name: str,
industry: str | None,
existing_names: list[str],
) -> list[CompetitorRecommendationItem]:
"""
基于行业规则的竞品推荐。
Args:
brand_name: 品牌名称
industry: 行业标识
existing_names: 已有竞品名称列表(用于排除)
Returns:
推荐竞品列表
"""
recommendations = []
# 从行业知识库中获取推荐
if industry and industry in INDUSTRY_COMPETITORS:
for comp in INDUSTRY_COMPETITORS[industry]:
if comp["name"] not in existing_names and comp["name"] != brand_name:
recommendations.append(CompetitorRecommendationItem(
name=comp["name"],
reason=comp["reason"],
industry=industry,
))
# 如果行业推荐不足3个添加通用推荐
if len(recommendations) < 3:
generic_recs = [
CompetitorRecommendationItem(
name=f"{brand_name}行业领先者",
reason="基于市场份额分析推荐",
industry=industry,
),
CompetitorRecommendationItem(
name=f"{brand_name}新兴竞争者",
reason="增长势头强劲的新兴品牌",
industry=industry,
),
CompetitorRecommendationItem(
name=f"{brand_name}直接竞争对手",
reason="产品与服务高度重叠",
industry=industry,
),
]
for rec in generic_recs:
if rec.name not in existing_names and rec.name != brand_name:
if not any(r.name == rec.name for r in recommendations):
recommendations.append(rec)
# 最多返回5个推荐
return recommendations[:5]
async def _get_llm_recommendations(
brand_name: str,
industry: str | None,
existing_names: list[str],
) -> list[CompetitorRecommendationItem]:
"""
使用DeepSeek LLM API推荐竞品。
Args:
brand_name: 品牌名称
industry: 行业标识
existing_names: 已有竞品名称列表(用于排除)
Returns:
推荐竞品列表
Raises:
Exception: API调用失败
"""
industry_label = _get_industry_label(industry)
existing_str = "".join(existing_names) if existing_names else ""
prompt = f"""你是一个专业的市场分析专家。请为以下品牌推荐3-5个竞品品牌。
品牌名称: {brand_name}
所属行业: {industry_label}
已有竞品: {existing_str}
请返回JSON格式不要包含其他文字:
{{
"recommendations": [
{{"name": "竞品名称", "reason": "推荐理由(简短说明为什么是竞品)"}}
]
}}
要求:
1. 推荐的竞品必须是真实存在的品牌
2. 不要推荐已有竞品列表中的品牌
3. 不要推荐品牌自身
4. 推荐理由要具体,说明竞争关系
5. 优先推荐同行业中知名度高、市场份额大的品牌"""
try:
from openai import OpenAI
client = OpenAI(
api_key=settings.DEEPSEEK_API_KEY,
base_url="https://api.deepseek.com",
)
response = await asyncio.to_thread(
client.chat.completions.create,
model="deepseek-chat",
messages=[{"role": "user", "content": prompt}],
temperature=0.3,
max_tokens=800,
)
content = response.choices[0].message.content
if not content:
raise ValueError("API返回空响应")
# 提取JSON
json_str = extract_json(content)
data = json.loads(json_str)
recommendations = []
for item in data.get("recommendations", []):
name = item.get("name", "").strip()
reason = item.get("reason", "").strip()
if (
name
and reason
and name != brand_name
and name not in existing_names
and 2 <= len(name) <= 50
):
recommendations.append(CompetitorRecommendationItem(
name=name,
reason=reason,
industry=industry,
))
return recommendations[:5]
except json.JSONDecodeError as e:
logger.error(f"LLM推荐竞品JSON解析失败: {e}")
raise
except Exception as e:
logger.error(f"LLM推荐竞品API调用失败: {e}")
raise
def _get_industry_label(industry: str | None) -> str:
"""获取行业中文标签。"""
labels = {
"technology": "科技",
"finance": "金融",
"retail": "零售",
"healthcare": "医疗健康",
"education": "教育",
"manufacturing": "制造业",
"entertainment": "娱乐",
"travel": "旅游",
"food": "餐饮",
"automotive": "汽车",
"real_estate": "房地产",
"other": "其他",
}
return labels.get(industry, "未知") if industry else "未知"