fischer-agentkit/tests/unit/memory/test_multi_source_rag.py

611 lines
21 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""MultiSourceRAG 单元测试 - 多源混合检索
测试场景:
- 指定单个信息源 → 仅从该源检索
- 指定多个信息源 → 并行检索,结果融合排序
- 不指定信息源 → 从所有可用源检索
- 来源追溯 → 每个结果包含来源信息
- AE4: 指定"合规文档库""法务知识库" → 仅从这两个源检索
"""
import pytest
from agentkit.memory.embedder import MockEmbedder
from agentkit.memory.knowledge_base import Document, KnowledgeBase, QueryResult, SourceInfo
from agentkit.memory.local_rag import InMemoryLocalRAGService
from agentkit.memory.multi_source_retriever import MultiSourceRetriever
from agentkit.memory.retriever import MemoryRetriever
# ── Fixtures ──────────────────────────────────────────────
@pytest.fixture
def embedder():
return MockEmbedder(dimension=128)
@pytest.fixture
def local_rag(embedder):
"""本地合规文档库"""
return InMemoryLocalRAGService(embedder=embedder, chunk_size=500, chunk_overlap=50)
@pytest.fixture
def legal_rag(embedder):
"""法务知识库"""
return InMemoryLocalRAGService(embedder=embedder, chunk_size=500, chunk_overlap=50)
@pytest.fixture
def tech_rag(embedder):
"""技术文档库"""
return InMemoryLocalRAGService(embedder=embedder, chunk_size=500, chunk_overlap=50)
@pytest.fixture
def compliance_docs():
"""合规文档"""
return [
Document(
doc_id="compliance-1",
content="数据保护合规要求:所有用户数据必须加密存储,访问需经授权审批。",
title="数据保护合规指南",
source_id="compliance_data_protection",
metadata={"source": "compliance_data_protection", "format": "text"},
),
Document(
doc_id="compliance-2",
content="跨境数据传输需遵守 GDPR 和中国网络安全法的相关规定。",
title="跨境数据传输合规",
source_id="compliance_cross_border",
metadata={"source": "compliance_cross_border", "format": "text"},
),
]
@pytest.fixture
def legal_docs():
"""法务文档"""
return [
Document(
doc_id="legal-1",
content="合同审查要点:注意违约责任条款、知识产权归属和保密义务。",
title="合同审查指南",
source_id="legal_contract_review",
metadata={"source": "legal_contract_review", "format": "text"},
),
Document(
doc_id="legal-2",
content="劳动法规定员工加班需支付加班费标准为平时工资的1.5倍至3倍。",
title="劳动法要点",
source_id="legal_labor_law",
metadata={"source": "legal_labor_law", "format": "text"},
),
]
@pytest.fixture
def tech_docs():
"""技术文档"""
return [
Document(
doc_id="tech-1",
content="API 网关配置:限流策略为每分钟 1000 次请求,超时设置 30 秒。",
title="API 网关配置手册",
source_id="tech_api_gateway",
metadata={"source": "tech_api_gateway", "format": "text"},
),
]
# ── MultiSourceRetriever 核心测试 ─────────────────────────
class TestMultiSourceRetrieverBasic:
"""MultiSourceRetriever 基础功能测试"""
def test_register_source(self, local_rag, legal_rag):
retriever = MultiSourceRetriever()
retriever.register_source("local:合规文档", local_rag)
retriever.register_source("法务知识库", legal_rag)
names = retriever.get_source_names()
assert "local:合规文档" in names
assert "法务知识库" in names
def test_register_source_via_constructor(self, local_rag, legal_rag):
retriever = MultiSourceRetriever(
sources={"local:合规文档": local_rag, "法务知识库": legal_rag}
)
names = retriever.get_source_names()
assert len(names) == 2
def test_unregister_source(self, local_rag, legal_rag):
retriever = MultiSourceRetriever(
sources={"local:合规文档": local_rag, "法务知识库": legal_rag}
)
result = retriever.unregister_source("local:合规文档")
assert result is True
assert "local:合规文档" not in retriever.get_source_names()
def test_unregister_nonexistent_source(self, local_rag):
retriever = MultiSourceRetriever(sources={"local:合规文档": local_rag})
result = retriever.unregister_source("不存在")
assert result is False
@pytest.mark.asyncio
async def test_list_all_sources(self, local_rag, legal_rag, compliance_docs, legal_docs):
await local_rag.ingest(compliance_docs)
await legal_rag.ingest(legal_docs)
retriever = MultiSourceRetriever(
sources={"local:合规文档": local_rag, "法务知识库": legal_rag}
)
sources = await retriever.list_all_sources()
assert "local:合规文档" in sources
assert "法务知识库" in sources
for name, info in sources.items():
assert isinstance(info, SourceInfo)
class TestMultiSourceRetrieverSearch:
"""MultiSourceRetriever 检索功能测试"""
@pytest.mark.asyncio
async def test_search_single_source(
self, local_rag, legal_rag, compliance_docs, legal_docs
):
"""指定单个信息源 → 仅从该源检索"""
await local_rag.ingest(compliance_docs)
await legal_rag.ingest(legal_docs)
retriever = MultiSourceRetriever(
sources={"local:合规文档": local_rag, "法务知识库": legal_rag}
)
# 仅从合规文档库检索
results = await retriever.search("合规", top_k=5, sources=["local:合规文档"])
# 所有结果应来自合规文档库
for r in results:
assert r.source_name == "local:合规文档"
@pytest.mark.asyncio
async def test_search_multiple_sources(
self, local_rag, legal_rag, compliance_docs, legal_docs
):
"""指定多个信息源 → 并行检索,结果融合排序"""
await local_rag.ingest(compliance_docs)
await legal_rag.ingest(legal_docs)
retriever = MultiSourceRetriever(
sources={"local:合规文档": local_rag, "法务知识库": legal_rag}
)
results = await retriever.search(
"合规 法务", top_k=10, sources=["local:合规文档", "法务知识库"]
)
# 结果应来自两个源
source_names = {r.source_name for r in results}
assert source_names.issubset({"local:合规文档", "法务知识库"})
@pytest.mark.asyncio
async def test_search_all_sources_when_none_specified(
self, local_rag, legal_rag, tech_rag, compliance_docs, legal_docs, tech_docs
):
"""不指定信息源 → 从所有可用源检索"""
await local_rag.ingest(compliance_docs)
await legal_rag.ingest(legal_docs)
await tech_rag.ingest(tech_docs)
retriever = MultiSourceRetriever(
sources={
"local:合规文档": local_rag,
"法务知识库": legal_rag,
"技术文档库": tech_rag,
}
)
results = await retriever.search("合规 法务 技术", top_k=10)
# 结果应来自所有三个源
source_names = {r.source_name for r in results}
assert len(source_names) >= 1 # 至少有一个源返回结果
@pytest.mark.asyncio
async def test_search_no_sources_registered(self):
"""无信息源注册时返回空结果"""
retriever = MultiSourceRetriever()
results = await retriever.search("anything", top_k=5)
assert len(results) == 0
@pytest.mark.asyncio
async def test_search_nonexistent_source(self, local_rag, compliance_docs):
"""指定不存在的源 → 跳过,返回空结果"""
await local_rag.ingest(compliance_docs)
retriever = MultiSourceRetriever(sources={"local:合规文档": local_rag})
results = await retriever.search("合规", top_k=5, sources=["不存在的源"])
assert len(results) == 0
@pytest.mark.asyncio
async def test_search_with_weights(
self, local_rag, legal_rag, compliance_docs, legal_docs
):
"""带权重检索 → 特定源分数被调整"""
await local_rag.ingest(compliance_docs)
await legal_rag.ingest(legal_docs)
retriever = MultiSourceRetriever(
sources={"local:合规文档": local_rag, "法务知识库": legal_rag}
)
# 先不带权重检索
results_no_weight = await retriever.search(
"合规", top_k=10, sources=["local:合规文档", "法务知识库"]
)
# 带权重检索:提升合规文档库
results_with_weight = await retriever.search(
"合规",
top_k=10,
sources=["local:合规文档", "法务知识库"],
weights={"local:合规文档": 2.0},
)
# 有权重时合规文档库的分数应更高
compliance_scores_weighted = [
r.score for r in results_with_weight if r.source_name == "local:合规文档"
]
compliance_scores_unweighted = [
r.score for r in results_no_weight if r.source_name == "local:合规文档"
]
if compliance_scores_weighted and compliance_scores_unweighted:
assert max(compliance_scores_weighted) >= max(compliance_scores_unweighted)
class TestMultiSourceRetrieverSourceTracing:
"""来源追溯测试"""
@pytest.mark.asyncio
async def test_result_contains_source_info(
self, local_rag, legal_rag, compliance_docs, legal_docs
):
"""每个检索结果包含来源追溯信息"""
await local_rag.ingest(compliance_docs)
await legal_rag.ingest(legal_docs)
retriever = MultiSourceRetriever(
sources={"local:合规文档": local_rag, "法务知识库": legal_rag}
)
results = await retriever.search("合规", top_k=5)
for r in results:
# source_id 应非空
assert r.source_id != ""
# source_name 应为注册的源名称
assert r.source_name in ("local:合规文档", "法务知识库")
# title 应非空
assert r.title != ""
@pytest.mark.asyncio
async def test_result_contains_document_title(
self, local_rag, compliance_docs
):
"""检索结果包含文档标题"""
await local_rag.ingest(compliance_docs)
retriever = MultiSourceRetriever(sources={"local:合规文档": local_rag})
results = await retriever.search("数据保护", top_k=5)
for r in results:
assert r.title != ""
assert r.doc_id != ""
class TestMultiSourceRetrieverDedup:
"""去重测试"""
@pytest.mark.asyncio
async def test_deduplicate_identical_content(
self, local_rag, legal_rag, embedder
):
"""相同内容从不同源返回时去重,保留高分"""
# 两个源包含相同内容
same_doc = Document(
doc_id="same-doc",
content="这是一段完全相同的内容用于测试去重功能。",
title="重复文档",
source_id="same_source",
metadata={"source": "same_source", "format": "text"},
)
await local_rag.ingest([same_doc])
await legal_rag.ingest([same_doc])
retriever = MultiSourceRetriever(
sources={"local:合规文档": local_rag, "法务知识库": legal_rag}
)
results = await retriever.search("去重", top_k=10)
# 相同内容应去重,只保留一个
content_counts: dict[str, int] = {}
for r in results:
content_counts[r.content] = content_counts.get(r.content, 0) + 1
for content, count in content_counts.items():
assert count == 1, f"内容 '{content[:30]}...' 出现了 {count} 次,应去重为 1 次"
# ── AE4: 合规文档库 + 法务知识库指定检索 ──────────────────
class TestAE4ComplianceAndLegalSearch:
"""AE4 场景:指定"合规文档库""法务知识库" → 仅从这两个源检索"""
@pytest.mark.asyncio
async def test_search_compliance_and_legal_only(
self, local_rag, legal_rag, tech_rag, compliance_docs, legal_docs, tech_docs
):
"""指定合规和法务源 → 不从技术文档库检索"""
await local_rag.ingest(compliance_docs)
await legal_rag.ingest(legal_docs)
await tech_rag.ingest(tech_docs)
retriever = MultiSourceRetriever(
sources={
"合规文档库": local_rag,
"法务知识库": legal_rag,
"技术文档库": tech_rag,
}
)
results = await retriever.search(
"合规 法务", top_k=10, sources=["合规文档库", "法务知识库"]
)
# 结果不应来自技术文档库
for r in results:
assert r.source_name != "技术文档库"
assert r.source_name in ("合规文档库", "法务知识库")
@pytest.mark.asyncio
async def test_search_compliance_and_legal_results_merged(
self, local_rag, legal_rag, compliance_docs, legal_docs
):
"""合规和法务源的结果应合并排序"""
await local_rag.ingest(compliance_docs)
await legal_rag.ingest(legal_docs)
retriever = MultiSourceRetriever(
sources={"合规文档库": local_rag, "法务知识库": legal_rag}
)
results = await retriever.search(
"合规 法务", top_k=10, sources=["合规文档库", "法务知识库"]
)
# 应有来自两个源的结果
source_names = {r.source_name for r in results}
assert len(source_names) >= 1
# 结果应按 score 降序排列
for i in range(len(results) - 1):
assert results[i].score >= results[i + 1].score
# ── MemoryRetriever 集成测试 ──────────────────────────────
class TestMemoryRetrieverIntegration:
"""MemoryRetriever 与 MultiSourceRetriever 集成测试"""
@pytest.mark.asyncio
async def test_retrieve_with_sources_parameter(
self, local_rag, legal_rag, compliance_docs, legal_docs
):
"""MemoryRetriever.retrieve(sources=...) 委托给 MultiSourceRetriever"""
await local_rag.ingest(compliance_docs)
await legal_rag.ingest(legal_docs)
retriever = MemoryRetriever(
knowledge_sources={"local:合规文档": local_rag, "法务知识库": legal_rag}
)
items = await retriever.retrieve(
"合规", top_k=5, sources=["local:合规文档"]
)
# 结果应为 MemoryItem 类型
for item in items:
assert hasattr(item, "key")
assert hasattr(item, "value")
assert hasattr(item, "score")
assert hasattr(item, "metadata")
@pytest.mark.asyncio
async def test_retrieve_without_sources_keeps_current_behavior(
self, local_rag, compliance_docs
):
"""不指定 sources 时保持原有行为(三层记忆检索)"""
await local_rag.ingest(compliance_docs)
retriever = MemoryRetriever(
knowledge_sources={"local:合规文档": local_rag}
)
# 不指定 sources → 走三层记忆路径
items = await retriever.retrieve("合规", top_k=5)
# 三层记忆为空,应返回空结果
assert isinstance(items, list)
@pytest.mark.asyncio
async def test_retrieve_from_sources_with_source_tracing(
self, local_rag, legal_rag, compliance_docs, legal_docs
):
"""通过 MemoryRetriever 多源检索时,结果包含来源追溯"""
await local_rag.ingest(compliance_docs)
await legal_rag.ingest(legal_docs)
retriever = MemoryRetriever(
knowledge_sources={"local:合规文档": local_rag, "法务知识库": legal_rag}
)
items = await retriever.retrieve(
"合规", top_k=5, sources=["local:合规文档", "法务知识库"]
)
for item in items:
assert item.metadata.get("source") == "rag"
assert "source_name" in item.metadata
assert "document_title" in item.metadata
@pytest.mark.asyncio
async def test_multi_source_retriever_property(
self, local_rag, compliance_docs
):
"""通过 multi_source_retriever 属性直接访问"""
retriever = MemoryRetriever(
knowledge_sources={"local:合规文档": local_rag}
)
ms_retriever = retriever.multi_source_retriever
assert isinstance(ms_retriever, MultiSourceRetriever)
assert "local:合规文档" in ms_retriever.get_source_names()
@pytest.mark.asyncio
async def test_register_source_via_property(
self, local_rag, legal_rag, compliance_docs, legal_docs
):
"""通过 multi_source_retriever 属性动态注册源"""
await local_rag.ingest(compliance_docs)
await legal_rag.ingest(legal_docs)
retriever = MemoryRetriever(
knowledge_sources={"local:合规文档": local_rag}
)
# 动态注册法务知识库
retriever.multi_source_retriever.register_source("法务知识库", legal_rag)
# 现在可以从法务知识库检索
items = await retriever.retrieve(
"合同", top_k=5, sources=["法务知识库"]
)
for item in items:
assert item.metadata.get("source_name") == "法务知识库"
@pytest.mark.asyncio
async def test_retrieve_with_source_weights(
self, local_rag, legal_rag, compliance_docs, legal_docs
):
"""MemoryRetriever 支持 source_weights 参数"""
await local_rag.ingest(compliance_docs)
await legal_rag.ingest(legal_docs)
retriever = MemoryRetriever(
knowledge_sources={"local:合规文档": local_rag, "法务知识库": legal_rag}
)
items = await retriever.retrieve(
"合规",
top_k=5,
sources=["local:合规文档", "法务知识库"],
source_weights={"local:合规文档": 1.5},
)
assert isinstance(items, list)
# ── 边界情况测试 ──────────────────────────────────────────
class TestEdgeCases:
"""边界情况测试"""
@pytest.mark.asyncio
async def test_search_with_empty_source_list(self, local_rag, compliance_docs):
"""sources=[] 空列表 → 不查询任何源,返回空"""
await local_rag.ingest(compliance_docs)
retriever = MultiSourceRetriever(sources={"local:合规文档": local_rag})
results = await retriever.search("合规", top_k=5, sources=[])
assert len(results) == 0
@pytest.mark.asyncio
async def test_search_top_k_limits_results(
self, local_rag, legal_rag, compliance_docs, legal_docs
):
"""top_k 限制返回结果数量"""
await local_rag.ingest(compliance_docs)
await legal_rag.ingest(legal_docs)
retriever = MultiSourceRetriever(
sources={"local:合规文档": local_rag, "法务知识库": legal_rag}
)
results = await retriever.search("合规", top_k=1)
assert len(results) <= 1
@pytest.mark.asyncio
async def test_source_query_failure_graceful(
self, local_rag, compliance_docs
):
"""某个源查询失败时,其他源结果正常返回"""
await local_rag.ingest(compliance_docs)
# 创建一个会抛异常的 mock 源
class FailingSource:
async def ingest(self, documents):
return []
async def query(self, text, top_k=5):
raise ConnectionError("Service unavailable")
async def delete_by_id(self, id):
return False
async def list_sources(self):
return [SourceInfo(source_id="failing", source_name="Failing", source_type="mock")]
async def health_check(self):
return False
retriever = MultiSourceRetriever(
sources={"local:合规文档": local_rag, "failing_source": FailingSource()}
)
# 应不抛异常,返回合规文档库的结果
results = await retriever.search("合规", top_k=5)
assert isinstance(results, list)
@pytest.mark.asyncio
async def test_search_results_sorted_by_score(
self, local_rag, legal_rag, compliance_docs, legal_docs
):
"""结果按 score 降序排列"""
await local_rag.ingest(compliance_docs)
await legal_rag.ingest(legal_docs)
retriever = MultiSourceRetriever(
sources={"local:合规文档": local_rag, "法务知识库": legal_rag}
)
results = await retriever.search("合规 法务", top_k=10)
for i in range(len(results) - 1):
assert results[i].score >= results[i + 1].score