import logging import uuid from datetime import datetime, timezone from sqlalchemy import and_, select from sqlalchemy.ext.asyncio import AsyncSession from app.models.brand import Brand from app.models.detection_task import VALID_FREQUENCIES, DetectionTask from app.services.ai_engine.base import EngineType from app.services.ai_engine.batch_query import BatchQueryService, _build_adapters logger = logging.getLogger(__name__) class TaskNotFoundError(Exception): pass class DetectionSchedulerService: async def create_task( self, task_data: dict, brand_id: uuid.UUID, user_id: uuid.UUID, db: AsyncSession, ) -> DetectionTask: frequency = task_data.get("frequency", "daily") self._validate_frequency(frequency) task = DetectionTask( brand_id=brand_id, user_id=user_id, name=task_data["name"], frequency=frequency, engines=task_data.get("engines", []), queries=task_data.get("queries", []), competitor_names=task_data.get("competitor_names"), ) task.next_run_at = task.compute_next_run_at() db.add(task) await db.commit() await db.refresh(task) logger.info(f"Created detection task: {task.id}, name={task.name}") return task async def update_task( self, task_id: uuid.UUID, task_data: dict, user_id: uuid.UUID, db: AsyncSession, ) -> DetectionTask: task = await self._get_task_by_id(task_id, user_id, db) if not task: raise TaskNotFoundError(f"Detection task {task_id} not found") frequency = task_data.get("frequency", task.frequency) self._validate_frequency(frequency) for field, value in task_data.items(): setattr(task, field, value) task.next_run_at = task.compute_next_run_at() await db.commit() await db.refresh(task) logger.info(f"Updated detection task: {task.id}") return task async def delete_task( self, task_id: uuid.UUID, user_id: uuid.UUID, db: AsyncSession, ) -> bool: task = await self._get_task_by_id(task_id, user_id, db) if not task: return False await db.delete(task) await db.commit() logger.info(f"Deleted detection task: {task_id}") return True async def get_tasks( self, brand_id: uuid.UUID, user_id: uuid.UUID, db: AsyncSession, ) -> list[DetectionTask]: stmt = ( select(DetectionTask) .where( and_( DetectionTask.brand_id == brand_id, DetectionTask.user_id == user_id, ) ) .order_by(DetectionTask.created_at.desc()) ) result = await db.execute(stmt) return list(result.scalars().all()) async def trigger_task( self, task_id: uuid.UUID, user_id: uuid.UUID, db: AsyncSession, ) -> dict: task = await self._get_task_by_id(task_id, user_id, db) if not task: return {"status": "error", "message": "Task not found"} return await self.execute_task(task, db) async def execute_task( self, task: DetectionTask, db: AsyncSession, ) -> dict: try: results = await self._run_batch_query(task, db) await self._generate_alerts_if_needed(task, results, db) task.last_run_at = datetime.now(timezone.utc) task.next_run_at = task.compute_next_run_at() await db.commit() await db.refresh(task) logger.info(f"Detection task executed: {task.id}") return {"status": "success", "results": results} except Exception as e: logger.error(f"Detection task execution failed: {task.id}, error={e}") return {"status": "error", "message": str(e)} async def _run_batch_query( self, task: DetectionTask, db: AsyncSession, ) -> list: adapters = _build_adapters() batch_service = BatchQueryService(adapters) all_results = [] brand = await self._get_brand(task.brand_id, db) brand_name = brand.name if brand else "Unknown" for query_text in task.queries: engines = [ EngineType(e) for e in task.engines if e in EngineType._value2member_map_ ] results = await batch_service.query_batch( engines=engines, query=query_text, brand_name=brand_name, competitor_names=task.competitor_names, ) all_results.extend(results) return all_results async def _generate_alerts_if_needed( self, task: DetectionTask, results: list, db: AsyncSession, ) -> list: if not results: return [] brand = await self._get_brand(task.brand_id, db) if not brand: return [] brand_cited = sum(1 for r in results if r.has_brand_citation) competitor_cited = sum(1 for r in results if r.has_competitor_citation) alerts = [] if competitor_cited > 0 and brand_cited == 0: from app.services.alert.alert_engine import AlertEngine alert_engine = AlertEngine(db) alert = await alert_engine._create_alert( brand_id=task.brand_id, user_id=task.user_id, alert_type="competitor_overtake", severity="warning", title=f"{brand.name} 未被引用但竞品被引用", message=f"检测任务「{task.name}」执行后发现:品牌未被AI引擎引用,但竞品被引用了 {competitor_cited} 次。", data={"task_id": str(task.id), "competitor_cited": competitor_cited}, ) if alert: alerts.append(alert) return alerts @staticmethod def _validate_frequency(frequency: str) -> None: if frequency not in VALID_FREQUENCIES: raise ValueError( f"Invalid frequency: {frequency}. Must be one of {VALID_FREQUENCIES}" ) async def _get_task_by_id( self, task_id: uuid.UUID, user_id: uuid.UUID, db: AsyncSession, ) -> DetectionTask | None: stmt = select(DetectionTask).where( and_( DetectionTask.id == task_id, DetectionTask.user_id == user_id, ) ) result = await db.execute(stmt) return result.scalar_one_or_none() async def _get_brand( self, brand_id: uuid.UUID, db: AsyncSession, ) -> Brand | None: stmt = select(Brand).where(Brand.id == brand_id) result = await db.execute(stmt) return result.scalar_one_or_none()