"""Onboarding API endpoints - 新用户引导流程""" import logging import uuid from typing import Optional from fastapi import APIRouter, Depends, HTTPException, Query, status from pydantic import BaseModel, Field from sqlalchemy import select, func from sqlalchemy.ext.asyncio import AsyncSession from app.api.deps import get_current_user from app.api.competitors import ( INDUSTRY_COMPETITORS, _get_rule_based_recommendations, _get_llm_recommendations, CompetitorRecommendationItem, ) from app.config import settings from app.database import get_db from app.models.user import User from app.models.brand import Brand from app.models.competitor import Competitor from app.models.query import Query as QueryModel from app.schemas.brand import BrandCreate, BrandResponse from app.services.diagnosis.data_collector import DataCollectorService from app.services.diagnosis.geo_diagnosis import GEODiagnosisService from app.utils.health import get_health_level, get_health_level_label logger = logging.getLogger(__name__) router = APIRouter(prefix="/onboarding", tags=["onboarding"]) _FREE_TIER_DIMENSIONS = {"内容可提取性", "E-E-A-T信号", "引用就绪度"} def _to_uuid(value: str | uuid.UUID) -> uuid.UUID: if isinstance(value, uuid.UUID): return value return uuid.UUID(str(value)) # ------------------------------------------------------------------ # Request / Response schemas # ------------------------------------------------------------------ class OnboardingBrandCreate(BaseModel): """Onboarding 创建品牌请求(简化版)""" name: str = Field(..., min_length=2, max_length=50, description="品牌名称") description: Optional[str] = Field(None, max_length=500, description="品牌描述") industry: Optional[str] = Field(None, max_length=50, description="行业") class OnboardingStatusResponse(BaseModel): """Onboarding 状态响应""" completed: bool brand_id: Optional[str] = None current_step: int class CompetitorRecommendationSimple(BaseModel): """简化竞品推荐项""" name: str description: str confidence: float class CompetitorRecommendationSimpleResponse(BaseModel): """简化竞品推荐响应""" recommendations: list[CompetitorRecommendationSimple] class HealthDimensionItem(BaseModel): name: str score: float max_score: float percentage: float status: str class HealthRecommendationItem(BaseModel): priority: str dimension: str title: str description: str class HealthReportResponse(BaseModel): brand_id: str brand_name: str overall_score: float health_level: str health_level_label: str platform_scores: dict strengths: list[str] weaknesses: list[str] competitor_scores: list[dict] dimensions: list[HealthDimensionItem] = [] recommendations: list[HealthRecommendationItem] = [] is_full_report: bool = False class ActionSuggestion(BaseModel): title: str description: str priority: str action_type: str is_paid_action: bool = False action_button_text: str = "" class ActionSuggestionsResponse(BaseModel): """行动建议响应""" suggestions: list[ActionSuggestion] class OnboardingCompleteResponse(BaseModel): """完成 onboarding 响应""" success: bool # ------------------------------------------------------------------ # Endpoints # ------------------------------------------------------------------ @router.get("/status", response_model=OnboardingStatusResponse) async def get_onboarding_status( current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """ 检查当前用户的 onboarding 状态。 通过查询 brands 表判断:用户是否已创建品牌(即完成 onboarding)。 - completed=True 且 brand_id 有值 → 已完成 - completed=False, current_step=1 → 需要创建品牌 """ stmt = select(Brand).where(Brand.user_id == _to_uuid(current_user.id)) result = await db.execute(stmt) brand = result.scalar_one_or_none() if brand: return OnboardingStatusResponse( completed=True, brand_id=str(brand.id), current_step=4, ) return OnboardingStatusResponse( completed=False, brand_id=None, current_step=1, ) @router.post("/brand", response_model=BrandResponse, status_code=status.HTTP_201_CREATED) async def create_onboarding_brand( brand_data: OnboardingBrandCreate, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """ Onboarding 流程中创建品牌。 复用 Brand 模型,将简化字段映射到完整 BrandCreate。 """ full_brand_data = BrandCreate( name=brand_data.name, aliases=[], website=None, industry=brand_data.industry, platforms=["wenxin", "kimi"], frequency="weekly", ) brand = Brand( user_id=_to_uuid(current_user.id), name=full_brand_data.name, aliases=full_brand_data.aliases, website=full_brand_data.website, industry=full_brand_data.industry, platforms=full_brand_data.platforms, frequency=full_brand_data.frequency, ) db.add(brand) await db.commit() await db.refresh(brand) # 自动创建默认查询词(检查 max_queries 限制) try: current_query_count_stmt = select(func.count()).select_from(QueryModel).where( QueryModel.user_id == current_user.id ) current_query_count_result = await db.execute(current_query_count_stmt) current_query_count = current_query_count_result.scalar_one() max_queries = getattr(current_user, "max_queries", 3) # 默认免费版 3 个 if current_query_count < max_queries: default_query = QueryModel( user_id=current_user.id, keyword=f"{brand.name} 推荐", target_brand=brand.name, brand_aliases=brand.aliases or [], platforms=brand.platforms or ["wenxin", "kimi"], frequency=brand.frequency or "weekly", status="active", ) db.add(default_query) await db.commit() await db.refresh(default_query) logger.info(f"Auto-created default query for brand '{brand.name}'") else: logger.info( f"Skipped auto-creating default query for brand '{brand.name}': " f"query limit reached ({current_query_count}/{max_queries})" ) except Exception as e: logger.warning(f"Failed to auto-create default query for brand '{brand.name}': {e}") return brand @router.get("/competitor-recommendations", response_model=CompetitorRecommendationSimpleResponse) async def get_onboarding_competitor_recommendations( brand_id: uuid.UUID = Query(..., description="品牌ID"), current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """ 根据品牌推荐竞品。 复用 brands/competitors 中的推荐逻辑, 支持 LLM 智能推荐和规则推荐两种模式。 """ stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == _to_uuid(current_user.id)) result = await db.execute(stmt) brand = result.scalar_one_or_none() if not brand: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="品牌不存在", ) existing_stmt = select(Competitor.name).where(Competitor.brand_id == brand_id) existing_result = await db.execute(existing_stmt) existing_names = [row[0] for row in existing_result.all()] # 选择推荐策略 if settings.ENABLE_LLM and settings.DEEPSEEK_API_KEY: try: rec_items = await _get_llm_recommendations( brand_name=brand.name, industry=brand.industry, existing_names=existing_names, ) except Exception as e: logger.warning(f"Onboarding LLM竞品推荐失败,回退规则推荐: {e}") rec_items = _get_rule_based_recommendations( brand_name=brand.name, industry=brand.industry, existing_names=existing_names, ) else: rec_items = _get_rule_based_recommendations( brand_name=brand.name, industry=brand.industry, existing_names=existing_names, ) # 转换为简化格式 recommendations = [] for item in rec_items: confidence = 0.8 if brand.industry and brand.industry in INDUSTRY_COMPETITORS else 0.5 recommendations.append(CompetitorRecommendationSimple( name=item.name, description=item.reason, confidence=confidence, )) return CompetitorRecommendationSimpleResponse(recommendations=recommendations) @router.get("/health-report/{brand_id}", response_model=HealthReportResponse) async def get_onboarding_health_report( brand_id: uuid.UUID, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): user_uuid = _to_uuid(current_user.id) stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == user_uuid) result = await db.execute(stmt) brand = result.scalar_one_or_none() if not brand: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="品牌不存在", ) user_plan = getattr(current_user, "plan", None) or "free" is_paid = user_plan not in ("free", None) try: collector = DataCollectorService(db) collection = await collector.collect( brand_name=brand.name, brand_aliases=brand.aliases or [], website=brand.website, industry=brand.industry, ) geo_service = GEODiagnosisService() diagnosis_result = geo_service.diagnose(collection.diagnosis_input) except Exception as e: logger.error(f"Onboarding健康报告采集失败: brand_id={brand_id}, error={e}", exc_info=True) return HealthReportResponse( brand_id=str(brand.id), brand_name=brand.name, overall_score=0.0, health_level="danger", health_level_label=get_health_level_label("danger"), platform_scores={}, strengths=["品牌已创建,等待数据采集"], weaknesses=["数据采集失败,请稍后重试"], competitor_scores=[], dimensions=[ HealthDimensionItem(name=d, score=0.0, max_score=0.0, percentage=0.0, status="fail") for d in sorted(_FREE_TIER_DIMENSIONS) ], recommendations=[], is_full_report=is_paid, ) result_dict = diagnosis_result.to_dict() all_dimensions = result_dict.get("dimensions", []) all_recommendations = result_dict.get("recommendations", []) if is_paid: filtered_dimensions = all_dimensions filtered_recommendations = all_recommendations else: filtered_dimensions = [d for d in all_dimensions if d.get("name") in _FREE_TIER_DIMENSIONS] filtered_recommendations = [r for r in all_recommendations if r.get("priority") == "P0"] overall_score = round(diagnosis_result.overall_score, 2) health_level = get_health_level(overall_score) platform_scores: dict[str, float] = {} queries_stmt = select(QueryModel).where( QueryModel.user_id == current_user.id, QueryModel.target_brand == brand.name, ) queries_result = await db.execute(queries_stmt) queries = list(queries_result.scalars().all()) if queries: from app.models.citation_record import CitationRecord query_ids = [q.id for q in queries] citations_stmt = select(CitationRecord).where( CitationRecord.query_id.in_(query_ids), ) citations_result = await db.execute(citations_stmt) citations = list(citations_result.scalars().all()) platforms_seen: dict[str, dict] = {} for c in citations: p = c.platform or "unknown" if p not in platforms_seen: platforms_seen[p] = {"total": 0, "cited": 0} platforms_seen[p]["total"] += 1 if c.cited: platforms_seen[p]["cited"] += 1 for p, data in platforms_seen.items(): rate = (data["cited"] / data["total"] * 100) if data["total"] > 0 else 0.0 platform_scores[p] = round(rate, 2) strengths = [] weaknesses = [] for d in filtered_dimensions: pct = d.get("percentage", 0) if pct >= 60: strengths.append(f"{d.get('name', '')}表现良好 ({round(pct, 1)}%)") elif pct > 0: weaknesses.append(f"{d.get('name', '')}有待提升 ({round(pct, 1)}%)") if not strengths: strengths.append("已有初步诊断数据") if not weaknesses: weaknesses.append("暂无明显短板") competitor_stmt = select(Competitor).where(Competitor.brand_id == brand_id) competitor_result = await db.execute(competitor_stmt) competitors = list(competitor_result.scalars().all()) competitor_scores = [] for comp in competitors: competitor_scores.append({ "name": comp.name, "score": 0.0, "is_leading": False, }) dimension_items = [ HealthDimensionItem( name=d.get("name", ""), score=round(d.get("score", 0.0), 2), max_score=d.get("max_score", 0.0), percentage=round(d.get("percentage", 0.0), 2), status=d.get("status", "fail"), ) for d in filtered_dimensions ] recommendation_items = [ HealthRecommendationItem( priority=r.get("priority", "P0"), dimension=r.get("dimension", ""), title=r.get("title", ""), description=r.get("description", ""), ) for r in filtered_recommendations ] return HealthReportResponse( brand_id=str(brand.id), brand_name=brand.name, overall_score=overall_score, health_level=health_level, health_level_label=get_health_level_label(health_level), platform_scores=platform_scores, strengths=strengths, weaknesses=weaknesses, competitor_scores=competitor_scores, dimensions=dimension_items, recommendations=recommendation_items, is_full_report=is_paid, ) @router.get("/action-suggestions/{brand_id}", response_model=ActionSuggestionsResponse) async def get_onboarding_action_suggestions( brand_id: uuid.UUID, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): report = await get_onboarding_health_report(brand_id, current_user, db) user_plan = getattr(current_user, "plan", None) or "free" is_paid = user_plan not in ("free", None) suggestions = [] if report.overall_score < 20: suggestions.append(ActionSuggestion( title="提升 AI 平台覆盖率", description=f"当前综合评分仅 {report.overall_score},品牌在AI搜索中几乎未被提及。建议增加查询词覆盖面,让AI平台更频繁地引用品牌。", priority="high", action_type="coverage", is_paid_action=False, action_button_text="设置查询词", )) if report.overall_score < 50: suggestions.append(ActionSuggestion( title="优化核心关键词", description="品牌在关键查询词下的提及率偏低,建议调整查询关键词策略,聚焦行业核心术语。", priority="high", action_type="keyword", is_paid_action=False, action_button_text="优化关键词", )) for platform, score in report.platform_scores.items(): if score < 30: suggestions.append(ActionSuggestion( title=f"提升 {platform} 平台覆盖率", description=f"品牌在 {platform} 平台的引用率仅为 {score}%,需要针对性优化该平台的内容策略。", priority="medium", action_type="platform", is_paid_action=False, action_button_text=f"优化{platform}", )) for dim in report.dimensions: if dim.percentage < 40 and dim.name not in _FREE_TIER_DIMENSIONS: suggestions.append(ActionSuggestion( title=f"提升{dim.name}得分", description=f"{dim.name}当前得分仅 {round(dim.percentage, 1)}%,解锁详细诊断和优化方案。", priority="medium", action_type="dimension", is_paid_action=True, action_button_text="升级解锁", )) for comp in report.competitor_scores: if comp.get("score", 0) > report.overall_score: suggestions.append(ActionSuggestion( title=f"应对竞品 {comp['name']} 威胁", description=f"竞品 {comp['name']} 评分 ({comp['score']}) 高于本品牌 ({report.overall_score}),建议分析竞品优势领域并制定差异化策略。", priority="high", action_type="competitive", is_paid_action=True, action_button_text="查看竞品分析", )) if report.overall_score == 0: suggestions = [ ActionSuggestion( title="设置核心查询词", description="品牌尚无查询数据,建议首先设置与品牌最相关的核心查询词,让系统开始数据采集。", priority="high", action_type="keyword", is_paid_action=False, action_button_text="设置查询词", ), ActionSuggestion( title="添加竞品对比", description="添加主要竞品以便进行对比分析,了解品牌在市场中的定位。", priority="medium", action_type="coverage", is_paid_action=False, action_button_text="添加竞品", ), ActionSuggestion( title="解锁完整6维度诊断", description="免费版仅展示3个核心维度,升级后可查看完整6维度诊断报告和深度优化方案。", priority="medium", action_type="upgrade", is_paid_action=True, action_button_text="升级Pro", ), ] if not suggestions: suggestions.append(ActionSuggestion( title="持续监测品牌表现", description="品牌表现良好,建议持续监测并保持当前策略。", priority="low", action_type="monitor", is_paid_action=False, action_button_text="查看Dashboard", )) if not is_paid: suggestions.append(ActionSuggestion( title="升级Pro解锁完整诊断", description="免费版仅展示3个核心维度和P0建议。升级Pro可获取完整6维度诊断、深度竞品分析和AI优化方案。", priority="low", action_type="upgrade", is_paid_action=True, action_button_text="升级Pro", )) return ActionSuggestionsResponse(suggestions=suggestions) @router.post("/complete/{brand_id}", response_model=OnboardingCompleteResponse) async def complete_onboarding( brand_id: uuid.UUID, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """ 标记用户已完成 onboarding。 通过验证品牌存在并归属当前用户来确认完成状态。 User 模型当前没有 onboarding_completed 专用字段, 品牌的创建即代表 onboarding 完成。 """ stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == _to_uuid(current_user.id)) result = await db.execute(stmt) brand = result.scalar_one_or_none() if not brand: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="品牌不存在", ) logger.info(f"User {current_user.id} completed onboarding with brand {brand_id}") return OnboardingCompleteResponse(success=True)