geo/backend/app/api/diagnosis.py

395 lines
12 KiB
Python

"""诊断API端点 - 提供SEO和GEO诊断功能"""
import asyncio
import logging
import uuid
from datetime import UTC, datetime
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.database import get_db
from app.models.user import User
from app.models.brand import Brand
from app.models.diagnosis_record import DiagnosisRecord
from app.schemas.diagnosis import (
GEODiagnosisHistoryItem,
GEODiagnosisHistoryResponse,
GEODiagnosisResponse,
GEODiagnosisResultResponse,
GEODiagnosisTaskResponse,
GEODiagnosisTriggerRequest,
)
from app.services.diagnosis.data_collector import DataCollectorService
from app.services.diagnosis.seo_diagnosis import SEODiagnosisService
from app.services.diagnosis.geo_diagnosis import GEODiagnosisService, GEODiagnosisInput
from app.utils.health import get_health_level_label
logger = logging.getLogger(__name__)
router = APIRouter()
_FREE_TIER_DIMENSIONS = {"内容可提取性", "E-E-A-T信号", "引用就绪度"}
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
if isinstance(value, uuid.UUID):
return value
return uuid.UUID(str(value))
@router.get("/seo/{brand_id}")
async def get_seo_diagnosis(
brand_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
brand = await _get_brand_or_404(brand_id, current_user, db)
try:
service = SEODiagnosisService()
result = service.diagnose()
logger.info(f"SEO诊断完成: brand_id={brand_id}, brand={brand.name}, score={result.overall_score}")
return result.to_dict()
except Exception as e:
logger.error(f"SEO诊断失败: brand_id={brand_id}, error={str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="SEO诊断服务异常,请稍后重试",
)
@router.post("/geo/{brand_id}", status_code=status.HTTP_202_ACCEPTED)
async def trigger_geo_diagnosis(
brand_id: uuid.UUID,
body: GEODiagnosisTriggerRequest = GEODiagnosisTriggerRequest(),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
brand = await _get_brand_or_404(brand_id, current_user, db)
if not body.force_refresh:
existing = await _find_recent_completed(db, brand_id, hours=24)
if existing:
return GEODiagnosisTaskResponse(
task_id=str(existing.id),
brand_id=str(brand_id),
status="completed",
)
record = DiagnosisRecord(
brand_id=brand_id,
user_id=_to_uuid(current_user.id),
diagnosis_type="geo",
status="pending",
)
db.add(record)
await db.commit()
await db.refresh(record)
asyncio.create_task(
_run_geo_diagnosis(
record_id=record.id,
brand_id=brand_id,
brand_name=brand.name,
brand_aliases=brand.aliases or [],
website=brand.website,
industry=brand.industry,
user_id=_to_uuid(current_user.id),
)
)
return GEODiagnosisTaskResponse(
task_id=str(record.id),
brand_id=str(brand_id),
status="pending",
)
@router.get("/geo/{brand_id}/result")
async def get_geo_diagnosis_result(
brand_id: uuid.UUID,
task_id: str | None = None,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
await _get_brand_or_404(brand_id, current_user, db)
if task_id:
try:
tid = uuid.UUID(task_id)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="无效的task_id",
)
stmt = select(DiagnosisRecord).where(
DiagnosisRecord.id == tid,
DiagnosisRecord.brand_id == brand_id,
)
else:
stmt = (
select(DiagnosisRecord)
.where(
DiagnosisRecord.brand_id == brand_id,
DiagnosisRecord.diagnosis_type == "geo",
)
.order_by(DiagnosisRecord.created_at.desc())
.limit(1)
)
result = await db.execute(stmt)
record = result.scalar_one_or_none()
if not record:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="诊断记录不存在",
)
if record.status == "pending" or record.status == "running":
return GEODiagnosisResultResponse(
task_id=str(record.id),
brand_id=str(brand_id),
status=record.status,
)
if record.status == "failed":
return GEODiagnosisResultResponse(
task_id=str(record.id),
brand_id=str(brand_id),
status="failed",
error=record.error_message,
)
user_plan = getattr(current_user, "plan", None) or "free"
is_paid = user_plan not in ("free", None)
diagnosis_resp = _build_diagnosis_response(record.result_json, is_paid)
return GEODiagnosisResultResponse(
task_id=str(record.id),
brand_id=str(brand_id),
status="completed",
result=diagnosis_resp,
)
@router.get("/geo/{brand_id}/history")
async def get_geo_diagnosis_history(
brand_id: uuid.UUID,
limit: int = 10,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
await _get_brand_or_404(brand_id, current_user, db)
stmt = (
select(DiagnosisRecord)
.where(
DiagnosisRecord.brand_id == brand_id,
DiagnosisRecord.diagnosis_type == "geo",
DiagnosisRecord.status == "completed",
)
.order_by(DiagnosisRecord.created_at.desc())
.limit(limit)
)
result = await db.execute(stmt)
records = result.scalars().all()
items = [
GEODiagnosisHistoryItem(
task_id=str(r.id),
overall_score=r.overall_score or 0,
health_level=r.result_json.get("health_level", "danger") if r.result_json else "danger",
created_at=r.created_at.isoformat(),
completed_at=r.completed_at.isoformat() if r.completed_at else None,
)
for r in records
]
return GEODiagnosisHistoryResponse(
brand_id=str(brand_id),
history=items,
)
@router.get("/combined/{brand_id}")
async def get_combined_diagnosis(
brand_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
brand = await _get_brand_or_404(brand_id, current_user, db)
try:
seo_service = SEODiagnosisService()
seo_result = seo_service.diagnose()
collector = DataCollectorService(db)
collection = await collector.collect(
brand_name=brand.name,
brand_aliases=brand.aliases or [],
website=brand.website,
industry=brand.industry,
)
geo_service = GEODiagnosisService()
geo_result = geo_service.diagnose(collection.diagnosis_input)
combined_score = round((seo_result.overall_score + geo_result.overall_score) / 2, 2)
logger.info(
f"综合诊断完成: brand_id={brand_id}, brand={brand.name}, "
f"seo_score={seo_result.overall_score}, "
f"geo_score={geo_result.overall_score}, "
f"combined_score={combined_score}"
)
return {
"seo_score": seo_result.overall_score,
"geo_score": geo_result.overall_score,
"combined_score": combined_score,
"seo_diagnosis": seo_result.to_dict(),
"geo_diagnosis": geo_result.to_dict(),
}
except Exception as e:
logger.error(f"综合诊断失败: brand_id={brand_id}, error={str(e)}", exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="综合诊断服务异常,请稍后重试",
)
async def _run_geo_diagnosis(
record_id: uuid.UUID,
brand_id: uuid.UUID,
brand_name: str,
brand_aliases: list[str],
website: str | None,
industry: str | None,
user_id: uuid.UUID,
) -> None:
from app.database import AsyncSessionLocal
async with AsyncSessionLocal() as db:
try:
stmt = select(DiagnosisRecord).where(DiagnosisRecord.id == record_id)
result = await db.execute(stmt)
record = result.scalar_one_or_none()
if not record:
return
record.status = "running"
await db.commit()
collector = DataCollectorService(db)
collection = await collector.collect(
brand_name=brand_name,
brand_aliases=brand_aliases,
website=website,
industry=industry,
)
geo_service = GEODiagnosisService()
diagnosis_result = geo_service.diagnose(collection.diagnosis_input)
record.status = "completed"
record.overall_score = diagnosis_result.overall_score
record.result_json = diagnosis_result.to_dict()
record.completed_at = datetime.now(UTC)
record.collection_metadata = collection.metadata
if collection.errors:
record.collection_metadata["errors"] = collection.errors
await db.commit()
logger.info(
f"GEO诊断完成: brand_id={brand_id}, brand={brand_name}, "
f"score={diagnosis_result.overall_score}"
)
except Exception as e:
logger.error(f"GEO诊断任务失败: record_id={record_id}, error={e}", exc_info=True)
try:
stmt = select(DiagnosisRecord).where(DiagnosisRecord.id == record_id)
result = await db.execute(stmt)
record = result.scalar_one_or_none()
if record:
record.status = "failed"
record.error_message = str(e)
await db.commit()
except Exception:
logger.error(f"更新失败状态也失败: record_id={record_id}")
def _build_diagnosis_response(result_json: dict | None, is_paid: bool) -> GEODiagnosisResponse:
if not result_json:
return GEODiagnosisResponse(
overall_score=0,
health_level="danger",
health_level_label=get_health_level_label("danger"),
dimensions=[],
recommendations=[],
is_full_report=is_paid,
)
dimensions = result_json.get("dimensions", [])
if not is_paid:
dimensions = [d for d in dimensions if d.get("name") in _FREE_TIER_DIMENSIONS]
recommendations = result_json.get("recommendations", [])
if not is_paid:
recommendations = [r for r in recommendations if r.get("priority") == "P0"]
return GEODiagnosisResponse(
overall_score=result_json.get("overall_score", 0),
health_level=result_json.get("health_level", "danger"),
health_level_label=result_json.get("health_level_label", get_health_level_label("danger")),
dimensions=dimensions,
recommendations=recommendations,
is_full_report=is_paid,
)
async def _find_recent_completed(
db: AsyncSession, brand_id: uuid.UUID, hours: int = 24
) -> DiagnosisRecord | None:
from datetime import timedelta
cutoff = datetime.now(UTC) - timedelta(hours=hours)
stmt = (
select(DiagnosisRecord)
.where(
DiagnosisRecord.brand_id == brand_id,
DiagnosisRecord.status == "completed",
DiagnosisRecord.completed_at >= cutoff,
)
.order_by(DiagnosisRecord.completed_at.desc())
.limit(1)
)
result = await db.execute(stmt)
return result.scalar_one_or_none()
async def _get_brand_or_404(
brand_id: uuid.UUID,
current_user: User,
db: AsyncSession,
) -> Brand:
user_uuid = _to_uuid(current_user.id)
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == user_uuid)
result = await db.execute(stmt)
brand = result.scalar_one_or_none()
if not brand:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="品牌不存在",
)
return brand