geo/tests/conftest.py

123 lines
3.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent / "backend"))
import pytest
import pytest_asyncio
import uuid
from datetime import datetime
from unittest.mock import AsyncMock, patch
from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
from sqlalchemy.pool import StaticPool
from app.main import app
from app.api.deps import get_current_user
from app.database import Base, get_db
from app.services.auth import create_access_token
# ---------------------------------------------------------------------------
# 全局 mock防止启动真实调度器 / Playwright 浏览器
# ---------------------------------------------------------------------------
@pytest.fixture(scope="session", autouse=True)
def mock_scheduler():
"""Mock the query scheduler to prevent real background jobs in tests."""
with patch("app.main.query_scheduler") as mock_sched:
mock_sched.start = lambda: None
mock_sched.shutdown = AsyncMock()
yield
# ---------------------------------------------------------------------------
# 内存数据库 fixture供集成测试使用
# ---------------------------------------------------------------------------
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
@pytest_asyncio.fixture
async def test_engine():
"""Create a fresh in-memory SQLite engine for each test function."""
engine = create_async_engine(
TEST_DATABASE_URL,
echo=False,
future=True,
poolclass=StaticPool,
connect_args={"check_same_thread": False},
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest_asyncio.fixture
async def test_session(test_engine) -> AsyncSession:
"""Yield an async session bound to the in-memory engine."""
async_session = async_sessionmaker(
test_engine, class_=AsyncSession, expire_on_commit=False, autoflush=False, autocommit=False
)
async with async_session() as session:
yield session
@pytest_asyncio.fixture
async def override_get_db(test_session):
"""Override FastAPI get_db dependency to use the test session."""
async def _get_db():
yield test_session
app.dependency_overrides[get_db] = _get_db
yield test_session
app.dependency_overrides.pop(get_db, None)
# ---------------------------------------------------------------------------
# Mock 用户 fixture供单元测试使用
# ---------------------------------------------------------------------------
@pytest.fixture
def mock_user():
"""Return a mock authenticated user."""
user = AsyncMock()
user.id = uuid.UUID("12345678-1234-1234-1234-123456789abc")
user.email = "test@example.com"
user.name = "Test User"
user.plan = "free"
user.max_queries = 5
user.is_active = True
user.created_at = datetime.now()
return user
@pytest.fixture
def override_get_current_user(mock_user):
"""Override the get_current_user dependency to return a mock user."""
async def _override():
return mock_user
app.dependency_overrides[get_current_user] = _override
yield
app.dependency_overrides.pop(get_current_user, None)
@pytest.fixture
def auth_token(mock_user):
"""Generate a valid JWT access token for the mock user."""
return create_access_token(data={"sub": str(mock_user.id)})
@pytest.fixture
def auth_headers(auth_token):
"""Return request headers containing the Bearer token."""
return {"Authorization": f"Bearer {auth_token}"}
@pytest_asyncio.fixture
async def async_client():
"""Create an async HTTP client for testing the FastAPI app."""
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client