geo/backend/app/api/health_score.py

145 lines
4.6 KiB
Python

import hashlib
import logging
from fastapi import APIRouter, Depends, Query
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.database import get_db
from app.models.brand import Brand
from app.schemas.health_score import (
HealthScoreDimension,
HealthScoreRecommendation,
HealthScoreResponse,
)
from app.services.cache import get_cache_service
from app.services.diagnosis.data_collector import DataCollectorService
from app.services.diagnosis.geo_diagnosis import GEODiagnosisService
from app.utils.health import get_health_level, get_health_level_label
logger = logging.getLogger(__name__)
router = APIRouter()
_FREE_TIER_DIMENSIONS = {"内容可提取性", "E-E-A-T信号", "引用就绪度"}
_CACHE_TTL = 86400
def _build_default_response(brand_name: str) -> HealthScoreResponse:
return HealthScoreResponse(
brand_name=brand_name,
overall_score=0.0,
health_level="danger",
health_level_label=get_health_level_label("danger"),
dimensions=[
HealthScoreDimension(
name=d,
score=0.0,
max_score=0.0,
percentage=0.0,
status="fail",
)
for d in sorted(_FREE_TIER_DIMENSIONS)
],
recommendations=[],
is_full_report=False,
cached=False,
)
@router.get("/health-score", response_model=HealthScoreResponse)
async def get_public_health_score(
brand: str = Query(..., min_length=1),
competitors: str = Query(default=""),
db: AsyncSession = Depends(get_db),
):
cache = get_cache_service()
cache_key = f"health_score:{hashlib.md5(brand.lower().encode()).hexdigest()}"
cached_data = await cache.get_json(cache_key)
if cached_data is not None:
cached_data["cached"] = True
return cached_data
brand_name = brand.strip()
competitor_list = [c.strip() for c in competitors.split(",") if c.strip()][:3]
brand_aliases: list[str] = []
website: str | None = None
industry: str | None = None
stmt = select(Brand).where(Brand.name == brand_name).limit(1)
result = await db.execute(stmt)
brand_record = result.scalar_one_or_none()
if brand_record:
brand_aliases = brand_record.aliases or []
website = brand_record.website
industry = brand_record.industry
try:
collector = DataCollectorService(db)
collection = await collector.collect(
brand_name=brand_name,
brand_aliases=brand_aliases,
website=website,
industry=industry,
)
geo_service = GEODiagnosisService()
diagnosis_result = geo_service.diagnose(collection.diagnosis_input)
except Exception as e:
logger.error(f"健康分数据采集失败: brand={brand_name}, error={e}", exc_info=True)
default_resp = _build_default_response(brand_name)
await cache.set_json(cache_key, default_resp.model_dump(mode="json"), expire=_CACHE_TTL)
return default_resp
all_dimensions = diagnosis_result.to_dict().get("dimensions", [])
free_dimensions = [d for d in all_dimensions if d.get("name") in _FREE_TIER_DIMENSIONS]
overall_score = round(sum(d.get("score", 0.0) for d in free_dimensions), 2)
health_level = get_health_level(overall_score)
dimension_responses = [
HealthScoreDimension(
name=d.get("name", ""),
score=round(d.get("score", 0.0), 2),
max_score=d.get("max_score", 0.0),
percentage=round(d.get("percentage", 0.0), 2),
status=d.get("status", "fail"),
)
for d in free_dimensions
]
all_recommendations = diagnosis_result.to_dict().get("recommendations", [])
free_recommendations = [
r for r in all_recommendations
if r.get("priority") == "P0" and r.get("dimension") in _FREE_TIER_DIMENSIONS
]
recommendation_responses = [
HealthScoreRecommendation(
priority=r.get("priority", "P0"),
dimension=r.get("dimension", ""),
title=r.get("title", ""),
description=r.get("description", ""),
)
for r in free_recommendations
]
response = HealthScoreResponse(
brand_name=brand_name,
overall_score=overall_score,
health_level=health_level,
health_level_label=get_health_level_label(health_level),
dimensions=dimension_responses,
recommendations=recommendation_responses,
is_full_report=False,
cached=False,
)
await cache.set_json(cache_key, response.model_dump(mode="json"), expire=_CACHE_TTL)
return response