212 lines
6.9 KiB
Python
212 lines
6.9 KiB
Python
"""
|
||
分块策略 - 支持多种分块方式
|
||
"""
|
||
import re
|
||
from abc import ABC, abstractmethod
|
||
from dataclasses import dataclass
|
||
from typing import Optional
|
||
|
||
@dataclass
|
||
class ChunkStrategy:
|
||
"""分块策略配置"""
|
||
name: str
|
||
description: str
|
||
chunk_size: int # 字符数
|
||
chunk_overlap: int # 重叠字符数
|
||
min_chunk_size: int
|
||
|
||
class BaseChunker(ABC):
|
||
"""分块器基类"""
|
||
|
||
STRATEGY: ChunkStrategy = None
|
||
|
||
@abstractmethod
|
||
def chunk(self, text: str, metadata: Optional[dict] = None) -> list[dict]:
|
||
"""执行分块"""
|
||
pass
|
||
|
||
def preview(self, text: str, max_chunks: int = 5) -> list[str]:
|
||
"""预览分块结果"""
|
||
chunks = self.chunk(text)
|
||
return [c["content"][:200] + "..." if len(c["content"]) > 200 else c["content"]
|
||
for c in chunks[:max_chunks]]
|
||
|
||
class RecursiveChunker(BaseChunker):
|
||
"""递归分块器(现有实现)"""
|
||
|
||
STRATEGY = ChunkStrategy(
|
||
name="recursive",
|
||
description="优先按段落分割,过长时按句子分割",
|
||
chunk_size=500,
|
||
chunk_overlap=50,
|
||
min_chunk_size=50,
|
||
)
|
||
|
||
# 分割模式(按优先级)
|
||
SEPARATORS = [
|
||
r"\n\n+", # 双换行(段落)
|
||
r"\n", # 单换行
|
||
r"[。!?!?]\s*", # 句子结束
|
||
r"[,,;;]\s*", # 分句
|
||
r"\s+", # 空格
|
||
]
|
||
|
||
def chunk(self, text: str, metadata: Optional[dict] = None) -> list[dict]:
|
||
chunks = []
|
||
metadata = metadata or {}
|
||
|
||
# 按段落分割
|
||
segments = re.split(r"\n\n+", text)
|
||
|
||
current_chunk = ""
|
||
for segment in segments:
|
||
if len(current_chunk) + len(segment) <= self.STRATEGY.chunk_size:
|
||
current_chunk += segment + "\n\n"
|
||
else:
|
||
# 当前块足够大,保存
|
||
if len(current_chunk.strip()) >= self.STRATEGY.min_chunk_size:
|
||
chunks.append({
|
||
"content": current_chunk.strip(),
|
||
"chunk_index": len(chunks),
|
||
"metadata": metadata,
|
||
})
|
||
|
||
# 处理过长段落
|
||
if len(segment) > self.STRATEGY.chunk_size:
|
||
current_chunk = segment
|
||
else:
|
||
# 保留重叠
|
||
overlap = current_chunk[-self.STRATEGY.chunk_overlap:]
|
||
current_chunk = overlap + segment + "\n\n"
|
||
|
||
# 处理最后一个块
|
||
if len(current_chunk.strip()) >= self.STRATEGY.min_chunk_size:
|
||
chunks.append({
|
||
"content": current_chunk.strip(),
|
||
"chunk_index": len(chunks),
|
||
"metadata": metadata,
|
||
})
|
||
|
||
return chunks
|
||
|
||
class SemanticChunker(BaseChunker):
|
||
"""语义分块器 - 按语义边界分割"""
|
||
|
||
STRATEGY = ChunkStrategy(
|
||
name="semantic",
|
||
description="根据语义边界(标题、段落)自动分块",
|
||
chunk_size=800,
|
||
chunk_overlap=100,
|
||
min_chunk_size=100,
|
||
)
|
||
|
||
# 语义边界模式
|
||
SEMANTIC_PATTERNS = [
|
||
(r"^#{1,6}\s+(.+)$", "heading"), # Markdown标题
|
||
(r"^【(.+?)】\s*$", "heading"), # 中文标题
|
||
(r"^第[一二三四五六七八九十百]+[章节条]", "heading"), # 章节标题
|
||
(r"^(\d+\.)+\s+", "heading"), # 数字编号
|
||
]
|
||
|
||
def chunk(self, text: str, metadata: Optional[dict] = None) -> list[dict]:
|
||
chunks = []
|
||
metadata = metadata or {}
|
||
lines = text.split("\n")
|
||
|
||
current_chunk = ""
|
||
current_section = None
|
||
|
||
for line in lines:
|
||
# 检查是否是语义边界
|
||
is_boundary = False
|
||
for pattern, boundary_type in self.SEMANTIC_PATTERNS:
|
||
if re.match(pattern, line.strip()):
|
||
is_boundary = True
|
||
current_section = line.strip()
|
||
break
|
||
|
||
# 如果是边界且当前块不为空,保存
|
||
if is_boundary and current_chunk.strip():
|
||
chunks.append({
|
||
"content": current_chunk.strip(),
|
||
"chunk_index": len(chunks),
|
||
"section": current_section,
|
||
"metadata": metadata,
|
||
})
|
||
# 保留重叠
|
||
overlap = current_chunk[-self.STRATEGY.chunk_overlap:]
|
||
current_chunk = overlap + line + "\n"
|
||
else:
|
||
current_chunk += line + "\n"
|
||
|
||
# 检查块大小
|
||
if len(current_chunk) >= self.STRATEGY.chunk_size:
|
||
chunks.append({
|
||
"content": current_chunk.strip(),
|
||
"chunk_index": len(chunks),
|
||
"section": current_section,
|
||
"metadata": metadata,
|
||
})
|
||
overlap = current_chunk[-self.STRATEGY.chunk_overlap:]
|
||
current_chunk = overlap
|
||
|
||
# 处理最后一个块
|
||
if current_chunk.strip():
|
||
chunks.append({
|
||
"content": current_chunk.strip(),
|
||
"chunk_index": len(chunks),
|
||
"section": current_section,
|
||
"metadata": metadata,
|
||
})
|
||
|
||
return chunks
|
||
|
||
class FixedLengthChunker(BaseChunker):
|
||
"""固定长度分块器"""
|
||
|
||
STRATEGY = ChunkStrategy(
|
||
name="fixed",
|
||
description="按固定长度强制分块",
|
||
chunk_size=300,
|
||
chunk_overlap=30,
|
||
min_chunk_size=50,
|
||
)
|
||
|
||
def chunk(self, text: str, metadata: Optional[dict] = None) -> list[dict]:
|
||
chunks = []
|
||
metadata = metadata or {}
|
||
|
||
# 移除多余空白
|
||
text = re.sub(r"\s+", " ", text)
|
||
|
||
for i in range(0, len(text), self.STRATEGY.chunk_size - self.STRATEGY.chunk_overlap):
|
||
chunk_text = text[i:i + self.STRATEGY.chunk_size]
|
||
|
||
if len(chunk_text.strip()) >= self.STRATEGY.min_chunk_size:
|
||
chunks.append({
|
||
"content": chunk_text.strip(),
|
||
"chunk_index": len(chunks),
|
||
"metadata": metadata,
|
||
})
|
||
|
||
return chunks
|
||
|
||
class ChunkerFactory:
|
||
"""分块策略工厂"""
|
||
|
||
STRATEGIES = {
|
||
"recursive": RecursiveChunker,
|
||
"semantic": SemanticChunker,
|
||
"fixed": FixedLengthChunker,
|
||
}
|
||
|
||
@classmethod
|
||
def create(cls, strategy: str = "recursive") -> BaseChunker:
|
||
"""创建分块器"""
|
||
chunker_cls = cls.STRATEGIES.get(strategy, RecursiveChunker)
|
||
return chunker_cls()
|
||
|
||
@classmethod
|
||
def list_strategies(cls) -> list[ChunkStrategy]:
|
||
"""列出所有策略"""
|
||
return [chunker_cls.STRATEGY for chunker_cls in cls.STRATEGIES.values()] |