611 lines
21 KiB
Python
611 lines
21 KiB
Python
"""MultiSourceRAG 单元测试 - 多源混合检索
|
||
|
||
测试场景:
|
||
- 指定单个信息源 → 仅从该源检索
|
||
- 指定多个信息源 → 并行检索,结果融合排序
|
||
- 不指定信息源 → 从所有可用源检索
|
||
- 来源追溯 → 每个结果包含来源信息
|
||
- AE4: 指定"合规文档库"和"法务知识库" → 仅从这两个源检索
|
||
"""
|
||
|
||
import pytest
|
||
|
||
from agentkit.memory.embedder import MockEmbedder
|
||
from agentkit.memory.knowledge_base import Document, KnowledgeBase, QueryResult, SourceInfo
|
||
from agentkit.memory.local_rag import InMemoryLocalRAGService
|
||
from agentkit.memory.multi_source_retriever import MultiSourceRetriever
|
||
from agentkit.memory.retriever import MemoryRetriever
|
||
|
||
|
||
# ── Fixtures ──────────────────────────────────────────────
|
||
|
||
|
||
@pytest.fixture
|
||
def embedder():
|
||
return MockEmbedder(dimension=128)
|
||
|
||
|
||
@pytest.fixture
|
||
def local_rag(embedder):
|
||
"""本地合规文档库"""
|
||
return InMemoryLocalRAGService(embedder=embedder, chunk_size=500, chunk_overlap=50)
|
||
|
||
|
||
@pytest.fixture
|
||
def legal_rag(embedder):
|
||
"""法务知识库"""
|
||
return InMemoryLocalRAGService(embedder=embedder, chunk_size=500, chunk_overlap=50)
|
||
|
||
|
||
@pytest.fixture
|
||
def tech_rag(embedder):
|
||
"""技术文档库"""
|
||
return InMemoryLocalRAGService(embedder=embedder, chunk_size=500, chunk_overlap=50)
|
||
|
||
|
||
@pytest.fixture
|
||
def compliance_docs():
|
||
"""合规文档"""
|
||
return [
|
||
Document(
|
||
doc_id="compliance-1",
|
||
content="数据保护合规要求:所有用户数据必须加密存储,访问需经授权审批。",
|
||
title="数据保护合规指南",
|
||
source_id="compliance_data_protection",
|
||
metadata={"source": "compliance_data_protection", "format": "text"},
|
||
),
|
||
Document(
|
||
doc_id="compliance-2",
|
||
content="跨境数据传输需遵守 GDPR 和中国网络安全法的相关规定。",
|
||
title="跨境数据传输合规",
|
||
source_id="compliance_cross_border",
|
||
metadata={"source": "compliance_cross_border", "format": "text"},
|
||
),
|
||
]
|
||
|
||
|
||
@pytest.fixture
|
||
def legal_docs():
|
||
"""法务文档"""
|
||
return [
|
||
Document(
|
||
doc_id="legal-1",
|
||
content="合同审查要点:注意违约责任条款、知识产权归属和保密义务。",
|
||
title="合同审查指南",
|
||
source_id="legal_contract_review",
|
||
metadata={"source": "legal_contract_review", "format": "text"},
|
||
),
|
||
Document(
|
||
doc_id="legal-2",
|
||
content="劳动法规定:员工加班需支付加班费,标准为平时工资的1.5倍至3倍。",
|
||
title="劳动法要点",
|
||
source_id="legal_labor_law",
|
||
metadata={"source": "legal_labor_law", "format": "text"},
|
||
),
|
||
]
|
||
|
||
|
||
@pytest.fixture
|
||
def tech_docs():
|
||
"""技术文档"""
|
||
return [
|
||
Document(
|
||
doc_id="tech-1",
|
||
content="API 网关配置:限流策略为每分钟 1000 次请求,超时设置 30 秒。",
|
||
title="API 网关配置手册",
|
||
source_id="tech_api_gateway",
|
||
metadata={"source": "tech_api_gateway", "format": "text"},
|
||
),
|
||
]
|
||
|
||
|
||
# ── MultiSourceRetriever 核心测试 ─────────────────────────
|
||
|
||
|
||
class TestMultiSourceRetrieverBasic:
|
||
"""MultiSourceRetriever 基础功能测试"""
|
||
|
||
def test_register_source(self, local_rag, legal_rag):
|
||
retriever = MultiSourceRetriever()
|
||
retriever.register_source("local:合规文档", local_rag)
|
||
retriever.register_source("法务知识库", legal_rag)
|
||
|
||
names = retriever.get_source_names()
|
||
assert "local:合规文档" in names
|
||
assert "法务知识库" in names
|
||
|
||
def test_register_source_via_constructor(self, local_rag, legal_rag):
|
||
retriever = MultiSourceRetriever(
|
||
sources={"local:合规文档": local_rag, "法务知识库": legal_rag}
|
||
)
|
||
|
||
names = retriever.get_source_names()
|
||
assert len(names) == 2
|
||
|
||
def test_unregister_source(self, local_rag, legal_rag):
|
||
retriever = MultiSourceRetriever(
|
||
sources={"local:合规文档": local_rag, "法务知识库": legal_rag}
|
||
)
|
||
|
||
result = retriever.unregister_source("local:合规文档")
|
||
assert result is True
|
||
assert "local:合规文档" not in retriever.get_source_names()
|
||
|
||
def test_unregister_nonexistent_source(self, local_rag):
|
||
retriever = MultiSourceRetriever(sources={"local:合规文档": local_rag})
|
||
|
||
result = retriever.unregister_source("不存在")
|
||
assert result is False
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_list_all_sources(self, local_rag, legal_rag, compliance_docs, legal_docs):
|
||
await local_rag.ingest(compliance_docs)
|
||
await legal_rag.ingest(legal_docs)
|
||
|
||
retriever = MultiSourceRetriever(
|
||
sources={"local:合规文档": local_rag, "法务知识库": legal_rag}
|
||
)
|
||
|
||
sources = await retriever.list_all_sources()
|
||
assert "local:合规文档" in sources
|
||
assert "法务知识库" in sources
|
||
for name, info in sources.items():
|
||
assert isinstance(info, SourceInfo)
|
||
|
||
|
||
class TestMultiSourceRetrieverSearch:
|
||
"""MultiSourceRetriever 检索功能测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_search_single_source(
|
||
self, local_rag, legal_rag, compliance_docs, legal_docs
|
||
):
|
||
"""指定单个信息源 → 仅从该源检索"""
|
||
await local_rag.ingest(compliance_docs)
|
||
await legal_rag.ingest(legal_docs)
|
||
|
||
retriever = MultiSourceRetriever(
|
||
sources={"local:合规文档": local_rag, "法务知识库": legal_rag}
|
||
)
|
||
|
||
# 仅从合规文档库检索
|
||
results = await retriever.search("合规", top_k=5, sources=["local:合规文档"])
|
||
|
||
# 所有结果应来自合规文档库
|
||
for r in results:
|
||
assert r.source_name == "local:合规文档"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_search_multiple_sources(
|
||
self, local_rag, legal_rag, compliance_docs, legal_docs
|
||
):
|
||
"""指定多个信息源 → 并行检索,结果融合排序"""
|
||
await local_rag.ingest(compliance_docs)
|
||
await legal_rag.ingest(legal_docs)
|
||
|
||
retriever = MultiSourceRetriever(
|
||
sources={"local:合规文档": local_rag, "法务知识库": legal_rag}
|
||
)
|
||
|
||
results = await retriever.search(
|
||
"合规 法务", top_k=10, sources=["local:合规文档", "法务知识库"]
|
||
)
|
||
|
||
# 结果应来自两个源
|
||
source_names = {r.source_name for r in results}
|
||
assert source_names.issubset({"local:合规文档", "法务知识库"})
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_search_all_sources_when_none_specified(
|
||
self, local_rag, legal_rag, tech_rag, compliance_docs, legal_docs, tech_docs
|
||
):
|
||
"""不指定信息源 → 从所有可用源检索"""
|
||
await local_rag.ingest(compliance_docs)
|
||
await legal_rag.ingest(legal_docs)
|
||
await tech_rag.ingest(tech_docs)
|
||
|
||
retriever = MultiSourceRetriever(
|
||
sources={
|
||
"local:合规文档": local_rag,
|
||
"法务知识库": legal_rag,
|
||
"技术文档库": tech_rag,
|
||
}
|
||
)
|
||
|
||
results = await retriever.search("合规 法务 技术", top_k=10)
|
||
|
||
# 结果应来自所有三个源
|
||
source_names = {r.source_name for r in results}
|
||
assert len(source_names) >= 1 # 至少有一个源返回结果
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_search_no_sources_registered(self):
|
||
"""无信息源注册时返回空结果"""
|
||
retriever = MultiSourceRetriever()
|
||
|
||
results = await retriever.search("anything", top_k=5)
|
||
assert len(results) == 0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_search_nonexistent_source(self, local_rag, compliance_docs):
|
||
"""指定不存在的源 → 跳过,返回空结果"""
|
||
await local_rag.ingest(compliance_docs)
|
||
|
||
retriever = MultiSourceRetriever(sources={"local:合规文档": local_rag})
|
||
|
||
results = await retriever.search("合规", top_k=5, sources=["不存在的源"])
|
||
assert len(results) == 0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_search_with_weights(
|
||
self, local_rag, legal_rag, compliance_docs, legal_docs
|
||
):
|
||
"""带权重检索 → 特定源分数被调整"""
|
||
await local_rag.ingest(compliance_docs)
|
||
await legal_rag.ingest(legal_docs)
|
||
|
||
retriever = MultiSourceRetriever(
|
||
sources={"local:合规文档": local_rag, "法务知识库": legal_rag}
|
||
)
|
||
|
||
# 先不带权重检索
|
||
results_no_weight = await retriever.search(
|
||
"合规", top_k=10, sources=["local:合规文档", "法务知识库"]
|
||
)
|
||
|
||
# 带权重检索:提升合规文档库
|
||
results_with_weight = await retriever.search(
|
||
"合规",
|
||
top_k=10,
|
||
sources=["local:合规文档", "法务知识库"],
|
||
weights={"local:合规文档": 2.0},
|
||
)
|
||
|
||
# 有权重时合规文档库的分数应更高
|
||
compliance_scores_weighted = [
|
||
r.score for r in results_with_weight if r.source_name == "local:合规文档"
|
||
]
|
||
compliance_scores_unweighted = [
|
||
r.score for r in results_no_weight if r.source_name == "local:合规文档"
|
||
]
|
||
|
||
if compliance_scores_weighted and compliance_scores_unweighted:
|
||
assert max(compliance_scores_weighted) >= max(compliance_scores_unweighted)
|
||
|
||
|
||
class TestMultiSourceRetrieverSourceTracing:
|
||
"""来源追溯测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_result_contains_source_info(
|
||
self, local_rag, legal_rag, compliance_docs, legal_docs
|
||
):
|
||
"""每个检索结果包含来源追溯信息"""
|
||
await local_rag.ingest(compliance_docs)
|
||
await legal_rag.ingest(legal_docs)
|
||
|
||
retriever = MultiSourceRetriever(
|
||
sources={"local:合规文档": local_rag, "法务知识库": legal_rag}
|
||
)
|
||
|
||
results = await retriever.search("合规", top_k=5)
|
||
|
||
for r in results:
|
||
# source_id 应非空
|
||
assert r.source_id != ""
|
||
# source_name 应为注册的源名称
|
||
assert r.source_name in ("local:合规文档", "法务知识库")
|
||
# title 应非空
|
||
assert r.title != ""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_result_contains_document_title(
|
||
self, local_rag, compliance_docs
|
||
):
|
||
"""检索结果包含文档标题"""
|
||
await local_rag.ingest(compliance_docs)
|
||
|
||
retriever = MultiSourceRetriever(sources={"local:合规文档": local_rag})
|
||
|
||
results = await retriever.search("数据保护", top_k=5)
|
||
|
||
for r in results:
|
||
assert r.title != ""
|
||
assert r.doc_id != ""
|
||
|
||
|
||
class TestMultiSourceRetrieverDedup:
|
||
"""去重测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_deduplicate_identical_content(
|
||
self, local_rag, legal_rag, embedder
|
||
):
|
||
"""相同内容从不同源返回时去重,保留高分"""
|
||
# 两个源包含相同内容
|
||
same_doc = Document(
|
||
doc_id="same-doc",
|
||
content="这是一段完全相同的内容用于测试去重功能。",
|
||
title="重复文档",
|
||
source_id="same_source",
|
||
metadata={"source": "same_source", "format": "text"},
|
||
)
|
||
await local_rag.ingest([same_doc])
|
||
await legal_rag.ingest([same_doc])
|
||
|
||
retriever = MultiSourceRetriever(
|
||
sources={"local:合规文档": local_rag, "法务知识库": legal_rag}
|
||
)
|
||
|
||
results = await retriever.search("去重", top_k=10)
|
||
|
||
# 相同内容应去重,只保留一个
|
||
content_counts: dict[str, int] = {}
|
||
for r in results:
|
||
content_counts[r.content] = content_counts.get(r.content, 0) + 1
|
||
|
||
for content, count in content_counts.items():
|
||
assert count == 1, f"内容 '{content[:30]}...' 出现了 {count} 次,应去重为 1 次"
|
||
|
||
|
||
# ── AE4: 合规文档库 + 法务知识库指定检索 ──────────────────
|
||
|
||
|
||
class TestAE4ComplianceAndLegalSearch:
|
||
"""AE4 场景:指定"合规文档库"和"法务知识库" → 仅从这两个源检索"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_search_compliance_and_legal_only(
|
||
self, local_rag, legal_rag, tech_rag, compliance_docs, legal_docs, tech_docs
|
||
):
|
||
"""指定合规和法务源 → 不从技术文档库检索"""
|
||
await local_rag.ingest(compliance_docs)
|
||
await legal_rag.ingest(legal_docs)
|
||
await tech_rag.ingest(tech_docs)
|
||
|
||
retriever = MultiSourceRetriever(
|
||
sources={
|
||
"合规文档库": local_rag,
|
||
"法务知识库": legal_rag,
|
||
"技术文档库": tech_rag,
|
||
}
|
||
)
|
||
|
||
results = await retriever.search(
|
||
"合规 法务", top_k=10, sources=["合规文档库", "法务知识库"]
|
||
)
|
||
|
||
# 结果不应来自技术文档库
|
||
for r in results:
|
||
assert r.source_name != "技术文档库"
|
||
assert r.source_name in ("合规文档库", "法务知识库")
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_search_compliance_and_legal_results_merged(
|
||
self, local_rag, legal_rag, compliance_docs, legal_docs
|
||
):
|
||
"""合规和法务源的结果应合并排序"""
|
||
await local_rag.ingest(compliance_docs)
|
||
await legal_rag.ingest(legal_docs)
|
||
|
||
retriever = MultiSourceRetriever(
|
||
sources={"合规文档库": local_rag, "法务知识库": legal_rag}
|
||
)
|
||
|
||
results = await retriever.search(
|
||
"合规 法务", top_k=10, sources=["合规文档库", "法务知识库"]
|
||
)
|
||
|
||
# 应有来自两个源的结果
|
||
source_names = {r.source_name for r in results}
|
||
assert len(source_names) >= 1
|
||
|
||
# 结果应按 score 降序排列
|
||
for i in range(len(results) - 1):
|
||
assert results[i].score >= results[i + 1].score
|
||
|
||
|
||
# ── MemoryRetriever 集成测试 ──────────────────────────────
|
||
|
||
|
||
class TestMemoryRetrieverIntegration:
|
||
"""MemoryRetriever 与 MultiSourceRetriever 集成测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_retrieve_with_sources_parameter(
|
||
self, local_rag, legal_rag, compliance_docs, legal_docs
|
||
):
|
||
"""MemoryRetriever.retrieve(sources=...) 委托给 MultiSourceRetriever"""
|
||
await local_rag.ingest(compliance_docs)
|
||
await legal_rag.ingest(legal_docs)
|
||
|
||
retriever = MemoryRetriever(
|
||
knowledge_sources={"local:合规文档": local_rag, "法务知识库": legal_rag}
|
||
)
|
||
|
||
items = await retriever.retrieve(
|
||
"合规", top_k=5, sources=["local:合规文档"]
|
||
)
|
||
|
||
# 结果应为 MemoryItem 类型
|
||
for item in items:
|
||
assert hasattr(item, "key")
|
||
assert hasattr(item, "value")
|
||
assert hasattr(item, "score")
|
||
assert hasattr(item, "metadata")
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_retrieve_without_sources_keeps_current_behavior(
|
||
self, local_rag, compliance_docs
|
||
):
|
||
"""不指定 sources 时保持原有行为(三层记忆检索)"""
|
||
await local_rag.ingest(compliance_docs)
|
||
|
||
retriever = MemoryRetriever(
|
||
knowledge_sources={"local:合规文档": local_rag}
|
||
)
|
||
|
||
# 不指定 sources → 走三层记忆路径
|
||
items = await retriever.retrieve("合规", top_k=5)
|
||
# 三层记忆为空,应返回空结果
|
||
assert isinstance(items, list)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_retrieve_from_sources_with_source_tracing(
|
||
self, local_rag, legal_rag, compliance_docs, legal_docs
|
||
):
|
||
"""通过 MemoryRetriever 多源检索时,结果包含来源追溯"""
|
||
await local_rag.ingest(compliance_docs)
|
||
await legal_rag.ingest(legal_docs)
|
||
|
||
retriever = MemoryRetriever(
|
||
knowledge_sources={"local:合规文档": local_rag, "法务知识库": legal_rag}
|
||
)
|
||
|
||
items = await retriever.retrieve(
|
||
"合规", top_k=5, sources=["local:合规文档", "法务知识库"]
|
||
)
|
||
|
||
for item in items:
|
||
assert item.metadata.get("source") == "rag"
|
||
assert "source_name" in item.metadata
|
||
assert "document_title" in item.metadata
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_multi_source_retriever_property(
|
||
self, local_rag, compliance_docs
|
||
):
|
||
"""通过 multi_source_retriever 属性直接访问"""
|
||
retriever = MemoryRetriever(
|
||
knowledge_sources={"local:合规文档": local_rag}
|
||
)
|
||
|
||
ms_retriever = retriever.multi_source_retriever
|
||
assert isinstance(ms_retriever, MultiSourceRetriever)
|
||
assert "local:合规文档" in ms_retriever.get_source_names()
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_register_source_via_property(
|
||
self, local_rag, legal_rag, compliance_docs, legal_docs
|
||
):
|
||
"""通过 multi_source_retriever 属性动态注册源"""
|
||
await local_rag.ingest(compliance_docs)
|
||
await legal_rag.ingest(legal_docs)
|
||
|
||
retriever = MemoryRetriever(
|
||
knowledge_sources={"local:合规文档": local_rag}
|
||
)
|
||
|
||
# 动态注册法务知识库
|
||
retriever.multi_source_retriever.register_source("法务知识库", legal_rag)
|
||
|
||
# 现在可以从法务知识库检索
|
||
items = await retriever.retrieve(
|
||
"合同", top_k=5, sources=["法务知识库"]
|
||
)
|
||
|
||
for item in items:
|
||
assert item.metadata.get("source_name") == "法务知识库"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_retrieve_with_source_weights(
|
||
self, local_rag, legal_rag, compliance_docs, legal_docs
|
||
):
|
||
"""MemoryRetriever 支持 source_weights 参数"""
|
||
await local_rag.ingest(compliance_docs)
|
||
await legal_rag.ingest(legal_docs)
|
||
|
||
retriever = MemoryRetriever(
|
||
knowledge_sources={"local:合规文档": local_rag, "法务知识库": legal_rag}
|
||
)
|
||
|
||
items = await retriever.retrieve(
|
||
"合规",
|
||
top_k=5,
|
||
sources=["local:合规文档", "法务知识库"],
|
||
source_weights={"local:合规文档": 1.5},
|
||
)
|
||
|
||
assert isinstance(items, list)
|
||
|
||
|
||
# ── 边界情况测试 ──────────────────────────────────────────
|
||
|
||
|
||
class TestEdgeCases:
|
||
"""边界情况测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_search_with_empty_source_list(self, local_rag, compliance_docs):
|
||
"""sources=[] 空列表 → 不查询任何源,返回空"""
|
||
await local_rag.ingest(compliance_docs)
|
||
|
||
retriever = MultiSourceRetriever(sources={"local:合规文档": local_rag})
|
||
|
||
results = await retriever.search("合规", top_k=5, sources=[])
|
||
assert len(results) == 0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_search_top_k_limits_results(
|
||
self, local_rag, legal_rag, compliance_docs, legal_docs
|
||
):
|
||
"""top_k 限制返回结果数量"""
|
||
await local_rag.ingest(compliance_docs)
|
||
await legal_rag.ingest(legal_docs)
|
||
|
||
retriever = MultiSourceRetriever(
|
||
sources={"local:合规文档": local_rag, "法务知识库": legal_rag}
|
||
)
|
||
|
||
results = await retriever.search("合规", top_k=1)
|
||
assert len(results) <= 1
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_source_query_failure_graceful(
|
||
self, local_rag, compliance_docs
|
||
):
|
||
"""某个源查询失败时,其他源结果正常返回"""
|
||
await local_rag.ingest(compliance_docs)
|
||
|
||
# 创建一个会抛异常的 mock 源
|
||
class FailingSource:
|
||
async def ingest(self, documents):
|
||
return []
|
||
|
||
async def query(self, text, top_k=5):
|
||
raise ConnectionError("Service unavailable")
|
||
|
||
async def delete_by_id(self, id):
|
||
return False
|
||
|
||
async def list_sources(self):
|
||
return [SourceInfo(source_id="failing", source_name="Failing", source_type="mock")]
|
||
|
||
async def health_check(self):
|
||
return False
|
||
|
||
retriever = MultiSourceRetriever(
|
||
sources={"local:合规文档": local_rag, "failing_source": FailingSource()}
|
||
)
|
||
|
||
# 应不抛异常,返回合规文档库的结果
|
||
results = await retriever.search("合规", top_k=5)
|
||
assert isinstance(results, list)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_search_results_sorted_by_score(
|
||
self, local_rag, legal_rag, compliance_docs, legal_docs
|
||
):
|
||
"""结果按 score 降序排列"""
|
||
await local_rag.ingest(compliance_docs)
|
||
await legal_rag.ingest(legal_docs)
|
||
|
||
retriever = MultiSourceRetriever(
|
||
sources={"local:合规文档": local_rag, "法务知识库": legal_rag}
|
||
)
|
||
|
||
results = await retriever.search("合规 法务", top_k=10)
|
||
|
||
for i in range(len(results) - 1):
|
||
assert results[i].score >= results[i + 1].score
|