geo/backend/app/api/dashboard.py

488 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Dashboard API endpoints."""
import uuid
from datetime import datetime, timedelta
from fastapi import APIRouter, Depends, Query
from sqlalchemy import select, func, Integer
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.database import get_db
from app.models.user import User
from app.models.brand import Brand
from app.models.competitor import Competitor
from app.models.query import Query as QueryModel
from app.models.citation_record import CitationRecord
from app.schemas.dashboard import (
DashboardStatsResponse,
DimensionScoreItem,
PlatformScoreItem,
RecentQueryItem,
)
from app.services.scoring_service import ScoringService, get_health_level
from app.services.sentiment_service import get_sentiment_service
from app.schemas.scoring import CitationResult
from app.services.cache import get_cache_service, TTL_DASHBOARD
router = APIRouter()
# Required platforms for the dashboard
REQUIRED_PLATFORMS = [
"wenxin", # 文心一言
"kimi", # Kimi
"tongyi", # 通义千问
"doubao", # 豆包
"xinghuo", # 讯飞星火
"tiangong", # 天工AI
"qingyan", # 智谱清言
]
async def _get_brand_score_by_platform(
db: AsyncSession,
user_id: uuid.UUID,
brand_id: uuid.UUID,
) -> dict[str, float]:
"""
Calculate brand score by platform.
Returns a dict mapping platform key to score (0-100).
"""
brand_stmt = select(Brand).where(
Brand.id == brand_id,
Brand.user_id == user_id,
)
brand_result = await db.execute(brand_stmt)
brand = brand_result.scalar_one_or_none()
if not brand:
return {platform: 0.0 for platform in REQUIRED_PLATFORMS}
queries_stmt = select(QueryModel).where(
QueryModel.user_id == user_id,
QueryModel.target_brand == brand.name,
)
queries_result = await db.execute(queries_stmt)
queries = list(queries_result.scalars().all())
if not queries:
return {platform: 0.0 for platform in REQUIRED_PLATFORMS}
query_ids = [q.id for q in queries]
platform_scores = {platform: 0.0 for platform in REQUIRED_PLATFORMS}
for platform in REQUIRED_PLATFORMS:
citation_stmt = select(
func.count().label("total"),
func.sum(
func.cast(
func.case((CitationRecord.cited == True, 1), else_=0),
Integer
)
).label("cited")
).where(
CitationRecord.query_id.in_(query_ids),
CitationRecord.platform == platform,
)
result = await db.execute(citation_stmt)
row = result.one()
total = row.total or 0
cited = row.cited or 0
if total > 0:
citation_rate = cited / total
platform_scores[platform] = round(citation_rate * 100, 2)
else:
platform_scores[platform] = 0.0
return platform_scores
async def _get_competitor_scores_by_platform(
db: AsyncSession,
user_id: uuid.UUID,
brand_id: uuid.UUID,
) -> dict[str, float]:
"""
Calculate average competitor score by platform.
Returns a dict mapping platform key to average competitor score (0-100).
"""
competitor_stmt = select(Competitor).where(Competitor.brand_id == brand_id)
competitor_result = await db.execute(competitor_stmt)
competitors = list(competitor_result.scalars().all())
if not competitors:
return {platform: 0.0 for platform in REQUIRED_PLATFORMS}
competitor_names = [c.name for c in competitors]
competitor_queries_stmt = select(QueryModel).where(
QueryModel.user_id == user_id,
QueryModel.target_brand.in_(competitor_names),
)
competitor_queries_result = await db.execute(competitor_queries_stmt)
competitor_queries = list(competitor_queries_result.scalars().all())
if not competitor_queries:
return {platform: 0.0 for platform in REQUIRED_PLATFORMS}
competitor_query_ids = [q.id for q in competitor_queries]
platform_scores = {platform: 0.0 for platform in REQUIRED_PLATFORMS}
for platform in REQUIRED_PLATFORMS:
citation_stmt = select(
func.count().label("total"),
func.sum(
func.cast(
func.case((CitationRecord.cited == True, 1), else_=0),
Integer
)
).label("cited")
).where(
CitationRecord.query_id.in_(competitor_query_ids),
CitationRecord.platform == platform,
)
result = await db.execute(citation_stmt)
row = result.one()
total = row.total or 0
cited = row.cited or 0
if total > 0:
citation_rate = cited / total
platform_scores[platform] = round(citation_rate * 100, 2)
else:
platform_scores[platform] = 0.0
return platform_scores
async def _calculate_overall_score_v2(
db: AsyncSession,
user_id: uuid.UUID,
brand_id: uuid.UUID,
) -> tuple[float, float, list[DimensionScoreItem], int, int]:
"""
Calculate V2 overall score, change from yesterday, dimension scores,
and competitor position.
Returns: (overall_score, change_from_yesterday, dimensions, ahead_count, behind_count)
"""
brand_stmt = select(Brand).where(
Brand.id == brand_id,
Brand.user_id == user_id,
)
brand_result = await db.execute(brand_stmt)
brand = brand_result.scalar_one_or_none()
if not brand:
return 0.0, 0.0, [], 0, 0
# Get queries for this brand
queries_stmt = select(QueryModel).where(
QueryModel.user_id == user_id,
QueryModel.target_brand == brand.name,
)
queries_result = await db.execute(queries_stmt)
queries = list(queries_result.scalars().all())
if not queries:
return 0.0, 0.0, [], 0, 0
query_ids = [q.id for q in queries]
# Get all citations
citations_stmt = select(CitationRecord).where(
CitationRecord.query_id.in_(query_ids),
)
citations_result = await db.execute(citations_stmt)
all_citations = list(citations_result.scalars().all())
total_queries = len(all_citations)
brand_citations = [c for c in all_citations if c.cited]
# Get competitor data
competitor_stmt = select(Competitor).where(Competitor.brand_id == brand_id)
competitor_result = await db.execute(competitor_stmt)
competitors = list(competitor_result.scalars().all())
competitor_names = [c.name for c in competitors]
# Calculate competitor mentions
competitor_mentions: dict[str, int] = {}
for comp_name in competitor_names:
count = sum(
1 for c in all_citations
if c.cited and c.competitor_brands
and comp_name in c.competitor_brands
)
if count > 0:
competitor_mentions[comp_name] = count
# Sentiment analysis - prefer persisted sentiment field
sentiment_service = get_sentiment_service()
sentiment_counts = {"positive": 0, "neutral": 0, "negative": 0}
for citation in brand_citations:
if citation.sentiment and citation.sentiment in ("positive", "neutral", "negative"):
sentiment_counts[citation.sentiment] += 1
else:
content = citation.raw_response or citation.citation_text or ""
if content.strip():
try:
result = await sentiment_service.analyze(
brand_name=brand.name,
content=content,
)
sentiment_counts[result.sentiment] += 1
except Exception:
sentiment_counts["neutral"] += 1
else:
sentiment_counts["neutral"] += 1
# Build citation results
citation_results = [
CitationResult(
cited=c.cited,
position=c.citation_position,
citation_text=c.citation_text,
sentiment="neutral", # simplified for dashboard
confidence=c.confidence or 0.0,
)
for c in brand_citations
]
# Extract positions
positions = [c.citation_position for c in brand_citations if c.cited]
# Calculate V2 score
scoring_service = ScoringService()
v2_result = scoring_service.calculate_v2(
mentioned_count=len(brand_citations),
total_queries=total_queries,
positions=positions,
sentiment_counts=sentiment_counts,
citations=citation_results,
brand_mentions=len(brand_citations),
competitor_mentions=competitor_mentions,
)
# Build dimension items
dimensions = [
DimensionScoreItem(
name=v2_result.mention_rate.name,
score=round(v2_result.mention_rate.score, 2),
max_score=v2_result.mention_rate.max_score,
percentage=round(v2_result.mention_rate.percentage, 2),
),
DimensionScoreItem(
name=v2_result.recommendation_rank.name,
score=round(v2_result.recommendation_rank.score, 2),
max_score=v2_result.recommendation_rank.max_score,
percentage=round(v2_result.recommendation_rank.percentage, 2),
),
DimensionScoreItem(
name=v2_result.sentiment_score.name,
score=round(v2_result.sentiment_score.score, 2),
max_score=v2_result.sentiment_score.max_score,
percentage=round(v2_result.sentiment_score.percentage, 2),
),
DimensionScoreItem(
name=v2_result.citation_quality.name,
score=round(v2_result.citation_quality.score, 2),
max_score=v2_result.citation_quality.max_score,
percentage=round(v2_result.citation_quality.percentage, 2),
),
DimensionScoreItem(
name=v2_result.competitive_position.name,
score=round(v2_result.competitive_position.score, 2),
max_score=v2_result.competitive_position.max_score,
percentage=round(v2_result.competitive_position.percentage, 2),
),
]
# Calculate competitor ahead/behind
ahead_count = sum(
1 for count in competitor_mentions.values()
if len(brand_citations) > count
)
behind_count = sum(
1 for count in competitor_mentions.values()
if len(brand_citations) <= count
)
# Calculate change from yesterday
today = datetime.now().date()
yesterday = today - timedelta(days=1)
today_citations = [
c for c in all_citations
if c.queried_at.date() == today
]
yesterday_citations = [
c for c in all_citations
if c.queried_at.date() == yesterday
]
today_cited = sum(1 for c in today_citations if c.cited)
today_total = len(today_citations)
today_score = (today_cited / today_total * 100) if today_total > 0 else 0.0
yesterday_cited = sum(1 for c in yesterday_citations if c.cited)
yesterday_total = len(yesterday_citations)
yesterday_score = (yesterday_cited / yesterday_total * 100) if yesterday_total > 0 else 0.0
change = round(today_score - yesterday_score, 2)
return v2_result.overall_score, change, dimensions, ahead_count, behind_count
@router.get("/stats", response_model=DashboardStatsResponse)
async def get_dashboard_stats(
brand_id: uuid.UUID | None = Query(None),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
"""
获取看板统计数据
包括:
- 综合评分(V2)和较昨日变化
- 健康等级
- 五维度评分详情
- 各平台评分列表(含竞品对比)
- 竞品地位(领先/落后数量)
- 最近查询记录
"""
cache = get_cache_service()
# 如果 brand_id 尚未确定,先查库取第一个品牌
if brand_id is None:
brand_stmt = select(Brand).where(Brand.user_id == current_user.id).limit(1)
brand_result = await db.execute(brand_stmt)
brand = brand_result.scalar_one_or_none()
if brand:
brand_id = brand.id
else:
return DashboardStatsResponse(
overall_score=0.0,
health_level="danger",
score_change=0.0,
platform_scores=[],
recent_queries=[],
dimensions=[],
competitors_ahead=0,
competitors_behind=0,
monitored_platforms=0,
total_platforms=7,
)
# 尝试从缓存读取TTL: 2 分钟)
cache_key = f"dashboard:stats:{current_user.id}:{brand_id}"
cached = await cache.get_json(cache_key)
if cached is not None:
return cached
# Get brand name
brand_stmt = select(Brand).where(Brand.id == brand_id)
brand_result = await db.execute(brand_stmt)
brand = brand_result.scalar_one_or_none()
brand_name = brand.name if brand else None
# Calculate V2 overall score and dimensions
overall_score, score_change, dimensions, ahead_count, behind_count = (
await _calculate_overall_score_v2(db, current_user.id, brand_id)
)
# Get platform scores
platform_scores_dict = await _get_brand_score_by_platform(
db, current_user.id, brand_id
)
# Get competitor platform scores
competitor_scores_dict = await _get_competitor_scores_by_platform(
db, current_user.id, brand_id
)
# Get first competitor name for display
competitor_stmt = select(Competitor).where(
Competitor.brand_id == brand_id
).limit(1)
competitor_result = await db.execute(competitor_stmt)
first_competitor = competitor_result.scalar_one_or_none()
competitor_name = first_competitor.name if first_competitor else None
platform_scores = [
PlatformScoreItem(
platform=platform,
score=score,
competitor_score=competitor_scores_dict.get(platform),
competitor_name=competitor_name if competitor_scores_dict.get(platform, 0) > 0 else None,
)
for platform, score in platform_scores_dict.items()
]
# Count monitored platforms
monitored = sum(1 for s in platform_scores_dict.values() if s > 0)
# Get recent queries (last 10)
recent_queries_stmt = (
select(QueryModel)
.where(QueryModel.user_id == current_user.id)
.order_by(QueryModel.created_at.desc())
.limit(10)
)
recent_queries_result = await db.execute(recent_queries_stmt)
recent_queries_list = list(recent_queries_result.scalars().all())
# Get citation count for each query
recent_queries = []
for query in recent_queries_list:
citation_count_stmt = select(
func.count().label("total"),
func.sum(
func.cast(
func.case((CitationRecord.cited == True, 1), else_=0),
Integer
)
).label("cited")
).where(CitationRecord.query_id == query.id)
count_result = await db.execute(citation_count_stmt)
count_row = count_result.one()
recent_queries.append(RecentQueryItem(
id=str(query.id),
keyword=query.keyword,
target_brand=query.target_brand,
citation_count=count_row.cited or 0,
queried_at=query.last_queried_at or query.created_at,
))
# Health level
health_level = get_health_level(overall_score)
response = DashboardStatsResponse(
overall_score=round(overall_score, 2),
health_level=health_level,
score_change=score_change,
platform_scores=platform_scores,
recent_queries=recent_queries,
dimensions=dimensions,
competitors_ahead=ahead_count,
competitors_behind=behind_count,
monitored_platforms=monitored,
total_platforms=7,
brand_name=brand_name,
)
# 将结果写入缓存TTL: 2 分钟)
await cache.set_json(
cache_key,
response.model_dump(mode="json"),
expire=TTL_DASHBOARD,
)
return response