geo/backend/app/api/strategy.py

388 lines
12 KiB
Python

import uuid
from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.database import get_db
from app.models.user import User
from app.models.brand import Brand
from app.models.geo_plan import GeoPlan, GeoPlanAction
from app.schemas.geo_plan import (
GeoPlanGenerateRequest,
GeoPlanResponse,
GeoPlanListResponse,
GeoPlanActionResponse,
GeoPlanActionUpdateStatus,
GeoPlanActionExecuteResponse,
)
from app.services.scoring.brand_scoring_data_service import get_brand_scoring_data_service
from app.services.strategy.geo_plan_generator import generate_geo_plan
from app.services.content.content_generation_service import ContentGenerationService
router = APIRouter()
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
if isinstance(value, uuid.UUID):
return value
return uuid.UUID(str(value))
async def _get_brand_with_access(
brand_id: uuid.UUID,
db: AsyncSession,
current_user: User,
) -> Brand:
stmt = select(Brand).where(
Brand.id == brand_id,
Brand.user_id == _to_uuid(current_user.id),
)
result = await db.execute(stmt)
brand = result.scalar_one_or_none()
if not brand:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="品牌不存在",
)
return brand
async def _get_brand_scoring_data(
db: AsyncSession,
user_id: uuid.UUID,
brand: Brand,
) -> tuple:
scoring_data_service = get_brand_scoring_data_service()
scoring_data = await scoring_data_service.get_brand_scoring_data(db, user_id, brand)
return (
scoring_data.v2_result,
scoring_data.competitor_data,
scoring_data.sentiment_counts,
scoring_data.platform_scores,
scoring_data.total_queries,
scoring_data.mentioned_count,
)
def _plan_to_response(plan: GeoPlan) -> GeoPlanResponse:
actions = [
GeoPlanActionResponse(
id=action.id,
plan_id=action.plan_id,
action_type=action.action_type,
title=action.title,
description=action.description,
reason=action.reason,
priority=action.priority,
status=action.status,
target_keyword=action.target_keyword,
target_platform=action.target_platform,
content_style=action.content_style,
estimated_impact=action.estimated_impact,
difficulty=action.difficulty,
execution_params=action.execution_params,
sort_order=action.sort_order,
completed_at=action.completed_at,
created_at=action.created_at,
)
for action in sorted(plan.actions, key=lambda a: a.sort_order)
]
return GeoPlanResponse(
id=plan.id,
brand_id=plan.brand_id,
title=plan.title,
status=plan.status,
diagnosis_score=plan.diagnosis_score,
target_score=plan.target_score,
estimated_weeks=plan.estimated_weeks,
plan_data=plan.plan_data,
source=plan.source,
actions=actions,
created_at=plan.created_at,
updated_at=plan.updated_at,
)
@router.post("/generate", response_model=GeoPlanResponse)
async def generate_geo_plan_endpoint(
request: GeoPlanGenerateRequest,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
brand = await _get_brand_with_access(request.brand_id, db, current_user)
(
v2_result,
competitor_data,
sentiment_data,
platform_scores,
total_queries,
mentioned_count,
) = await _get_brand_scoring_data(db, _to_uuid(current_user.id), brand)
target_score = request.target_score or 75
plan_data = await generate_geo_plan(
brand_name=brand.name,
scoring_result=v2_result,
target_score=target_score,
total_queries=total_queries,
platform_scores=platform_scores,
competitor_data=competitor_data,
)
from app.config import settings
source = "llm" if (settings.ENABLE_LLM and settings.DEEPSEEK_API_KEY) else "rule"
organization_id = current_user.id
org_stmt = select(func.count()).select_from(
select(1).where(True).subquery()
)
db_plan = GeoPlan(
organization_id=organization_id,
brand_id=brand.id,
title=plan_data.title,
status="draft",
diagnosis_score=int(round(v2_result.overall_score)),
target_score=target_score,
estimated_weeks=plan_data.estimated_weeks,
plan_data={
"weekly_plan": plan_data.weekly_plan,
},
source=source,
created_by=current_user.id,
)
db.add(db_plan)
await db.flush()
for idx, action_item in enumerate(plan_data.actions):
db_action = GeoPlanAction(
plan_id=db_plan.id,
action_type=action_item.action_type,
title=action_item.title,
description=action_item.description,
reason=action_item.reason,
priority=action_item.priority,
status="pending",
target_keyword=action_item.target_keyword,
target_platform=action_item.target_platform,
content_style=action_item.content_style,
estimated_impact=action_item.estimated_impact,
difficulty=action_item.difficulty,
execution_params=action_item.execution_params,
sort_order=idx,
)
db.add(db_action)
await db.commit()
await db.refresh(db_plan)
stmt = select(GeoPlan).where(GeoPlan.id == db_plan.id)
result = await db.execute(stmt)
db_plan = result.scalar_one()
action_stmt = select(GeoPlanAction).where(
GeoPlanAction.plan_id == db_plan.id
).order_by(GeoPlanAction.sort_order)
action_result = await db.execute(action_stmt)
db_plan.actions = list(action_result.scalars().all())
return _plan_to_response(db_plan)
@router.get("/brand/{brand_id}", response_model=GeoPlanListResponse)
async def get_brand_plans(
brand_id: uuid.UUID,
skip: int = Query(0, ge=0),
limit: int = Query(20, ge=1, le=100),
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
await _get_brand_with_access(brand_id, db, current_user)
count_stmt = select(func.count()).select_from(GeoPlan).where(
GeoPlan.brand_id == brand_id,
)
count_result = await db.execute(count_stmt)
total = count_result.scalar_one()
stmt = (
select(GeoPlan)
.where(GeoPlan.brand_id == brand_id)
.order_by(GeoPlan.created_at.desc())
.offset(skip)
.limit(limit)
)
result = await db.execute(stmt)
plans = list(result.scalars().all())
plan_responses = []
for plan in plans:
action_stmt = select(GeoPlanAction).where(
GeoPlanAction.plan_id == plan.id
).order_by(GeoPlanAction.sort_order)
action_result = await db.execute(action_stmt)
plan.actions = list(action_result.scalars().all())
plan_responses.append(_plan_to_response(plan))
return GeoPlanListResponse(plans=plan_responses, total=total)
@router.get("/{plan_id}", response_model=GeoPlanResponse)
async def get_plan_detail(
plan_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
stmt = select(GeoPlan).where(GeoPlan.id == plan_id)
result = await db.execute(stmt)
plan = result.scalar_one_or_none()
if not plan:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="方案不存在",
)
brand = await _get_brand_with_access(plan.brand_id, db, current_user)
action_stmt = select(GeoPlanAction).where(
GeoPlanAction.plan_id == plan.id
).order_by(GeoPlanAction.sort_order)
action_result = await db.execute(action_stmt)
plan.actions = list(action_result.scalars().all())
return _plan_to_response(plan)
@router.put("/actions/{action_id}/status", response_model=GeoPlanActionResponse)
async def update_action_status(
action_id: uuid.UUID,
status_update: GeoPlanActionUpdateStatus,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
valid_statuses = {"pending", "in_progress", "completed", "skipped"}
if status_update.status not in valid_statuses:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"无效的状态值,支持: {', '.join(valid_statuses)}",
)
stmt = select(GeoPlanAction).where(GeoPlanAction.id == action_id)
result = await db.execute(stmt)
action = result.scalar_one_or_none()
if not action:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="行动项不存在",
)
plan_stmt = select(GeoPlan).where(GeoPlan.id == action.plan_id)
plan_result = await db.execute(plan_stmt)
plan = plan_result.scalar_one()
await _get_brand_with_access(plan.brand_id, db, current_user)
action.status = status_update.status
if status_update.status == "completed":
action.completed_at = datetime.now()
await db.commit()
await db.refresh(action)
return GeoPlanActionResponse(
id=action.id,
plan_id=action.plan_id,
action_type=action.action_type,
title=action.title,
description=action.description,
reason=action.reason,
priority=action.priority,
status=action.status,
target_keyword=action.target_keyword,
target_platform=action.target_platform,
content_style=action.content_style,
estimated_impact=action.estimated_impact,
difficulty=action.difficulty,
execution_params=action.execution_params,
sort_order=action.sort_order,
completed_at=action.completed_at,
created_at=action.created_at,
)
@router.post("/actions/{action_id}/execute", response_model=GeoPlanActionExecuteResponse)
async def execute_action(
action_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
):
stmt = select(GeoPlanAction).where(GeoPlanAction.id == action_id)
result = await db.execute(stmt)
action = result.scalar_one_or_none()
if not action:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="行动项不存在",
)
plan_stmt = select(GeoPlan).where(GeoPlan.id == action.plan_id)
plan_result = await db.execute(plan_stmt)
plan = plan_result.scalar_one()
brand = await _get_brand_with_access(plan.brand_id, db, current_user)
if action.action_type not in ("content_creation", "content_optimization"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"行动类型 '{action.action_type}' 不支持一键执行,仅支持 content_creation 和 content_optimization",
)
params = action.execution_params or {}
keyword = params.get("keyword", action.target_keyword or brand.name)
platform = params.get("platform", action.target_platform or "通用")
style = params.get("style", action.content_style or "专业严谨")
word_count = params.get("word_count", 2000)
knowledge_base_ids = params.get("knowledge_base_ids")
content_service = ContentGenerationService()
try:
gen_result = await content_service.generate_content(
keyword=keyword,
brand_name=brand.name,
platform=platform,
content_style=style,
word_count=word_count,
knowledge_base_ids=knowledge_base_ids,
db=db,
user_id=current_user.id,
org_id=str(plan.organization_id),
run_deai=True,
run_geo=True,
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"内容生成失败: {str(e)}",
)
action.status = "completed"
action.completed_at = datetime.now()
await db.commit()
await db.refresh(action)
content_id = gen_result.get("content_id")
return GeoPlanActionExecuteResponse(
action_id=action.id,
content_id=content_id,
message="内容生成成功" if content_id else "内容生成完成(未持久化)",
)