181 lines
6.0 KiB
Python
181 lines
6.0 KiB
Python
"""GEO优化Agent - SEO/GEO双重优化"""
|
||
|
||
import json
|
||
import logging
|
||
import time
|
||
from datetime import datetime, timezone
|
||
|
||
from app.agent_framework.base import BaseAgent
|
||
from app.agent_framework.prompts import GEO_OPTIMIZER_TEMPLATE
|
||
from app.agent_framework.protocol import (
|
||
AgentCapability,
|
||
AgentType,
|
||
TaskMessage,
|
||
TaskResult,
|
||
TaskStatus,
|
||
)
|
||
from app.services.llm import LLMFactory, LLMError
|
||
from app.utils.json_extractor import extract_json
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class GEOOptimizerAgent(BaseAgent):
|
||
"""GEO/SEO内容优化,提升AI搜索引擎可见性
|
||
|
||
支持的任务类型:
|
||
- geo_optimize: 对文章进行GEO/SEO优化
|
||
"""
|
||
|
||
def __init__(self):
|
||
super().__init__(
|
||
name="geo_optimizer",
|
||
agent_type=AgentType.GEO_OPTIMIZER,
|
||
version="1.0.0",
|
||
)
|
||
|
||
def get_capabilities(self) -> AgentCapability:
|
||
return AgentCapability(
|
||
agent_name=self.name,
|
||
agent_type=self.agent_type,
|
||
version=self.version,
|
||
supported_tasks=["geo_optimize"],
|
||
max_concurrency=2,
|
||
description="GEO/SEO内容优化Agent:提升内容在AI搜索引擎中的可见性和引用率",
|
||
)
|
||
|
||
async def execute(self, task: TaskMessage) -> TaskResult:
|
||
"""执行GEO优化任务"""
|
||
started_at = datetime.now(timezone.utc)
|
||
start_time = time.monotonic()
|
||
|
||
try:
|
||
output = await self._optimize(task)
|
||
|
||
elapsed = time.monotonic() - start_time
|
||
return TaskResult(
|
||
task_id=task.task_id,
|
||
agent_name=self.name,
|
||
status=TaskStatus.COMPLETED,
|
||
output_data=output,
|
||
error_message=None,
|
||
started_at=started_at,
|
||
completed_at=datetime.now(timezone.utc),
|
||
metrics={
|
||
"elapsed_seconds": round(elapsed, 2),
|
||
"task_type": task.task_type,
|
||
},
|
||
)
|
||
|
||
except LLMError as e:
|
||
elapsed = time.monotonic() - start_time
|
||
logger.error(f"GEOOptimizer LLM error on task {task.task_id}: {e}")
|
||
return TaskResult(
|
||
task_id=task.task_id,
|
||
agent_name=self.name,
|
||
status=TaskStatus.FAILED,
|
||
output_data=None,
|
||
error_message=f"LLM调用失败: {e}",
|
||
started_at=started_at,
|
||
completed_at=datetime.now(timezone.utc),
|
||
metrics={
|
||
"elapsed_seconds": round(elapsed, 2),
|
||
"task_type": task.task_type,
|
||
},
|
||
)
|
||
|
||
except Exception as e:
|
||
elapsed = time.monotonic() - start_time
|
||
logger.error(f"GEOOptimizer task {task.task_id} failed: {e}")
|
||
return TaskResult(
|
||
task_id=task.task_id,
|
||
agent_name=self.name,
|
||
status=TaskStatus.FAILED,
|
||
output_data=None,
|
||
error_message=str(e),
|
||
started_at=started_at,
|
||
completed_at=datetime.now(timezone.utc),
|
||
metrics={
|
||
"elapsed_seconds": round(elapsed, 2),
|
||
"task_type": task.task_type,
|
||
},
|
||
)
|
||
|
||
async def _optimize(self, task: TaskMessage) -> dict:
|
||
"""执行GEO优化
|
||
|
||
input_data 字段:
|
||
- content: str (必填,待优化文章)
|
||
- target_keywords: list[str] (必填,目标关键词列表)
|
||
- target_platform: str (可选,目标平台)
|
||
- optimization_level: str (可选: light/moderate/aggressive, 默认moderate)
|
||
"""
|
||
input_data = task.input_data
|
||
content = input_data.get("content", "")
|
||
keywords = input_data.get("target_keywords", [])
|
||
|
||
if not content:
|
||
raise ValueError("input_data必须包含非空的'content'字段")
|
||
|
||
# 上报进度:开始
|
||
await self.report_progress(
|
||
task_id=task.task_id,
|
||
progress=0.1,
|
||
message="开始GEO优化...",
|
||
)
|
||
|
||
variables = {
|
||
"original_content": content,
|
||
"target_keywords": ", ".join(keywords) if keywords else "未指定",
|
||
"target_platform": input_data.get("target_platform", "通用"),
|
||
"optimization_level": input_data.get("optimization_level", "moderate"),
|
||
}
|
||
messages = GEO_OPTIMIZER_TEMPLATE.render(variables)
|
||
|
||
# 上报进度:调用LLM
|
||
await self.report_progress(
|
||
task_id=task.task_id,
|
||
progress=0.3,
|
||
message="正在调用LLM进行GEO优化...",
|
||
)
|
||
|
||
provider = LLMFactory.get_default()
|
||
response = await provider.chat(
|
||
messages,
|
||
temperature=0.5,
|
||
max_tokens=len(content) * 3,
|
||
)
|
||
|
||
# 上报进度:解析完成
|
||
await self.report_progress(
|
||
task_id=task.task_id,
|
||
progress=0.9,
|
||
message="LLM优化完成,解析输出...",
|
||
)
|
||
|
||
# 尝试解析JSON输出
|
||
try:
|
||
result = json.loads(extract_json(response.content))
|
||
result["usage"] = response.usage
|
||
# 上报进度:完成
|
||
await self.report_progress(
|
||
task_id=task.task_id,
|
||
progress=1.0,
|
||
message="GEO优化完成",
|
||
)
|
||
return result
|
||
except (json.JSONDecodeError, Exception) as e:
|
||
logger.warning(f"GEO优化结果JSON解析失败,返回原始内容: {e}")
|
||
# 如果LLM没返回标准JSON,包装为基本格式
|
||
await self.report_progress(
|
||
task_id=task.task_id,
|
||
progress=1.0,
|
||
message="GEO优化完成(输出非标准JSON格式)",
|
||
)
|
||
return {
|
||
"optimized_content": response.content,
|
||
"seo_score": None,
|
||
"changes": ["LLM输出非标准格式,已返回原始优化结果"],
|
||
"usage": response.usage,
|
||
}
|