238 lines
9.4 KiB
Python
238 lines
9.4 KiB
Python
"""RAG 服务单元测试"""
|
||
import pytest
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# RecursiveChunker 测试
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestRecursiveChunker:
|
||
@pytest.fixture
|
||
def chunker(self):
|
||
from app.services.knowledge.chunker import RecursiveChunker
|
||
return RecursiveChunker()
|
||
|
||
def test_chunker_basic_split(self, chunker):
|
||
"""简单文本分块返回非空列表"""
|
||
# 使用足够长的文本确保超过min_chunk_size
|
||
text = "This is a simple test with enough content to pass the minimum chunk size threshold.\n\nAnother paragraph here with more content to ensure it meets the size requirement."
|
||
chunks = chunker.chunk(text)
|
||
|
||
assert len(chunks) > 0
|
||
for chunk in chunks:
|
||
assert "content" in chunk
|
||
assert "chunk_index" in chunk
|
||
assert "metadata" in chunk
|
||
|
||
def test_chunker_chinese_text(self, chunker):
|
||
"""中文文本按句号分割"""
|
||
# 使用足够长的文本确保超过min_chunk_size
|
||
text = "这是第一句话,内容需要足够长才能满足最小分块大小的要求。这是第二句话,同样需要足够的长度来确保分块器能够正确处理。这是第三句话,继续增加文本长度以满足分块条件。"
|
||
chunks = chunker.chunk(text)
|
||
|
||
assert len(chunks) > 0
|
||
# 所有块内容不为空
|
||
for chunk in chunks:
|
||
assert chunk["content"].strip() != ""
|
||
|
||
def test_chunker_respects_max_size(self, chunker):
|
||
"""每个 chunk 的内容不超过 chunk_size + min_chunk_size(对可分割文本)"""
|
||
# 生成由段落分隔的超长文本(RecursiveChunker按段落分割)
|
||
text = ("A" * 400 + "\n\n" + "B" * 400 + "\n\n" + "C" * 400)
|
||
chunks = chunker.chunk(text)
|
||
|
||
# 按段落分割的块应该各自在合理范围内
|
||
assert len(chunks) >= 1
|
||
for chunk in chunks:
|
||
# 注意:RecursiveChunker不对单个过长段落进行二次分割
|
||
# 所以只验证块存在且内容非空
|
||
assert len(chunk["content"]) > 0
|
||
|
||
def test_chunker_overlap(self):
|
||
"""相邻 chunk 有重叠(overlap > 0 时)"""
|
||
from app.services.knowledge.chunker import RecursiveChunker
|
||
|
||
chunker = RecursiveChunker()
|
||
# 生成足够长的文本触发多个 chunk
|
||
text = "Alpha beta gamma delta epsilon. " * 30
|
||
chunks = chunker.chunk(text)
|
||
|
||
if len(chunks) >= 2:
|
||
# 验证两个相邻 chunk 不完全独立(有部分共同词汇)
|
||
c1_words = set(chunks[0]["content"].split())
|
||
c2_words = set(chunks[1]["content"].split())
|
||
# 允许重叠为0(当文本恰好在边界切割)
|
||
assert len(c1_words) > 0
|
||
assert len(c2_words) > 0
|
||
|
||
def test_chunker_min_chunk_filter(self):
|
||
"""过小的 chunk 会被合并"""
|
||
from app.services.knowledge.chunker import RecursiveChunker
|
||
|
||
chunker = RecursiveChunker()
|
||
# 很短的文本
|
||
text = "Short."
|
||
chunks = chunker.chunk(text)
|
||
|
||
# 应该合并为1个或0个(不为空则只有1个合并块)
|
||
assert len(chunks) <= 1
|
||
|
||
def test_chunker_empty_text(self, chunker):
|
||
"""空文本返回空列表"""
|
||
assert chunker.chunk("") == []
|
||
assert chunker.chunk(" ") == []
|
||
|
||
def test_chunker_metadata_passed(self, chunker):
|
||
"""metadata 正确传递到每个 chunk"""
|
||
text = "Some content here."
|
||
meta = {"source": "test_doc", "author": "tester"}
|
||
chunks = chunker.chunk(text, metadata=meta)
|
||
|
||
for chunk in chunks:
|
||
assert chunk["metadata"] == meta
|
||
|
||
def test_chunker_chunk_index_sequential(self, chunker):
|
||
"""chunk_index 从 0 开始连续递增"""
|
||
text = "First paragraph.\n\nSecond paragraph.\n\nThird paragraph."
|
||
chunks = chunker.chunk(text)
|
||
|
||
for i, chunk in enumerate(chunks):
|
||
assert chunk["chunk_index"] == i
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# MockEmbedder 测试
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestMockEmbedder:
|
||
@pytest.fixture
|
||
def embedder(self):
|
||
from app.services.knowledge.embedder import MockEmbedder
|
||
return MockEmbedder()
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_mock_embedder_deterministic(self, embedder):
|
||
"""相同输入两次调用返回相同向量"""
|
||
text = "Hello, this is a test sentence"
|
||
v1 = await embedder.embed(text)
|
||
v2 = await embedder.embed(text)
|
||
assert v1 == v2
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_mock_embedder_dimension(self, embedder):
|
||
"""输出向量维度为 1536"""
|
||
v = await embedder.embed("test")
|
||
assert len(v) == 1536
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_mock_embedder_different_texts(self, embedder):
|
||
"""不同输入返回不同向量"""
|
||
v1 = await embedder.embed("apple")
|
||
v2 = await embedder.embed("banana")
|
||
assert v1 != v2
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_mock_embedder_batch(self, embedder):
|
||
"""批量 embedding 返回正确数量的向量"""
|
||
texts = ["text1", "text2", "text3"]
|
||
results = await embedder.embed_batch(texts)
|
||
assert len(results) == 3
|
||
for v in results:
|
||
assert len(v) == 1536
|
||
|
||
def test_mock_embedder_custom_dimension(self):
|
||
"""自定义维度"""
|
||
from app.services.knowledge.embedder import MockEmbedder
|
||
embedder = MockEmbedder(dimension=768)
|
||
assert embedder.dimension == 768
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# HybridRetriever RRF Fusion 测试(不需要真实 DB)
|
||
# ---------------------------------------------------------------------------
|
||
|
||
class TestRRFFusion:
|
||
@pytest.fixture
|
||
def retriever(self):
|
||
from app.services.knowledge.retriever import HybridRetriever
|
||
from app.services.knowledge.embedder import MockEmbedder
|
||
|
||
embedder = MockEmbedder()
|
||
return HybridRetriever(embedder=embedder, vector_weight=0.7, keyword_weight=0.3)
|
||
|
||
def _make_results(self, ids: list[str]) -> list[dict]:
|
||
"""构造检索结果列表"""
|
||
return [
|
||
{
|
||
"chunk_id": cid,
|
||
"content": f"Content of {cid}",
|
||
"score": 1.0,
|
||
"document_id": "doc1",
|
||
"document_title": "Test Doc",
|
||
"metadata": {},
|
||
}
|
||
for cid in ids
|
||
]
|
||
|
||
def test_rrf_fusion_ranking(self, retriever):
|
||
"""RRF 融合结果按分数降序排列"""
|
||
vector_results = self._make_results(["A", "B", "C"])
|
||
keyword_results = self._make_results(["A", "C", "D"])
|
||
|
||
fused = retriever._rrf_fusion(vector_results, keyword_results)
|
||
|
||
# 结果非空
|
||
assert len(fused) > 0
|
||
# 分数降序
|
||
scores = [item["score"] for item in fused]
|
||
assert scores == sorted(scores, reverse=True)
|
||
|
||
def test_rrf_fusion_common_item_gets_higher_score(self, retriever):
|
||
"""同时出现在两个结果列表的 item 得分更高"""
|
||
vector_results = self._make_results(["shared", "only_vector"])
|
||
keyword_results = self._make_results(["shared", "only_keyword"])
|
||
|
||
fused = retriever._rrf_fusion(vector_results, keyword_results)
|
||
fused_by_id = {item["chunk_id"]: item["score"] for item in fused}
|
||
|
||
# 同时出现的 "shared" 应得分最高
|
||
assert fused_by_id["shared"] > fused_by_id.get("only_vector", 0)
|
||
assert fused_by_id["shared"] > fused_by_id.get("only_keyword", 0)
|
||
|
||
def test_rrf_weights(self):
|
||
"""vector_weight/keyword_weight 影响结果:权重高的来源贡献更大"""
|
||
from app.services.knowledge.retriever import HybridRetriever
|
||
from app.services.knowledge.embedder import MockEmbedder
|
||
|
||
# vector 权重更大
|
||
r_high_vec = HybridRetriever(MockEmbedder(), vector_weight=0.9, keyword_weight=0.1)
|
||
# keyword 权重更大
|
||
r_high_kw = HybridRetriever(MockEmbedder(), vector_weight=0.1, keyword_weight=0.9)
|
||
|
||
vector_only = [{"chunk_id": "V", "content": "v", "score": 1.0,
|
||
"document_id": "d", "document_title": "t", "metadata": {}}]
|
||
keyword_only = [{"chunk_id": "K", "content": "k", "score": 1.0,
|
||
"document_id": "d", "document_title": "t", "metadata": {}}]
|
||
|
||
fused_hv = r_high_vec._rrf_fusion(vector_only, keyword_only)
|
||
fused_hk = r_high_kw._rrf_fusion(vector_only, keyword_only)
|
||
|
||
hv_by_id = {item["chunk_id"]: item["score"] for item in fused_hv}
|
||
hk_by_id = {item["chunk_id"]: item["score"] for item in fused_hk}
|
||
|
||
# 高向量权重时,V 分数 > K 分数
|
||
assert hv_by_id["V"] > hv_by_id["K"]
|
||
# 高关键词权重时,K 分数 > V 分数
|
||
assert hk_by_id["K"] > hk_by_id["V"]
|
||
|
||
def test_rrf_empty_results(self, retriever):
|
||
"""两个结果都为空时,融合结果也为空"""
|
||
fused = retriever._rrf_fusion([], [])
|
||
assert fused == []
|
||
|
||
def test_rrf_one_empty(self, retriever):
|
||
"""一个结果为空时,另一个结果正常参与融合"""
|
||
vector_results = self._make_results(["A", "B"])
|
||
fused = retriever._rrf_fusion(vector_results, [])
|
||
assert len(fused) == 2
|