494 lines
18 KiB
Python
494 lines
18 KiB
Python
"""ContentGenerationService - 内容生成服务
|
||
|
||
从 api/content.py 中提取的业务逻辑层,负责:
|
||
1. 三阶段内容生成流程(generate -> de-AI -> GEO optimize)
|
||
2. 知识库上下文检索
|
||
3. 生成结果持久化
|
||
4. Agent 框架集成(可选)
|
||
|
||
API 层只需负责请求解析和响应格式化,所有业务逻辑委托给此服务。
|
||
"""
|
||
|
||
import asyncio
|
||
import logging
|
||
import uuid
|
||
from datetime import datetime, timezone
|
||
from typing import Optional
|
||
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.agent_framework.prompts import (
|
||
CONTENT_GENERATOR_TEMPLATE,
|
||
DEAI_TEMPLATE,
|
||
GEO_OPTIMIZER_TEMPLATE,
|
||
)
|
||
from app.models.content import Content, ContentVersion
|
||
from app.services.llm import LLMFactory, LLMError
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class ContentGenerationService:
|
||
"""内容生成服务 - 封装三阶段生成流程及持久化逻辑。"""
|
||
|
||
def _get_provider(self):
|
||
"""获取默认 LLM Provider。可被子类或测试覆盖。"""
|
||
return LLMFactory.get_default()
|
||
|
||
async def _get_knowledge_context(
|
||
self,
|
||
db: AsyncSession,
|
||
brand_name: str,
|
||
knowledge_base_ids: list[str],
|
||
target_keyword: str,
|
||
) -> str:
|
||
"""
|
||
从知识库检索与查询相关的上下文。
|
||
|
||
如果有知识库ID,则调用 RAGService.search 获取相关内容;
|
||
否则返回空字符串,不影响后续流程。
|
||
"""
|
||
if not knowledge_base_ids:
|
||
return ""
|
||
|
||
try:
|
||
from app.services.knowledge.rag_service import RAGService
|
||
|
||
rag_service = RAGService()
|
||
results = await rag_service.search(
|
||
session=db,
|
||
query=f"{brand_name} {target_keyword}" if brand_name else target_keyword,
|
||
knowledge_base_ids=knowledge_base_ids,
|
||
top_k=3,
|
||
)
|
||
if results:
|
||
context_parts = []
|
||
for r in results:
|
||
content = r.get("content", "")
|
||
title = r.get("document_title", "")
|
||
if content:
|
||
context_parts.append(f"[{title}] {content}")
|
||
return "\n".join(context_parts)
|
||
return ""
|
||
except Exception as e:
|
||
logger.warning(f"知识库检索失败,将不使用知识库上下文: {e}")
|
||
return ""
|
||
|
||
async def _poll_task_result(
|
||
self,
|
||
dispatcher,
|
||
task_id: str,
|
||
timeout: int = 300,
|
||
) -> dict:
|
||
"""
|
||
轮询 Agent 框架任务结果。
|
||
|
||
Args:
|
||
dispatcher: TaskDispatcher 实例
|
||
task_id: 已分发的任务 ID
|
||
timeout: 超时时间(秒)
|
||
|
||
Returns:
|
||
dict: 任务的 output_data
|
||
|
||
Raises:
|
||
TimeoutError: 任务超时
|
||
Exception: 任务执行失败或被取消
|
||
"""
|
||
from app.agent_framework.protocol import TaskStatus
|
||
|
||
elapsed = 0.0
|
||
poll_interval = 1.0
|
||
while elapsed < timeout:
|
||
await asyncio.sleep(poll_interval)
|
||
elapsed += poll_interval
|
||
|
||
task_status = await dispatcher.get_task_status(task_id)
|
||
status = task_status.get("status")
|
||
|
||
if status == TaskStatus.COMPLETED:
|
||
return task_status.get("output_data", {})
|
||
|
||
elif status == TaskStatus.FAILED:
|
||
error_msg = task_status.get("error_message", "Unknown error")
|
||
raise Exception(f"Agent 任务执行失败: {error_msg}")
|
||
|
||
elif status == TaskStatus.CANCELLED:
|
||
raise Exception(f"Agent 任务被取消: {task_id}")
|
||
|
||
raise TimeoutError(f"Agent 任务超时 ({timeout}s): {task_id}")
|
||
|
||
async def _execute_via_agent_framework(
|
||
self,
|
||
keyword: str,
|
||
brand_name: str,
|
||
platform: str,
|
||
content_style: str,
|
||
word_count: int,
|
||
knowledge_context: str,
|
||
knowledge_base_ids: list[str] | None,
|
||
run_deai: bool,
|
||
run_geo: bool,
|
||
db: AsyncSession | None,
|
||
user_id: str | None,
|
||
org_id: str | None,
|
||
) -> dict:
|
||
"""
|
||
通过 Agent 框架执行三阶段内容生成流程。
|
||
|
||
依次分发任务到 content_generator、deai_agent、geo_optimizer,
|
||
并轮询等待每个阶段的结果。失败时抛出异常,由调用方决定是否回退。
|
||
|
||
Returns:
|
||
dict: 与 generate_content 返回格式一致的结果字典
|
||
|
||
Raises:
|
||
Exception: Agent 框架不可用或任务执行失败时
|
||
"""
|
||
from app.agent_framework.dispatcher import TaskDispatcher
|
||
from app.agent_framework.protocol import TaskMessage
|
||
from app.config import settings
|
||
|
||
dispatcher = TaskDispatcher(settings.REDIS_URL)
|
||
stages = []
|
||
|
||
try:
|
||
# ---- Stage 1: 内容生成 ----
|
||
logger.info(f"通过 Agent 框架执行内容生成: keyword={keyword}")
|
||
task_id = str(uuid.uuid4())
|
||
task_message = TaskMessage(
|
||
task_id=task_id,
|
||
agent_name="content_generator",
|
||
task_type="generate_article",
|
||
priority=0,
|
||
input_data={
|
||
"target_keyword": keyword,
|
||
"brand_name": brand_name,
|
||
"target_platform": platform,
|
||
"knowledge_base_ids": knowledge_base_ids or [],
|
||
"word_count": word_count,
|
||
"content_style": content_style,
|
||
"knowledge_context": knowledge_context,
|
||
},
|
||
callback_url=None,
|
||
created_at=datetime.now(timezone.utc),
|
||
timeout_seconds=300,
|
||
)
|
||
|
||
dispatched_id = await dispatcher.dispatch(
|
||
task_message,
|
||
organization_id=org_id,
|
||
created_by=user_id,
|
||
)
|
||
gen_result = await self._poll_task_result(
|
||
dispatcher, dispatched_id, timeout=300
|
||
)
|
||
content = gen_result.get("content", "")
|
||
stages.append(
|
||
{
|
||
"stage": "content_generation",
|
||
"status": "success",
|
||
"word_count": len(content),
|
||
}
|
||
)
|
||
|
||
# ---- Stage 2: 去AI化(可选) ----
|
||
if run_deai:
|
||
logger.info("通过 Agent 框架执行去AI化")
|
||
task_id = str(uuid.uuid4())
|
||
task_message = TaskMessage(
|
||
task_id=task_id,
|
||
agent_name="deai_agent",
|
||
task_type="deai_process",
|
||
priority=0,
|
||
input_data={
|
||
"content": content,
|
||
"platform": platform,
|
||
"style": "自然流畅",
|
||
"preserve_structure": True,
|
||
},
|
||
callback_url=None,
|
||
created_at=datetime.now(timezone.utc),
|
||
timeout_seconds=180,
|
||
)
|
||
dispatched_id = await dispatcher.dispatch(
|
||
task_message,
|
||
organization_id=org_id,
|
||
created_by=user_id,
|
||
)
|
||
deai_result = await self._poll_task_result(
|
||
dispatcher, dispatched_id, timeout=180
|
||
)
|
||
content = deai_result.get("content", content)
|
||
stages.append({"stage": "deai", "status": "success"})
|
||
|
||
# ---- Stage 3: GEO优化(可选) ----
|
||
optimized = content
|
||
seo_score = None
|
||
if run_geo:
|
||
logger.info("通过 Agent 框架执行 GEO 优化")
|
||
task_id = str(uuid.uuid4())
|
||
task_message = TaskMessage(
|
||
task_id=task_id,
|
||
agent_name="geo_optimizer",
|
||
task_type="geo_optimize",
|
||
priority=0,
|
||
input_data={
|
||
"content": content,
|
||
"target_keywords": [keyword],
|
||
"target_platform": platform,
|
||
"optimization_level": "moderate",
|
||
},
|
||
callback_url=None,
|
||
created_at=datetime.now(timezone.utc),
|
||
timeout_seconds=180,
|
||
)
|
||
dispatched_id = await dispatcher.dispatch(
|
||
task_message,
|
||
organization_id=org_id,
|
||
created_by=user_id,
|
||
)
|
||
geo_result = await self._poll_task_result(
|
||
dispatcher, dispatched_id, timeout=180
|
||
)
|
||
optimized = geo_result.get("optimized_content", content)
|
||
seo_score = geo_result.get("seo_score")
|
||
stages.append({"stage": "geo_optimization", "status": "success"})
|
||
|
||
# ---- 持久化(可选) ----
|
||
content_id = None
|
||
if db and user_id and org_id:
|
||
content_obj = Content(
|
||
organization_id=org_id,
|
||
title=keyword,
|
||
content_type="article",
|
||
body=optimized,
|
||
status="draft",
|
||
target_platforms=[platform] if platform else [],
|
||
keywords=[keyword],
|
||
extra_metadata={
|
||
"original_content": content if content != optimized else None,
|
||
"pipeline_stages": stages,
|
||
"seo_score": seo_score,
|
||
"brand_name": brand_name,
|
||
"content_style": content_style,
|
||
"word_count_target": word_count,
|
||
"execution_mode": "agent_framework",
|
||
},
|
||
created_by=user_id,
|
||
current_version=1,
|
||
)
|
||
db.add(content_obj)
|
||
await db.flush()
|
||
|
||
version = ContentVersion(
|
||
content_id=content_obj.id,
|
||
version_number=1,
|
||
title=keyword,
|
||
body=optimized,
|
||
change_summary="Agent框架Pipeline自动生成",
|
||
created_by=user_id,
|
||
)
|
||
db.add(version)
|
||
await db.commit()
|
||
await db.refresh(content_obj)
|
||
content_id = str(content_obj.id)
|
||
|
||
logger.info("通过 Agent 框架执行内容生成完成")
|
||
return {
|
||
"content": content,
|
||
"optimized_content": optimized,
|
||
"seo_score": seo_score,
|
||
"content_id": content_id,
|
||
"pipeline_stages": stages,
|
||
}
|
||
|
||
finally:
|
||
await dispatcher.close()
|
||
|
||
async def generate_content(
|
||
self,
|
||
keyword: str,
|
||
brand_name: str = "",
|
||
platform: str = "通用",
|
||
content_style: str = "专业严谨",
|
||
word_count: int = 2000,
|
||
knowledge_context: str = "",
|
||
knowledge_base_ids: list[str] | None = None,
|
||
db: AsyncSession | None = None,
|
||
user_id: str | None = None,
|
||
org_id: str | None = None,
|
||
run_deai: bool = True,
|
||
run_geo: bool = True,
|
||
use_agent_framework: bool = False,
|
||
) -> dict:
|
||
"""
|
||
执行三阶段内容生成流程。
|
||
|
||
阶段:
|
||
1. 内容生成(CONTENT_GENERATOR_TEMPLATE)
|
||
2. 去AI化(DEAI_TEMPLATE,可选)
|
||
3. GEO优化(GEO_OPTIMIZER_TEMPLATE,可选)
|
||
|
||
如果提供了 db、user_id 和 org_id,生成结果将持久化到数据库。
|
||
|
||
Args:
|
||
keyword: 目标关键词
|
||
brand_name: 品牌名称
|
||
platform: 目标平台,默认"通用"
|
||
content_style: 内容风格,默认"专业严谨"
|
||
word_count: 目标字数,默认2000
|
||
knowledge_context: 直接传入的知识库上下文(优先使用)
|
||
knowledge_base_ids: 知识库ID列表,用于RAG检索
|
||
db: 数据库会话(可选,提供时将持久化结果)
|
||
user_id: 用户ID(可选,持久化时需要)
|
||
org_id: 组织ID(可选,持久化时需要)
|
||
run_deai: 是否执行去AI化,默认True
|
||
run_geo: 是否执行GEO优化,默认True
|
||
use_agent_framework: 是否通过Agent框架执行,默认False。
|
||
当为True时,尝试通过TaskDispatcher分发任务到Agent;
|
||
如果Agent框架不可用,自动回退到直接调用模式。
|
||
|
||
Returns:
|
||
dict: {
|
||
"content": str, # 去AI化后的内容(或原始生成内容)
|
||
"optimized_content": str, # GEO优化后的内容(或与content相同)
|
||
"seo_score": int | None,
|
||
"content_id": str | None, # 数据库记录ID
|
||
"pipeline_stages": list[dict],
|
||
}
|
||
|
||
Raises:
|
||
LLMError: LLM调用失败时
|
||
"""
|
||
# ---- Agent 框架路径 ----
|
||
if use_agent_framework:
|
||
try:
|
||
logger.info("尝试通过 Agent 框架执行内容生成")
|
||
return await self._execute_via_agent_framework(
|
||
keyword=keyword,
|
||
brand_name=brand_name,
|
||
platform=platform,
|
||
content_style=content_style,
|
||
word_count=word_count,
|
||
knowledge_context=knowledge_context,
|
||
knowledge_base_ids=knowledge_base_ids,
|
||
run_deai=run_deai,
|
||
run_geo=run_geo,
|
||
db=db,
|
||
user_id=user_id,
|
||
org_id=org_id,
|
||
)
|
||
except Exception as e:
|
||
logger.warning(
|
||
f"Agent 框架执行失败,回退到直接调用模式: {e}"
|
||
)
|
||
# 继续执行下方的直接调用逻辑
|
||
|
||
# ---- 直接调用路径(原有逻辑) ----
|
||
provider = self._get_provider()
|
||
stages = []
|
||
|
||
# 如果没有直接传入知识库上下文,但提供了知识库ID和db,则检索
|
||
if not knowledge_context and knowledge_base_ids and db:
|
||
knowledge_context = await self._get_knowledge_context(
|
||
db, brand_name, knowledge_base_ids, keyword
|
||
)
|
||
|
||
# ---- Stage 1: 内容生成 ----
|
||
gen_variables = {
|
||
"topic_title": keyword,
|
||
"target_keyword": keyword,
|
||
"target_platform": platform,
|
||
"content_angle": "综合分析",
|
||
"content_style": content_style,
|
||
"word_count": str(word_count),
|
||
"brand_name": brand_name,
|
||
"knowledge_context": knowledge_context,
|
||
}
|
||
messages = CONTENT_GENERATOR_TEMPLATE.render(gen_variables)
|
||
response = await provider.chat(
|
||
messages, temperature=0.7, max_tokens=word_count * 2
|
||
)
|
||
content = response.content
|
||
stages.append(
|
||
{"stage": "content_generation", "status": "success", "word_count": len(content)}
|
||
)
|
||
|
||
# ---- Stage 2: 去AI化(可选) ----
|
||
if run_deai:
|
||
deai_variables = {
|
||
"original_content": content,
|
||
"target_style": "自然流畅",
|
||
"preserve_structure": "是",
|
||
}
|
||
messages = DEAI_TEMPLATE.render(deai_variables)
|
||
response = await provider.chat(
|
||
messages, temperature=0.9, max_tokens=len(content) * 2
|
||
)
|
||
content = response.content
|
||
stages.append({"stage": "deai", "status": "success"})
|
||
|
||
# ---- Stage 3: GEO优化(可选) ----
|
||
optimized = content
|
||
seo_score = None
|
||
if run_geo:
|
||
geo_variables = {
|
||
"original_content": content,
|
||
"target_keywords": keyword,
|
||
"target_platform": platform,
|
||
"optimization_level": "moderate",
|
||
}
|
||
messages = GEO_OPTIMIZER_TEMPLATE.render(geo_variables)
|
||
response = await provider.chat(
|
||
messages, temperature=0.5, max_tokens=len(content) * 2
|
||
)
|
||
optimized = response.content
|
||
stages.append({"stage": "geo_optimization", "status": "success"})
|
||
|
||
# ---- 持久化(可选) ----
|
||
content_id = None
|
||
if db and user_id and org_id:
|
||
content_obj = Content(
|
||
organization_id=org_id,
|
||
title=keyword,
|
||
content_type="article",
|
||
body=optimized,
|
||
status="draft",
|
||
target_platforms=[platform] if platform else [],
|
||
keywords=[keyword],
|
||
extra_metadata={
|
||
"original_content": content if content != optimized else None,
|
||
"pipeline_stages": stages,
|
||
"seo_score": seo_score,
|
||
"brand_name": brand_name,
|
||
"content_style": content_style,
|
||
"word_count_target": word_count,
|
||
},
|
||
created_by=user_id,
|
||
current_version=1,
|
||
)
|
||
db.add(content_obj)
|
||
await db.flush()
|
||
|
||
version = ContentVersion(
|
||
content_id=content_obj.id,
|
||
version_number=1,
|
||
title=keyword,
|
||
body=optimized,
|
||
change_summary="Pipeline自动生成",
|
||
created_by=user_id,
|
||
)
|
||
db.add(version)
|
||
await db.commit()
|
||
await db.refresh(content_obj)
|
||
content_id = str(content_obj.id)
|
||
|
||
return {
|
||
"content": content,
|
||
"optimized_content": optimized,
|
||
"seo_score": seo_score,
|
||
"content_id": content_id,
|
||
"pipeline_stages": stages,
|
||
}
|