fischer-agentkit/src/agentkit/rag_platform/question_gen.py

193 lines
6.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""LLM-based 问题生成 — 为文档段落生成相关问题,提升检索召回率。
参考 memory/contextual_retrieval.py 的 ContextualChunker 模式:
- 使用 LLM gateway 生成问题
- Prompt 模板化
- 失败时降级(返回空列表,不抛异常)
生成的问题可作为 chunk 的 metadata.questions 字段,在索引时与 chunk 内容
一同嵌入,提升"问题→段落"的检索召回率HyDE 反向模式)。
"""
from __future__ import annotations
import hashlib
import logging
import re
from pydantic import BaseModel, ConfigDict
logger = logging.getLogger(__name__)
class GeneratedQuestion(BaseModel):
"""生成的问题条目。"""
model_config = ConfigDict()
question: str
chunk_id: str
document_id: str
# Prompt 模板 — 指示 LLM 为给定 chunk 生成可被该 chunk 回答的问题
QUESTION_GEN_PROMPT_TEMPLATE = """\
请为以下文本片段生成 {n} 个问题,要求:
1. 每个问题都能直接从该片段中找到答案
2. 问题应覆盖片段中的关键信息点
3. 问题应简洁、自然,符合用户提问习惯
4. 每行一个问题,不要编号,不要前缀
文本片段:
{chunk}
问题(每行一个):
"""
# ponytail: 解析 LLM 输出为问题列表 — 按行切分,过滤空行和编号前缀
# 升级路径:若 LLM 输出格式不稳定,可改用 JSON 模式(要求 LLM 输出 JSON 数组)
_NUMBER_PREFIX_RE = re.compile(r"^\s*\d+[\.\)、\]]\s*")
class QuestionGenerator:
"""问题生成器 — 为每个 chunk 生成相关问题。
Args:
llm_gateway: LLM Gateway 实例(需实现 async chat(messages, model) -> response
max_questions_per_chunk: 每个 chunk 生成的问题数上限
model: LLM 模型名(默认 "default"
cache: 是否启用缓存(避免对同一 chunk 重复调用 LLM
"""
def __init__(
self,
llm_gateway: object | None = None,
max_questions_per_chunk: int = 3,
model: str = "default",
cache: bool = True,
) -> None:
self._llm_gateway = llm_gateway
self._max_questions = max_questions_per_chunk
self._model = model
self._cache_enabled = cache
self._cache: dict[str, list[str]] = {}
async def generate(
self,
chunks: list[dict[str, object]],
document_context: str = "",
) -> list[GeneratedQuestion]:
"""为每个 chunk 生成相关问题。
Args:
chunks: chunk 字典列表,每个字典需包含 id、content、document_id 字段
document_context: 完整文档内容(可选,用于提供额外上下文)
Returns:
生成的问题列表GeneratedQuestion。无 LLM 或失败时返回空列表。
"""
if not chunks:
return []
if not self._llm_gateway:
logger.info("No LLM gateway configured, skipping question generation")
return []
results: list[GeneratedQuestion] = []
for chunk in chunks:
chunk_id = str(chunk.get("id", ""))
content = str(chunk.get("content", ""))
document_id = str(chunk.get("document_id", ""))
if not content.strip():
continue
questions = await self._generate_for_chunk(content, document_context)
for q in questions:
results.append(
GeneratedQuestion(
question=q,
chunk_id=chunk_id,
document_id=document_id,
)
)
return results
async def _generate_for_chunk(
self,
chunk_content: str,
document_context: str,
) -> list[str]:
"""为单个 chunk 生成问题。
Args:
chunk_content: chunk 文本内容
document_context: 完整文档内容(当前未使用,保留以供未来扩展)
Returns:
问题字符串列表。失败时返回空列表。
"""
# 缓存检查
cache_key = self._make_cache_key(chunk_content)
if self._cache_enabled and cache_key in self._cache:
return self._cache[cache_key]
# 截断超长 chunk — 避免 prompt 过长
chunk_preview = chunk_content[:2000] if len(chunk_content) > 2000 else chunk_content
prompt = QUESTION_GEN_PROMPT_TEMPLATE.format(
n=self._max_questions,
chunk=chunk_preview,
)
try:
response = await self._llm_gateway.chat(
messages=[{"role": "user", "content": prompt}],
model=self._model,
)
content = response.content.strip()
questions = self._parse_questions(content)
except Exception as e:
logger.warning("Question generation failed for chunk: %s", e)
questions = []
if self._cache_enabled:
self._cache[cache_key] = questions
return questions
@staticmethod
def _parse_questions(raw_output: str) -> list[str]:
"""解析 LLM 输出为问题列表 — 按行切分,过滤空行和编号前缀。
Args:
raw_output: LLM 原始输出文本
Returns:
问题字符串列表(最多 max_questions 个)。
"""
lines: list[str] = []
for line in raw_output.splitlines():
line = line.strip()
if not line:
continue
# 去除编号前缀(如 "1. "、"2) "、"3、 "
line = _NUMBER_PREFIX_RE.sub("", line).strip()
if line:
lines.append(line)
return lines
@staticmethod
def _make_cache_key(chunk_content: str) -> str:
"""生成缓存键 — 基于 chunk 内容的 SHA256 哈希。"""
return hashlib.sha256(chunk_content.encode()).hexdigest()[:16]
def clear_cache(self) -> None:
"""清除问题生成缓存。"""
self._cache.clear()
__all__ = ["GeneratedQuestion", "QuestionGenerator"]