geo/backend/app/services/content/content_generation_service.py

494 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""ContentGenerationService - 内容生成服务
从 api/content.py 中提取的业务逻辑层,负责:
1. 三阶段内容生成流程generate -> de-AI -> GEO optimize
2. 知识库上下文检索
3. 生成结果持久化
4. Agent 框架集成(可选)
API 层只需负责请求解析和响应格式化,所有业务逻辑委托给此服务。
"""
import asyncio
import logging
import uuid
from datetime import datetime, timezone
from typing import Optional
from sqlalchemy.ext.asyncio import AsyncSession
from app.prompts import (
CONTENT_GENERATOR_TEMPLATE,
DEAI_TEMPLATE,
GEO_OPTIMIZER_TEMPLATE,
)
from app.models.content import Content, ContentVersion
from app.services.llm import LLMFactory, LLMError
logger = logging.getLogger(__name__)
class ContentGenerationService:
"""内容生成服务 - 封装三阶段生成流程及持久化逻辑。"""
def _get_provider(self):
"""获取默认 LLM Provider。可被子类或测试覆盖。"""
return LLMFactory.get_default()
async def _get_knowledge_context(
self,
db: AsyncSession,
brand_name: str,
knowledge_base_ids: list[str],
target_keyword: str,
) -> str:
"""
从知识库检索与查询相关的上下文。
如果有知识库ID则调用 RAGService.search 获取相关内容;
否则返回空字符串,不影响后续流程。
"""
if not knowledge_base_ids:
return ""
try:
from app.services.knowledge.rag_service import RAGService
rag_service = RAGService()
results = await rag_service.search(
session=db,
query=f"{brand_name} {target_keyword}" if brand_name else target_keyword,
knowledge_base_ids=knowledge_base_ids,
top_k=3,
)
if results:
context_parts = []
for r in results:
content = r.get("content", "")
title = r.get("document_title", "")
if content:
context_parts.append(f"[{title}] {content}")
return "\n".join(context_parts)
return ""
except Exception as e:
logger.warning(f"知识库检索失败,将不使用知识库上下文: {e}")
return ""
async def _poll_task_result(
self,
dispatcher,
task_id: str,
timeout: int = 300,
) -> dict:
"""
轮询 Agent 框架任务结果。
Args:
dispatcher: TaskDispatcher 实例
task_id: 已分发的任务 ID
timeout: 超时时间(秒)
Returns:
dict: 任务的 output_data
Raises:
TimeoutError: 任务超时
Exception: 任务执行失败或被取消
"""
from app.agent_framework.protocol import TaskStatus
elapsed = 0.0
poll_interval = 1.0
while elapsed < timeout:
await asyncio.sleep(poll_interval)
elapsed += poll_interval
task_status = await dispatcher.get_task_status(task_id)
status = task_status.get("status")
if status == TaskStatus.COMPLETED:
return task_status.get("output_data", {})
elif status == TaskStatus.FAILED:
error_msg = task_status.get("error_message", "Unknown error")
raise Exception(f"Agent 任务执行失败: {error_msg}")
elif status == TaskStatus.CANCELLED:
raise Exception(f"Agent 任务被取消: {task_id}")
raise TimeoutError(f"Agent 任务超时 ({timeout}s): {task_id}")
async def _execute_via_agent_framework(
self,
keyword: str,
brand_name: str,
platform: str,
content_style: str,
word_count: int,
knowledge_context: str,
knowledge_base_ids: list[str] | None,
run_deai: bool,
run_geo: bool,
db: AsyncSession | None,
user_id: str | None,
org_id: str | None,
) -> dict:
"""
通过 Agent 框架执行三阶段内容生成流程。
依次分发任务到 content_generator、deai_agent、geo_optimizer
并轮询等待每个阶段的结果。失败时抛出异常,由调用方决定是否回退。
Returns:
dict: 与 generate_content 返回格式一致的结果字典
Raises:
Exception: Agent 框架不可用或任务执行失败时
"""
from app.agent_framework.dispatcher import TaskDispatcher
from app.agent_framework.protocol import TaskMessage
from app.config import settings
dispatcher = TaskDispatcher(settings.REDIS_URL)
stages = []
try:
# ---- Stage 1: 内容生成 ----
logger.info(f"通过 Agent 框架执行内容生成: keyword={keyword}")
task_id = str(uuid.uuid4())
task_message = TaskMessage(
task_id=task_id,
agent_name="content_generator",
task_type="generate_article",
priority=0,
input_data={
"target_keyword": keyword,
"brand_name": brand_name,
"target_platform": platform,
"knowledge_base_ids": knowledge_base_ids or [],
"word_count": word_count,
"content_style": content_style,
"knowledge_context": knowledge_context,
},
callback_url=None,
created_at=datetime.now(timezone.utc),
timeout_seconds=300,
)
dispatched_id = await dispatcher.dispatch(
task_message,
organization_id=org_id,
created_by=user_id,
)
gen_result = await self._poll_task_result(
dispatcher, dispatched_id, timeout=300
)
content = gen_result.get("content", "")
stages.append(
{
"stage": "content_generation",
"status": "success",
"word_count": len(content),
}
)
# ---- Stage 2: 去AI化可选 ----
if run_deai:
logger.info("通过 Agent 框架执行去AI化")
task_id = str(uuid.uuid4())
task_message = TaskMessage(
task_id=task_id,
agent_name="deai_agent",
task_type="deai_process",
priority=0,
input_data={
"content": content,
"platform": platform,
"style": "自然流畅",
"preserve_structure": True,
},
callback_url=None,
created_at=datetime.now(timezone.utc),
timeout_seconds=180,
)
dispatched_id = await dispatcher.dispatch(
task_message,
organization_id=org_id,
created_by=user_id,
)
deai_result = await self._poll_task_result(
dispatcher, dispatched_id, timeout=180
)
content = deai_result.get("content", content)
stages.append({"stage": "deai", "status": "success"})
# ---- Stage 3: GEO优化可选 ----
optimized = content
seo_score = None
if run_geo:
logger.info("通过 Agent 框架执行 GEO 优化")
task_id = str(uuid.uuid4())
task_message = TaskMessage(
task_id=task_id,
agent_name="geo_optimizer",
task_type="geo_optimize",
priority=0,
input_data={
"content": content,
"target_keywords": [keyword],
"target_platform": platform,
"optimization_level": "moderate",
},
callback_url=None,
created_at=datetime.now(timezone.utc),
timeout_seconds=180,
)
dispatched_id = await dispatcher.dispatch(
task_message,
organization_id=org_id,
created_by=user_id,
)
geo_result = await self._poll_task_result(
dispatcher, dispatched_id, timeout=180
)
optimized = geo_result.get("optimized_content", content)
seo_score = geo_result.get("seo_score")
stages.append({"stage": "geo_optimization", "status": "success"})
# ---- 持久化(可选) ----
content_id = None
if db and user_id and org_id:
content_obj = Content(
organization_id=org_id,
title=keyword,
content_type="article",
body=optimized,
status="draft",
target_platforms=[platform] if platform else [],
keywords=[keyword],
extra_metadata={
"original_content": content if content != optimized else None,
"pipeline_stages": stages,
"seo_score": seo_score,
"brand_name": brand_name,
"content_style": content_style,
"word_count_target": word_count,
"execution_mode": "agent_framework",
},
created_by=user_id,
current_version=1,
)
db.add(content_obj)
await db.flush()
version = ContentVersion(
content_id=content_obj.id,
version_number=1,
title=keyword,
body=optimized,
change_summary="Agent框架Pipeline自动生成",
created_by=user_id,
)
db.add(version)
await db.commit()
await db.refresh(content_obj)
content_id = str(content_obj.id)
logger.info("通过 Agent 框架执行内容生成完成")
return {
"content": content,
"optimized_content": optimized,
"seo_score": seo_score,
"content_id": content_id,
"pipeline_stages": stages,
}
finally:
await dispatcher.close()
async def generate_content(
self,
keyword: str,
brand_name: str = "",
platform: str = "通用",
content_style: str = "专业严谨",
word_count: int = 2000,
knowledge_context: str = "",
knowledge_base_ids: list[str] | None = None,
db: AsyncSession | None = None,
user_id: str | None = None,
org_id: str | None = None,
run_deai: bool = True,
run_geo: bool = True,
use_agent_framework: bool = False,
) -> dict:
"""
执行三阶段内容生成流程。
阶段:
1. 内容生成CONTENT_GENERATOR_TEMPLATE
2. 去AI化DEAI_TEMPLATE可选
3. GEO优化GEO_OPTIMIZER_TEMPLATE可选
如果提供了 db、user_id 和 org_id生成结果将持久化到数据库。
Args:
keyword: 目标关键词
brand_name: 品牌名称
platform: 目标平台,默认"通用"
content_style: 内容风格,默认"专业严谨"
word_count: 目标字数默认2000
knowledge_context: 直接传入的知识库上下文(优先使用)
knowledge_base_ids: 知识库ID列表用于RAG检索
db: 数据库会话(可选,提供时将持久化结果)
user_id: 用户ID可选持久化时需要
org_id: 组织ID可选持久化时需要
run_deai: 是否执行去AI化默认True
run_geo: 是否执行GEO优化默认True
use_agent_framework: 是否通过Agent框架执行默认False。
当为True时尝试通过TaskDispatcher分发任务到Agent
如果Agent框架不可用自动回退到直接调用模式。
Returns:
dict: {
"content": str, # 去AI化后的内容或原始生成内容
"optimized_content": str, # GEO优化后的内容或与content相同
"seo_score": int | None,
"content_id": str | None, # 数据库记录ID
"pipeline_stages": list[dict],
}
Raises:
LLMError: LLM调用失败时
"""
# ---- Agent 框架路径 ----
if use_agent_framework:
try:
logger.info("尝试通过 Agent 框架执行内容生成")
return await self._execute_via_agent_framework(
keyword=keyword,
brand_name=brand_name,
platform=platform,
content_style=content_style,
word_count=word_count,
knowledge_context=knowledge_context,
knowledge_base_ids=knowledge_base_ids,
run_deai=run_deai,
run_geo=run_geo,
db=db,
user_id=user_id,
org_id=org_id,
)
except Exception as e:
logger.warning(
f"Agent 框架执行失败,回退到直接调用模式: {e}"
)
# 继续执行下方的直接调用逻辑
# ---- 直接调用路径(原有逻辑) ----
provider = self._get_provider()
stages = []
# 如果没有直接传入知识库上下文但提供了知识库ID和db则检索
if not knowledge_context and knowledge_base_ids and db:
knowledge_context = await self._get_knowledge_context(
db, brand_name, knowledge_base_ids, keyword
)
# ---- Stage 1: 内容生成 ----
gen_variables = {
"topic_title": keyword,
"target_keyword": keyword,
"target_platform": platform,
"content_angle": "综合分析",
"content_style": content_style,
"word_count": str(word_count),
"brand_name": brand_name,
"knowledge_context": knowledge_context,
}
messages = CONTENT_GENERATOR_TEMPLATE.render(gen_variables)
response = await provider.chat(
messages, temperature=0.7, max_tokens=word_count * 2
)
content = response.content
stages.append(
{"stage": "content_generation", "status": "success", "word_count": len(content)}
)
# ---- Stage 2: 去AI化可选 ----
if run_deai:
deai_variables = {
"original_content": content,
"target_style": "自然流畅",
"preserve_structure": "",
}
messages = DEAI_TEMPLATE.render(deai_variables)
response = await provider.chat(
messages, temperature=0.9, max_tokens=len(content) * 2
)
content = response.content
stages.append({"stage": "deai", "status": "success"})
# ---- Stage 3: GEO优化可选 ----
optimized = content
seo_score = None
if run_geo:
geo_variables = {
"original_content": content,
"target_keywords": keyword,
"target_platform": platform,
"optimization_level": "moderate",
}
messages = GEO_OPTIMIZER_TEMPLATE.render(geo_variables)
response = await provider.chat(
messages, temperature=0.5, max_tokens=len(content) * 2
)
optimized = response.content
stages.append({"stage": "geo_optimization", "status": "success"})
# ---- 持久化(可选) ----
content_id = None
if db and user_id and org_id:
content_obj = Content(
organization_id=org_id,
title=keyword,
content_type="article",
body=optimized,
status="draft",
target_platforms=[platform] if platform else [],
keywords=[keyword],
extra_metadata={
"original_content": content if content != optimized else None,
"pipeline_stages": stages,
"seo_score": seo_score,
"brand_name": brand_name,
"content_style": content_style,
"word_count_target": word_count,
},
created_by=user_id,
current_version=1,
)
db.add(content_obj)
await db.flush()
version = ContentVersion(
content_id=content_obj.id,
version_number=1,
title=keyword,
body=optimized,
change_summary="Pipeline自动生成",
created_by=user_id,
)
db.add(version)
await db.commit()
await db.refresh(content_obj)
content_id = str(content_obj.id)
return {
"content": content,
"optimized_content": optimized,
"seo_score": seo_score,
"content_id": content_id,
"pipeline_stages": stages,
}