585 lines
20 KiB
Python
585 lines
20 KiB
Python
"""Onboarding API endpoints - 新用户引导流程"""
|
||
import logging
|
||
import uuid
|
||
from typing import Optional
|
||
|
||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||
from pydantic import BaseModel, Field
|
||
from sqlalchemy import select, func
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.api.deps import get_current_user
|
||
from app.api.competitors import (
|
||
INDUSTRY_COMPETITORS,
|
||
_get_rule_based_recommendations,
|
||
_get_llm_recommendations,
|
||
CompetitorRecommendationItem,
|
||
)
|
||
from app.config import settings
|
||
from app.database import get_db
|
||
from app.models.user import User
|
||
from app.models.brand import Brand
|
||
from app.models.competitor import Competitor
|
||
from app.models.query import Query as QueryModel
|
||
from app.schemas.brand import BrandCreate, BrandResponse
|
||
from app.services.diagnosis.data_collector import DataCollectorService
|
||
from app.services.diagnosis.geo_diagnosis import GEODiagnosisService
|
||
from app.utils.health import get_health_level, get_health_level_label
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
router = APIRouter(prefix="/onboarding", tags=["onboarding"])
|
||
|
||
_FREE_TIER_DIMENSIONS = {"内容可提取性", "E-E-A-T信号", "引用就绪度"}
|
||
|
||
|
||
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
|
||
if isinstance(value, uuid.UUID):
|
||
return value
|
||
return uuid.UUID(str(value))
|
||
|
||
|
||
# ------------------------------------------------------------------
|
||
# Request / Response schemas
|
||
# ------------------------------------------------------------------
|
||
|
||
class OnboardingBrandCreate(BaseModel):
|
||
"""Onboarding 创建品牌请求(简化版)"""
|
||
name: str = Field(..., min_length=2, max_length=50, description="品牌名称")
|
||
description: Optional[str] = Field(None, max_length=500, description="品牌描述")
|
||
industry: Optional[str] = Field(None, max_length=50, description="行业")
|
||
|
||
|
||
class OnboardingStatusResponse(BaseModel):
|
||
"""Onboarding 状态响应"""
|
||
completed: bool
|
||
brand_id: Optional[str] = None
|
||
current_step: int
|
||
|
||
|
||
class CompetitorRecommendationSimple(BaseModel):
|
||
"""简化竞品推荐项"""
|
||
name: str
|
||
description: str
|
||
confidence: float
|
||
|
||
|
||
class CompetitorRecommendationSimpleResponse(BaseModel):
|
||
"""简化竞品推荐响应"""
|
||
recommendations: list[CompetitorRecommendationSimple]
|
||
|
||
|
||
class HealthDimensionItem(BaseModel):
|
||
name: str
|
||
score: float
|
||
max_score: float
|
||
percentage: float
|
||
status: str
|
||
|
||
|
||
class HealthRecommendationItem(BaseModel):
|
||
priority: str
|
||
dimension: str
|
||
title: str
|
||
description: str
|
||
|
||
|
||
class HealthReportResponse(BaseModel):
|
||
brand_id: str
|
||
brand_name: str
|
||
overall_score: float
|
||
health_level: str
|
||
health_level_label: str
|
||
platform_scores: dict
|
||
strengths: list[str]
|
||
weaknesses: list[str]
|
||
competitor_scores: list[dict]
|
||
dimensions: list[HealthDimensionItem] = []
|
||
recommendations: list[HealthRecommendationItem] = []
|
||
is_full_report: bool = False
|
||
|
||
|
||
class ActionSuggestion(BaseModel):
|
||
title: str
|
||
description: str
|
||
priority: str
|
||
action_type: str
|
||
is_paid_action: bool = False
|
||
action_button_text: str = ""
|
||
|
||
|
||
class ActionSuggestionsResponse(BaseModel):
|
||
"""行动建议响应"""
|
||
suggestions: list[ActionSuggestion]
|
||
|
||
|
||
class OnboardingCompleteResponse(BaseModel):
|
||
"""完成 onboarding 响应"""
|
||
success: bool
|
||
|
||
|
||
# ------------------------------------------------------------------
|
||
# Endpoints
|
||
# ------------------------------------------------------------------
|
||
|
||
@router.get("/status", response_model=OnboardingStatusResponse)
|
||
async def get_onboarding_status(
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""
|
||
检查当前用户的 onboarding 状态。
|
||
|
||
通过查询 brands 表判断:用户是否已创建品牌(即完成 onboarding)。
|
||
- completed=True 且 brand_id 有值 → 已完成
|
||
- completed=False, current_step=1 → 需要创建品牌
|
||
"""
|
||
stmt = select(Brand).where(Brand.user_id == _to_uuid(current_user.id))
|
||
result = await db.execute(stmt)
|
||
brand = result.scalar_one_or_none()
|
||
|
||
if brand:
|
||
return OnboardingStatusResponse(
|
||
completed=True,
|
||
brand_id=str(brand.id),
|
||
current_step=4,
|
||
)
|
||
|
||
return OnboardingStatusResponse(
|
||
completed=False,
|
||
brand_id=None,
|
||
current_step=1,
|
||
)
|
||
|
||
|
||
@router.post("/brand", response_model=BrandResponse, status_code=status.HTTP_201_CREATED)
|
||
async def create_onboarding_brand(
|
||
brand_data: OnboardingBrandCreate,
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""
|
||
Onboarding 流程中创建品牌。
|
||
|
||
复用 Brand 模型,将简化字段映射到完整 BrandCreate。
|
||
"""
|
||
full_brand_data = BrandCreate(
|
||
name=brand_data.name,
|
||
aliases=[],
|
||
website=None,
|
||
industry=brand_data.industry,
|
||
platforms=["wenxin", "kimi"],
|
||
frequency="weekly",
|
||
)
|
||
|
||
brand = Brand(
|
||
user_id=_to_uuid(current_user.id),
|
||
name=full_brand_data.name,
|
||
aliases=full_brand_data.aliases,
|
||
website=full_brand_data.website,
|
||
industry=full_brand_data.industry,
|
||
platforms=full_brand_data.platforms,
|
||
frequency=full_brand_data.frequency,
|
||
)
|
||
db.add(brand)
|
||
await db.commit()
|
||
await db.refresh(brand)
|
||
|
||
# 自动创建默认查询词(检查 max_queries 限制)
|
||
try:
|
||
current_query_count_stmt = select(func.count()).select_from(QueryModel).where(
|
||
QueryModel.user_id == current_user.id
|
||
)
|
||
current_query_count_result = await db.execute(current_query_count_stmt)
|
||
current_query_count = current_query_count_result.scalar_one()
|
||
|
||
max_queries = getattr(current_user, "max_queries", 3) # 默认免费版 3 个
|
||
|
||
if current_query_count < max_queries:
|
||
default_query = QueryModel(
|
||
user_id=current_user.id,
|
||
keyword=f"{brand.name} 推荐",
|
||
target_brand=brand.name,
|
||
brand_aliases=brand.aliases or [],
|
||
platforms=brand.platforms or ["wenxin", "kimi"],
|
||
frequency=brand.frequency or "weekly",
|
||
status="active",
|
||
)
|
||
db.add(default_query)
|
||
await db.commit()
|
||
await db.refresh(default_query)
|
||
logger.info(f"Auto-created default query for brand '{brand.name}'")
|
||
else:
|
||
logger.info(
|
||
f"Skipped auto-creating default query for brand '{brand.name}': "
|
||
f"query limit reached ({current_query_count}/{max_queries})"
|
||
)
|
||
except Exception as e:
|
||
logger.warning(f"Failed to auto-create default query for brand '{brand.name}': {e}")
|
||
|
||
return brand
|
||
|
||
|
||
@router.get("/competitor-recommendations", response_model=CompetitorRecommendationSimpleResponse)
|
||
async def get_onboarding_competitor_recommendations(
|
||
brand_id: uuid.UUID = Query(..., description="品牌ID"),
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""
|
||
根据品牌推荐竞品。
|
||
|
||
复用 brands/competitors 中的推荐逻辑,
|
||
支持 LLM 智能推荐和规则推荐两种模式。
|
||
"""
|
||
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == _to_uuid(current_user.id))
|
||
result = await db.execute(stmt)
|
||
brand = result.scalar_one_or_none()
|
||
|
||
if not brand:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="品牌不存在",
|
||
)
|
||
|
||
existing_stmt = select(Competitor.name).where(Competitor.brand_id == brand_id)
|
||
existing_result = await db.execute(existing_stmt)
|
||
existing_names = [row[0] for row in existing_result.all()]
|
||
|
||
# 选择推荐策略
|
||
if settings.ENABLE_LLM and settings.DEEPSEEK_API_KEY:
|
||
try:
|
||
rec_items = await _get_llm_recommendations(
|
||
brand_name=brand.name,
|
||
industry=brand.industry,
|
||
existing_names=existing_names,
|
||
)
|
||
except Exception as e:
|
||
logger.warning(f"Onboarding LLM竞品推荐失败,回退规则推荐: {e}")
|
||
rec_items = _get_rule_based_recommendations(
|
||
brand_name=brand.name,
|
||
industry=brand.industry,
|
||
existing_names=existing_names,
|
||
)
|
||
else:
|
||
rec_items = _get_rule_based_recommendations(
|
||
brand_name=brand.name,
|
||
industry=brand.industry,
|
||
existing_names=existing_names,
|
||
)
|
||
|
||
# 转换为简化格式
|
||
recommendations = []
|
||
for item in rec_items:
|
||
confidence = 0.8 if brand.industry and brand.industry in INDUSTRY_COMPETITORS else 0.5
|
||
recommendations.append(CompetitorRecommendationSimple(
|
||
name=item.name,
|
||
description=item.reason,
|
||
confidence=confidence,
|
||
))
|
||
|
||
return CompetitorRecommendationSimpleResponse(recommendations=recommendations)
|
||
|
||
|
||
@router.get("/health-report/{brand_id}", response_model=HealthReportResponse)
|
||
async def get_onboarding_health_report(
|
||
brand_id: uuid.UUID,
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
user_uuid = _to_uuid(current_user.id)
|
||
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == user_uuid)
|
||
result = await db.execute(stmt)
|
||
brand = result.scalar_one_or_none()
|
||
|
||
if not brand:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="品牌不存在",
|
||
)
|
||
|
||
user_plan = getattr(current_user, "plan", None) or "free"
|
||
is_paid = user_plan not in ("free", None)
|
||
|
||
try:
|
||
collector = DataCollectorService(db)
|
||
collection = await collector.collect(
|
||
brand_name=brand.name,
|
||
brand_aliases=brand.aliases or [],
|
||
website=brand.website,
|
||
industry=brand.industry,
|
||
)
|
||
|
||
geo_service = GEODiagnosisService()
|
||
diagnosis_result = geo_service.diagnose(collection.diagnosis_input)
|
||
except Exception as e:
|
||
logger.error(f"Onboarding健康报告采集失败: brand_id={brand_id}, error={e}", exc_info=True)
|
||
return HealthReportResponse(
|
||
brand_id=str(brand.id),
|
||
brand_name=brand.name,
|
||
overall_score=0.0,
|
||
health_level="danger",
|
||
health_level_label=get_health_level_label("danger"),
|
||
platform_scores={},
|
||
strengths=["品牌已创建,等待数据采集"],
|
||
weaknesses=["数据采集失败,请稍后重试"],
|
||
competitor_scores=[],
|
||
dimensions=[
|
||
HealthDimensionItem(name=d, score=0.0, max_score=0.0, percentage=0.0, status="fail")
|
||
for d in sorted(_FREE_TIER_DIMENSIONS)
|
||
],
|
||
recommendations=[],
|
||
is_full_report=is_paid,
|
||
)
|
||
|
||
result_dict = diagnosis_result.to_dict()
|
||
all_dimensions = result_dict.get("dimensions", [])
|
||
all_recommendations = result_dict.get("recommendations", [])
|
||
|
||
if is_paid:
|
||
filtered_dimensions = all_dimensions
|
||
filtered_recommendations = all_recommendations
|
||
else:
|
||
filtered_dimensions = [d for d in all_dimensions if d.get("name") in _FREE_TIER_DIMENSIONS]
|
||
filtered_recommendations = [r for r in all_recommendations if r.get("priority") == "P0"]
|
||
|
||
overall_score = round(diagnosis_result.overall_score, 2)
|
||
health_level = get_health_level(overall_score)
|
||
|
||
platform_scores: dict[str, float] = {}
|
||
queries_stmt = select(QueryModel).where(
|
||
QueryModel.user_id == current_user.id,
|
||
QueryModel.target_brand == brand.name,
|
||
)
|
||
queries_result = await db.execute(queries_stmt)
|
||
queries = list(queries_result.scalars().all())
|
||
|
||
if queries:
|
||
from app.models.citation_record import CitationRecord
|
||
query_ids = [q.id for q in queries]
|
||
citations_stmt = select(CitationRecord).where(
|
||
CitationRecord.query_id.in_(query_ids),
|
||
)
|
||
citations_result = await db.execute(citations_stmt)
|
||
citations = list(citations_result.scalars().all())
|
||
|
||
platforms_seen: dict[str, dict] = {}
|
||
for c in citations:
|
||
p = c.platform or "unknown"
|
||
if p not in platforms_seen:
|
||
platforms_seen[p] = {"total": 0, "cited": 0}
|
||
platforms_seen[p]["total"] += 1
|
||
if c.cited:
|
||
platforms_seen[p]["cited"] += 1
|
||
|
||
for p, data in platforms_seen.items():
|
||
rate = (data["cited"] / data["total"] * 100) if data["total"] > 0 else 0.0
|
||
platform_scores[p] = round(rate, 2)
|
||
|
||
strengths = []
|
||
weaknesses = []
|
||
|
||
for d in filtered_dimensions:
|
||
pct = d.get("percentage", 0)
|
||
if pct >= 60:
|
||
strengths.append(f"{d.get('name', '')}表现良好 ({round(pct, 1)}%)")
|
||
elif pct > 0:
|
||
weaknesses.append(f"{d.get('name', '')}有待提升 ({round(pct, 1)}%)")
|
||
|
||
if not strengths:
|
||
strengths.append("已有初步诊断数据")
|
||
if not weaknesses:
|
||
weaknesses.append("暂无明显短板")
|
||
|
||
competitor_stmt = select(Competitor).where(Competitor.brand_id == brand_id)
|
||
competitor_result = await db.execute(competitor_stmt)
|
||
competitors = list(competitor_result.scalars().all())
|
||
|
||
competitor_scores = []
|
||
for comp in competitors:
|
||
competitor_scores.append({
|
||
"name": comp.name,
|
||
"score": 0.0,
|
||
"is_leading": False,
|
||
})
|
||
|
||
dimension_items = [
|
||
HealthDimensionItem(
|
||
name=d.get("name", ""),
|
||
score=round(d.get("score", 0.0), 2),
|
||
max_score=d.get("max_score", 0.0),
|
||
percentage=round(d.get("percentage", 0.0), 2),
|
||
status=d.get("status", "fail"),
|
||
)
|
||
for d in filtered_dimensions
|
||
]
|
||
|
||
recommendation_items = [
|
||
HealthRecommendationItem(
|
||
priority=r.get("priority", "P0"),
|
||
dimension=r.get("dimension", ""),
|
||
title=r.get("title", ""),
|
||
description=r.get("description", ""),
|
||
)
|
||
for r in filtered_recommendations
|
||
]
|
||
|
||
return HealthReportResponse(
|
||
brand_id=str(brand.id),
|
||
brand_name=brand.name,
|
||
overall_score=overall_score,
|
||
health_level=health_level,
|
||
health_level_label=get_health_level_label(health_level),
|
||
platform_scores=platform_scores,
|
||
strengths=strengths,
|
||
weaknesses=weaknesses,
|
||
competitor_scores=competitor_scores,
|
||
dimensions=dimension_items,
|
||
recommendations=recommendation_items,
|
||
is_full_report=is_paid,
|
||
)
|
||
|
||
|
||
@router.get("/action-suggestions/{brand_id}", response_model=ActionSuggestionsResponse)
|
||
async def get_onboarding_action_suggestions(
|
||
brand_id: uuid.UUID,
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
report = await get_onboarding_health_report(brand_id, current_user, db)
|
||
|
||
user_plan = getattr(current_user, "plan", None) or "free"
|
||
is_paid = user_plan not in ("free", None)
|
||
|
||
suggestions = []
|
||
|
||
if report.overall_score < 20:
|
||
suggestions.append(ActionSuggestion(
|
||
title="提升 AI 平台覆盖率",
|
||
description=f"当前综合评分仅 {report.overall_score},品牌在AI搜索中几乎未被提及。建议增加查询词覆盖面,让AI平台更频繁地引用品牌。",
|
||
priority="high",
|
||
action_type="coverage",
|
||
is_paid_action=False,
|
||
action_button_text="设置查询词",
|
||
))
|
||
|
||
if report.overall_score < 50:
|
||
suggestions.append(ActionSuggestion(
|
||
title="优化核心关键词",
|
||
description="品牌在关键查询词下的提及率偏低,建议调整查询关键词策略,聚焦行业核心术语。",
|
||
priority="high",
|
||
action_type="keyword",
|
||
is_paid_action=False,
|
||
action_button_text="优化关键词",
|
||
))
|
||
|
||
for platform, score in report.platform_scores.items():
|
||
if score < 30:
|
||
suggestions.append(ActionSuggestion(
|
||
title=f"提升 {platform} 平台覆盖率",
|
||
description=f"品牌在 {platform} 平台的引用率仅为 {score}%,需要针对性优化该平台的内容策略。",
|
||
priority="medium",
|
||
action_type="platform",
|
||
is_paid_action=False,
|
||
action_button_text=f"优化{platform}",
|
||
))
|
||
|
||
for dim in report.dimensions:
|
||
if dim.percentage < 40 and dim.name not in _FREE_TIER_DIMENSIONS:
|
||
suggestions.append(ActionSuggestion(
|
||
title=f"提升{dim.name}得分",
|
||
description=f"{dim.name}当前得分仅 {round(dim.percentage, 1)}%,解锁详细诊断和优化方案。",
|
||
priority="medium",
|
||
action_type="dimension",
|
||
is_paid_action=True,
|
||
action_button_text="升级解锁",
|
||
))
|
||
|
||
for comp in report.competitor_scores:
|
||
if comp.get("score", 0) > report.overall_score:
|
||
suggestions.append(ActionSuggestion(
|
||
title=f"应对竞品 {comp['name']} 威胁",
|
||
description=f"竞品 {comp['name']} 评分 ({comp['score']}) 高于本品牌 ({report.overall_score}),建议分析竞品优势领域并制定差异化策略。",
|
||
priority="high",
|
||
action_type="competitive",
|
||
is_paid_action=True,
|
||
action_button_text="查看竞品分析",
|
||
))
|
||
|
||
if report.overall_score == 0:
|
||
suggestions = [
|
||
ActionSuggestion(
|
||
title="设置核心查询词",
|
||
description="品牌尚无查询数据,建议首先设置与品牌最相关的核心查询词,让系统开始数据采集。",
|
||
priority="high",
|
||
action_type="keyword",
|
||
is_paid_action=False,
|
||
action_button_text="设置查询词",
|
||
),
|
||
ActionSuggestion(
|
||
title="添加竞品对比",
|
||
description="添加主要竞品以便进行对比分析,了解品牌在市场中的定位。",
|
||
priority="medium",
|
||
action_type="coverage",
|
||
is_paid_action=False,
|
||
action_button_text="添加竞品",
|
||
),
|
||
ActionSuggestion(
|
||
title="解锁完整6维度诊断",
|
||
description="免费版仅展示3个核心维度,升级后可查看完整6维度诊断报告和深度优化方案。",
|
||
priority="medium",
|
||
action_type="upgrade",
|
||
is_paid_action=True,
|
||
action_button_text="升级Pro",
|
||
),
|
||
]
|
||
|
||
if not suggestions:
|
||
suggestions.append(ActionSuggestion(
|
||
title="持续监测品牌表现",
|
||
description="品牌表现良好,建议持续监测并保持当前策略。",
|
||
priority="low",
|
||
action_type="monitor",
|
||
is_paid_action=False,
|
||
action_button_text="查看Dashboard",
|
||
))
|
||
|
||
if not is_paid:
|
||
suggestions.append(ActionSuggestion(
|
||
title="升级Pro解锁完整诊断",
|
||
description="免费版仅展示3个核心维度和P0建议。升级Pro可获取完整6维度诊断、深度竞品分析和AI优化方案。",
|
||
priority="low",
|
||
action_type="upgrade",
|
||
is_paid_action=True,
|
||
action_button_text="升级Pro",
|
||
))
|
||
|
||
return ActionSuggestionsResponse(suggestions=suggestions)
|
||
|
||
|
||
@router.post("/complete/{brand_id}", response_model=OnboardingCompleteResponse)
|
||
async def complete_onboarding(
|
||
brand_id: uuid.UUID,
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""
|
||
标记用户已完成 onboarding。
|
||
|
||
通过验证品牌存在并归属当前用户来确认完成状态。
|
||
User 模型当前没有 onboarding_completed 专用字段,
|
||
品牌的创建即代表 onboarding 完成。
|
||
"""
|
||
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == _to_uuid(current_user.id))
|
||
result = await db.execute(stmt)
|
||
brand = result.scalar_one_or_none()
|
||
|
||
if not brand:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="品牌不存在",
|
||
)
|
||
|
||
logger.info(f"User {current_user.id} completed onboarding with brand {brand_id}")
|
||
|
||
return OnboardingCompleteResponse(success=True) |