geo/backend/app/services/knowledge/chunker.py

212 lines
6.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
分块策略 - 支持多种分块方式
"""
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional
@dataclass
class ChunkStrategy:
"""分块策略配置"""
name: str
description: str
chunk_size: int # 字符数
chunk_overlap: int # 重叠字符数
min_chunk_size: int
class BaseChunker(ABC):
"""分块器基类"""
STRATEGY: ChunkStrategy = None
@abstractmethod
def chunk(self, text: str, metadata: Optional[dict] = None) -> list[dict]:
"""执行分块"""
pass
def preview(self, text: str, max_chunks: int = 5) -> list[str]:
"""预览分块结果"""
chunks = self.chunk(text)
return [c["content"][:200] + "..." if len(c["content"]) > 200 else c["content"]
for c in chunks[:max_chunks]]
class RecursiveChunker(BaseChunker):
"""递归分块器(现有实现)"""
STRATEGY = ChunkStrategy(
name="recursive",
description="优先按段落分割,过长时按句子分割",
chunk_size=500,
chunk_overlap=50,
min_chunk_size=50,
)
# 分割模式(按优先级)
SEPARATORS = [
r"\n\n+", # 双换行(段落)
r"\n", # 单换行
r"[。!?!?]\s*", # 句子结束
r"[,;]\s*", # 分句
r"\s+", # 空格
]
def chunk(self, text: str, metadata: Optional[dict] = None) -> list[dict]:
chunks = []
metadata = metadata or {}
# 按段落分割
segments = re.split(r"\n\n+", text)
current_chunk = ""
for segment in segments:
if len(current_chunk) + len(segment) <= self.STRATEGY.chunk_size:
current_chunk += segment + "\n\n"
else:
# 当前块足够大,保存
if len(current_chunk.strip()) >= self.STRATEGY.min_chunk_size:
chunks.append({
"content": current_chunk.strip(),
"chunk_index": len(chunks),
"metadata": metadata,
})
# 处理过长段落
if len(segment) > self.STRATEGY.chunk_size:
current_chunk = segment
else:
# 保留重叠
overlap = current_chunk[-self.STRATEGY.chunk_overlap:]
current_chunk = overlap + segment + "\n\n"
# 处理最后一个块
if len(current_chunk.strip()) >= self.STRATEGY.min_chunk_size:
chunks.append({
"content": current_chunk.strip(),
"chunk_index": len(chunks),
"metadata": metadata,
})
return chunks
class SemanticChunker(BaseChunker):
"""语义分块器 - 按语义边界分割"""
STRATEGY = ChunkStrategy(
name="semantic",
description="根据语义边界(标题、段落)自动分块",
chunk_size=800,
chunk_overlap=100,
min_chunk_size=100,
)
# 语义边界模式
SEMANTIC_PATTERNS = [
(r"^#{1,6}\s+(.+)$", "heading"), # Markdown标题
(r"^【(.+?)】\s*$", "heading"), # 中文标题
(r"^第[一二三四五六七八九十百]+[章节条]", "heading"), # 章节标题
(r"^(\d+\.)+\s+", "heading"), # 数字编号
]
def chunk(self, text: str, metadata: Optional[dict] = None) -> list[dict]:
chunks = []
metadata = metadata or {}
lines = text.split("\n")
current_chunk = ""
current_section = None
for line in lines:
# 检查是否是语义边界
is_boundary = False
for pattern, boundary_type in self.SEMANTIC_PATTERNS:
if re.match(pattern, line.strip()):
is_boundary = True
current_section = line.strip()
break
# 如果是边界且当前块不为空,保存
if is_boundary and current_chunk.strip():
chunks.append({
"content": current_chunk.strip(),
"chunk_index": len(chunks),
"section": current_section,
"metadata": metadata,
})
# 保留重叠
overlap = current_chunk[-self.STRATEGY.chunk_overlap:]
current_chunk = overlap + line + "\n"
else:
current_chunk += line + "\n"
# 检查块大小
if len(current_chunk) >= self.STRATEGY.chunk_size:
chunks.append({
"content": current_chunk.strip(),
"chunk_index": len(chunks),
"section": current_section,
"metadata": metadata,
})
overlap = current_chunk[-self.STRATEGY.chunk_overlap:]
current_chunk = overlap
# 处理最后一个块
if current_chunk.strip():
chunks.append({
"content": current_chunk.strip(),
"chunk_index": len(chunks),
"section": current_section,
"metadata": metadata,
})
return chunks
class FixedLengthChunker(BaseChunker):
"""固定长度分块器"""
STRATEGY = ChunkStrategy(
name="fixed",
description="按固定长度强制分块",
chunk_size=300,
chunk_overlap=30,
min_chunk_size=50,
)
def chunk(self, text: str, metadata: Optional[dict] = None) -> list[dict]:
chunks = []
metadata = metadata or {}
# 移除多余空白
text = re.sub(r"\s+", " ", text)
for i in range(0, len(text), self.STRATEGY.chunk_size - self.STRATEGY.chunk_overlap):
chunk_text = text[i:i + self.STRATEGY.chunk_size]
if len(chunk_text.strip()) >= self.STRATEGY.min_chunk_size:
chunks.append({
"content": chunk_text.strip(),
"chunk_index": len(chunks),
"metadata": metadata,
})
return chunks
class ChunkerFactory:
"""分块策略工厂"""
STRATEGIES = {
"recursive": RecursiveChunker,
"semantic": SemanticChunker,
"fixed": FixedLengthChunker,
}
@classmethod
def create(cls, strategy: str = "recursive") -> BaseChunker:
"""创建分块器"""
chunker_cls = cls.STRATEGIES.get(strategy, RecursiveChunker)
return chunker_cls()
@classmethod
def list_strategies(cls) -> list[ChunkStrategy]:
"""列出所有策略"""
return [chunker_cls.STRATEGY for chunker_cls in cls.STRATEGIES.values()]