fischer-agentkit/tests/unit/rag_platform/test_question_gen.py

342 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""U5 测试 — LLM-based 问题生成。
测试场景:
1. 无 LLM gateway 时返回空列表
2. 空 chunks 列表返回空列表
3. LLM 生成的问题被正确解析(按行切分,去除编号前缀)
4. 每个 chunk 生成的问题关联正确的 chunk_id 和 document_id
5. LLM 失败时该 chunk 返回空问题列表(不影响其他 chunk
6. 缓存生效(同一 chunk 不重复调用 LLM
7. 跳过空内容 chunk
8. _parse_questions 正确处理编号前缀和空行
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
from agentkit.rag_platform.question_gen import (
GeneratedQuestion,
QuestionGenerator,
_NUMBER_PREFIX_RE,
)
# ---------------------------------------------------------------------------
# 测试辅助函数
# ---------------------------------------------------------------------------
def _make_llm_response(content: str):
"""创建 mock LLM 响应。"""
response = MagicMock()
response.content = content
return response
def _make_mock_llm_gateway(responses: list[str] | None = None):
"""创建 mock LLM gateway。
Args:
responses: chat 方法返回的响应内容列表(按调用顺序返回)
若为 None返回默认响应
Returns:
(mock_gateway, chat_mock) — mock gateway 和 chat AsyncMock
"""
mock_gateway = MagicMock()
mock_gateway.chat = AsyncMock()
if responses is not None:
mock_gateway.chat.side_effect = [_make_llm_response(r) for r in responses]
else:
mock_gateway.chat.return_value = _make_llm_response("问题一\n问题二\n问题三")
return mock_gateway, mock_gateway.chat
# ---------------------------------------------------------------------------
# GeneratedQuestion 模型测试
# ---------------------------------------------------------------------------
class TestGeneratedQuestion:
"""GeneratedQuestion 模型测试。"""
def test_fields(self):
"""模型字段正确。"""
q = GeneratedQuestion(
question="什么是 RAG",
chunk_id="c1",
document_id="d1",
)
assert q.question == "什么是 RAG"
assert q.chunk_id == "c1"
assert q.document_id == "d1"
# ---------------------------------------------------------------------------
# _parse_questions 测试
# ---------------------------------------------------------------------------
class TestParseQuestions:
"""_parse_questions 静态方法测试。"""
def test_plain_lines(self):
"""纯文本行被正确切分。"""
result = QuestionGenerator._parse_questions("问题一\n问题二\n问题三")
assert result == ["问题一", "问题二", "问题三"]
def test_numbered_lines(self):
"""编号前缀被去除。"""
result = QuestionGenerator._parse_questions("1. 问题一\n2. 问题二\n3. 问题三")
assert result == ["问题一", "问题二", "问题三"]
def test_paren_numbered_lines(self):
"""括号编号前缀被去除。"""
result = QuestionGenerator._parse_questions("1) 问题一\n2) 问题二")
assert result == ["问题一", "问题二"]
def test_chinese_numbered_lines(self):
"""中文编号前缀(、)被去除。"""
result = QuestionGenerator._parse_questions("1、 问题一\n2、 问题二")
assert result == ["问题一", "问题二"]
def test_empty_lines_filtered(self):
"""空行被过滤。"""
result = QuestionGenerator._parse_questions("问题一\n\n \n问题二")
assert result == ["问题一", "问题二"]
def test_empty_input(self):
"""空输入返回空列表。"""
assert QuestionGenerator._parse_questions("") == []
def test_whitespace_only(self):
"""纯空白返回空列表。"""
assert QuestionGenerator._parse_questions(" \n \n") == []
# ---------------------------------------------------------------------------
# QuestionGenerator 测试
# ---------------------------------------------------------------------------
class TestQuestionGeneratorNoLLM:
"""无 LLM gateway 时的测试。"""
async def test_no_llm_returns_empty(self):
"""无 LLM gateway 时返回空列表。"""
gen = QuestionGenerator(llm_gateway=None)
chunks = [{"id": "c1", "content": "内容", "document_id": "d1"}]
result = await gen.generate(chunks)
assert result == []
async def test_empty_chunks_returns_empty(self):
"""空 chunks 列表返回空列表。"""
mock_gw, _ = _make_mock_llm_gateway()
gen = QuestionGenerator(llm_gateway=mock_gw)
result = await gen.generate([])
assert result == []
mock_gw.chat.assert_not_awaited()
class TestQuestionGeneratorWithLLM:
"""有 LLM gateway 时的测试。"""
async def test_generates_questions_for_chunks(self):
"""为每个 chunk 生成相关问题。"""
mock_gw, chat_mock = _make_mock_llm_gateway(["问题A1\n问题A2", "问题B1\n问题B2"])
gen = QuestionGenerator(llm_gateway=mock_gw, max_questions_per_chunk=2)
chunks = [
{"id": "c1", "content": "内容A", "document_id": "d1"},
{"id": "c2", "content": "内容B", "document_id": "d1"},
]
result = await gen.generate(chunks)
assert len(result) == 4
# 第一个 chunk 的问题
assert result[0].question == "问题A1"
assert result[0].chunk_id == "c1"
assert result[0].document_id == "d1"
assert result[1].question == "问题A2"
assert result[1].chunk_id == "c1"
# 第二个 chunk 的问题
assert result[2].question == "问题B1"
assert result[2].chunk_id == "c2"
assert result[3].question == "问题B2"
assert result[3].chunk_id == "c2"
# LLM 被调用 2 次(每个 chunk 一次)
assert chat_mock.await_count == 2
async def test_questions_relate_to_chunk_content(self):
"""生成的问题与 chunk 内容相关(验证 prompt 包含 chunk 内容)。"""
mock_gw, chat_mock = _make_mock_llm_gateway(["什么是 RAG"])
gen = QuestionGenerator(llm_gateway=mock_gw, max_questions_per_chunk=1)
chunks = [
{
"id": "c1",
"content": "RAG 是检索增强生成的缩写,结合了检索和生成模型。",
"document_id": "d1",
}
]
result = await gen.generate(chunks)
assert len(result) == 1
assert result[0].question == "什么是 RAG"
# 验证 prompt 包含 chunk 内容
call_args = chat_mock.await_args
messages = call_args.args[0] if call_args.args else call_args.kwargs["messages"]
prompt_content = messages[0]["content"]
assert "RAG" in prompt_content
assert "检索增强生成" in prompt_content
async def test_llm_failure_returns_empty_for_chunk(self):
"""LLM 失败时该 chunk 返回空问题列表,不影响其他 chunk。"""
mock_gw = MagicMock()
mock_gw.chat = AsyncMock(
side_effect=[
RuntimeError("LLM error"), # 第一个 chunk 失败
_make_llm_response("问题B1"), # 第二个 chunk 成功
]
)
gen = QuestionGenerator(llm_gateway=mock_gw, cache=False)
chunks = [
{"id": "c1", "content": "内容A", "document_id": "d1"},
{"id": "c2", "content": "内容B", "document_id": "d1"},
]
result = await gen.generate(chunks)
# 第一个 chunk 失败,无问题;第二个 chunk 成功1 个问题
assert len(result) == 1
assert result[0].question == "问题B1"
assert result[0].chunk_id == "c2"
async def test_skips_empty_content_chunks(self):
"""空内容 chunk 被跳过(不调用 LLM"""
mock_gw, chat_mock = _make_mock_llm_gateway(["问题A"])
gen = QuestionGenerator(llm_gateway=mock_gw)
chunks = [
{"id": "c1", "content": "", "document_id": "d1"}, # 空内容
{"id": "c2", "content": " ", "document_id": "d1"}, # 纯空白
{"id": "c3", "content": "内容C", "document_id": "d1"}, # 有效
]
result = await gen.generate(chunks)
# 只有 c3 生成问题
assert len(result) == 1
assert result[0].chunk_id == "c3"
# LLM 只被调用 1 次
assert chat_mock.await_count == 1
async def test_cache_avoids_duplicate_calls(self):
"""缓存生效 — 同一 chunk 内容不重复调用 LLM。"""
mock_gw, chat_mock = _make_mock_llm_gateway(["问题A"])
gen = QuestionGenerator(llm_gateway=mock_gw, cache=True)
chunks = [{"id": "c1", "content": "相同内容", "document_id": "d1"}]
# 第一次调用
result1 = await gen.generate(chunks)
assert len(result1) == 1
# 第二次调用相同内容
result2 = await gen.generate(chunks)
assert len(result2) == 1
# LLM 只被调用 1 次(缓存命中)
assert chat_mock.await_count == 1
async def test_no_cache_calls_each_time(self):
"""禁用缓存时每次都调用 LLM。"""
mock_gw, chat_mock = _make_mock_llm_gateway(["问题A", "问题A"])
gen = QuestionGenerator(llm_gateway=mock_gw, cache=False)
chunks = [{"id": "c1", "content": "相同内容", "document_id": "d1"}]
await gen.generate(chunks)
await gen.generate(chunks)
# LLM 被调用 2 次
assert chat_mock.await_count == 2
async def test_clear_cache(self):
"""clear_cache 清除缓存后重新调用 LLM。"""
mock_gw, chat_mock = _make_mock_llm_gateway(["问题A", "问题A"])
gen = QuestionGenerator(llm_gateway=mock_gw, cache=True)
chunks = [{"id": "c1", "content": "相同内容", "document_id": "d1"}]
await gen.generate(chunks)
gen.clear_cache()
await gen.generate(chunks)
# 缓存清除后 LLM 被调用 2 次
assert chat_mock.await_count == 2
async def test_max_questions_in_prompt(self):
"""prompt 中包含 max_questions_per_chunk 数量。"""
mock_gw, chat_mock = _make_mock_llm_gateway(["问题A"])
gen = QuestionGenerator(llm_gateway=mock_gw, max_questions_per_chunk=5)
chunks = [{"id": "c1", "content": "内容", "document_id": "d1"}]
await gen.generate(chunks)
call_args = chat_mock.await_args
messages = call_args.args[0] if call_args.args else call_args.kwargs["messages"]
prompt_content = messages[0]["content"]
# prompt 中应包含 "5 个问题"
assert "5" in prompt_content
async def test_truncates_long_chunk(self):
"""超长 chunk 被截断到 2000 字符(避免 prompt 过长)。"""
mock_gw, chat_mock = _make_mock_llm_gateway(["问题A"])
gen = QuestionGenerator(llm_gateway=mock_gw)
long_content = "A" * 3000
chunks = [{"id": "c1", "content": long_content, "document_id": "d1"}]
await gen.generate(chunks)
call_args = chat_mock.await_args
messages = call_args.args[0] if call_args.args else call_args.kwargs["messages"]
prompt_content = messages[0]["content"]
# prompt 中不应包含完整的 3000 字符内容(被截断到 2000
assert prompt_content.count("A") < 3000
class TestNumberPrefixRegex:
"""_NUMBER_PREFIX_RE 正则测试。"""
def test_dot_prefix(self):
"""点号编号前缀匹配。"""
assert _NUMBER_PREFIX_RE.match("1. 问题")
assert _NUMBER_PREFIX_RE.match("12. 问题")
def test_paren_prefix(self):
"""括号编号前缀匹配。"""
assert _NUMBER_PREFIX_RE.match("1) 问题")
assert _NUMBER_PREFIX_RE.match("2) 问题")
def test_chinese_paren_prefix(self):
"""中文顿号编号前缀匹配。"""
assert _NUMBER_PREFIX_RE.match("1、 问题")
def test_no_match_plain_text(self):
"""纯文本不匹配。"""
assert not _NUMBER_PREFIX_RE.match("问题一")
assert not _NUMBER_PREFIX_RE.match("什么是 RAG")