geo/backend/app/agent_framework/agents/monitor_agent.py

232 lines
7.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
import time
import uuid
from datetime import datetime, timezone
from sqlalchemy import select
from app.agent_framework.base import BaseAgent
from app.agent_framework.protocol import (
AgentCapability,
AgentType,
TaskMessage,
TaskResult,
TaskStatus,
)
from app.database import AsyncSessionLocal
from app.models.monitoring import MonitoringRecord
from app.models.query import Query
from app.models.brand import Brand
from app.services.monitoring.monitor_service import MonitorService
logger = logging.getLogger(__name__)
class MonitorAgent(BaseAgent):
def __init__(self):
super().__init__(
name="monitor",
agent_type=AgentType.PERFORMANCE_TRACKER,
version="1.0.0",
)
def get_capabilities(self) -> AgentCapability:
return AgentCapability(
agent_name=self.name,
agent_type=self.agent_type,
version=self.version,
supported_tasks=["monitor_track", "monitor_check_single"],
max_concurrency=3,
description="效果追踪Agent监测品牌引用量、情感、排名变化生成变化报告",
)
async def execute(self, task: TaskMessage) -> TaskResult:
started_at = datetime.now(timezone.utc)
start_time = time.monotonic()
try:
task_type = task.task_type
if task_type == "monitor_track":
output = await self._monitor_track(task)
elif task_type == "monitor_check_single":
output = await self._monitor_check_single(task)
else:
raise ValueError(f"不支持的任务类型: {task_type}")
elapsed = time.monotonic() - start_time
return TaskResult(
task_id=task.task_id,
agent_name=self.name,
status=TaskStatus.COMPLETED,
output_data=output,
error_message=None,
started_at=started_at,
completed_at=datetime.now(timezone.utc),
metrics={
"elapsed_seconds": round(elapsed, 2),
"task_type": task_type,
},
)
except Exception as e:
elapsed = time.monotonic() - start_time
logger.error(f"MonitorAgent task {task.task_id} failed: {e}")
return TaskResult(
task_id=task.task_id,
agent_name=self.name,
status=TaskStatus.FAILED,
output_data=None,
error_message=str(e),
started_at=started_at,
completed_at=datetime.now(timezone.utc),
metrics={
"elapsed_seconds": round(elapsed, 2),
"task_type": task.task_type,
},
)
async def _monitor_track(self, task: TaskMessage) -> dict:
input_data = task.input_data
brand_id = input_data.get("brand_id")
if not brand_id:
raise ValueError("input_data必须包含'brand_id'字段")
brand_id = uuid.UUID(brand_id)
await self.report_progress(
task_id=task.task_id,
progress=0.1,
message="开始效果追踪...",
)
async with AsyncSessionLocal() as db:
stmt = select(Brand).where(Brand.id == brand_id)
result = await db.execute(stmt)
brand = result.scalar_one_or_none()
if not brand:
raise ValueError(f"品牌不存在: {brand_id}")
queries_stmt = select(Query).where(Query.target_brand == brand.name)
queries_result = await db.execute(queries_stmt)
queries = list(queries_result.scalars().all())
if not queries:
await self.report_progress(
task_id=task.task_id,
progress=1.0,
message="品牌下无关联查询,效果追踪完成",
)
return {
"brand_id": str(brand_id),
"brand_name": brand.name,
"total_queries": 0,
"reports": [],
}
total_queries = len(queries)
reports = []
service = MonitorService()
monitoring_stmt = select(MonitoringRecord).where(
MonitoringRecord.brand_id == brand_id,
MonitoringRecord.status == "active",
)
monitoring_result = await db.execute(monitoring_stmt)
monitoring_records = list(monitoring_result.scalars().all())
for idx, record in enumerate(monitoring_records):
progress = 0.2 + (0.7 * idx / max(len(monitoring_records), 1))
await self.report_progress(
task_id=task.task_id,
progress=progress,
message=f"正在检测第 {idx + 1}/{len(monitoring_records)} 条监测记录...",
)
updated_record = await service.check_and_compare(db, record.id)
if updated_record:
report = await service.generate_change_report(db, updated_record.id)
if report:
reports.append(report)
await self.report_progress(
task_id=task.task_id,
progress=1.0,
message="效果追踪完成",
)
return {
"brand_id": str(brand_id),
"brand_name": brand.name,
"total_queries": total_queries,
"checked_records": len(monitoring_records),
"reports": reports,
}
async def _monitor_check_single(self, task: TaskMessage) -> dict:
input_data = task.input_data
brand_id = input_data.get("brand_id")
keyword = input_data.get("keyword")
platform = input_data.get("platform")
if not brand_id:
raise ValueError("input_data必须包含'brand_id'字段")
brand_id = uuid.UUID(brand_id)
await self.report_progress(
task_id=task.task_id,
progress=0.1,
message="开始单关键词检测...",
)
async with AsyncSessionLocal() as db:
service = MonitorService()
await self.report_progress(
task_id=task.task_id,
progress=0.3,
message="正在创建监测记录...",
)
record = await service.create_monitoring_record(
db=db,
brand_id=brand_id,
query_keywords=keyword,
platform=platform,
check_interval_hours=input_data.get("check_interval_hours", 24),
)
await self.report_progress(
task_id=task.task_id,
progress=0.6,
message="正在执行检测对比...",
)
updated_record = await service.check_and_compare(db, record.id)
await self.report_progress(
task_id=task.task_id,
progress=0.8,
message="正在生成变化报告...",
)
report = None
if updated_record:
report = await service.generate_change_report(db, updated_record.id)
await self.report_progress(
task_id=task.task_id,
progress=1.0,
message="单关键词检测完成",
)
return {
"record_id": str(record.id),
"brand_id": str(brand_id),
"keyword": keyword,
"platform": platform,
"change_type": updated_record.change_type if updated_record else None,
"report": report,
}