395 lines
12 KiB
Python
395 lines
12 KiB
Python
"""诊断API端点 - 提供SEO和GEO诊断功能"""
|
|
import asyncio
|
|
import logging
|
|
import uuid
|
|
from datetime import UTC, datetime
|
|
from typing import Annotated
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, status
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.api.deps import get_current_user
|
|
from app.database import get_db
|
|
from app.models.user import User
|
|
from app.models.brand import Brand
|
|
from app.models.diagnosis_record import DiagnosisRecord
|
|
from app.schemas.diagnosis import (
|
|
GEODiagnosisHistoryItem,
|
|
GEODiagnosisHistoryResponse,
|
|
GEODiagnosisResponse,
|
|
GEODiagnosisResultResponse,
|
|
GEODiagnosisTaskResponse,
|
|
GEODiagnosisTriggerRequest,
|
|
)
|
|
from app.services.diagnosis.data_collector import DataCollectorService
|
|
from app.services.diagnosis.seo_diagnosis import SEODiagnosisService
|
|
from app.services.diagnosis.geo_diagnosis import GEODiagnosisService, GEODiagnosisInput
|
|
from app.utils.health import get_health_level_label
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter()
|
|
|
|
_FREE_TIER_DIMENSIONS = {"内容可提取性", "E-E-A-T信号", "引用就绪度"}
|
|
|
|
|
|
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
|
|
if isinstance(value, uuid.UUID):
|
|
return value
|
|
return uuid.UUID(str(value))
|
|
|
|
|
|
@router.get("/seo/{brand_id}")
|
|
async def get_seo_diagnosis(
|
|
brand_id: uuid.UUID,
|
|
current_user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
brand = await _get_brand_or_404(brand_id, current_user, db)
|
|
|
|
try:
|
|
service = SEODiagnosisService()
|
|
result = service.diagnose()
|
|
|
|
logger.info(f"SEO诊断完成: brand_id={brand_id}, brand={brand.name}, score={result.overall_score}")
|
|
|
|
return result.to_dict()
|
|
except Exception as e:
|
|
logger.error(f"SEO诊断失败: brand_id={brand_id}, error={str(e)}", exc_info=True)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="SEO诊断服务异常,请稍后重试",
|
|
)
|
|
|
|
|
|
@router.post("/geo/{brand_id}", status_code=status.HTTP_202_ACCEPTED)
|
|
async def trigger_geo_diagnosis(
|
|
brand_id: uuid.UUID,
|
|
body: GEODiagnosisTriggerRequest = GEODiagnosisTriggerRequest(),
|
|
current_user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
brand = await _get_brand_or_404(brand_id, current_user, db)
|
|
|
|
if not body.force_refresh:
|
|
existing = await _find_recent_completed(db, brand_id, hours=24)
|
|
if existing:
|
|
return GEODiagnosisTaskResponse(
|
|
task_id=str(existing.id),
|
|
brand_id=str(brand_id),
|
|
status="completed",
|
|
)
|
|
|
|
record = DiagnosisRecord(
|
|
brand_id=brand_id,
|
|
user_id=_to_uuid(current_user.id),
|
|
diagnosis_type="geo",
|
|
status="pending",
|
|
)
|
|
db.add(record)
|
|
await db.commit()
|
|
await db.refresh(record)
|
|
|
|
asyncio.create_task(
|
|
_run_geo_diagnosis(
|
|
record_id=record.id,
|
|
brand_id=brand_id,
|
|
brand_name=brand.name,
|
|
brand_aliases=brand.aliases or [],
|
|
website=brand.website,
|
|
industry=brand.industry,
|
|
user_id=_to_uuid(current_user.id),
|
|
)
|
|
)
|
|
|
|
return GEODiagnosisTaskResponse(
|
|
task_id=str(record.id),
|
|
brand_id=str(brand_id),
|
|
status="pending",
|
|
)
|
|
|
|
|
|
@router.get("/geo/{brand_id}/result")
|
|
async def get_geo_diagnosis_result(
|
|
brand_id: uuid.UUID,
|
|
task_id: str | None = None,
|
|
current_user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
await _get_brand_or_404(brand_id, current_user, db)
|
|
|
|
if task_id:
|
|
try:
|
|
tid = uuid.UUID(task_id)
|
|
except ValueError:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="无效的task_id",
|
|
)
|
|
stmt = select(DiagnosisRecord).where(
|
|
DiagnosisRecord.id == tid,
|
|
DiagnosisRecord.brand_id == brand_id,
|
|
)
|
|
else:
|
|
stmt = (
|
|
select(DiagnosisRecord)
|
|
.where(
|
|
DiagnosisRecord.brand_id == brand_id,
|
|
DiagnosisRecord.diagnosis_type == "geo",
|
|
)
|
|
.order_by(DiagnosisRecord.created_at.desc())
|
|
.limit(1)
|
|
)
|
|
|
|
result = await db.execute(stmt)
|
|
record = result.scalar_one_or_none()
|
|
|
|
if not record:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="诊断记录不存在",
|
|
)
|
|
|
|
if record.status == "pending" or record.status == "running":
|
|
return GEODiagnosisResultResponse(
|
|
task_id=str(record.id),
|
|
brand_id=str(brand_id),
|
|
status=record.status,
|
|
)
|
|
|
|
if record.status == "failed":
|
|
return GEODiagnosisResultResponse(
|
|
task_id=str(record.id),
|
|
brand_id=str(brand_id),
|
|
status="failed",
|
|
error=record.error_message,
|
|
)
|
|
|
|
user_plan = getattr(current_user, "plan", None) or "free"
|
|
is_paid = user_plan not in ("free", None)
|
|
diagnosis_resp = _build_diagnosis_response(record.result_json, is_paid)
|
|
|
|
return GEODiagnosisResultResponse(
|
|
task_id=str(record.id),
|
|
brand_id=str(brand_id),
|
|
status="completed",
|
|
result=diagnosis_resp,
|
|
)
|
|
|
|
|
|
@router.get("/geo/{brand_id}/history")
|
|
async def get_geo_diagnosis_history(
|
|
brand_id: uuid.UUID,
|
|
limit: int = 10,
|
|
current_user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
await _get_brand_or_404(brand_id, current_user, db)
|
|
|
|
stmt = (
|
|
select(DiagnosisRecord)
|
|
.where(
|
|
DiagnosisRecord.brand_id == brand_id,
|
|
DiagnosisRecord.diagnosis_type == "geo",
|
|
DiagnosisRecord.status == "completed",
|
|
)
|
|
.order_by(DiagnosisRecord.created_at.desc())
|
|
.limit(limit)
|
|
)
|
|
result = await db.execute(stmt)
|
|
records = result.scalars().all()
|
|
|
|
items = [
|
|
GEODiagnosisHistoryItem(
|
|
task_id=str(r.id),
|
|
overall_score=r.overall_score or 0,
|
|
health_level=r.result_json.get("health_level", "danger") if r.result_json else "danger",
|
|
created_at=r.created_at.isoformat(),
|
|
completed_at=r.completed_at.isoformat() if r.completed_at else None,
|
|
)
|
|
for r in records
|
|
]
|
|
|
|
return GEODiagnosisHistoryResponse(
|
|
brand_id=str(brand_id),
|
|
history=items,
|
|
)
|
|
|
|
|
|
@router.get("/combined/{brand_id}")
|
|
async def get_combined_diagnosis(
|
|
brand_id: uuid.UUID,
|
|
current_user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
brand = await _get_brand_or_404(brand_id, current_user, db)
|
|
|
|
try:
|
|
seo_service = SEODiagnosisService()
|
|
seo_result = seo_service.diagnose()
|
|
|
|
collector = DataCollectorService(db)
|
|
collection = await collector.collect(
|
|
brand_name=brand.name,
|
|
brand_aliases=brand.aliases or [],
|
|
website=brand.website,
|
|
industry=brand.industry,
|
|
)
|
|
|
|
geo_service = GEODiagnosisService()
|
|
geo_result = geo_service.diagnose(collection.diagnosis_input)
|
|
|
|
combined_score = round((seo_result.overall_score + geo_result.overall_score) / 2, 2)
|
|
|
|
logger.info(
|
|
f"综合诊断完成: brand_id={brand_id}, brand={brand.name}, "
|
|
f"seo_score={seo_result.overall_score}, "
|
|
f"geo_score={geo_result.overall_score}, "
|
|
f"combined_score={combined_score}"
|
|
)
|
|
|
|
return {
|
|
"seo_score": seo_result.overall_score,
|
|
"geo_score": geo_result.overall_score,
|
|
"combined_score": combined_score,
|
|
"seo_diagnosis": seo_result.to_dict(),
|
|
"geo_diagnosis": geo_result.to_dict(),
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"综合诊断失败: brand_id={brand_id}, error={str(e)}", exc_info=True)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="综合诊断服务异常,请稍后重试",
|
|
)
|
|
|
|
|
|
async def _run_geo_diagnosis(
|
|
record_id: uuid.UUID,
|
|
brand_id: uuid.UUID,
|
|
brand_name: str,
|
|
brand_aliases: list[str],
|
|
website: str | None,
|
|
industry: str | None,
|
|
user_id: uuid.UUID,
|
|
) -> None:
|
|
from app.database import AsyncSessionLocal
|
|
|
|
async with AsyncSessionLocal() as db:
|
|
try:
|
|
stmt = select(DiagnosisRecord).where(DiagnosisRecord.id == record_id)
|
|
result = await db.execute(stmt)
|
|
record = result.scalar_one_or_none()
|
|
if not record:
|
|
return
|
|
|
|
record.status = "running"
|
|
await db.commit()
|
|
|
|
collector = DataCollectorService(db)
|
|
collection = await collector.collect(
|
|
brand_name=brand_name,
|
|
brand_aliases=brand_aliases,
|
|
website=website,
|
|
industry=industry,
|
|
)
|
|
|
|
geo_service = GEODiagnosisService()
|
|
diagnosis_result = geo_service.diagnose(collection.diagnosis_input)
|
|
|
|
record.status = "completed"
|
|
record.overall_score = diagnosis_result.overall_score
|
|
record.result_json = diagnosis_result.to_dict()
|
|
record.completed_at = datetime.now(UTC)
|
|
record.collection_metadata = collection.metadata
|
|
if collection.errors:
|
|
record.collection_metadata["errors"] = collection.errors
|
|
|
|
await db.commit()
|
|
|
|
logger.info(
|
|
f"GEO诊断完成: brand_id={brand_id}, brand={brand_name}, "
|
|
f"score={diagnosis_result.overall_score}"
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"GEO诊断任务失败: record_id={record_id}, error={e}", exc_info=True)
|
|
try:
|
|
stmt = select(DiagnosisRecord).where(DiagnosisRecord.id == record_id)
|
|
result = await db.execute(stmt)
|
|
record = result.scalar_one_or_none()
|
|
if record:
|
|
record.status = "failed"
|
|
record.error_message = str(e)
|
|
await db.commit()
|
|
except Exception:
|
|
logger.error(f"更新失败状态也失败: record_id={record_id}")
|
|
|
|
|
|
def _build_diagnosis_response(result_json: dict | None, is_paid: bool) -> GEODiagnosisResponse:
|
|
if not result_json:
|
|
return GEODiagnosisResponse(
|
|
overall_score=0,
|
|
health_level="danger",
|
|
health_level_label=get_health_level_label("danger"),
|
|
dimensions=[],
|
|
recommendations=[],
|
|
is_full_report=is_paid,
|
|
)
|
|
|
|
dimensions = result_json.get("dimensions", [])
|
|
if not is_paid:
|
|
dimensions = [d for d in dimensions if d.get("name") in _FREE_TIER_DIMENSIONS]
|
|
|
|
recommendations = result_json.get("recommendations", [])
|
|
if not is_paid:
|
|
recommendations = [r for r in recommendations if r.get("priority") == "P0"]
|
|
|
|
return GEODiagnosisResponse(
|
|
overall_score=result_json.get("overall_score", 0),
|
|
health_level=result_json.get("health_level", "danger"),
|
|
health_level_label=result_json.get("health_level_label", get_health_level_label("danger")),
|
|
dimensions=dimensions,
|
|
recommendations=recommendations,
|
|
is_full_report=is_paid,
|
|
)
|
|
|
|
|
|
async def _find_recent_completed(
|
|
db: AsyncSession, brand_id: uuid.UUID, hours: int = 24
|
|
) -> DiagnosisRecord | None:
|
|
from datetime import timedelta
|
|
|
|
cutoff = datetime.now(UTC) - timedelta(hours=hours)
|
|
stmt = (
|
|
select(DiagnosisRecord)
|
|
.where(
|
|
DiagnosisRecord.brand_id == brand_id,
|
|
DiagnosisRecord.status == "completed",
|
|
DiagnosisRecord.completed_at >= cutoff,
|
|
)
|
|
.order_by(DiagnosisRecord.completed_at.desc())
|
|
.limit(1)
|
|
)
|
|
result = await db.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
async def _get_brand_or_404(
|
|
brand_id: uuid.UUID,
|
|
current_user: User,
|
|
db: AsyncSession,
|
|
) -> Brand:
|
|
user_uuid = _to_uuid(current_user.id)
|
|
stmt = select(Brand).where(Brand.id == brand_id, Brand.user_id == user_uuid)
|
|
result = await db.execute(stmt)
|
|
brand = result.scalar_one_or_none()
|
|
|
|
if not brand:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="品牌不存在",
|
|
)
|
|
|
|
return brand
|