240 lines
8.1 KiB
Python
240 lines
8.1 KiB
Python
"""内容生产API - 串联Agent Pipeline"""
|
||
import json
|
||
import logging
|
||
import re
|
||
from typing import Optional
|
||
|
||
from fastapi import APIRouter, Depends, HTTPException
|
||
from pydantic import BaseModel
|
||
from sqlalchemy import select
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.api.deps import get_current_user
|
||
from app.database import get_db
|
||
from app.models.content import Content, ContentVersion
|
||
from app.models.user import User
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
router = APIRouter()
|
||
|
||
|
||
class ContentGenerateRequest(BaseModel):
|
||
target_keyword: str
|
||
target_platform: str = "通用"
|
||
knowledge_base_ids: list[str] = []
|
||
content_style: str = "专业严谨"
|
||
word_count: int = 2000
|
||
brand_name: str = ""
|
||
brand_description: str = ""
|
||
run_deai: bool = True
|
||
run_geo: bool = True
|
||
|
||
|
||
class ContentGenerateResponse(BaseModel):
|
||
status: str
|
||
content: str = ""
|
||
optimized_content: str = ""
|
||
seo_score: Optional[int] = None
|
||
content_id: Optional[str] = None
|
||
topics: list[dict] = []
|
||
pipeline_stages: list[dict] = [] # 每个阶段的执行结果摘要
|
||
|
||
|
||
async def _get_knowledge_context(
|
||
db: AsyncSession,
|
||
brand_name: str,
|
||
knowledge_base_ids: list[str],
|
||
target_keyword: str,
|
||
) -> str:
|
||
"""
|
||
从知识库检索与查询相关的上下文。
|
||
|
||
如果有知识库ID,则调用 RAGService.search 获取相关内容;
|
||
否则返回空字符串,不影响后续流程。
|
||
"""
|
||
if not knowledge_base_ids:
|
||
return ""
|
||
|
||
try:
|
||
from app.services.knowledge.rag_service import RAGService
|
||
rag_service = RAGService()
|
||
results = await rag_service.search(
|
||
session=db,
|
||
query=f"{brand_name} {target_keyword}" if brand_name else target_keyword,
|
||
knowledge_base_ids=knowledge_base_ids,
|
||
top_k=3,
|
||
)
|
||
if results:
|
||
context_parts = []
|
||
for r in results:
|
||
content = r.get("content", "")
|
||
title = r.get("document_title", "")
|
||
if content:
|
||
context_parts.append(f"[{title}] {content}")
|
||
return "\n".join(context_parts)
|
||
return ""
|
||
except Exception as e:
|
||
logger.warning(f"知识库检索失败,将不使用知识库上下文: {e}")
|
||
return ""
|
||
|
||
|
||
@router.post("/generate", response_model=ContentGenerateResponse)
|
||
async def generate_content(
|
||
req: ContentGenerateRequest,
|
||
db: AsyncSession = Depends(get_db),
|
||
current_user: User = Depends(get_current_user),
|
||
):
|
||
"""
|
||
一键生成内容(同步执行Pipeline),结果存入数据库
|
||
|
||
流程:ContentGenerator → DeAI → GEOOptimizer
|
||
"""
|
||
from app.services.llm import LLMError, LLMFactory
|
||
from app.agent_framework.prompts import (
|
||
CONTENT_GENERATOR_TEMPLATE,
|
||
DEAI_TEMPLATE,
|
||
GEO_OPTIMIZER_TEMPLATE,
|
||
)
|
||
|
||
org_id = getattr(current_user, "organization_id", None)
|
||
if not org_id:
|
||
raise HTTPException(status_code=403, detail="用户未关联组织")
|
||
|
||
stages = []
|
||
|
||
try:
|
||
provider = LLMFactory.get_default()
|
||
|
||
# 获取知识库上下文
|
||
knowledge_context = await _get_knowledge_context(
|
||
db, req.brand_name, req.knowledge_base_ids, req.target_keyword
|
||
)
|
||
|
||
# Stage 1: 内容生成
|
||
gen_variables = {
|
||
"topic_title": req.target_keyword,
|
||
"target_keyword": req.target_keyword,
|
||
"target_platform": req.target_platform,
|
||
"content_angle": "综合分析",
|
||
"content_style": req.content_style,
|
||
"word_count": str(req.word_count),
|
||
"brand_name": req.brand_name,
|
||
"knowledge_context": knowledge_context,
|
||
}
|
||
messages = CONTENT_GENERATOR_TEMPLATE.render(gen_variables)
|
||
response = await provider.chat(messages, temperature=0.7, max_tokens=req.word_count * 2)
|
||
content = response.content
|
||
stages.append({"stage": "content_generation", "status": "success", "word_count": len(content)})
|
||
|
||
# Stage 2: 去AI化(可选)
|
||
if req.run_deai:
|
||
deai_variables = {
|
||
"original_content": content,
|
||
"target_style": "自然流畅",
|
||
"preserve_structure": "是",
|
||
}
|
||
messages = DEAI_TEMPLATE.render(deai_variables)
|
||
response = await provider.chat(messages, temperature=0.9, max_tokens=len(content) * 2)
|
||
content = response.content
|
||
stages.append({"stage": "deai", "status": "success"})
|
||
|
||
# Stage 3: GEO优化(可选)
|
||
optimized = content
|
||
seo_score = None
|
||
if req.run_geo:
|
||
geo_variables = {
|
||
"original_content": content,
|
||
"target_keywords": req.target_keyword,
|
||
"target_platform": req.target_platform,
|
||
"optimization_level": "moderate",
|
||
}
|
||
messages = GEO_OPTIMIZER_TEMPLATE.render(geo_variables)
|
||
response = await provider.chat(messages, temperature=0.5, max_tokens=len(content) * 2)
|
||
optimized = response.content
|
||
stages.append({"stage": "geo_optimization", "status": "success"})
|
||
|
||
# ---- 存入数据库 ----
|
||
content_obj = Content(
|
||
organization_id=org_id,
|
||
title=req.target_keyword,
|
||
content_type="article",
|
||
body=optimized,
|
||
status="draft",
|
||
target_platforms=[req.target_platform] if req.target_platform else [],
|
||
keywords=[req.target_keyword],
|
||
extra_metadata={
|
||
"original_content": content if content != optimized else None,
|
||
"pipeline_stages": stages,
|
||
"seo_score": seo_score,
|
||
"brand_name": req.brand_name,
|
||
"content_style": req.content_style,
|
||
"word_count_target": req.word_count,
|
||
},
|
||
created_by=current_user.id,
|
||
current_version=1,
|
||
)
|
||
db.add(content_obj)
|
||
await db.flush() # get content_obj.id
|
||
|
||
# 创建版本记录(初始版本)
|
||
version = ContentVersion(
|
||
content_id=content_obj.id,
|
||
version_number=1,
|
||
title=req.target_keyword,
|
||
body=optimized,
|
||
change_summary="Pipeline自动生成",
|
||
created_by=current_user.id,
|
||
)
|
||
db.add(version)
|
||
await db.commit()
|
||
await db.refresh(content_obj)
|
||
|
||
return ContentGenerateResponse(
|
||
status="success",
|
||
content=content,
|
||
optimized_content=optimized,
|
||
seo_score=seo_score,
|
||
content_id=str(content_obj.id),
|
||
pipeline_stages=stages,
|
||
)
|
||
|
||
except LLMError as e:
|
||
raise HTTPException(status_code=502, detail=f"LLM调用失败: {str(e)}")
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"内容生成异常: {str(e)}")
|
||
|
||
|
||
@router.post("/generate-topics")
|
||
async def generate_topics(
|
||
req: ContentGenerateRequest,
|
||
current_user: User = Depends(get_current_user),
|
||
):
|
||
"""生成选题列表"""
|
||
from app.services.llm import LLMError, LLMFactory
|
||
from app.agent_framework.prompts import TOPIC_SELECTOR_TEMPLATE
|
||
|
||
try:
|
||
provider = LLMFactory.get_default()
|
||
variables = {
|
||
"target_keyword": req.target_keyword,
|
||
"brand_name": req.brand_name,
|
||
"brand_description": req.brand_description,
|
||
"target_platform": req.target_platform,
|
||
"knowledge_context": "暂无",
|
||
"published_topics": "暂无",
|
||
}
|
||
messages = TOPIC_SELECTOR_TEMPLATE.render(variables)
|
||
response = await provider.chat(messages, temperature=0.8)
|
||
|
||
# 尝试解析JSON
|
||
match = re.search(r'```(?:json)?\s*\n?(.*?)\n?```', response.content, re.DOTALL)
|
||
text = match.group(1).strip() if match else response.content
|
||
try:
|
||
topics = json.loads(text)
|
||
except json.JSONDecodeError:
|
||
topics = [{"title": response.content[:100], "reason": "解析失败"}]
|
||
|
||
return {"status": "success", "topics": topics}
|
||
except LLMError as e:
|
||
raise HTTPException(status_code=502, detail=str(e)) |