geo/backend/app/agent_framework/agents/competitor_analyzer.py

164 lines
5.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
import time
from datetime import datetime, timezone
from app.agent_framework.base import BaseAgent
from app.agent_framework.protocol import (
AgentCapability,
AgentType,
TaskMessage,
TaskResult,
TaskStatus,
)
from app.services.llm import LLMFactory, LLMError
logger = logging.getLogger(__name__)
class CompetitorAnalyzerAgent(BaseAgent):
def __init__(self):
super().__init__(
name="competitor_analyzer",
agent_type=AgentType.COMPETITOR_ANALYZER,
version="1.0.0",
)
def get_capabilities(self) -> AgentCapability:
return AgentCapability(
agent_name=self.name,
agent_type=self.agent_type,
version=self.version,
supported_tasks=["competitor_analyze", "competitor_gap_analysis"],
max_concurrency=2,
description="竞品策略分析Agent对比品牌与竞品的引用数据识别差距领域发现机会点生成策略建议",
)
async def execute(self, task: TaskMessage) -> TaskResult:
started_at = datetime.now(timezone.utc)
start_time = time.monotonic()
try:
if task.task_type == "competitor_gap_analysis":
output = await self._gap_analysis(task)
else:
output = await self._analyze(task)
elapsed = time.monotonic() - start_time
return TaskResult(
task_id=task.task_id,
agent_name=self.name,
status=TaskStatus.COMPLETED,
output_data=output,
error_message=None,
started_at=started_at,
completed_at=datetime.now(timezone.utc),
metrics={
"elapsed_seconds": round(elapsed, 2),
"task_type": task.task_type,
},
)
except LLMError as e:
elapsed = time.monotonic() - start_time
logger.error(f"CompetitorAnalyzer LLM error on task {task.task_id}: {e}")
return TaskResult(
task_id=task.task_id,
agent_name=self.name,
status=TaskStatus.FAILED,
output_data=None,
error_message=f"LLM调用失败: {e}",
started_at=started_at,
completed_at=datetime.now(timezone.utc),
metrics={
"elapsed_seconds": round(elapsed, 2),
"task_type": task.task_type,
},
)
except Exception as e:
elapsed = time.monotonic() - start_time
logger.error(f"CompetitorAnalyzer task {task.task_id} failed: {e}")
return TaskResult(
task_id=task.task_id,
agent_name=self.name,
status=TaskStatus.FAILED,
output_data=None,
error_message=str(e),
started_at=started_at,
completed_at=datetime.now(timezone.utc),
metrics={
"elapsed_seconds": round(elapsed, 2),
"task_type": task.task_type,
},
)
async def _analyze(self, task: TaskMessage) -> dict:
from app.services.competitor.competitor_analyzer_service import CompetitorAnalyzerService
input_data = task.input_data
brand_id = input_data.get("brand_id")
analysis_types = input_data.get("analysis_types")
period_days = input_data.get("period_days", 30)
if not brand_id:
raise ValueError("input_data必须包含'brand_id'字段")
await self.report_progress(
task_id=task.task_id,
progress=0.1,
message="开始竞品策略分析...",
)
service = CompetitorAnalyzerService()
result = await service.analyze_competitor(
brand_id=brand_id,
analysis_types=analysis_types,
period_days=period_days,
progress_callback=lambda p, m: self.report_progress(
task_id=task.task_id, progress=p, message=m,
),
)
await self.report_progress(
task_id=task.task_id,
progress=1.0,
message="竞品策略分析完成",
)
return result
async def _gap_analysis(self, task: TaskMessage) -> dict:
from app.services.competitor.competitor_analyzer_service import CompetitorAnalyzerService
input_data = task.input_data
brand_id = input_data.get("brand_id")
period_days = input_data.get("period_days", 30)
if not brand_id:
raise ValueError("input_data必须包含'brand_id'字段")
await self.report_progress(
task_id=task.task_id,
progress=0.1,
message="开始竞品差距分析...",
)
service = CompetitorAnalyzerService()
result = await service.analyze_competitor(
brand_id=brand_id,
analysis_types=["citation_gap", "platform_coverage", "query_overlap"],
period_days=period_days,
progress_callback=lambda p, m: self.report_progress(
task_id=task.task_id, progress=p, message=m,
),
)
await self.report_progress(
task_id=task.task_id,
progress=1.0,
message="竞品差距分析完成",
)
return result