import uuid from unittest.mock import AsyncMock, MagicMock, patch import pytest import pytest_asyncio from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.pool import StaticPool from app.database import Base from app.models.brand import Brand from app.models.user import User from app.services.auth import hash_password from tests.fixtures.auth import _to_uuid @pytest_asyncio.fixture async def async_engine(): engine = create_async_engine( "sqlite+aiosqlite:///:memory:", connect_args={"check_same_thread": False}, poolclass=StaticPool, ) async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) yield engine await engine.dispose() @pytest_asyncio.fixture async def async_session(async_engine): session_maker = async_sessionmaker( async_engine, class_=AsyncSession, expire_on_commit=False, autoflush=False, autocommit=False, ) async with session_maker() as session: yield session @pytest_asyncio.fixture async def test_user(async_session): user = User( id=str(uuid.uuid4()), email="test@example.com", password=hash_password("Test@123456"), firstName="Test User", plan="free", max_queries=5, isActive=True, emailVerified=True, ) async_session.add(user) await async_session.commit() await async_session.refresh(user) return user @pytest_asyncio.fixture async def test_brand(async_session, test_user): brand = Brand( id=uuid.uuid4(), user_id=_to_uuid(test_user.id), name="Test Brand", aliases=["TestBrand", "TB"], website="https://testbrand.com", industry="technology", platforms=["wenxin", "kimi"], frequency="weekly", status="active", ) async_session.add(brand) await async_session.commit() await async_session.refresh(brand) return brand class TestDetectionTaskModel: @pytest.mark.asyncio async def test_create_detection_task(self, async_session, test_brand, test_user): from app.models.detection_task import DetectionTask task = DetectionTask( brand_id=test_brand.id, user_id=_to_uuid(test_user.id), name="每日品牌检测", frequency="daily", engines=["chatgpt", "perplexity"], queries=["最佳保险品牌", "保险推荐"], competitor_names=["竞品A", "竞品B"], ) async_session.add(task) await async_session.commit() await async_session.refresh(task) assert task.id is not None assert task.brand_id == test_brand.id assert task.user_id == _to_uuid(test_user.id) assert task.name == "每日品牌检测" assert task.frequency == "daily" assert task.engines == ["chatgpt", "perplexity"] assert task.queries == ["最佳保险品牌", "保险推荐"] assert task.competitor_names == ["竞品A", "竞品B"] assert task.is_active is True assert task.last_run_at is None assert task.next_run_at is not None assert task.created_at is not None assert task.updated_at is not None @pytest.mark.asyncio async def test_detection_task_default_values(self, async_session, test_brand, test_user): from app.models.detection_task import DetectionTask task = DetectionTask( brand_id=test_brand.id, user_id=_to_uuid(test_user.id), name="简单检测", frequency="weekly", engines=["chatgpt"], queries=["测试查询"], ) async_session.add(task) await async_session.commit() await async_session.refresh(task) assert task.is_active is True assert task.competitor_names is None assert task.last_run_at is None class TestDetectionSchedulerService: @pytest.mark.asyncio async def test_create_task(self, async_session, test_brand, test_user): from app.services.detection.detection_scheduler import DetectionSchedulerService service = DetectionSchedulerService() task_data = { "name": "每日品牌检测", "frequency": "daily", "engines": ["chatgpt", "perplexity"], "queries": ["最佳保险品牌", "保险推荐"], "competitor_names": ["竞品A"], } task = await service.create_task(task_data, test_brand.id, _to_uuid(test_user.id), async_session) assert task.id is not None assert task.name == "每日品牌检测" assert task.frequency == "daily" assert task.brand_id == test_brand.id assert task.user_id == _to_uuid(test_user.id) @pytest.mark.asyncio async def test_update_task(self, async_session, test_brand, test_user): from app.models.detection_task import DetectionTask from app.services.detection.detection_scheduler import DetectionSchedulerService task = DetectionTask( brand_id=test_brand.id, user_id=_to_uuid(test_user.id), name="旧名称", frequency="weekly", engines=["chatgpt"], queries=["查询1"], ) async_session.add(task) await async_session.commit() await async_session.refresh(task) service = DetectionSchedulerService() update_data = { "name": "新名称", "frequency": "daily", "engines": ["chatgpt", "perplexity"], } updated = await service.update_task(task.id, update_data, _to_uuid(test_user.id), async_session) assert updated.name == "新名称" assert updated.frequency == "daily" assert updated.engines == ["chatgpt", "perplexity"] @pytest.mark.asyncio async def test_delete_task(self, async_session, test_brand, test_user): from app.models.detection_task import DetectionTask from app.services.detection.detection_scheduler import DetectionSchedulerService task = DetectionTask( brand_id=test_brand.id, user_id=_to_uuid(test_user.id), name="待删除", frequency="weekly", engines=["chatgpt"], queries=["查询1"], ) async_session.add(task) await async_session.commit() await async_session.refresh(task) service = DetectionSchedulerService() result = await service.delete_task(task.id, _to_uuid(test_user.id), async_session) assert result is True stmt = select(DetectionTask).where(DetectionTask.id == task.id) db_result = await async_session.execute(stmt) assert db_result.scalar_one_or_none() is None @pytest.mark.asyncio async def test_get_tasks(self, async_session, test_brand, test_user): from app.models.detection_task import DetectionTask from app.services.detection.detection_scheduler import DetectionSchedulerService for i in range(3): task = DetectionTask( brand_id=test_brand.id, user_id=_to_uuid(test_user.id), name=f"任务{i}", frequency="daily", engines=["chatgpt"], queries=[f"查询{i}"], ) async_session.add(task) await async_session.commit() service = DetectionSchedulerService() tasks = await service.get_tasks(test_brand.id, _to_uuid(test_user.id), async_session) assert len(tasks) == 3 @pytest.mark.asyncio async def test_trigger_task(self, async_session, test_brand, test_user): from app.models.detection_task import DetectionTask from app.services.detection.detection_scheduler import DetectionSchedulerService task = DetectionTask( brand_id=test_brand.id, user_id=_to_uuid(test_user.id), name="手动触发测试", frequency="daily", engines=["chatgpt"], queries=["测试查询"], ) async_session.add(task) await async_session.commit() await async_session.refresh(task) service = DetectionSchedulerService() with patch.object(service, "execute_task", new_callable=AsyncMock) as mock_execute: mock_execute.return_value = {"status": "success", "results": []} result = await service.trigger_task(task.id, _to_uuid(test_user.id), async_session) assert result["status"] == "success" mock_execute.assert_called_once() @pytest.mark.asyncio async def test_frequency_validation_hourly(self, async_session, test_brand, test_user): from app.services.detection.detection_scheduler import DetectionSchedulerService service = DetectionSchedulerService() task_data = { "name": "每小时检测", "frequency": "hourly", "engines": ["chatgpt"], "queries": ["查询"], } task = await service.create_task(task_data, test_brand.id, _to_uuid(test_user.id), async_session) assert task.frequency == "hourly" assert task.next_run_at is not None @pytest.mark.asyncio async def test_frequency_validation_daily(self, async_session, test_brand, test_user): from app.services.detection.detection_scheduler import DetectionSchedulerService service = DetectionSchedulerService() task_data = { "name": "每日检测", "frequency": "daily", "engines": ["chatgpt"], "queries": ["查询"], } task = await service.create_task(task_data, test_brand.id, _to_uuid(test_user.id), async_session) assert task.frequency == "daily" assert task.next_run_at is not None @pytest.mark.asyncio async def test_frequency_validation_weekly(self, async_session, test_brand, test_user): from app.services.detection.detection_scheduler import DetectionSchedulerService service = DetectionSchedulerService() task_data = { "name": "每周检测", "frequency": "weekly", "engines": ["chatgpt"], "queries": ["查询"], } task = await service.create_task(task_data, test_brand.id, _to_uuid(test_user.id), async_session) assert task.frequency == "weekly" assert task.next_run_at is not None @pytest.mark.asyncio async def test_frequency_validation_invalid(self, async_session, test_brand, test_user): from app.services.detection.detection_scheduler import DetectionSchedulerService service = DetectionSchedulerService() task_data = { "name": "无效频率", "frequency": "monthly", "engines": ["chatgpt"], "queries": ["查询"], } with pytest.raises(ValueError, match="frequency"): await service.create_task(task_data, test_brand.id, _to_uuid(test_user.id), async_session) @pytest.mark.asyncio async def test_execute_task_flow(self, async_session, test_brand, test_user): from app.models.detection_task import DetectionTask from app.services.detection.detection_scheduler import DetectionSchedulerService task = DetectionTask( brand_id=test_brand.id, user_id=_to_uuid(test_user.id), name="执行流程测试", frequency="daily", engines=["chatgpt"], queries=["测试查询"], competitor_names=["竞品A"], ) async_session.add(task) await async_session.commit() await async_session.refresh(task) service = DetectionSchedulerService() mock_batch_result = [ MagicMock( engine_type=MagicMock(value="chatgpt"), has_brand_citation=True, has_competitor_citation=False, ) ] with patch.object( service, "_run_batch_query", new_callable=AsyncMock, return_value=mock_batch_result ), patch.object( service, "_generate_alerts_if_needed", new_callable=AsyncMock, return_value=[] ): result = await service.execute_task(task, async_session) assert result["status"] == "success" assert "results" in result await async_session.refresh(task) assert task.last_run_at is not None assert task.next_run_at is not None @pytest.mark.asyncio async def test_delete_task_not_found(self, async_session, test_user): from app.services.detection.detection_scheduler import DetectionSchedulerService service = DetectionSchedulerService() result = await service.delete_task(uuid.uuid4(), _to_uuid(test_user.id), async_session) assert result is False @pytest.mark.asyncio async def test_update_task_not_found(self, async_session, test_user): from app.services.detection.detection_scheduler import DetectionSchedulerService, TaskNotFoundError service = DetectionSchedulerService() with pytest.raises(TaskNotFoundError): await service.update_task(uuid.uuid4(), {"name": "新名称"}, _to_uuid(test_user.id), async_session) @pytest.mark.asyncio async def test_get_tasks_empty(self, async_session, test_brand, test_user): from app.services.detection.detection_scheduler import DetectionSchedulerService service = DetectionSchedulerService() tasks = await service.get_tasks(test_brand.id, _to_uuid(test_user.id), async_session) assert tasks == []