167 lines
5.1 KiB
Python
167 lines
5.1 KiB
Python
import logging
|
||
import time
|
||
from datetime import datetime, timezone
|
||
|
||
from app.agent_framework.base import BaseAgent
|
||
from app.agent_framework.protocol import (
|
||
AgentCapability,
|
||
AgentType,
|
||
TaskMessage,
|
||
TaskResult,
|
||
TaskStatus,
|
||
)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class TrendAgent(BaseAgent):
|
||
|
||
def __init__(self):
|
||
super().__init__(
|
||
name="trend_agent",
|
||
agent_type=AgentType.TREND_AGENT,
|
||
version="1.0.0",
|
||
)
|
||
|
||
def get_capabilities(self) -> AgentCapability:
|
||
return AgentCapability(
|
||
agent_name=self.name,
|
||
agent_type=self.agent_type,
|
||
version=self.version,
|
||
supported_tasks=["trend_insight", "trend_hotspot"],
|
||
max_concurrency=2,
|
||
description="趋势洞察Agent:分析品牌引用趋势、识别热点话题、推断变化原因并生成建议",
|
||
)
|
||
|
||
async def execute(self, task: TaskMessage) -> TaskResult:
|
||
started_at = datetime.now(timezone.utc)
|
||
start_time = time.monotonic()
|
||
|
||
try:
|
||
if task.task_type == "trend_insight":
|
||
output = await self._trend_insight(task)
|
||
elif task.task_type == "trend_hotspot":
|
||
output = await self._trend_hotspot(task)
|
||
else:
|
||
raise ValueError(f"不支持的任务类型: {task.task_type}")
|
||
|
||
elapsed = time.monotonic() - start_time
|
||
return TaskResult(
|
||
task_id=task.task_id,
|
||
agent_name=self.name,
|
||
status=TaskStatus.COMPLETED,
|
||
output_data=output,
|
||
error_message=None,
|
||
started_at=started_at,
|
||
completed_at=datetime.now(timezone.utc),
|
||
metrics={
|
||
"elapsed_seconds": round(elapsed, 2),
|
||
"task_type": task.task_type,
|
||
},
|
||
)
|
||
|
||
except Exception as e:
|
||
elapsed = time.monotonic() - start_time
|
||
logger.error(f"TrendAgent task {task.task_id} failed: {e}")
|
||
return TaskResult(
|
||
task_id=task.task_id,
|
||
agent_name=self.name,
|
||
status=TaskStatus.FAILED,
|
||
output_data=None,
|
||
error_message=str(e),
|
||
started_at=started_at,
|
||
completed_at=datetime.now(timezone.utc),
|
||
metrics={
|
||
"elapsed_seconds": round(elapsed, 2),
|
||
"task_type": task.task_type,
|
||
},
|
||
)
|
||
|
||
async def _trend_insight(self, task: TaskMessage) -> dict:
|
||
input_data = task.input_data
|
||
brand_id = input_data.get("brand_id")
|
||
days = input_data.get("days", 30)
|
||
platforms = input_data.get("platforms")
|
||
keywords = input_data.get("keywords")
|
||
|
||
if not brand_id:
|
||
raise ValueError("input_data必须包含'brand_id'字段")
|
||
|
||
await self.report_progress(
|
||
task_id=task.task_id,
|
||
progress=0.1,
|
||
message="开始趋势洞察分析...",
|
||
)
|
||
|
||
from app.database import AsyncSessionLocal
|
||
from app.services.trend.trend_analyzer_service import TrendAnalyzerService
|
||
|
||
async with AsyncSessionLocal() as db:
|
||
service = TrendAnalyzerService(db)
|
||
|
||
await self.report_progress(
|
||
task_id=task.task_id,
|
||
progress=0.3,
|
||
message="获取历史引用数据...",
|
||
)
|
||
|
||
await self.report_progress(
|
||
task_id=task.task_id,
|
||
progress=0.5,
|
||
message="执行时间序列分析...",
|
||
)
|
||
|
||
result = await service.analyze_trends(
|
||
brand_id=brand_id,
|
||
days=days,
|
||
platforms=platforms,
|
||
keywords=keywords,
|
||
)
|
||
|
||
await self.report_progress(
|
||
task_id=task.task_id,
|
||
progress=1.0,
|
||
message="趋势洞察分析完成",
|
||
)
|
||
|
||
return result
|
||
|
||
async def _trend_hotspot(self, task: TaskMessage) -> dict:
|
||
input_data = task.input_data
|
||
brand_id = input_data.get("brand_id")
|
||
days = input_data.get("days", 30)
|
||
|
||
if not brand_id:
|
||
raise ValueError("input_data必须包含'brand_id'字段")
|
||
|
||
await self.report_progress(
|
||
task_id=task.task_id,
|
||
progress=0.1,
|
||
message="开始热点话题分析...",
|
||
)
|
||
|
||
from app.database import AsyncSessionLocal
|
||
from app.services.trend.trend_analyzer_service import TrendAnalyzerService
|
||
|
||
async with AsyncSessionLocal() as db:
|
||
service = TrendAnalyzerService(db)
|
||
|
||
await self.report_progress(
|
||
task_id=task.task_id,
|
||
progress=0.5,
|
||
message="检测引用量突增的关键词/话题...",
|
||
)
|
||
|
||
result = await service.get_hotspots(
|
||
brand_id=brand_id,
|
||
days=days,
|
||
)
|
||
|
||
await self.report_progress(
|
||
task_id=task.task_id,
|
||
progress=1.0,
|
||
message="热点话题分析完成",
|
||
)
|
||
|
||
return result
|