geo/backend/app/api/platform_rules.py

470 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""平台规则管理 API 路由"""
import logging
from datetime import datetime
from typing import Any, Optional
from fastapi import APIRouter, Depends, HTTPException, Query, status
from pydantic import ValidationError
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.database import get_db
from app.models.platform_rule_version import PlatformRuleVersion
from app.models.user import User
from app.services.distribution.platform_rules import (
PLATFORM_RULES,
rule_engine,
)
from app.services.distribution.rule_service import platform_rule_service
from app.schemas.platform_rule import (
ContentValidationIssue,
ContentValidationResponse,
ContentValidateRequest,
DeAIContentRequest,
DeAIContentResponse,
PlatformBrief,
PlatformDetailResponse,
PlatformListResponse,
PlatformRuleUpdateRequest,
PlatformRuleUpdateResponse,
RuleChangeHistory,
RuleChangeHistoryResponse,
RuleDiff,
RuleDiffResponse,
# 内部 Schema
AISensitivity,
ContentLengthRule,
GEORule,
HTMLRule,
KeywordDensity,
PublishRule,
SEORule,
SensitiveWordsConfig,
StructurePreference,
TagRule,
TitleRule,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/v1/platforms", tags=["平台规则管理"])
async def _get_rule_version(
db: AsyncSession, rule_id: str, version: int
) -> PlatformRuleVersion | None:
stmt = select(PlatformRuleVersion).where(
PlatformRuleVersion.rule_id == rule_id,
PlatformRuleVersion.version == version,
)
result = await db.execute(stmt)
return result.scalar_one_or_none()
def _compute_diff(
old_data: dict, new_data: dict, prefix: str = ""
) -> list[RuleDiff]:
diffs: list[RuleDiff] = []
all_keys = set(old_data.keys()) | set(new_data.keys())
for key in sorted(all_keys):
field = f"{prefix}{key}" if not prefix else f"{prefix}.{key}"
old_val = old_data.get(key)
new_val = new_data.get(key)
if isinstance(old_val, dict) and isinstance(new_val, dict):
diffs.extend(_compute_diff(old_val, new_val, field))
elif old_val != new_val:
diffs.append(RuleDiff(field=field, old_value=old_val, new_value=new_val))
return diffs
def _version_to_dict(v: PlatformRuleVersion) -> dict:
return {
"id": v.id,
"rule_id": v.rule_id,
"platform": v.platform,
"version": v.version,
"rule_data": v.rule_data,
"change_summary": v.change_summary,
"created_by": v.created_by,
"created_at": v.created_at.isoformat() if v.created_at else None,
}
def _convert_rule_to_schema(rules: dict) -> dict:
"""将规则字典转换为 Schema 格式"""
if not rules:
return {}
return {
"content_length": ContentLengthRule(**rules.get("content_length", {})),
"structure_preference": StructurePreference(**rules.get("structure_preference", {})),
"title_rules": TitleRule(**rules.get("title_rules", {})),
"tag_rules": TagRule(**rules.get("tag_rules", {})),
"ai_sensitivity": AISensitivity(**rules.get("ai_sensitivity", {})),
"sensitive_words": SensitiveWordsConfig(**rules.get("sensitive_words", {})),
"seo_rules": SEORule(**rules.get("seo_rules", {})),
"geo_rules": GEORule(**rules.get("geo_rules", {})),
"html_rules": HTMLRule(**rules.get("html_rules", {})),
"publish_rules": PublishRule(**rules.get("publish_rules", {})),
}
@router.get("", response_model=PlatformListResponse)
async def list_platforms(
enabled_only: bool = Query(True, description="是否只返回启用的平台"),
):
"""获取所有支持平台列表"""
platforms_raw = platform_rule_service.get_all_platforms(enabled_only=enabled_only)
platforms = [
PlatformBrief(
id=p["id"],
name=p["name"],
platform_type=p.get("platform_type", ""),
priority=p.get("priority", "P2"),
enabled=p.get("enabled", True),
)
for p in platforms_raw
]
return PlatformListResponse(platforms=platforms, total=len(platforms))
@router.get("/{platform_id}", response_model=PlatformDetailResponse)
async def get_platform_detail(platform_id: str):
"""获取平台详情
Args:
platform_id: 平台标识 (如 zhihu, wechat, baijiahao 等)
"""
rules = platform_rule_service.get_platform_detail(platform_id)
if not rules:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"平台不存在: {platform_id}",
)
# 构建响应
converted = _convert_rule_to_schema(rules)
return PlatformDetailResponse(
id=platform_id,
name=rules.get("name", ""),
platform_type=rules.get("platform_type", ""),
priority=rules.get("priority", "P2"),
enabled=rules.get("enabled", True),
content_style=rules.get("content_style", ""),
content_length=converted["content_length"],
structure_preference=converted["structure_preference"],
title_rules=converted["title_rules"],
tag_rules=converted["tag_rules"],
ai_sensitivity=converted["ai_sensitivity"],
sensitive_words=converted["sensitive_words"],
seo_rules=converted["seo_rules"],
geo_rules=converted["geo_rules"],
html_rules=converted["html_rules"],
publish_rules=converted["publish_rules"],
best_publish_times=rules.get("best_publish_times", []),
best_publish_days=rules.get("best_publish_days", []),
max_images=rules.get("max_images", 0),
)
@router.put("/{platform_id}/rules", response_model=PlatformRuleUpdateResponse)
async def update_platform_rules(
platform_id: str,
req: PlatformRuleUpdateRequest,
current_user: User = Depends(get_current_user),
):
"""更新平台规则
注意:当前实现更新的是内存中的规则。
如需持久化,需要配合数据库和规则变更历史表使用。
"""
if platform_id not in PLATFORM_RULES:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"平台不存在: {platform_id}",
)
# 获取当前规则
current_rules = PLATFORM_RULES[platform_id]
# 构建更新数据
update_data = req.model_dump(exclude_unset=True)
# 更新规则(这里只做演示,实际需要持久化到数据库)
for key, value in update_data.items():
if value is not None:
if key == "enabled":
current_rules[key] = value
elif isinstance(value, dict):
current_rules[key] = {
**current_rules.get(key, {}),
**value,
}
else:
current_rules[key] = value
logger.info(
f"用户 {current_user.id} 更新了平台 {platform_id} 的规则: {list(update_data.keys())}"
)
return PlatformRuleUpdateResponse(
success=True,
platform_id=platform_id,
message=f"规则更新成功",
updated_at=datetime.now(),
)
@router.get("/{platform_id}/rules/diff", response_model=RuleDiffResponse)
async def compare_rule_changes(
platform_id: str,
from_version: int = Query(..., description="起始版本号"),
to_version: int = Query(..., description="目标版本号"),
db: AsyncSession = Depends(get_db),
):
"""对比规则变更
Args:
platform_id: 平台标识
from_version: 起始版本号
to_version: 目标版本号
"""
if platform_id not in PLATFORM_RULES:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"平台不存在: {platform_id}",
)
current_rules = PLATFORM_RULES[platform_id]
from_rule = await _get_rule_version(db, platform_id, from_version)
to_rule = await _get_rule_version(db, platform_id, to_version)
if not from_rule or not to_rule:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="版本不存在",
)
diffs = _compute_diff(from_rule.rule_data, to_rule.rule_data)
return RuleDiffResponse(
platform_id=platform_id,
platform_name=current_rules.get("name", ""),
diffs=diffs,
total_changes=len(diffs),
)
@router.get("/{platform_id}/rules/history", response_model=RuleChangeHistoryResponse)
async def get_rule_history(
platform_id: str,
limit: int = Query(20, ge=1, le=100, description="返回记录数"),
db: AsyncSession = Depends(get_db),
):
"""获取规则变更历史
Args:
platform_id: 平台标识
limit: 返回记录数
"""
if platform_id not in PLATFORM_RULES:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"平台不存在: {platform_id}",
)
count_stmt = select(func.count()).select_from(PlatformRuleVersion).where(
PlatformRuleVersion.rule_id == platform_id
)
total = (await db.execute(count_stmt)).scalar() or 0
stmt = (
select(PlatformRuleVersion)
.where(PlatformRuleVersion.rule_id == platform_id)
.order_by(PlatformRuleVersion.version.desc())
.limit(limit)
)
result = await db.execute(stmt)
versions = result.scalars().all()
history = [
RuleChangeHistory(
id=v.version,
version=v.version,
platform_id=v.rule_id,
platform_name=v.platform,
changed_by=v.created_by or "",
change_summary=v.change_summary or "",
change_type="update",
previous_rules=None,
new_rules=v.rule_data,
created_at=v.created_at,
)
for v in versions
]
return RuleChangeHistoryResponse(
history=history,
total=total,
)
@router.post("/{platform_id}/rules/validate", response_model=ContentValidationResponse)
async def validate_content_for_platform(
platform_id: str,
req: ContentValidateRequest,
):
"""验证内容是否符合平台规则
Args:
platform_id: 平台标识
req: 验证请求
"""
if platform_id not in PLATFORM_RULES:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"平台不存在: {platform_id}",
)
result = platform_rule_service.validate_content(
content=req.content,
title=req.title,
platform_id=platform_id,
tags=req.tags,
)
issues = [
ContentValidationIssue(
severity=i["severity"],
message=i["message"],
category=i.get("category", "general"),
)
for i in result.get("issues", [])
]
return ContentValidationResponse(
is_valid=result["is_valid"],
score=result["score"],
issues=issues,
passed=result.get("passed", []),
)
@router.get("/{platform_id}/rules/{rule_category}")
async def get_platform_rule_category(
platform_id: str,
rule_category: str,
):
"""获取平台特定类别的规则
Args:
platform_id: 平台标识
rule_category: 规则类别 (title_rules, tag_rules, ai_sensitivity, etc.)
"""
valid_categories = rule_engine.get_all_rule_categories()
if rule_category not in valid_categories:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"无效的规则类别: {rule_category},有效值: {', '.join(valid_categories)}",
)
rule = platform_rule_service.get_rule_category(platform_id, rule_category)
if rule is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"平台不存在或规则类别无效: {platform_id}/{rule_category}",
)
return {
"platform_id": platform_id,
"rule_category": rule_category,
"rule": rule,
}
@router.get("/{platform_id}/ai-config")
async def get_platform_ai_config(platform_id: str):
"""获取平台AI敏感度配置用于去AI化处理
Args:
platform_id: 平台标识
"""
config = platform_rule_service.get_ai_humanization_config(platform_id)
if config is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"平台不存在: {platform_id}",
)
return {
"platform_id": platform_id,
"ai_sensitivity": config,
}
@router.post("/{platform_id}/detect-ai-patterns")
async def detect_ai_patterns_in_content(
platform_id: str,
content: str = Query(..., description="待检测内容"),
):
"""检测内容中的AI写作模式
Args:
platform_id: 平台标识
content: 待检测内容
"""
if platform_id not in PLATFORM_RULES:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"平台不存在: {platform_id}",
)
detected = platform_rule_service.detect_ai_patterns(content, platform_id)
return {
"platform_id": platform_id,
"content_length": len(content),
"detected_patterns": detected,
"total_detected": len(detected),
}
@router.get("/{platform_id}/tips")
async def get_platform_optimization_tips(platform_id: str):
"""获取平台优化建议
Args:
platform_id: 平台标识
"""
if platform_id not in PLATFORM_RULES:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"平台不存在: {platform_id}",
)
tips = platform_rule_service.get_optimization_tips(platform_id)
rules = PLATFORM_RULES.get(platform_id, {})
return {
"platform_id": platform_id,
"platform_name": rules.get("name", ""),
"content_style": rules.get("content_style", ""),
"tips": tips,
}
@router.get("/categories")
async def list_rule_categories():
"""获取所有规则类别"""
categories = rule_engine.get_all_rule_categories()
return {
"categories": categories,
"total": len(categories),
}