484 lines
15 KiB
Python
484 lines
15 KiB
Python
"""优化建议 API endpoints."""
|
||
import uuid
|
||
from datetime import datetime, timedelta
|
||
|
||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||
from sqlalchemy import select, func, and_
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.api.deps import get_current_user
|
||
from app.database import get_db
|
||
from app.models.user import User
|
||
from app.models.brand import Brand
|
||
from app.models.competitor import Competitor
|
||
from app.models.query import Query as QueryModel
|
||
from app.models.citation_record import CitationRecord
|
||
from app.models.suggestion import Suggestion
|
||
from app.schemas.suggestion import (
|
||
SuggestionResponse,
|
||
SuggestionListResponse,
|
||
SuggestionUpdateStatus,
|
||
SuggestionHistoryItem,
|
||
SuggestionHistoryResponse,
|
||
)
|
||
from app.schemas.scoring import CitationResult
|
||
from app.services.scoring.scoring_service import ScoringService
|
||
from app.services.analysis.sentiment_service import get_sentiment_service
|
||
from app.services.advisor.optimization_advisor import (
|
||
generate_suggestions,
|
||
build_context_from_scoring_result,
|
||
)
|
||
|
||
router = APIRouter()
|
||
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
|
||
if isinstance(value, uuid.UUID):
|
||
return value
|
||
return uuid.UUID(str(value))
|
||
|
||
|
||
async def _get_brand_with_access(
|
||
brand_id: uuid.UUID,
|
||
db: AsyncSession,
|
||
current_user: User,
|
||
) -> Brand:
|
||
"""验证品牌存在且用户有访问权限"""
|
||
stmt = select(Brand).where(
|
||
Brand.id == brand_id,
|
||
Brand.user_id == _to_uuid(current_user.id),
|
||
)
|
||
result = await db.execute(stmt)
|
||
brand = result.scalar_one_or_none()
|
||
|
||
if not brand:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="品牌不存在",
|
||
)
|
||
return brand
|
||
|
||
|
||
async def _get_brand_scoring_data(
|
||
db: AsyncSession,
|
||
user_id: uuid.UUID,
|
||
brand: Brand,
|
||
) -> tuple:
|
||
"""
|
||
获取品牌评分数据,用于生成优化建议。
|
||
|
||
Returns:
|
||
(scoring_result, competitor_data, sentiment_data, platform_scores, total_queries, mentioned_count)
|
||
"""
|
||
# 获取品牌查询
|
||
queries_stmt = select(QueryModel).where(
|
||
QueryModel.user_id == user_id,
|
||
QueryModel.target_brand == brand.name,
|
||
)
|
||
queries_result = await db.execute(queries_stmt)
|
||
queries = list(queries_result.scalars().all())
|
||
|
||
if not queries:
|
||
# 没有查询数据,返回空评分
|
||
scoring_service = ScoringService()
|
||
empty_result = scoring_service.calculate_v2(
|
||
mentioned_count=0,
|
||
total_queries=0,
|
||
positions=[],
|
||
sentiment_counts={"positive": 0, "neutral": 0, "negative": 0},
|
||
citations=[],
|
||
brand_mentions=0,
|
||
competitor_mentions={},
|
||
)
|
||
return empty_result, {}, {}, {}, 0, 0
|
||
|
||
query_ids = [q.id for q in queries]
|
||
|
||
# 获取所有引用记录
|
||
citations_stmt = select(CitationRecord).where(
|
||
CitationRecord.query_id.in_(query_ids),
|
||
)
|
||
citations_result = await db.execute(citations_stmt)
|
||
all_citations = list(citations_result.scalars().all())
|
||
|
||
total_queries = len(all_citations)
|
||
brand_citations = [c for c in all_citations if c.cited]
|
||
|
||
# 获取竞品数据
|
||
competitor_stmt = select(Competitor).where(Competitor.brand_id == brand.id)
|
||
competitor_result = await db.execute(competitor_stmt)
|
||
competitors = list(competitor_result.scalars().all())
|
||
competitor_names = [c.name for c in competitors]
|
||
|
||
# 计算竞品提及
|
||
competitor_mentions: dict[str, int] = {}
|
||
for comp_name in competitor_names:
|
||
count = sum(
|
||
1 for c in all_citations
|
||
if c.cited and c.competitor_brands
|
||
and comp_name in c.competitor_brands
|
||
)
|
||
if count > 0:
|
||
competitor_mentions[comp_name] = count
|
||
|
||
# 情感分析
|
||
sentiment_service = get_sentiment_service()
|
||
sentiment_counts = {"positive": 0, "neutral": 0, "negative": 0}
|
||
for citation in brand_citations:
|
||
if citation.sentiment and citation.sentiment in ("positive", "neutral", "negative"):
|
||
sentiment_counts[citation.sentiment] += 1
|
||
else:
|
||
content = citation.raw_response or citation.citation_text or ""
|
||
if content.strip():
|
||
try:
|
||
result = await sentiment_service.analyze(
|
||
brand_name=brand.name,
|
||
content=content,
|
||
)
|
||
sentiment_counts[result.sentiment] += 1
|
||
except Exception:
|
||
sentiment_counts["neutral"] += 1
|
||
else:
|
||
sentiment_counts["neutral"] += 1
|
||
|
||
# 构建CitationResult列表
|
||
citation_results = [
|
||
CitationResult(
|
||
cited=c.cited,
|
||
position=c.citation_position,
|
||
citation_text=c.citation_text,
|
||
sentiment="neutral",
|
||
confidence=c.confidence or 0.0,
|
||
)
|
||
for c in brand_citations
|
||
]
|
||
|
||
# 提取位置列表
|
||
positions = [c.citation_position for c in brand_citations if c.cited]
|
||
|
||
# 计算V2评分
|
||
scoring_service = ScoringService()
|
||
v2_result = scoring_service.calculate_v2(
|
||
mentioned_count=len(brand_citations),
|
||
total_queries=total_queries,
|
||
positions=positions,
|
||
sentiment_counts=sentiment_counts,
|
||
citations=citation_results,
|
||
brand_mentions=len(brand_citations),
|
||
competitor_mentions=competitor_mentions,
|
||
)
|
||
|
||
# 计算平台评分
|
||
from app.services.scoring.brand_scoring_data_service import REQUIRED_PLATFORMS
|
||
platform_scores: dict[str, float] = {}
|
||
for platform in REQUIRED_PLATFORMS:
|
||
platform_citations = [c for c in all_citations if c.platform == platform]
|
||
total_p = len(platform_citations)
|
||
cited_p = sum(1 for c in platform_citations if c.cited)
|
||
platform_scores[platform] = round((cited_p / total_p * 100) if total_p > 0 else 0.0, 2)
|
||
|
||
# 竞品对比数据
|
||
competitor_data = {
|
||
"brand_mentions": len(brand_citations),
|
||
"competitor_mentions": competitor_mentions,
|
||
"ahead_count": sum(1 for count in competitor_mentions.values() if len(brand_citations) > count),
|
||
"behind_count": sum(1 for count in competitor_mentions.values() if len(brand_citations) <= count),
|
||
}
|
||
|
||
return (
|
||
v2_result,
|
||
competitor_data,
|
||
sentiment_counts,
|
||
platform_scores,
|
||
total_queries,
|
||
len(brand_citations),
|
||
)
|
||
|
||
|
||
@router.get("/{brand_id}/suggestions", response_model=SuggestionListResponse)
|
||
async def get_suggestions(
|
||
brand_id: uuid.UUID,
|
||
type: str | None = Query(None, description="按类型筛选"),
|
||
priority: str | None = Query(None, description="按优先级筛选"),
|
||
status_filter: str | None = Query(None, alias="status", description="按状态筛选"),
|
||
skip: int = Query(0, ge=0),
|
||
limit: int = Query(20, ge=1, le=100),
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""
|
||
获取品牌的优化建议
|
||
|
||
如果没有缓存建议,自动生成新的建议。
|
||
支持按类型、优先级、状态筛选。
|
||
"""
|
||
brand = await _get_brand_with_access(brand_id, db, current_user)
|
||
|
||
# 检查是否有缓存的有效建议(最近7天内生成的pending/in_progress建议)
|
||
seven_days_ago = datetime.now() - timedelta(days=7)
|
||
cache_stmt = select(func.count()).select_from(Suggestion).where(
|
||
Suggestion.brand_id == brand_id,
|
||
Suggestion.generated_at >= seven_days_ago,
|
||
Suggestion.status.in_(["pending", "in_progress"]),
|
||
)
|
||
cache_result = await db.execute(cache_stmt)
|
||
cached_count = cache_result.scalar_one()
|
||
|
||
# 如果没有缓存建议,自动生成
|
||
if cached_count == 0:
|
||
await _generate_and_save_suggestions(db, brand, current_user)
|
||
|
||
# 构建查询
|
||
conditions = [Suggestion.brand_id == brand_id]
|
||
if type:
|
||
conditions.append(Suggestion.type == type)
|
||
if priority:
|
||
conditions.append(Suggestion.priority == priority)
|
||
if status_filter:
|
||
conditions.append(Suggestion.status == status_filter)
|
||
|
||
# 查询总数
|
||
count_stmt = select(func.count()).select_from(Suggestion).where(*conditions)
|
||
count_result = await db.execute(count_stmt)
|
||
total = count_result.scalar_one()
|
||
|
||
# 查询建议列表(按优先级和生成时间排序)
|
||
priority_order = {"high": 0, "medium": 1, "low": 2}
|
||
suggestions_stmt = (
|
||
select(Suggestion)
|
||
.where(*conditions)
|
||
.order_by(
|
||
Suggestion.priority.asc(),
|
||
Suggestion.generated_at.desc(),
|
||
)
|
||
.offset(skip)
|
||
.limit(limit)
|
||
)
|
||
suggestions_result = await db.execute(suggestions_stmt)
|
||
suggestions = list(suggestions_result.scalars().all())
|
||
|
||
return SuggestionListResponse(
|
||
suggestions=[SuggestionResponse.model_validate(s) for s in suggestions],
|
||
total=total,
|
||
)
|
||
|
||
|
||
@router.post("/{brand_id}/suggestions/regenerate", response_model=SuggestionListResponse)
|
||
async def regenerate_suggestions(
|
||
brand_id: uuid.UUID,
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""
|
||
重新生成优化建议
|
||
|
||
将现有的pending建议标记为dismissed,然后重新生成。
|
||
"""
|
||
brand = await _get_brand_with_access(brand_id, db, current_user)
|
||
|
||
# 将现有pending建议标记为dismissed
|
||
dismiss_stmt = (
|
||
select(Suggestion)
|
||
.where(
|
||
Suggestion.brand_id == brand_id,
|
||
Suggestion.status == "pending",
|
||
)
|
||
)
|
||
dismiss_result = await db.execute(dismiss_stmt)
|
||
old_suggestions = list(dismiss_result.scalars().all())
|
||
|
||
for old_suggestion in old_suggestions:
|
||
old_suggestion.status = "dismissed"
|
||
|
||
await db.flush()
|
||
|
||
# 生成新建议
|
||
new_suggestions = await _generate_and_save_suggestions(db, brand, current_user)
|
||
|
||
return SuggestionListResponse(
|
||
suggestions=[SuggestionResponse.model_validate(s) for s in new_suggestions],
|
||
total=len(new_suggestions),
|
||
)
|
||
|
||
|
||
@router.put("/{brand_id}/suggestions/{suggestion_id}/status", response_model=SuggestionResponse)
|
||
async def update_suggestion_status(
|
||
brand_id: uuid.UUID,
|
||
suggestion_id: uuid.UUID,
|
||
status_update: SuggestionUpdateStatus,
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""
|
||
更新建议状态
|
||
|
||
支持的状态: pending/in_progress/completed/dismissed
|
||
"""
|
||
await _get_brand_with_access(brand_id, db, current_user)
|
||
|
||
# 查找建议
|
||
stmt = select(Suggestion).where(
|
||
Suggestion.id == suggestion_id,
|
||
Suggestion.brand_id == brand_id,
|
||
)
|
||
result = await db.execute(stmt)
|
||
suggestion = result.scalar_one_or_none()
|
||
|
||
if not suggestion:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="建议不存在",
|
||
)
|
||
|
||
# 验证状态值
|
||
valid_statuses = {"pending", "in_progress", "completed", "dismissed"}
|
||
if status_update.status not in valid_statuses:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail=f"无效的状态值,支持: {', '.join(valid_statuses)}",
|
||
)
|
||
|
||
suggestion.status = status_update.status
|
||
await db.commit()
|
||
await db.refresh(suggestion)
|
||
|
||
return SuggestionResponse.model_validate(suggestion)
|
||
|
||
|
||
@router.get("/{brand_id}/suggestions/history", response_model=SuggestionHistoryResponse)
|
||
async def get_suggestions_history(
|
||
brand_id: uuid.UUID,
|
||
skip: int = Query(0, ge=0),
|
||
limit: int = Query(10, ge=1, le=50),
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""
|
||
获取建议生成历史
|
||
|
||
按批次ID分组,展示历次生成的建议。
|
||
"""
|
||
await _get_brand_with_access(brand_id, db, current_user)
|
||
|
||
# 查询所有批次
|
||
batch_stmt = (
|
||
select(
|
||
Suggestion.batch_id,
|
||
Suggestion.generated_at,
|
||
Suggestion.source,
|
||
func.count().label("suggestion_count"),
|
||
)
|
||
.where(Suggestion.brand_id == brand_id)
|
||
.group_by(Suggestion.batch_id, Suggestion.generated_at, Suggestion.source)
|
||
.order_by(Suggestion.generated_at.desc())
|
||
.offset(skip)
|
||
.limit(limit)
|
||
)
|
||
batch_result = await db.execute(batch_stmt)
|
||
batches = batch_result.all()
|
||
|
||
# 查询总数
|
||
total_stmt = (
|
||
select(func.count())
|
||
.select_from(
|
||
select(Suggestion.batch_id)
|
||
.where(Suggestion.brand_id == brand_id)
|
||
.group_by(Suggestion.batch_id)
|
||
.subquery()
|
||
)
|
||
)
|
||
total_result = await db.execute(total_stmt)
|
||
total = total_result.scalar_one()
|
||
|
||
history = []
|
||
for batch in batches:
|
||
# 获取该批次的建议
|
||
batch_suggestions_stmt = (
|
||
select(Suggestion)
|
||
.where(
|
||
Suggestion.brand_id == brand_id,
|
||
Suggestion.batch_id == batch.batch_id,
|
||
)
|
||
.order_by(Suggestion.priority.asc())
|
||
)
|
||
batch_suggestions_result = await db.execute(batch_suggestions_stmt)
|
||
batch_suggestions = list(batch_suggestions_result.scalars().all())
|
||
|
||
history.append(SuggestionHistoryItem(
|
||
batch_id=batch.batch_id,
|
||
generated_at=batch.generated_at,
|
||
source=batch.source,
|
||
suggestion_count=batch.suggestion_count,
|
||
suggestions=[SuggestionResponse.model_validate(s) for s in batch_suggestions],
|
||
))
|
||
|
||
return SuggestionHistoryResponse(history=history, total=total)
|
||
|
||
|
||
async def _generate_and_save_suggestions(
|
||
db: AsyncSession,
|
||
brand: Brand,
|
||
current_user: User,
|
||
) -> list[Suggestion]:
|
||
"""
|
||
生成并保存优化建议
|
||
|
||
Returns:
|
||
新创建的Suggestion列表
|
||
"""
|
||
# 获取评分数据
|
||
(
|
||
v2_result,
|
||
competitor_data,
|
||
sentiment_data,
|
||
platform_scores,
|
||
total_queries,
|
||
mentioned_count,
|
||
) = await _get_brand_scoring_data(db, _to_uuid(current_user.id), brand)
|
||
|
||
# 构建分析上下文
|
||
ctx = build_context_from_scoring_result(
|
||
brand_name=brand.name,
|
||
scoring_result=v2_result,
|
||
competitor_data=competitor_data,
|
||
sentiment_data=sentiment_data,
|
||
platform_scores=platform_scores,
|
||
total_queries=total_queries,
|
||
mentioned_count=mentioned_count,
|
||
)
|
||
|
||
# 生成建议
|
||
suggestion_items = await generate_suggestions(ctx)
|
||
|
||
# 确定来源
|
||
from app.config import settings
|
||
source = "llm" if (settings.ENABLE_LLM and settings.DEEPSEEK_API_KEY) else "rule"
|
||
|
||
# 生成批次ID
|
||
batch_id = uuid.uuid4()
|
||
|
||
# 保存到数据库
|
||
db_suggestions = []
|
||
for item in suggestion_items:
|
||
db_suggestion = Suggestion(
|
||
brand_id=brand.id,
|
||
type=item.type,
|
||
priority=item.priority,
|
||
title=item.title,
|
||
description=item.description,
|
||
action=item.action,
|
||
expected_impact=item.expected_impact,
|
||
difficulty=item.difficulty,
|
||
status="pending",
|
||
batch_id=batch_id,
|
||
source=source,
|
||
)
|
||
db.add(db_suggestion)
|
||
db_suggestions.append(db_suggestion)
|
||
|
||
await db.commit()
|
||
|
||
# 刷新获取生成的ID和时间戳
|
||
for s in db_suggestions:
|
||
await db.refresh(s)
|
||
|
||
return db_suggestions
|