193 lines
6.0 KiB
Python
193 lines
6.0 KiB
Python
"""LLM-based 问题生成 — 为文档段落生成相关问题,提升检索召回率。
|
||
|
||
参考 memory/contextual_retrieval.py 的 ContextualChunker 模式:
|
||
- 使用 LLM gateway 生成问题
|
||
- Prompt 模板化
|
||
- 失败时降级(返回空列表,不抛异常)
|
||
|
||
生成的问题可作为 chunk 的 metadata.questions 字段,在索引时与 chunk 内容
|
||
一同嵌入,提升"问题→段落"的检索召回率(HyDE 反向模式)。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import hashlib
|
||
import logging
|
||
import re
|
||
|
||
from pydantic import BaseModel, ConfigDict
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class GeneratedQuestion(BaseModel):
|
||
"""生成的问题条目。"""
|
||
|
||
model_config = ConfigDict()
|
||
|
||
question: str
|
||
chunk_id: str
|
||
document_id: str
|
||
|
||
|
||
# Prompt 模板 — 指示 LLM 为给定 chunk 生成可被该 chunk 回答的问题
|
||
QUESTION_GEN_PROMPT_TEMPLATE = """\
|
||
请为以下文本片段生成 {n} 个问题,要求:
|
||
1. 每个问题都能直接从该片段中找到答案
|
||
2. 问题应覆盖片段中的关键信息点
|
||
3. 问题应简洁、自然,符合用户提问习惯
|
||
4. 每行一个问题,不要编号,不要前缀
|
||
|
||
文本片段:
|
||
{chunk}
|
||
|
||
问题(每行一个):
|
||
"""
|
||
|
||
|
||
# ponytail: 解析 LLM 输出为问题列表 — 按行切分,过滤空行和编号前缀
|
||
# 升级路径:若 LLM 输出格式不稳定,可改用 JSON 模式(要求 LLM 输出 JSON 数组)
|
||
_NUMBER_PREFIX_RE = re.compile(r"^\s*\d+[\.\)、\]]\s*")
|
||
|
||
|
||
class QuestionGenerator:
|
||
"""问题生成器 — 为每个 chunk 生成相关问题。
|
||
|
||
Args:
|
||
llm_gateway: LLM Gateway 实例(需实现 async chat(messages, model) -> response)
|
||
max_questions_per_chunk: 每个 chunk 生成的问题数上限
|
||
model: LLM 模型名(默认 "default")
|
||
cache: 是否启用缓存(避免对同一 chunk 重复调用 LLM)
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
llm_gateway: object | None = None,
|
||
max_questions_per_chunk: int = 3,
|
||
model: str = "default",
|
||
cache: bool = True,
|
||
) -> None:
|
||
self._llm_gateway = llm_gateway
|
||
self._max_questions = max_questions_per_chunk
|
||
self._model = model
|
||
self._cache_enabled = cache
|
||
self._cache: dict[str, list[str]] = {}
|
||
|
||
async def generate(
|
||
self,
|
||
chunks: list[dict[str, object]],
|
||
document_context: str = "",
|
||
) -> list[GeneratedQuestion]:
|
||
"""为每个 chunk 生成相关问题。
|
||
|
||
Args:
|
||
chunks: chunk 字典列表,每个字典需包含 id、content、document_id 字段
|
||
document_context: 完整文档内容(可选,用于提供额外上下文)
|
||
|
||
Returns:
|
||
生成的问题列表(GeneratedQuestion)。无 LLM 或失败时返回空列表。
|
||
"""
|
||
if not chunks:
|
||
return []
|
||
|
||
if not self._llm_gateway:
|
||
logger.info("No LLM gateway configured, skipping question generation")
|
||
return []
|
||
|
||
results: list[GeneratedQuestion] = []
|
||
for chunk in chunks:
|
||
chunk_id = str(chunk.get("id", ""))
|
||
content = str(chunk.get("content", ""))
|
||
document_id = str(chunk.get("document_id", ""))
|
||
|
||
if not content.strip():
|
||
continue
|
||
|
||
questions = await self._generate_for_chunk(content, document_context)
|
||
for q in questions:
|
||
results.append(
|
||
GeneratedQuestion(
|
||
question=q,
|
||
chunk_id=chunk_id,
|
||
document_id=document_id,
|
||
)
|
||
)
|
||
|
||
return results
|
||
|
||
async def _generate_for_chunk(
|
||
self,
|
||
chunk_content: str,
|
||
document_context: str,
|
||
) -> list[str]:
|
||
"""为单个 chunk 生成问题。
|
||
|
||
Args:
|
||
chunk_content: chunk 文本内容
|
||
document_context: 完整文档内容(当前未使用,保留以供未来扩展)
|
||
|
||
Returns:
|
||
问题字符串列表。失败时返回空列表。
|
||
"""
|
||
# 缓存检查
|
||
cache_key = self._make_cache_key(chunk_content)
|
||
if self._cache_enabled and cache_key in self._cache:
|
||
return self._cache[cache_key]
|
||
|
||
# 截断超长 chunk — 避免 prompt 过长
|
||
chunk_preview = chunk_content[:2000] if len(chunk_content) > 2000 else chunk_content
|
||
|
||
prompt = QUESTION_GEN_PROMPT_TEMPLATE.format(
|
||
n=self._max_questions,
|
||
chunk=chunk_preview,
|
||
)
|
||
|
||
try:
|
||
response = await self._llm_gateway.chat(
|
||
messages=[{"role": "user", "content": prompt}],
|
||
model=self._model,
|
||
)
|
||
content = response.content.strip()
|
||
questions = self._parse_questions(content)
|
||
except Exception as e:
|
||
logger.warning("Question generation failed for chunk: %s", e)
|
||
questions = []
|
||
|
||
if self._cache_enabled:
|
||
self._cache[cache_key] = questions
|
||
|
||
return questions
|
||
|
||
@staticmethod
|
||
def _parse_questions(raw_output: str) -> list[str]:
|
||
"""解析 LLM 输出为问题列表 — 按行切分,过滤空行和编号前缀。
|
||
|
||
Args:
|
||
raw_output: LLM 原始输出文本
|
||
|
||
Returns:
|
||
问题字符串列表(最多 max_questions 个)。
|
||
"""
|
||
lines: list[str] = []
|
||
for line in raw_output.splitlines():
|
||
line = line.strip()
|
||
if not line:
|
||
continue
|
||
# 去除编号前缀(如 "1. "、"2) "、"3、 ")
|
||
line = _NUMBER_PREFIX_RE.sub("", line).strip()
|
||
if line:
|
||
lines.append(line)
|
||
return lines
|
||
|
||
@staticmethod
|
||
def _make_cache_key(chunk_content: str) -> str:
|
||
"""生成缓存键 — 基于 chunk 内容的 SHA256 哈希。"""
|
||
return hashlib.sha256(chunk_content.encode()).hexdigest()[:16]
|
||
|
||
def clear_cache(self) -> None:
|
||
"""清除问题生成缓存。"""
|
||
self._cache.clear()
|
||
|
||
|
||
__all__ = ["GeneratedQuestion", "QuestionGenerator"]
|