342 lines
12 KiB
Python
342 lines
12 KiB
Python
"""U5 测试 — LLM-based 问题生成。
|
||
|
||
测试场景:
|
||
1. 无 LLM gateway 时返回空列表
|
||
2. 空 chunks 列表返回空列表
|
||
3. LLM 生成的问题被正确解析(按行切分,去除编号前缀)
|
||
4. 每个 chunk 生成的问题关联正确的 chunk_id 和 document_id
|
||
5. LLM 失败时该 chunk 返回空问题列表(不影响其他 chunk)
|
||
6. 缓存生效(同一 chunk 不重复调用 LLM)
|
||
7. 跳过空内容 chunk
|
||
8. _parse_questions 正确处理编号前缀和空行
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
from unittest.mock import AsyncMock, MagicMock
|
||
|
||
from agentkit.rag_platform.question_gen import (
|
||
GeneratedQuestion,
|
||
QuestionGenerator,
|
||
_NUMBER_PREFIX_RE,
|
||
)
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 测试辅助函数
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def _make_llm_response(content: str):
|
||
"""创建 mock LLM 响应。"""
|
||
response = MagicMock()
|
||
response.content = content
|
||
return response
|
||
|
||
|
||
def _make_mock_llm_gateway(responses: list[str] | None = None):
|
||
"""创建 mock LLM gateway。
|
||
|
||
Args:
|
||
responses: chat 方法返回的响应内容列表(按调用顺序返回)
|
||
若为 None,返回默认响应
|
||
|
||
Returns:
|
||
(mock_gateway, chat_mock) — mock gateway 和 chat AsyncMock
|
||
"""
|
||
mock_gateway = MagicMock()
|
||
mock_gateway.chat = AsyncMock()
|
||
|
||
if responses is not None:
|
||
mock_gateway.chat.side_effect = [_make_llm_response(r) for r in responses]
|
||
else:
|
||
mock_gateway.chat.return_value = _make_llm_response("问题一\n问题二\n问题三")
|
||
|
||
return mock_gateway, mock_gateway.chat
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# GeneratedQuestion 模型测试
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestGeneratedQuestion:
|
||
"""GeneratedQuestion 模型测试。"""
|
||
|
||
def test_fields(self):
|
||
"""模型字段正确。"""
|
||
q = GeneratedQuestion(
|
||
question="什么是 RAG?",
|
||
chunk_id="c1",
|
||
document_id="d1",
|
||
)
|
||
assert q.question == "什么是 RAG?"
|
||
assert q.chunk_id == "c1"
|
||
assert q.document_id == "d1"
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# _parse_questions 测试
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestParseQuestions:
|
||
"""_parse_questions 静态方法测试。"""
|
||
|
||
def test_plain_lines(self):
|
||
"""纯文本行被正确切分。"""
|
||
result = QuestionGenerator._parse_questions("问题一\n问题二\n问题三")
|
||
assert result == ["问题一", "问题二", "问题三"]
|
||
|
||
def test_numbered_lines(self):
|
||
"""编号前缀被去除。"""
|
||
result = QuestionGenerator._parse_questions("1. 问题一\n2. 问题二\n3. 问题三")
|
||
assert result == ["问题一", "问题二", "问题三"]
|
||
|
||
def test_paren_numbered_lines(self):
|
||
"""括号编号前缀被去除。"""
|
||
result = QuestionGenerator._parse_questions("1) 问题一\n2) 问题二")
|
||
assert result == ["问题一", "问题二"]
|
||
|
||
def test_chinese_numbered_lines(self):
|
||
"""中文编号前缀(、)被去除。"""
|
||
result = QuestionGenerator._parse_questions("1、 问题一\n2、 问题二")
|
||
assert result == ["问题一", "问题二"]
|
||
|
||
def test_empty_lines_filtered(self):
|
||
"""空行被过滤。"""
|
||
result = QuestionGenerator._parse_questions("问题一\n\n \n问题二")
|
||
assert result == ["问题一", "问题二"]
|
||
|
||
def test_empty_input(self):
|
||
"""空输入返回空列表。"""
|
||
assert QuestionGenerator._parse_questions("") == []
|
||
|
||
def test_whitespace_only(self):
|
||
"""纯空白返回空列表。"""
|
||
assert QuestionGenerator._parse_questions(" \n \n") == []
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# QuestionGenerator 测试
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestQuestionGeneratorNoLLM:
|
||
"""无 LLM gateway 时的测试。"""
|
||
|
||
async def test_no_llm_returns_empty(self):
|
||
"""无 LLM gateway 时返回空列表。"""
|
||
gen = QuestionGenerator(llm_gateway=None)
|
||
chunks = [{"id": "c1", "content": "内容", "document_id": "d1"}]
|
||
|
||
result = await gen.generate(chunks)
|
||
|
||
assert result == []
|
||
|
||
async def test_empty_chunks_returns_empty(self):
|
||
"""空 chunks 列表返回空列表。"""
|
||
mock_gw, _ = _make_mock_llm_gateway()
|
||
gen = QuestionGenerator(llm_gateway=mock_gw)
|
||
|
||
result = await gen.generate([])
|
||
|
||
assert result == []
|
||
mock_gw.chat.assert_not_awaited()
|
||
|
||
|
||
class TestQuestionGeneratorWithLLM:
|
||
"""有 LLM gateway 时的测试。"""
|
||
|
||
async def test_generates_questions_for_chunks(self):
|
||
"""为每个 chunk 生成相关问题。"""
|
||
mock_gw, chat_mock = _make_mock_llm_gateway(["问题A1\n问题A2", "问题B1\n问题B2"])
|
||
gen = QuestionGenerator(llm_gateway=mock_gw, max_questions_per_chunk=2)
|
||
|
||
chunks = [
|
||
{"id": "c1", "content": "内容A", "document_id": "d1"},
|
||
{"id": "c2", "content": "内容B", "document_id": "d1"},
|
||
]
|
||
|
||
result = await gen.generate(chunks)
|
||
|
||
assert len(result) == 4
|
||
# 第一个 chunk 的问题
|
||
assert result[0].question == "问题A1"
|
||
assert result[0].chunk_id == "c1"
|
||
assert result[0].document_id == "d1"
|
||
assert result[1].question == "问题A2"
|
||
assert result[1].chunk_id == "c1"
|
||
# 第二个 chunk 的问题
|
||
assert result[2].question == "问题B1"
|
||
assert result[2].chunk_id == "c2"
|
||
assert result[3].question == "问题B2"
|
||
assert result[3].chunk_id == "c2"
|
||
|
||
# LLM 被调用 2 次(每个 chunk 一次)
|
||
assert chat_mock.await_count == 2
|
||
|
||
async def test_questions_relate_to_chunk_content(self):
|
||
"""生成的问题与 chunk 内容相关(验证 prompt 包含 chunk 内容)。"""
|
||
mock_gw, chat_mock = _make_mock_llm_gateway(["什么是 RAG?"])
|
||
gen = QuestionGenerator(llm_gateway=mock_gw, max_questions_per_chunk=1)
|
||
|
||
chunks = [
|
||
{
|
||
"id": "c1",
|
||
"content": "RAG 是检索增强生成的缩写,结合了检索和生成模型。",
|
||
"document_id": "d1",
|
||
}
|
||
]
|
||
|
||
result = await gen.generate(chunks)
|
||
|
||
assert len(result) == 1
|
||
assert result[0].question == "什么是 RAG?"
|
||
|
||
# 验证 prompt 包含 chunk 内容
|
||
call_args = chat_mock.await_args
|
||
messages = call_args.args[0] if call_args.args else call_args.kwargs["messages"]
|
||
prompt_content = messages[0]["content"]
|
||
assert "RAG" in prompt_content
|
||
assert "检索增强生成" in prompt_content
|
||
|
||
async def test_llm_failure_returns_empty_for_chunk(self):
|
||
"""LLM 失败时该 chunk 返回空问题列表,不影响其他 chunk。"""
|
||
mock_gw = MagicMock()
|
||
mock_gw.chat = AsyncMock(
|
||
side_effect=[
|
||
RuntimeError("LLM error"), # 第一个 chunk 失败
|
||
_make_llm_response("问题B1"), # 第二个 chunk 成功
|
||
]
|
||
)
|
||
gen = QuestionGenerator(llm_gateway=mock_gw, cache=False)
|
||
|
||
chunks = [
|
||
{"id": "c1", "content": "内容A", "document_id": "d1"},
|
||
{"id": "c2", "content": "内容B", "document_id": "d1"},
|
||
]
|
||
|
||
result = await gen.generate(chunks)
|
||
|
||
# 第一个 chunk 失败,无问题;第二个 chunk 成功,1 个问题
|
||
assert len(result) == 1
|
||
assert result[0].question == "问题B1"
|
||
assert result[0].chunk_id == "c2"
|
||
|
||
async def test_skips_empty_content_chunks(self):
|
||
"""空内容 chunk 被跳过(不调用 LLM)。"""
|
||
mock_gw, chat_mock = _make_mock_llm_gateway(["问题A"])
|
||
gen = QuestionGenerator(llm_gateway=mock_gw)
|
||
|
||
chunks = [
|
||
{"id": "c1", "content": "", "document_id": "d1"}, # 空内容
|
||
{"id": "c2", "content": " ", "document_id": "d1"}, # 纯空白
|
||
{"id": "c3", "content": "内容C", "document_id": "d1"}, # 有效
|
||
]
|
||
|
||
result = await gen.generate(chunks)
|
||
|
||
# 只有 c3 生成问题
|
||
assert len(result) == 1
|
||
assert result[0].chunk_id == "c3"
|
||
# LLM 只被调用 1 次
|
||
assert chat_mock.await_count == 1
|
||
|
||
async def test_cache_avoids_duplicate_calls(self):
|
||
"""缓存生效 — 同一 chunk 内容不重复调用 LLM。"""
|
||
mock_gw, chat_mock = _make_mock_llm_gateway(["问题A"])
|
||
gen = QuestionGenerator(llm_gateway=mock_gw, cache=True)
|
||
|
||
chunks = [{"id": "c1", "content": "相同内容", "document_id": "d1"}]
|
||
|
||
# 第一次调用
|
||
result1 = await gen.generate(chunks)
|
||
assert len(result1) == 1
|
||
|
||
# 第二次调用相同内容
|
||
result2 = await gen.generate(chunks)
|
||
assert len(result2) == 1
|
||
|
||
# LLM 只被调用 1 次(缓存命中)
|
||
assert chat_mock.await_count == 1
|
||
|
||
async def test_no_cache_calls_each_time(self):
|
||
"""禁用缓存时每次都调用 LLM。"""
|
||
mock_gw, chat_mock = _make_mock_llm_gateway(["问题A", "问题A"])
|
||
gen = QuestionGenerator(llm_gateway=mock_gw, cache=False)
|
||
|
||
chunks = [{"id": "c1", "content": "相同内容", "document_id": "d1"}]
|
||
|
||
await gen.generate(chunks)
|
||
await gen.generate(chunks)
|
||
|
||
# LLM 被调用 2 次
|
||
assert chat_mock.await_count == 2
|
||
|
||
async def test_clear_cache(self):
|
||
"""clear_cache 清除缓存后重新调用 LLM。"""
|
||
mock_gw, chat_mock = _make_mock_llm_gateway(["问题A", "问题A"])
|
||
gen = QuestionGenerator(llm_gateway=mock_gw, cache=True)
|
||
|
||
chunks = [{"id": "c1", "content": "相同内容", "document_id": "d1"}]
|
||
|
||
await gen.generate(chunks)
|
||
gen.clear_cache()
|
||
await gen.generate(chunks)
|
||
|
||
# 缓存清除后 LLM 被调用 2 次
|
||
assert chat_mock.await_count == 2
|
||
|
||
async def test_max_questions_in_prompt(self):
|
||
"""prompt 中包含 max_questions_per_chunk 数量。"""
|
||
mock_gw, chat_mock = _make_mock_llm_gateway(["问题A"])
|
||
gen = QuestionGenerator(llm_gateway=mock_gw, max_questions_per_chunk=5)
|
||
|
||
chunks = [{"id": "c1", "content": "内容", "document_id": "d1"}]
|
||
await gen.generate(chunks)
|
||
|
||
call_args = chat_mock.await_args
|
||
messages = call_args.args[0] if call_args.args else call_args.kwargs["messages"]
|
||
prompt_content = messages[0]["content"]
|
||
# prompt 中应包含 "5 个问题"
|
||
assert "5" in prompt_content
|
||
|
||
async def test_truncates_long_chunk(self):
|
||
"""超长 chunk 被截断到 2000 字符(避免 prompt 过长)。"""
|
||
mock_gw, chat_mock = _make_mock_llm_gateway(["问题A"])
|
||
gen = QuestionGenerator(llm_gateway=mock_gw)
|
||
|
||
long_content = "A" * 3000
|
||
chunks = [{"id": "c1", "content": long_content, "document_id": "d1"}]
|
||
await gen.generate(chunks)
|
||
|
||
call_args = chat_mock.await_args
|
||
messages = call_args.args[0] if call_args.args else call_args.kwargs["messages"]
|
||
prompt_content = messages[0]["content"]
|
||
# prompt 中不应包含完整的 3000 字符内容(被截断到 2000)
|
||
assert prompt_content.count("A") < 3000
|
||
|
||
|
||
class TestNumberPrefixRegex:
|
||
"""_NUMBER_PREFIX_RE 正则测试。"""
|
||
|
||
def test_dot_prefix(self):
|
||
"""点号编号前缀匹配。"""
|
||
assert _NUMBER_PREFIX_RE.match("1. 问题")
|
||
assert _NUMBER_PREFIX_RE.match("12. 问题")
|
||
|
||
def test_paren_prefix(self):
|
||
"""括号编号前缀匹配。"""
|
||
assert _NUMBER_PREFIX_RE.match("1) 问题")
|
||
assert _NUMBER_PREFIX_RE.match("2) 问题")
|
||
|
||
def test_chinese_paren_prefix(self):
|
||
"""中文顿号编号前缀匹配。"""
|
||
assert _NUMBER_PREFIX_RE.match("1、 问题")
|
||
|
||
def test_no_match_plain_text(self):
|
||
"""纯文本不匹配。"""
|
||
assert not _NUMBER_PREFIX_RE.match("问题一")
|
||
assert not _NUMBER_PREFIX_RE.match("什么是 RAG?")
|