geo/backend/app/api/content.py

240 lines
8.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""内容生产API - 串联Agent Pipeline"""
import json
import logging
import re
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.database import get_db
from app.models.content import Content, ContentVersion
from app.models.user import User
logger = logging.getLogger(__name__)
router = APIRouter()
class ContentGenerateRequest(BaseModel):
target_keyword: str
target_platform: str = "通用"
knowledge_base_ids: list[str] = []
content_style: str = "专业严谨"
word_count: int = 2000
brand_name: str = ""
brand_description: str = ""
run_deai: bool = True
run_geo: bool = True
class ContentGenerateResponse(BaseModel):
status: str
content: str = ""
optimized_content: str = ""
seo_score: Optional[int] = None
content_id: Optional[str] = None
topics: list[dict] = []
pipeline_stages: list[dict] = [] # 每个阶段的执行结果摘要
async def _get_knowledge_context(
db: AsyncSession,
brand_name: str,
knowledge_base_ids: list[str],
target_keyword: str,
) -> str:
"""
从知识库检索与查询相关的上下文。
如果有知识库ID则调用 RAGService.search 获取相关内容;
否则返回空字符串,不影响后续流程。
"""
if not knowledge_base_ids:
return ""
try:
from app.services.knowledge.rag_service import RAGService
rag_service = RAGService()
results = await rag_service.search(
session=db,
query=f"{brand_name} {target_keyword}" if brand_name else target_keyword,
knowledge_base_ids=knowledge_base_ids,
top_k=3,
)
if results:
context_parts = []
for r in results:
content = r.get("content", "")
title = r.get("document_title", "")
if content:
context_parts.append(f"[{title}] {content}")
return "\n".join(context_parts)
return ""
except Exception as e:
logger.warning(f"知识库检索失败,将不使用知识库上下文: {e}")
return ""
@router.post("/generate", response_model=ContentGenerateResponse)
async def generate_content(
req: ContentGenerateRequest,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
一键生成内容同步执行Pipeline结果存入数据库
流程ContentGenerator → DeAI → GEOOptimizer
"""
from app.services.llm import LLMError, LLMFactory
from app.agent_framework.prompts import (
CONTENT_GENERATOR_TEMPLATE,
DEAI_TEMPLATE,
GEO_OPTIMIZER_TEMPLATE,
)
org_id = getattr(current_user, "organization_id", None)
if not org_id:
raise HTTPException(status_code=403, detail="用户未关联组织")
stages = []
try:
provider = LLMFactory.get_default()
# 获取知识库上下文
knowledge_context = await _get_knowledge_context(
db, req.brand_name, req.knowledge_base_ids, req.target_keyword
)
# Stage 1: 内容生成
gen_variables = {
"topic_title": req.target_keyword,
"target_keyword": req.target_keyword,
"target_platform": req.target_platform,
"content_angle": "综合分析",
"content_style": req.content_style,
"word_count": str(req.word_count),
"brand_name": req.brand_name,
"knowledge_context": knowledge_context,
}
messages = CONTENT_GENERATOR_TEMPLATE.render(gen_variables)
response = await provider.chat(messages, temperature=0.7, max_tokens=req.word_count * 2)
content = response.content
stages.append({"stage": "content_generation", "status": "success", "word_count": len(content)})
# Stage 2: 去AI化可选
if req.run_deai:
deai_variables = {
"original_content": content,
"target_style": "自然流畅",
"preserve_structure": "",
}
messages = DEAI_TEMPLATE.render(deai_variables)
response = await provider.chat(messages, temperature=0.9, max_tokens=len(content) * 2)
content = response.content
stages.append({"stage": "deai", "status": "success"})
# Stage 3: GEO优化可选
optimized = content
seo_score = None
if req.run_geo:
geo_variables = {
"original_content": content,
"target_keywords": req.target_keyword,
"target_platform": req.target_platform,
"optimization_level": "moderate",
}
messages = GEO_OPTIMIZER_TEMPLATE.render(geo_variables)
response = await provider.chat(messages, temperature=0.5, max_tokens=len(content) * 2)
optimized = response.content
stages.append({"stage": "geo_optimization", "status": "success"})
# ---- 存入数据库 ----
content_obj = Content(
organization_id=org_id,
title=req.target_keyword,
content_type="article",
body=optimized,
status="draft",
target_platforms=[req.target_platform] if req.target_platform else [],
keywords=[req.target_keyword],
extra_metadata={
"original_content": content if content != optimized else None,
"pipeline_stages": stages,
"seo_score": seo_score,
"brand_name": req.brand_name,
"content_style": req.content_style,
"word_count_target": req.word_count,
},
created_by=current_user.id,
current_version=1,
)
db.add(content_obj)
await db.flush() # get content_obj.id
# 创建版本记录(初始版本)
version = ContentVersion(
content_id=content_obj.id,
version_number=1,
title=req.target_keyword,
body=optimized,
change_summary="Pipeline自动生成",
created_by=current_user.id,
)
db.add(version)
await db.commit()
await db.refresh(content_obj)
return ContentGenerateResponse(
status="success",
content=content,
optimized_content=optimized,
seo_score=seo_score,
content_id=str(content_obj.id),
pipeline_stages=stages,
)
except LLMError as e:
raise HTTPException(status_code=502, detail=f"LLM调用失败: {str(e)}")
except Exception as e:
raise HTTPException(status_code=500, detail=f"内容生成异常: {str(e)}")
@router.post("/generate-topics")
async def generate_topics(
req: ContentGenerateRequest,
current_user: User = Depends(get_current_user),
):
"""生成选题列表"""
from app.services.llm import LLMError, LLMFactory
from app.agent_framework.prompts import TOPIC_SELECTOR_TEMPLATE
try:
provider = LLMFactory.get_default()
variables = {
"target_keyword": req.target_keyword,
"brand_name": req.brand_name,
"brand_description": req.brand_description,
"target_platform": req.target_platform,
"knowledge_context": "暂无",
"published_topics": "暂无",
}
messages = TOPIC_SELECTOR_TEMPLATE.render(variables)
response = await provider.chat(messages, temperature=0.8)
# 尝试解析JSON
match = re.search(r'```(?:json)?\s*\n?(.*?)\n?```', response.content, re.DOTALL)
text = match.group(1).strip() if match else response.content
try:
topics = json.loads(text)
except json.JSONDecodeError:
topics = [{"title": response.content[:100], "reason": "解析失败"}]
return {"status": "success", "topics": topics}
except LLMError as e:
raise HTTPException(status_code=502, detail=str(e))