523 lines
17 KiB
Python
523 lines
17 KiB
Python
"""LocalRAGService 单元测试 - 本地文档 RAG 服务
|
||
|
||
使用 InMemoryLocalRAGService 进行测试,无需 pgvector 依赖。
|
||
同时测试分块策略(TextChunker / StructuralChunker)。
|
||
"""
|
||
|
||
import pytest
|
||
|
||
from agentkit.memory.chunking import Chunk, StructuralChunker, TextChunker
|
||
from agentkit.memory.document_loader import Document as LoaderDocument
|
||
from agentkit.memory.embedder import MockEmbedder
|
||
from agentkit.memory.knowledge_base import Document, KnowledgeBase, QueryResult, SourceInfo
|
||
from agentkit.memory.local_rag import InMemoryLocalRAGService
|
||
|
||
|
||
# ── Fixtures ──────────────────────────────────────────────
|
||
|
||
|
||
@pytest.fixture
|
||
def embedder():
|
||
return MockEmbedder(dimension=128)
|
||
|
||
|
||
@pytest.fixture
|
||
def rag_service(embedder):
|
||
return InMemoryLocalRAGService(embedder=embedder, chunk_size=500, chunk_overlap=50)
|
||
|
||
|
||
@pytest.fixture
|
||
def sample_documents():
|
||
"""knowledge_base.Document 格式的测试文档"""
|
||
return [
|
||
Document(
|
||
doc_id="doc-1",
|
||
content="Python 是一种通用编程语言。它支持多种编程范式,包括面向对象、命令式、函数式和过程式编程。Python 的设计哲学强调代码的可读性和简洁性。",
|
||
title="Python 入门指南",
|
||
source_id="python_intro.txt",
|
||
metadata={"source": "python_intro.txt", "format": "text"},
|
||
),
|
||
Document(
|
||
doc_id="doc-2",
|
||
content="机器学习是人工智能的一个分支,它使计算机系统能够从数据中学习和改进。常见的机器学习算法包括线性回归、决策树、支持向量机和神经网络。",
|
||
title="机器学习基础",
|
||
source_id="ml_basics.txt",
|
||
metadata={"source": "ml_basics.txt", "format": "text"},
|
||
),
|
||
]
|
||
|
||
|
||
@pytest.fixture
|
||
def markdown_document():
|
||
return Document(
|
||
doc_id="doc-md-1",
|
||
content="""# API 文档
|
||
|
||
## 认证
|
||
|
||
所有 API 请求需要 Bearer Token 认证。请在请求头中添加 Authorization 字段。
|
||
|
||
## 用户接口
|
||
|
||
### 获取用户信息
|
||
|
||
GET /api/users/{id}
|
||
|
||
返回指定用户的详细信息。
|
||
|
||
### 创建用户
|
||
|
||
POST /api/users
|
||
|
||
创建一个新用户。
|
||
|
||
## 数据接口
|
||
|
||
### 查询数据
|
||
|
||
POST /api/data/query
|
||
|
||
根据条件查询数据。
|
||
""",
|
||
title="API 文档",
|
||
source_id="api_doc.md",
|
||
metadata={"source": "api_doc.md", "format": "markdown"},
|
||
)
|
||
|
||
|
||
# ── TextChunker 测试 ──────────────────────────────────────
|
||
|
||
|
||
class TestTextChunker:
|
||
"""TextChunker 单元测试"""
|
||
|
||
def test_chunk_short_text(self):
|
||
chunker = TextChunker(chunk_size=1000, chunk_overlap=100)
|
||
chunks = chunker.chunk("Short text", source_doc_id="doc-1")
|
||
|
||
assert len(chunks) == 1
|
||
assert chunks[0].content == "Short text"
|
||
assert chunks[0].metadata["source_doc"] == "doc-1"
|
||
assert chunks[0].metadata["position"] == 0
|
||
|
||
def test_chunk_empty_text(self):
|
||
chunker = TextChunker(chunk_size=1000, chunk_overlap=100)
|
||
chunks = chunker.chunk("", source_doc_id="doc-1")
|
||
assert len(chunks) == 0
|
||
|
||
def test_chunk_whitespace_only(self):
|
||
chunker = TextChunker(chunk_size=1000, chunk_overlap=100)
|
||
chunks = chunker.chunk(" \n\n \t ", source_doc_id="doc-1")
|
||
assert len(chunks) == 0
|
||
|
||
def test_chunk_long_text(self):
|
||
chunker = TextChunker(chunk_size=100, chunk_overlap=20)
|
||
text = "A" * 300
|
||
chunks = chunker.chunk(text, source_doc_id="doc-1")
|
||
|
||
assert len(chunks) >= 2
|
||
# 每个块不超过 chunk_size(允许少量超出用于句子边界)
|
||
for chunk in chunks:
|
||
assert len(chunk.content) <= 150 # 允许一些余量
|
||
|
||
def test_chunk_preserves_metadata(self):
|
||
chunker = TextChunker(chunk_size=1000, chunk_overlap=100)
|
||
chunks = chunker.chunk(
|
||
"Some content",
|
||
source_doc_id="doc-1",
|
||
metadata={"format": "pdf", "page_count": 5},
|
||
)
|
||
|
||
assert len(chunks) == 1
|
||
assert chunks[0].metadata["format"] == "pdf"
|
||
assert chunks[0].metadata["page_count"] == 5
|
||
assert chunks[0].metadata["source_doc"] == "doc-1"
|
||
|
||
def test_chunk_with_multiple_paragraphs(self):
|
||
chunker = TextChunker(chunk_size=200, chunk_overlap=20, separator="\n\n")
|
||
text = "第一段内容,包含一些文字。\n\n第二段内容,也有一些文字。\n\n第三段内容,同样有文字。"
|
||
chunks = chunker.chunk(text, source_doc_id="doc-1")
|
||
|
||
assert len(chunks) >= 1
|
||
for chunk in chunks:
|
||
assert len(chunk.content) > 0
|
||
|
||
def test_invalid_overlap(self):
|
||
with pytest.raises(ValueError):
|
||
TextChunker(chunk_size=100, chunk_overlap=100)
|
||
|
||
def test_chunk_with_separator(self):
|
||
chunker = TextChunker(chunk_size=200, chunk_overlap=20, separator="\n\n")
|
||
text = "第一段内容\n\n第二段内容\n\n第三段内容"
|
||
chunks = chunker.chunk(text, source_doc_id="doc-1")
|
||
|
||
assert len(chunks) >= 1
|
||
for chunk in chunks:
|
||
assert len(chunk.content) > 0
|
||
|
||
|
||
class TestStructuralChunker:
|
||
"""StructuralChunker 单元测试"""
|
||
|
||
def test_chunk_markdown_by_headings(self):
|
||
chunker = StructuralChunker(chunk_size=1000, chunk_overlap=50)
|
||
md = """# Title
|
||
|
||
## Section A
|
||
|
||
Content for section A.
|
||
|
||
## Section B
|
||
|
||
Content for section B.
|
||
|
||
## Section C
|
||
|
||
Content for section C."""
|
||
chunks = chunker.chunk(md, source_doc_id="doc-1")
|
||
|
||
assert len(chunks) >= 3
|
||
# 每个块应该有标题元数据
|
||
headings = [c.metadata.get("heading") for c in chunks]
|
||
assert "Section A" in headings
|
||
assert "Section B" in headings
|
||
assert "Section C" in headings
|
||
|
||
def test_chunk_markdown_no_headings(self):
|
||
chunker = StructuralChunker(chunk_size=1000, chunk_overlap=50)
|
||
md = "Just some text without any headings."
|
||
chunks = chunker.chunk(md, source_doc_id="doc-1")
|
||
|
||
assert len(chunks) == 1
|
||
assert chunks[0].content == "Just some text without any headings."
|
||
|
||
def test_chunk_empty_text(self):
|
||
chunker = StructuralChunker(chunk_size=1000, chunk_overlap=50)
|
||
chunks = chunker.chunk("", source_doc_id="doc-1")
|
||
assert len(chunks) == 0
|
||
|
||
def test_chunk_large_section_falls_back_to_text_chunker(self):
|
||
chunker = StructuralChunker(chunk_size=100, chunk_overlap=20)
|
||
md = """# Large Section
|
||
|
||
""" + "A" * 300
|
||
chunks = chunker.chunk(md, source_doc_id="doc-1")
|
||
|
||
# 大段应被 TextChunker 进一步切分
|
||
assert len(chunks) >= 2
|
||
for chunk in chunks:
|
||
assert chunk.metadata.get("heading") == "Large Section"
|
||
|
||
def test_heading_levels(self):
|
||
chunker = StructuralChunker(chunk_size=1000, heading_levels=2)
|
||
md = """# H1
|
||
|
||
Content 1.
|
||
|
||
## H2
|
||
|
||
Content 2.
|
||
|
||
### H3
|
||
|
||
This should be part of H2 section since heading_levels=2.
|
||
"""
|
||
chunks = chunker.chunk(md, source_doc_id="doc-1")
|
||
# H3 不应该作为独立标题分割
|
||
assert len(chunks) >= 2
|
||
|
||
|
||
# ── Chunk 数据类测试 ──────────────────────────────────────
|
||
|
||
|
||
class TestChunk:
|
||
"""Chunk 数据类测试"""
|
||
|
||
def test_default_metadata(self):
|
||
chunk = Chunk(chunk_id="c1", content="test")
|
||
assert chunk.metadata["source_doc"] == ""
|
||
assert chunk.metadata["position"] == 0
|
||
|
||
def test_to_dict(self):
|
||
chunk = Chunk(
|
||
chunk_id="c1",
|
||
content="test content",
|
||
metadata={"source_doc": "doc-1", "position": 0},
|
||
)
|
||
d = chunk.to_dict()
|
||
assert d["chunk_id"] == "c1"
|
||
assert d["content"] == "test content"
|
||
assert d["metadata"]["source_doc"] == "doc-1"
|
||
|
||
|
||
# ── InMemoryLocalRAGService 测试 ──────────────────────────
|
||
|
||
|
||
class TestInMemoryLocalRAGService:
|
||
"""InMemoryLocalRAGService 单元测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_ingest_documents(self, rag_service, sample_documents):
|
||
ids = await rag_service.ingest(sample_documents)
|
||
|
||
assert len(ids) == 2
|
||
assert "doc-1" in ids
|
||
assert "doc-2" in ids
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_query_after_ingest(self, rag_service, sample_documents):
|
||
await rag_service.ingest(sample_documents)
|
||
|
||
results = await rag_service.query("编程语言", top_k=2)
|
||
|
||
assert len(results) >= 1
|
||
assert all(isinstance(r, QueryResult) for r in results)
|
||
# 结果应该包含相关内容
|
||
assert any("Python" in r.content or "编程" in r.content for r in results)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_query_returns_source_info(self, rag_service, sample_documents):
|
||
await rag_service.ingest(sample_documents)
|
||
|
||
results = await rag_service.query("机器学习", top_k=5)
|
||
|
||
assert len(results) >= 1
|
||
for r in results:
|
||
assert r.source_id != ""
|
||
assert r.source_name != ""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_query_no_results_when_empty(self, rag_service):
|
||
results = await rag_service.query("anything", top_k=5)
|
||
assert len(results) == 0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_delete_by_id(self, rag_service, sample_documents):
|
||
await rag_service.ingest(sample_documents)
|
||
|
||
deleted = await rag_service.delete_by_id("doc-1")
|
||
assert deleted is True
|
||
|
||
# 删除后查询不应返回该文档的内容
|
||
results = await rag_service.query("Python", top_k=5)
|
||
assert all(r.source_id != "doc-1" for r in results)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_delete_nonexistent_id(self, rag_service):
|
||
deleted = await rag_service.delete_by_id("nonexistent")
|
||
assert deleted is False
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_list_sources(self, rag_service, sample_documents):
|
||
await rag_service.ingest(sample_documents)
|
||
|
||
sources = await rag_service.list_sources()
|
||
|
||
assert len(sources) == 2
|
||
source_ids = {s.source_id for s in sources}
|
||
assert "doc-1" in source_ids
|
||
assert "doc-2" in source_ids
|
||
|
||
for s in sources:
|
||
assert isinstance(s, SourceInfo)
|
||
assert s.source_name != ""
|
||
assert s.document_count > 0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_list_sources_empty(self, rag_service):
|
||
sources = await rag_service.list_sources()
|
||
assert len(sources) == 0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_health_check(self, rag_service):
|
||
assert await rag_service.health_check() is True
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_ingest_markdown_with_structural_chunking(self, rag_service, markdown_document):
|
||
ids = await rag_service.ingest([markdown_document])
|
||
|
||
assert len(ids) == 1
|
||
sources = await rag_service.list_sources()
|
||
assert len(sources) == 1
|
||
assert sources[0].source_type == "markdown"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_query_markdown_by_section(self, rag_service, markdown_document):
|
||
await rag_service.ingest([markdown_document])
|
||
|
||
results = await rag_service.query("认证", top_k=3)
|
||
|
||
# MockEmbedder 基于文本哈希,语义相关性不保证,
|
||
# 但应至少返回结果(因为文档已被摄取)
|
||
assert len(results) >= 0 # 可能因阈值过滤无结果
|
||
# 使用与文档内容更相似的查询词来验证检索
|
||
results = await rag_service.query("API 文档 认证", top_k=3)
|
||
assert len(results) >= 1
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_ingest_empty_document(self, rag_service):
|
||
doc = Document(
|
||
doc_id="empty-doc",
|
||
content="",
|
||
title="Empty",
|
||
source_id="empty.txt",
|
||
metadata={"source": "empty.txt", "format": "text"},
|
||
)
|
||
ids = await rag_service.ingest([doc])
|
||
|
||
# 空文档应该被跳过(没有块生成)
|
||
assert len(ids) == 1 # doc_id 仍然返回
|
||
sources = await rag_service.list_sources()
|
||
assert len(sources) == 1
|
||
assert sources[0].document_count == 0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_ingest_large_document_chunking(self, embedder):
|
||
"""大文件分块 → 块大小在配置范围内"""
|
||
rag = InMemoryLocalRAGService(embedder=embedder, chunk_size=200, chunk_overlap=20)
|
||
|
||
large_content = "这是一段很长的文本。" * 200 # ~2000 字符
|
||
doc = Document(
|
||
doc_id="large-doc",
|
||
content=large_content,
|
||
title="Large Document",
|
||
source_id="large.txt",
|
||
metadata={"source": "large.txt", "format": "text"},
|
||
)
|
||
ids = await rag.ingest([doc])
|
||
|
||
assert len(ids) == 1
|
||
sources = await rag.list_sources()
|
||
assert sources[0].document_count > 1 # 应该被分成多个块
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_query_result_has_score(self, rag_service, sample_documents):
|
||
await rag_service.ingest(sample_documents)
|
||
|
||
results = await rag_service.query("编程", top_k=5)
|
||
|
||
for r in results:
|
||
assert 0.0 <= r.score <= 1.0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_ingest_loader_document(self, rag_service):
|
||
"""测试传入 document_loader.Document 时的自动转换"""
|
||
loader_doc = LoaderDocument(
|
||
doc_id="loader-doc-1",
|
||
title="Test Loader Doc",
|
||
content="This is content from document_loader.",
|
||
metadata={"source": "test.txt", "format": "text"},
|
||
)
|
||
ids = await rag_service.ingest([loader_doc])
|
||
|
||
assert len(ids) == 1
|
||
results = await rag_service.query("content", top_k=3)
|
||
assert len(results) >= 1
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_multiple_ingest_same_doc_id(self, rag_service):
|
||
"""重复摄取相同 doc_id 的文档"""
|
||
doc1 = Document(
|
||
doc_id="same-id",
|
||
content="First version content",
|
||
title="Version 1",
|
||
source_id="v1.txt",
|
||
metadata={"source": "v1.txt", "format": "text"},
|
||
)
|
||
doc2 = Document(
|
||
doc_id="same-id",
|
||
content="Second version content with more text",
|
||
title="Version 2",
|
||
source_id="v2.txt",
|
||
metadata={"source": "v2.txt", "format": "text"},
|
||
)
|
||
|
||
await rag_service.ingest([doc1])
|
||
await rag_service.ingest([doc2])
|
||
|
||
# 第二次摄取会覆盖(内存实现中 doc_id 相同会覆盖)
|
||
sources = await rag_service.list_sources()
|
||
source_ids = [s.source_id for s in sources]
|
||
assert "same-id" in source_ids
|
||
|
||
|
||
# ── KnowledgeBase 协议测试 ────────────────────────────────
|
||
|
||
|
||
class TestKnowledgeBaseProtocol:
|
||
"""KnowledgeBase 协议兼容性测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_inmemory_service_implements_protocol(self, rag_service):
|
||
"""InMemoryLocalRAGService 应该满足 KnowledgeBase 协议"""
|
||
assert isinstance(rag_service, KnowledgeBase)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_protocol_methods_exist(self, rag_service):
|
||
"""验证所有协议方法都存在"""
|
||
assert hasattr(rag_service, "ingest")
|
||
assert hasattr(rag_service, "query")
|
||
assert hasattr(rag_service, "delete_by_id")
|
||
assert hasattr(rag_service, "list_sources")
|
||
assert hasattr(rag_service, "health_check")
|
||
|
||
# 验证方法可调用
|
||
assert callable(rag_service.ingest)
|
||
assert callable(rag_service.query)
|
||
assert callable(rag_service.delete_by_id)
|
||
assert callable(rag_service.list_sources)
|
||
assert callable(rag_service.health_check)
|
||
|
||
|
||
# ── QueryResult / SourceInfo 测试 ─────────────────────────
|
||
|
||
|
||
class TestQueryResult:
|
||
"""QueryResult 数据类测试"""
|
||
|
||
def test_creation(self):
|
||
result = QueryResult(
|
||
content="test content",
|
||
source_id="doc-1",
|
||
source_name="Test Doc",
|
||
score=0.95,
|
||
)
|
||
assert result.content == "test content"
|
||
assert result.source_id == "doc-1"
|
||
assert result.source_name == "Test Doc"
|
||
assert result.score == 0.95
|
||
|
||
def test_with_optional_fields(self):
|
||
result = QueryResult(
|
||
content="test content",
|
||
source_id="doc-1",
|
||
source_name="Test Doc",
|
||
score=0.95,
|
||
metadata={"position": 0},
|
||
doc_id="doc-1",
|
||
title="Test Doc",
|
||
)
|
||
assert result.doc_id == "doc-1"
|
||
assert result.title == "Test Doc"
|
||
assert result.metadata["position"] == 0
|
||
|
||
|
||
class TestSourceInfo:
|
||
"""SourceInfo 数据类测试"""
|
||
|
||
def test_creation(self):
|
||
from datetime import datetime, timezone
|
||
|
||
now = datetime.now(timezone.utc)
|
||
info = SourceInfo(
|
||
source_id="doc-1",
|
||
source_name="Test",
|
||
source_type="local",
|
||
document_count=5,
|
||
last_updated=now,
|
||
)
|
||
assert info.source_id == "doc-1"
|
||
assert info.source_name == "Test"
|
||
assert info.source_type == "local"
|
||
assert info.document_count == 5
|