from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession from sqlalchemy.orm import declarative_base from sqlalchemy import text, JSON from sqlalchemy import exc as sqlalchemy_exc from sqlalchemy.types import TypeDecorator from sqlalchemy.dialects.postgresql import JSONB import logging from app.config import settings logger = logging.getLogger(__name__) class JSONType(TypeDecorator): """A JSON type that uses JSONB on PostgreSQL and JSON on other databases.""" impl = JSON cache_ok = True def load_dialect_impl(self, dialect): if dialect.name == "postgresql": return dialect.type_descriptor(JSONB()) return dialect.type_descriptor(JSON()) _db_url = settings.DATABASE_URL _connect_args = {} if _db_url.startswith("postgresql+asyncpg"): _connect_args = {"ssl": False} engine = create_async_engine( _db_url, pool_size=10, max_overflow=20, pool_timeout=30, pool_recycle=3600, pool_pre_ping=True, echo=False, connect_args=_connect_args, ) AsyncSessionLocal = async_sessionmaker( engine, class_=AsyncSession, expire_on_commit=False, autoflush=False, autocommit=False, ) Base = declarative_base() async def get_db() -> AsyncSession: async with AsyncSessionLocal() as session: try: yield session finally: await session.close() async def check_db_connection() -> bool: """检查数据库连接是否正常""" try: async with AsyncSessionLocal() as session: await session.execute(text("SELECT 1")) return True except sqlalchemy_exc.SQLAlchemyError as e: logger.error(f"Database connection error: {e}", exc_info=True) raise ConnectionError(f"Failed to connect to database: {e}") from e except Exception as e: logger.error(f"Unexpected database error: {e}", exc_info=True) raise