378 lines
14 KiB
Python
378 lines
14 KiB
Python
import uuid
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
|
from sqlalchemy.pool import StaticPool
|
|
|
|
from app.database import Base
|
|
from app.models.brand import Brand
|
|
from app.models.user import User
|
|
from app.services.auth import hash_password
|
|
from tests.fixtures.auth import _to_uuid
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def async_engine():
|
|
engine = create_async_engine(
|
|
"sqlite+aiosqlite:///:memory:",
|
|
connect_args={"check_same_thread": False},
|
|
poolclass=StaticPool,
|
|
)
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
yield engine
|
|
await engine.dispose()
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def async_session(async_engine):
|
|
session_maker = async_sessionmaker(
|
|
async_engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False,
|
|
autoflush=False,
|
|
autocommit=False,
|
|
)
|
|
async with session_maker() as session:
|
|
yield session
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def test_user(async_session):
|
|
user = User(
|
|
id=str(uuid.uuid4()),
|
|
email="test@example.com",
|
|
password=hash_password("Test@123456"),
|
|
firstName="Test User",
|
|
plan="free",
|
|
max_queries=5,
|
|
isActive=True,
|
|
emailVerified=True,
|
|
)
|
|
async_session.add(user)
|
|
await async_session.commit()
|
|
await async_session.refresh(user)
|
|
return user
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def test_brand(async_session, test_user):
|
|
brand = Brand(
|
|
id=uuid.uuid4(),
|
|
user_id=_to_uuid(test_user.id),
|
|
name="Test Brand",
|
|
aliases=["TestBrand", "TB"],
|
|
website="https://testbrand.com",
|
|
industry="technology",
|
|
platforms=["wenxin", "kimi"],
|
|
frequency="weekly",
|
|
status="active",
|
|
)
|
|
async_session.add(brand)
|
|
await async_session.commit()
|
|
await async_session.refresh(brand)
|
|
return brand
|
|
|
|
|
|
class TestDetectionTaskModel:
|
|
@pytest.mark.asyncio
|
|
async def test_create_detection_task(self, async_session, test_brand, test_user):
|
|
from app.models.detection_task import DetectionTask
|
|
|
|
task = DetectionTask(
|
|
brand_id=test_brand.id,
|
|
user_id=_to_uuid(test_user.id),
|
|
name="每日品牌检测",
|
|
frequency="daily",
|
|
engines=["chatgpt", "perplexity"],
|
|
queries=["最佳保险品牌", "保险推荐"],
|
|
competitor_names=["竞品A", "竞品B"],
|
|
)
|
|
async_session.add(task)
|
|
await async_session.commit()
|
|
await async_session.refresh(task)
|
|
|
|
assert task.id is not None
|
|
assert task.brand_id == test_brand.id
|
|
assert task.user_id == _to_uuid(test_user.id)
|
|
assert task.name == "每日品牌检测"
|
|
assert task.frequency == "daily"
|
|
assert task.engines == ["chatgpt", "perplexity"]
|
|
assert task.queries == ["最佳保险品牌", "保险推荐"]
|
|
assert task.competitor_names == ["竞品A", "竞品B"]
|
|
assert task.is_active is True
|
|
assert task.last_run_at is None
|
|
assert task.next_run_at is not None
|
|
assert task.created_at is not None
|
|
assert task.updated_at is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_detection_task_default_values(self, async_session, test_brand, test_user):
|
|
from app.models.detection_task import DetectionTask
|
|
|
|
task = DetectionTask(
|
|
brand_id=test_brand.id,
|
|
user_id=_to_uuid(test_user.id),
|
|
name="简单检测",
|
|
frequency="weekly",
|
|
engines=["chatgpt"],
|
|
queries=["测试查询"],
|
|
)
|
|
async_session.add(task)
|
|
await async_session.commit()
|
|
await async_session.refresh(task)
|
|
|
|
assert task.is_active is True
|
|
assert task.competitor_names is None
|
|
assert task.last_run_at is None
|
|
|
|
|
|
class TestDetectionSchedulerService:
|
|
@pytest.mark.asyncio
|
|
async def test_create_task(self, async_session, test_brand, test_user):
|
|
from app.services.detection.detection_scheduler import DetectionSchedulerService
|
|
|
|
service = DetectionSchedulerService()
|
|
task_data = {
|
|
"name": "每日品牌检测",
|
|
"frequency": "daily",
|
|
"engines": ["chatgpt", "perplexity"],
|
|
"queries": ["最佳保险品牌", "保险推荐"],
|
|
"competitor_names": ["竞品A"],
|
|
}
|
|
task = await service.create_task(task_data, test_brand.id, _to_uuid(test_user.id), async_session)
|
|
|
|
assert task.id is not None
|
|
assert task.name == "每日品牌检测"
|
|
assert task.frequency == "daily"
|
|
assert task.brand_id == test_brand.id
|
|
assert task.user_id == _to_uuid(test_user.id)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_task(self, async_session, test_brand, test_user):
|
|
from app.models.detection_task import DetectionTask
|
|
from app.services.detection.detection_scheduler import DetectionSchedulerService
|
|
|
|
task = DetectionTask(
|
|
brand_id=test_brand.id,
|
|
user_id=_to_uuid(test_user.id),
|
|
name="旧名称",
|
|
frequency="weekly",
|
|
engines=["chatgpt"],
|
|
queries=["查询1"],
|
|
)
|
|
async_session.add(task)
|
|
await async_session.commit()
|
|
await async_session.refresh(task)
|
|
|
|
service = DetectionSchedulerService()
|
|
update_data = {
|
|
"name": "新名称",
|
|
"frequency": "daily",
|
|
"engines": ["chatgpt", "perplexity"],
|
|
}
|
|
updated = await service.update_task(task.id, update_data, _to_uuid(test_user.id), async_session)
|
|
|
|
assert updated.name == "新名称"
|
|
assert updated.frequency == "daily"
|
|
assert updated.engines == ["chatgpt", "perplexity"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_task(self, async_session, test_brand, test_user):
|
|
from app.models.detection_task import DetectionTask
|
|
from app.services.detection.detection_scheduler import DetectionSchedulerService
|
|
|
|
task = DetectionTask(
|
|
brand_id=test_brand.id,
|
|
user_id=_to_uuid(test_user.id),
|
|
name="待删除",
|
|
frequency="weekly",
|
|
engines=["chatgpt"],
|
|
queries=["查询1"],
|
|
)
|
|
async_session.add(task)
|
|
await async_session.commit()
|
|
await async_session.refresh(task)
|
|
|
|
service = DetectionSchedulerService()
|
|
result = await service.delete_task(task.id, _to_uuid(test_user.id), async_session)
|
|
assert result is True
|
|
|
|
stmt = select(DetectionTask).where(DetectionTask.id == task.id)
|
|
db_result = await async_session.execute(stmt)
|
|
assert db_result.scalar_one_or_none() is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_tasks(self, async_session, test_brand, test_user):
|
|
from app.models.detection_task import DetectionTask
|
|
from app.services.detection.detection_scheduler import DetectionSchedulerService
|
|
|
|
for i in range(3):
|
|
task = DetectionTask(
|
|
brand_id=test_brand.id,
|
|
user_id=_to_uuid(test_user.id),
|
|
name=f"任务{i}",
|
|
frequency="daily",
|
|
engines=["chatgpt"],
|
|
queries=[f"查询{i}"],
|
|
)
|
|
async_session.add(task)
|
|
await async_session.commit()
|
|
|
|
service = DetectionSchedulerService()
|
|
tasks = await service.get_tasks(test_brand.id, _to_uuid(test_user.id), async_session)
|
|
assert len(tasks) == 3
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_trigger_task(self, async_session, test_brand, test_user):
|
|
from app.models.detection_task import DetectionTask
|
|
from app.services.detection.detection_scheduler import DetectionSchedulerService
|
|
|
|
task = DetectionTask(
|
|
brand_id=test_brand.id,
|
|
user_id=_to_uuid(test_user.id),
|
|
name="手动触发测试",
|
|
frequency="daily",
|
|
engines=["chatgpt"],
|
|
queries=["测试查询"],
|
|
)
|
|
async_session.add(task)
|
|
await async_session.commit()
|
|
await async_session.refresh(task)
|
|
|
|
service = DetectionSchedulerService()
|
|
with patch.object(service, "execute_task", new_callable=AsyncMock) as mock_execute:
|
|
mock_execute.return_value = {"status": "success", "results": []}
|
|
result = await service.trigger_task(task.id, _to_uuid(test_user.id), async_session)
|
|
|
|
assert result["status"] == "success"
|
|
mock_execute.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_frequency_validation_hourly(self, async_session, test_brand, test_user):
|
|
from app.services.detection.detection_scheduler import DetectionSchedulerService
|
|
|
|
service = DetectionSchedulerService()
|
|
task_data = {
|
|
"name": "每小时检测",
|
|
"frequency": "hourly",
|
|
"engines": ["chatgpt"],
|
|
"queries": ["查询"],
|
|
}
|
|
task = await service.create_task(task_data, test_brand.id, _to_uuid(test_user.id), async_session)
|
|
assert task.frequency == "hourly"
|
|
assert task.next_run_at is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_frequency_validation_daily(self, async_session, test_brand, test_user):
|
|
from app.services.detection.detection_scheduler import DetectionSchedulerService
|
|
|
|
service = DetectionSchedulerService()
|
|
task_data = {
|
|
"name": "每日检测",
|
|
"frequency": "daily",
|
|
"engines": ["chatgpt"],
|
|
"queries": ["查询"],
|
|
}
|
|
task = await service.create_task(task_data, test_brand.id, _to_uuid(test_user.id), async_session)
|
|
assert task.frequency == "daily"
|
|
assert task.next_run_at is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_frequency_validation_weekly(self, async_session, test_brand, test_user):
|
|
from app.services.detection.detection_scheduler import DetectionSchedulerService
|
|
|
|
service = DetectionSchedulerService()
|
|
task_data = {
|
|
"name": "每周检测",
|
|
"frequency": "weekly",
|
|
"engines": ["chatgpt"],
|
|
"queries": ["查询"],
|
|
}
|
|
task = await service.create_task(task_data, test_brand.id, _to_uuid(test_user.id), async_session)
|
|
assert task.frequency == "weekly"
|
|
assert task.next_run_at is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_frequency_validation_invalid(self, async_session, test_brand, test_user):
|
|
from app.services.detection.detection_scheduler import DetectionSchedulerService
|
|
|
|
service = DetectionSchedulerService()
|
|
task_data = {
|
|
"name": "无效频率",
|
|
"frequency": "monthly",
|
|
"engines": ["chatgpt"],
|
|
"queries": ["查询"],
|
|
}
|
|
with pytest.raises(ValueError, match="frequency"):
|
|
await service.create_task(task_data, test_brand.id, _to_uuid(test_user.id), async_session)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_task_flow(self, async_session, test_brand, test_user):
|
|
from app.models.detection_task import DetectionTask
|
|
from app.services.detection.detection_scheduler import DetectionSchedulerService
|
|
|
|
task = DetectionTask(
|
|
brand_id=test_brand.id,
|
|
user_id=_to_uuid(test_user.id),
|
|
name="执行流程测试",
|
|
frequency="daily",
|
|
engines=["chatgpt"],
|
|
queries=["测试查询"],
|
|
competitor_names=["竞品A"],
|
|
)
|
|
async_session.add(task)
|
|
await async_session.commit()
|
|
await async_session.refresh(task)
|
|
|
|
service = DetectionSchedulerService()
|
|
|
|
mock_batch_result = [
|
|
MagicMock(
|
|
engine_type=MagicMock(value="chatgpt"),
|
|
has_brand_citation=True,
|
|
has_competitor_citation=False,
|
|
)
|
|
]
|
|
|
|
with patch.object(
|
|
service, "_run_batch_query", new_callable=AsyncMock, return_value=mock_batch_result
|
|
), patch.object(
|
|
service, "_generate_alerts_if_needed", new_callable=AsyncMock, return_value=[]
|
|
):
|
|
result = await service.execute_task(task, async_session)
|
|
|
|
assert result["status"] == "success"
|
|
assert "results" in result
|
|
|
|
await async_session.refresh(task)
|
|
assert task.last_run_at is not None
|
|
assert task.next_run_at is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_task_not_found(self, async_session, test_user):
|
|
from app.services.detection.detection_scheduler import DetectionSchedulerService
|
|
|
|
service = DetectionSchedulerService()
|
|
result = await service.delete_task(uuid.uuid4(), _to_uuid(test_user.id), async_session)
|
|
assert result is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_task_not_found(self, async_session, test_user):
|
|
from app.services.detection.detection_scheduler import DetectionSchedulerService, TaskNotFoundError
|
|
|
|
service = DetectionSchedulerService()
|
|
with pytest.raises(TaskNotFoundError):
|
|
await service.update_task(uuid.uuid4(), {"name": "新名称"}, _to_uuid(test_user.id), async_session)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_tasks_empty(self, async_session, test_brand, test_user):
|
|
from app.services.detection.detection_scheduler import DetectionSchedulerService
|
|
|
|
service = DetectionSchedulerService()
|
|
tasks = await service.get_tasks(test_brand.id, _to_uuid(test_user.id), async_session)
|
|
assert tasks == []
|