202 lines
8.2 KiB
Python
202 lines
8.2 KiB
Python
"""
|
||
定时任务调度器
|
||
- 使用 APScheduler 的 AsyncIOScheduler
|
||
- 每小时执行一次检查
|
||
- 查找 queries 表中 status='active' 且 next_query_at <= now() 的记录
|
||
- 为每个符合条件的 query 调用 CitationEngine 执行查询
|
||
"""
|
||
|
||
import asyncio
|
||
import logging
|
||
from datetime import datetime, timedelta, timezone
|
||
|
||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||
from apscheduler.triggers.interval import IntervalTrigger
|
||
from sqlalchemy import select
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.database import AsyncSessionLocal
|
||
from app.models.citation_record import CitationRecord
|
||
from app.models.query import Query
|
||
from app.models.query_task import QueryTask
|
||
from app.workers.citation_engine import CitationEngine
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class QueryScheduler:
|
||
def __init__(self):
|
||
self.scheduler = AsyncIOScheduler()
|
||
self.engine = CitationEngine()
|
||
self._loop = None
|
||
|
||
def start(self):
|
||
"""启动调度器"""
|
||
self._loop = asyncio.get_event_loop()
|
||
self.scheduler.add_job(
|
||
self._run_check,
|
||
trigger=IntervalTrigger(hours=1),
|
||
id="check_queries",
|
||
name="检查并执行到期的查询任务",
|
||
replace_existing=True,
|
||
)
|
||
self.scheduler.add_job(
|
||
self._run_pending_tasks_check,
|
||
trigger=IntervalTrigger(minutes=1),
|
||
id="check_pending_tasks",
|
||
name="检查并执行遗留的pending查询任务",
|
||
replace_existing=True,
|
||
)
|
||
self.scheduler.start()
|
||
logger.info("查询调度器已启动,每小时检查一次待执行任务,每分钟检查一次遗留pending任务")
|
||
|
||
def _run_check(self):
|
||
"""同步包装:将异步检查任务调度到当前事件循环"""
|
||
if self._loop and self._loop.is_running():
|
||
asyncio.run_coroutine_threadsafe(self.check_and_execute_queries(), self._loop)
|
||
else:
|
||
asyncio.run(self.check_and_execute_queries())
|
||
|
||
async def check_and_execute_queries(self):
|
||
"""检查并执行到期的查询任务"""
|
||
logger.info("开始检查待执行查询任务...")
|
||
async with AsyncSessionLocal() as db:
|
||
try:
|
||
now = datetime.now(timezone.utc)
|
||
stmt = select(Query).where(
|
||
Query.status == "active",
|
||
Query.next_query_at <= now,
|
||
)
|
||
result = await db.execute(stmt)
|
||
queries = result.scalars().all()
|
||
|
||
logger.info(f"找到 {len(queries)} 个待执行查询")
|
||
|
||
for query in queries:
|
||
try:
|
||
await self._execute_single_query(query, db)
|
||
except Exception as e:
|
||
logger.error(f"执行查询 {query.id} 失败: {e}")
|
||
continue
|
||
|
||
except Exception as e:
|
||
logger.error(f"检查查询任务时出错: {e}")
|
||
|
||
async def _execute_single_query(self, query: Query, db: AsyncSession):
|
||
"""执行单个查询"""
|
||
logger.info(f"开始执行查询: {query.keyword} (ID: {query.id})")
|
||
try:
|
||
await self.engine.execute_query(query, db)
|
||
logger.info(f"查询 {query.id} 执行完成")
|
||
except Exception as e:
|
||
logger.error(f"查询 {query.id} 执行失败: {e}")
|
||
raise
|
||
|
||
def _run_pending_tasks_check(self):
|
||
"""同步包装:将异步遗留任务检查调度到当前事件循环"""
|
||
if self._loop and self._loop.is_running():
|
||
asyncio.run_coroutine_threadsafe(self.check_and_execute_pending_tasks(), self._loop)
|
||
else:
|
||
asyncio.run(self.check_and_execute_pending_tasks())
|
||
|
||
async def check_and_execute_pending_tasks(self):
|
||
"""兜底:处理超过1分钟仍未执行的pending任务"""
|
||
logger.info("检查并执行遗留的 pending 查询任务...")
|
||
async with AsyncSessionLocal() as db:
|
||
try:
|
||
one_minute_ago = datetime.utcnow() - timedelta(minutes=1)
|
||
|
||
stmt = select(QueryTask).where(
|
||
QueryTask.status == "pending",
|
||
QueryTask.scheduled_at <= one_minute_ago,
|
||
)
|
||
result = await db.execute(stmt)
|
||
tasks = result.scalars().all()
|
||
|
||
logger.info(f"找到 {len(tasks)} 个遗留的 pending 任务")
|
||
|
||
from collections import defaultdict
|
||
tasks_by_query = defaultdict(list)
|
||
for task in tasks:
|
||
tasks_by_query[task.query_id].append(task)
|
||
|
||
for query_id, task_list in tasks_by_query.items():
|
||
query_stmt = select(Query).where(Query.id == query_id)
|
||
query_result = await db.execute(query_stmt)
|
||
query = query_result.scalar_one_or_none()
|
||
|
||
if not query or query.status != "active":
|
||
continue
|
||
|
||
for task in task_list:
|
||
try:
|
||
task.status = "running"
|
||
task.started_at = datetime.utcnow()
|
||
task.error_message = None
|
||
await db.commit()
|
||
|
||
citation_result = await self.engine.execute_single_platform(
|
||
keyword=query.keyword,
|
||
platform=task.platform,
|
||
target_brand=query.target_brand,
|
||
brand_aliases=query.brand_aliases or [],
|
||
)
|
||
|
||
if citation_result:
|
||
record = CitationRecord(
|
||
query_id=query_id,
|
||
platform=task.platform,
|
||
cited=citation_result.get("cited", False),
|
||
citation_position=citation_result.get("position"),
|
||
citation_text=citation_result.get("citation_text"),
|
||
competitor_brands=citation_result.get("competitor_brands", []),
|
||
raw_response=citation_result.get("raw_response", ""),
|
||
confidence=citation_result.get("confidence"),
|
||
match_type=citation_result.get("match_type"),
|
||
# 引用源分析字段
|
||
data_source=citation_result.get("data_source"),
|
||
source_urls=citation_result.get("source_urls"),
|
||
source_titles=citation_result.get("source_titles"),
|
||
citation_contexts=citation_result.get("citation_contexts"),
|
||
ai_response_text=citation_result.get("ai_response_text", ""),
|
||
)
|
||
db.add(record)
|
||
|
||
task.status = "success"
|
||
task.completed_at = datetime.utcnow()
|
||
await db.commit()
|
||
|
||
except Exception as e:
|
||
await db.rollback()
|
||
task.status = "failed"
|
||
task.error_message = str(e)
|
||
task.completed_at = datetime.utcnow()
|
||
await db.commit()
|
||
logger.error(f"执行遗留任务 {task.id} 失败: {e}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"检查遗留任务时出错: {e}")
|
||
|
||
async def shutdown(self):
|
||
"""关闭调度器"""
|
||
self.scheduler.shutdown(wait=False)
|
||
await self.engine.close()
|
||
logger.info("查询调度器已关闭")
|
||
|
||
|
||
# 全局调度器实例
|
||
query_scheduler = QueryScheduler()
|
||
|
||
# 导出别名以兼容测试
|
||
scheduler = query_scheduler.scheduler
|
||
|
||
|
||
def run_job_now(job_id: str):
|
||
"""手动触发指定任务"""
|
||
job = query_scheduler.scheduler.get_job(job_id)
|
||
if job:
|
||
# 获取任务的回调函数并直接调用
|
||
job.func()
|
||
return True
|
||
return False
|