72 lines
1.9 KiB
Python
72 lines
1.9 KiB
Python
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
|
|
from sqlalchemy.orm import declarative_base
|
|
from sqlalchemy import text, JSON
|
|
from sqlalchemy import exc as sqlalchemy_exc
|
|
from sqlalchemy.types import TypeDecorator
|
|
from sqlalchemy.dialects.postgresql import JSONB
|
|
import logging
|
|
|
|
from app.config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class JSONType(TypeDecorator):
|
|
"""A JSON type that uses JSONB on PostgreSQL and JSON on other databases."""
|
|
impl = JSON
|
|
cache_ok = True
|
|
|
|
def load_dialect_impl(self, dialect):
|
|
if dialect.name == "postgresql":
|
|
return dialect.type_descriptor(JSONB())
|
|
return dialect.type_descriptor(JSON())
|
|
|
|
|
|
_db_url = settings.DATABASE_URL
|
|
_connect_args = {}
|
|
if _db_url.startswith("postgresql+asyncpg"):
|
|
_connect_args = {"ssl": False}
|
|
|
|
engine = create_async_engine(
|
|
_db_url,
|
|
pool_size=10,
|
|
max_overflow=20,
|
|
pool_timeout=30,
|
|
pool_recycle=3600,
|
|
pool_pre_ping=True,
|
|
echo=False,
|
|
connect_args=_connect_args,
|
|
)
|
|
|
|
AsyncSessionLocal = async_sessionmaker(
|
|
engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False,
|
|
autoflush=False,
|
|
autocommit=False,
|
|
)
|
|
|
|
Base = declarative_base()
|
|
|
|
|
|
async def get_db() -> AsyncSession:
|
|
async with AsyncSessionLocal() as session:
|
|
try:
|
|
yield session
|
|
finally:
|
|
await session.close()
|
|
|
|
|
|
async def check_db_connection() -> bool:
|
|
"""检查数据库连接是否正常"""
|
|
try:
|
|
async with AsyncSessionLocal() as session:
|
|
await session.execute(text("SELECT 1"))
|
|
return True
|
|
except sqlalchemy_exc.SQLAlchemyError as e:
|
|
logger.error(f"Database connection error: {e}", exc_info=True)
|
|
raise ConnectionError(f"Failed to connect to database: {e}") from e
|
|
except Exception as e:
|
|
logger.error(f"Unexpected database error: {e}", exc_info=True)
|
|
raise
|