import logging import time import uuid from datetime import datetime, timezone from sqlalchemy import select from app.agent_framework.base import BaseAgent from app.agent_framework.protocol import ( AgentCapability, AgentType, TaskMessage, TaskResult, TaskStatus, ) from app.database import AsyncSessionLocal from app.models.monitoring import MonitoringRecord from app.models.query import Query from app.models.brand import Brand from app.services.monitoring.monitor_service import MonitorService logger = logging.getLogger(__name__) class MonitorAgent(BaseAgent): def __init__(self): super().__init__( name="monitor", agent_type=AgentType.PERFORMANCE_TRACKER, version="1.0.0", ) def get_capabilities(self) -> AgentCapability: return AgentCapability( agent_name=self.name, agent_type=self.agent_type, version=self.version, supported_tasks=["monitor_track", "monitor_check_single"], max_concurrency=3, description="效果追踪Agent:监测品牌引用量、情感、排名变化,生成变化报告", ) async def execute(self, task: TaskMessage) -> TaskResult: started_at = datetime.now(timezone.utc) start_time = time.monotonic() try: task_type = task.task_type if task_type == "monitor_track": output = await self._monitor_track(task) elif task_type == "monitor_check_single": output = await self._monitor_check_single(task) else: raise ValueError(f"不支持的任务类型: {task_type}") elapsed = time.monotonic() - start_time return TaskResult( task_id=task.task_id, agent_name=self.name, status=TaskStatus.COMPLETED, output_data=output, error_message=None, started_at=started_at, completed_at=datetime.now(timezone.utc), metrics={ "elapsed_seconds": round(elapsed, 2), "task_type": task_type, }, ) except Exception as e: elapsed = time.monotonic() - start_time logger.error(f"MonitorAgent task {task.task_id} failed: {e}") return TaskResult( task_id=task.task_id, agent_name=self.name, status=TaskStatus.FAILED, output_data=None, error_message=str(e), started_at=started_at, completed_at=datetime.now(timezone.utc), metrics={ "elapsed_seconds": round(elapsed, 2), "task_type": task.task_type, }, ) async def _monitor_track(self, task: TaskMessage) -> dict: input_data = task.input_data brand_id = input_data.get("brand_id") if not brand_id: raise ValueError("input_data必须包含'brand_id'字段") brand_id = uuid.UUID(brand_id) await self.report_progress( task_id=task.task_id, progress=0.1, message="开始效果追踪...", ) async with AsyncSessionLocal() as db: stmt = select(Brand).where(Brand.id == brand_id) result = await db.execute(stmt) brand = result.scalar_one_or_none() if not brand: raise ValueError(f"品牌不存在: {brand_id}") queries_stmt = select(Query).where(Query.target_brand == brand.name) queries_result = await db.execute(queries_stmt) queries = list(queries_result.scalars().all()) if not queries: await self.report_progress( task_id=task.task_id, progress=1.0, message="品牌下无关联查询,效果追踪完成", ) return { "brand_id": str(brand_id), "brand_name": brand.name, "total_queries": 0, "reports": [], } total_queries = len(queries) reports = [] service = MonitorService() monitoring_stmt = select(MonitoringRecord).where( MonitoringRecord.brand_id == brand_id, MonitoringRecord.status == "active", ) monitoring_result = await db.execute(monitoring_stmt) monitoring_records = list(monitoring_result.scalars().all()) for idx, record in enumerate(monitoring_records): progress = 0.2 + (0.7 * idx / max(len(monitoring_records), 1)) await self.report_progress( task_id=task.task_id, progress=progress, message=f"正在检测第 {idx + 1}/{len(monitoring_records)} 条监测记录...", ) updated_record = await service.check_and_compare(db, record.id) if updated_record: report = await service.generate_change_report(db, updated_record.id) if report: reports.append(report) await self.report_progress( task_id=task.task_id, progress=1.0, message="效果追踪完成", ) return { "brand_id": str(brand_id), "brand_name": brand.name, "total_queries": total_queries, "checked_records": len(monitoring_records), "reports": reports, } async def _monitor_check_single(self, task: TaskMessage) -> dict: input_data = task.input_data brand_id = input_data.get("brand_id") keyword = input_data.get("keyword") platform = input_data.get("platform") if not brand_id: raise ValueError("input_data必须包含'brand_id'字段") brand_id = uuid.UUID(brand_id) await self.report_progress( task_id=task.task_id, progress=0.1, message="开始单关键词检测...", ) async with AsyncSessionLocal() as db: service = MonitorService() await self.report_progress( task_id=task.task_id, progress=0.3, message="正在创建监测记录...", ) record = await service.create_monitoring_record( db=db, brand_id=brand_id, query_keywords=keyword, platform=platform, check_interval_hours=input_data.get("check_interval_hours", 24), ) await self.report_progress( task_id=task.task_id, progress=0.6, message="正在执行检测对比...", ) updated_record = await service.check_and_compare(db, record.id) await self.report_progress( task_id=task.task_id, progress=0.8, message="正在生成变化报告...", ) report = None if updated_record: report = await service.generate_change_report(db, updated_record.id) await self.report_progress( task_id=task.task_id, progress=1.0, message="单关键词检测完成", ) return { "record_id": str(record.id), "brand_id": str(brand_id), "keyword": keyword, "platform": platform, "change_type": updated_record.change_type if updated_record else None, "report": report, }