import uuid from datetime import datetime from fastapi import APIRouter, Depends, HTTPException, Query, status from sqlalchemy import select, func from sqlalchemy.ext.asyncio import AsyncSession from app.api.deps import get_current_user from app.database import get_db from app.models.user import User from app.models.brand import Brand from app.models.geo_plan import GeoPlan, GeoPlanAction from app.schemas.geo_plan import ( GeoPlanGenerateRequest, GeoPlanResponse, GeoPlanListResponse, GeoPlanActionResponse, GeoPlanActionUpdateStatus, GeoPlanActionExecuteResponse, ) from app.services.scoring.brand_scoring_data_service import get_brand_scoring_data_service from app.services.strategy.geo_plan_generator import generate_geo_plan from app.services.content.content_generation_service import ContentGenerationService router = APIRouter() def _to_uuid(value: str | uuid.UUID) -> uuid.UUID: if isinstance(value, uuid.UUID): return value return uuid.UUID(str(value)) async def _get_brand_with_access( brand_id: uuid.UUID, db: AsyncSession, current_user: User, ) -> Brand: stmt = select(Brand).where( Brand.id == brand_id, Brand.user_id == _to_uuid(current_user.id), ) result = await db.execute(stmt) brand = result.scalar_one_or_none() if not brand: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="品牌不存在", ) return brand async def _get_brand_scoring_data( db: AsyncSession, user_id: uuid.UUID, brand: Brand, ) -> tuple: scoring_data_service = get_brand_scoring_data_service() scoring_data = await scoring_data_service.get_brand_scoring_data(db, user_id, brand) return ( scoring_data.v2_result, scoring_data.competitor_data, scoring_data.sentiment_counts, scoring_data.platform_scores, scoring_data.total_queries, scoring_data.mentioned_count, ) def _plan_to_response(plan: GeoPlan) -> GeoPlanResponse: actions = [ GeoPlanActionResponse( id=action.id, plan_id=action.plan_id, action_type=action.action_type, title=action.title, description=action.description, reason=action.reason, priority=action.priority, status=action.status, target_keyword=action.target_keyword, target_platform=action.target_platform, content_style=action.content_style, estimated_impact=action.estimated_impact, difficulty=action.difficulty, execution_params=action.execution_params, sort_order=action.sort_order, completed_at=action.completed_at, created_at=action.created_at, ) for action in sorted(plan.actions, key=lambda a: a.sort_order) ] return GeoPlanResponse( id=plan.id, brand_id=plan.brand_id, title=plan.title, status=plan.status, diagnosis_score=plan.diagnosis_score, target_score=plan.target_score, estimated_weeks=plan.estimated_weeks, plan_data=plan.plan_data, source=plan.source, actions=actions, created_at=plan.created_at, updated_at=plan.updated_at, ) @router.post("/generate", response_model=GeoPlanResponse) async def generate_geo_plan_endpoint( request: GeoPlanGenerateRequest, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): brand = await _get_brand_with_access(request.brand_id, db, current_user) ( v2_result, competitor_data, sentiment_data, platform_scores, total_queries, mentioned_count, ) = await _get_brand_scoring_data(db, _to_uuid(current_user.id), brand) target_score = request.target_score or 75 plan_data = await generate_geo_plan( brand_name=brand.name, scoring_result=v2_result, target_score=target_score, total_queries=total_queries, platform_scores=platform_scores, competitor_data=competitor_data, ) from app.config import settings source = "llm" if (settings.ENABLE_LLM and settings.DEEPSEEK_API_KEY) else "rule" organization_id = current_user.id org_stmt = select(func.count()).select_from( select(1).where(True).subquery() ) db_plan = GeoPlan( organization_id=organization_id, brand_id=brand.id, title=plan_data.title, status="draft", diagnosis_score=int(round(v2_result.overall_score)), target_score=target_score, estimated_weeks=plan_data.estimated_weeks, plan_data={ "weekly_plan": plan_data.weekly_plan, }, source=source, created_by=current_user.id, ) db.add(db_plan) await db.flush() for idx, action_item in enumerate(plan_data.actions): db_action = GeoPlanAction( plan_id=db_plan.id, action_type=action_item.action_type, title=action_item.title, description=action_item.description, reason=action_item.reason, priority=action_item.priority, status="pending", target_keyword=action_item.target_keyword, target_platform=action_item.target_platform, content_style=action_item.content_style, estimated_impact=action_item.estimated_impact, difficulty=action_item.difficulty, execution_params=action_item.execution_params, sort_order=idx, ) db.add(db_action) await db.commit() await db.refresh(db_plan) stmt = select(GeoPlan).where(GeoPlan.id == db_plan.id) result = await db.execute(stmt) db_plan = result.scalar_one() action_stmt = select(GeoPlanAction).where( GeoPlanAction.plan_id == db_plan.id ).order_by(GeoPlanAction.sort_order) action_result = await db.execute(action_stmt) db_plan.actions = list(action_result.scalars().all()) return _plan_to_response(db_plan) @router.get("/brand/{brand_id}", response_model=GeoPlanListResponse) async def get_brand_plans( brand_id: uuid.UUID, skip: int = Query(0, ge=0), limit: int = Query(20, ge=1, le=100), current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): await _get_brand_with_access(brand_id, db, current_user) count_stmt = select(func.count()).select_from(GeoPlan).where( GeoPlan.brand_id == brand_id, ) count_result = await db.execute(count_stmt) total = count_result.scalar_one() stmt = ( select(GeoPlan) .where(GeoPlan.brand_id == brand_id) .order_by(GeoPlan.created_at.desc()) .offset(skip) .limit(limit) ) result = await db.execute(stmt) plans = list(result.scalars().all()) plan_responses = [] for plan in plans: action_stmt = select(GeoPlanAction).where( GeoPlanAction.plan_id == plan.id ).order_by(GeoPlanAction.sort_order) action_result = await db.execute(action_stmt) plan.actions = list(action_result.scalars().all()) plan_responses.append(_plan_to_response(plan)) return GeoPlanListResponse(plans=plan_responses, total=total) @router.get("/{plan_id}", response_model=GeoPlanResponse) async def get_plan_detail( plan_id: uuid.UUID, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): stmt = select(GeoPlan).where(GeoPlan.id == plan_id) result = await db.execute(stmt) plan = result.scalar_one_or_none() if not plan: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="方案不存在", ) brand = await _get_brand_with_access(plan.brand_id, db, current_user) action_stmt = select(GeoPlanAction).where( GeoPlanAction.plan_id == plan.id ).order_by(GeoPlanAction.sort_order) action_result = await db.execute(action_stmt) plan.actions = list(action_result.scalars().all()) return _plan_to_response(plan) @router.put("/actions/{action_id}/status", response_model=GeoPlanActionResponse) async def update_action_status( action_id: uuid.UUID, status_update: GeoPlanActionUpdateStatus, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): valid_statuses = {"pending", "in_progress", "completed", "skipped"} if status_update.status not in valid_statuses: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"无效的状态值,支持: {', '.join(valid_statuses)}", ) stmt = select(GeoPlanAction).where(GeoPlanAction.id == action_id) result = await db.execute(stmt) action = result.scalar_one_or_none() if not action: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="行动项不存在", ) plan_stmt = select(GeoPlan).where(GeoPlan.id == action.plan_id) plan_result = await db.execute(plan_stmt) plan = plan_result.scalar_one() await _get_brand_with_access(plan.brand_id, db, current_user) action.status = status_update.status if status_update.status == "completed": action.completed_at = datetime.now() await db.commit() await db.refresh(action) return GeoPlanActionResponse( id=action.id, plan_id=action.plan_id, action_type=action.action_type, title=action.title, description=action.description, reason=action.reason, priority=action.priority, status=action.status, target_keyword=action.target_keyword, target_platform=action.target_platform, content_style=action.content_style, estimated_impact=action.estimated_impact, difficulty=action.difficulty, execution_params=action.execution_params, sort_order=action.sort_order, completed_at=action.completed_at, created_at=action.created_at, ) @router.post("/actions/{action_id}/execute", response_model=GeoPlanActionExecuteResponse) async def execute_action( action_id: uuid.UUID, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): stmt = select(GeoPlanAction).where(GeoPlanAction.id == action_id) result = await db.execute(stmt) action = result.scalar_one_or_none() if not action: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="行动项不存在", ) plan_stmt = select(GeoPlan).where(GeoPlan.id == action.plan_id) plan_result = await db.execute(plan_stmt) plan = plan_result.scalar_one() brand = await _get_brand_with_access(plan.brand_id, db, current_user) if action.action_type not in ("content_creation", "content_optimization"): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"行动类型 '{action.action_type}' 不支持一键执行,仅支持 content_creation 和 content_optimization", ) params = action.execution_params or {} keyword = params.get("keyword", action.target_keyword or brand.name) platform = params.get("platform", action.target_platform or "通用") style = params.get("style", action.content_style or "专业严谨") word_count = params.get("word_count", 2000) knowledge_base_ids = params.get("knowledge_base_ids") content_service = ContentGenerationService() try: gen_result = await content_service.generate_content( keyword=keyword, brand_name=brand.name, platform=platform, content_style=style, word_count=word_count, knowledge_base_ids=knowledge_base_ids, db=db, user_id=current_user.id, org_id=str(plan.organization_id), run_deai=True, run_geo=True, ) except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"内容生成失败: {str(e)}", ) action.status = "completed" action.completed_at = datetime.now() await db.commit() await db.refresh(action) content_id = gen_result.get("content_id") return GeoPlanActionExecuteResponse( action_id=action.id, content_id=content_id, message="内容生成成功" if content_id else "内容生成完成(未持久化)", )