geo/backend/app/main.py

202 lines
7.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
from contextlib import asynccontextmanager
from datetime import datetime, timezone
from fastapi import FastAPI, HTTPException, Request, Depends
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import text
# 必须在其他模块 import 之前初始化 JSON 日志
from app.logging_config import setup_logging
setup_logging()
from fastapi.middleware.cors import CORSMiddleware
from app.api.admin import router as admin_router
from app.api.content import router as content_router
from app.api.contents import router as contents_router
from app.api.clients import router as clients_router
from app.api.agents import router as agents_router
from app.api.knowledge import router as knowledge_router
from app.api.distribution import router as distribution_router
from app.api.analytics import router as analytics_router
from app.api.lifecycle import router as lifecycle_router
from app.api.auth import router as auth_router
from app.api.citations import router as citations_router
from app.api.queries import router as queries_router
from app.api.reports import router as reports_router
from app.api.subscriptions import router as subscription_router
from app.api.alerts import router as alerts_router
from app.api.dashboard import router as dashboard_router
from app.api.brands import router as brands_router
from app.api.onboarding import router as onboarding_router
from app.config import settings
from app.database import engine, Base
from app.schemas.common import ErrorResponse, ErrorCode
from app.middleware.rate_limit import RateLimitMiddleware
from app.middleware.logging_middleware import RequestLoggingMiddleware
from app.middleware.request_id import RequestIdMiddleware
from app.middleware.metrics import MetricsMiddleware
from app.database import get_db
from app.workers.scheduler import query_scheduler
@asynccontextmanager
async def lifespan(app: FastAPI):
import app.models
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
query_scheduler.start()
yield
await query_scheduler.shutdown()
app = FastAPI(
title="GEO Platform API",
version="1.0.0",
lifespan=lifespan,
)
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException) -> JSONResponse:
"""统一 HTTP 异常响应格式。"""
code = ErrorCode.from_status(exc.status_code)
return JSONResponse(
status_code=exc.status_code,
content=ErrorResponse(
detail=str(exc.detail),
code=code,
).model_dump(mode="json"),
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError) -> JSONResponse:
"""统一参数校验异常响应格式。"""
return JSONResponse(
status_code=422,
content=ErrorResponse(
detail="请求参数校验失败",
code=ErrorCode.VALIDATION_ERROR,
extra={"errors": exc.errors()},
).model_dump(mode="json"),
)
@app.exception_handler(Exception)
async def general_exception_handler(request: Request, exc: Exception) -> JSONResponse:
"""兜底异常处理器,避免内部错误泄漏给客户端。"""
logging.getLogger(__name__).exception("Unhandled exception: %s", exc)
return JSONResponse(
status_code=500,
content=ErrorResponse(
detail="服务器内部错误,请稍后重试",
code=ErrorCode.INTERNAL_ERROR,
).model_dump(mode="json"),
)
_allow_origins = [origin.strip() for origin in settings.CORS_ORIGINS.split(",") if origin.strip()]
if not _allow_origins:
_allow_origins = ["http://localhost:3000"]
app.add_middleware(
CORSMiddleware,
allow_origins=_allow_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 安全响应头
@app.middleware("http")
async def add_security_headers(request, call_next):
response = await call_next(request)
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
response.headers["X-XSS-Protection"] = "1; mode=block"
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
return response
# 中间件注册顺序FastAPI 后进先出,最后注册的最先执行)
# 执行链RequestId → Metrics → RateLimit → RequestLogging → CORS → SecurityHeaders
app.add_middleware(RequestLoggingMiddleware)
app.add_middleware(RateLimitMiddleware)
app.add_middleware(MetricsMiddleware)
app.add_middleware(RequestIdMiddleware)
app.include_router(auth_router, prefix="/api/v1/auth", tags=["认证"])
app.include_router(queries_router, prefix="/api/v1/queries", tags=["查询词"])
app.include_router(citations_router, prefix="/api/v1/citations", tags=["引用数据"])
app.include_router(reports_router, prefix="/api/v1/reports", tags=["报告"])
app.include_router(subscription_router)
app.include_router(admin_router)
app.include_router(agents_router, prefix="/api/v1/agents", tags=["Agent管理"])
app.include_router(lifecycle_router, prefix="/api/v1/lifecycle", tags=["lifecycle"])
app.include_router(knowledge_router, prefix="/api/v1/knowledge", tags=["知识库"])
app.include_router(content_router, prefix="/api/v1/content", tags=["内容生产"])
app.include_router(contents_router, prefix="/api/v1/contents", tags=["内容管理"])
app.include_router(clients_router, prefix="/api/v1/clients", tags=["客户管理"])
app.include_router(distribution_router, prefix="/api/v1/distribution", tags=["内容分发"])
app.include_router(analytics_router, prefix="/api/v1/analytics", tags=["监测优化"])
app.include_router(alerts_router, prefix="/api/v1/alerts", tags=["告警通知"])
app.include_router(dashboard_router, prefix="/api/v1/dashboard", tags=["仪表盘"])
app.include_router(brands_router, prefix="/api/v1/brands", tags=["品牌管理"])
app.include_router(onboarding_router, prefix="/api/v1")
@app.get("/health", tags=["可观测性"])
async def health_check():
"""存活检查Liveness服务进程是否运行正常。不依赖外部服务。"""
return {
"status": "healthy",
"timestamp": datetime.now(timezone.utc).isoformat(),
}
@app.get("/ready", tags=["可观测性"])
async def readiness_check(db: AsyncSession = Depends(get_db)):
"""就绪检查Readiness依赖服务DB / Redis是否就绪。
供 Kubernetes readinessProbe / Docker healthcheck 使用。
不需要认证。
"""
import redis.asyncio as aioredis # type: ignore
from app.config import settings as _settings
# --- 检查数据库 ---
try:
await db.execute(text("SELECT 1"))
db_ok = True
except Exception:
db_ok = False
# --- 检查 Redis ---
redis_ok = False
try:
redis_client = aioredis.from_url(_settings.REDIS_URL, socket_connect_timeout=2)
await redis_client.ping()
await redis_client.aclose()
redis_ok = True
except Exception:
pass
all_ok = db_ok and redis_ok
return JSONResponse(
status_code=200 if all_ok else 503,
content={
"status": "ready" if all_ok else "not_ready",
"checks": {
"database": "ok" if db_ok else "error",
"redis": "ok" if redis_ok else "error",
},
"timestamp": datetime.now(timezone.utc).isoformat(),
},
)