"""内容生产API - 串联Agent Pipeline""" import json import logging import re from typing import Optional from fastapi import APIRouter, Depends, HTTPException from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.api.deps import get_current_user from app.database import get_db from app.models.content import Content, ContentVersion from app.models.user import User logger = logging.getLogger(__name__) router = APIRouter() class ContentGenerateRequest(BaseModel): target_keyword: str target_platform: str = "通用" knowledge_base_ids: list[str] = [] content_style: str = "专业严谨" word_count: int = 2000 brand_name: str = "" brand_description: str = "" run_deai: bool = True run_geo: bool = True class ContentGenerateResponse(BaseModel): status: str content: str = "" optimized_content: str = "" seo_score: Optional[int] = None content_id: Optional[str] = None topics: list[dict] = [] pipeline_stages: list[dict] = [] # 每个阶段的执行结果摘要 async def _get_knowledge_context( db: AsyncSession, brand_name: str, knowledge_base_ids: list[str], target_keyword: str, ) -> str: """ 从知识库检索与查询相关的上下文。 如果有知识库ID,则调用 RAGService.search 获取相关内容; 否则返回空字符串,不影响后续流程。 """ if not knowledge_base_ids: return "" try: from app.services.knowledge.rag_service import RAGService rag_service = RAGService() results = await rag_service.search( session=db, query=f"{brand_name} {target_keyword}" if brand_name else target_keyword, knowledge_base_ids=knowledge_base_ids, top_k=3, ) if results: context_parts = [] for r in results: content = r.get("content", "") title = r.get("document_title", "") if content: context_parts.append(f"[{title}] {content}") return "\n".join(context_parts) return "" except Exception as e: logger.warning(f"知识库检索失败,将不使用知识库上下文: {e}") return "" @router.post("/generate", response_model=ContentGenerateResponse) async def generate_content( req: ContentGenerateRequest, db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user), ): """ 一键生成内容(同步执行Pipeline),结果存入数据库 流程:ContentGenerator → DeAI → GEOOptimizer """ from app.services.llm import LLMError, LLMFactory from app.agent_framework.prompts import ( CONTENT_GENERATOR_TEMPLATE, DEAI_TEMPLATE, GEO_OPTIMIZER_TEMPLATE, ) org_id = getattr(current_user, "organization_id", None) if not org_id: raise HTTPException(status_code=403, detail="用户未关联组织") stages = [] try: provider = LLMFactory.get_default() # 获取知识库上下文 knowledge_context = await _get_knowledge_context( db, req.brand_name, req.knowledge_base_ids, req.target_keyword ) # Stage 1: 内容生成 gen_variables = { "topic_title": req.target_keyword, "target_keyword": req.target_keyword, "target_platform": req.target_platform, "content_angle": "综合分析", "content_style": req.content_style, "word_count": str(req.word_count), "brand_name": req.brand_name, "knowledge_context": knowledge_context, } messages = CONTENT_GENERATOR_TEMPLATE.render(gen_variables) response = await provider.chat(messages, temperature=0.7, max_tokens=req.word_count * 2) content = response.content stages.append({"stage": "content_generation", "status": "success", "word_count": len(content)}) # Stage 2: 去AI化(可选) if req.run_deai: deai_variables = { "original_content": content, "target_style": "自然流畅", "preserve_structure": "是", } messages = DEAI_TEMPLATE.render(deai_variables) response = await provider.chat(messages, temperature=0.9, max_tokens=len(content) * 2) content = response.content stages.append({"stage": "deai", "status": "success"}) # Stage 3: GEO优化(可选) optimized = content seo_score = None if req.run_geo: geo_variables = { "original_content": content, "target_keywords": req.target_keyword, "target_platform": req.target_platform, "optimization_level": "moderate", } messages = GEO_OPTIMIZER_TEMPLATE.render(geo_variables) response = await provider.chat(messages, temperature=0.5, max_tokens=len(content) * 2) optimized = response.content stages.append({"stage": "geo_optimization", "status": "success"}) # ---- 存入数据库 ---- content_obj = Content( organization_id=org_id, title=req.target_keyword, content_type="article", body=optimized, status="draft", target_platforms=[req.target_platform] if req.target_platform else [], keywords=[req.target_keyword], extra_metadata={ "original_content": content if content != optimized else None, "pipeline_stages": stages, "seo_score": seo_score, "brand_name": req.brand_name, "content_style": req.content_style, "word_count_target": req.word_count, }, created_by=current_user.id, current_version=1, ) db.add(content_obj) await db.flush() # get content_obj.id # 创建版本记录(初始版本) version = ContentVersion( content_id=content_obj.id, version_number=1, title=req.target_keyword, body=optimized, change_summary="Pipeline自动生成", created_by=current_user.id, ) db.add(version) await db.commit() await db.refresh(content_obj) return ContentGenerateResponse( status="success", content=content, optimized_content=optimized, seo_score=seo_score, content_id=str(content_obj.id), pipeline_stages=stages, ) except LLMError as e: raise HTTPException(status_code=502, detail=f"LLM调用失败: {str(e)}") except Exception as e: raise HTTPException(status_code=500, detail=f"内容生成异常: {str(e)}") @router.post("/generate-topics") async def generate_topics( req: ContentGenerateRequest, current_user: User = Depends(get_current_user), ): """生成选题列表""" from app.services.llm import LLMError, LLMFactory from app.agent_framework.prompts import TOPIC_SELECTOR_TEMPLATE try: provider = LLMFactory.get_default() variables = { "target_keyword": req.target_keyword, "brand_name": req.brand_name, "brand_description": req.brand_description, "target_platform": req.target_platform, "knowledge_context": "暂无", "published_topics": "暂无", } messages = TOPIC_SELECTOR_TEMPLATE.render(variables) response = await provider.chat(messages, temperature=0.8) # 尝试解析JSON match = re.search(r'```(?:json)?\s*\n?(.*?)\n?```', response.content, re.DOTALL) text = match.group(1).strip() if match else response.content try: topics = json.loads(text) except json.JSONDecodeError: topics = [{"title": response.content[:100], "reason": "解析失败"}] return {"status": "success", "topics": topics} except LLMError as e: raise HTTPException(status_code=502, detail=str(e)) # ==================== 母题库接口 ==================== class TopicGenerateRequest(BaseModel): """母题生成请求""" params: dict # 母题模板参数 platform: str = "通用" style: str = "专业严谨" @router.get("/topics") async def list_topics(): """获取所有母题库列表""" from app.services.content.topic_templates import list_topic_templates templates = list_topic_templates() return [ { "id": t.id, "name": t.name, "description": t.description, "icon": t.icon, "recommended_platforms": t.recommended_platforms, "word_count_range": list(t.word_count_range), "required_params": t.required_params, "optional_params": t.optional_params, } for t in templates ] @router.get("/topics/{topic_id}") async def get_topic(topic_id: str): """获取母题详情""" from app.services.content.topic_templates import get_topic_template template = get_topic_template(topic_id) if not template: raise HTTPException(status_code=404, detail="Topic not found") return { "id": template.id, "name": template.name, "description": template.description, "icon": template.icon, "prompt_template": template.prompt_template, "seo_tips": template.seo_tips, "recommended_platforms": template.recommended_platforms, "word_count_range": list(template.word_count_range), "required_params": template.required_params, "optional_params": template.optional_params, } @router.post("/topics/{topic_id}/generate") async def generate_with_topic( topic_id: str, request: TopicGenerateRequest, db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user), ): """使用母题生成内容""" from app.services.content.topic_templates import get_topic_template, render_topic_prompt from app.services.llm import LLMError, LLMFactory from app.agent_framework.prompts import DEAI_TEMPLATE, GEO_OPTIMIZER_TEMPLATE template = get_topic_template(topic_id) if not template: raise HTTPException(status_code=404, detail="Topic not found") # 验证必填参数 for param in template.required_params: if param not in request.params: raise HTTPException( status_code=400, detail=f"Missing required parameter: {param}" ) org_id = getattr(current_user, "organization_id", None) if not org_id: raise HTTPException(status_code=403, detail="用户未关联组织") try: provider = LLMFactory.get_default() # 渲染Prompt prompt = render_topic_prompt(topic_id, request.params) # 调用内容生成 response = await provider.chat( [{"role": "user", "content": prompt}], temperature=0.7, max_tokens=4000 ) content = response.content # 去AI化处理 deai_variables = { "original_content": content, "target_style": "自然流畅", "preserve_structure": "是", } messages = DEAI_TEMPLATE.render(deai_variables) response = await provider.chat(messages, temperature=0.9, max_tokens=len(content) * 2) content = response.content # GEO优化 geo_variables = { "original_content": content, "target_keywords": request.params.get("keywords", ""), "target_platform": request.platform, "optimization_level": "moderate", } messages = GEO_OPTIMIZER_TEMPLATE.render(geo_variables) response = await provider.chat(messages, temperature=0.5, max_tokens=len(content) * 2) optimized = response.content # 存入数据库 content_obj = Content( organization_id=org_id, title=request.params.get("product_name") or request.params.get("topic") or topic_id, content_type="article", body=optimized, status="draft", target_platforms=[request.platform], keywords=[request.params.get("keywords", "")], extra_metadata={ "original_content": content, "topic_id": topic_id, "topic_name": template.name, "brand_name": request.params.get("brand_name", ""), "content_style": request.style, }, created_by=current_user.id, current_version=1, ) db.add(content_obj) await db.flush() version = ContentVersion( content_id=content_obj.id, version_number=1, title=content_obj.title, body=optimized, change_summary="母题库自动生成", created_by=current_user.id, ) db.add(version) await db.commit() await db.refresh(content_obj) return { "topic_id": topic_id, "content": content, "optimized_content": optimized, "content_id": str(content_obj.id), "seo_tips": template.seo_tips, } except LLMError as e: raise HTTPException(status_code=502, detail=f"LLM调用失败: {str(e)}") except Exception as e: raise HTTPException(status_code=500, detail=f"内容生成异常: {str(e)}")