238 lines
7.7 KiB
Python
238 lines
7.7 KiB
Python
"""内容分发 API 路由"""
|
|
|
|
import uuid
|
|
from datetime import datetime
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, status
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.api.deps import get_current_user
|
|
from app.database import get_db
|
|
from app.models.distribution import DistributionSchedule
|
|
from app.models.user import User
|
|
from app.schemas.distribution import (
|
|
ContentFormatRequest,
|
|
ContentFormatResponse,
|
|
ContentValidateRequest,
|
|
ContentValidateResponse,
|
|
PlatformInfo,
|
|
PlatformListResponse,
|
|
PlatformSchedule,
|
|
PublishStrategyRequest,
|
|
PublishStrategyResponse,
|
|
ScheduleCreateRequest,
|
|
ScheduleCreateResponse,
|
|
)
|
|
from app.services.distribution.formatter import ContentFormatter
|
|
from app.services.distribution.platform_rules import PLATFORM_RULES, PlatformRuleEngine
|
|
from app.services.distribution.publish_strategy import PublishStrategyService
|
|
|
|
router = APIRouter()
|
|
|
|
# 服务实例
|
|
_rule_engine = PlatformRuleEngine()
|
|
_strategy_service = PublishStrategyService()
|
|
_formatter = ContentFormatter()
|
|
|
|
|
|
@router.get("/platforms", response_model=PlatformListResponse)
|
|
async def list_platforms():
|
|
"""获取所有支持平台列表"""
|
|
platforms_raw = _rule_engine.get_platforms()
|
|
|
|
platforms: list[PlatformInfo] = []
|
|
for p in platforms_raw:
|
|
platform_key = p["id"]
|
|
full_rules = PLATFORM_RULES.get(platform_key, {})
|
|
|
|
format_features = None
|
|
if "format_features" in full_rules:
|
|
ff = full_rules["format_features"]
|
|
format_features = {
|
|
"supports_markdown": ff.get("supports_markdown", False),
|
|
"supports_html": ff.get("supports_html", False),
|
|
"max_heading_level": ff.get("max_heading_level", 0),
|
|
"supports_code_block": ff.get("supports_code_block", False),
|
|
"supports_emoji": ff.get("supports_emoji", False),
|
|
}
|
|
|
|
platforms.append(PlatformInfo(
|
|
id=p["id"],
|
|
name=p["name"],
|
|
icon=p["icon"],
|
|
max_title_length=p["max_title_length"],
|
|
max_content_length=p["max_content_length"],
|
|
min_content_length=p["min_content_length"],
|
|
supported_media=p["supported_media"],
|
|
max_images=p["max_images"],
|
|
best_publish_times=full_rules.get("best_publish_times", []),
|
|
best_publish_days=full_rules.get("best_publish_days", []),
|
|
format_features=format_features,
|
|
rules=full_rules.get("rules", []),
|
|
seo_tips=full_rules.get("seo_tips", []),
|
|
))
|
|
|
|
return PlatformListResponse(platforms=platforms)
|
|
|
|
|
|
@router.post("/validate", response_model=ContentValidateResponse)
|
|
async def validate_content(req: ContentValidateRequest):
|
|
"""校验内容是否符合平台规则"""
|
|
if req.platform not in PLATFORM_RULES:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=f"不支持的平台: {req.platform},支持: {', '.join(PLATFORM_RULES.keys())}",
|
|
)
|
|
|
|
result = _rule_engine.validate_content(
|
|
content=req.content,
|
|
title=req.title,
|
|
platform=req.platform,
|
|
)
|
|
|
|
return ContentValidateResponse(
|
|
is_valid=result["is_valid"],
|
|
score=result["score"],
|
|
issues=result["issues"],
|
|
passed=result["passed"],
|
|
)
|
|
|
|
|
|
@router.post("/strategy", response_model=PublishStrategyResponse)
|
|
async def generate_strategy(req: PublishStrategyRequest):
|
|
"""生成多平台发布策略"""
|
|
invalid_platforms = [p for p in req.platforms if p not in PLATFORM_RULES]
|
|
if invalid_platforms:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=f"不支持的平台: {', '.join(invalid_platforms)}",
|
|
)
|
|
|
|
result = _strategy_service.generate_strategy(
|
|
content_title=req.content_title,
|
|
platforms=req.platforms,
|
|
industry=req.industry,
|
|
)
|
|
|
|
return PublishStrategyResponse(
|
|
schedule=result["schedule"],
|
|
tags=result["tags"],
|
|
tips=result["tips"],
|
|
)
|
|
|
|
|
|
@router.post("/format", response_model=ContentFormatResponse)
|
|
async def format_content(req: ContentFormatRequest):
|
|
"""将内容格式化为指定平台格式"""
|
|
if req.platform not in PLATFORM_RULES:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=f"不支持的平台: {req.platform},支持: {', '.join(PLATFORM_RULES.keys())}",
|
|
)
|
|
|
|
original_length = len(req.content)
|
|
formatted = _formatter.format_for_platform(req.content, req.platform)
|
|
|
|
return ContentFormatResponse(
|
|
platform=req.platform,
|
|
formatted_content=formatted,
|
|
original_length=original_length,
|
|
formatted_length=len(formatted),
|
|
)
|
|
|
|
|
|
@router.post("/schedule", response_model=ScheduleCreateResponse)
|
|
async def create_schedule(
|
|
req: ScheduleCreateRequest,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
"""创建发布排期(持久化到数据库)"""
|
|
org_id = current_user.organization_id
|
|
if not org_id:
|
|
raise HTTPException(status_code=403, detail="用户未关联组织")
|
|
|
|
invalid_platforms = [p for p in req.platforms if p not in PLATFORM_RULES]
|
|
if invalid_platforms:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=f"不支持的平台: {', '.join(invalid_platforms)}",
|
|
)
|
|
|
|
# 如果有自定义时间,使用自定义时间;否则自动生成策略
|
|
if req.scheduled_times:
|
|
platforms_schedule: list[PlatformSchedule] = []
|
|
for platform_key in req.platforms:
|
|
scheduled_time = req.scheduled_times.get(
|
|
platform_key,
|
|
datetime.now().strftime("%Y-%m-%d %H:%M"),
|
|
)
|
|
platform_name = PLATFORM_RULES[platform_key]["name"]
|
|
platforms_schedule.append(PlatformSchedule(
|
|
platform=platform_key,
|
|
platform_name=platform_name,
|
|
scheduled_time=scheduled_time,
|
|
status="pending",
|
|
))
|
|
else:
|
|
strategy = _strategy_service.generate_strategy(
|
|
content_title=req.content_title,
|
|
platforms=req.platforms,
|
|
industry=req.industry,
|
|
)
|
|
platforms_schedule = [
|
|
PlatformSchedule(
|
|
platform=item["platform"],
|
|
platform_name=item["platform_name"],
|
|
scheduled_time=item["suggested_time"],
|
|
status="pending",
|
|
)
|
|
for item in strategy["schedule"]
|
|
]
|
|
|
|
# 获取策略提示
|
|
tips = _strategy_service.generate_strategy(
|
|
content_title=req.content_title,
|
|
platforms=req.platforms,
|
|
industry=req.industry,
|
|
)["tips"]
|
|
|
|
# ---- 持久化到数据库 ----
|
|
platforms_data = [
|
|
{
|
|
"platform": ps.platform,
|
|
"platform_name": ps.platform_name,
|
|
"scheduled_time": ps.scheduled_time,
|
|
"status": ps.status,
|
|
}
|
|
for ps in platforms_schedule
|
|
]
|
|
|
|
content_uuid = None
|
|
if req.content_id:
|
|
try:
|
|
content_uuid = uuid.UUID(req.content_id)
|
|
except ValueError:
|
|
pass
|
|
|
|
schedule = DistributionSchedule(
|
|
organization_id=org_id,
|
|
content_title=req.content_title,
|
|
content_id=content_uuid,
|
|
platforms=platforms_data,
|
|
tips=tips,
|
|
status="pending",
|
|
created_by=current_user.id,
|
|
)
|
|
db.add(schedule)
|
|
await db.commit()
|
|
await db.refresh(schedule)
|
|
|
|
return ScheduleCreateResponse(
|
|
schedule_id=str(schedule.id),
|
|
content_title=req.content_title,
|
|
platforms=platforms_schedule,
|
|
tips=tips,
|
|
created_at=schedule.created_at.strftime("%Y-%m-%d %H:%M:%S"),
|
|
)
|