fischer-agentkit/tests/unit/memory/test_local_rag.py

523 lines
17 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""LocalRAGService 单元测试 - 本地文档 RAG 服务
使用 InMemoryLocalRAGService 进行测试,无需 pgvector 依赖。
同时测试分块策略TextChunker / StructuralChunker
"""
import pytest
from agentkit.memory.chunking import Chunk, StructuralChunker, TextChunker
from agentkit.memory.document_loader import Document as LoaderDocument
from agentkit.memory.embedder import MockEmbedder
from agentkit.memory.knowledge_base import Document, KnowledgeBase, QueryResult, SourceInfo
from agentkit.memory.local_rag import InMemoryLocalRAGService
# ── Fixtures ──────────────────────────────────────────────
@pytest.fixture
def embedder():
return MockEmbedder(dimension=128)
@pytest.fixture
def rag_service(embedder):
return InMemoryLocalRAGService(embedder=embedder, chunk_size=500, chunk_overlap=50)
@pytest.fixture
def sample_documents():
"""knowledge_base.Document 格式的测试文档"""
return [
Document(
doc_id="doc-1",
content="Python 是一种通用编程语言。它支持多种编程范式包括面向对象、命令式、函数式和过程式编程。Python 的设计哲学强调代码的可读性和简洁性。",
title="Python 入门指南",
source_id="python_intro.txt",
metadata={"source": "python_intro.txt", "format": "text"},
),
Document(
doc_id="doc-2",
content="机器学习是人工智能的一个分支,它使计算机系统能够从数据中学习和改进。常见的机器学习算法包括线性回归、决策树、支持向量机和神经网络。",
title="机器学习基础",
source_id="ml_basics.txt",
metadata={"source": "ml_basics.txt", "format": "text"},
),
]
@pytest.fixture
def markdown_document():
return Document(
doc_id="doc-md-1",
content="""# API 文档
## 认证
所有 API 请求需要 Bearer Token 认证。请在请求头中添加 Authorization 字段。
## 用户接口
### 获取用户信息
GET /api/users/{id}
返回指定用户的详细信息。
### 创建用户
POST /api/users
创建一个新用户。
## 数据接口
### 查询数据
POST /api/data/query
根据条件查询数据。
""",
title="API 文档",
source_id="api_doc.md",
metadata={"source": "api_doc.md", "format": "markdown"},
)
# ── TextChunker 测试 ──────────────────────────────────────
class TestTextChunker:
"""TextChunker 单元测试"""
def test_chunk_short_text(self):
chunker = TextChunker(chunk_size=1000, chunk_overlap=100)
chunks = chunker.chunk("Short text", source_doc_id="doc-1")
assert len(chunks) == 1
assert chunks[0].content == "Short text"
assert chunks[0].metadata["source_doc"] == "doc-1"
assert chunks[0].metadata["position"] == 0
def test_chunk_empty_text(self):
chunker = TextChunker(chunk_size=1000, chunk_overlap=100)
chunks = chunker.chunk("", source_doc_id="doc-1")
assert len(chunks) == 0
def test_chunk_whitespace_only(self):
chunker = TextChunker(chunk_size=1000, chunk_overlap=100)
chunks = chunker.chunk(" \n\n \t ", source_doc_id="doc-1")
assert len(chunks) == 0
def test_chunk_long_text(self):
chunker = TextChunker(chunk_size=100, chunk_overlap=20)
text = "A" * 300
chunks = chunker.chunk(text, source_doc_id="doc-1")
assert len(chunks) >= 2
# 每个块不超过 chunk_size允许少量超出用于句子边界
for chunk in chunks:
assert len(chunk.content) <= 150 # 允许一些余量
def test_chunk_preserves_metadata(self):
chunker = TextChunker(chunk_size=1000, chunk_overlap=100)
chunks = chunker.chunk(
"Some content",
source_doc_id="doc-1",
metadata={"format": "pdf", "page_count": 5},
)
assert len(chunks) == 1
assert chunks[0].metadata["format"] == "pdf"
assert chunks[0].metadata["page_count"] == 5
assert chunks[0].metadata["source_doc"] == "doc-1"
def test_chunk_with_multiple_paragraphs(self):
chunker = TextChunker(chunk_size=200, chunk_overlap=20, separator="\n\n")
text = "第一段内容,包含一些文字。\n\n第二段内容,也有一些文字。\n\n第三段内容,同样有文字。"
chunks = chunker.chunk(text, source_doc_id="doc-1")
assert len(chunks) >= 1
for chunk in chunks:
assert len(chunk.content) > 0
def test_invalid_overlap(self):
with pytest.raises(ValueError):
TextChunker(chunk_size=100, chunk_overlap=100)
def test_chunk_with_separator(self):
chunker = TextChunker(chunk_size=200, chunk_overlap=20, separator="\n\n")
text = "第一段内容\n\n第二段内容\n\n第三段内容"
chunks = chunker.chunk(text, source_doc_id="doc-1")
assert len(chunks) >= 1
for chunk in chunks:
assert len(chunk.content) > 0
class TestStructuralChunker:
"""StructuralChunker 单元测试"""
def test_chunk_markdown_by_headings(self):
chunker = StructuralChunker(chunk_size=1000, chunk_overlap=50)
md = """# Title
## Section A
Content for section A.
## Section B
Content for section B.
## Section C
Content for section C."""
chunks = chunker.chunk(md, source_doc_id="doc-1")
assert len(chunks) >= 3
# 每个块应该有标题元数据
headings = [c.metadata.get("heading") for c in chunks]
assert "Section A" in headings
assert "Section B" in headings
assert "Section C" in headings
def test_chunk_markdown_no_headings(self):
chunker = StructuralChunker(chunk_size=1000, chunk_overlap=50)
md = "Just some text without any headings."
chunks = chunker.chunk(md, source_doc_id="doc-1")
assert len(chunks) == 1
assert chunks[0].content == "Just some text without any headings."
def test_chunk_empty_text(self):
chunker = StructuralChunker(chunk_size=1000, chunk_overlap=50)
chunks = chunker.chunk("", source_doc_id="doc-1")
assert len(chunks) == 0
def test_chunk_large_section_falls_back_to_text_chunker(self):
chunker = StructuralChunker(chunk_size=100, chunk_overlap=20)
md = """# Large Section
""" + "A" * 300
chunks = chunker.chunk(md, source_doc_id="doc-1")
# 大段应被 TextChunker 进一步切分
assert len(chunks) >= 2
for chunk in chunks:
assert chunk.metadata.get("heading") == "Large Section"
def test_heading_levels(self):
chunker = StructuralChunker(chunk_size=1000, heading_levels=2)
md = """# H1
Content 1.
## H2
Content 2.
### H3
This should be part of H2 section since heading_levels=2.
"""
chunks = chunker.chunk(md, source_doc_id="doc-1")
# H3 不应该作为独立标题分割
assert len(chunks) >= 2
# ── Chunk 数据类测试 ──────────────────────────────────────
class TestChunk:
"""Chunk 数据类测试"""
def test_default_metadata(self):
chunk = Chunk(chunk_id="c1", content="test")
assert chunk.metadata["source_doc"] == ""
assert chunk.metadata["position"] == 0
def test_to_dict(self):
chunk = Chunk(
chunk_id="c1",
content="test content",
metadata={"source_doc": "doc-1", "position": 0},
)
d = chunk.to_dict()
assert d["chunk_id"] == "c1"
assert d["content"] == "test content"
assert d["metadata"]["source_doc"] == "doc-1"
# ── InMemoryLocalRAGService 测试 ──────────────────────────
class TestInMemoryLocalRAGService:
"""InMemoryLocalRAGService 单元测试"""
@pytest.mark.asyncio
async def test_ingest_documents(self, rag_service, sample_documents):
ids = await rag_service.ingest(sample_documents)
assert len(ids) == 2
assert "doc-1" in ids
assert "doc-2" in ids
@pytest.mark.asyncio
async def test_query_after_ingest(self, rag_service, sample_documents):
await rag_service.ingest(sample_documents)
results = await rag_service.query("编程语言", top_k=2)
assert len(results) >= 1
assert all(isinstance(r, QueryResult) for r in results)
# 结果应该包含相关内容
assert any("Python" in r.content or "编程" in r.content for r in results)
@pytest.mark.asyncio
async def test_query_returns_source_info(self, rag_service, sample_documents):
await rag_service.ingest(sample_documents)
results = await rag_service.query("机器学习", top_k=5)
assert len(results) >= 1
for r in results:
assert r.source_id != ""
assert r.source_name != ""
@pytest.mark.asyncio
async def test_query_no_results_when_empty(self, rag_service):
results = await rag_service.query("anything", top_k=5)
assert len(results) == 0
@pytest.mark.asyncio
async def test_delete_by_id(self, rag_service, sample_documents):
await rag_service.ingest(sample_documents)
deleted = await rag_service.delete_by_id("doc-1")
assert deleted is True
# 删除后查询不应返回该文档的内容
results = await rag_service.query("Python", top_k=5)
assert all(r.source_id != "doc-1" for r in results)
@pytest.mark.asyncio
async def test_delete_nonexistent_id(self, rag_service):
deleted = await rag_service.delete_by_id("nonexistent")
assert deleted is False
@pytest.mark.asyncio
async def test_list_sources(self, rag_service, sample_documents):
await rag_service.ingest(sample_documents)
sources = await rag_service.list_sources()
assert len(sources) == 2
source_ids = {s.source_id for s in sources}
assert "doc-1" in source_ids
assert "doc-2" in source_ids
for s in sources:
assert isinstance(s, SourceInfo)
assert s.source_name != ""
assert s.document_count > 0
@pytest.mark.asyncio
async def test_list_sources_empty(self, rag_service):
sources = await rag_service.list_sources()
assert len(sources) == 0
@pytest.mark.asyncio
async def test_health_check(self, rag_service):
assert await rag_service.health_check() is True
@pytest.mark.asyncio
async def test_ingest_markdown_with_structural_chunking(self, rag_service, markdown_document):
ids = await rag_service.ingest([markdown_document])
assert len(ids) == 1
sources = await rag_service.list_sources()
assert len(sources) == 1
assert sources[0].source_type == "markdown"
@pytest.mark.asyncio
async def test_query_markdown_by_section(self, rag_service, markdown_document):
await rag_service.ingest([markdown_document])
results = await rag_service.query("认证", top_k=3)
# MockEmbedder 基于文本哈希,语义相关性不保证,
# 但应至少返回结果(因为文档已被摄取)
assert len(results) >= 0 # 可能因阈值过滤无结果
# 使用与文档内容更相似的查询词来验证检索
results = await rag_service.query("API 文档 认证", top_k=3)
assert len(results) >= 1
@pytest.mark.asyncio
async def test_ingest_empty_document(self, rag_service):
doc = Document(
doc_id="empty-doc",
content="",
title="Empty",
source_id="empty.txt",
metadata={"source": "empty.txt", "format": "text"},
)
ids = await rag_service.ingest([doc])
# 空文档应该被跳过(没有块生成)
assert len(ids) == 1 # doc_id 仍然返回
sources = await rag_service.list_sources()
assert len(sources) == 1
assert sources[0].document_count == 0
@pytest.mark.asyncio
async def test_ingest_large_document_chunking(self, embedder):
"""大文件分块 → 块大小在配置范围内"""
rag = InMemoryLocalRAGService(embedder=embedder, chunk_size=200, chunk_overlap=20)
large_content = "这是一段很长的文本。" * 200 # ~2000 字符
doc = Document(
doc_id="large-doc",
content=large_content,
title="Large Document",
source_id="large.txt",
metadata={"source": "large.txt", "format": "text"},
)
ids = await rag.ingest([doc])
assert len(ids) == 1
sources = await rag.list_sources()
assert sources[0].document_count > 1 # 应该被分成多个块
@pytest.mark.asyncio
async def test_query_result_has_score(self, rag_service, sample_documents):
await rag_service.ingest(sample_documents)
results = await rag_service.query("编程", top_k=5)
for r in results:
assert 0.0 <= r.score <= 1.0
@pytest.mark.asyncio
async def test_ingest_loader_document(self, rag_service):
"""测试传入 document_loader.Document 时的自动转换"""
loader_doc = LoaderDocument(
doc_id="loader-doc-1",
title="Test Loader Doc",
content="This is content from document_loader.",
metadata={"source": "test.txt", "format": "text"},
)
ids = await rag_service.ingest([loader_doc])
assert len(ids) == 1
results = await rag_service.query("content", top_k=3)
assert len(results) >= 1
@pytest.mark.asyncio
async def test_multiple_ingest_same_doc_id(self, rag_service):
"""重复摄取相同 doc_id 的文档"""
doc1 = Document(
doc_id="same-id",
content="First version content",
title="Version 1",
source_id="v1.txt",
metadata={"source": "v1.txt", "format": "text"},
)
doc2 = Document(
doc_id="same-id",
content="Second version content with more text",
title="Version 2",
source_id="v2.txt",
metadata={"source": "v2.txt", "format": "text"},
)
await rag_service.ingest([doc1])
await rag_service.ingest([doc2])
# 第二次摄取会覆盖(内存实现中 doc_id 相同会覆盖)
sources = await rag_service.list_sources()
source_ids = [s.source_id for s in sources]
assert "same-id" in source_ids
# ── KnowledgeBase 协议测试 ────────────────────────────────
class TestKnowledgeBaseProtocol:
"""KnowledgeBase 协议兼容性测试"""
@pytest.mark.asyncio
async def test_inmemory_service_implements_protocol(self, rag_service):
"""InMemoryLocalRAGService 应该满足 KnowledgeBase 协议"""
assert isinstance(rag_service, KnowledgeBase)
@pytest.mark.asyncio
async def test_protocol_methods_exist(self, rag_service):
"""验证所有协议方法都存在"""
assert hasattr(rag_service, "ingest")
assert hasattr(rag_service, "query")
assert hasattr(rag_service, "delete_by_id")
assert hasattr(rag_service, "list_sources")
assert hasattr(rag_service, "health_check")
# 验证方法可调用
assert callable(rag_service.ingest)
assert callable(rag_service.query)
assert callable(rag_service.delete_by_id)
assert callable(rag_service.list_sources)
assert callable(rag_service.health_check)
# ── QueryResult / SourceInfo 测试 ─────────────────────────
class TestQueryResult:
"""QueryResult 数据类测试"""
def test_creation(self):
result = QueryResult(
content="test content",
source_id="doc-1",
source_name="Test Doc",
score=0.95,
)
assert result.content == "test content"
assert result.source_id == "doc-1"
assert result.source_name == "Test Doc"
assert result.score == 0.95
def test_with_optional_fields(self):
result = QueryResult(
content="test content",
source_id="doc-1",
source_name="Test Doc",
score=0.95,
metadata={"position": 0},
doc_id="doc-1",
title="Test Doc",
)
assert result.doc_id == "doc-1"
assert result.title == "Test Doc"
assert result.metadata["position"] == 0
class TestSourceInfo:
"""SourceInfo 数据类测试"""
def test_creation(self):
from datetime import datetime, timezone
now = datetime.now(timezone.utc)
info = SourceInfo(
source_id="doc-1",
source_name="Test",
source_type="local",
document_count=5,
last_updated=now,
)
assert info.source_id == "doc-1"
assert info.source_name == "Test"
assert info.source_type == "local"
assert info.document_count == 5