geo/backend/app/agent_framework/agents/trend_agent.py

167 lines
5.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
import time
from datetime import datetime, timezone
from app.agent_framework.base import BaseAgent
from app.agent_framework.protocol import (
AgentCapability,
AgentType,
TaskMessage,
TaskResult,
TaskStatus,
)
logger = logging.getLogger(__name__)
class TrendAgent(BaseAgent):
def __init__(self):
super().__init__(
name="trend_agent",
agent_type=AgentType.TREND_AGENT,
version="1.0.0",
)
def get_capabilities(self) -> AgentCapability:
return AgentCapability(
agent_name=self.name,
agent_type=self.agent_type,
version=self.version,
supported_tasks=["trend_insight", "trend_hotspot"],
max_concurrency=2,
description="趋势洞察Agent分析品牌引用趋势、识别热点话题、推断变化原因并生成建议",
)
async def execute(self, task: TaskMessage) -> TaskResult:
started_at = datetime.now(timezone.utc)
start_time = time.monotonic()
try:
if task.task_type == "trend_insight":
output = await self._trend_insight(task)
elif task.task_type == "trend_hotspot":
output = await self._trend_hotspot(task)
else:
raise ValueError(f"不支持的任务类型: {task.task_type}")
elapsed = time.monotonic() - start_time
return TaskResult(
task_id=task.task_id,
agent_name=self.name,
status=TaskStatus.COMPLETED,
output_data=output,
error_message=None,
started_at=started_at,
completed_at=datetime.now(timezone.utc),
metrics={
"elapsed_seconds": round(elapsed, 2),
"task_type": task.task_type,
},
)
except Exception as e:
elapsed = time.monotonic() - start_time
logger.error(f"TrendAgent task {task.task_id} failed: {e}")
return TaskResult(
task_id=task.task_id,
agent_name=self.name,
status=TaskStatus.FAILED,
output_data=None,
error_message=str(e),
started_at=started_at,
completed_at=datetime.now(timezone.utc),
metrics={
"elapsed_seconds": round(elapsed, 2),
"task_type": task.task_type,
},
)
async def _trend_insight(self, task: TaskMessage) -> dict:
input_data = task.input_data
brand_id = input_data.get("brand_id")
days = input_data.get("days", 30)
platforms = input_data.get("platforms")
keywords = input_data.get("keywords")
if not brand_id:
raise ValueError("input_data必须包含'brand_id'字段")
await self.report_progress(
task_id=task.task_id,
progress=0.1,
message="开始趋势洞察分析...",
)
from app.database import AsyncSessionLocal
from app.services.trend.trend_analyzer_service import TrendAnalyzerService
async with AsyncSessionLocal() as db:
service = TrendAnalyzerService(db)
await self.report_progress(
task_id=task.task_id,
progress=0.3,
message="获取历史引用数据...",
)
await self.report_progress(
task_id=task.task_id,
progress=0.5,
message="执行时间序列分析...",
)
result = await service.analyze_trends(
brand_id=brand_id,
days=days,
platforms=platforms,
keywords=keywords,
)
await self.report_progress(
task_id=task.task_id,
progress=1.0,
message="趋势洞察分析完成",
)
return result
async def _trend_hotspot(self, task: TaskMessage) -> dict:
input_data = task.input_data
brand_id = input_data.get("brand_id")
days = input_data.get("days", 30)
if not brand_id:
raise ValueError("input_data必须包含'brand_id'字段")
await self.report_progress(
task_id=task.task_id,
progress=0.1,
message="开始热点话题分析...",
)
from app.database import AsyncSessionLocal
from app.services.trend.trend_analyzer_service import TrendAnalyzerService
async with AsyncSessionLocal() as db:
service = TrendAnalyzerService(db)
await self.report_progress(
task_id=task.task_id,
progress=0.5,
message="检测引用量突增的关键词/话题...",
)
result = await service.get_hotspots(
brand_id=brand_id,
days=days,
)
await self.report_progress(
task_id=task.task_id,
progress=1.0,
message="热点话题分析完成",
)
return result