514 lines
17 KiB
Python
514 lines
17 KiB
Python
"""Onboarding API endpoints - 新用户引导流程"""
|
||
import logging
|
||
import uuid
|
||
from typing import Optional
|
||
|
||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||
from pydantic import BaseModel, Field
|
||
from sqlalchemy import select, func
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.api.deps import get_current_user
|
||
from app.api.competitors import (
|
||
INDUSTRY_COMPETITORS,
|
||
_get_rule_based_recommendations,
|
||
_get_llm_recommendations,
|
||
CompetitorRecommendationItem,
|
||
)
|
||
from app.config import settings
|
||
from app.database import get_db
|
||
from app.models.user import User
|
||
from app.models.brand import Brand
|
||
from app.models.competitor import Competitor
|
||
from app.models.citation_record import CitationRecord
|
||
from app.models.query import Query as QueryModel
|
||
from app.services.scoring_service import ScoringService
|
||
from app.schemas.brand import BrandCreate, BrandResponse
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
router = APIRouter(prefix="/onboarding", tags=["onboarding"])
|
||
|
||
|
||
# ------------------------------------------------------------------
|
||
# Request / Response schemas
|
||
# ------------------------------------------------------------------
|
||
|
||
class OnboardingBrandCreate(BaseModel):
|
||
"""Onboarding 创建品牌请求(简化版)"""
|
||
name: str = Field(..., min_length=2, max_length=50, description="品牌名称")
|
||
description: Optional[str] = Field(None, max_length=500, description="品牌描述")
|
||
industry: Optional[str] = Field(None, max_length=50, description="行业")
|
||
|
||
|
||
class OnboardingStatusResponse(BaseModel):
|
||
"""Onboarding 状态响应"""
|
||
completed: bool
|
||
brand_id: Optional[str] = None
|
||
current_step: int
|
||
|
||
|
||
class CompetitorRecommendationSimple(BaseModel):
|
||
"""简化竞品推荐项"""
|
||
name: str
|
||
description: str
|
||
confidence: float
|
||
|
||
|
||
class CompetitorRecommendationSimpleResponse(BaseModel):
|
||
"""简化竞品推荐响应"""
|
||
recommendations: list[CompetitorRecommendationSimple]
|
||
|
||
|
||
class HealthReportResponse(BaseModel):
|
||
"""初始健康评分报告"""
|
||
brand_id: str
|
||
brand_name: str
|
||
overall_score: float
|
||
platform_scores: dict
|
||
strengths: list[str]
|
||
weaknesses: list[str]
|
||
competitor_scores: list[dict]
|
||
|
||
|
||
class ActionSuggestion(BaseModel):
|
||
"""行动建议项"""
|
||
title: str
|
||
description: str
|
||
priority: str # high / medium / low
|
||
action_type: str # e.g. coverage, keyword, sentiment, platform
|
||
|
||
|
||
class ActionSuggestionsResponse(BaseModel):
|
||
"""行动建议响应"""
|
||
suggestions: list[ActionSuggestion]
|
||
|
||
|
||
class OnboardingCompleteResponse(BaseModel):
|
||
"""完成 onboarding 响应"""
|
||
success: bool
|
||
|
||
|
||
# ------------------------------------------------------------------
|
||
# Endpoints
|
||
# ------------------------------------------------------------------
|
||
|
||
@router.get("/status", response_model=OnboardingStatusResponse)
|
||
async def get_onboarding_status(
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""
|
||
检查当前用户的 onboarding 状态。
|
||
|
||
通过查询 brands 表判断:用户是否已创建品牌(即完成 onboarding)。
|
||
- completed=True 且 brand_id 有值 → 已完成
|
||
- completed=False, current_step=1 → 需要创建品牌
|
||
"""
|
||
stmt = select(Brand).where(Brand.user_id == current_user.id)
|
||
result = await db.execute(stmt)
|
||
brand = result.scalar_one_or_none()
|
||
|
||
if brand:
|
||
return OnboardingStatusResponse(
|
||
completed=True,
|
||
brand_id=str(brand.id),
|
||
current_step=4,
|
||
)
|
||
|
||
return OnboardingStatusResponse(
|
||
completed=False,
|
||
brand_id=None,
|
||
current_step=1,
|
||
)
|
||
|
||
|
||
@router.post("/brand", response_model=BrandResponse, status_code=status.HTTP_201_CREATED)
|
||
async def create_onboarding_brand(
|
||
brand_data: OnboardingBrandCreate,
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""
|
||
Onboarding 流程中创建品牌。
|
||
|
||
复用 Brand 模型,将简化字段映射到完整 BrandCreate。
|
||
"""
|
||
full_brand_data = BrandCreate(
|
||
name=brand_data.name,
|
||
aliases=[],
|
||
website=None,
|
||
industry=brand_data.industry,
|
||
platforms=["wenxin", "kimi"],
|
||
frequency="weekly",
|
||
)
|
||
|
||
brand = Brand(
|
||
user_id=current_user.id,
|
||
name=full_brand_data.name,
|
||
aliases=full_brand_data.aliases,
|
||
website=full_brand_data.website,
|
||
industry=full_brand_data.industry,
|
||
platforms=full_brand_data.platforms,
|
||
frequency=full_brand_data.frequency,
|
||
)
|
||
db.add(brand)
|
||
await db.commit()
|
||
await db.refresh(brand)
|
||
|
||
return brand
|
||
|
||
|
||
@router.get("/competitor-recommendations", response_model=CompetitorRecommendationSimpleResponse)
|
||
async def get_onboarding_competitor_recommendations(
|
||
brand_id: uuid.UUID = Query(..., description="品牌ID"),
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""
|
||
根据品牌推荐竞品。
|
||
|
||
复用 brands/competitors 中的推荐逻辑,
|
||
支持 LLM 智能推荐和规则推荐两种模式。
|
||
"""
|
||
# 验证品牌归属
|
||
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id)
|
||
result = await db.execute(stmt)
|
||
brand = result.scalar_one_or_none()
|
||
|
||
if not brand:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="品牌不存在",
|
||
)
|
||
|
||
# 获取已有竞品名称(排除)
|
||
existing_stmt = select(Competitor.name).where(Competitor.brand_id == brand_id)
|
||
existing_result = await db.execute(existing_stmt)
|
||
existing_names = [row[0] for row in existing_result.all()]
|
||
|
||
# 选择推荐策略
|
||
if settings.ENABLE_LLM and settings.DEEPSEEK_API_KEY:
|
||
try:
|
||
rec_items = await _get_llm_recommendations(
|
||
brand_name=brand.name,
|
||
industry=brand.industry,
|
||
existing_names=existing_names,
|
||
)
|
||
except Exception as e:
|
||
logger.warning(f"Onboarding LLM竞品推荐失败,回退规则推荐: {e}")
|
||
rec_items = _get_rule_based_recommendations(
|
||
brand_name=brand.name,
|
||
industry=brand.industry,
|
||
existing_names=existing_names,
|
||
)
|
||
else:
|
||
rec_items = _get_rule_based_recommendations(
|
||
brand_name=brand.name,
|
||
industry=brand.industry,
|
||
existing_names=existing_names,
|
||
)
|
||
|
||
# 转换为简化格式
|
||
recommendations = []
|
||
for item in rec_items:
|
||
confidence = 0.8 if brand.industry and brand.industry in INDUSTRY_COMPETITORS else 0.5
|
||
recommendations.append(CompetitorRecommendationSimple(
|
||
name=item.name,
|
||
description=item.reason,
|
||
confidence=confidence,
|
||
))
|
||
|
||
return CompetitorRecommendationSimpleResponse(recommendations=recommendations)
|
||
|
||
|
||
@router.get("/health-report/{brand_id}", response_model=HealthReportResponse)
|
||
async def get_onboarding_health_report(
|
||
brand_id: uuid.UUID,
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""
|
||
获取品牌初始健康评分报告。
|
||
|
||
基于 citation_records 表统计品牌的引用数据,
|
||
如果没有引用数据则返回初始化状态(overall_score: 0)。
|
||
"""
|
||
# 验证品牌归属
|
||
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id)
|
||
result = await db.execute(stmt)
|
||
brand = result.scalar_one_or_none()
|
||
|
||
if not brand:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="品牌不存在",
|
||
)
|
||
|
||
# 查询与品牌关联的 queries
|
||
queries_stmt = select(QueryModel).where(
|
||
QueryModel.user_id == current_user.id,
|
||
QueryModel.target_brand == brand.name,
|
||
)
|
||
queries_result = await db.execute(queries_stmt)
|
||
queries = list(queries_result.scalars().all())
|
||
|
||
# 没有查询数据 → 返回初始化状态
|
||
if not queries:
|
||
return HealthReportResponse(
|
||
brand_id=str(brand.id),
|
||
brand_name=brand.name,
|
||
overall_score=0.0,
|
||
platform_scores={},
|
||
strengths=["品牌已创建,等待数据采集"],
|
||
weaknesses=["尚无AI平台引用数据,需等待查询执行"],
|
||
competitor_scores=[],
|
||
)
|
||
|
||
query_ids = [q.id for q in queries]
|
||
|
||
# 获取引用记录
|
||
citations_stmt = select(CitationRecord).where(
|
||
CitationRecord.query_id.in_(query_ids),
|
||
)
|
||
citations_result = await db.execute(citations_stmt)
|
||
citations = list(citations_result.scalars().all())
|
||
|
||
total = len(citations)
|
||
cited = [c for c in citations if c.cited]
|
||
|
||
# 计算各平台评分
|
||
platform_scores: dict[str, float] = {}
|
||
platforms_seen: dict[str, dict] = {} # {platform: {total, cited}}
|
||
|
||
for c in citations:
|
||
p = c.platform or "unknown"
|
||
if p not in platforms_seen:
|
||
platforms_seen[p] = {"total": 0, "cited": 0}
|
||
platforms_seen[p]["total"] += 1
|
||
if c.cited:
|
||
platforms_seen[p]["cited"] += 1
|
||
|
||
for p, data in platforms_seen.items():
|
||
rate = (data["cited"] / data["total"] * 100) if data["total"] > 0 else 0.0
|
||
platform_scores[p] = round(rate, 2)
|
||
|
||
# 使用 ScoringService 计算 overall_score
|
||
scoring_service = ScoringService()
|
||
sentiment_counts = {"positive": 0, "neutral": 0, "negative": 0}
|
||
for c in cited:
|
||
sentiment = c.sentiment or "neutral"
|
||
if sentiment in sentiment_counts:
|
||
sentiment_counts[sentiment] += 1
|
||
|
||
from app.schemas.scoring import CitationResult
|
||
citation_results = [
|
||
CitationResult(
|
||
cited=c.cited,
|
||
position=c.citation_position,
|
||
citation_text=c.citation_text,
|
||
sentiment=c.sentiment or "neutral",
|
||
confidence=c.confidence or 0.0,
|
||
)
|
||
for c in cited
|
||
]
|
||
positions = [c.citation_position for c in cited if c.cited]
|
||
|
||
# 获取竞品信息
|
||
competitor_stmt = select(Competitor).where(Competitor.brand_id == brand_id)
|
||
competitor_result = await db.execute(competitor_stmt)
|
||
competitors = list(competitor_result.scalars().all())
|
||
competitor_names = [c.name for c in competitors]
|
||
competitor_mentions: dict[str, int] = {}
|
||
for comp_name in competitor_names:
|
||
count = sum(
|
||
1 for c in citations
|
||
if c.cited and c.competitor_brands and comp_name in c.competitor_brands
|
||
)
|
||
if count > 0:
|
||
competitor_mentions[comp_name] = count
|
||
|
||
v2_result = scoring_service.calculate_v2(
|
||
mentioned_count=len(cited),
|
||
total_queries=total,
|
||
positions=positions,
|
||
sentiment_counts=sentiment_counts,
|
||
citations=citation_results,
|
||
brand_mentions=len(cited),
|
||
competitor_mentions=competitor_mentions,
|
||
)
|
||
|
||
# 生成 strengths/weaknesses
|
||
strengths = []
|
||
weaknesses = []
|
||
|
||
if total == 0:
|
||
strengths.append("品牌已创建")
|
||
weaknesses.append("尚无引用数据")
|
||
else:
|
||
mention_rate = len(cited) / total * 100 if total > 0 else 0
|
||
if mention_rate >= 50:
|
||
strengths.append(f"提及率较高 ({round(mention_rate, 1)}%)")
|
||
else:
|
||
weaknesses.append(f"提及率偏低 ({round(mention_rate, 1)}%)")
|
||
|
||
for p, score in platform_scores.items():
|
||
if score >= 60:
|
||
strengths.append(f"{p} 平台表现良好 ({score}%)")
|
||
elif score > 0:
|
||
weaknesses.append(f"{p} 平台覆盖率不足 ({score}%)")
|
||
|
||
if sentiment_counts["positive"] > sentiment_counts["negative"]:
|
||
strengths.append("情感倾向正面")
|
||
elif sentiment_counts["negative"] > sentiment_counts["positive"]:
|
||
weaknesses.append("情感倾向偏负面")
|
||
|
||
if not strengths:
|
||
strengths.append("已有初步引用数据")
|
||
if not weaknesses:
|
||
weaknesses.append("暂无明显短板")
|
||
|
||
# 竞品评分
|
||
competitor_scores = []
|
||
for comp_name, mentions in competitor_mentions.items():
|
||
comp_score = round(mentions / total * 100, 2) if total > 0 else 0.0
|
||
competitor_scores.append({
|
||
"name": comp_name,
|
||
"score": comp_score,
|
||
})
|
||
|
||
return HealthReportResponse(
|
||
brand_id=str(brand.id),
|
||
brand_name=brand.name,
|
||
overall_score=round(v2_result.overall_score, 2),
|
||
platform_scores=platform_scores,
|
||
strengths=strengths,
|
||
weaknesses=weaknesses,
|
||
competitor_scores=competitor_scores,
|
||
)
|
||
|
||
|
||
@router.get("/action-suggestions/{brand_id}", response_model=ActionSuggestionsResponse)
|
||
async def get_onboarding_action_suggestions(
|
||
brand_id: uuid.UUID,
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""
|
||
根据健康报告生成行动建议(基于规则引擎,不需要 LLM)。
|
||
"""
|
||
# 先获取健康报告数据(复用逻辑)
|
||
report = await get_onboarding_health_report(brand_id, current_user, db)
|
||
|
||
suggestions = []
|
||
|
||
# 规则引擎:基于评分和平台数据生成建议
|
||
if report.overall_score < 20:
|
||
suggestions.append(ActionSuggestion(
|
||
title="提升 AI 平台覆盖率",
|
||
description=f"当前综合评分仅 {report.overall_score},品牌在AI搜索中几乎未被提及。建议增加查询词覆盖面,让AI平台更频繁地引用品牌。",
|
||
priority="high",
|
||
action_type="coverage",
|
||
))
|
||
|
||
if report.overall_score < 50:
|
||
suggestions.append(ActionSuggestion(
|
||
title="优化核心关键词",
|
||
description="品牌在关键查询词下的提及率偏低,建议调整查询关键词策略,聚焦行业核心术语。",
|
||
priority="high",
|
||
action_type="keyword",
|
||
))
|
||
|
||
# 平台维度建议
|
||
for platform, score in report.platform_scores.items():
|
||
if score < 30:
|
||
suggestions.append(ActionSuggestion(
|
||
title=f"提升 {platform} 平台覆盖率",
|
||
description=f"品牌在 {platform} 平台的引用率仅为 {score}%,需要针对性优化该平台的内容策略。",
|
||
priority="medium",
|
||
action_type="platform",
|
||
))
|
||
|
||
# 情感维度建议
|
||
if "情感倾向偏负面" in report.weaknesses:
|
||
suggestions.append(ActionSuggestion(
|
||
title="改善品牌情感倾向",
|
||
description="AI平台对品牌的情感评价偏负面,建议发布正面品牌内容、优化品牌描述以改善情感得分。",
|
||
priority="medium",
|
||
action_type="sentiment",
|
||
))
|
||
|
||
# 竞品对比建议
|
||
for comp in report.competitor_scores:
|
||
if comp["score"] > report.overall_score:
|
||
suggestions.append(ActionSuggestion(
|
||
title=f"应对竞品 {comp['name']} 威胁",
|
||
description=f"竞品 {comp['name']} 评分 ({comp['score']}) 高于本品牌 ({report.overall_score}),建议分析竞品优势领域并制定差异化策略。",
|
||
priority="high",
|
||
action_type="competitive",
|
||
))
|
||
|
||
# 如果没有引用数据,给出基础建议
|
||
if report.overall_score == 0:
|
||
suggestions = [
|
||
ActionSuggestion(
|
||
title="设置核心查询词",
|
||
description="品牌尚无查询数据,建议首先设置与品牌最相关的核心查询词,让系统开始数据采集。",
|
||
priority="high",
|
||
action_type="keyword",
|
||
),
|
||
ActionSuggestion(
|
||
title="添加竞品对比",
|
||
description="添加主要竞品以便进行对比分析,了解品牌在市场中的定位。",
|
||
priority="medium",
|
||
action_type="coverage",
|
||
),
|
||
ActionSuggestion(
|
||
title="完善品牌信息",
|
||
description="补充品牌别名、网站、行业等详细信息,有助于提升AI平台识别率。",
|
||
priority="medium",
|
||
action_type="brand_info",
|
||
),
|
||
]
|
||
|
||
# 确保至少有1条建议
|
||
if not suggestions:
|
||
suggestions.append(ActionSuggestion(
|
||
title="持续监测品牌表现",
|
||
description="品牌表现良好,建议持续监测并保持当前策略。",
|
||
priority="low",
|
||
action_type="monitor",
|
||
))
|
||
|
||
return ActionSuggestionsResponse(suggestions=suggestions)
|
||
|
||
|
||
@router.post("/complete/{brand_id}", response_model=OnboardingCompleteResponse)
|
||
async def complete_onboarding(
|
||
brand_id: uuid.UUID,
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""
|
||
标记用户已完成 onboarding。
|
||
|
||
通过验证品牌存在并归属当前用户来确认完成状态。
|
||
User 模型当前没有 onboarding_completed 专用字段,
|
||
品牌的创建即代表 onboarding 完成。
|
||
"""
|
||
# 验证品牌归属
|
||
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == current_user.id)
|
||
result = await db.execute(stmt)
|
||
brand = result.scalar_one_or_none()
|
||
|
||
if not brand:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="品牌不存在",
|
||
)
|
||
|
||
# 品牌已创建即代表 onboarding 完成,无需额外字段更新
|
||
# 后续如需专用字段,可通过 alembic 迁移添加 user.onboarding_completed
|
||
logger.info(f"User {current_user.id} completed onboarding with brand {brand_id}")
|
||
|
||
return OnboardingCompleteResponse(success=True) |