320 lines
13 KiB
Python
320 lines
13 KiB
Python
"""U6 测试 — 命中处理(model_opt / direct 模式)。
|
||
|
||
测试场景:
|
||
1. model_opt 模式:LLM 基于检索结果生成回答
|
||
2. direct 模式:直接返回匹配段落
|
||
3. 空结果返回"未找到相关内容"
|
||
4. 无 LLM 网关时降级为 direct 模式
|
||
5. LLM 调用失败时降级为 direct 模式
|
||
6. 缓存命中返回 cached=True
|
||
7. KB 默认模式生效
|
||
8. Agent 运行时覆盖默认模式
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
from unittest.mock import AsyncMock, MagicMock
|
||
|
||
from agentkit.rag_platform.hit_processing import (
|
||
HIT_PROCESSING_DIRECT,
|
||
HIT_PROCESSING_MODEL_OPT,
|
||
HitProcessor,
|
||
)
|
||
from agentkit.rag_platform.models import QueryResult
|
||
from agentkit.rag_platform.settings import KBSettings
|
||
|
||
|
||
def _make_query_result(
|
||
chunk_id: str,
|
||
content: str,
|
||
score: float = 0.5,
|
||
document_id: str = "doc1",
|
||
kb_id: str = "kb1",
|
||
) -> QueryResult:
|
||
"""创建测试用 QueryResult。"""
|
||
return QueryResult(
|
||
chunk_id=chunk_id,
|
||
content=content,
|
||
score=score,
|
||
metadata={"document_id": document_id, "kb_id": kb_id},
|
||
document_id=document_id,
|
||
kb_id=kb_id,
|
||
)
|
||
|
||
|
||
def _make_mock_llm_gateway(response_content: str = "LLM 生成的回答") -> MagicMock:
|
||
"""创建 mock LLM 网关 — chat() 返回带 content 的 response。"""
|
||
mock = MagicMock()
|
||
mock_response = MagicMock()
|
||
mock_response.content = response_content
|
||
mock.chat = AsyncMock(return_value=mock_response)
|
||
return mock
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# direct 模式测试
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestDirectMode:
|
||
"""direct 模式测试。"""
|
||
|
||
async def test_direct_returns_concatenated_chunks(self):
|
||
"""direct 模式返回拼接的匹配段落,按分数降序。"""
|
||
processor = HitProcessor(llm_gateway=None, cache_enabled=False)
|
||
results = [
|
||
_make_query_result("c1", "段落一", score=0.9),
|
||
_make_query_result("c2", "段落二", score=0.7),
|
||
]
|
||
result = await processor.process("query", results, mode=HIT_PROCESSING_DIRECT)
|
||
|
||
assert result.mode == HIT_PROCESSING_DIRECT
|
||
assert "段落一" in result.answer
|
||
assert "段落二" in result.answer
|
||
assert result.cached is False
|
||
assert len(result.sources) == 2
|
||
# 按分数降序
|
||
assert result.sources[0].chunk_id == "c1"
|
||
|
||
async def test_direct_format_includes_index(self):
|
||
"""direct 模式段落带序号 [1] [2]。"""
|
||
processor = HitProcessor(llm_gateway=None, cache_enabled=False)
|
||
results = [_make_query_result("c1", "唯一段落")]
|
||
result = await processor.process("query", results, mode=HIT_PROCESSING_DIRECT)
|
||
|
||
assert "[1]" in result.answer
|
||
assert "唯一段落" in result.answer
|
||
|
||
async def test_direct_empty_results(self):
|
||
"""空结果返回"未找到相关内容"。"""
|
||
processor = HitProcessor(llm_gateway=None, cache_enabled=False)
|
||
result = await processor.process("query", [], mode=HIT_PROCESSING_DIRECT)
|
||
|
||
assert result.answer == "未找到相关内容。"
|
||
assert result.sources == []
|
||
assert result.cached is False
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# model_opt 模式测试
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestModelOptMode:
|
||
"""model_opt 模式测试。"""
|
||
|
||
async def test_model_opt_calls_llm(self):
|
||
"""model_opt 模式调用 LLM 生成回答。"""
|
||
mock_llm = _make_mock_llm_gateway("基于资料的回答")
|
||
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False)
|
||
results = [_make_query_result("c1", "参考资料内容")]
|
||
result = await processor.process("问题", results, mode=HIT_PROCESSING_MODEL_OPT)
|
||
|
||
assert result.mode == HIT_PROCESSING_MODEL_OPT
|
||
assert result.answer == "基于资料的回答"
|
||
assert len(result.sources) == 1
|
||
mock_llm.chat.assert_awaited_once()
|
||
|
||
async def test_model_opt_includes_context_in_prompt(self):
|
||
"""model_opt 模式将检索结果作为 context 传入 LLM。"""
|
||
mock_llm = _make_mock_llm_gateway("回答")
|
||
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False)
|
||
results = [_make_query_result("c1", "独特的资料文本")]
|
||
await processor.process("问题", results, mode=HIT_PROCESSING_MODEL_OPT)
|
||
|
||
call_args = mock_llm.chat.await_args
|
||
messages = call_args.kwargs.get("messages") or call_args.args[0]
|
||
# 检查 context 出现在 prompt 中
|
||
all_content = " ".join(m["content"] for m in messages)
|
||
assert "独特的资料文本" in all_content
|
||
|
||
async def test_model_opt_passes_query_in_prompt(self):
|
||
"""model_opt 模式将用户查询传入 LLM prompt。"""
|
||
mock_llm = _make_mock_llm_gateway("回答")
|
||
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False)
|
||
results = [_make_query_result("c1", "资料")]
|
||
await processor.process("独特查询文本", results, mode=HIT_PROCESSING_MODEL_OPT)
|
||
|
||
call_args = mock_llm.chat.await_args
|
||
messages = call_args.kwargs.get("messages") or call_args.args[0]
|
||
all_content = " ".join(m["content"] for m in messages)
|
||
assert "独特查询文本" in all_content
|
||
|
||
async def test_model_opt_no_llm_falls_back_to_direct(self):
|
||
"""无 LLM 网关时降级为 direct 模式。"""
|
||
processor = HitProcessor(llm_gateway=None, cache_enabled=False)
|
||
results = [_make_query_result("c1", "段落内容")]
|
||
result = await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT)
|
||
|
||
# 降级为 direct — answer 是拼接的段落
|
||
assert "段落内容" in result.answer
|
||
|
||
async def test_model_opt_llm_failure_falls_back_to_direct(self):
|
||
"""LLM 调用失败时降级为 direct 模式(不丢失检索结果)。"""
|
||
mock_llm = MagicMock()
|
||
mock_llm.chat = AsyncMock(side_effect=RuntimeError("LLM 不可用"))
|
||
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False)
|
||
results = [_make_query_result("c1", "降级段落")]
|
||
result = await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT)
|
||
|
||
assert "降级段落" in result.answer
|
||
|
||
async def test_model_opt_empty_results(self):
|
||
"""空结果返回"未找到相关内容",不调用 LLM。"""
|
||
mock_llm = _make_mock_llm_gateway()
|
||
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False)
|
||
result = await processor.process("query", [], mode=HIT_PROCESSING_MODEL_OPT)
|
||
|
||
assert result.answer == "未找到相关内容。"
|
||
mock_llm.chat.assert_not_awaited()
|
||
|
||
async def test_model_opt_sorts_by_score(self):
|
||
"""model_opt 模式按分数降序构建 context。"""
|
||
mock_llm = _make_mock_llm_gateway("回答")
|
||
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False)
|
||
results = [
|
||
_make_query_result("c2", "低分段落", score=0.3),
|
||
_make_query_result("c1", "高分段落", score=0.9),
|
||
]
|
||
result = await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT)
|
||
|
||
assert result.sources[0].chunk_id == "c1"
|
||
assert result.sources[1].chunk_id == "c2"
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# 缓存测试
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestCaching:
|
||
"""缓存测试。"""
|
||
|
||
async def test_cache_hit_returns_cached_true(self):
|
||
"""相同输入第二次返回 cached=True。"""
|
||
mock_llm = _make_mock_llm_gateway("回答")
|
||
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=True)
|
||
results = [_make_query_result("c1", "内容")]
|
||
|
||
first = await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT)
|
||
second = await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT)
|
||
|
||
assert first.cached is False
|
||
assert second.cached is True
|
||
# LLM 只调用一次(第二次命中缓存)
|
||
assert mock_llm.chat.await_count == 1
|
||
|
||
async def test_cache_disabled_no_caching(self):
|
||
"""禁用缓存时每次都调用 LLM。"""
|
||
mock_llm = _make_mock_llm_gateway("回答")
|
||
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False)
|
||
results = [_make_query_result("c1", "内容")]
|
||
|
||
await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT)
|
||
await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT)
|
||
|
||
assert mock_llm.chat.await_count == 2
|
||
|
||
async def test_cache_key_includes_query(self):
|
||
"""不同查询不共享缓存。"""
|
||
mock_llm = _make_mock_llm_gateway("回答")
|
||
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=True)
|
||
results = [_make_query_result("c1", "内容")]
|
||
|
||
await processor.process("query1", results, mode=HIT_PROCESSING_MODEL_OPT)
|
||
await processor.process("query2", results, mode=HIT_PROCESSING_MODEL_OPT)
|
||
|
||
assert mock_llm.chat.await_count == 2
|
||
|
||
async def test_cache_key_includes_chunk_ids(self):
|
||
"""不同检索结果不共享缓存。"""
|
||
mock_llm = _make_mock_llm_gateway("回答")
|
||
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=True)
|
||
|
||
await processor.process(
|
||
"query", [_make_query_result("c1", "内容")], mode=HIT_PROCESSING_MODEL_OPT
|
||
)
|
||
await processor.process(
|
||
"query", [_make_query_result("c2", "内容")], mode=HIT_PROCESSING_MODEL_OPT
|
||
)
|
||
|
||
assert mock_llm.chat.await_count == 2
|
||
|
||
async def test_cache_key_includes_mode(self):
|
||
"""不同模式不共享缓存。"""
|
||
mock_llm = _make_mock_llm_gateway("回答")
|
||
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=True)
|
||
results = [_make_query_result("c1", "内容")]
|
||
|
||
await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT)
|
||
await processor.process("query", results, mode=HIT_PROCESSING_DIRECT)
|
||
|
||
# direct 模式不调用 LLM,但也不应命中 model_opt 的缓存
|
||
assert mock_llm.chat.await_count == 1
|
||
|
||
async def test_cache_not_used_for_empty_results(self):
|
||
"""空结果不写入缓存。"""
|
||
mock_llm = _make_mock_llm_gateway()
|
||
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=True)
|
||
|
||
first = await processor.process("query", [], mode=HIT_PROCESSING_MODEL_OPT)
|
||
second = await processor.process("query", [], mode=HIT_PROCESSING_MODEL_OPT)
|
||
|
||
assert first.cached is False
|
||
assert second.cached is False
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# KB 默认模式 + 运行时覆盖测试
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestKBDefaultMode:
|
||
"""KB 默认模式生效测试。"""
|
||
|
||
async def test_kb_default_model_opt(self):
|
||
"""KB 默认 model_opt 模式生效。"""
|
||
mock_llm = _make_mock_llm_gateway("回答")
|
||
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False)
|
||
settings = KBSettings(kb_id="kb1") # 默认 model_opt
|
||
results = [_make_query_result("c1", "内容")]
|
||
|
||
result = await processor.process("query", results, mode=settings.default_hit_processing)
|
||
assert result.mode == HIT_PROCESSING_MODEL_OPT
|
||
mock_llm.chat.assert_awaited_once()
|
||
|
||
async def test_kb_default_direct(self):
|
||
"""KB 默认 direct 模式生效。"""
|
||
processor = HitProcessor(llm_gateway=None, cache_enabled=False)
|
||
settings = KBSettings(kb_id="kb1", default_hit_processing=HIT_PROCESSING_DIRECT)
|
||
results = [_make_query_result("c1", "内容")]
|
||
|
||
result = await processor.process("query", results, mode=settings.default_hit_processing)
|
||
assert result.mode == HIT_PROCESSING_DIRECT
|
||
|
||
|
||
class TestModeOverride:
|
||
"""Agent 运行时覆盖默认模式测试。"""
|
||
|
||
async def test_override_to_direct(self):
|
||
"""运行时覆盖为 direct 模式 — 不调用 LLM。"""
|
||
mock_llm = _make_mock_llm_gateway("LLM 回答")
|
||
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False)
|
||
results = [_make_query_result("c1", "段落")]
|
||
|
||
# KB 默认 model_opt,运行时覆盖为 direct
|
||
result = await processor.process("query", results, mode=HIT_PROCESSING_DIRECT)
|
||
assert result.mode == HIT_PROCESSING_DIRECT
|
||
mock_llm.chat.assert_not_awaited()
|
||
|
||
async def test_override_to_model_opt(self):
|
||
"""运行时覆盖为 model_opt 模式 — 调用 LLM。"""
|
||
mock_llm = _make_mock_llm_gateway("LLM 回答")
|
||
processor = HitProcessor(llm_gateway=mock_llm, cache_enabled=False)
|
||
results = [_make_query_result("c1", "段落")]
|
||
|
||
# KB 默认 direct,运行时覆盖为 model_opt
|
||
result = await processor.process("query", results, mode=HIT_PROCESSING_MODEL_OPT)
|
||
assert result.mode == HIT_PROCESSING_MODEL_OPT
|
||
mock_llm.chat.assert_awaited_once()
|