391 lines
13 KiB
Python
391 lines
13 KiB
Python
"""内容生产API - 串联Agent Pipeline
|
||
|
||
业务逻辑已委托给 ContentGenerationService,API 层仅负责:
|
||
1. 请求解析与参数校验
|
||
2. 调用服务层
|
||
3. 格式化响应
|
||
"""
|
||
import json
|
||
import logging
|
||
import re
|
||
import uuid
|
||
from typing import Optional
|
||
|
||
from fastapi import APIRouter, Depends, HTTPException
|
||
from pydantic import BaseModel
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.api.deps import get_current_user
|
||
from app.database import get_db
|
||
from app.models.brand import Brand
|
||
from app.models.content import Content, ContentVersion
|
||
from app.models.diagnosis_record import DiagnosisRecord
|
||
from app.models.user import User
|
||
from app.services.content.content_generation_service import ContentGenerationService
|
||
from app.services.llm import LLMError
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
router = APIRouter()
|
||
|
||
|
||
class ContentGenerateRequest(BaseModel):
|
||
target_keyword: str
|
||
target_platform: str = "通用"
|
||
knowledge_base_ids: list[str] = []
|
||
content_style: str = "专业严谨"
|
||
word_count: int = 2000
|
||
brand_name: str = ""
|
||
brand_description: str = ""
|
||
run_deai: bool = True
|
||
run_geo: bool = True
|
||
use_agent_framework: bool = False
|
||
|
||
|
||
class ContentGenerateResponse(BaseModel):
|
||
status: str
|
||
content: str = ""
|
||
optimized_content: str = ""
|
||
seo_score: Optional[int] = None
|
||
content_id: Optional[str] = None
|
||
topics: list[dict] = []
|
||
pipeline_stages: list[dict] = [] # 每个阶段的执行结果摘要
|
||
|
||
|
||
@router.post("/generate", response_model=ContentGenerateResponse)
|
||
async def generate_content(
|
||
req: ContentGenerateRequest,
|
||
db: AsyncSession = Depends(get_db),
|
||
current_user: User = Depends(get_current_user),
|
||
):
|
||
"""
|
||
一键生成内容(同步执行Pipeline),结果存入数据库
|
||
|
||
流程:ContentGenerator -> DeAI -> GEOOptimizer
|
||
业务逻辑委托给 ContentGenerationService
|
||
"""
|
||
org_id = getattr(current_user, "organization_id", None)
|
||
if not org_id:
|
||
raise HTTPException(status_code=403, detail="用户未关联组织")
|
||
|
||
try:
|
||
service = ContentGenerationService()
|
||
result = await service.generate_content(
|
||
keyword=req.target_keyword,
|
||
brand_name=req.brand_name,
|
||
platform=req.target_platform,
|
||
content_style=req.content_style,
|
||
word_count=req.word_count,
|
||
knowledge_base_ids=req.knowledge_base_ids,
|
||
db=db,
|
||
user_id=current_user.id,
|
||
org_id=org_id,
|
||
run_deai=req.run_deai,
|
||
run_geo=req.run_geo,
|
||
use_agent_framework=req.use_agent_framework,
|
||
)
|
||
|
||
return ContentGenerateResponse(
|
||
status="success",
|
||
content=result["content"],
|
||
optimized_content=result["optimized_content"],
|
||
seo_score=result["seo_score"],
|
||
content_id=result["content_id"],
|
||
pipeline_stages=result["pipeline_stages"],
|
||
)
|
||
|
||
except LLMError as e:
|
||
raise HTTPException(status_code=502, detail=f"LLM调用失败: {str(e)}")
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"内容生成异常: {str(e)}")
|
||
|
||
|
||
class GEOContentGenerateRequest(BaseModel):
|
||
brand_id: str
|
||
target_keywords: list[str]
|
||
platform: str = "通用"
|
||
content_style: str = "专业严谨"
|
||
word_count: int = 2000
|
||
knowledge_base_ids: list[str] = []
|
||
run_deai: bool = True
|
||
run_geo: bool = True
|
||
|
||
|
||
class GEOContentGenerateResponse(BaseModel):
|
||
content_id: Optional[str] = None
|
||
content: str = ""
|
||
optimized_content: str = ""
|
||
seo_score: Optional[int] = None
|
||
pipeline_stages: list[dict] = []
|
||
|
||
|
||
@router.post("/generate-geo", response_model=GEOContentGenerateResponse, status_code=201)
|
||
async def generate_geo_content(
|
||
req: GEOContentGenerateRequest,
|
||
db: AsyncSession = Depends(get_db),
|
||
current_user: User = Depends(get_current_user),
|
||
):
|
||
org_id = getattr(current_user, "organization_id", None)
|
||
if not org_id:
|
||
raise HTTPException(status_code=403, detail="用户未关联组织")
|
||
|
||
from sqlalchemy import select
|
||
|
||
try:
|
||
brand_uuid = uuid.UUID(req.brand_id)
|
||
except ValueError:
|
||
raise HTTPException(status_code=400, detail=f"Invalid brand_id format: {req.brand_id}")
|
||
|
||
brand_stmt = select(Brand).where(Brand.id == brand_uuid)
|
||
brand_result = await db.execute(brand_stmt)
|
||
brand = brand_result.scalar_one_or_none()
|
||
if not brand:
|
||
raise HTTPException(status_code=404, detail=f"Brand not found: {req.brand_id}")
|
||
|
||
diagnosis_context = ""
|
||
diag_stmt = (
|
||
select(DiagnosisRecord)
|
||
.where(DiagnosisRecord.brand_id == brand_uuid, DiagnosisRecord.status == "completed")
|
||
.order_by(DiagnosisRecord.created_at.desc())
|
||
)
|
||
diag_result = await db.execute(diag_stmt)
|
||
diagnosis = diag_result.scalar_one_or_none()
|
||
if diagnosis and diagnosis.result_json:
|
||
result_json = diagnosis.result_json
|
||
weak_dimensions = []
|
||
if isinstance(result_json, dict):
|
||
dimensions = result_json.get("dimensions", {})
|
||
for dim_name, dim_data in dimensions.items():
|
||
if isinstance(dim_data, dict) and dim_data.get("score", 100) < 60:
|
||
weak_dimensions.append(dim_name)
|
||
if weak_dimensions:
|
||
diagnosis_context = f"基于诊断结果,以下维度需要重点优化:{', '.join(weak_dimensions)}。请围绕这些维度生成针对性内容。"
|
||
|
||
keyword = "、".join(req.target_keywords)
|
||
if diagnosis_context:
|
||
keyword = f"{keyword}({diagnosis_context})"
|
||
|
||
try:
|
||
service = ContentGenerationService()
|
||
result = await service.generate_content(
|
||
keyword=keyword,
|
||
brand_name=brand.name,
|
||
platform=req.platform,
|
||
content_style=req.content_style,
|
||
word_count=req.word_count,
|
||
knowledge_base_ids=req.knowledge_base_ids,
|
||
db=db,
|
||
user_id=current_user.id,
|
||
org_id=org_id,
|
||
run_deai=req.run_deai,
|
||
run_geo=req.run_geo,
|
||
)
|
||
|
||
return GEOContentGenerateResponse(
|
||
content_id=result["content_id"],
|
||
content=result["content"],
|
||
optimized_content=result["optimized_content"],
|
||
seo_score=result["seo_score"],
|
||
pipeline_stages=result["pipeline_stages"],
|
||
)
|
||
|
||
except LLMError as e:
|
||
raise HTTPException(status_code=502, detail=f"LLM调用失败: {str(e)}")
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"内容生成异常: {str(e)}")
|
||
|
||
|
||
@router.post("/generate-topics")
|
||
async def generate_topics(
|
||
req: ContentGenerateRequest,
|
||
current_user: User = Depends(get_current_user),
|
||
):
|
||
"""生成选题列表"""
|
||
from app.services.llm import LLMError, LLMFactory
|
||
from app.agent_framework.prompts import TOPIC_SELECTOR_TEMPLATE
|
||
|
||
try:
|
||
provider = LLMFactory.get_default()
|
||
variables = {
|
||
"target_keyword": req.target_keyword,
|
||
"brand_name": req.brand_name,
|
||
"brand_description": req.brand_description,
|
||
"target_platform": req.target_platform,
|
||
"knowledge_context": "暂无",
|
||
"published_topics": "暂无",
|
||
}
|
||
messages = TOPIC_SELECTOR_TEMPLATE.render(variables)
|
||
response = await provider.chat(messages, temperature=0.8)
|
||
|
||
# 尝试解析JSON
|
||
match = re.search(r'```(?:json)?\s*\n?(.*?)\n?```', response.content, re.DOTALL)
|
||
text = match.group(1).strip() if match else response.content
|
||
try:
|
||
topics = json.loads(text)
|
||
except json.JSONDecodeError:
|
||
topics = [{"title": response.content[:100], "reason": "解析失败"}]
|
||
|
||
return {"status": "success", "topics": topics}
|
||
except LLMError as e:
|
||
raise HTTPException(status_code=502, detail=str(e))
|
||
|
||
|
||
# ==================== 母题库接口 ====================
|
||
|
||
class TopicGenerateRequest(BaseModel):
|
||
"""母题生成请求"""
|
||
params: dict # 母题模板参数
|
||
platform: str = "通用"
|
||
style: str = "专业严谨"
|
||
|
||
|
||
@router.get("/topics")
|
||
async def list_topics():
|
||
"""获取所有母题库列表"""
|
||
from app.services.content.topic_templates import list_topic_templates
|
||
templates = list_topic_templates()
|
||
return [
|
||
{
|
||
"id": t.id,
|
||
"name": t.name,
|
||
"description": t.description,
|
||
"icon": t.icon,
|
||
"recommended_platforms": t.recommended_platforms,
|
||
"word_count_range": list(t.word_count_range),
|
||
"required_params": t.required_params,
|
||
"optional_params": t.optional_params,
|
||
}
|
||
for t in templates
|
||
]
|
||
|
||
|
||
@router.get("/topics/{topic_id}")
|
||
async def get_topic(topic_id: str):
|
||
"""获取母题详情"""
|
||
from app.services.content.topic_templates import get_topic_template
|
||
template = get_topic_template(topic_id)
|
||
if not template:
|
||
raise HTTPException(status_code=404, detail="Topic not found")
|
||
return {
|
||
"id": template.id,
|
||
"name": template.name,
|
||
"description": template.description,
|
||
"icon": template.icon,
|
||
"prompt_template": template.prompt_template,
|
||
"seo_tips": template.seo_tips,
|
||
"recommended_platforms": template.recommended_platforms,
|
||
"word_count_range": list(template.word_count_range),
|
||
"required_params": template.required_params,
|
||
"optional_params": template.optional_params,
|
||
}
|
||
|
||
|
||
@router.post("/topics/{topic_id}/generate")
|
||
async def generate_with_topic(
|
||
topic_id: str,
|
||
request: TopicGenerateRequest,
|
||
db: AsyncSession = Depends(get_db),
|
||
current_user: User = Depends(get_current_user),
|
||
):
|
||
"""使用母题生成内容"""
|
||
from app.services.content.topic_templates import get_topic_template, render_topic_prompt
|
||
from app.services.llm import LLMError, LLMFactory
|
||
from app.agent_framework.prompts import DEAI_TEMPLATE, GEO_OPTIMIZER_TEMPLATE
|
||
|
||
template = get_topic_template(topic_id)
|
||
if not template:
|
||
raise HTTPException(status_code=404, detail="Topic not found")
|
||
|
||
# 验证必填参数
|
||
for param in template.required_params:
|
||
if param not in request.params:
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail=f"Missing required parameter: {param}"
|
||
)
|
||
|
||
org_id = getattr(current_user, "organization_id", None)
|
||
if not org_id:
|
||
raise HTTPException(status_code=403, detail="用户未关联组织")
|
||
|
||
try:
|
||
provider = LLMFactory.get_default()
|
||
|
||
# 渲染Prompt
|
||
prompt = render_topic_prompt(topic_id, request.params)
|
||
|
||
# 调用内容生成
|
||
response = await provider.chat(
|
||
[{"role": "user", "content": prompt}],
|
||
temperature=0.7,
|
||
max_tokens=4000
|
||
)
|
||
content = response.content
|
||
|
||
# 去AI化处理
|
||
deai_variables = {
|
||
"original_content": content,
|
||
"target_style": "自然流畅",
|
||
"preserve_structure": "是",
|
||
}
|
||
messages = DEAI_TEMPLATE.render(deai_variables)
|
||
response = await provider.chat(messages, temperature=0.9, max_tokens=len(content) * 2)
|
||
content = response.content
|
||
|
||
# GEO优化
|
||
geo_variables = {
|
||
"original_content": content,
|
||
"target_keywords": request.params.get("keywords", ""),
|
||
"target_platform": request.platform,
|
||
"optimization_level": "moderate",
|
||
}
|
||
messages = GEO_OPTIMIZER_TEMPLATE.render(geo_variables)
|
||
response = await provider.chat(messages, temperature=0.5, max_tokens=len(content) * 2)
|
||
optimized = response.content
|
||
|
||
# 存入数据库
|
||
content_obj = Content(
|
||
organization_id=org_id,
|
||
title=request.params.get("product_name") or request.params.get("topic") or topic_id,
|
||
content_type="article",
|
||
body=optimized,
|
||
status="draft",
|
||
target_platforms=[request.platform],
|
||
keywords=[request.params.get("keywords", "")],
|
||
extra_metadata={
|
||
"original_content": content,
|
||
"topic_id": topic_id,
|
||
"topic_name": template.name,
|
||
"brand_name": request.params.get("brand_name", ""),
|
||
"content_style": request.style,
|
||
},
|
||
created_by=current_user.id,
|
||
current_version=1,
|
||
)
|
||
db.add(content_obj)
|
||
await db.flush()
|
||
|
||
version = ContentVersion(
|
||
content_id=content_obj.id,
|
||
version_number=1,
|
||
title=content_obj.title,
|
||
body=optimized,
|
||
change_summary="母题库自动生成",
|
||
created_by=current_user.id,
|
||
)
|
||
db.add(version)
|
||
await db.commit()
|
||
await db.refresh(content_obj)
|
||
|
||
return {
|
||
"topic_id": topic_id,
|
||
"content": content,
|
||
"optimized_content": optimized,
|
||
"content_id": str(content_obj.id),
|
||
"seo_tips": template.seo_tips,
|
||
}
|
||
|
||
except LLMError as e:
|
||
raise HTTPException(status_code=502, detail=f"LLM调用失败: {str(e)}")
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"内容生成异常: {str(e)}") |