""" 分块策略 - 支持多种分块方式 """ import re from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Optional @dataclass class ChunkStrategy: """分块策略配置""" name: str description: str chunk_size: int # 字符数 chunk_overlap: int # 重叠字符数 min_chunk_size: int class BaseChunker(ABC): """分块器基类""" STRATEGY: ChunkStrategy = None @abstractmethod def chunk(self, text: str, metadata: Optional[dict] = None) -> list[dict]: """执行分块""" pass def preview(self, text: str, max_chunks: int = 5) -> list[str]: """预览分块结果""" chunks = self.chunk(text) return [c["content"][:200] + "..." if len(c["content"]) > 200 else c["content"] for c in chunks[:max_chunks]] class RecursiveChunker(BaseChunker): """递归分块器(现有实现)""" STRATEGY = ChunkStrategy( name="recursive", description="优先按段落分割,过长时按句子分割", chunk_size=500, chunk_overlap=50, min_chunk_size=50, ) # 分割模式(按优先级) SEPARATORS = [ r"\n\n+", # 双换行(段落) r"\n", # 单换行 r"[。!?!?]\s*", # 句子结束 r"[,,;;]\s*", # 分句 r"\s+", # 空格 ] def chunk(self, text: str, metadata: Optional[dict] = None) -> list[dict]: chunks = [] metadata = metadata or {} # 按段落分割 segments = re.split(r"\n\n+", text) current_chunk = "" for segment in segments: if len(current_chunk) + len(segment) <= self.STRATEGY.chunk_size: current_chunk += segment + "\n\n" else: # 当前块足够大,保存 if len(current_chunk.strip()) >= self.STRATEGY.min_chunk_size: chunks.append({ "content": current_chunk.strip(), "chunk_index": len(chunks), "metadata": metadata, }) # 处理过长段落 if len(segment) > self.STRATEGY.chunk_size: current_chunk = segment else: # 保留重叠 overlap = current_chunk[-self.STRATEGY.chunk_overlap:] current_chunk = overlap + segment + "\n\n" # 处理最后一个块 if len(current_chunk.strip()) >= self.STRATEGY.min_chunk_size: chunks.append({ "content": current_chunk.strip(), "chunk_index": len(chunks), "metadata": metadata, }) return chunks class SemanticChunker(BaseChunker): """语义分块器 - 按语义边界分割""" STRATEGY = ChunkStrategy( name="semantic", description="根据语义边界(标题、段落)自动分块", chunk_size=800, chunk_overlap=100, min_chunk_size=100, ) # 语义边界模式 SEMANTIC_PATTERNS = [ (r"^#{1,6}\s+(.+)$", "heading"), # Markdown标题 (r"^【(.+?)】\s*$", "heading"), # 中文标题 (r"^第[一二三四五六七八九十百]+[章节条]", "heading"), # 章节标题 (r"^(\d+\.)+\s+", "heading"), # 数字编号 ] def chunk(self, text: str, metadata: Optional[dict] = None) -> list[dict]: chunks = [] metadata = metadata or {} lines = text.split("\n") current_chunk = "" current_section = None for line in lines: # 检查是否是语义边界 is_boundary = False for pattern, boundary_type in self.SEMANTIC_PATTERNS: if re.match(pattern, line.strip()): is_boundary = True current_section = line.strip() break # 如果是边界且当前块不为空,保存 if is_boundary and current_chunk.strip(): chunks.append({ "content": current_chunk.strip(), "chunk_index": len(chunks), "section": current_section, "metadata": metadata, }) # 保留重叠 overlap = current_chunk[-self.STRATEGY.chunk_overlap:] current_chunk = overlap + line + "\n" else: current_chunk += line + "\n" # 检查块大小 if len(current_chunk) >= self.STRATEGY.chunk_size: chunks.append({ "content": current_chunk.strip(), "chunk_index": len(chunks), "section": current_section, "metadata": metadata, }) overlap = current_chunk[-self.STRATEGY.chunk_overlap:] current_chunk = overlap # 处理最后一个块 if current_chunk.strip(): chunks.append({ "content": current_chunk.strip(), "chunk_index": len(chunks), "section": current_section, "metadata": metadata, }) return chunks class FixedLengthChunker(BaseChunker): """固定长度分块器""" STRATEGY = ChunkStrategy( name="fixed", description="按固定长度强制分块", chunk_size=300, chunk_overlap=30, min_chunk_size=50, ) def chunk(self, text: str, metadata: Optional[dict] = None) -> list[dict]: chunks = [] metadata = metadata or {} # 移除多余空白 text = re.sub(r"\s+", " ", text) for i in range(0, len(text), self.STRATEGY.chunk_size - self.STRATEGY.chunk_overlap): chunk_text = text[i:i + self.STRATEGY.chunk_size] if len(chunk_text.strip()) >= self.STRATEGY.min_chunk_size: chunks.append({ "content": chunk_text.strip(), "chunk_index": len(chunks), "metadata": metadata, }) return chunks class ChunkerFactory: """分块策略工厂""" STRATEGIES = { "recursive": RecursiveChunker, "semantic": SemanticChunker, "fixed": FixedLengthChunker, } @classmethod def create(cls, strategy: str = "recursive") -> BaseChunker: """创建分块器""" chunker_cls = cls.STRATEGIES.get(strategy, RecursiveChunker) return chunker_cls() @classmethod def list_strategies(cls) -> list[ChunkStrategy]: """列出所有策略""" return [chunker_cls.STRATEGY for chunker_cls in cls.STRATEGIES.values()]