470 lines
14 KiB
Python
470 lines
14 KiB
Python
"""平台规则管理 API 路由"""
|
||
|
||
import logging
|
||
from datetime import datetime
|
||
from typing import Any, Optional
|
||
|
||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||
from pydantic import ValidationError
|
||
from sqlalchemy import select, func
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.api.deps import get_current_user
|
||
from app.database import get_db
|
||
from app.models.platform_rule_version import PlatformRuleVersion
|
||
from app.models.user import User
|
||
from app.services.distribution.platform_rules import (
|
||
PLATFORM_RULES,
|
||
rule_engine,
|
||
)
|
||
from app.services.distribution.rule_service import platform_rule_service
|
||
from app.schemas.platform_rule import (
|
||
ContentValidationIssue,
|
||
ContentValidationResponse,
|
||
ContentValidateRequest,
|
||
DeAIContentRequest,
|
||
DeAIContentResponse,
|
||
PlatformBrief,
|
||
PlatformDetailResponse,
|
||
PlatformListResponse,
|
||
PlatformRuleUpdateRequest,
|
||
PlatformRuleUpdateResponse,
|
||
RuleChangeHistory,
|
||
RuleChangeHistoryResponse,
|
||
RuleDiff,
|
||
RuleDiffResponse,
|
||
# 内部 Schema
|
||
AISensitivity,
|
||
ContentLengthRule,
|
||
GEORule,
|
||
HTMLRule,
|
||
KeywordDensity,
|
||
PublishRule,
|
||
SEORule,
|
||
SensitiveWordsConfig,
|
||
StructurePreference,
|
||
TagRule,
|
||
TitleRule,
|
||
)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
router = APIRouter(prefix="/api/v1/platforms", tags=["平台规则管理"])
|
||
|
||
|
||
async def _get_rule_version(
|
||
db: AsyncSession, rule_id: str, version: int
|
||
) -> PlatformRuleVersion | None:
|
||
stmt = select(PlatformRuleVersion).where(
|
||
PlatformRuleVersion.rule_id == rule_id,
|
||
PlatformRuleVersion.version == version,
|
||
)
|
||
result = await db.execute(stmt)
|
||
return result.scalar_one_or_none()
|
||
|
||
|
||
def _compute_diff(
|
||
old_data: dict, new_data: dict, prefix: str = ""
|
||
) -> list[RuleDiff]:
|
||
diffs: list[RuleDiff] = []
|
||
all_keys = set(old_data.keys()) | set(new_data.keys())
|
||
for key in sorted(all_keys):
|
||
field = f"{prefix}{key}" if not prefix else f"{prefix}.{key}"
|
||
old_val = old_data.get(key)
|
||
new_val = new_data.get(key)
|
||
if isinstance(old_val, dict) and isinstance(new_val, dict):
|
||
diffs.extend(_compute_diff(old_val, new_val, field))
|
||
elif old_val != new_val:
|
||
diffs.append(RuleDiff(field=field, old_value=old_val, new_value=new_val))
|
||
return diffs
|
||
|
||
|
||
def _version_to_dict(v: PlatformRuleVersion) -> dict:
|
||
return {
|
||
"id": v.id,
|
||
"rule_id": v.rule_id,
|
||
"platform": v.platform,
|
||
"version": v.version,
|
||
"rule_data": v.rule_data,
|
||
"change_summary": v.change_summary,
|
||
"created_by": v.created_by,
|
||
"created_at": v.created_at.isoformat() if v.created_at else None,
|
||
}
|
||
|
||
|
||
def _convert_rule_to_schema(rules: dict) -> dict:
|
||
"""将规则字典转换为 Schema 格式"""
|
||
if not rules:
|
||
return {}
|
||
|
||
return {
|
||
"content_length": ContentLengthRule(**rules.get("content_length", {})),
|
||
"structure_preference": StructurePreference(**rules.get("structure_preference", {})),
|
||
"title_rules": TitleRule(**rules.get("title_rules", {})),
|
||
"tag_rules": TagRule(**rules.get("tag_rules", {})),
|
||
"ai_sensitivity": AISensitivity(**rules.get("ai_sensitivity", {})),
|
||
"sensitive_words": SensitiveWordsConfig(**rules.get("sensitive_words", {})),
|
||
"seo_rules": SEORule(**rules.get("seo_rules", {})),
|
||
"geo_rules": GEORule(**rules.get("geo_rules", {})),
|
||
"html_rules": HTMLRule(**rules.get("html_rules", {})),
|
||
"publish_rules": PublishRule(**rules.get("publish_rules", {})),
|
||
}
|
||
|
||
|
||
@router.get("", response_model=PlatformListResponse)
|
||
async def list_platforms(
|
||
enabled_only: bool = Query(True, description="是否只返回启用的平台"),
|
||
):
|
||
"""获取所有支持平台列表"""
|
||
platforms_raw = platform_rule_service.get_all_platforms(enabled_only=enabled_only)
|
||
|
||
platforms = [
|
||
PlatformBrief(
|
||
id=p["id"],
|
||
name=p["name"],
|
||
platform_type=p.get("platform_type", ""),
|
||
priority=p.get("priority", "P2"),
|
||
enabled=p.get("enabled", True),
|
||
)
|
||
for p in platforms_raw
|
||
]
|
||
|
||
return PlatformListResponse(platforms=platforms, total=len(platforms))
|
||
|
||
|
||
@router.get("/{platform_id}", response_model=PlatformDetailResponse)
|
||
async def get_platform_detail(platform_id: str):
|
||
"""获取平台详情
|
||
|
||
Args:
|
||
platform_id: 平台标识 (如 zhihu, wechat, baijiahao 等)
|
||
"""
|
||
rules = platform_rule_service.get_platform_detail(platform_id)
|
||
if not rules:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail=f"平台不存在: {platform_id}",
|
||
)
|
||
|
||
# 构建响应
|
||
converted = _convert_rule_to_schema(rules)
|
||
|
||
return PlatformDetailResponse(
|
||
id=platform_id,
|
||
name=rules.get("name", ""),
|
||
platform_type=rules.get("platform_type", ""),
|
||
priority=rules.get("priority", "P2"),
|
||
enabled=rules.get("enabled", True),
|
||
content_style=rules.get("content_style", ""),
|
||
content_length=converted["content_length"],
|
||
structure_preference=converted["structure_preference"],
|
||
title_rules=converted["title_rules"],
|
||
tag_rules=converted["tag_rules"],
|
||
ai_sensitivity=converted["ai_sensitivity"],
|
||
sensitive_words=converted["sensitive_words"],
|
||
seo_rules=converted["seo_rules"],
|
||
geo_rules=converted["geo_rules"],
|
||
html_rules=converted["html_rules"],
|
||
publish_rules=converted["publish_rules"],
|
||
best_publish_times=rules.get("best_publish_times", []),
|
||
best_publish_days=rules.get("best_publish_days", []),
|
||
max_images=rules.get("max_images", 0),
|
||
)
|
||
|
||
|
||
@router.put("/{platform_id}/rules", response_model=PlatformRuleUpdateResponse)
|
||
async def update_platform_rules(
|
||
platform_id: str,
|
||
req: PlatformRuleUpdateRequest,
|
||
current_user: User = Depends(get_current_user),
|
||
):
|
||
"""更新平台规则
|
||
|
||
注意:当前实现更新的是内存中的规则。
|
||
如需持久化,需要配合数据库和规则变更历史表使用。
|
||
"""
|
||
if platform_id not in PLATFORM_RULES:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail=f"平台不存在: {platform_id}",
|
||
)
|
||
|
||
# 获取当前规则
|
||
current_rules = PLATFORM_RULES[platform_id]
|
||
|
||
# 构建更新数据
|
||
update_data = req.model_dump(exclude_unset=True)
|
||
|
||
# 更新规则(这里只做演示,实际需要持久化到数据库)
|
||
for key, value in update_data.items():
|
||
if value is not None:
|
||
if key == "enabled":
|
||
current_rules[key] = value
|
||
elif isinstance(value, dict):
|
||
current_rules[key] = {
|
||
**current_rules.get(key, {}),
|
||
**value,
|
||
}
|
||
else:
|
||
current_rules[key] = value
|
||
|
||
logger.info(
|
||
f"用户 {current_user.id} 更新了平台 {platform_id} 的规则: {list(update_data.keys())}"
|
||
)
|
||
|
||
return PlatformRuleUpdateResponse(
|
||
success=True,
|
||
platform_id=platform_id,
|
||
message=f"规则更新成功",
|
||
updated_at=datetime.now(),
|
||
)
|
||
|
||
|
||
@router.get("/{platform_id}/rules/diff", response_model=RuleDiffResponse)
|
||
async def compare_rule_changes(
|
||
platform_id: str,
|
||
from_version: int = Query(..., description="起始版本号"),
|
||
to_version: int = Query(..., description="目标版本号"),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""对比规则变更
|
||
|
||
Args:
|
||
platform_id: 平台标识
|
||
from_version: 起始版本号
|
||
to_version: 目标版本号
|
||
"""
|
||
if platform_id not in PLATFORM_RULES:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail=f"平台不存在: {platform_id}",
|
||
)
|
||
|
||
current_rules = PLATFORM_RULES[platform_id]
|
||
|
||
from_rule = await _get_rule_version(db, platform_id, from_version)
|
||
to_rule = await _get_rule_version(db, platform_id, to_version)
|
||
|
||
if not from_rule or not to_rule:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail="版本不存在",
|
||
)
|
||
|
||
diffs = _compute_diff(from_rule.rule_data, to_rule.rule_data)
|
||
return RuleDiffResponse(
|
||
platform_id=platform_id,
|
||
platform_name=current_rules.get("name", ""),
|
||
diffs=diffs,
|
||
total_changes=len(diffs),
|
||
)
|
||
|
||
|
||
@router.get("/{platform_id}/rules/history", response_model=RuleChangeHistoryResponse)
|
||
async def get_rule_history(
|
||
platform_id: str,
|
||
limit: int = Query(20, ge=1, le=100, description="返回记录数"),
|
||
db: AsyncSession = Depends(get_db),
|
||
):
|
||
"""获取规则变更历史
|
||
|
||
Args:
|
||
platform_id: 平台标识
|
||
limit: 返回记录数
|
||
"""
|
||
if platform_id not in PLATFORM_RULES:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail=f"平台不存在: {platform_id}",
|
||
)
|
||
|
||
count_stmt = select(func.count()).select_from(PlatformRuleVersion).where(
|
||
PlatformRuleVersion.rule_id == platform_id
|
||
)
|
||
total = (await db.execute(count_stmt)).scalar() or 0
|
||
|
||
stmt = (
|
||
select(PlatformRuleVersion)
|
||
.where(PlatformRuleVersion.rule_id == platform_id)
|
||
.order_by(PlatformRuleVersion.version.desc())
|
||
.limit(limit)
|
||
)
|
||
result = await db.execute(stmt)
|
||
versions = result.scalars().all()
|
||
|
||
history = [
|
||
RuleChangeHistory(
|
||
id=v.version,
|
||
version=v.version,
|
||
platform_id=v.rule_id,
|
||
platform_name=v.platform,
|
||
changed_by=v.created_by or "",
|
||
change_summary=v.change_summary or "",
|
||
change_type="update",
|
||
previous_rules=None,
|
||
new_rules=v.rule_data,
|
||
created_at=v.created_at,
|
||
)
|
||
for v in versions
|
||
]
|
||
|
||
return RuleChangeHistoryResponse(
|
||
history=history,
|
||
total=total,
|
||
)
|
||
|
||
|
||
@router.post("/{platform_id}/rules/validate", response_model=ContentValidationResponse)
|
||
async def validate_content_for_platform(
|
||
platform_id: str,
|
||
req: ContentValidateRequest,
|
||
):
|
||
"""验证内容是否符合平台规则
|
||
|
||
Args:
|
||
platform_id: 平台标识
|
||
req: 验证请求
|
||
"""
|
||
if platform_id not in PLATFORM_RULES:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail=f"平台不存在: {platform_id}",
|
||
)
|
||
|
||
result = platform_rule_service.validate_content(
|
||
content=req.content,
|
||
title=req.title,
|
||
platform_id=platform_id,
|
||
tags=req.tags,
|
||
)
|
||
|
||
issues = [
|
||
ContentValidationIssue(
|
||
severity=i["severity"],
|
||
message=i["message"],
|
||
category=i.get("category", "general"),
|
||
)
|
||
for i in result.get("issues", [])
|
||
]
|
||
|
||
return ContentValidationResponse(
|
||
is_valid=result["is_valid"],
|
||
score=result["score"],
|
||
issues=issues,
|
||
passed=result.get("passed", []),
|
||
)
|
||
|
||
|
||
@router.get("/{platform_id}/rules/{rule_category}")
|
||
async def get_platform_rule_category(
|
||
platform_id: str,
|
||
rule_category: str,
|
||
):
|
||
"""获取平台特定类别的规则
|
||
|
||
Args:
|
||
platform_id: 平台标识
|
||
rule_category: 规则类别 (title_rules, tag_rules, ai_sensitivity, etc.)
|
||
"""
|
||
valid_categories = rule_engine.get_all_rule_categories()
|
||
|
||
if rule_category not in valid_categories:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_400_BAD_REQUEST,
|
||
detail=f"无效的规则类别: {rule_category},有效值: {', '.join(valid_categories)}",
|
||
)
|
||
|
||
rule = platform_rule_service.get_rule_category(platform_id, rule_category)
|
||
if rule is None:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail=f"平台不存在或规则类别无效: {platform_id}/{rule_category}",
|
||
)
|
||
|
||
return {
|
||
"platform_id": platform_id,
|
||
"rule_category": rule_category,
|
||
"rule": rule,
|
||
}
|
||
|
||
|
||
@router.get("/{platform_id}/ai-config")
|
||
async def get_platform_ai_config(platform_id: str):
|
||
"""获取平台AI敏感度配置(用于去AI化处理)
|
||
|
||
Args:
|
||
platform_id: 平台标识
|
||
"""
|
||
config = platform_rule_service.get_ai_humanization_config(platform_id)
|
||
if config is None:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail=f"平台不存在: {platform_id}",
|
||
)
|
||
|
||
return {
|
||
"platform_id": platform_id,
|
||
"ai_sensitivity": config,
|
||
}
|
||
|
||
|
||
@router.post("/{platform_id}/detect-ai-patterns")
|
||
async def detect_ai_patterns_in_content(
|
||
platform_id: str,
|
||
content: str = Query(..., description="待检测内容"),
|
||
):
|
||
"""检测内容中的AI写作模式
|
||
|
||
Args:
|
||
platform_id: 平台标识
|
||
content: 待检测内容
|
||
"""
|
||
if platform_id not in PLATFORM_RULES:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail=f"平台不存在: {platform_id}",
|
||
)
|
||
|
||
detected = platform_rule_service.detect_ai_patterns(content, platform_id)
|
||
|
||
return {
|
||
"platform_id": platform_id,
|
||
"content_length": len(content),
|
||
"detected_patterns": detected,
|
||
"total_detected": len(detected),
|
||
}
|
||
|
||
|
||
@router.get("/{platform_id}/tips")
|
||
async def get_platform_optimization_tips(platform_id: str):
|
||
"""获取平台优化建议
|
||
|
||
Args:
|
||
platform_id: 平台标识
|
||
"""
|
||
if platform_id not in PLATFORM_RULES:
|
||
raise HTTPException(
|
||
status_code=status.HTTP_404_NOT_FOUND,
|
||
detail=f"平台不存在: {platform_id}",
|
||
)
|
||
|
||
tips = platform_rule_service.get_optimization_tips(platform_id)
|
||
rules = PLATFORM_RULES.get(platform_id, {})
|
||
|
||
return {
|
||
"platform_id": platform_id,
|
||
"platform_name": rules.get("name", ""),
|
||
"content_style": rules.get("content_style", ""),
|
||
"tips": tips,
|
||
}
|
||
|
||
|
||
@router.get("/categories")
|
||
async def list_rule_categories():
|
||
"""获取所有规则类别"""
|
||
categories = rule_engine.get_all_rule_categories()
|
||
return {
|
||
"categories": categories,
|
||
"total": len(categories),
|
||
}
|