66 lines
2.0 KiB
Python
66 lines
2.0 KiB
Python
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
|
||
from sqlalchemy.orm import declarative_base
|
||
from sqlalchemy import text, JSON
|
||
from sqlalchemy import exc as sqlalchemy_exc
|
||
from sqlalchemy.types import TypeDecorator
|
||
from sqlalchemy.dialects.postgresql import JSONB
|
||
import logging
|
||
|
||
from app.config import settings
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class JSONType(TypeDecorator):
|
||
"""A JSON type that uses JSONB on PostgreSQL and JSON on other databases."""
|
||
impl = JSON
|
||
cache_ok = True
|
||
|
||
def load_dialect_impl(self, dialect):
|
||
if dialect.name == "postgresql":
|
||
return dialect.type_descriptor(JSONB())
|
||
return dialect.type_descriptor(JSON())
|
||
|
||
|
||
engine = create_async_engine(
|
||
settings.DATABASE_URL,
|
||
pool_size=10, # 连接池大小
|
||
max_overflow=20, # 最大溢出连接数
|
||
pool_timeout=30, # 等待连接超时(秒)
|
||
pool_recycle=3600, # 连接回收时间(1小时)
|
||
pool_pre_ping=True, # 使用前 ping 检查连接有效性
|
||
echo=False, # 生产环境关闭 SQL echo
|
||
)
|
||
|
||
AsyncSessionLocal = async_sessionmaker(
|
||
engine,
|
||
class_=AsyncSession,
|
||
expire_on_commit=False,
|
||
autoflush=False,
|
||
autocommit=False,
|
||
)
|
||
|
||
Base = declarative_base()
|
||
|
||
|
||
async def get_db() -> AsyncSession:
|
||
async with AsyncSessionLocal() as session:
|
||
try:
|
||
yield session
|
||
finally:
|
||
await session.close()
|
||
|
||
|
||
async def check_db_connection() -> bool:
|
||
"""检查数据库连接是否正常"""
|
||
try:
|
||
async with AsyncSessionLocal() as session:
|
||
await session.execute(text("SELECT 1"))
|
||
return True
|
||
except sqlalchemy_exc.SQLAlchemyError as e:
|
||
logger.error(f"Database connection error: {e}", exc_info=True)
|
||
raise ConnectionError(f"Failed to connect to database: {e}") from e
|
||
except Exception as e:
|
||
logger.error(f"Unexpected database error: {e}", exc_info=True)
|
||
raise
|