fischer-agentkit/tests/unit/rag_platform/test_hit_processing.py

320 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""U6 测试 — 命中处理model_opt / direct 模式)。
测试场景:
1. model_opt 模式LLM 基于检索结果生成回答
2. direct 模式:直接返回匹配段落
3. 空结果返回"未找到相关内容"
4. 无 LLM 网关时降级为 direct 模式
5. LLM 调用失败时降级为 direct 模式
6. 缓存命中返回 cached=True
7. KB 默认模式生效
8. Agent 运行时覆盖默认模式
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock
from agentkit.rag_platform.hit_processing import (
HIT_PROCESSING_DIRECT,
HIT_PROCESSING_MODEL_OPT,
HitProcessor,
)
from agentkit.rag_platform.models import QueryResult
from agentkit.rag_platform.settings import KBSettings
def _make_query_result(
chunk_id: str,
content: str,
score: float = 0.5,
document_id: str = "doc1",
kb_id: str = "kb1",
) -> QueryResult:
"""创建测试用 QueryResult。"""
return QueryResult(
chunk_id=chunk_id,
content=content,
score=score,
metadata={"document_id": document_id, "kb_id": kb_id},
document_id=document_id,
kb_id=kb_id,
)
def _make_mock_llm_gateway(response_content: str = "LLM 生成的回答") -> MagicMock:
"""创建 mock LLM 网关 — chat() 返回带 content 的 response。"""
mock = MagicMock()
mock_response = MagicMock()
mock_response.content = response_content
mock.chat = AsyncMock(return_value=mock_response)
return mock
# ---------------------------------------------------------------------------
# direct 模式测试
# ---------------------------------------------------------------------------
class TestDirectMode:
"""direct 模式测试。"""
async def test_direct_returns_concatenated_chunks(self):
"""direct 模式返回拼接的匹配段落,按分数降序。"""
processor = HitProcessor(llm_gateway=None, cache_enabled=False)
results = [
_make_query_result("c1", "段落一", score=0.9),
_make_query_result("c2", "段落二", score=0.7),
]
result = await processor.process("query", results, mode=HIT_PROCESSING_DIRECT)
assert result.mode == HIT_PROCESSING_DIRECT
assert "段落一" in result.answer
assert "段落二" in result.answer
assert result.cached is False
assert len(result.sources) == 2
# 按分数降序
assert result.sources[0].chunk_id == "c1"
async def test_direct_format_includes_index(self):
"""direct 模式段落带序号 [1] [2]。"""
processor = HitProcessor(llm_gateway=None, cache_enabled=False)
results = [_make_query_result("c1", "唯一段落")]
result = await processor.process("query", results, mode=HIT_PROCESSING_DIRECT)
assert "[1]" in result.answer
assert "唯一段落" in result.answer
async def test_direct_empty_results(self):
"""空结果返回"未找到相关内容""""
processor = HitProcessor(llm_gateway=None, cache_enabled=False)
result = await processor.process("query", [], mode=HIT_PROCESSING_DIRECT)
assert result.answer == "未找到相关内容。"
assert result.sources == []
assert result.cached is False
# ---------------------------------------------------------------------------
# model_opt 模式测试
# ---------------------------------------------------------------------------
class TestModelOptMode:
"""model_opt 模式测试。"""
async def test_model_opt_calls_llm(self):
"""model_opt 模式调用 LLM 生成回答。"""
mock_llm = _make_mock_llm_gateway("基于资料的回答")
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False)
results = [_make_query_result("c1", "参考资料内容")]
result = await processor.process("问题", results, mode=HIT_PROCESSING_MODEL_OPT)
assert result.mode == HIT_PROCESSING_MODEL_OPT
assert result.answer == "基于资料的回答"
assert len(result.sources) == 1
mock_llm.chat.assert_awaited_once()
async def test_model_opt_includes_context_in_prompt(self):
"""model_opt 模式将检索结果作为 context 传入 LLM。"""
mock_llm = _make_mock_llm_gateway("回答")
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False)
results = [_make_query_result("c1", "独特的资料文本")]
await processor.process("问题", results, mode=HIT_PROCESSING_MODEL_OPT)
call_args = mock_llm.chat.await_args
messages = call_args.kwargs.get("messages") or call_args.args[0]
# 检查 context 出现在 prompt 中
all_content = " ".join(m["content"] for m in messages)
assert "独特的资料文本" in all_content
async def test_model_opt_passes_query_in_prompt(self):
"""model_opt 模式将用户查询传入 LLM prompt。"""
mock_llm = _make_mock_llm_gateway("回答")
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False)
results = [_make_query_result("c1", "资料")]
await processor.process("独特查询文本", results, mode=HIT_PROCESSING_MODEL_OPT)
call_args = mock_llm.chat.await_args
messages = call_args.kwargs.get("messages") or call_args.args[0]
all_content = " ".join(m["content"] for m in messages)
assert "独特查询文本" in all_content
async def test_model_opt_no_llm_falls_back_to_direct(self):
"""无 LLM 网关时降级为 direct 模式。"""
processor = HitProcessor(llm_gateway=None, cache_enabled=False)
results = [_make_query_result("c1", "段落内容")]
result = await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT)
# 降级为 direct — answer 是拼接的段落
assert "段落内容" in result.answer
async def test_model_opt_llm_failure_falls_back_to_direct(self):
"""LLM 调用失败时降级为 direct 模式(不丢失检索结果)。"""
mock_llm = MagicMock()
mock_llm.chat = AsyncMock(side_effect=RuntimeError("LLM 不可用"))
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False)
results = [_make_query_result("c1", "降级段落")]
result = await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT)
assert "降级段落" in result.answer
async def test_model_opt_empty_results(self):
"""空结果返回"未找到相关内容",不调用 LLM。"""
mock_llm = _make_mock_llm_gateway()
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False)
result = await processor.process("query", [], mode=HIT_PROCESSING_MODEL_OPT)
assert result.answer == "未找到相关内容。"
mock_llm.chat.assert_not_awaited()
async def test_model_opt_sorts_by_score(self):
"""model_opt 模式按分数降序构建 context。"""
mock_llm = _make_mock_llm_gateway("回答")
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False)
results = [
_make_query_result("c2", "低分段落", score=0.3),
_make_query_result("c1", "高分段落", score=0.9),
]
result = await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT)
assert result.sources[0].chunk_id == "c1"
assert result.sources[1].chunk_id == "c2"
# ---------------------------------------------------------------------------
# 缓存测试
# ---------------------------------------------------------------------------
class TestCaching:
"""缓存测试。"""
async def test_cache_hit_returns_cached_true(self):
"""相同输入第二次返回 cached=True。"""
mock_llm = _make_mock_llm_gateway("回答")
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=True)
results = [_make_query_result("c1", "内容")]
first = await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT)
second = await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT)
assert first.cached is False
assert second.cached is True
# LLM 只调用一次(第二次命中缓存)
assert mock_llm.chat.await_count == 1
async def test_cache_disabled_no_caching(self):
"""禁用缓存时每次都调用 LLM。"""
mock_llm = _make_mock_llm_gateway("回答")
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False)
results = [_make_query_result("c1", "内容")]
await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT)
await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT)
assert mock_llm.chat.await_count == 2
async def test_cache_key_includes_query(self):
"""不同查询不共享缓存。"""
mock_llm = _make_mock_llm_gateway("回答")
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=True)
results = [_make_query_result("c1", "内容")]
await processor.process("query1", results, mode=HIT_PROCESSING_MODEL_OPT)
await processor.process("query2", results, mode=HIT_PROCESSING_MODEL_OPT)
assert mock_llm.chat.await_count == 2
async def test_cache_key_includes_chunk_ids(self):
"""不同检索结果不共享缓存。"""
mock_llm = _make_mock_llm_gateway("回答")
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=True)
await processor.process(
"query", [_make_query_result("c1", "内容")], mode=HIT_PROCESSING_MODEL_OPT
)
await processor.process(
"query", [_make_query_result("c2", "内容")], mode=HIT_PROCESSING_MODEL_OPT
)
assert mock_llm.chat.await_count == 2
async def test_cache_key_includes_mode(self):
"""不同模式不共享缓存。"""
mock_llm = _make_mock_llm_gateway("回答")
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=True)
results = [_make_query_result("c1", "内容")]
await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT)
await processor.process("query", results, mode=HIT_PROCESSING_DIRECT)
# direct 模式不调用 LLM但也不应命中 model_opt 的缓存
assert mock_llm.chat.await_count == 1
async def test_cache_not_used_for_empty_results(self):
"""空结果不写入缓存。"""
mock_llm = _make_mock_llm_gateway()
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=True)
first = await processor.process("query", [], mode=HIT_PROCESSING_MODEL_OPT)
second = await processor.process("query", [], mode=HIT_PROCESSING_MODEL_OPT)
assert first.cached is False
assert second.cached is False
# ---------------------------------------------------------------------------
# KB 默认模式 + 运行时覆盖测试
# ---------------------------------------------------------------------------
class TestKBDefaultMode:
"""KB 默认模式生效测试。"""
async def test_kb_default_model_opt(self):
"""KB 默认 model_opt 模式生效。"""
mock_llm = _make_mock_llm_gateway("回答")
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False)
settings = KBSettings(kb_id="kb1") # 默认 model_opt
results = [_make_query_result("c1", "内容")]
result = await processor.process("query", results, mode=settings.default_hit_processing)
assert result.mode == HIT_PROCESSING_MODEL_OPT
mock_llm.chat.assert_awaited_once()
async def test_kb_default_direct(self):
"""KB 默认 direct 模式生效。"""
processor = HitProcessor(llm_gateway=None, cache_enabled=False)
settings = KBSettings(kb_id="kb1", default_hit_processing=HIT_PROCESSING_DIRECT)
results = [_make_query_result("c1", "内容")]
result = await processor.process("query", results, mode=settings.default_hit_processing)
assert result.mode == HIT_PROCESSING_DIRECT
class TestModeOverride:
"""Agent 运行时覆盖默认模式测试。"""
async def test_override_to_direct(self):
"""运行时覆盖为 direct 模式 — 不调用 LLM。"""
mock_llm = _make_mock_llm_gateway("LLM 回答")
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False)
results = [_make_query_result("c1", "段落")]
# KB 默认 model_opt运行时覆盖为 direct
result = await processor.process("query", results, mode=HIT_PROCESSING_DIRECT)
assert result.mode == HIT_PROCESSING_DIRECT
mock_llm.chat.assert_not_awaited()
async def test_override_to_model_opt(self):
"""运行时覆盖为 model_opt 模式 — 调用 LLM。"""
mock_llm = _make_mock_llm_gateway("LLM 回答")
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False)
results = [_make_query_result("c1", "段落")]
# KB 默认 direct运行时覆盖为 model_opt
result = await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT)
assert result.mode == HIT_PROCESSING_MODEL_OPT
mock_llm.chat.assert_awaited_once()