geo/backend/app/services/detection_scheduler.py

227 lines
6.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
import uuid
from datetime import datetime, timezone
from sqlalchemy import and_, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.brand import Brand
from app.models.detection_task import VALID_FREQUENCIES, DetectionTask
from app.services.ai_engine.base import EngineType
from app.services.ai_engine.batch_query import BatchQueryService, _build_adapters
logger = logging.getLogger(__name__)
class TaskNotFoundError(Exception):
pass
class DetectionSchedulerService:
async def create_task(
self,
task_data: dict,
brand_id: uuid.UUID,
user_id: uuid.UUID,
db: AsyncSession,
) -> DetectionTask:
frequency = task_data.get("frequency", "daily")
self._validate_frequency(frequency)
task = DetectionTask(
brand_id=brand_id,
user_id=user_id,
name=task_data["name"],
frequency=frequency,
engines=task_data.get("engines", []),
queries=task_data.get("queries", []),
competitor_names=task_data.get("competitor_names"),
)
task.next_run_at = task.compute_next_run_at()
db.add(task)
await db.commit()
await db.refresh(task)
logger.info(f"Created detection task: {task.id}, name={task.name}")
return task
async def update_task(
self,
task_id: uuid.UUID,
task_data: dict,
user_id: uuid.UUID,
db: AsyncSession,
) -> DetectionTask:
task = await self._get_task_by_id(task_id, user_id, db)
if not task:
raise TaskNotFoundError(f"Detection task {task_id} not found")
frequency = task_data.get("frequency", task.frequency)
self._validate_frequency(frequency)
for field, value in task_data.items():
setattr(task, field, value)
task.next_run_at = task.compute_next_run_at()
await db.commit()
await db.refresh(task)
logger.info(f"Updated detection task: {task.id}")
return task
async def delete_task(
self,
task_id: uuid.UUID,
user_id: uuid.UUID,
db: AsyncSession,
) -> bool:
task = await self._get_task_by_id(task_id, user_id, db)
if not task:
return False
await db.delete(task)
await db.commit()
logger.info(f"Deleted detection task: {task_id}")
return True
async def get_tasks(
self,
brand_id: uuid.UUID,
user_id: uuid.UUID,
db: AsyncSession,
) -> list[DetectionTask]:
stmt = (
select(DetectionTask)
.where(
and_(
DetectionTask.brand_id == brand_id,
DetectionTask.user_id == user_id,
)
)
.order_by(DetectionTask.created_at.desc())
)
result = await db.execute(stmt)
return list(result.scalars().all())
async def trigger_task(
self,
task_id: uuid.UUID,
user_id: uuid.UUID,
db: AsyncSession,
) -> dict:
task = await self._get_task_by_id(task_id, user_id, db)
if not task:
return {"status": "error", "message": "Task not found"}
return await self.execute_task(task, db)
async def execute_task(
self,
task: DetectionTask,
db: AsyncSession,
) -> dict:
try:
results = await self._run_batch_query(task, db)
await self._generate_alerts_if_needed(task, results, db)
task.last_run_at = datetime.now(timezone.utc)
task.next_run_at = task.compute_next_run_at()
await db.commit()
await db.refresh(task)
logger.info(f"Detection task executed: {task.id}")
return {"status": "success", "results": results}
except Exception as e:
logger.error(f"Detection task execution failed: {task.id}, error={e}")
return {"status": "error", "message": str(e)}
async def _run_batch_query(
self,
task: DetectionTask,
db: AsyncSession,
) -> list:
adapters = _build_adapters()
batch_service = BatchQueryService(adapters)
all_results = []
brand = await self._get_brand(task.brand_id, db)
brand_name = brand.name if brand else "Unknown"
for query_text in task.queries:
engines = [
EngineType(e) for e in task.engines if e in EngineType._value2member_map_
]
results = await batch_service.query_batch(
engines=engines,
query=query_text,
brand_name=brand_name,
competitor_names=task.competitor_names,
)
all_results.extend(results)
return all_results
async def _generate_alerts_if_needed(
self,
task: DetectionTask,
results: list,
db: AsyncSession,
) -> list:
if not results:
return []
brand = await self._get_brand(task.brand_id, db)
if not brand:
return []
brand_cited = sum(1 for r in results if r.has_brand_citation)
competitor_cited = sum(1 for r in results if r.has_competitor_citation)
alerts = []
if competitor_cited > 0 and brand_cited == 0:
from app.services.alert_engine import AlertEngine
alert_engine = AlertEngine(db)
alert = await alert_engine._create_alert(
brand_id=task.brand_id,
user_id=task.user_id,
alert_type="competitor_overtake",
severity="warning",
title=f"{brand.name} 未被引用但竞品被引用",
message=f"检测任务「{task.name}」执行后发现品牌未被AI引擎引用但竞品被引用了 {competitor_cited} 次。",
data={"task_id": str(task.id), "competitor_cited": competitor_cited},
)
if alert:
alerts.append(alert)
return alerts
@staticmethod
def _validate_frequency(frequency: str) -> None:
if frequency not in VALID_FREQUENCIES:
raise ValueError(
f"Invalid frequency: {frequency}. Must be one of {VALID_FREQUENCIES}"
)
async def _get_task_by_id(
self,
task_id: uuid.UUID,
user_id: uuid.UUID,
db: AsyncSession,
) -> DetectionTask | None:
stmt = select(DetectionTask).where(
and_(
DetectionTask.id == task_id,
DetectionTask.user_id == user_id,
)
)
result = await db.execute(stmt)
return result.scalar_one_or_none()
async def _get_brand(
self,
brand_id: uuid.UUID,
db: AsyncSession,
) -> Brand | None:
stmt = select(Brand).where(Brand.id == brand_id)
result = await db.execute(stmt)
return result.scalar_one_or_none()