393 lines
12 KiB
Python
393 lines
12 KiB
Python
import uuid
|
|
from datetime import datetime
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
|
from sqlalchemy import select, func
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.orm import selectinload
|
|
|
|
from app.api.deps import get_current_user
|
|
from app.database import get_db
|
|
from app.models.user import User
|
|
from app.models.brand import Brand
|
|
from app.models.geo_plan import GeoPlan, GeoPlanAction
|
|
from app.schemas.geo_plan import (
|
|
GeoPlanGenerateRequest,
|
|
GeoPlanResponse,
|
|
GeoPlanListResponse,
|
|
GeoPlanActionResponse,
|
|
GeoPlanActionUpdateStatus,
|
|
GeoPlanActionExecuteResponse,
|
|
)
|
|
from app.services.scoring.brand_scoring_data_service import get_brand_scoring_data_service
|
|
from app.services.strategy.geo_plan_generator import generate_geo_plan
|
|
from app.services.content.content_generation_service import ContentGenerationService
|
|
|
|
router = APIRouter()
|
|
def _to_uuid(value: str | uuid.UUID) -> uuid.UUID:
|
|
if isinstance(value, uuid.UUID):
|
|
return value
|
|
return uuid.UUID(str(value))
|
|
|
|
|
|
async def _get_brand_with_access(
|
|
brand_id: uuid.UUID,
|
|
db: AsyncSession,
|
|
current_user: User,
|
|
) -> Brand:
|
|
stmt = select(Brand).where(
|
|
Brand.id == brand_id,
|
|
Brand.user_id == _to_uuid(current_user.id),
|
|
)
|
|
result = await db.execute(stmt)
|
|
brand = result.scalar_one_or_none()
|
|
|
|
if not brand:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="品牌不存在",
|
|
)
|
|
return brand
|
|
|
|
|
|
async def _get_brand_scoring_data(
|
|
db: AsyncSession,
|
|
user_id: uuid.UUID,
|
|
brand: Brand,
|
|
) -> tuple:
|
|
scoring_data_service = get_brand_scoring_data_service()
|
|
scoring_data = await scoring_data_service.get_brand_scoring_data(db, user_id, brand)
|
|
return (
|
|
scoring_data.v2_result,
|
|
scoring_data.competitor_data,
|
|
scoring_data.sentiment_counts,
|
|
scoring_data.platform_scores,
|
|
scoring_data.total_queries,
|
|
scoring_data.mentioned_count,
|
|
)
|
|
|
|
|
|
def _plan_to_response(plan: GeoPlan) -> GeoPlanResponse:
|
|
actions = [
|
|
GeoPlanActionResponse(
|
|
id=action.id,
|
|
plan_id=action.plan_id,
|
|
action_type=action.action_type,
|
|
title=action.title,
|
|
description=action.description,
|
|
reason=action.reason,
|
|
priority=action.priority,
|
|
status=action.status,
|
|
target_keyword=action.target_keyword,
|
|
target_platform=action.target_platform,
|
|
content_style=action.content_style,
|
|
estimated_impact=action.estimated_impact,
|
|
difficulty=action.difficulty,
|
|
execution_params=action.execution_params,
|
|
sort_order=action.sort_order,
|
|
completed_at=action.completed_at,
|
|
created_at=action.created_at,
|
|
)
|
|
for action in sorted(plan.actions, key=lambda a: a.sort_order)
|
|
]
|
|
return GeoPlanResponse(
|
|
id=plan.id,
|
|
brand_id=plan.brand_id,
|
|
title=plan.title,
|
|
status=plan.status,
|
|
diagnosis_score=plan.diagnosis_score,
|
|
target_score=plan.target_score,
|
|
estimated_weeks=plan.estimated_weeks,
|
|
plan_data=plan.plan_data,
|
|
source=plan.source,
|
|
actions=actions,
|
|
created_at=plan.created_at,
|
|
updated_at=plan.updated_at,
|
|
)
|
|
|
|
|
|
@router.post("/generate", response_model=GeoPlanResponse)
|
|
async def generate_geo_plan_endpoint(
|
|
request: GeoPlanGenerateRequest,
|
|
current_user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
brand = await _get_brand_with_access(request.brand_id, db, current_user)
|
|
|
|
(
|
|
v2_result,
|
|
competitor_data,
|
|
sentiment_data,
|
|
platform_scores,
|
|
total_queries,
|
|
mentioned_count,
|
|
) = await _get_brand_scoring_data(db, _to_uuid(current_user.id), brand)
|
|
|
|
target_score = request.target_score or 75
|
|
|
|
plan_data = await generate_geo_plan(
|
|
brand_name=brand.name,
|
|
scoring_result=v2_result,
|
|
target_score=target_score,
|
|
total_queries=total_queries,
|
|
platform_scores=platform_scores,
|
|
competitor_data=competitor_data,
|
|
)
|
|
|
|
from app.config import settings
|
|
source = "llm" if (settings.ENABLE_LLM and settings.DEEPSEEK_API_KEY) else "rule"
|
|
|
|
organization_id = current_user.id
|
|
org_stmt = select(func.count()).select_from(
|
|
select(1).where(True).subquery()
|
|
)
|
|
|
|
db_plan = GeoPlan(
|
|
organization_id=organization_id,
|
|
brand_id=brand.id,
|
|
title=plan_data.title,
|
|
status="draft",
|
|
diagnosis_score=int(round(v2_result.overall_score)),
|
|
target_score=target_score,
|
|
estimated_weeks=plan_data.estimated_weeks,
|
|
plan_data={
|
|
"weekly_plan": plan_data.weekly_plan,
|
|
},
|
|
source=source,
|
|
created_by=current_user.id,
|
|
)
|
|
db.add(db_plan)
|
|
await db.flush()
|
|
|
|
for idx, action_item in enumerate(plan_data.actions):
|
|
db_action = GeoPlanAction(
|
|
plan_id=db_plan.id,
|
|
action_type=action_item.action_type,
|
|
title=action_item.title,
|
|
description=action_item.description,
|
|
reason=action_item.reason,
|
|
priority=action_item.priority,
|
|
status="pending",
|
|
target_keyword=action_item.target_keyword,
|
|
target_platform=action_item.target_platform,
|
|
content_style=action_item.content_style,
|
|
estimated_impact=action_item.estimated_impact,
|
|
difficulty=action_item.difficulty,
|
|
execution_params=action_item.execution_params,
|
|
sort_order=idx,
|
|
)
|
|
db.add(db_action)
|
|
|
|
await db.commit()
|
|
await db.refresh(db_plan)
|
|
|
|
stmt = (
|
|
select(GeoPlan)
|
|
.options(selectinload(GeoPlanAction.plan))
|
|
.where(GeoPlan.id == db_plan.id)
|
|
)
|
|
result = await db.execute(stmt)
|
|
db_plan = result.scalar_one()
|
|
|
|
action_stmt = select(GeoPlanAction).where(
|
|
GeoPlanAction.plan_id == db_plan.id
|
|
).order_by(GeoPlanAction.sort_order)
|
|
action_result = await db.execute(action_stmt)
|
|
db_plan.actions = list(action_result.scalars().all())
|
|
|
|
return _plan_to_response(db_plan)
|
|
|
|
|
|
@router.get("/brand/{brand_id}", response_model=GeoPlanListResponse)
|
|
async def get_brand_plans(
|
|
brand_id: uuid.UUID,
|
|
skip: int = Query(0, ge=0),
|
|
limit: int = Query(20, ge=1, le=100),
|
|
current_user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
await _get_brand_with_access(brand_id, db, current_user)
|
|
|
|
count_stmt = select(func.count()).select_from(GeoPlan).where(
|
|
GeoPlan.brand_id == brand_id,
|
|
)
|
|
count_result = await db.execute(count_stmt)
|
|
total = count_result.scalar_one()
|
|
|
|
stmt = (
|
|
select(GeoPlan)
|
|
.where(GeoPlan.brand_id == brand_id)
|
|
.order_by(GeoPlan.created_at.desc())
|
|
.offset(skip)
|
|
.limit(limit)
|
|
)
|
|
result = await db.execute(stmt)
|
|
plans = list(result.scalars().all())
|
|
|
|
plan_responses = []
|
|
for plan in plans:
|
|
action_stmt = select(GeoPlanAction).where(
|
|
GeoPlanAction.plan_id == plan.id
|
|
).order_by(GeoPlanAction.sort_order)
|
|
action_result = await db.execute(action_stmt)
|
|
plan.actions = list(action_result.scalars().all())
|
|
plan_responses.append(_plan_to_response(plan))
|
|
|
|
return GeoPlanListResponse(plans=plan_responses, total=total)
|
|
|
|
|
|
@router.get("/{plan_id}", response_model=GeoPlanResponse)
|
|
async def get_plan_detail(
|
|
plan_id: uuid.UUID,
|
|
current_user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
stmt = select(GeoPlan).where(GeoPlan.id == plan_id)
|
|
result = await db.execute(stmt)
|
|
plan = result.scalar_one_or_none()
|
|
|
|
if not plan:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="方案不存在",
|
|
)
|
|
|
|
brand = await _get_brand_with_access(plan.brand_id, db, current_user)
|
|
|
|
action_stmt = select(GeoPlanAction).where(
|
|
GeoPlanAction.plan_id == plan.id
|
|
).order_by(GeoPlanAction.sort_order)
|
|
action_result = await db.execute(action_stmt)
|
|
plan.actions = list(action_result.scalars().all())
|
|
|
|
return _plan_to_response(plan)
|
|
|
|
|
|
@router.put("/actions/{action_id}/status", response_model=GeoPlanActionResponse)
|
|
async def update_action_status(
|
|
action_id: uuid.UUID,
|
|
status_update: GeoPlanActionUpdateStatus,
|
|
current_user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
valid_statuses = {"pending", "in_progress", "completed", "skipped"}
|
|
if status_update.status not in valid_statuses:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=f"无效的状态值,支持: {', '.join(valid_statuses)}",
|
|
)
|
|
|
|
stmt = select(GeoPlanAction).where(GeoPlanAction.id == action_id)
|
|
result = await db.execute(stmt)
|
|
action = result.scalar_one_or_none()
|
|
|
|
if not action:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="行动项不存在",
|
|
)
|
|
|
|
plan_stmt = select(GeoPlan).where(GeoPlan.id == action.plan_id)
|
|
plan_result = await db.execute(plan_stmt)
|
|
plan = plan_result.scalar_one()
|
|
|
|
await _get_brand_with_access(plan.brand_id, db, current_user)
|
|
|
|
action.status = status_update.status
|
|
if status_update.status == "completed":
|
|
action.completed_at = datetime.now()
|
|
|
|
await db.commit()
|
|
await db.refresh(action)
|
|
|
|
return GeoPlanActionResponse(
|
|
id=action.id,
|
|
plan_id=action.plan_id,
|
|
action_type=action.action_type,
|
|
title=action.title,
|
|
description=action.description,
|
|
reason=action.reason,
|
|
priority=action.priority,
|
|
status=action.status,
|
|
target_keyword=action.target_keyword,
|
|
target_platform=action.target_platform,
|
|
content_style=action.content_style,
|
|
estimated_impact=action.estimated_impact,
|
|
difficulty=action.difficulty,
|
|
execution_params=action.execution_params,
|
|
sort_order=action.sort_order,
|
|
completed_at=action.completed_at,
|
|
created_at=action.created_at,
|
|
)
|
|
|
|
|
|
@router.post("/actions/{action_id}/execute", response_model=GeoPlanActionExecuteResponse)
|
|
async def execute_action(
|
|
action_id: uuid.UUID,
|
|
current_user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
):
|
|
stmt = select(GeoPlanAction).where(GeoPlanAction.id == action_id)
|
|
result = await db.execute(stmt)
|
|
action = result.scalar_one_or_none()
|
|
|
|
if not action:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="行动项不存在",
|
|
)
|
|
|
|
plan_stmt = select(GeoPlan).where(GeoPlan.id == action.plan_id)
|
|
plan_result = await db.execute(plan_stmt)
|
|
plan = plan_result.scalar_one()
|
|
|
|
brand = await _get_brand_with_access(plan.brand_id, db, current_user)
|
|
|
|
if action.action_type not in ("content_creation", "content_optimization"):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=f"行动类型 '{action.action_type}' 不支持一键执行,仅支持 content_creation 和 content_optimization",
|
|
)
|
|
|
|
params = action.execution_params or {}
|
|
keyword = params.get("keyword", action.target_keyword or brand.name)
|
|
platform = params.get("platform", action.target_platform or "通用")
|
|
style = params.get("style", action.content_style or "专业严谨")
|
|
word_count = params.get("word_count", 2000)
|
|
knowledge_base_ids = params.get("knowledge_base_ids")
|
|
|
|
content_service = ContentGenerationService()
|
|
|
|
try:
|
|
gen_result = await content_service.generate_content(
|
|
keyword=keyword,
|
|
brand_name=brand.name,
|
|
platform=platform,
|
|
content_style=style,
|
|
word_count=word_count,
|
|
knowledge_base_ids=knowledge_base_ids,
|
|
db=db,
|
|
user_id=current_user.id,
|
|
org_id=str(plan.organization_id),
|
|
run_deai=True,
|
|
run_geo=True,
|
|
)
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"内容生成失败: {str(e)}",
|
|
)
|
|
|
|
action.status = "completed"
|
|
action.completed_at = datetime.now()
|
|
await db.commit()
|
|
await db.refresh(action)
|
|
|
|
content_id = gen_result.get("content_id")
|
|
|
|
return GeoPlanActionExecuteResponse(
|
|
action_id=action.id,
|
|
content_id=content_id,
|
|
message="内容生成成功" if content_id else "内容生成完成(未持久化)",
|
|
)
|