"""优化建议 API endpoints.""" import uuid from datetime import datetime, timedelta from fastapi import APIRouter, Depends, HTTPException, Query, status from sqlalchemy import select, func, and_ from sqlalchemy.ext.asyncio import AsyncSession from app.api.deps import get_current_user from app.database import get_db from app.models.user import User from app.models.brand import Brand from app.models.competitor import Competitor from app.models.query import Query as QueryModel from app.models.citation_record import CitationRecord from app.models.suggestion import Suggestion from app.schemas.suggestion import ( SuggestionResponse, SuggestionListResponse, SuggestionUpdateStatus, SuggestionHistoryItem, SuggestionHistoryResponse, ) from app.schemas.scoring import CitationResult from app.services.scoring.scoring_service import ScoringService from app.services.analysis.sentiment_service import get_sentiment_service from app.services.advisor.optimization_advisor import ( generate_suggestions, build_context_from_scoring_result, ) router = APIRouter() def _to_uuid(value: str | uuid.UUID) -> uuid.UUID: if isinstance(value, uuid.UUID): return value return uuid.UUID(str(value)) async def _get_brand_with_access( brand_id: uuid.UUID, db: AsyncSession, current_user: User, ) -> Brand: """验证品牌存在且用户有访问权限""" stmt = select(Brand).where( Brand.id == brand_id, Brand.user_id == _to_uuid(current_user.id), ) result = await db.execute(stmt) brand = result.scalar_one_or_none() if not brand: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="品牌不存在", ) return brand async def _get_brand_scoring_data( db: AsyncSession, user_id: uuid.UUID, brand: Brand, ) -> tuple: """ 获取品牌评分数据,用于生成优化建议。 Returns: (scoring_result, competitor_data, sentiment_data, platform_scores, total_queries, mentioned_count) """ # 获取品牌查询 queries_stmt = select(QueryModel).where( QueryModel.user_id == str(user_id), QueryModel.target_brand == brand.name, ) queries_result = await db.execute(queries_stmt) queries = list(queries_result.scalars().all()) if not queries: # 没有查询数据,返回空评分 scoring_service = ScoringService() empty_result = scoring_service.calculate_v2( mentioned_count=0, total_queries=0, positions=[], sentiment_counts={"positive": 0, "neutral": 0, "negative": 0}, citations=[], brand_mentions=0, competitor_mentions={}, ) return empty_result, {}, {}, {}, 0, 0 query_ids = [q.id for q in queries] # 获取所有引用记录 citations_stmt = select(CitationRecord).where( CitationRecord.query_id.in_(query_ids), ) citations_result = await db.execute(citations_stmt) all_citations = list(citations_result.scalars().all()) total_queries = len(all_citations) brand_citations = [c for c in all_citations if c.cited] # 获取竞品数据 competitor_stmt = select(Competitor).where(Competitor.brand_id == brand.id) competitor_result = await db.execute(competitor_stmt) competitors = list(competitor_result.scalars().all()) competitor_names = [c.name for c in competitors] # 计算竞品提及 competitor_mentions: dict[str, int] = {} for comp_name in competitor_names: count = sum( 1 for c in all_citations if c.cited and c.competitor_brands and comp_name in c.competitor_brands ) if count > 0: competitor_mentions[comp_name] = count # 情感分析 sentiment_service = get_sentiment_service() sentiment_counts = {"positive": 0, "neutral": 0, "negative": 0} for citation in brand_citations: if citation.sentiment and citation.sentiment in ("positive", "neutral", "negative"): sentiment_counts[citation.sentiment] += 1 else: content = citation.raw_response or citation.citation_text or "" if content.strip(): try: result = await sentiment_service.analyze( brand_name=brand.name, content=content, ) sentiment_counts[result.sentiment] += 1 except Exception: sentiment_counts["neutral"] += 1 else: sentiment_counts["neutral"] += 1 # 构建CitationResult列表 citation_results = [ CitationResult( cited=c.cited, position=c.citation_position, citation_text=c.citation_text, sentiment="neutral", confidence=c.confidence or 0.0, ) for c in brand_citations ] # 提取位置列表 positions = [c.citation_position for c in brand_citations if c.cited] # 计算V2评分 scoring_service = ScoringService() v2_result = scoring_service.calculate_v2( mentioned_count=len(brand_citations), total_queries=total_queries, positions=positions, sentiment_counts=sentiment_counts, citations=citation_results, brand_mentions=len(brand_citations), competitor_mentions=competitor_mentions, ) # 计算平台评分 from app.services.scoring.brand_scoring_data_service import REQUIRED_PLATFORMS platform_scores: dict[str, float] = {} for platform in REQUIRED_PLATFORMS: platform_citations = [c for c in all_citations if c.platform == platform] total_p = len(platform_citations) cited_p = sum(1 for c in platform_citations if c.cited) platform_scores[platform] = round((cited_p / total_p * 100) if total_p > 0 else 0.0, 2) # 竞品对比数据 competitor_data = { "brand_mentions": len(brand_citations), "competitor_mentions": competitor_mentions, "ahead_count": sum(1 for count in competitor_mentions.values() if len(brand_citations) > count), "behind_count": sum(1 for count in competitor_mentions.values() if len(brand_citations) <= count), } return ( v2_result, competitor_data, sentiment_counts, platform_scores, total_queries, len(brand_citations), ) @router.get("/{brand_id}/suggestions", response_model=SuggestionListResponse) async def get_suggestions( brand_id: uuid.UUID, type: str | None = Query(None, description="按类型筛选"), priority: str | None = Query(None, description="按优先级筛选"), status_filter: str | None = Query(None, alias="status", description="按状态筛选"), skip: int = Query(0, ge=0), limit: int = Query(20, ge=1, le=100), current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """ 获取品牌的优化建议 如果没有缓存建议,自动生成新的建议。 支持按类型、优先级、状态筛选。 """ brand = await _get_brand_with_access(brand_id, db, current_user) # 检查是否有缓存的有效建议(最近7天内生成的pending/in_progress建议) seven_days_ago = datetime.now() - timedelta(days=7) cache_stmt = select(func.count()).select_from(Suggestion).where( Suggestion.brand_id == brand_id, Suggestion.generated_at >= seven_days_ago, Suggestion.status.in_(["pending", "in_progress"]), ) cache_result = await db.execute(cache_stmt) cached_count = cache_result.scalar_one() # 如果没有缓存建议,自动生成 if cached_count == 0: await _generate_and_save_suggestions(db, brand, current_user) # 构建查询 conditions = [Suggestion.brand_id == brand_id] if type: conditions.append(Suggestion.type == type) if priority: conditions.append(Suggestion.priority == priority) if status_filter: conditions.append(Suggestion.status == status_filter) # 查询总数 count_stmt = select(func.count()).select_from(Suggestion).where(*conditions) count_result = await db.execute(count_stmt) total = count_result.scalar_one() # 查询建议列表(按优先级和生成时间排序) priority_order = {"high": 0, "medium": 1, "low": 2} suggestions_stmt = ( select(Suggestion) .where(*conditions) .order_by( Suggestion.priority.asc(), Suggestion.generated_at.desc(), ) .offset(skip) .limit(limit) ) suggestions_result = await db.execute(suggestions_stmt) suggestions = list(suggestions_result.scalars().all()) return SuggestionListResponse( suggestions=[SuggestionResponse.model_validate(s) for s in suggestions], total=total, ) @router.post("/{brand_id}/suggestions/regenerate", response_model=SuggestionListResponse) async def regenerate_suggestions( brand_id: uuid.UUID, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """ 重新生成优化建议 将现有的pending建议标记为dismissed,然后重新生成。 """ brand = await _get_brand_with_access(brand_id, db, current_user) # 将现有pending建议标记为dismissed dismiss_stmt = ( select(Suggestion) .where( Suggestion.brand_id == brand_id, Suggestion.status == "pending", ) ) dismiss_result = await db.execute(dismiss_stmt) old_suggestions = list(dismiss_result.scalars().all()) for old_suggestion in old_suggestions: old_suggestion.status = "dismissed" await db.flush() # 生成新建议 new_suggestions = await _generate_and_save_suggestions(db, brand, current_user) return SuggestionListResponse( suggestions=[SuggestionResponse.model_validate(s) for s in new_suggestions], total=len(new_suggestions), ) @router.put("/{brand_id}/suggestions/{suggestion_id}/status", response_model=SuggestionResponse) async def update_suggestion_status( brand_id: uuid.UUID, suggestion_id: uuid.UUID, status_update: SuggestionUpdateStatus, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """ 更新建议状态 支持的状态: pending/in_progress/completed/dismissed """ await _get_brand_with_access(brand_id, db, current_user) # 查找建议 stmt = select(Suggestion).where( Suggestion.id == suggestion_id, Suggestion.brand_id == brand_id, ) result = await db.execute(stmt) suggestion = result.scalar_one_or_none() if not suggestion: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="建议不存在", ) # 验证状态值 valid_statuses = {"pending", "in_progress", "completed", "dismissed"} if status_update.status not in valid_statuses: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"无效的状态值,支持: {', '.join(valid_statuses)}", ) suggestion.status = status_update.status await db.commit() await db.refresh(suggestion) return SuggestionResponse.model_validate(suggestion) @router.get("/{brand_id}/suggestions/history", response_model=SuggestionHistoryResponse) async def get_suggestions_history( brand_id: uuid.UUID, skip: int = Query(0, ge=0), limit: int = Query(10, ge=1, le=50), current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): """ 获取建议生成历史 按批次ID分组,展示历次生成的建议。 """ await _get_brand_with_access(brand_id, db, current_user) # 查询所有批次 batch_stmt = ( select( Suggestion.batch_id, Suggestion.generated_at, Suggestion.source, func.count().label("suggestion_count"), ) .where(Suggestion.brand_id == brand_id) .group_by(Suggestion.batch_id, Suggestion.generated_at, Suggestion.source) .order_by(Suggestion.generated_at.desc()) .offset(skip) .limit(limit) ) batch_result = await db.execute(batch_stmt) batches = batch_result.all() # 查询总数 total_stmt = ( select(func.count()) .select_from( select(Suggestion.batch_id) .where(Suggestion.brand_id == brand_id) .group_by(Suggestion.batch_id) .subquery() ) ) total_result = await db.execute(total_stmt) total = total_result.scalar_one() history = [] for batch in batches: # 获取该批次的建议 batch_suggestions_stmt = ( select(Suggestion) .where( Suggestion.brand_id == brand_id, Suggestion.batch_id == batch.batch_id, ) .order_by(Suggestion.priority.asc()) ) batch_suggestions_result = await db.execute(batch_suggestions_stmt) batch_suggestions = list(batch_suggestions_result.scalars().all()) history.append(SuggestionHistoryItem( batch_id=batch.batch_id, generated_at=batch.generated_at, source=batch.source, suggestion_count=batch.suggestion_count, suggestions=[SuggestionResponse.model_validate(s) for s in batch_suggestions], )) return SuggestionHistoryResponse(history=history, total=total) async def _generate_and_save_suggestions( db: AsyncSession, brand: Brand, current_user: User, ) -> list[Suggestion]: """ 生成并保存优化建议 Returns: 新创建的Suggestion列表 """ # 获取评分数据 ( v2_result, competitor_data, sentiment_data, platform_scores, total_queries, mentioned_count, ) = await _get_brand_scoring_data(db, _to_uuid(current_user.id), brand) # 构建分析上下文 ctx = build_context_from_scoring_result( brand_name=brand.name, scoring_result=v2_result, competitor_data=competitor_data, sentiment_data=sentiment_data, platform_scores=platform_scores, total_queries=total_queries, mentioned_count=mentioned_count, ) # 生成建议 suggestion_items = await generate_suggestions(ctx) # 确定来源 from app.config import settings source = "llm" if (settings.ENABLE_LLM and settings.DEEPSEEK_API_KEY) else "rule" # 生成批次ID batch_id = uuid.uuid4() # 保存到数据库 db_suggestions = [] for item in suggestion_items: db_suggestion = Suggestion( brand_id=brand.id, type=item.type, priority=item.priority, title=item.title, description=item.description, action=item.action, expected_impact=item.expected_impact, difficulty=item.difficulty, status="pending", batch_id=batch_id, source=source, ) db.add(db_suggestion) db_suggestions.append(db_suggestion) await db.commit() # 刷新获取生成的ID和时间戳 for s in db_suggestions: await db.refresh(s) return db_suggestions