diff --git a/backend/app/services/content/__init__.py b/backend/app/services/content/__init__.py new file mode 100644 index 0000000..e4d2ebe --- /dev/null +++ b/backend/app/services/content/__init__.py @@ -0,0 +1,3 @@ +from app.services.content.html_generator import HTMLGenerator + +__all__ = ["HTMLGenerator"] diff --git a/backend/app/services/content/content_pipeline.py b/backend/app/services/content/content_pipeline.py new file mode 100644 index 0000000..d0f2167 --- /dev/null +++ b/backend/app/services/content/content_pipeline.py @@ -0,0 +1,152 @@ +import time +from dataclasses import dataclass, field +from typing import Optional, Any, List + +from app.services.content.rule_validator import RuleValidator +from app.services.content.sensitive_filter import SensitiveFilter +from app.services.content.seo_optimizer import SEOOptimizer +from app.services.content.html_generator import HTMLGenerator + + +@dataclass +class PipelineStage: + name: str + passed: bool + result: Any = None + duration: float = 0.0 + error: Optional[str] = None + + +@dataclass +class PipelineOutput: + html: str = "" + markdown: str = "" + plain: str = "" + + +@dataclass +class PipelineResponse: + stages: List[PipelineStage] = field(default_factory=list) + outputs: Optional[PipelineOutput] = None + error: Optional[str] = None + + +class ContentPipeline: + def __init__(self): + self.validator = RuleValidator() + self.sensitive_filter = SensitiveFilter() + self.seo_optimizer = SEOOptimizer() + self.html_generator = HTMLGenerator() + + async def run(self, request: dict) -> PipelineResponse: + """执行完整内容处理Pipeline""" + stages = [] + content = request.get("content", "") + title = request.get("title", "") + platform = request.get("platform", "") + optimize_for = request.get("optimize_for", ["validation"]) + output_formats = request.get("output_formats", ["html", "markdown", "plain"]) + + current_content = content + + try: + # Stage 1: 规则校验 + if "validation" in optimize_for: + start = time.time() + try: + validation_result = self.validator.validate(current_content, title, platform) + duration = time.time() - start + + stages.append(PipelineStage( + name="validation", + passed=validation_result.is_valid, + result=validation_result, + duration=duration + )) + + # 如果校验失败(高严重级别问题),中断Pipeline + if not validation_result.is_valid: + return PipelineResponse( + stages=stages, + outputs=None, + error="内容校验未通过" + ) + except Exception as e: + stages.append(PipelineStage( + name="validation", + passed=False, + error=str(e), + duration=time.time() - start + )) + return PipelineResponse(stages=stages, error=str(e)) + + # Stage 2: 敏感词过滤 + if "sensitive" in optimize_for: + start = time.time() + try: + filter_result = self.sensitive_filter.filter(current_content, platform) + duration = time.time() - start + current_content = filter_result.filtered_content + + stages.append(PipelineStage( + name="sensitive_filter", + passed=True, + result=filter_result, + duration=duration + )) + except Exception as e: + stages.append(PipelineStage( + name="sensitive_filter", + passed=False, + error=str(e), + duration=time.time() - start + )) + + # Stage 3: SEO优化 + if "seo" in optimize_for: + start = time.time() + try: + keyword = request.get("keyword", "") + seo_result = self.seo_optimizer.optimize(current_content, title, platform, keyword) + duration = time.time() - start + + stages.append(PipelineStage( + name="seo_optimization", + passed=True, + result=seo_result, + duration=duration + )) + except Exception as e: + stages.append(PipelineStage( + name="seo_optimization", + passed=False, + error=str(e), + duration=time.time() - start + )) + + # Stage 4: HTML生成 + outputs = PipelineOutput() + + if "html" in output_formats or (not output_formats): + outputs.html = self.html_generator.generate(current_content, platform, "html") + + if "markdown" in output_formats: + outputs.markdown = self.html_generator.to_markdown(current_content) + + if "plain" in output_formats: + outputs.plain = self.html_generator.to_plain(current_content) + + stages.append(PipelineStage( + name="html_generation", + passed=True, + result=outputs + )) + + return PipelineResponse(stages=stages, outputs=outputs) + + except Exception as e: + return PipelineResponse(stages=stages, error=str(e)) + + async def validate_only(self, content: str, title: str, platform: str): + """仅执行校验,不处理""" + return self.validator.validate(content, title, platform) diff --git a/backend/app/services/content/html_generator.py b/backend/app/services/content/html_generator.py new file mode 100644 index 0000000..1140637 --- /dev/null +++ b/backend/app/services/content/html_generator.py @@ -0,0 +1,118 @@ +import re +from typing import Optional + + +class HTMLGenerator: + """HTML生成器 - 根据平台规则生成适配HTML""" + + def generate(self, content: str, platform: str, format: str = "html") -> str: + """根据平台规则生成HTML + + Args: + content: HTML内容 + platform: 平台标识 + format: 输出格式 (html/markdown/plain) + + Returns: + 处理后的内容 + """ + from app.services.distribution.platform_rules import PLATFORM_RULES + + rules = PLATFORM_RULES.get(platform, {}) + html_rules = rules.get("html_rules", {}) + + # 获取平台支持的标签和禁用标签 + banned_tags = html_rules.get("banned_tags", []) + + result = content + + # 移除禁用的标签及其内容 + for tag in banned_tags: + # 移除带内容的标签 + result = re.sub( + f"<{tag}[^>]*>.*?{tag}>", "", result, flags=re.DOTALL | re.IGNORECASE + ) + # 移除自闭合标签 + result = re.sub(f"<{tag}[^>]*/?>", "", result, flags=re.IGNORECASE) + + # 平台特定处理 + if platform == "wechat": + # 微信公众号:移除外部链接 + result = re.sub( + r"]*href=['\"]https?://(?!mp\.weixin\.qq\.com)[^'\"]*['\"][^>]*>", + "", + result, + flags=re.IGNORECASE, + ) + # 移除链接文本但保留内部内容 + result = re.sub( + r"", "", result, flags=re.IGNORECASE + ) + + if format == "markdown": + return self.to_markdown(result) + elif format == "plain": + return self.to_plain(result) + + return result + + def to_markdown(self, content: str) -> str: + """HTML转Markdown + + Args: + content: HTML内容 + + Returns: + Markdown格式内容 + """ + # h1 -> # + content = re.sub(r"
]*>(.*?)
", r"\1\n\n", content, flags=re.IGNORECASE) + # br -> 换行 + content = re.sub(r"]*>(.*?)", r"> \1", content, flags=re.IGNORECASE | re.DOTALL) + # code inline + content = re.sub(r"
]*>(.*?)", r"`\1`", content, flags=re.IGNORECASE)
+ # pre
+ content = re.sub(r"]*>(.*?)", r"```\n\1\n```", content, flags=re.IGNORECASE | re.DOTALL) + # 清理残留标签 + content = re.sub(r"<[^>]+>", "", content) + # 清理多余空行 + content = re.sub(r"\n{3,}", r"\n\n", content) + + return content.strip() + + def to_plain(self, content: str) -> str: + """HTML转纯文本 + + Args: + content: HTML内容 + + Returns: + 纯文本内容 + """ + # 移除所有HTML标签 + text = re.sub(r"<[^>]+>", "", content) + # 解码HTML实体 + text = text.replace(" ", " ") + text = text.replace("<", "<") + text = text.replace(">", ">") + text = text.replace("&", "&") + text = text.replace(""", '"') + text = text.replace("'", "'") + # 清理多余空格 + text = re.sub(r" {2,}", " ", text) + # 清理多余换行 + text = re.sub(r"\n{3,}", r"\n\n", text) + + return text.strip() diff --git a/backend/app/services/content/rule_validator.py b/backend/app/services/content/rule_validator.py new file mode 100644 index 0000000..009e0d2 --- /dev/null +++ b/backend/app/services/content/rule_validator.py @@ -0,0 +1,318 @@ +"""内容规则校验服务""" +import re +from dataclasses import dataclass +from typing import Optional + +from app.services.distribution.platform_rules import PLATFORM_RULES + + +@dataclass +class ValidationIssue: + """校验问题""" + severity: str # high, medium, low + message: str + category: str + + +@dataclass +class ValidationResult: + """校验结果""" + is_valid: bool + score: int + issues: list # list of ValidationIssue + passed: list # list of str + + +@dataclass +class AI_Pattern: + """AI写作特征""" + pattern: str + type: str # banned_word, banned_structure + severity: str # medium, high + + +class RuleValidator: + """内容规则校验器""" + + def validate(self, content: str, title: str, platform: str) -> ValidationResult: + """ + 校验内容是否符合平台规则 + + Args: + content: 内容正文 + title: 标题 + platform: 平台标识 + + Returns: + ValidationResult: 校验结果 + """ + rules = PLATFORM_RULES.get(platform) + if not rules: + raise ValueError(f"不支持的平台: {platform}") + + issues: list[ValidationIssue] = [] + passed: list[str] = [] + + # 标题长度校验 + title_len = len(title) + title_rules = rules.get("title_rules", {}) + max_title = title_rules.get("max_length", 30) + min_title = title_rules.get("min_length", 5) + + if title_len > max_title: + issues.append(ValidationIssue( + "high", + f"标题长度 {title_len} 超过限制 {max_title}", + "title_length" + )) + elif title_len < min_title: + issues.append(ValidationIssue( + "medium", + f"标题长度 {title_len} 低于最低要求 {min_title}", + "title_length" + )) + else: + passed.append(f"标题长度合规({title_len}/{max_title})") + + # 内容长度校验 + content_len = len(content) + content_rules = rules.get("content_length", {}) + max_content = content_rules.get("max", 20000) + min_content = content_rules.get("min", 0) + + if content_len > max_content: + issues.append(ValidationIssue( + "high", + f"内容长度 {content_len} 超过限制 {max_content}", + "content_length" + )) + elif min_content > 0 and content_len < min_content: + issues.append(ValidationIssue( + "medium", + f"内容长度 {content_len} 低于建议最低 {min_content}", + "content_length" + )) + else: + passed.append(f"内容长度合规({content_len}/{max_content})") + + # AI模式检测 + ai_sensitivity = rules.get("ai_sensitivity", {}) + if ai_sensitivity.get("humanization_required", False): + ai_results = self.detect_ai_patterns(content, platform) + for result in ai_results: + issues.append(ValidationIssue( + "medium", + f"发现AI写作特征: {result.pattern}", + "ai_pattern" + )) + + # 平台特定规则 + platform_issues, platform_passed = self._validate_platform_specific(content, title, platform) + issues.extend(platform_issues) + passed.extend(platform_passed) + + # 计算分数 + penalty = sum( + 15 if i.severity == "high" else 8 if i.severity == "medium" else 3 + for i in issues + ) + score = max(0, 100 - penalty) + + # 判断是否有效(无high级别问题) + is_valid = all(i.severity != "high" for i in issues) + + return ValidationResult(is_valid, score, issues, passed) + + def detect_ai_patterns(self, content: str, platform: str) -> list[AI_Pattern]: + """ + 检测AI写作模式 + + Args: + content: 内容正文 + platform: 平台标识 + + Returns: + list[AI_Pattern]: 检测到的AI特征列表 + """ + rules = PLATFORM_RULES.get(platform) + if not rules: + return [] + + results: list[AI_Pattern] = [] + ai_config = rules.get("ai_sensitivity", {}) + banned_patterns = ai_config.get("banned_patterns", []) + banned_structures = ai_config.get("banned_structures", []) + + # 检测禁用词汇 + for pattern in banned_patterns: + if pattern in content: + results.append(AI_Pattern(pattern, "banned_word", "medium")) + + # 检测禁用结构 + for structure in banned_structures: + if re.search(structure, content): + results.append(AI_Pattern(structure, "banned_structure", "medium")) + break + + return results + + def get_optimization_tips(self, platform: str) -> list[str]: + """ + 获取平台优化建议 + + Args: + platform: 平台标识 + + Returns: + list[str]: 优化建议列表 + """ + rules = PLATFORM_RULES.get(platform) + if not rules: + return [] + return rules.get("seo_tips", []) + + def _validate_platform_specific( + self, content: str, title: str, platform: str + ) -> tuple: + """平台特定规则校验""" + issues: list[ValidationIssue] = [] + passed: list[str] = [] + + # 诱导分享/关注检测 + inducing_patterns = re.compile( + r"(转发|分享|关注|点赞|收藏).{0,4}(领|获|得|拿|解锁|免费)", + re.IGNORECASE, + ) + + # 连续特殊符号 + consecutive_symbols = re.compile(r"[!!??]{3,}") + + # 外部链接(排除公众号和小程序链接) + external_link = re.compile( + r"https?://(?!mp\.weixin\.qq\.com|wx\.qq\.com|weixin://)[^\s<>))]+", + re.IGNORECASE, + ) + + # 标题党词汇 + clickbait_words = {"震惊", "惊呆", "吓死", "笑死", "疯传", "刷屏", "出大事", "不敢相信"} + + # 水印检测 + watermark_patterns = re.compile( + r"(抖音|快手|小红书|微博|B站|bilibili).*(水印|logo)", + re.IGNORECASE, + ) + + if platform == "wechat": + # 诱导分享/关注 + if inducing_patterns.search(title) or inducing_patterns.search(content): + issues.append(ValidationIssue( + "high", + "包含诱导分享/关注语句", + "platform_rule" + )) + else: + passed.append("无诱导分享/关注语句") + + # 连续特殊符号 + if consecutive_symbols.search(title): + issues.append(ValidationIssue( + "medium", + "标题包含连续特殊符号", + "title_format" + )) + else: + passed.append("标题无连续特殊符号") + + # 外部链接 + if external_link.search(content): + issues.append(ValidationIssue( + "high", + "正文包含外部链接(仅支持公众号链接和小程序)", + "platform_rule" + )) + else: + passed.append("无外部链接") + + # 营销用语检测 + marketing_words = ["购买", "下单", "优惠价", "限时折扣", "点击购买"] + found_marketing = [w for w in marketing_words if w in content] + if found_marketing: + issues.append(ValidationIssue( + "medium", + f"疑似营销用语: {', '.join(found_marketing)}", + "platform_rule" + )) + else: + passed.append("未检测到过度营销用语") + + elif platform == "zhihu": + # 营销内容检测 + marketing_words = ["购买", "下单", "优惠价", "限时折扣", "点击购买"] + found_marketing = [w for w in marketing_words if w in content] + if found_marketing: + issues.append(ValidationIssue( + "medium", + f"疑似营销用语: {', '.join(found_marketing)}", + "platform_rule" + )) + else: + passed.append("未检测到过度营销用语") + + elif platform == "xiaohongshu": + # 字数建议 + content_len = len(content) + if content_len > 800: + issues.append(ValidationIssue( + "medium", + f"正文建议300-800字,当前 {content_len} 字", + "content_length" + )) + elif content_len < 300: + issues.append(ValidationIssue( + "low", + f"正文建议300-800字,当前仅 {content_len} 字", + "content_length" + )) + else: + passed.append(f"正文字数适宜({content_len}字)") + + # 其他平台引流 + cross_platform_keywords = ["微信", "公众号", "抖音号", "微博"] + found_cross = [p for p in cross_platform_keywords if p in content] + if found_cross: + issues.append(ValidationIssue( + "high", + f"疑似其他平台引流: {', '.join(found_cross)}", + "platform_rule" + )) + else: + passed.append("未检测到其他平台引流信息") + + elif platform in ("baijiahao", "toutiao"): + # 标题党检测 + found_clickbait = clickbait_words & set(title) + if found_clickbait: + issues.append(ValidationIssue( + "high", + f"标题含标题党词汇: {', '.join(found_clickbait)}", + "title_content" + )) + else: + passed.append("标题无标题党词汇") + + elif platform == "douyin": + # 水印检测 + if watermark_patterns.search(content): + issues.append(ValidationIssue( + "high", + "内容包含其他平台水印信息", + "platform_rule" + )) + else: + passed.append("未检测到其他平台水印") + + return issues, passed + + +# 导出单例 +validator = RuleValidator() diff --git a/backend/app/services/content/sensitive_filter.py b/backend/app/services/content/sensitive_filter.py new file mode 100644 index 0000000..2290881 --- /dev/null +++ b/backend/app/services/content/sensitive_filter.py @@ -0,0 +1,129 @@ +"""敏感词过滤服务""" + +import re +from dataclasses import dataclass, field +from typing import Optional + +# 基础敏感词库 +SENSITIVE_WORDS = { + "politics": [ + "台湾", "西藏", "新疆", "香港", "澳门", + "分裂", "独立", "抗议", "游行", "示威", + "政治", "敏感词", + ], + "medical": [ + "药品", "治疗", "疗效", "治愈", + "处方", "医生", "医院", "手术", + "医疗", "敏感词", + ], + "finance": [ + "投资", "理财", "收益率", "回报", + "股票", "基金", "债券", "期货", + ], + "adult": [ + "色情", "赌博", "毒品", "暴力", + ], +} + +REPLACEMENT_CHAR = "*" + + +@dataclass +class FoundWord: + """发现的敏感词""" + word: str + category: str + position: int + replacement: str + + +@dataclass +class FilterResult: + """过滤结果""" + filtered_content: str + found_words: list = field(default_factory=list) + replacements: dict = field(default_factory=dict) + + +class SensitiveFilter: + """敏感词过滤器""" + + def __init__(self): + self.custom_words: dict = {} + self.replacement_char = REPLACEMENT_CHAR + + def filter(self, content: str, platform: str) -> FilterResult: + """过滤敏感词 + + Args: + content: 待过滤的内容 + platform: 平台标识 + + Returns: + FilterResult: 包含过滤后内容、发现的敏感词和替换映射 + """ + # 获取平台的敏感词配置 + from app.services.distribution.platform_rules import PLATFORM_RULES + + rules = PLATFORM_RULES.get(platform, {}) + sensitive_config = rules.get("sensitive_words", {}) + + check_required = sensitive_config.get("check_required", True) + if not check_required: + return FilterResult(content, [], {}) + + categories = sensitive_config.get("categories", ["politics"]) + max_tolerance = sensitive_config.get("max_tolerance", 0) + + # 合并基础词库和自定义词库 + all_words = {} + for cat in categories: + all_words[cat] = [] + if cat in SENSITIVE_WORDS: + all_words[cat].extend(SENSITIVE_WORDS[cat]) + if cat in self.custom_words: + all_words[cat].extend(self.custom_words[cat]) + + # 自定义分类的词也需要检查,将其合并到所有启用的分类中 + for custom_cat, custom_words_list in self.custom_words.items(): + if custom_cat not in categories: + # 自定义分类不在平台启用分类中,将其添加到第一个分类 + target_cat = categories[0] + all_words[target_cat].extend(custom_words_list) + + found_words = [] + filtered = content + replacements = {} + + for category, words in all_words.items(): + for word in words: + if word in filtered: + # 记录发现的敏感词 + position = filtered.find(word) + found_words.append(FoundWord( + word=word, + category=category, + position=position, + replacement=self.replacement_char * len(word) + )) + # 替换敏感词 + replacement = self.replacement_char * len(word) + filtered = filtered.replace(word, replacement) + replacements[word] = replacement + + return FilterResult( + filtered_content=filtered, + found_words=found_words, + replacements=replacements + ) + + def add_custom_words(self, category: str, words: list): + """添加自定义敏感词 + + Args: + category: 敏感词分类 + words: 敏感词列表 + """ + if category not in self.custom_words: + self.custom_words[category] = [] + self.custom_words[category].extend(words) diff --git a/backend/app/services/content/seo_optimizer.py b/backend/app/services/content/seo_optimizer.py new file mode 100644 index 0000000..5a830d4 --- /dev/null +++ b/backend/app/services/content/seo_optimizer.py @@ -0,0 +1,117 @@ +"""SEO优化服务""" + +from dataclasses import dataclass +from typing import Optional + +from app.services.distribution.platform_rules import PLATFORM_RULES + + +@dataclass +class OptimizationResult: + """SEO优化结果""" + optimized_content: str + density: float + suggestions: list + tips: list + + +class SEOOptimizer: + """SEO优化器""" + + def get_keyword_density(self, content: str, keyword: str) -> float: + """计算关键词密度 + + Args: + content: 内容文本 + keyword: 关键词 + + Returns: + 关键词密度百分比 + """ + if not keyword or not content: + return 0.0 + + content_len = len(content) + keyword_count = content.count(keyword) + + # 密度 = (关键词字符数 * 出现次数) / 总字符数 * 100 + density = (len(keyword) * keyword_count) / content_len * 100 + return round(density, 2) + + def optimize( + self, + content: str, + title: str, + platform: str, + keyword: str = "" + ) -> OptimizationResult: + """优化内容SEO + + Args: + content: 内容文本 + title: 标题 + platform: 平台标识 + keyword: 关键词 + + Returns: + OptimizationResult: 优化结果 + """ + rules = PLATFORM_RULES.get(platform, {}) + seo_rules = rules.get("seo_rules", {}) + + suggestions = [] + tips = [] + optimized = content + + # 获取推荐密度配置 + density_config = seo_rules.get("keyword_density", {"min": 1, "max": 3, "recommended": 2}) + min_density = density_config["min"] + max_density = density_config["max"] + recommended = density_config["recommended"] + + # 关键词位置 + keyword_positions = seo_rules.get("keyword_position", ["title", "first_para"]) + + # 计算当前密度 + if keyword: + current_density = self.get_keyword_density(content, keyword) + + # 密度调整建议 + if current_density < min_density: + suggestions.append( + f"关键词密度 {current_density}% 低于最低要求 {min_density}%,建议增加关键词出现次数" + ) + elif current_density > max_density: + suggestions.append( + f"关键词密度 {current_density}% 超过最高限制 {max_density}%,建议减少关键词堆砌" + ) + else: + suggestions.append(f"关键词密度 {current_density}% 在推荐范围内") + + # 关键词位置检查 + keyword_in_title = keyword in title if title else False + keyword_in_first = keyword in content[:100] if content else False + + if "title" in keyword_positions and not keyword_in_title: + suggestions.append(f"建议在标题中包含关键词「{keyword}」") + + if "first_para" in keyword_positions and not keyword_in_first: + suggestions.append(f"建议在前100字中包含关键词「{keyword}」") + + tips.extend(rules.get("seo_tips", [])) + + return OptimizationResult( + optimized_content=optimized, + density=current_density, + suggestions=suggestions, + tips=tips + ) + else: + # 无关键词时返回SEO建议 + tips.extend(rules.get("seo_tips", [])) + return OptimizationResult( + optimized_content=optimized, + density=0.0, + suggestions=["请指定要优化的关键词"], + tips=tips + ) diff --git a/backend/tests/test_content_pipeline.py b/backend/tests/test_content_pipeline.py new file mode 100644 index 0000000..57ae14f --- /dev/null +++ b/backend/tests/test_content_pipeline.py @@ -0,0 +1,89 @@ +# test_content_pipeline.py +import pytest + +# 导入实际的 ContentPipeline 实现 +from app.services.content.content_pipeline import ContentPipeline + +@pytest.mark.asyncio +async def test_pipeline_complete_run(): + """完整Pipeline执行""" + pipeline = ContentPipeline() + request = { + "content": "这是一篇测试文章内容", + "title": "测试标题", + "platform": "zhihu", + "optimize_for": ["validation", "sensitive", "seo"] + } + result = await pipeline.run(request) + + assert result.stages is not None + assert len(result.stages) > 0 + assert result.outputs is not None + +@pytest.mark.asyncio +async def test_pipeline_with_validation_fail(): + """校验失败中断""" + pipeline = ContentPipeline() + request = { + "content": "内容", + "title": "这个标题太长了超过了三十个字符的限制了哈哈哈啊", + "platform": "wechat", + "optimize_for": ["validation"] + } + result = await pipeline.run(request) + + # 校验失败时不应继续执行后续阶段 + validation_stage = next((s for s in result.stages if s.name == "validation"), None) + assert validation_stage is not None + assert validation_stage.passed == False + +@pytest.mark.asyncio +async def test_pipeline_multi_platform(): + """多平台适配""" + pipeline = ContentPipeline() + + zhihu_result = await pipeline.run({ + "content": "
测试内容
外部链接", + "title": "测试标题", + "platform": "zhihu" + }) + + wechat_result = await pipeline.run({ + "content": "测试内容
外部链接", + "title": "测试标题", + "platform": "wechat" + }) + + # 不同平台应产生不同的优化结果 + assert zhihu_result.outputs != wechat_result.outputs + +@pytest.mark.asyncio +async def test_pipeline_stage_results(): + """各阶段结果记录""" + pipeline = ContentPipeline() + result = await pipeline.run({ + "content": "内容", + "title": "标题", + "platform": "zhihu" + }) + + # 检查每个阶段的结果 + for stage in result.stages: + assert stage.name is not None + assert hasattr(stage, 'passed') or hasattr(stage, 'result') + +@pytest.mark.asyncio +async def test_pipeline_error_handling(): + """错误处理""" + pipeline = ContentPipeline() + + # 无效平台应返回错误 + try: + result = await pipeline.run({ + "content": "内容", + "title": "标题", + "platform": "invalid_platform" + }) + assert result.error is not None + except ValueError as e: + assert "不支持的平台" in str(e) diff --git a/backend/tests/test_html_generator.py b/backend/tests/test_html_generator.py new file mode 100644 index 0000000..fbbc1e0 --- /dev/null +++ b/backend/tests/test_html_generator.py @@ -0,0 +1,54 @@ +# test_html_generator.py +import pytest + +# 使用实际实现的 HTMLGenerator +from app.services.content.html_generator import HTMLGenerator + +def test_filter_banned_tags_zhihu(): + """知乎HTML标签过滤""" + generator = HTMLGenerator() + html = generator.generate( + content="这是内容
", + platform="zhihu" + ) + assert "