geo/backend/tests/test_services/test_detection_scheduler.py

378 lines
14 KiB
Python

import uuid
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import pytest_asyncio
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from app.database import Base
from app.models.brand import Brand
from app.models.user import User
from app.services.auth import hash_password
from tests.fixtures.auth import _to_uuid
@pytest_asyncio.fixture
async def async_engine():
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest_asyncio.fixture
async def async_session(async_engine):
session_maker = async_sessionmaker(
async_engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False,
)
async with session_maker() as session:
yield session
@pytest_asyncio.fixture
async def test_user(async_session):
user = User(
id=str(uuid.uuid4()),
email="test@example.com",
password=hash_password("Test@123456"),
firstName="Test User",
plan="free",
max_queries=5,
isActive=True,
emailVerified=True,
)
async_session.add(user)
await async_session.commit()
await async_session.refresh(user)
return user
@pytest_asyncio.fixture
async def test_brand(async_session, test_user):
brand = Brand(
id=uuid.uuid4(),
user_id=_to_uuid(test_user.id),
name="Test Brand",
aliases=["TestBrand", "TB"],
website="https://testbrand.com",
industry="technology",
platforms=["wenxin", "kimi"],
frequency="weekly",
status="active",
)
async_session.add(brand)
await async_session.commit()
await async_session.refresh(brand)
return brand
class TestDetectionTaskModel:
@pytest.mark.asyncio
async def test_create_detection_task(self, async_session, test_brand, test_user):
from app.models.detection_task import DetectionTask
task = DetectionTask(
brand_id=test_brand.id,
user_id=_to_uuid(test_user.id),
name="每日品牌检测",
frequency="daily",
engines=["chatgpt", "perplexity"],
queries=["最佳保险品牌", "保险推荐"],
competitor_names=["竞品A", "竞品B"],
)
async_session.add(task)
await async_session.commit()
await async_session.refresh(task)
assert task.id is not None
assert task.brand_id == test_brand.id
assert task.user_id == _to_uuid(test_user.id)
assert task.name == "每日品牌检测"
assert task.frequency == "daily"
assert task.engines == ["chatgpt", "perplexity"]
assert task.queries == ["最佳保险品牌", "保险推荐"]
assert task.competitor_names == ["竞品A", "竞品B"]
assert task.is_active is True
assert task.last_run_at is None
assert task.next_run_at is not None
assert task.created_at is not None
assert task.updated_at is not None
@pytest.mark.asyncio
async def test_detection_task_default_values(self, async_session, test_brand, test_user):
from app.models.detection_task import DetectionTask
task = DetectionTask(
brand_id=test_brand.id,
user_id=_to_uuid(test_user.id),
name="简单检测",
frequency="weekly",
engines=["chatgpt"],
queries=["测试查询"],
)
async_session.add(task)
await async_session.commit()
await async_session.refresh(task)
assert task.is_active is True
assert task.competitor_names is None
assert task.last_run_at is None
class TestDetectionSchedulerService:
@pytest.mark.asyncio
async def test_create_task(self, async_session, test_brand, test_user):
from app.services.detection.detection_scheduler import DetectionSchedulerService
service = DetectionSchedulerService()
task_data = {
"name": "每日品牌检测",
"frequency": "daily",
"engines": ["chatgpt", "perplexity"],
"queries": ["最佳保险品牌", "保险推荐"],
"competitor_names": ["竞品A"],
}
task = await service.create_task(task_data, test_brand.id, _to_uuid(test_user.id), async_session)
assert task.id is not None
assert task.name == "每日品牌检测"
assert task.frequency == "daily"
assert task.brand_id == test_brand.id
assert task.user_id == _to_uuid(test_user.id)
@pytest.mark.asyncio
async def test_update_task(self, async_session, test_brand, test_user):
from app.models.detection_task import DetectionTask
from app.services.detection.detection_scheduler import DetectionSchedulerService
task = DetectionTask(
brand_id=test_brand.id,
user_id=_to_uuid(test_user.id),
name="旧名称",
frequency="weekly",
engines=["chatgpt"],
queries=["查询1"],
)
async_session.add(task)
await async_session.commit()
await async_session.refresh(task)
service = DetectionSchedulerService()
update_data = {
"name": "新名称",
"frequency": "daily",
"engines": ["chatgpt", "perplexity"],
}
updated = await service.update_task(task.id, update_data, _to_uuid(test_user.id), async_session)
assert updated.name == "新名称"
assert updated.frequency == "daily"
assert updated.engines == ["chatgpt", "perplexity"]
@pytest.mark.asyncio
async def test_delete_task(self, async_session, test_brand, test_user):
from app.models.detection_task import DetectionTask
from app.services.detection.detection_scheduler import DetectionSchedulerService
task = DetectionTask(
brand_id=test_brand.id,
user_id=_to_uuid(test_user.id),
name="待删除",
frequency="weekly",
engines=["chatgpt"],
queries=["查询1"],
)
async_session.add(task)
await async_session.commit()
await async_session.refresh(task)
service = DetectionSchedulerService()
result = await service.delete_task(task.id, _to_uuid(test_user.id), async_session)
assert result is True
stmt = select(DetectionTask).where(DetectionTask.id == task.id)
db_result = await async_session.execute(stmt)
assert db_result.scalar_one_or_none() is None
@pytest.mark.asyncio
async def test_get_tasks(self, async_session, test_brand, test_user):
from app.models.detection_task import DetectionTask
from app.services.detection.detection_scheduler import DetectionSchedulerService
for i in range(3):
task = DetectionTask(
brand_id=test_brand.id,
user_id=_to_uuid(test_user.id),
name=f"任务{i}",
frequency="daily",
engines=["chatgpt"],
queries=[f"查询{i}"],
)
async_session.add(task)
await async_session.commit()
service = DetectionSchedulerService()
tasks = await service.get_tasks(test_brand.id, _to_uuid(test_user.id), async_session)
assert len(tasks) == 3
@pytest.mark.asyncio
async def test_trigger_task(self, async_session, test_brand, test_user):
from app.models.detection_task import DetectionTask
from app.services.detection.detection_scheduler import DetectionSchedulerService
task = DetectionTask(
brand_id=test_brand.id,
user_id=_to_uuid(test_user.id),
name="手动触发测试",
frequency="daily",
engines=["chatgpt"],
queries=["测试查询"],
)
async_session.add(task)
await async_session.commit()
await async_session.refresh(task)
service = DetectionSchedulerService()
with patch.object(service, "execute_task", new_callable=AsyncMock) as mock_execute:
mock_execute.return_value = {"status": "success", "results": []}
result = await service.trigger_task(task.id, _to_uuid(test_user.id), async_session)
assert result["status"] == "success"
mock_execute.assert_called_once()
@pytest.mark.asyncio
async def test_frequency_validation_hourly(self, async_session, test_brand, test_user):
from app.services.detection.detection_scheduler import DetectionSchedulerService
service = DetectionSchedulerService()
task_data = {
"name": "每小时检测",
"frequency": "hourly",
"engines": ["chatgpt"],
"queries": ["查询"],
}
task = await service.create_task(task_data, test_brand.id, _to_uuid(test_user.id), async_session)
assert task.frequency == "hourly"
assert task.next_run_at is not None
@pytest.mark.asyncio
async def test_frequency_validation_daily(self, async_session, test_brand, test_user):
from app.services.detection.detection_scheduler import DetectionSchedulerService
service = DetectionSchedulerService()
task_data = {
"name": "每日检测",
"frequency": "daily",
"engines": ["chatgpt"],
"queries": ["查询"],
}
task = await service.create_task(task_data, test_brand.id, _to_uuid(test_user.id), async_session)
assert task.frequency == "daily"
assert task.next_run_at is not None
@pytest.mark.asyncio
async def test_frequency_validation_weekly(self, async_session, test_brand, test_user):
from app.services.detection.detection_scheduler import DetectionSchedulerService
service = DetectionSchedulerService()
task_data = {
"name": "每周检测",
"frequency": "weekly",
"engines": ["chatgpt"],
"queries": ["查询"],
}
task = await service.create_task(task_data, test_brand.id, _to_uuid(test_user.id), async_session)
assert task.frequency == "weekly"
assert task.next_run_at is not None
@pytest.mark.asyncio
async def test_frequency_validation_invalid(self, async_session, test_brand, test_user):
from app.services.detection.detection_scheduler import DetectionSchedulerService
service = DetectionSchedulerService()
task_data = {
"name": "无效频率",
"frequency": "monthly",
"engines": ["chatgpt"],
"queries": ["查询"],
}
with pytest.raises(ValueError, match="frequency"):
await service.create_task(task_data, test_brand.id, _to_uuid(test_user.id), async_session)
@pytest.mark.asyncio
async def test_execute_task_flow(self, async_session, test_brand, test_user):
from app.models.detection_task import DetectionTask
from app.services.detection.detection_scheduler import DetectionSchedulerService
task = DetectionTask(
brand_id=test_brand.id,
user_id=_to_uuid(test_user.id),
name="执行流程测试",
frequency="daily",
engines=["chatgpt"],
queries=["测试查询"],
competitor_names=["竞品A"],
)
async_session.add(task)
await async_session.commit()
await async_session.refresh(task)
service = DetectionSchedulerService()
mock_batch_result = [
MagicMock(
engine_type=MagicMock(value="chatgpt"),
has_brand_citation=True,
has_competitor_citation=False,
)
]
with patch.object(
service, "_run_batch_query", new_callable=AsyncMock, return_value=mock_batch_result
), patch.object(
service, "_generate_alerts_if_needed", new_callable=AsyncMock, return_value=[]
):
result = await service.execute_task(task, async_session)
assert result["status"] == "success"
assert "results" in result
await async_session.refresh(task)
assert task.last_run_at is not None
assert task.next_run_at is not None
@pytest.mark.asyncio
async def test_delete_task_not_found(self, async_session, test_user):
from app.services.detection.detection_scheduler import DetectionSchedulerService
service = DetectionSchedulerService()
result = await service.delete_task(uuid.uuid4(), _to_uuid(test_user.id), async_session)
assert result is False
@pytest.mark.asyncio
async def test_update_task_not_found(self, async_session, test_user):
from app.services.detection.detection_scheduler import DetectionSchedulerService, TaskNotFoundError
service = DetectionSchedulerService()
with pytest.raises(TaskNotFoundError):
await service.update_task(uuid.uuid4(), {"name": "新名称"}, _to_uuid(test_user.id), async_session)
@pytest.mark.asyncio
async def test_get_tasks_empty(self, async_session, test_brand, test_user):
from app.services.detection.detection_scheduler import DetectionSchedulerService
service = DetectionSchedulerService()
tasks = await service.get_tasks(test_brand.id, _to_uuid(test_user.id), async_session)
assert tasks == []