geo/backend/app/api/onboarding.py

585 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Onboarding API endpoints - 新用户引导流程"""
import logging
import uuid
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query, status
from pydantic import BaseModel, Field
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.api.competitors import (
INDUSTRY_COMPETITORS,
_get_rule_based_recommendations,
_get_llm_recommendations,
CompetitorRecommendationItem,
)
from app.config import settings
from app.database import get_db
from app.models.user import User
from app.models.brand import Brand
from app.models.competitor import Competitor
from app.models.query import Query as QueryModel
from app.schemas.brand import BrandCreate, BrandResponse
from app.services.diagnosis.data_collector import DataCollectorService
from app.services.diagnosis.geo_diagnosis import GEODiagnosisService
from app.utils.health import get_health_level, get_health_level_label
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/onboarding", tags=["onboarding"])
_FREE_TIER_DIMENSIONS = {"内容可提取性", "E-E-A-T信号", "引用就绪度"}
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
if isinstance(value, uuid.UUID):
return value
return uuid.UUID(str(value))
# ------------------------------------------------------------------
# Request / Response schemas
# ------------------------------------------------------------------
class OnboardingBrandCreate(BaseModel):
"""Onboarding 创建品牌请求(简化版)"""
name: str = Field(..., min_length=2, max_length=50, description="品牌名称")
description: Optional[str] = Field(None, max_length=500, description="品牌描述")
industry: Optional[str] = Field(None, max_length=50, description="行业")
class OnboardingStatusResponse(BaseModel):
"""Onboarding 状态响应"""
completed: bool
brand_id: Optional[str] = None
current_step: int
class CompetitorRecommendationSimple(BaseModel):
"""简化竞品推荐项"""
name: str
description: str
confidence: float
class CompetitorRecommendationSimpleResponse(BaseModel):
"""简化竞品推荐响应"""
recommendations: list[CompetitorRecommendationSimple]
class HealthDimensionItem(BaseModel):
name: str
score: float
max_score: float
percentage: float
status: str
class HealthRecommendationItem(BaseModel):
priority: str
dimension: str
title: str
description: str
class HealthReportResponse(BaseModel):
brand_id: str
brand_name: str
overall_score: float
health_level: str
health_level_label: str
platform_scores: dict
strengths: list[str]
weaknesses: list[str]
competitor_scores: list[dict]
dimensions: list[HealthDimensionItem] = []
recommendations: list[HealthRecommendationItem] = []
is_full_report: bool = False
class ActionSuggestion(BaseModel):
title: str
description: str
priority: str
action_type: str
is_paid_action: bool = False
action_button_text: str = ""
class ActionSuggestionsResponse(BaseModel):
"""行动建议响应"""
suggestions: list[ActionSuggestion]
class OnboardingCompleteResponse(BaseModel):
"""完成 onboarding 响应"""
success: bool
# ------------------------------------------------------------------
# Endpoints
# ------------------------------------------------------------------
@router.get("/status", response_model=OnboardingStatusResponse)
async def get_onboarding_status(
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""
检查当前用户的 onboarding 状态。
通过查询 brands 表判断:用户是否已创建品牌(即完成 onboarding
- completed=True 且 brand_id 有值 → 已完成
- completed=False, current_step=1 → 需要创建品牌
"""
stmt = select(Brand).where(Brand.user_id == _to_uuid(current_user.id))
result = await db.execute(stmt)
brand = result.scalar_one_or_none()
if brand:
return OnboardingStatusResponse(
completed=True,
brand_id=str(brand.id),
current_step=4,
)
return OnboardingStatusResponse(
completed=False,
brand_id=None,
current_step=1,
)
@router.post("/brand", response_model=BrandResponse, status_code=status.HTTP_201_CREATED)
async def create_onboarding_brand(
brand_data: OnboardingBrandCreate,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""
Onboarding 流程中创建品牌。
复用 Brand 模型,将简化字段映射到完整 BrandCreate。
"""
full_brand_data = BrandCreate(
name=brand_data.name,
aliases=[],
website=None,
industry=brand_data.industry,
platforms=["wenxin", "kimi"],
frequency="weekly",
)
brand = Brand(
user_id=_to_uuid(current_user.id),
name=full_brand_data.name,
aliases=full_brand_data.aliases,
website=full_brand_data.website,
industry=full_brand_data.industry,
platforms=full_brand_data.platforms,
frequency=full_brand_data.frequency,
)
db.add(brand)
await db.commit()
await db.refresh(brand)
# 自动创建默认查询词(检查 max_queries 限制)
try:
current_query_count_stmt = select(func.count()).select_from(QueryModel).where(
QueryModel.user_id == current_user.id
)
current_query_count_result = await db.execute(current_query_count_stmt)
current_query_count = current_query_count_result.scalar_one()
max_queries = getattr(current_user, "max_queries", 3) # 默认免费版 3 个
if current_query_count < max_queries:
default_query = QueryModel(
user_id=current_user.id,
keyword=f"{brand.name} 推荐",
target_brand=brand.name,
brand_aliases=brand.aliases or [],
platforms=brand.platforms or ["wenxin", "kimi"],
frequency=brand.frequency or "weekly",
status="active",
)
db.add(default_query)
await db.commit()
await db.refresh(default_query)
logger.info(f"Auto-created default query for brand '{brand.name}'")
else:
logger.info(
f"Skipped auto-creating default query for brand '{brand.name}': "
f"query limit reached ({current_query_count}/{max_queries})"
)
except Exception as e:
logger.warning(f"Failed to auto-create default query for brand '{brand.name}': {e}")
return brand
@router.get("/competitor-recommendations", response_model=CompetitorRecommendationSimpleResponse)
async def get_onboarding_competitor_recommendations(
brand_id: uuid.UUID = Query(..., description="品牌ID"),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""
根据品牌推荐竞品。
复用 brands/competitors 中的推荐逻辑,
支持 LLM 智能推荐和规则推荐两种模式。
"""
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == _to_uuid(current_user.id))
result = await db.execute(stmt)
brand = result.scalar_one_or_none()
if not brand:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="品牌不存在",
)
existing_stmt = select(Competitor.name).where(Competitor.brand_id == brand_id)
existing_result = await db.execute(existing_stmt)
existing_names = [row[0] for row in existing_result.all()]
# 选择推荐策略
if settings.ENABLE_LLM and settings.DEEPSEEK_API_KEY:
try:
rec_items = await _get_llm_recommendations(
brand_name=brand.name,
industry=brand.industry,
existing_names=existing_names,
)
except Exception as e:
logger.warning(f"Onboarding LLM竞品推荐失败回退规则推荐: {e}")
rec_items = _get_rule_based_recommendations(
brand_name=brand.name,
industry=brand.industry,
existing_names=existing_names,
)
else:
rec_items = _get_rule_based_recommendations(
brand_name=brand.name,
industry=brand.industry,
existing_names=existing_names,
)
# 转换为简化格式
recommendations = []
for item in rec_items:
confidence = 0.8 if brand.industry and brand.industry in INDUSTRY_COMPETITORS else 0.5
recommendations.append(CompetitorRecommendationSimple(
name=item.name,
description=item.reason,
confidence=confidence,
))
return CompetitorRecommendationSimpleResponse(recommendations=recommendations)
@router.get("/health-report/{brand_id}", response_model=HealthReportResponse)
async def get_onboarding_health_report(
brand_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
user_uuid = _to_uuid(current_user.id)
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == user_uuid)
result = await db.execute(stmt)
brand = result.scalar_one_or_none()
if not brand:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="品牌不存在",
)
user_plan = getattr(current_user, "plan", None) or "free"
is_paid = user_plan not in ("free", None)
try:
collector = DataCollectorService(db)
collection = await collector.collect(
brand_name=brand.name,
brand_aliases=brand.aliases or [],
website=brand.website,
industry=brand.industry,
)
geo_service = GEODiagnosisService()
diagnosis_result = geo_service.diagnose(collection.diagnosis_input)
except Exception as e:
logger.error(f"Onboarding健康报告采集失败: brand_id={brand_id}, error={e}", exc_info=True)
return HealthReportResponse(
brand_id=str(brand.id),
brand_name=brand.name,
overall_score=0.0,
health_level="danger",
health_level_label=get_health_level_label("danger"),
platform_scores={},
strengths=["品牌已创建,等待数据采集"],
weaknesses=["数据采集失败,请稍后重试"],
competitor_scores=[],
dimensions=[
HealthDimensionItem(name=d, score=0.0, max_score=0.0, percentage=0.0, status="fail")
for d in sorted(_FREE_TIER_DIMENSIONS)
],
recommendations=[],
is_full_report=is_paid,
)
result_dict = diagnosis_result.to_dict()
all_dimensions = result_dict.get("dimensions", [])
all_recommendations = result_dict.get("recommendations", [])
if is_paid:
filtered_dimensions = all_dimensions
filtered_recommendations = all_recommendations
else:
filtered_dimensions = [d for d in all_dimensions if d.get("name") in _FREE_TIER_DIMENSIONS]
filtered_recommendations = [r for r in all_recommendations if r.get("priority") == "P0"]
overall_score = round(diagnosis_result.overall_score, 2)
health_level = get_health_level(overall_score)
platform_scores: dict[str, float] = {}
queries_stmt = select(QueryModel).where(
QueryModel.user_id == current_user.id,
QueryModel.target_brand == brand.name,
)
queries_result = await db.execute(queries_stmt)
queries = list(queries_result.scalars().all())
if queries:
from app.models.citation_record import CitationRecord
query_ids = [q.id for q in queries]
citations_stmt = select(CitationRecord).where(
CitationRecord.query_id.in_(query_ids),
)
citations_result = await db.execute(citations_stmt)
citations = list(citations_result.scalars().all())
platforms_seen: dict[str, dict] = {}
for c in citations:
p = c.platform or "unknown"
if p not in platforms_seen:
platforms_seen[p] = {"total": 0, "cited": 0}
platforms_seen[p]["total"] += 1
if c.cited:
platforms_seen[p]["cited"] += 1
for p, data in platforms_seen.items():
rate = (data["cited"] / data["total"] * 100) if data["total"] > 0 else 0.0
platform_scores[p] = round(rate, 2)
strengths = []
weaknesses = []
for d in filtered_dimensions:
pct = d.get("percentage", 0)
if pct >= 60:
strengths.append(f"{d.get('name', '')}表现良好 ({round(pct, 1)}%)")
elif pct > 0:
weaknesses.append(f"{d.get('name', '')}有待提升 ({round(pct, 1)}%)")
if not strengths:
strengths.append("已有初步诊断数据")
if not weaknesses:
weaknesses.append("暂无明显短板")
competitor_stmt = select(Competitor).where(Competitor.brand_id == brand_id)
competitor_result = await db.execute(competitor_stmt)
competitors = list(competitor_result.scalars().all())
competitor_scores = []
for comp in competitors:
competitor_scores.append({
"name": comp.name,
"score": 0.0,
"is_leading": False,
})
dimension_items = [
HealthDimensionItem(
name=d.get("name", ""),
score=round(d.get("score", 0.0), 2),
max_score=d.get("max_score", 0.0),
percentage=round(d.get("percentage", 0.0), 2),
status=d.get("status", "fail"),
)
for d in filtered_dimensions
]
recommendation_items = [
HealthRecommendationItem(
priority=r.get("priority", "P0"),
dimension=r.get("dimension", ""),
title=r.get("title", ""),
description=r.get("description", ""),
)
for r in filtered_recommendations
]
return HealthReportResponse(
brand_id=str(brand.id),
brand_name=brand.name,
overall_score=overall_score,
health_level=health_level,
health_level_label=get_health_level_label(health_level),
platform_scores=platform_scores,
strengths=strengths,
weaknesses=weaknesses,
competitor_scores=competitor_scores,
dimensions=dimension_items,
recommendations=recommendation_items,
is_full_report=is_paid,
)
@router.get("/action-suggestions/{brand_id}", response_model=ActionSuggestionsResponse)
async def get_onboarding_action_suggestions(
brand_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
report = await get_onboarding_health_report(brand_id, current_user, db)
user_plan = getattr(current_user, "plan", None) or "free"
is_paid = user_plan not in ("free", None)
suggestions = []
if report.overall_score < 20:
suggestions.append(ActionSuggestion(
title="提升 AI 平台覆盖率",
description=f"当前综合评分仅 {report.overall_score}品牌在AI搜索中几乎未被提及。建议增加查询词覆盖面让AI平台更频繁地引用品牌。",
priority="high",
action_type="coverage",
is_paid_action=False,
action_button_text="设置查询词",
))
if report.overall_score < 50:
suggestions.append(ActionSuggestion(
title="优化核心关键词",
description="品牌在关键查询词下的提及率偏低,建议调整查询关键词策略,聚焦行业核心术语。",
priority="high",
action_type="keyword",
is_paid_action=False,
action_button_text="优化关键词",
))
for platform, score in report.platform_scores.items():
if score < 30:
suggestions.append(ActionSuggestion(
title=f"提升 {platform} 平台覆盖率",
description=f"品牌在 {platform} 平台的引用率仅为 {score}%,需要针对性优化该平台的内容策略。",
priority="medium",
action_type="platform",
is_paid_action=False,
action_button_text=f"优化{platform}",
))
for dim in report.dimensions:
if dim.percentage < 40 and dim.name not in _FREE_TIER_DIMENSIONS:
suggestions.append(ActionSuggestion(
title=f"提升{dim.name}得分",
description=f"{dim.name}当前得分仅 {round(dim.percentage, 1)}%,解锁详细诊断和优化方案。",
priority="medium",
action_type="dimension",
is_paid_action=True,
action_button_text="升级解锁",
))
for comp in report.competitor_scores:
if comp.get("score", 0) > report.overall_score:
suggestions.append(ActionSuggestion(
title=f"应对竞品 {comp['name']} 威胁",
description=f"竞品 {comp['name']} 评分 ({comp['score']}) 高于本品牌 ({report.overall_score}),建议分析竞品优势领域并制定差异化策略。",
priority="high",
action_type="competitive",
is_paid_action=True,
action_button_text="查看竞品分析",
))
if report.overall_score == 0:
suggestions = [
ActionSuggestion(
title="设置核心查询词",
description="品牌尚无查询数据,建议首先设置与品牌最相关的核心查询词,让系统开始数据采集。",
priority="high",
action_type="keyword",
is_paid_action=False,
action_button_text="设置查询词",
),
ActionSuggestion(
title="添加竞品对比",
description="添加主要竞品以便进行对比分析,了解品牌在市场中的定位。",
priority="medium",
action_type="coverage",
is_paid_action=False,
action_button_text="添加竞品",
),
ActionSuggestion(
title="解锁完整6维度诊断",
description="免费版仅展示3个核心维度升级后可查看完整6维度诊断报告和深度优化方案。",
priority="medium",
action_type="upgrade",
is_paid_action=True,
action_button_text="升级Pro",
),
]
if not suggestions:
suggestions.append(ActionSuggestion(
title="持续监测品牌表现",
description="品牌表现良好,建议持续监测并保持当前策略。",
priority="low",
action_type="monitor",
is_paid_action=False,
action_button_text="查看Dashboard",
))
if not is_paid:
suggestions.append(ActionSuggestion(
title="升级Pro解锁完整诊断",
description="免费版仅展示3个核心维度和P0建议。升级Pro可获取完整6维度诊断、深度竞品分析和AI优化方案。",
priority="low",
action_type="upgrade",
is_paid_action=True,
action_button_text="升级Pro",
))
return ActionSuggestionsResponse(suggestions=suggestions)
@router.post("/complete/{brand_id}", response_model=OnboardingCompleteResponse)
async def complete_onboarding(
brand_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""
标记用户已完成 onboarding。
通过验证品牌存在并归属当前用户来确认完成状态。
User 模型当前没有 onboarding_completed 专用字段,
品牌的创建即代表 onboarding 完成。
"""
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == _to_uuid(current_user.id))
result = await db.execute(stmt)
brand = result.scalar_one_or_none()
if not brand:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="品牌不存在",
)
logger.info(f"User {current_user.id} completed onboarding with brand {brand_id}")
return OnboardingCompleteResponse(success=True)