123 lines
4.0 KiB
Python
123 lines
4.0 KiB
Python
import uuid
|
||
from datetime import datetime, timezone, timedelta
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
import pytest
|
||
|
||
from app.models.query import Query
|
||
from app.models.user import User
|
||
from app.services.auth import hash_password
|
||
from app.workers.citation_engine import CitationEngine
|
||
from app.workers.scheduler import QueryScheduler
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 1. 调度器启动 / 关闭
|
||
# ---------------------------------------------------------------------------
|
||
@pytest.mark.asyncio
|
||
async def test_scheduler_start_stop():
|
||
scheduler = QueryScheduler()
|
||
scheduler.engine = AsyncMock()
|
||
|
||
scheduler.start()
|
||
# Verify the scheduled job was added
|
||
job = scheduler.scheduler.get_job("check_queries")
|
||
assert job is not None
|
||
assert job.name == "检查并执行到期的查询任务"
|
||
|
||
await scheduler.shutdown()
|
||
# Verify engine.close was awaited
|
||
scheduler.engine.close.assert_awaited_once()
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 2. 查询任务筛选:只选择 status=active 且 next_query_at <= now 的查询
|
||
# ---------------------------------------------------------------------------
|
||
@pytest.mark.asyncio
|
||
async def test_scheduler_query_filtering(test_session):
|
||
# Create a user first
|
||
user = User(
|
||
email="sched@test.com",
|
||
password_hash=hash_password("pass"),
|
||
name="Scheduler",
|
||
)
|
||
test_session.add(user)
|
||
await test_session.commit()
|
||
await test_session.refresh(user)
|
||
|
||
now = datetime.now(timezone.utc)
|
||
|
||
# q1: active and overdue -> should be picked
|
||
q1 = Query(
|
||
user_id=user.id,
|
||
keyword="overdue",
|
||
target_brand="B1",
|
||
status="active",
|
||
next_query_at=now - timedelta(hours=1),
|
||
)
|
||
# q2: active but in the future -> should NOT be picked
|
||
q2 = Query(
|
||
user_id=user.id,
|
||
keyword="future",
|
||
target_brand="B2",
|
||
status="active",
|
||
next_query_at=now + timedelta(days=1),
|
||
)
|
||
# q3: paused and overdue -> should NOT be picked
|
||
q3 = Query(
|
||
user_id=user.id,
|
||
keyword="paused",
|
||
target_brand="B3",
|
||
status="paused",
|
||
next_query_at=now - timedelta(hours=1),
|
||
)
|
||
|
||
test_session.add_all([q1, q2, q3])
|
||
await test_session.commit()
|
||
|
||
scheduler = QueryScheduler()
|
||
scheduler.engine = AsyncMock()
|
||
|
||
# Mock AsyncSessionLocal so scheduler uses our test session
|
||
mock_local = MagicMock()
|
||
mock_local.return_value.__aenter__ = AsyncMock(return_value=test_session)
|
||
mock_local.return_value.__aexit__ = AsyncMock(return_value=False)
|
||
|
||
with patch("app.workers.scheduler.AsyncSessionLocal", mock_local):
|
||
await scheduler.check_and_execute_queries()
|
||
|
||
# execute_query should be called exactly once (for q1)
|
||
scheduler.engine.execute_query.assert_called_once()
|
||
called_query = scheduler.engine.execute_query.call_args[0][0]
|
||
assert called_query.keyword == "overdue"
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 3. 频率计算:daily 和 weekly 的 next_query_at 正确计算
|
||
# ---------------------------------------------------------------------------
|
||
@pytest.mark.asyncio
|
||
async def test_scheduler_frequency_calculation_daily():
|
||
engine = CitationEngine()
|
||
now = datetime.utcnow()
|
||
result = engine._calculate_next_query_at("daily")
|
||
expected = now + timedelta(days=1)
|
||
assert abs((result - expected).total_seconds()) < 5
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_scheduler_frequency_calculation_weekly():
|
||
engine = CitationEngine()
|
||
now = datetime.utcnow()
|
||
result = engine._calculate_next_query_at("weekly")
|
||
expected = now + timedelta(days=7)
|
||
assert abs((result - expected).total_seconds()) < 5
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_scheduler_frequency_calculation_default():
|
||
engine = CitationEngine()
|
||
now = datetime.utcnow()
|
||
result = engine._calculate_next_query_at(None)
|
||
expected = now + timedelta(days=7)
|
||
assert abs((result - expected).total_seconds()) < 5
|