179 lines
6.3 KiB
Python
179 lines
6.3 KiB
Python
"""
|
||
RecursiveChunker: 递归语义分块器
|
||
按优先级分隔符(段落→句子→词)将文档切割为适合embedding的块。
|
||
"""
|
||
import re
|
||
from typing import Optional
|
||
|
||
|
||
class RecursiveChunker:
|
||
"""递归语义分块器"""
|
||
|
||
def __init__(
|
||
self,
|
||
chunk_size: int = 512,
|
||
chunk_overlap: int = 50,
|
||
min_chunk_size: int = 100,
|
||
):
|
||
self.chunk_size = chunk_size
|
||
self.chunk_overlap = chunk_overlap
|
||
self.min_chunk_size = min_chunk_size
|
||
# 分隔符优先级:段落 > 句子 > 词
|
||
self.separators = ["\n\n", "\n", "。", ".", "!", "!", "?", "?", ";", ";", " "]
|
||
|
||
# ------------------------------------------------------------------
|
||
# Public API
|
||
# ------------------------------------------------------------------
|
||
|
||
def chunk(self, text: str, metadata: Optional[dict] = None) -> list[dict]:
|
||
"""
|
||
将文本递归分块。
|
||
|
||
Returns:
|
||
list of dicts:
|
||
{
|
||
"content": str,
|
||
"chunk_index": int,
|
||
"token_count": int,
|
||
"metadata": dict,
|
||
}
|
||
"""
|
||
if not text or not text.strip():
|
||
return []
|
||
|
||
raw_chunks = self._split_recursive(text.strip(), self.separators)
|
||
|
||
# 合并过短的块 & 添加重叠
|
||
merged = self._merge_small_chunks(raw_chunks)
|
||
result = []
|
||
for idx, content in enumerate(merged):
|
||
result.append(
|
||
{
|
||
"content": content,
|
||
"chunk_index": idx,
|
||
"token_count": self._estimate_tokens(content),
|
||
"metadata": metadata or {},
|
||
}
|
||
)
|
||
return result
|
||
|
||
# ------------------------------------------------------------------
|
||
# Internal helpers
|
||
# ------------------------------------------------------------------
|
||
|
||
def _split_recursive(self, text: str, separators: list[str]) -> list[str]:
|
||
"""
|
||
递归分割:尝试当前分隔符,若块太大则用下一级分隔符继续。
|
||
"""
|
||
if not separators:
|
||
# 最后手段:按字符强制截断
|
||
return self._hard_split(text)
|
||
|
||
sep = separators[0]
|
||
remaining_seps = separators[1:]
|
||
|
||
# 先用当前分隔符分割
|
||
splits = text.split(sep)
|
||
# 去掉空串,但保留分隔符语义(拼回去)
|
||
splits = [s for s in splits if s.strip()]
|
||
|
||
if len(splits) <= 1:
|
||
# 该分隔符无法分割,尝试下一级
|
||
return self._split_recursive(text, remaining_seps)
|
||
|
||
chunks: list[str] = []
|
||
current_buffer = ""
|
||
|
||
for piece in splits:
|
||
candidate = (current_buffer + sep + piece).strip() if current_buffer else piece.strip()
|
||
if self._estimate_tokens(candidate) <= self.chunk_size:
|
||
current_buffer = candidate
|
||
else:
|
||
# 当前 buffer 达到 chunk_size
|
||
if current_buffer:
|
||
# buffer 本身是否太大?若是,递归细分
|
||
if self._estimate_tokens(current_buffer) > self.chunk_size:
|
||
chunks.extend(self._split_recursive(current_buffer, remaining_seps))
|
||
else:
|
||
chunks.append(current_buffer)
|
||
# piece 单独处理
|
||
if self._estimate_tokens(piece) > self.chunk_size:
|
||
chunks.extend(self._split_recursive(piece, remaining_seps))
|
||
current_buffer = ""
|
||
else:
|
||
current_buffer = piece.strip()
|
||
|
||
if current_buffer:
|
||
if self._estimate_tokens(current_buffer) > self.chunk_size:
|
||
chunks.extend(self._split_recursive(current_buffer, remaining_seps))
|
||
else:
|
||
chunks.append(current_buffer)
|
||
|
||
return [c for c in chunks if c.strip()]
|
||
|
||
def _merge_small_chunks(self, chunks: list[str]) -> list[str]:
|
||
"""
|
||
合并过短的块(< min_chunk_size token),并在相邻块间加入重叠文本。
|
||
"""
|
||
if not chunks:
|
||
return []
|
||
|
||
merged: list[str] = []
|
||
buffer = chunks[0]
|
||
|
||
for chunk in chunks[1:]:
|
||
if self._estimate_tokens(buffer) < self.min_chunk_size:
|
||
buffer = buffer + "\n" + chunk
|
||
else:
|
||
merged.append(buffer)
|
||
# 添加重叠:取上一块末尾若干 token 作为前缀
|
||
overlap_text = self._get_overlap_prefix(buffer)
|
||
buffer = (overlap_text + "\n" + chunk).strip() if overlap_text else chunk
|
||
|
||
merged.append(buffer)
|
||
return [c for c in merged if c.strip()]
|
||
|
||
def _get_overlap_prefix(self, text: str) -> str:
|
||
"""截取文本末尾作为下一块的重叠前缀(按 token 估算)。"""
|
||
if self.chunk_overlap <= 0:
|
||
return ""
|
||
# 简单实现:按字符比例截取
|
||
words = text.split()
|
||
if not words:
|
||
return ""
|
||
# 估算每个词约 1.5 token(中英混合)
|
||
token_per_word = 1.5
|
||
overlap_words = int(self.chunk_overlap / token_per_word)
|
||
overlap_words = max(1, min(overlap_words, len(words)))
|
||
return " ".join(words[-overlap_words:])
|
||
|
||
def _hard_split(self, text: str) -> list[str]:
|
||
"""按字符强制截断(最后手段)。"""
|
||
# 粗略:1 token ≈ 2 字符(中文)
|
||
char_limit = self.chunk_size * 2
|
||
chunks = []
|
||
start = 0
|
||
while start < len(text):
|
||
end = start + char_limit
|
||
chunks.append(text[start:end])
|
||
start = end - self.chunk_overlap * 2 # 加入字符级重叠
|
||
if start <= 0:
|
||
start = end
|
||
return chunks
|
||
|
||
def _estimate_tokens(self, text: str) -> int:
|
||
"""
|
||
估算 token 数。
|
||
规则:中文字符每字计 1 token,英文单词计 1.3 token(BPE 碎片系数)。
|
||
"""
|
||
if not text:
|
||
return 0
|
||
|
||
# 中文字符计数
|
||
chinese_chars = len(re.findall(r"[\u4e00-\u9fff\u3400-\u4dbf]", text))
|
||
# 去掉中文后,计算英文单词数
|
||
non_chinese = re.sub(r"[\u4e00-\u9fff\u3400-\u4dbf]", " ", text)
|
||
english_words = len(non_chinese.split())
|
||
|
||
return int(chinese_chars + english_words * 1.3)
|