geo/backend/app/services/knowledge/chunker.py

179 lines
6.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
RecursiveChunker: 递归语义分块器
按优先级分隔符段落→句子→词将文档切割为适合embedding的块。
"""
import re
from typing import Optional
class RecursiveChunker:
"""递归语义分块器"""
def __init__(
self,
chunk_size: int = 512,
chunk_overlap: int = 50,
min_chunk_size: int = 100,
):
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.min_chunk_size = min_chunk_size
# 分隔符优先级:段落 > 句子 > 词
self.separators = ["\n\n", "\n", "", ".", "", "!", "", "?", "", ";", " "]
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def chunk(self, text: str, metadata: Optional[dict] = None) -> list[dict]:
"""
将文本递归分块。
Returns:
list of dicts:
{
"content": str,
"chunk_index": int,
"token_count": int,
"metadata": dict,
}
"""
if not text or not text.strip():
return []
raw_chunks = self._split_recursive(text.strip(), self.separators)
# 合并过短的块 & 添加重叠
merged = self._merge_small_chunks(raw_chunks)
result = []
for idx, content in enumerate(merged):
result.append(
{
"content": content,
"chunk_index": idx,
"token_count": self._estimate_tokens(content),
"metadata": metadata or {},
}
)
return result
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _split_recursive(self, text: str, separators: list[str]) -> list[str]:
"""
递归分割:尝试当前分隔符,若块太大则用下一级分隔符继续。
"""
if not separators:
# 最后手段:按字符强制截断
return self._hard_split(text)
sep = separators[0]
remaining_seps = separators[1:]
# 先用当前分隔符分割
splits = text.split(sep)
# 去掉空串,但保留分隔符语义(拼回去)
splits = [s for s in splits if s.strip()]
if len(splits) <= 1:
# 该分隔符无法分割,尝试下一级
return self._split_recursive(text, remaining_seps)
chunks: list[str] = []
current_buffer = ""
for piece in splits:
candidate = (current_buffer + sep + piece).strip() if current_buffer else piece.strip()
if self._estimate_tokens(candidate) <= self.chunk_size:
current_buffer = candidate
else:
# 当前 buffer 达到 chunk_size
if current_buffer:
# buffer 本身是否太大?若是,递归细分
if self._estimate_tokens(current_buffer) > self.chunk_size:
chunks.extend(self._split_recursive(current_buffer, remaining_seps))
else:
chunks.append(current_buffer)
# piece 单独处理
if self._estimate_tokens(piece) > self.chunk_size:
chunks.extend(self._split_recursive(piece, remaining_seps))
current_buffer = ""
else:
current_buffer = piece.strip()
if current_buffer:
if self._estimate_tokens(current_buffer) > self.chunk_size:
chunks.extend(self._split_recursive(current_buffer, remaining_seps))
else:
chunks.append(current_buffer)
return [c for c in chunks if c.strip()]
def _merge_small_chunks(self, chunks: list[str]) -> list[str]:
"""
合并过短的块(< min_chunk_size token并在相邻块间加入重叠文本。
"""
if not chunks:
return []
merged: list[str] = []
buffer = chunks[0]
for chunk in chunks[1:]:
if self._estimate_tokens(buffer) < self.min_chunk_size:
buffer = buffer + "\n" + chunk
else:
merged.append(buffer)
# 添加重叠:取上一块末尾若干 token 作为前缀
overlap_text = self._get_overlap_prefix(buffer)
buffer = (overlap_text + "\n" + chunk).strip() if overlap_text else chunk
merged.append(buffer)
return [c for c in merged if c.strip()]
def _get_overlap_prefix(self, text: str) -> str:
"""截取文本末尾作为下一块的重叠前缀(按 token 估算)。"""
if self.chunk_overlap <= 0:
return ""
# 简单实现:按字符比例截取
words = text.split()
if not words:
return ""
# 估算每个词约 1.5 token中英混合
token_per_word = 1.5
overlap_words = int(self.chunk_overlap / token_per_word)
overlap_words = max(1, min(overlap_words, len(words)))
return " ".join(words[-overlap_words:])
def _hard_split(self, text: str) -> list[str]:
"""按字符强制截断(最后手段)。"""
# 粗略1 token ≈ 2 字符(中文)
char_limit = self.chunk_size * 2
chunks = []
start = 0
while start < len(text):
end = start + char_limit
chunks.append(text[start:end])
start = end - self.chunk_overlap * 2 # 加入字符级重叠
if start <= 0:
start = end
return chunks
def _estimate_tokens(self, text: str) -> int:
"""
估算 token 数。
规则:中文字符每字计 1 token英文单词计 1.3 tokenBPE 碎片系数)。
"""
if not text:
return 0
# 中文字符计数
chinese_chars = len(re.findall(r"[\u4e00-\u9fff\u3400-\u4dbf]", text))
# 去掉中文后,计算英文单词数
non_chinese = re.sub(r"[\u4e00-\u9fff\u3400-\u4dbf]", " ", text)
english_words = len(non_chinese.split())
return int(chinese_chars + english_words * 1.3)