488 lines
16 KiB
Python
488 lines
16 KiB
Python
"""Dashboard API endpoints."""
|
||
import uuid
|
||
from datetime import datetime, timedelta
|
||
|
||
from fastapi import APIRouter, Depends, Query
|
||
from sqlalchemy import select, func, Integer
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.api.deps import get_current_user
|
||
from app.database import get_db
|
||
from app.models.user import User
|
||
from app.models.brand import Brand
|
||
from app.models.competitor import Competitor
|
||
from app.models.query import Query as QueryModel
|
||
from app.models.citation_record import CitationRecord
|
||
from app.schemas.dashboard import (
|
||
DashboardStatsResponse,
|
||
DimensionScoreItem,
|
||
PlatformScoreItem,
|
||
RecentQueryItem,
|
||
)
|
||
from app.services.scoring_service import ScoringService, get_health_level
|
||
from app.services.sentiment_service import get_sentiment_service
|
||
from app.schemas.scoring import CitationResult
|
||
from app.services.cache import get_cache_service, TTL_DASHBOARD
|
||
|
||
router = APIRouter()
|
||
|
||
# Required platforms for the dashboard
|
||
REQUIRED_PLATFORMS = [
|
||
"wenxin", # 文心一言
|
||
"kimi", # Kimi
|
||
"tongyi", # 通义千问
|
||
"doubao", # 豆包
|
||
"xinghuo", # 讯飞星火
|
||
"tiangong", # 天工AI
|
||
"qingyan", # 智谱清言
|
||
]
|
||
|
||
|
||
async def _get_brand_score_by_platform(
|
||
db: AsyncSession,
|
||
user_id: uuid.UUID,
|
||
brand_id: uuid.UUID,
|
||
) -> dict[str, float]:
|
||
"""
|
||
Calculate brand score by platform.
|
||
|
||
Returns a dict mapping platform key to score (0-100).
|
||
"""
|
||
brand_stmt = select(Brand).where(
|
||
Brand.id == brand_id,
|
||
Brand.user_id == user_id,
|
||
)
|
||
brand_result = await db.execute(brand_stmt)
|
||
brand = brand_result.scalar_one_or_none()
|
||
|
||
if not brand:
|
||
return {platform: 0.0 for platform in REQUIRED_PLATFORMS}
|
||
|
||
queries_stmt = select(QueryModel).where(
|
||
QueryModel.user_id == user_id,
|
||
QueryModel.target_brand == brand.name,
|
||
)
|
||
queries_result = await db.execute(queries_stmt)
|
||
queries = list(queries_result.scalars().all())
|
||
|
||
if not queries:
|
||
return {platform: 0.0 for platform in REQUIRED_PLATFORMS}
|
||
|
||
query_ids = [q.id for q in queries]
|
||
|
||
platform_scores = {platform: 0.0 for platform in REQUIRED_PLATFORMS}
|
||
|
||
for platform in REQUIRED_PLATFORMS:
|
||
citation_stmt = select(
|
||
func.count().label("total"),
|
||
func.sum(
|
||
func.cast(
|
||
func.case((CitationRecord.cited == True, 1), else_=0),
|
||
Integer
|
||
)
|
||
).label("cited")
|
||
).where(
|
||
CitationRecord.query_id.in_(query_ids),
|
||
CitationRecord.platform == platform,
|
||
)
|
||
result = await db.execute(citation_stmt)
|
||
row = result.one()
|
||
|
||
total = row.total or 0
|
||
cited = row.cited or 0
|
||
|
||
if total > 0:
|
||
citation_rate = cited / total
|
||
platform_scores[platform] = round(citation_rate * 100, 2)
|
||
else:
|
||
platform_scores[platform] = 0.0
|
||
|
||
return platform_scores
|
||
|
||
|
||
async def _get_competitor_scores_by_platform(
|
||
db: AsyncSession,
|
||
user_id: uuid.UUID,
|
||
brand_id: uuid.UUID,
|
||
) -> dict[str, float]:
|
||
"""
|
||
Calculate average competitor score by platform.
|
||
|
||
Returns a dict mapping platform key to average competitor score (0-100).
|
||
"""
|
||
competitor_stmt = select(Competitor).where(Competitor.brand_id == brand_id)
|
||
competitor_result = await db.execute(competitor_stmt)
|
||
competitors = list(competitor_result.scalars().all())
|
||
|
||
if not competitors:
|
||
return {platform: 0.0 for platform in REQUIRED_PLATFORMS}
|
||
|
||
competitor_names = [c.name for c in competitors]
|
||
|
||
competitor_queries_stmt = select(QueryModel).where(
|
||
QueryModel.user_id == user_id,
|
||
QueryModel.target_brand.in_(competitor_names),
|
||
)
|
||
competitor_queries_result = await db.execute(competitor_queries_stmt)
|
||
competitor_queries = list(competitor_queries_result.scalars().all())
|
||
|
||
if not competitor_queries:
|
||
return {platform: 0.0 for platform in REQUIRED_PLATFORMS}
|
||
|
||
competitor_query_ids = [q.id for q in competitor_queries]
|
||
|
||
platform_scores = {platform: 0.0 for platform in REQUIRED_PLATFORMS}
|
||
|
||
for platform in REQUIRED_PLATFORMS:
|
||
citation_stmt = select(
|
||
func.count().label("total"),
|
||
func.sum(
|
||
func.cast(
|
||
func.case((CitationRecord.cited == True, 1), else_=0),
|
||
Integer
|
||
)
|
||
).label("cited")
|
||
).where(
|
||
CitationRecord.query_id.in_(competitor_query_ids),
|
||
CitationRecord.platform == platform,
|
||
)
|
||
result = await db.execute(citation_stmt)
|
||
row = result.one()
|
||
|
||
total = row.total or 0
|
||
cited = row.cited or 0
|
||
|
||
if total > 0:
|
||
citation_rate = cited / total
|
||
platform_scores[platform] = round(citation_rate * 100, 2)
|
||
else:
|
||
platform_scores[platform] = 0.0
|
||
|
||
return platform_scores
|
||
|
||
|
||
async def _calculate_overall_score_v2(
|
||
db: AsyncSession,
|
||
user_id: uuid.UUID,
|
||
brand_id: uuid.UUID,
|
||
) -> tuple[float, float, list[DimensionScoreItem], int, int]:
|
||
"""
|
||
Calculate V2 overall score, change from yesterday, dimension scores,
|
||
and competitor position.
|
||
|
||
Returns: (overall_score, change_from_yesterday, dimensions, ahead_count, behind_count)
|
||
"""
|
||
brand_stmt = select(Brand).where(
|
||
Brand.id == brand_id,
|
||
Brand.user_id == user_id,
|
||
)
|
||
brand_result = await db.execute(brand_stmt)
|
||
brand = brand_result.scalar_one_or_none()
|
||
|
||
if not brand:
|
||
return 0.0, 0.0, [], 0, 0
|
||
|
||
# Get queries for this brand
|
||
queries_stmt = select(QueryModel).where(
|
||
QueryModel.user_id == user_id,
|
||
QueryModel.target_brand == brand.name,
|
||
)
|
||
queries_result = await db.execute(queries_stmt)
|
||
queries = list(queries_result.scalars().all())
|
||
|
||
if not queries:
|
||
return 0.0, 0.0, [], 0, 0
|
||
|
||
query_ids = [q.id for q in queries]
|
||
|
||
# Get all citations
|
||
citations_stmt = select(CitationRecord).where(
|
||
CitationRecord.query_id.in_(query_ids),
|
||
)
|
||
citations_result = await db.execute(citations_stmt)
|
||
all_citations = list(citations_result.scalars().all())
|
||
|
||
total_queries = len(all_citations)
|
||
brand_citations = [c for c in all_citations if c.cited]
|
||
|
||
# Get competitor data
|
||
competitor_stmt = select(Competitor).where(Competitor.brand_id == brand_id)
|
||
competitor_result = await db.execute(competitor_stmt)
|
||
competitors = list(competitor_result.scalars().all())
|
||
competitor_names = [c.name for c in competitors]
|
||
|
||
# Calculate competitor mentions
|
||
competitor_mentions: dict[str, int] = {}
|
||
for comp_name in competitor_names:
|
||
count = sum(
|
||
1 for c in all_citations
|
||
if c.cited and c.competitor_brands
|
||
and comp_name in c.competitor_brands
|
||
)
|
||
if count > 0:
|
||
competitor_mentions[comp_name] = count
|
||
|
||
# Sentiment analysis - prefer persisted sentiment field
|
||
sentiment_service = get_sentiment_service()
|
||
sentiment_counts = {"positive": 0, "neutral": 0, "negative": 0}
|
||
for citation in brand_citations:
|
||
if citation.sentiment and citation.sentiment in ("positive", "neutral", "negative"):
|
||
sentiment_counts[citation.sentiment] += 1
|
||
else:
|
||
content = citation.raw_response or citation.citation_text or ""
|
||
if content.strip():
|
||
try:
|
||
result = await sentiment_service.analyze(
|
||
brand_name=brand.name,
|
||
content=content,
|
||
)
|
||
sentiment_counts[result.sentiment] += 1
|
||
except Exception:
|
||
sentiment_counts["neutral"] += 1
|
||
else:
|
||
sentiment_counts["neutral"] += 1
|
||
|
||
# Build citation results
|
||
citation_results = [
|
||
CitationResult(
|
||
cited=c.cited,
|
||
position=c.citation_position,
|
||
citation_text=c.citation_text,
|
||
sentiment="neutral", # simplified for dashboard
|
||
confidence=c.confidence or 0.0,
|
||
)
|
||
for c in brand_citations
|
||
]
|
||
|
||
# Extract positions
|
||
positions = [c.citation_position for c in brand_citations if c.cited]
|
||
|
||
# Calculate V2 score
|
||
scoring_service = ScoringService()
|
||
v2_result = scoring_service.calculate_v2(
|
||
mentioned_count=len(brand_citations),
|
||
total_queries=total_queries,
|
||
positions=positions,
|
||
sentiment_counts=sentiment_counts,
|
||
citations=citation_results,
|
||
brand_mentions=len(brand_citations),
|
||
competitor_mentions=competitor_mentions,
|
||
)
|
||
|
||
# Build dimension items
|
||
dimensions = [
|
||
DimensionScoreItem(
|
||
name=v2_result.mention_rate.name,
|
||
score=round(v2_result.mention_rate.score, 2),
|
||
max_score=v2_result.mention_rate.max_score,
|
||
percentage=round(v2_result.mention_rate.percentage, 2),
|
||
),
|
||
DimensionScoreItem(
|
||
name=v2_result.recommendation_rank.name,
|
||
score=round(v2_result.recommendation_rank.score, 2),
|
||
max_score=v2_result.recommendation_rank.max_score,
|
||
percentage=round(v2_result.recommendation_rank.percentage, 2),
|
||
),
|
||
DimensionScoreItem(
|
||
name=v2_result.sentiment_score.name,
|
||
score=round(v2_result.sentiment_score.score, 2),
|
||
max_score=v2_result.sentiment_score.max_score,
|
||
percentage=round(v2_result.sentiment_score.percentage, 2),
|
||
),
|
||
DimensionScoreItem(
|
||
name=v2_result.citation_quality.name,
|
||
score=round(v2_result.citation_quality.score, 2),
|
||
max_score=v2_result.citation_quality.max_score,
|
||
percentage=round(v2_result.citation_quality.percentage, 2),
|
||
),
|
||
DimensionScoreItem(
|
||
name=v2_result.competitive_position.name,
|
||
score=round(v2_result.competitive_position.score, 2),
|
||
max_score=v2_result.competitive_position.max_score,
|
||
percentage=round(v2_result.competitive_position.percentage, 2),
|
||
),
|
||
]
|
||
|
||
# Calculate competitor ahead/behind
|
||
ahead_count = sum(
|
||
1 for count in competitor_mentions.values()
|
||
if len(brand_citations) > count
|
||
)
|
||
behind_count = sum(
|
||
1 for count in competitor_mentions.values()
|
||
if len(brand_citations) <= count
|
||
)
|
||
|
||
# Calculate change from yesterday
|
||
today = datetime.now().date()
|
||
yesterday = today - timedelta(days=1)
|
||
|
||
today_citations = [
|
||
c for c in all_citations
|
||
if c.queried_at.date() == today
|
||
]
|
||
yesterday_citations = [
|
||
c for c in all_citations
|
||
if c.queried_at.date() == yesterday
|
||
]
|
||
|
||
today_cited = sum(1 for c in today_citations if c.cited)
|
||
today_total = len(today_citations)
|
||
today_score = (today_cited / today_total * 100) if today_total > 0 else 0.0
|
||
|
||
yesterday_cited = sum(1 for c in yesterday_citations if c.cited)
|
||
yesterday_total = len(yesterday_citations)
|
||
yesterday_score = (yesterday_cited / yesterday_total * 100) if yesterday_total > 0 else 0.0
|
||
|
||
change = round(today_score - yesterday_score, 2)
|
||
|
||
return v2_result.overall_score, change, dimensions, ahead_count, behind_count
|
||
|
||
|
||
@router.get("/stats", response_model=DashboardStatsResponse)
|
||
async def get_dashboard_stats(
|
||
brand_id: uuid.UUID | None = Query(None),
|
||
current_user: User = Depends(get_current_user),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""
|
||
获取看板统计数据
|
||
|
||
包括:
|
||
- 综合评分(V2)和较昨日变化
|
||
- 健康等级
|
||
- 五维度评分详情
|
||
- 各平台评分列表(含竞品对比)
|
||
- 竞品地位(领先/落后数量)
|
||
- 最近查询记录
|
||
"""
|
||
cache = get_cache_service()
|
||
# 如果 brand_id 尚未确定,先查库取第一个品牌
|
||
if brand_id is None:
|
||
brand_stmt = select(Brand).where(Brand.user_id == current_user.id).limit(1)
|
||
brand_result = await db.execute(brand_stmt)
|
||
brand = brand_result.scalar_one_or_none()
|
||
|
||
if brand:
|
||
brand_id = brand.id
|
||
else:
|
||
return DashboardStatsResponse(
|
||
overall_score=0.0,
|
||
health_level="danger",
|
||
score_change=0.0,
|
||
platform_scores=[],
|
||
recent_queries=[],
|
||
dimensions=[],
|
||
competitors_ahead=0,
|
||
competitors_behind=0,
|
||
monitored_platforms=0,
|
||
total_platforms=7,
|
||
)
|
||
|
||
# 尝试从缓存读取(TTL: 2 分钟)
|
||
cache_key = f"dashboard:stats:{current_user.id}:{brand_id}"
|
||
cached = await cache.get_json(cache_key)
|
||
if cached is not None:
|
||
return cached
|
||
|
||
# Get brand name
|
||
brand_stmt = select(Brand).where(Brand.id == brand_id)
|
||
brand_result = await db.execute(brand_stmt)
|
||
brand = brand_result.scalar_one_or_none()
|
||
brand_name = brand.name if brand else None
|
||
|
||
# Calculate V2 overall score and dimensions
|
||
overall_score, score_change, dimensions, ahead_count, behind_count = (
|
||
await _calculate_overall_score_v2(db, current_user.id, brand_id)
|
||
)
|
||
|
||
# Get platform scores
|
||
platform_scores_dict = await _get_brand_score_by_platform(
|
||
db, current_user.id, brand_id
|
||
)
|
||
|
||
# Get competitor platform scores
|
||
competitor_scores_dict = await _get_competitor_scores_by_platform(
|
||
db, current_user.id, brand_id
|
||
)
|
||
|
||
# Get first competitor name for display
|
||
competitor_stmt = select(Competitor).where(
|
||
Competitor.brand_id == brand_id
|
||
).limit(1)
|
||
competitor_result = await db.execute(competitor_stmt)
|
||
first_competitor = competitor_result.scalar_one_or_none()
|
||
competitor_name = first_competitor.name if first_competitor else None
|
||
|
||
platform_scores = [
|
||
PlatformScoreItem(
|
||
platform=platform,
|
||
score=score,
|
||
competitor_score=competitor_scores_dict.get(platform),
|
||
competitor_name=competitor_name if competitor_scores_dict.get(platform, 0) > 0 else None,
|
||
)
|
||
for platform, score in platform_scores_dict.items()
|
||
]
|
||
|
||
# Count monitored platforms
|
||
monitored = sum(1 for s in platform_scores_dict.values() if s > 0)
|
||
|
||
# Get recent queries (last 10)
|
||
recent_queries_stmt = (
|
||
select(QueryModel)
|
||
.where(QueryModel.user_id == current_user.id)
|
||
.order_by(QueryModel.created_at.desc())
|
||
.limit(10)
|
||
)
|
||
recent_queries_result = await db.execute(recent_queries_stmt)
|
||
recent_queries_list = list(recent_queries_result.scalars().all())
|
||
|
||
# Get citation count for each query
|
||
recent_queries = []
|
||
for query in recent_queries_list:
|
||
citation_count_stmt = select(
|
||
func.count().label("total"),
|
||
func.sum(
|
||
func.cast(
|
||
func.case((CitationRecord.cited == True, 1), else_=0),
|
||
Integer
|
||
)
|
||
).label("cited")
|
||
).where(CitationRecord.query_id == query.id)
|
||
count_result = await db.execute(citation_count_stmt)
|
||
count_row = count_result.one()
|
||
|
||
recent_queries.append(RecentQueryItem(
|
||
id=str(query.id),
|
||
keyword=query.keyword,
|
||
target_brand=query.target_brand,
|
||
citation_count=count_row.cited or 0,
|
||
queried_at=query.last_queried_at or query.created_at,
|
||
))
|
||
|
||
# Health level
|
||
health_level = get_health_level(overall_score)
|
||
|
||
response = DashboardStatsResponse(
|
||
overall_score=round(overall_score, 2),
|
||
health_level=health_level,
|
||
score_change=score_change,
|
||
platform_scores=platform_scores,
|
||
recent_queries=recent_queries,
|
||
dimensions=dimensions,
|
||
competitors_ahead=ahead_count,
|
||
competitors_behind=behind_count,
|
||
monitored_platforms=monitored,
|
||
total_platforms=7,
|
||
brand_name=brand_name,
|
||
)
|
||
|
||
# 将结果写入缓存(TTL: 2 分钟)
|
||
await cache.set_json(
|
||
cache_key,
|
||
response.model_dump(mode="json"),
|
||
expire=TTL_DASHBOARD,
|
||
)
|
||
|
||
return response
|