46 lines
1.9 KiB
Python
46 lines
1.9 KiB
Python
import pytest
|
|
from unittest.mock import patch, AsyncMock, MagicMock
|
|
from sqlalchemy.exc import SQLAlchemyError, OperationalError
|
|
|
|
from app.database import check_db_connection
|
|
|
|
|
|
class TestDatabaseExceptionHandling:
|
|
"""测试 database.py 中的异常处理行为"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_db_connection_handles_sqlalchemy_error(self):
|
|
"""测试 check_db_connection 对 SQLAlchemyError 的处理"""
|
|
with patch('app.database.AsyncSessionLocal') as mock_session_local:
|
|
mock_session = AsyncMock()
|
|
mock_session_local.return_value.__aenter__.return_value = mock_session
|
|
mock_session.execute.side_effect = OperationalError(
|
|
statement="SELECT 1",
|
|
params={},
|
|
orig=Exception("Connection refused")
|
|
)
|
|
|
|
with pytest.raises(ConnectionError, match="Failed to connect to database"):
|
|
await check_db_connection()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_db_connection_handles_generic_exception(self):
|
|
"""测试 check_db_connection 对通用异常的处理"""
|
|
with patch('app.database.AsyncSessionLocal') as mock_session_local:
|
|
mock_session = AsyncMock()
|
|
mock_session_local.return_value.__aenter__.return_value = mock_session
|
|
mock_session.execute.side_effect = RuntimeError("Unexpected error")
|
|
|
|
with pytest.raises(RuntimeError):
|
|
await check_db_connection()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_check_db_connection_success(self):
|
|
"""测试 check_db_connection 成功时返回 True"""
|
|
with patch('app.database.AsyncSessionLocal') as mock_session_local:
|
|
mock_session = AsyncMock()
|
|
mock_session_local.return_value.__aenter__.return_value = mock_session
|
|
|
|
result = await check_db_connection()
|
|
assert result is True
|