geo/backend/app/database.py

72 lines
1.9 KiB
Python

from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
from sqlalchemy.orm import declarative_base
from sqlalchemy import text, JSON
from sqlalchemy import exc as sqlalchemy_exc
from sqlalchemy.types import TypeDecorator
from sqlalchemy.dialects.postgresql import JSONB
import logging
from app.config import settings
logger = logging.getLogger(__name__)
class JSONType(TypeDecorator):
"""A JSON type that uses JSONB on PostgreSQL and JSON on other databases."""
impl = JSON
cache_ok = True
def load_dialect_impl(self, dialect):
if dialect.name == "postgresql":
return dialect.type_descriptor(JSONB())
return dialect.type_descriptor(JSON())
_db_url = settings.DATABASE_URL
_connect_args = {}
if _db_url.startswith("postgresql+asyncpg"):
_connect_args = {"ssl": False}
engine = create_async_engine(
_db_url,
pool_size=10,
max_overflow=20,
pool_timeout=30,
pool_recycle=3600,
pool_pre_ping=True,
echo=False,
connect_args=_connect_args,
)
AsyncSessionLocal = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False,
)
Base = declarative_base()
async def get_db() -> AsyncSession:
async with AsyncSessionLocal() as session:
try:
yield session
finally:
await session.close()
async def check_db_connection() -> bool:
"""检查数据库连接是否正常"""
try:
async with AsyncSessionLocal() as session:
await session.execute(text("SELECT 1"))
return True
except sqlalchemy_exc.SQLAlchemyError as e:
logger.error(f"Database connection error: {e}", exc_info=True)
raise ConnectionError(f"Failed to connect to database: {e}") from e
except Exception as e:
logger.error(f"Unexpected database error: {e}", exc_info=True)
raise