"""U5 测试 — LLM-based 问题生成。 测试场景: 1. 无 LLM gateway 时返回空列表 2. 空 chunks 列表返回空列表 3. LLM 生成的问题被正确解析(按行切分,去除编号前缀) 4. 每个 chunk 生成的问题关联正确的 chunk_id 和 document_id 5. LLM 失败时该 chunk 返回空问题列表(不影响其他 chunk) 6. 缓存生效(同一 chunk 不重复调用 LLM) 7. 跳过空内容 chunk 8. _parse_questions 正确处理编号前缀和空行 """ from __future__ import annotations from unittest.mock import AsyncMock, MagicMock from agentkit.rag_platform.question_gen import ( GeneratedQuestion, QuestionGenerator, _NUMBER_PREFIX_RE, ) # --------------------------------------------------------------------------- # 测试辅助函数 # --------------------------------------------------------------------------- def _make_llm_response(content: str): """创建 mock LLM 响应。""" response = MagicMock() response.content = content return response def _make_mock_llm_gateway(responses: list[str] | None = None): """创建 mock LLM gateway。 Args: responses: chat 方法返回的响应内容列表(按调用顺序返回) 若为 None,返回默认响应 Returns: (mock_gateway, chat_mock) — mock gateway 和 chat AsyncMock """ mock_gateway = MagicMock() mock_gateway.chat = AsyncMock() if responses is not None: mock_gateway.chat.side_effect = [_make_llm_response(r) for r in responses] else: mock_gateway.chat.return_value = _make_llm_response("问题一\n问题二\n问题三") return mock_gateway, mock_gateway.chat # --------------------------------------------------------------------------- # GeneratedQuestion 模型测试 # --------------------------------------------------------------------------- class TestGeneratedQuestion: """GeneratedQuestion 模型测试。""" def test_fields(self): """模型字段正确。""" q = GeneratedQuestion( question="什么是 RAG?", chunk_id="c1", document_id="d1", ) assert q.question == "什么是 RAG?" assert q.chunk_id == "c1" assert q.document_id == "d1" # --------------------------------------------------------------------------- # _parse_questions 测试 # --------------------------------------------------------------------------- class TestParseQuestions: """_parse_questions 静态方法测试。""" def test_plain_lines(self): """纯文本行被正确切分。""" result = QuestionGenerator._parse_questions("问题一\n问题二\n问题三") assert result == ["问题一", "问题二", "问题三"] def test_numbered_lines(self): """编号前缀被去除。""" result = QuestionGenerator._parse_questions("1. 问题一\n2. 问题二\n3. 问题三") assert result == ["问题一", "问题二", "问题三"] def test_paren_numbered_lines(self): """括号编号前缀被去除。""" result = QuestionGenerator._parse_questions("1) 问题一\n2) 问题二") assert result == ["问题一", "问题二"] def test_chinese_numbered_lines(self): """中文编号前缀(、)被去除。""" result = QuestionGenerator._parse_questions("1、 问题一\n2、 问题二") assert result == ["问题一", "问题二"] def test_empty_lines_filtered(self): """空行被过滤。""" result = QuestionGenerator._parse_questions("问题一\n\n \n问题二") assert result == ["问题一", "问题二"] def test_empty_input(self): """空输入返回空列表。""" assert QuestionGenerator._parse_questions("") == [] def test_whitespace_only(self): """纯空白返回空列表。""" assert QuestionGenerator._parse_questions(" \n \n") == [] # --------------------------------------------------------------------------- # QuestionGenerator 测试 # --------------------------------------------------------------------------- class TestQuestionGeneratorNoLLM: """无 LLM gateway 时的测试。""" async def test_no_llm_returns_empty(self): """无 LLM gateway 时返回空列表。""" gen = QuestionGenerator(llm_gateway=None) chunks = [{"id": "c1", "content": "内容", "document_id": "d1"}] result = await gen.generate(chunks) assert result == [] async def test_empty_chunks_returns_empty(self): """空 chunks 列表返回空列表。""" mock_gw, _ = _make_mock_llm_gateway() gen = QuestionGenerator(llm_gateway=mock_gw) result = await gen.generate([]) assert result == [] mock_gw.chat.assert_not_awaited() class TestQuestionGeneratorWithLLM: """有 LLM gateway 时的测试。""" async def test_generates_questions_for_chunks(self): """为每个 chunk 生成相关问题。""" mock_gw, chat_mock = _make_mock_llm_gateway(["问题A1\n问题A2", "问题B1\n问题B2"]) gen = QuestionGenerator(llm_gateway=mock_gw, max_questions_per_chunk=2) chunks = [ {"id": "c1", "content": "内容A", "document_id": "d1"}, {"id": "c2", "content": "内容B", "document_id": "d1"}, ] result = await gen.generate(chunks) assert len(result) == 4 # 第一个 chunk 的问题 assert result[0].question == "问题A1" assert result[0].chunk_id == "c1" assert result[0].document_id == "d1" assert result[1].question == "问题A2" assert result[1].chunk_id == "c1" # 第二个 chunk 的问题 assert result[2].question == "问题B1" assert result[2].chunk_id == "c2" assert result[3].question == "问题B2" assert result[3].chunk_id == "c2" # LLM 被调用 2 次(每个 chunk 一次) assert chat_mock.await_count == 2 async def test_questions_relate_to_chunk_content(self): """生成的问题与 chunk 内容相关(验证 prompt 包含 chunk 内容)。""" mock_gw, chat_mock = _make_mock_llm_gateway(["什么是 RAG?"]) gen = QuestionGenerator(llm_gateway=mock_gw, max_questions_per_chunk=1) chunks = [ { "id": "c1", "content": "RAG 是检索增强生成的缩写,结合了检索和生成模型。", "document_id": "d1", } ] result = await gen.generate(chunks) assert len(result) == 1 assert result[0].question == "什么是 RAG?" # 验证 prompt 包含 chunk 内容 call_args = chat_mock.await_args messages = call_args.args[0] if call_args.args else call_args.kwargs["messages"] prompt_content = messages[0]["content"] assert "RAG" in prompt_content assert "检索增强生成" in prompt_content async def test_llm_failure_returns_empty_for_chunk(self): """LLM 失败时该 chunk 返回空问题列表,不影响其他 chunk。""" mock_gw = MagicMock() mock_gw.chat = AsyncMock( side_effect=[ RuntimeError("LLM error"), # 第一个 chunk 失败 _make_llm_response("问题B1"), # 第二个 chunk 成功 ] ) gen = QuestionGenerator(llm_gateway=mock_gw, cache=False) chunks = [ {"id": "c1", "content": "内容A", "document_id": "d1"}, {"id": "c2", "content": "内容B", "document_id": "d1"}, ] result = await gen.generate(chunks) # 第一个 chunk 失败,无问题;第二个 chunk 成功,1 个问题 assert len(result) == 1 assert result[0].question == "问题B1" assert result[0].chunk_id == "c2" async def test_skips_empty_content_chunks(self): """空内容 chunk 被跳过(不调用 LLM)。""" mock_gw, chat_mock = _make_mock_llm_gateway(["问题A"]) gen = QuestionGenerator(llm_gateway=mock_gw) chunks = [ {"id": "c1", "content": "", "document_id": "d1"}, # 空内容 {"id": "c2", "content": " ", "document_id": "d1"}, # 纯空白 {"id": "c3", "content": "内容C", "document_id": "d1"}, # 有效 ] result = await gen.generate(chunks) # 只有 c3 生成问题 assert len(result) == 1 assert result[0].chunk_id == "c3" # LLM 只被调用 1 次 assert chat_mock.await_count == 1 async def test_cache_avoids_duplicate_calls(self): """缓存生效 — 同一 chunk 内容不重复调用 LLM。""" mock_gw, chat_mock = _make_mock_llm_gateway(["问题A"]) gen = QuestionGenerator(llm_gateway=mock_gw, cache=True) chunks = [{"id": "c1", "content": "相同内容", "document_id": "d1"}] # 第一次调用 result1 = await gen.generate(chunks) assert len(result1) == 1 # 第二次调用相同内容 result2 = await gen.generate(chunks) assert len(result2) == 1 # LLM 只被调用 1 次(缓存命中) assert chat_mock.await_count == 1 async def test_no_cache_calls_each_time(self): """禁用缓存时每次都调用 LLM。""" mock_gw, chat_mock = _make_mock_llm_gateway(["问题A", "问题A"]) gen = QuestionGenerator(llm_gateway=mock_gw, cache=False) chunks = [{"id": "c1", "content": "相同内容", "document_id": "d1"}] await gen.generate(chunks) await gen.generate(chunks) # LLM 被调用 2 次 assert chat_mock.await_count == 2 async def test_clear_cache(self): """clear_cache 清除缓存后重新调用 LLM。""" mock_gw, chat_mock = _make_mock_llm_gateway(["问题A", "问题A"]) gen = QuestionGenerator(llm_gateway=mock_gw, cache=True) chunks = [{"id": "c1", "content": "相同内容", "document_id": "d1"}] await gen.generate(chunks) gen.clear_cache() await gen.generate(chunks) # 缓存清除后 LLM 被调用 2 次 assert chat_mock.await_count == 2 async def test_max_questions_in_prompt(self): """prompt 中包含 max_questions_per_chunk 数量。""" mock_gw, chat_mock = _make_mock_llm_gateway(["问题A"]) gen = QuestionGenerator(llm_gateway=mock_gw, max_questions_per_chunk=5) chunks = [{"id": "c1", "content": "内容", "document_id": "d1"}] await gen.generate(chunks) call_args = chat_mock.await_args messages = call_args.args[0] if call_args.args else call_args.kwargs["messages"] prompt_content = messages[0]["content"] # prompt 中应包含 "5 个问题" assert "5" in prompt_content async def test_truncates_long_chunk(self): """超长 chunk 被截断到 2000 字符(避免 prompt 过长)。""" mock_gw, chat_mock = _make_mock_llm_gateway(["问题A"]) gen = QuestionGenerator(llm_gateway=mock_gw) long_content = "A" * 3000 chunks = [{"id": "c1", "content": long_content, "document_id": "d1"}] await gen.generate(chunks) call_args = chat_mock.await_args messages = call_args.args[0] if call_args.args else call_args.kwargs["messages"] prompt_content = messages[0]["content"] # prompt 中不应包含完整的 3000 字符内容(被截断到 2000) assert prompt_content.count("A") < 3000 class TestNumberPrefixRegex: """_NUMBER_PREFIX_RE 正则测试。""" def test_dot_prefix(self): """点号编号前缀匹配。""" assert _NUMBER_PREFIX_RE.match("1. 问题") assert _NUMBER_PREFIX_RE.match("12. 问题") def test_paren_prefix(self): """括号编号前缀匹配。""" assert _NUMBER_PREFIX_RE.match("1) 问题") assert _NUMBER_PREFIX_RE.match("2) 问题") def test_chinese_paren_prefix(self): """中文顿号编号前缀匹配。""" assert _NUMBER_PREFIX_RE.match("1、 问题") def test_no_match_plain_text(self): """纯文本不匹配。""" assert not _NUMBER_PREFIX_RE.match("问题一") assert not _NUMBER_PREFIX_RE.match("什么是 RAG?")