""" RecursiveChunker: 递归语义分块器 按优先级分隔符(段落→句子→词)将文档切割为适合embedding的块。 """ import re from typing import Optional class RecursiveChunker: """递归语义分块器""" def __init__( self, chunk_size: int = 512, chunk_overlap: int = 50, min_chunk_size: int = 100, ): self.chunk_size = chunk_size self.chunk_overlap = chunk_overlap self.min_chunk_size = min_chunk_size # 分隔符优先级:段落 > 句子 > 词 self.separators = ["\n\n", "\n", "。", ".", "!", "!", "?", "?", ";", ";", " "] # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------ def chunk(self, text: str, metadata: Optional[dict] = None) -> list[dict]: """ 将文本递归分块。 Returns: list of dicts: { "content": str, "chunk_index": int, "token_count": int, "metadata": dict, } """ if not text or not text.strip(): return [] raw_chunks = self._split_recursive(text.strip(), self.separators) # 合并过短的块 & 添加重叠 merged = self._merge_small_chunks(raw_chunks) result = [] for idx, content in enumerate(merged): result.append( { "content": content, "chunk_index": idx, "token_count": self._estimate_tokens(content), "metadata": metadata or {}, } ) return result # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ def _split_recursive(self, text: str, separators: list[str]) -> list[str]: """ 递归分割:尝试当前分隔符,若块太大则用下一级分隔符继续。 """ if not separators: # 最后手段:按字符强制截断 return self._hard_split(text) sep = separators[0] remaining_seps = separators[1:] # 先用当前分隔符分割 splits = text.split(sep) # 去掉空串,但保留分隔符语义(拼回去) splits = [s for s in splits if s.strip()] if len(splits) <= 1: # 该分隔符无法分割,尝试下一级 return self._split_recursive(text, remaining_seps) chunks: list[str] = [] current_buffer = "" for piece in splits: candidate = (current_buffer + sep + piece).strip() if current_buffer else piece.strip() if self._estimate_tokens(candidate) <= self.chunk_size: current_buffer = candidate else: # 当前 buffer 达到 chunk_size if current_buffer: # buffer 本身是否太大?若是,递归细分 if self._estimate_tokens(current_buffer) > self.chunk_size: chunks.extend(self._split_recursive(current_buffer, remaining_seps)) else: chunks.append(current_buffer) # piece 单独处理 if self._estimate_tokens(piece) > self.chunk_size: chunks.extend(self._split_recursive(piece, remaining_seps)) current_buffer = "" else: current_buffer = piece.strip() if current_buffer: if self._estimate_tokens(current_buffer) > self.chunk_size: chunks.extend(self._split_recursive(current_buffer, remaining_seps)) else: chunks.append(current_buffer) return [c for c in chunks if c.strip()] def _merge_small_chunks(self, chunks: list[str]) -> list[str]: """ 合并过短的块(< min_chunk_size token),并在相邻块间加入重叠文本。 """ if not chunks: return [] merged: list[str] = [] buffer = chunks[0] for chunk in chunks[1:]: if self._estimate_tokens(buffer) < self.min_chunk_size: buffer = buffer + "\n" + chunk else: merged.append(buffer) # 添加重叠:取上一块末尾若干 token 作为前缀 overlap_text = self._get_overlap_prefix(buffer) buffer = (overlap_text + "\n" + chunk).strip() if overlap_text else chunk merged.append(buffer) return [c for c in merged if c.strip()] def _get_overlap_prefix(self, text: str) -> str: """截取文本末尾作为下一块的重叠前缀(按 token 估算)。""" if self.chunk_overlap <= 0: return "" # 简单实现:按字符比例截取 words = text.split() if not words: return "" # 估算每个词约 1.5 token(中英混合) token_per_word = 1.5 overlap_words = int(self.chunk_overlap / token_per_word) overlap_words = max(1, min(overlap_words, len(words))) return " ".join(words[-overlap_words:]) def _hard_split(self, text: str) -> list[str]: """按字符强制截断(最后手段)。""" # 粗略:1 token ≈ 2 字符(中文) char_limit = self.chunk_size * 2 chunks = [] start = 0 while start < len(text): end = start + char_limit chunks.append(text[start:end]) start = end - self.chunk_overlap * 2 # 加入字符级重叠 if start <= 0: start = end return chunks def _estimate_tokens(self, text: str) -> int: """ 估算 token 数。 规则:中文字符每字计 1 token,英文单词计 1.3 token(BPE 碎片系数)。 """ if not text: return 0 # 中文字符计数 chinese_chars = len(re.findall(r"[\u4e00-\u9fff\u3400-\u4dbf]", text)) # 去掉中文后,计算英文单词数 non_chinese = re.sub(r"[\u4e00-\u9fff\u3400-\u4dbf]", " ", text) english_words = len(non_chinese.split()) return int(chinese_chars + english_words * 1.3)