232 lines
7.8 KiB
Python
232 lines
7.8 KiB
Python
import logging
|
||
import time
|
||
import uuid
|
||
from datetime import datetime, timezone
|
||
|
||
from sqlalchemy import select
|
||
|
||
from app.agent_framework.base import BaseAgent
|
||
from app.agent_framework.protocol import (
|
||
AgentCapability,
|
||
AgentType,
|
||
TaskMessage,
|
||
TaskResult,
|
||
TaskStatus,
|
||
)
|
||
from app.database import AsyncSessionLocal
|
||
from app.models.monitoring import MonitoringRecord
|
||
from app.models.query import Query
|
||
from app.models.brand import Brand
|
||
from app.services.monitoring.monitor_service import MonitorService
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class MonitorAgent(BaseAgent):
|
||
|
||
def __init__(self):
|
||
super().__init__(
|
||
name="monitor",
|
||
agent_type=AgentType.PERFORMANCE_TRACKER,
|
||
version="1.0.0",
|
||
)
|
||
|
||
def get_capabilities(self) -> AgentCapability:
|
||
return AgentCapability(
|
||
agent_name=self.name,
|
||
agent_type=self.agent_type,
|
||
version=self.version,
|
||
supported_tasks=["monitor_track", "monitor_check_single"],
|
||
max_concurrency=3,
|
||
description="效果追踪Agent:监测品牌引用量、情感、排名变化,生成变化报告",
|
||
)
|
||
|
||
async def execute(self, task: TaskMessage) -> TaskResult:
|
||
started_at = datetime.now(timezone.utc)
|
||
start_time = time.monotonic()
|
||
|
||
try:
|
||
task_type = task.task_type
|
||
if task_type == "monitor_track":
|
||
output = await self._monitor_track(task)
|
||
elif task_type == "monitor_check_single":
|
||
output = await self._monitor_check_single(task)
|
||
else:
|
||
raise ValueError(f"不支持的任务类型: {task_type}")
|
||
|
||
elapsed = time.monotonic() - start_time
|
||
return TaskResult(
|
||
task_id=task.task_id,
|
||
agent_name=self.name,
|
||
status=TaskStatus.COMPLETED,
|
||
output_data=output,
|
||
error_message=None,
|
||
started_at=started_at,
|
||
completed_at=datetime.now(timezone.utc),
|
||
metrics={
|
||
"elapsed_seconds": round(elapsed, 2),
|
||
"task_type": task_type,
|
||
},
|
||
)
|
||
|
||
except Exception as e:
|
||
elapsed = time.monotonic() - start_time
|
||
logger.error(f"MonitorAgent task {task.task_id} failed: {e}")
|
||
return TaskResult(
|
||
task_id=task.task_id,
|
||
agent_name=self.name,
|
||
status=TaskStatus.FAILED,
|
||
output_data=None,
|
||
error_message=str(e),
|
||
started_at=started_at,
|
||
completed_at=datetime.now(timezone.utc),
|
||
metrics={
|
||
"elapsed_seconds": round(elapsed, 2),
|
||
"task_type": task.task_type,
|
||
},
|
||
)
|
||
|
||
async def _monitor_track(self, task: TaskMessage) -> dict:
|
||
input_data = task.input_data
|
||
brand_id = input_data.get("brand_id")
|
||
if not brand_id:
|
||
raise ValueError("input_data必须包含'brand_id'字段")
|
||
|
||
brand_id = uuid.UUID(brand_id)
|
||
|
||
await self.report_progress(
|
||
task_id=task.task_id,
|
||
progress=0.1,
|
||
message="开始效果追踪...",
|
||
)
|
||
|
||
async with AsyncSessionLocal() as db:
|
||
stmt = select(Brand).where(Brand.id == brand_id)
|
||
result = await db.execute(stmt)
|
||
brand = result.scalar_one_or_none()
|
||
if not brand:
|
||
raise ValueError(f"品牌不存在: {brand_id}")
|
||
|
||
queries_stmt = select(Query).where(Query.target_brand == brand.name)
|
||
queries_result = await db.execute(queries_stmt)
|
||
queries = list(queries_result.scalars().all())
|
||
|
||
if not queries:
|
||
await self.report_progress(
|
||
task_id=task.task_id,
|
||
progress=1.0,
|
||
message="品牌下无关联查询,效果追踪完成",
|
||
)
|
||
return {
|
||
"brand_id": str(brand_id),
|
||
"brand_name": brand.name,
|
||
"total_queries": 0,
|
||
"reports": [],
|
||
}
|
||
|
||
total_queries = len(queries)
|
||
reports = []
|
||
service = MonitorService()
|
||
|
||
monitoring_stmt = select(MonitoringRecord).where(
|
||
MonitoringRecord.brand_id == brand_id,
|
||
MonitoringRecord.status == "active",
|
||
)
|
||
monitoring_result = await db.execute(monitoring_stmt)
|
||
monitoring_records = list(monitoring_result.scalars().all())
|
||
|
||
for idx, record in enumerate(monitoring_records):
|
||
progress = 0.2 + (0.7 * idx / max(len(monitoring_records), 1))
|
||
await self.report_progress(
|
||
task_id=task.task_id,
|
||
progress=progress,
|
||
message=f"正在检测第 {idx + 1}/{len(monitoring_records)} 条监测记录...",
|
||
)
|
||
|
||
updated_record = await service.check_and_compare(db, record.id)
|
||
if updated_record:
|
||
report = await service.generate_change_report(db, updated_record.id)
|
||
if report:
|
||
reports.append(report)
|
||
|
||
await self.report_progress(
|
||
task_id=task.task_id,
|
||
progress=1.0,
|
||
message="效果追踪完成",
|
||
)
|
||
|
||
return {
|
||
"brand_id": str(brand_id),
|
||
"brand_name": brand.name,
|
||
"total_queries": total_queries,
|
||
"checked_records": len(monitoring_records),
|
||
"reports": reports,
|
||
}
|
||
|
||
async def _monitor_check_single(self, task: TaskMessage) -> dict:
|
||
input_data = task.input_data
|
||
brand_id = input_data.get("brand_id")
|
||
keyword = input_data.get("keyword")
|
||
platform = input_data.get("platform")
|
||
|
||
if not brand_id:
|
||
raise ValueError("input_data必须包含'brand_id'字段")
|
||
|
||
brand_id = uuid.UUID(brand_id)
|
||
|
||
await self.report_progress(
|
||
task_id=task.task_id,
|
||
progress=0.1,
|
||
message="开始单关键词检测...",
|
||
)
|
||
|
||
async with AsyncSessionLocal() as db:
|
||
service = MonitorService()
|
||
|
||
await self.report_progress(
|
||
task_id=task.task_id,
|
||
progress=0.3,
|
||
message="正在创建监测记录...",
|
||
)
|
||
|
||
record = await service.create_monitoring_record(
|
||
db=db,
|
||
brand_id=brand_id,
|
||
query_keywords=keyword,
|
||
platform=platform,
|
||
check_interval_hours=input_data.get("check_interval_hours", 24),
|
||
)
|
||
|
||
await self.report_progress(
|
||
task_id=task.task_id,
|
||
progress=0.6,
|
||
message="正在执行检测对比...",
|
||
)
|
||
|
||
updated_record = await service.check_and_compare(db, record.id)
|
||
|
||
await self.report_progress(
|
||
task_id=task.task_id,
|
||
progress=0.8,
|
||
message="正在生成变化报告...",
|
||
)
|
||
|
||
report = None
|
||
if updated_record:
|
||
report = await service.generate_change_report(db, updated_record.id)
|
||
|
||
await self.report_progress(
|
||
task_id=task.task_id,
|
||
progress=1.0,
|
||
message="单关键词检测完成",
|
||
)
|
||
|
||
return {
|
||
"record_id": str(record.id),
|
||
"brand_id": str(brand_id),
|
||
"keyword": keyword,
|
||
"platform": platform,
|
||
"change_type": updated_record.change_type if updated_record else None,
|
||
"report": report,
|
||
}
|