84 lines
2.4 KiB
Python
84 lines
2.4 KiB
Python
import logging
|
|
from contextlib import asynccontextmanager
|
|
|
|
from fastapi import FastAPI
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s [%(name)s] %(levelname)s: %(message)s"
|
|
)
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
from app.api.admin import router as admin_router
|
|
from app.api.auth import router as auth_router
|
|
from app.api.citations import router as citations_router
|
|
from app.api.queries import router as queries_router
|
|
from app.api.reports import router as reports_router
|
|
from app.api.subscriptions import router as subscription_router
|
|
from app.config import settings
|
|
from app.database import engine, Base
|
|
from app.middleware.rate_limit import RateLimitMiddleware
|
|
from app.middleware.logging_middleware import RequestLoggingMiddleware
|
|
from app.workers.scheduler import query_scheduler
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
import app.models
|
|
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
|
|
query_scheduler.start()
|
|
|
|
yield
|
|
|
|
await query_scheduler.shutdown()
|
|
|
|
|
|
app = FastAPI(
|
|
title="GEO Platform API",
|
|
version="1.0.0",
|
|
lifespan=lifespan,
|
|
)
|
|
|
|
_allow_origins = [origin.strip() for origin in settings.CORS_ORIGINS.split(",") if origin.strip()]
|
|
if not _allow_origins:
|
|
_allow_origins = ["http://localhost:3000"]
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=_allow_origins,
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# 安全响应头
|
|
@app.middleware("http")
|
|
async def add_security_headers(request, call_next):
|
|
response = await call_next(request)
|
|
response.headers["X-Content-Type-Options"] = "nosniff"
|
|
response.headers["X-Frame-Options"] = "DENY"
|
|
response.headers["X-XSS-Protection"] = "1; mode=block"
|
|
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
|
return response
|
|
|
|
# 限流中间件
|
|
app.add_middleware(RateLimitMiddleware)
|
|
|
|
# 请求日志中间件
|
|
app.add_middleware(RequestLoggingMiddleware)
|
|
|
|
app.include_router(auth_router, prefix="/api/v1/auth", tags=["认证"])
|
|
app.include_router(queries_router, prefix="/api/v1/queries", tags=["查询词"])
|
|
app.include_router(citations_router, prefix="/api/v1/citations", tags=["引用数据"])
|
|
app.include_router(reports_router, prefix="/api/v1/reports", tags=["报告"])
|
|
app.include_router(subscription_router)
|
|
app.include_router(admin_router)
|
|
|
|
|
|
@app.get("/health")
|
|
async def health_check():
|
|
return {"status": "ok"}
|