164 lines
5.4 KiB
Python
164 lines
5.4 KiB
Python
import logging
|
||
import time
|
||
from datetime import datetime, timezone
|
||
|
||
from app.agent_framework.base import BaseAgent
|
||
from app.agent_framework.protocol import (
|
||
AgentCapability,
|
||
AgentType,
|
||
TaskMessage,
|
||
TaskResult,
|
||
TaskStatus,
|
||
)
|
||
from app.services.llm import LLMFactory, LLMError
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class CompetitorAnalyzerAgent(BaseAgent):
|
||
|
||
def __init__(self):
|
||
super().__init__(
|
||
name="competitor_analyzer",
|
||
agent_type=AgentType.COMPETITOR_ANALYZER,
|
||
version="1.0.0",
|
||
)
|
||
|
||
def get_capabilities(self) -> AgentCapability:
|
||
return AgentCapability(
|
||
agent_name=self.name,
|
||
agent_type=self.agent_type,
|
||
version=self.version,
|
||
supported_tasks=["competitor_analyze", "competitor_gap_analysis"],
|
||
max_concurrency=2,
|
||
description="竞品策略分析Agent:对比品牌与竞品的引用数据,识别差距领域,发现机会点,生成策略建议",
|
||
)
|
||
|
||
async def execute(self, task: TaskMessage) -> TaskResult:
|
||
started_at = datetime.now(timezone.utc)
|
||
start_time = time.monotonic()
|
||
|
||
try:
|
||
if task.task_type == "competitor_gap_analysis":
|
||
output = await self._gap_analysis(task)
|
||
else:
|
||
output = await self._analyze(task)
|
||
|
||
elapsed = time.monotonic() - start_time
|
||
return TaskResult(
|
||
task_id=task.task_id,
|
||
agent_name=self.name,
|
||
status=TaskStatus.COMPLETED,
|
||
output_data=output,
|
||
error_message=None,
|
||
started_at=started_at,
|
||
completed_at=datetime.now(timezone.utc),
|
||
metrics={
|
||
"elapsed_seconds": round(elapsed, 2),
|
||
"task_type": task.task_type,
|
||
},
|
||
)
|
||
|
||
except LLMError as e:
|
||
elapsed = time.monotonic() - start_time
|
||
logger.error(f"CompetitorAnalyzer LLM error on task {task.task_id}: {e}")
|
||
return TaskResult(
|
||
task_id=task.task_id,
|
||
agent_name=self.name,
|
||
status=TaskStatus.FAILED,
|
||
output_data=None,
|
||
error_message=f"LLM调用失败: {e}",
|
||
started_at=started_at,
|
||
completed_at=datetime.now(timezone.utc),
|
||
metrics={
|
||
"elapsed_seconds": round(elapsed, 2),
|
||
"task_type": task.task_type,
|
||
},
|
||
)
|
||
|
||
except Exception as e:
|
||
elapsed = time.monotonic() - start_time
|
||
logger.error(f"CompetitorAnalyzer task {task.task_id} failed: {e}")
|
||
return TaskResult(
|
||
task_id=task.task_id,
|
||
agent_name=self.name,
|
||
status=TaskStatus.FAILED,
|
||
output_data=None,
|
||
error_message=str(e),
|
||
started_at=started_at,
|
||
completed_at=datetime.now(timezone.utc),
|
||
metrics={
|
||
"elapsed_seconds": round(elapsed, 2),
|
||
"task_type": task.task_type,
|
||
},
|
||
)
|
||
|
||
async def _analyze(self, task: TaskMessage) -> dict:
|
||
from app.services.competitor.competitor_analyzer_service import CompetitorAnalyzerService
|
||
|
||
input_data = task.input_data
|
||
brand_id = input_data.get("brand_id")
|
||
analysis_types = input_data.get("analysis_types")
|
||
period_days = input_data.get("period_days", 30)
|
||
|
||
if not brand_id:
|
||
raise ValueError("input_data必须包含'brand_id'字段")
|
||
|
||
await self.report_progress(
|
||
task_id=task.task_id,
|
||
progress=0.1,
|
||
message="开始竞品策略分析...",
|
||
)
|
||
|
||
service = CompetitorAnalyzerService()
|
||
result = await service.analyze_competitor(
|
||
brand_id=brand_id,
|
||
analysis_types=analysis_types,
|
||
period_days=period_days,
|
||
progress_callback=lambda p, m: self.report_progress(
|
||
task_id=task.task_id, progress=p, message=m,
|
||
),
|
||
)
|
||
|
||
await self.report_progress(
|
||
task_id=task.task_id,
|
||
progress=1.0,
|
||
message="竞品策略分析完成",
|
||
)
|
||
|
||
return result
|
||
|
||
async def _gap_analysis(self, task: TaskMessage) -> dict:
|
||
from app.services.competitor.competitor_analyzer_service import CompetitorAnalyzerService
|
||
|
||
input_data = task.input_data
|
||
brand_id = input_data.get("brand_id")
|
||
period_days = input_data.get("period_days", 30)
|
||
|
||
if not brand_id:
|
||
raise ValueError("input_data必须包含'brand_id'字段")
|
||
|
||
await self.report_progress(
|
||
task_id=task.task_id,
|
||
progress=0.1,
|
||
message="开始竞品差距分析...",
|
||
)
|
||
|
||
service = CompetitorAnalyzerService()
|
||
result = await service.analyze_competitor(
|
||
brand_id=brand_id,
|
||
analysis_types=["citation_gap", "platform_coverage", "query_overlap"],
|
||
period_days=period_days,
|
||
progress_callback=lambda p, m: self.report_progress(
|
||
task_id=task.task_id, progress=p, message=m,
|
||
),
|
||
)
|
||
|
||
await self.report_progress(
|
||
task_id=task.task_id,
|
||
progress=1.0,
|
||
message="竞品差距分析完成",
|
||
)
|
||
|
||
return result
|