geo/backend/app/api/suggestions.py

484 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""优化建议 API endpoints."""
import uuid
from datetime import datetime, timedelta
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import select, func, and_
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.database import get_db
from app.models.user import User
from app.models.brand import Brand
from app.models.competitor import Competitor
from app.models.query import Query as QueryModel
from app.models.citation_record import CitationRecord
from app.models.suggestion import Suggestion
from app.schemas.suggestion import (
SuggestionResponse,
SuggestionListResponse,
SuggestionUpdateStatus,
SuggestionHistoryItem,
SuggestionHistoryResponse,
)
from app.schemas.scoring import CitationResult
from app.services.scoring.scoring_service import ScoringService
from app.services.analysis.sentiment_service import get_sentiment_service
from app.services.advisor.optimization_advisor import (
generate_suggestions,
build_context_from_scoring_result,
)
router = APIRouter()
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
if isinstance(value, uuid.UUID):
return value
return uuid.UUID(str(value))
async def _get_brand_with_access(
brand_id: uuid.UUID,
db: AsyncSession,
current_user: User,
) -> Brand:
"""验证品牌存在且用户有访问权限"""
stmt = select(Brand).where(
Brand.id == brand_id,
Brand.user_id == _to_uuid(current_user.id),
)
result = await db.execute(stmt)
brand = result.scalar_one_or_none()
if not brand:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="品牌不存在",
)
return brand
async def _get_brand_scoring_data(
db: AsyncSession,
user_id: uuid.UUID,
brand: Brand,
) -> tuple:
"""
获取品牌评分数据,用于生成优化建议。
Returns:
(scoring_result, competitor_data, sentiment_data, platform_scores, total_queries, mentioned_count)
"""
# 获取品牌查询
queries_stmt = select(QueryModel).where(
QueryModel.user_id == user_id,
QueryModel.target_brand == brand.name,
)
queries_result = await db.execute(queries_stmt)
queries = list(queries_result.scalars().all())
if not queries:
# 没有查询数据,返回空评分
scoring_service = ScoringService()
empty_result = scoring_service.calculate_v2(
mentioned_count=0,
total_queries=0,
positions=[],
sentiment_counts={"positive": 0, "neutral": 0, "negative": 0},
citations=[],
brand_mentions=0,
competitor_mentions={},
)
return empty_result, {}, {}, {}, 0, 0
query_ids = [q.id for q in queries]
# 获取所有引用记录
citations_stmt = select(CitationRecord).where(
CitationRecord.query_id.in_(query_ids),
)
citations_result = await db.execute(citations_stmt)
all_citations = list(citations_result.scalars().all())
total_queries = len(all_citations)
brand_citations = [c for c in all_citations if c.cited]
# 获取竞品数据
competitor_stmt = select(Competitor).where(Competitor.brand_id == brand.id)
competitor_result = await db.execute(competitor_stmt)
competitors = list(competitor_result.scalars().all())
competitor_names = [c.name for c in competitors]
# 计算竞品提及
competitor_mentions: dict[str, int] = {}
for comp_name in competitor_names:
count = sum(
1 for c in all_citations
if c.cited and c.competitor_brands
and comp_name in c.competitor_brands
)
if count > 0:
competitor_mentions[comp_name] = count
# 情感分析
sentiment_service = get_sentiment_service()
sentiment_counts = {"positive": 0, "neutral": 0, "negative": 0}
for citation in brand_citations:
if citation.sentiment and citation.sentiment in ("positive", "neutral", "negative"):
sentiment_counts[citation.sentiment] += 1
else:
content = citation.raw_response or citation.citation_text or ""
if content.strip():
try:
result = await sentiment_service.analyze(
brand_name=brand.name,
content=content,
)
sentiment_counts[result.sentiment] += 1
except Exception:
sentiment_counts["neutral"] += 1
else:
sentiment_counts["neutral"] += 1
# 构建CitationResult列表
citation_results = [
CitationResult(
cited=c.cited,
position=c.citation_position,
citation_text=c.citation_text,
sentiment="neutral",
confidence=c.confidence or 0.0,
)
for c in brand_citations
]
# 提取位置列表
positions = [c.citation_position for c in brand_citations if c.cited]
# 计算V2评分
scoring_service = ScoringService()
v2_result = scoring_service.calculate_v2(
mentioned_count=len(brand_citations),
total_queries=total_queries,
positions=positions,
sentiment_counts=sentiment_counts,
citations=citation_results,
brand_mentions=len(brand_citations),
competitor_mentions=competitor_mentions,
)
# 计算平台评分
from app.services.scoring.brand_scoring_data_service import REQUIRED_PLATFORMS
platform_scores: dict[str, float] = {}
for platform in REQUIRED_PLATFORMS:
platform_citations = [c for c in all_citations if c.platform == platform]
total_p = len(platform_citations)
cited_p = sum(1 for c in platform_citations if c.cited)
platform_scores[platform] = round((cited_p / total_p * 100) if total_p > 0 else 0.0, 2)
# 竞品对比数据
competitor_data = {
"brand_mentions": len(brand_citations),
"competitor_mentions": competitor_mentions,
"ahead_count": sum(1 for count in competitor_mentions.values() if len(brand_citations) > count),
"behind_count": sum(1 for count in competitor_mentions.values() if len(brand_citations) <= count),
}
return (
v2_result,
competitor_data,
sentiment_counts,
platform_scores,
total_queries,
len(brand_citations),
)
@router.get("/{brand_id}/suggestions", response_model=SuggestionListResponse)
async def get_suggestions(
brand_id: uuid.UUID,
type: str | None = Query(None, description="按类型筛选"),
priority: str | None = Query(None, description="按优先级筛选"),
status_filter: str | None = Query(None, alias="status", description="按状态筛选"),
skip: int = Query(0, ge=0),
limit: int = Query(20, ge=1, le=100),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""
获取品牌的优化建议
如果没有缓存建议,自动生成新的建议。
支持按类型、优先级、状态筛选。
"""
brand = await _get_brand_with_access(brand_id, db, current_user)
# 检查是否有缓存的有效建议最近7天内生成的pending/in_progress建议
seven_days_ago = datetime.now() - timedelta(days=7)
cache_stmt = select(func.count()).select_from(Suggestion).where(
Suggestion.brand_id == brand_id,
Suggestion.generated_at >= seven_days_ago,
Suggestion.status.in_(["pending", "in_progress"]),
)
cache_result = await db.execute(cache_stmt)
cached_count = cache_result.scalar_one()
# 如果没有缓存建议,自动生成
if cached_count == 0:
await _generate_and_save_suggestions(db, brand, current_user)
# 构建查询
conditions = [Suggestion.brand_id == brand_id]
if type:
conditions.append(Suggestion.type == type)
if priority:
conditions.append(Suggestion.priority == priority)
if status_filter:
conditions.append(Suggestion.status == status_filter)
# 查询总数
count_stmt = select(func.count()).select_from(Suggestion).where(*conditions)
count_result = await db.execute(count_stmt)
total = count_result.scalar_one()
# 查询建议列表(按优先级和生成时间排序)
priority_order = {"high": 0, "medium": 1, "low": 2}
suggestions_stmt = (
select(Suggestion)
.where(*conditions)
.order_by(
Suggestion.priority.asc(),
Suggestion.generated_at.desc(),
)
.offset(skip)
.limit(limit)
)
suggestions_result = await db.execute(suggestions_stmt)
suggestions = list(suggestions_result.scalars().all())
return SuggestionListResponse(
suggestions=[SuggestionResponse.model_validate(s) for s in suggestions],
total=total,
)
@router.post("/{brand_id}/suggestions/regenerate", response_model=SuggestionListResponse)
async def regenerate_suggestions(
brand_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""
重新生成优化建议
将现有的pending建议标记为dismissed然后重新生成。
"""
brand = await _get_brand_with_access(brand_id, db, current_user)
# 将现有pending建议标记为dismissed
dismiss_stmt = (
select(Suggestion)
.where(
Suggestion.brand_id == brand_id,
Suggestion.status == "pending",
)
)
dismiss_result = await db.execute(dismiss_stmt)
old_suggestions = list(dismiss_result.scalars().all())
for old_suggestion in old_suggestions:
old_suggestion.status = "dismissed"
await db.flush()
# 生成新建议
new_suggestions = await _generate_and_save_suggestions(db, brand, current_user)
return SuggestionListResponse(
suggestions=[SuggestionResponse.model_validate(s) for s in new_suggestions],
total=len(new_suggestions),
)
@router.put("/{brand_id}/suggestions/{suggestion_id}/status", response_model=SuggestionResponse)
async def update_suggestion_status(
brand_id: uuid.UUID,
suggestion_id: uuid.UUID,
status_update: SuggestionUpdateStatus,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""
更新建议状态
支持的状态: pending/in_progress/completed/dismissed
"""
await _get_brand_with_access(brand_id, db, current_user)
# 查找建议
stmt = select(Suggestion).where(
Suggestion.id == suggestion_id,
Suggestion.brand_id == brand_id,
)
result = await db.execute(stmt)
suggestion = result.scalar_one_or_none()
if not suggestion:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="建议不存在",
)
# 验证状态值
valid_statuses = {"pending", "in_progress", "completed", "dismissed"}
if status_update.status not in valid_statuses:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"无效的状态值,支持: {', '.join(valid_statuses)}",
)
suggestion.status = status_update.status
await db.commit()
await db.refresh(suggestion)
return SuggestionResponse.model_validate(suggestion)
@router.get("/{brand_id}/suggestions/history", response_model=SuggestionHistoryResponse)
async def get_suggestions_history(
brand_id: uuid.UUID,
skip: int = Query(0, ge=0),
limit: int = Query(10, ge=1, le=50),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""
获取建议生成历史
按批次ID分组展示历次生成的建议。
"""
await _get_brand_with_access(brand_id, db, current_user)
# 查询所有批次
batch_stmt = (
select(
Suggestion.batch_id,
Suggestion.generated_at,
Suggestion.source,
func.count().label("suggestion_count"),
)
.where(Suggestion.brand_id == brand_id)
.group_by(Suggestion.batch_id, Suggestion.generated_at, Suggestion.source)
.order_by(Suggestion.generated_at.desc())
.offset(skip)
.limit(limit)
)
batch_result = await db.execute(batch_stmt)
batches = batch_result.all()
# 查询总数
total_stmt = (
select(func.count())
.select_from(
select(Suggestion.batch_id)
.where(Suggestion.brand_id == brand_id)
.group_by(Suggestion.batch_id)
.subquery()
)
)
total_result = await db.execute(total_stmt)
total = total_result.scalar_one()
history = []
for batch in batches:
# 获取该批次的建议
batch_suggestions_stmt = (
select(Suggestion)
.where(
Suggestion.brand_id == brand_id,
Suggestion.batch_id == batch.batch_id,
)
.order_by(Suggestion.priority.asc())
)
batch_suggestions_result = await db.execute(batch_suggestions_stmt)
batch_suggestions = list(batch_suggestions_result.scalars().all())
history.append(SuggestionHistoryItem(
batch_id=batch.batch_id,
generated_at=batch.generated_at,
source=batch.source,
suggestion_count=batch.suggestion_count,
suggestions=[SuggestionResponse.model_validate(s) for s in batch_suggestions],
))
return SuggestionHistoryResponse(history=history, total=total)
async def _generate_and_save_suggestions(
db: AsyncSession,
brand: Brand,
current_user: User,
) -> list[Suggestion]:
"""
生成并保存优化建议
Returns:
新创建的Suggestion列表
"""
# 获取评分数据
(
v2_result,
competitor_data,
sentiment_data,
platform_scores,
total_queries,
mentioned_count,
) = await _get_brand_scoring_data(db, _to_uuid(current_user.id), brand)
# 构建分析上下文
ctx = build_context_from_scoring_result(
brand_name=brand.name,
scoring_result=v2_result,
competitor_data=competitor_data,
sentiment_data=sentiment_data,
platform_scores=platform_scores,
total_queries=total_queries,
mentioned_count=mentioned_count,
)
# 生成建议
suggestion_items = await generate_suggestions(ctx)
# 确定来源
from app.config import settings
source = "llm" if (settings.ENABLE_LLM and settings.DEEPSEEK_API_KEY) else "rule"
# 生成批次ID
batch_id = uuid.uuid4()
# 保存到数据库
db_suggestions = []
for item in suggestion_items:
db_suggestion = Suggestion(
brand_id=brand.id,
type=item.type,
priority=item.priority,
title=item.title,
description=item.description,
action=item.action,
expected_impact=item.expected_impact,
difficulty=item.difficulty,
status="pending",
batch_id=batch_id,
source=source,
)
db.add(db_suggestion)
db_suggestions.append(db_suggestion)
await db.commit()
# 刷新获取生成的ID和时间戳
for s in db_suggestions:
await db.refresh(s)
return db_suggestions