"""U6 测试 — 命中处理(model_opt / direct 模式)。 测试场景: 1. model_opt 模式:LLM 基于检索结果生成回答 2. direct 模式:直接返回匹配段落 3. 空结果返回"未找到相关内容" 4. 无 LLM 网关时降级为 direct 模式 5. LLM 调用失败时降级为 direct 模式 6. 缓存命中返回 cached=True 7. KB 默认模式生效 8. Agent 运行时覆盖默认模式 """ from __future__ import annotations from unittest.mock import AsyncMock, MagicMock from agentkit.rag_platform.hit_processing import ( HIT_PROCESSING_DIRECT, HIT_PROCESSING_MODEL_OPT, HitProcessor, ) from agentkit.rag_platform.models import QueryResult from agentkit.rag_platform.settings import KBSettings def _make_query_result( chunk_id: str, content: str, score: float = 0.5, document_id: str = "doc1", kb_id: str = "kb1", ) -> QueryResult: """创建测试用 QueryResult。""" return QueryResult( chunk_id=chunk_id, content=content, score=score, metadata={"document_id": document_id, "kb_id": kb_id}, document_id=document_id, kb_id=kb_id, ) def _make_mock_llm_gateway(response_content: str = "LLM 生成的回答") -> MagicMock: """创建 mock LLM 网关 — chat() 返回带 content 的 response。""" mock = MagicMock() mock_response = MagicMock() mock_response.content = response_content mock.chat = AsyncMock(return_value=mock_response) return mock # --------------------------------------------------------------------------- # direct 模式测试 # --------------------------------------------------------------------------- class TestDirectMode: """direct 模式测试。""" async def test_direct_returns_concatenated_chunks(self): """direct 模式返回拼接的匹配段落,按分数降序。""" processor = HitProcessor(llm_gateway=None, cache_enabled=False) results = [ _make_query_result("c1", "段落一", score=0.9), _make_query_result("c2", "段落二", score=0.7), ] result = await processor.process("query", results, mode=HIT_PROCESSING_DIRECT) assert result.mode == HIT_PROCESSING_DIRECT assert "段落一" in result.answer assert "段落二" in result.answer assert result.cached is False assert len(result.sources) == 2 # 按分数降序 assert result.sources[0].chunk_id == "c1" async def test_direct_format_includes_index(self): """direct 模式段落带序号 [1] [2]。""" processor = HitProcessor(llm_gateway=None, cache_enabled=False) results = [_make_query_result("c1", "唯一段落")] result = await processor.process("query", results, mode=HIT_PROCESSING_DIRECT) assert "[1]" in result.answer assert "唯一段落" in result.answer async def test_direct_empty_results(self): """空结果返回"未找到相关内容"。""" processor = HitProcessor(llm_gateway=None, cache_enabled=False) result = await processor.process("query", [], mode=HIT_PROCESSING_DIRECT) assert result.answer == "未找到相关内容。" assert result.sources == [] assert result.cached is False # --------------------------------------------------------------------------- # model_opt 模式测试 # --------------------------------------------------------------------------- class TestModelOptMode: """model_opt 模式测试。""" async def test_model_opt_calls_llm(self): """model_opt 模式调用 LLM 生成回答。""" mock_llm = _make_mock_llm_gateway("基于资料的回答") processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False) results = [_make_query_result("c1", "参考资料内容")] result = await processor.process("问题", results, mode=HIT_PROCESSING_MODEL_OPT) assert result.mode == HIT_PROCESSING_MODEL_OPT assert result.answer == "基于资料的回答" assert len(result.sources) == 1 mock_llm.chat.assert_awaited_once() async def test_model_opt_includes_context_in_prompt(self): """model_opt 模式将检索结果作为 context 传入 LLM。""" mock_llm = _make_mock_llm_gateway("回答") processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False) results = [_make_query_result("c1", "独特的资料文本")] await processor.process("问题", results, mode=HIT_PROCESSING_MODEL_OPT) call_args = mock_llm.chat.await_args messages = call_args.kwargs.get("messages") or call_args.args[0] # 检查 context 出现在 prompt 中 all_content = " ".join(m["content"] for m in messages) assert "独特的资料文本" in all_content async def test_model_opt_passes_query_in_prompt(self): """model_opt 模式将用户查询传入 LLM prompt。""" mock_llm = _make_mock_llm_gateway("回答") processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False) results = [_make_query_result("c1", "资料")] await processor.process("独特查询文本", results, mode=HIT_PROCESSING_MODEL_OPT) call_args = mock_llm.chat.await_args messages = call_args.kwargs.get("messages") or call_args.args[0] all_content = " ".join(m["content"] for m in messages) assert "独特查询文本" in all_content async def test_model_opt_no_llm_falls_back_to_direct(self): """无 LLM 网关时降级为 direct 模式。""" processor = HitProcessor(llm_gateway=None, cache_enabled=False) results = [_make_query_result("c1", "段落内容")] result = await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT) # 降级为 direct — answer 是拼接的段落 assert "段落内容" in result.answer async def test_model_opt_llm_failure_falls_back_to_direct(self): """LLM 调用失败时降级为 direct 模式(不丢失检索结果)。""" mock_llm = MagicMock() mock_llm.chat = AsyncMock(side_effect=RuntimeError("LLM 不可用")) processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False) results = [_make_query_result("c1", "降级段落")] result = await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT) assert "降级段落" in result.answer async def test_model_opt_empty_results(self): """空结果返回"未找到相关内容",不调用 LLM。""" mock_llm = _make_mock_llm_gateway() processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False) result = await processor.process("query", [], mode=HIT_PROCESSING_MODEL_OPT) assert result.answer == "未找到相关内容。" mock_llm.chat.assert_not_awaited() async def test_model_opt_sorts_by_score(self): """model_opt 模式按分数降序构建 context。""" mock_llm = _make_mock_llm_gateway("回答") processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False) results = [ _make_query_result("c2", "低分段落", score=0.3), _make_query_result("c1", "高分段落", score=0.9), ] result = await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT) assert result.sources[0].chunk_id == "c1" assert result.sources[1].chunk_id == "c2" # --------------------------------------------------------------------------- # 缓存测试 # --------------------------------------------------------------------------- class TestCaching: """缓存测试。""" async def test_cache_hit_returns_cached_true(self): """相同输入第二次返回 cached=True。""" mock_llm = _make_mock_llm_gateway("回答") processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=True) results = [_make_query_result("c1", "内容")] first = await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT) second = await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT) assert first.cached is False assert second.cached is True # LLM 只调用一次(第二次命中缓存) assert mock_llm.chat.await_count == 1 async def test_cache_disabled_no_caching(self): """禁用缓存时每次都调用 LLM。""" mock_llm = _make_mock_llm_gateway("回答") processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False) results = [_make_query_result("c1", "内容")] await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT) await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT) assert mock_llm.chat.await_count == 2 async def test_cache_key_includes_query(self): """不同查询不共享缓存。""" mock_llm = _make_mock_llm_gateway("回答") processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=True) results = [_make_query_result("c1", "内容")] await processor.process("query1", results, mode=HIT_PROCESSING_MODEL_OPT) await processor.process("query2", results, mode=HIT_PROCESSING_MODEL_OPT) assert mock_llm.chat.await_count == 2 async def test_cache_key_includes_chunk_ids(self): """不同检索结果不共享缓存。""" mock_llm = _make_mock_llm_gateway("回答") processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=True) await processor.process( "query", [_make_query_result("c1", "内容")], mode=HIT_PROCESSING_MODEL_OPT ) await processor.process( "query", [_make_query_result("c2", "内容")], mode=HIT_PROCESSING_MODEL_OPT ) assert mock_llm.chat.await_count == 2 async def test_cache_key_includes_mode(self): """不同模式不共享缓存。""" mock_llm = _make_mock_llm_gateway("回答") processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=True) results = [_make_query_result("c1", "内容")] await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT) await processor.process("query", results, mode=HIT_PROCESSING_DIRECT) # direct 模式不调用 LLM,但也不应命中 model_opt 的缓存 assert mock_llm.chat.await_count == 1 async def test_cache_not_used_for_empty_results(self): """空结果不写入缓存。""" mock_llm = _make_mock_llm_gateway() processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=True) first = await processor.process("query", [], mode=HIT_PROCESSING_MODEL_OPT) second = await processor.process("query", [], mode=HIT_PROCESSING_MODEL_OPT) assert first.cached is False assert second.cached is False # --------------------------------------------------------------------------- # KB 默认模式 + 运行时覆盖测试 # --------------------------------------------------------------------------- class TestKBDefaultMode: """KB 默认模式生效测试。""" async def test_kb_default_model_opt(self): """KB 默认 model_opt 模式生效。""" mock_llm = _make_mock_llm_gateway("回答") processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False) settings = KBSettings(kb_id="kb1") # 默认 model_opt results = [_make_query_result("c1", "内容")] result = await processor.process("query", results, mode=settings.default_hit_processing) assert result.mode == HIT_PROCESSING_MODEL_OPT mock_llm.chat.assert_awaited_once() async def test_kb_default_direct(self): """KB 默认 direct 模式生效。""" processor = HitProcessor(llm_gateway=None, cache_enabled=False) settings = KBSettings(kb_id="kb1", default_hit_processing=HIT_PROCESSING_DIRECT) results = [_make_query_result("c1", "内容")] result = await processor.process("query", results, mode=settings.default_hit_processing) assert result.mode == HIT_PROCESSING_DIRECT class TestModeOverride: """Agent 运行时覆盖默认模式测试。""" async def test_override_to_direct(self): """运行时覆盖为 direct 模式 — 不调用 LLM。""" mock_llm = _make_mock_llm_gateway("LLM 回答") processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False) results = [_make_query_result("c1", "段落")] # KB 默认 model_opt,运行时覆盖为 direct result = await processor.process("query", results, mode=HIT_PROCESSING_DIRECT) assert result.mode == HIT_PROCESSING_DIRECT mock_llm.chat.assert_not_awaited() async def test_override_to_model_opt(self): """运行时覆盖为 model_opt 模式 — 调用 LLM。""" mock_llm = _make_mock_llm_gateway("LLM 回答") processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False) results = [_make_query_result("c1", "段落")] # KB 默认 direct,运行时覆盖为 model_opt result = await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT) assert result.mode == HIT_PROCESSING_MODEL_OPT mock_llm.chat.assert_awaited_once()