import logging import time from datetime import datetime, timezone from app.agent_framework.base import BaseAgent from app.agent_framework.protocol import ( AgentCapability, AgentType, TaskMessage, TaskResult, TaskStatus, ) logger = logging.getLogger(__name__) class TrendAgent(BaseAgent): def __init__(self): super().__init__( name="trend_agent", agent_type=AgentType.TREND_AGENT, version="1.0.0", ) def get_capabilities(self) -> AgentCapability: return AgentCapability( agent_name=self.name, agent_type=self.agent_type, version=self.version, supported_tasks=["trend_insight", "trend_hotspot"], max_concurrency=2, description="趋势洞察Agent:分析品牌引用趋势、识别热点话题、推断变化原因并生成建议", ) async def execute(self, task: TaskMessage) -> TaskResult: started_at = datetime.now(timezone.utc) start_time = time.monotonic() try: if task.task_type == "trend_insight": output = await self._trend_insight(task) elif task.task_type == "trend_hotspot": output = await self._trend_hotspot(task) else: raise ValueError(f"不支持的任务类型: {task.task_type}") elapsed = time.monotonic() - start_time return TaskResult( task_id=task.task_id, agent_name=self.name, status=TaskStatus.COMPLETED, output_data=output, error_message=None, started_at=started_at, completed_at=datetime.now(timezone.utc), metrics={ "elapsed_seconds": round(elapsed, 2), "task_type": task.task_type, }, ) except Exception as e: elapsed = time.monotonic() - start_time logger.error(f"TrendAgent task {task.task_id} failed: {e}") return TaskResult( task_id=task.task_id, agent_name=self.name, status=TaskStatus.FAILED, output_data=None, error_message=str(e), started_at=started_at, completed_at=datetime.now(timezone.utc), metrics={ "elapsed_seconds": round(elapsed, 2), "task_type": task.task_type, }, ) async def _trend_insight(self, task: TaskMessage) -> dict: input_data = task.input_data brand_id = input_data.get("brand_id") days = input_data.get("days", 30) platforms = input_data.get("platforms") keywords = input_data.get("keywords") if not brand_id: raise ValueError("input_data必须包含'brand_id'字段") await self.report_progress( task_id=task.task_id, progress=0.1, message="开始趋势洞察分析...", ) from app.database import AsyncSessionLocal from app.services.trend.trend_analyzer_service import TrendAnalyzerService async with AsyncSessionLocal() as db: service = TrendAnalyzerService(db) await self.report_progress( task_id=task.task_id, progress=0.3, message="获取历史引用数据...", ) await self.report_progress( task_id=task.task_id, progress=0.5, message="执行时间序列分析...", ) result = await service.analyze_trends( brand_id=brand_id, days=days, platforms=platforms, keywords=keywords, ) await self.report_progress( task_id=task.task_id, progress=1.0, message="趋势洞察分析完成", ) return result async def _trend_hotspot(self, task: TaskMessage) -> dict: input_data = task.input_data brand_id = input_data.get("brand_id") days = input_data.get("days", 30) if not brand_id: raise ValueError("input_data必须包含'brand_id'字段") await self.report_progress( task_id=task.task_id, progress=0.1, message="开始热点话题分析...", ) from app.database import AsyncSessionLocal from app.services.trend.trend_analyzer_service import TrendAnalyzerService async with AsyncSessionLocal() as db: service = TrendAnalyzerService(db) await self.report_progress( task_id=task.task_id, progress=0.5, message="检测引用量突增的关键词/话题...", ) result = await service.get_hotspots( brand_id=brand_id, days=days, ) await self.report_progress( task_id=task.task_id, progress=1.0, message="热点话题分析完成", ) return result