import uuid from datetime import datetime from fastapi import APIRouter, Depends, HTTPException, Query, status from sqlalchemy import select, func from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from app.api.deps import get_current_user from app.database import get_db from app.models.user import User from app.models.brand import Brand from app.models.geo_plan import GeoPlan, GeoPlanAction from app.schemas.geo_plan import ( GeoPlanGenerateRequest, GeoPlanResponse, GeoPlanListResponse, GeoPlanActionResponse, GeoPlanActionUpdateStatus, GeoPlanActionExecuteResponse, ) from app.services.scoring.brand_scoring_data_service import get_brand_scoring_data_service from app.services.strategy.geo_plan_generator import generate_geo_plan from app.services.content.content_generation_service import ContentGenerationService router = APIRouter() def _to_uuid(value: str | uuid.UUID) -> uuid.UUID: if isinstance(value, uuid.UUID): return value return uuid.UUID(str(value)) async def _get_brand_with_access( brand_id: uuid.UUID, db: AsyncSession, current_user: User, ) -> Brand: stmt = select(Brand).where( Brand.id == brand_id, Brand.user_id == _to_uuid(current_user.id), ) result = await db.execute(stmt) brand = result.scalar_one_or_none() if not brand: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="品牌不存在", ) return brand async def _get_brand_scoring_data( db: AsyncSession, user_id: uuid.UUID, brand: Brand, ) -> tuple: scoring_data_service = get_brand_scoring_data_service() scoring_data = await scoring_data_service.get_brand_scoring_data(db, user_id, brand) return ( scoring_data.v2_result, scoring_data.competitor_data, scoring_data.sentiment_counts, scoring_data.platform_scores, scoring_data.total_queries, scoring_data.mentioned_count, ) def _plan_to_response(plan: GeoPlan) -> GeoPlanResponse: actions = [ GeoPlanActionResponse( id=action.id, plan_id=action.plan_id, action_type=action.action_type, title=action.title, description=action.description, reason=action.reason, priority=action.priority, status=action.status, target_keyword=action.target_keyword, target_platform=action.target_platform, content_style=action.content_style, estimated_impact=action.estimated_impact, difficulty=action.difficulty, execution_params=action.execution_params, sort_order=action.sort_order, completed_at=action.completed_at, created_at=action.created_at, ) for action in sorted(plan.actions, key=lambda a: a.sort_order) ] return GeoPlanResponse( id=plan.id, brand_id=plan.brand_id, title=plan.title, status=plan.status, diagnosis_score=plan.diagnosis_score, target_score=plan.target_score, estimated_weeks=plan.estimated_weeks, plan_data=plan.plan_data, source=plan.source, actions=actions, created_at=plan.created_at, updated_at=plan.updated_at, ) @router.post("/generate", response_model=GeoPlanResponse) async def generate_geo_plan_endpoint( request: GeoPlanGenerateRequest, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): brand = await _get_brand_with_access(request.brand_id, db, current_user) ( v2_result, competitor_data, sentiment_data, platform_scores, total_queries, mentioned_count, ) = await _get_brand_scoring_data(db, _to_uuid(current_user.id), brand) target_score = request.target_score or 75 plan_data = await generate_geo_plan( brand_name=brand.name, scoring_result=v2_result, target_score=target_score, total_queries=total_queries, platform_scores=platform_scores, competitor_data=competitor_data, ) from app.config import settings source = "llm" if (settings.ENABLE_LLM and settings.DEEPSEEK_API_KEY) else "rule" organization_id = current_user.id org_stmt = select(func.count()).select_from( select(1).where(True).subquery() ) db_plan = GeoPlan( organization_id=organization_id, brand_id=brand.id, title=plan_data.title, status="draft", diagnosis_score=int(round(v2_result.overall_score)), target_score=target_score, estimated_weeks=plan_data.estimated_weeks, plan_data={ "weekly_plan": plan_data.weekly_plan, }, source=source, created_by=current_user.id, ) db.add(db_plan) await db.flush() for idx, action_item in enumerate(plan_data.actions): db_action = GeoPlanAction( plan_id=db_plan.id, action_type=action_item.action_type, title=action_item.title, description=action_item.description, reason=action_item.reason, priority=action_item.priority, status="pending", target_keyword=action_item.target_keyword, target_platform=action_item.target_platform, content_style=action_item.content_style, estimated_impact=action_item.estimated_impact, difficulty=action_item.difficulty, execution_params=action_item.execution_params, sort_order=idx, ) db.add(db_action) await db.commit() await db.refresh(db_plan) stmt = ( select(GeoPlan) .options(selectinload(GeoPlanAction.plan)) .where(GeoPlan.id == db_plan.id) ) result = await db.execute(stmt) db_plan = result.scalar_one() action_stmt = select(GeoPlanAction).where( GeoPlanAction.plan_id == db_plan.id ).order_by(GeoPlanAction.sort_order) action_result = await db.execute(action_stmt) db_plan.actions = list(action_result.scalars().all()) return _plan_to_response(db_plan) @router.get("/brand/{brand_id}", response_model=GeoPlanListResponse) async def get_brand_plans( brand_id: uuid.UUID, skip: int = Query(0, ge=0), limit: int = Query(20, ge=1, le=100), current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): await _get_brand_with_access(brand_id, db, current_user) count_stmt = select(func.count()).select_from(GeoPlan).where( GeoPlan.brand_id == brand_id, ) count_result = await db.execute(count_stmt) total = count_result.scalar_one() stmt = ( select(GeoPlan) .where(GeoPlan.brand_id == brand_id) .order_by(GeoPlan.created_at.desc()) .offset(skip) .limit(limit) ) result = await db.execute(stmt) plans = list(result.scalars().all()) plan_responses = [] for plan in plans: action_stmt = select(GeoPlanAction).where( GeoPlanAction.plan_id == plan.id ).order_by(GeoPlanAction.sort_order) action_result = await db.execute(action_stmt) plan.actions = list(action_result.scalars().all()) plan_responses.append(_plan_to_response(plan)) return GeoPlanListResponse(plans=plan_responses, total=total) @router.get("/{plan_id}", response_model=GeoPlanResponse) async def get_plan_detail( plan_id: uuid.UUID, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): stmt = select(GeoPlan).where(GeoPlan.id == plan_id) result = await db.execute(stmt) plan = result.scalar_one_or_none() if not plan: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="方案不存在", ) brand = await _get_brand_with_access(plan.brand_id, db, current_user) action_stmt = select(GeoPlanAction).where( GeoPlanAction.plan_id == plan.id ).order_by(GeoPlanAction.sort_order) action_result = await db.execute(action_stmt) plan.actions = list(action_result.scalars().all()) return _plan_to_response(plan) @router.put("/actions/{action_id}/status", response_model=GeoPlanActionResponse) async def update_action_status( action_id: uuid.UUID, status_update: GeoPlanActionUpdateStatus, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): valid_statuses = {"pending", "in_progress", "completed", "skipped"} if status_update.status not in valid_statuses: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"无效的状态值,支持: {', '.join(valid_statuses)}", ) stmt = select(GeoPlanAction).where(GeoPlanAction.id == action_id) result = await db.execute(stmt) action = result.scalar_one_or_none() if not action: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="行动项不存在", ) plan_stmt = select(GeoPlan).where(GeoPlan.id == action.plan_id) plan_result = await db.execute(plan_stmt) plan = plan_result.scalar_one() await _get_brand_with_access(plan.brand_id, db, current_user) action.status = status_update.status if status_update.status == "completed": action.completed_at = datetime.now() await db.commit() await db.refresh(action) return GeoPlanActionResponse( id=action.id, plan_id=action.plan_id, action_type=action.action_type, title=action.title, description=action.description, reason=action.reason, priority=action.priority, status=action.status, target_keyword=action.target_keyword, target_platform=action.target_platform, content_style=action.content_style, estimated_impact=action.estimated_impact, difficulty=action.difficulty, execution_params=action.execution_params, sort_order=action.sort_order, completed_at=action.completed_at, created_at=action.created_at, ) @router.post("/actions/{action_id}/execute", response_model=GeoPlanActionExecuteResponse) async def execute_action( action_id: uuid.UUID, current_user: User = Depends(get_current_user), db: AsyncSession = Depends(get_db), ): stmt = select(GeoPlanAction).where(GeoPlanAction.id == action_id) result = await db.execute(stmt) action = result.scalar_one_or_none() if not action: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="行动项不存在", ) plan_stmt = select(GeoPlan).where(GeoPlan.id == action.plan_id) plan_result = await db.execute(plan_stmt) plan = plan_result.scalar_one() brand = await _get_brand_with_access(plan.brand_id, db, current_user) if action.action_type not in ("content_creation", "content_optimization"): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"行动类型 '{action.action_type}' 不支持一键执行,仅支持 content_creation 和 content_optimization", ) params = action.execution_params or {} keyword = params.get("keyword", action.target_keyword or brand.name) platform = params.get("platform", action.target_platform or "通用") style = params.get("style", action.content_style or "专业严谨") word_count = params.get("word_count", 2000) knowledge_base_ids = params.get("knowledge_base_ids") content_service = ContentGenerationService() try: gen_result = await content_service.generate_content( keyword=keyword, brand_name=brand.name, platform=platform, content_style=style, word_count=word_count, knowledge_base_ids=knowledge_base_ids, db=db, user_id=current_user.id, org_id=str(plan.organization_id), run_deai=True, run_geo=True, ) except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"内容生成失败: {str(e)}", ) action.status = "completed" action.completed_at = datetime.now() await db.commit() await db.refresh(action) content_id = gen_result.get("content_id") return GeoPlanActionExecuteResponse( action_id=action.id, content_id=content_id, message="内容生成成功" if content_id else "内容生成完成(未持久化)", )