fischer-agentkit/tests/unit/rag_platform/test_document_processor.py

383 lines
15 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""U3+U7 测试 — 文档处理管道。
测试场景:
1. parse + segment 管道mock 文件 I/O
2. preview 返回 chunk 列表
3. vectorize 调用 embed model + vector store
4. 失败时设置 error 状态
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from agentkit.rag_platform.document_processor import DocumentProcessor
from agentkit.rag_platform.models import DocumentStatus
from agentkit.rag_platform.preview import PreviewResult, generate_preview
class TestParseAndSegment:
"""parse + segment 管道测试。"""
def test_parse_extracts_text(self, tmp_path):
"""parse 从文件提取文本。"""
# 创建测试文件
file_path = tmp_path / "test.txt"
file_path.write_text("Hello, world!\nThis is a test document.", encoding="utf-8")
processor = DocumentProcessor()
text = processor.parse(str(file_path), "txt")
assert "Hello, world!" in text
assert "test document" in text
def test_parse_applies_sanitization(self, tmp_path):
"""parse 对文本格式应用内容净化。"""
file_path = tmp_path / "test.md"
file_path.write_text(
"Hello <script>alert(1)</script> world",
encoding="utf-8",
)
processor = DocumentProcessor()
text = processor.parse(str(file_path), "md")
assert "<script>" not in text
assert "Hello" in text
assert "world" in text
def test_segment_splits_text(self):
"""segment 将文本分段。"""
processor = DocumentProcessor(chunk_size=100, chunk_overlap=10)
text = "This is a sentence. " * 50 # 足够长以触发分段
with patch("llama_index.core.node_parser.SentenceSplitter") as mock_splitter_cls:
mock_splitter = MagicMock()
mock_splitter.split_text.return_value = ["chunk1", "chunk2", "chunk3"]
mock_splitter_cls.return_value = mock_splitter
chunks = processor.segment(text, chunk_size=100, chunk_overlap=10)
assert chunks == ["chunk1", "chunk2", "chunk3"]
mock_splitter.split_text.assert_called_once_with(text)
def test_segment_uses_custom_params(self):
"""segment 使用自定义 chunk_size 和 chunk_overlap。"""
processor = DocumentProcessor()
with patch("llama_index.core.node_parser.SentenceSplitter") as mock_splitter_cls:
mock_splitter = MagicMock()
mock_splitter.split_text.return_value = ["chunk"]
mock_splitter_cls.return_value = mock_splitter
processor.segment("text", chunk_size=256, chunk_overlap=20)
# 验证 SentenceSplitter 使用了自定义参数
call_kwargs = mock_splitter_cls.call_args.kwargs
assert call_kwargs["chunk_size"] == 256
assert call_kwargs["chunk_overlap"] == 20
def test_parse_file_not_found(self):
"""parse 文件不存在时抛出 FileNotFoundError。"""
processor = DocumentProcessor()
with pytest.raises(FileNotFoundError):
processor.parse("/nonexistent/file.txt", "txt")
class TestPreview:
"""preview 方法测试。"""
def test_preview_returns_chunks(self, tmp_path):
"""preview 返回 chunk 字典列表。"""
file_path = tmp_path / "test.txt"
file_path.write_text("Test content for preview.", encoding="utf-8")
processor = DocumentProcessor(chunk_size=100, chunk_overlap=10)
with patch("llama_index.core.node_parser.SentenceSplitter") as mock_splitter_cls:
mock_splitter = MagicMock()
mock_splitter.split_text.return_value = ["chunk1", "chunk2"]
mock_splitter_cls.return_value = mock_splitter
chunks = processor.preview(str(file_path), "txt")
assert len(chunks) == 2
assert chunks[0]["index"] == 0
assert chunks[0]["content"] == "chunk1"
assert chunks[1]["index"] == 1
assert chunks[1]["content"] == "chunk2"
assert chunks[0]["metadata"] == {}
def test_preview_uses_default_params(self, tmp_path):
"""preview 使用构造器默认参数。"""
file_path = tmp_path / "test.txt"
file_path.write_text("Test content.", encoding="utf-8")
processor = DocumentProcessor(chunk_size=512, chunk_overlap=50)
with patch("llama_index.core.node_parser.SentenceSplitter") as mock_splitter_cls:
mock_splitter = MagicMock()
mock_splitter.split_text.return_value = ["chunk"]
mock_splitter_cls.return_value = mock_splitter
processor.preview(str(file_path), "txt")
call_kwargs = mock_splitter_cls.call_args.kwargs
assert call_kwargs["chunk_size"] == 512
assert call_kwargs["chunk_overlap"] == 50
class TestGeneratePreview:
"""generate_preview 函数测试。"""
def test_generate_preview_returns_preview_result(self, tmp_path):
"""generate_preview 返回 PreviewResult。"""
file_path = tmp_path / "test.txt"
file_path.write_text("Test content for preview.", encoding="utf-8")
with patch("llama_index.core.node_parser.SentenceSplitter") as mock_splitter_cls:
mock_splitter = MagicMock()
mock_splitter.split_text.return_value = ["chunk1", "chunk2", "chunk3"]
mock_splitter_cls.return_value = mock_splitter
result = generate_preview(str(file_path), "txt")
assert isinstance(result, PreviewResult)
assert result.total_chunks == 3
assert len(result.chunks) == 3
assert result.chunks[0].index == 0
assert result.chunks[0].content == "chunk1"
assert result.chunks[2].index == 2
assert result.chunks[2].content == "chunk3"
assert result.document_id == "" # 预览阶段无 document_id
def test_generate_preview_with_custom_params(self, tmp_path):
"""generate_preview 使用自定义分段参数。"""
file_path = tmp_path / "test.txt"
file_path.write_text("Test content.", encoding="utf-8")
with patch("llama_index.core.node_parser.SentenceSplitter") as mock_splitter_cls:
mock_splitter = MagicMock()
mock_splitter.split_text.return_value = ["chunk"]
mock_splitter_cls.return_value = mock_splitter
generate_preview(str(file_path), "txt", chunk_size=256, chunk_overlap=20)
call_kwargs = mock_splitter_cls.call_args.kwargs
assert call_kwargs["chunk_size"] == 256
assert call_kwargs["chunk_overlap"] == 20
class TestVectorize:
"""vectorize 方法测试。"""
def _make_mock_embed_model(self):
"""创建 mock embedding 模型。"""
mock = MagicMock()
mock.aget_text_embedding = AsyncMock(return_value=[0.1] * 1536)
return mock
def _make_mock_vector_store(self):
"""创建 mock vector store。"""
mock = MagicMock()
mock.async_add = AsyncMock()
return mock
async def test_vectorize_calls_embed_model(self):
"""vectorize 调用 embedding 模型。"""
processor = DocumentProcessor()
mock_embed = self._make_mock_embed_model()
mock_vs = self._make_mock_vector_store()
chunks = ["chunk1", "chunk2", "chunk3"]
with patch("llama_index.core.schema.TextNode") as mock_node_cls:
mock_nodes = [MagicMock() for _ in chunks]
mock_node_cls.side_effect = mock_nodes
await processor.vectorize(chunks, "kb-1", "doc-1", mock_vs, mock_embed)
# 每个 chunk 都调用了 embed
assert mock_embed.aget_text_embedding.await_count == 3
async def test_vectorize_calls_vector_store(self):
"""vectorize 调用 vector store 的 async_add。"""
processor = DocumentProcessor()
mock_embed = self._make_mock_embed_model()
mock_vs = self._make_mock_vector_store()
chunks = ["chunk1", "chunk2"]
with patch("llama_index.core.schema.TextNode") as mock_node_cls:
mock_nodes = [MagicMock() for _ in chunks]
mock_node_cls.side_effect = mock_nodes
await processor.vectorize(chunks, "kb-1", "doc-1", mock_vs, mock_embed)
mock_vs.async_add.assert_awaited_once()
# 验证传入的 nodes 数量正确
added_nodes = mock_vs.async_add.call_args.args[0]
assert len(added_nodes) == 2
async def test_vectorize_accepts_dict_chunks(self):
"""vectorize 接受字典格式的 chunk。"""
processor = DocumentProcessor()
mock_embed = self._make_mock_embed_model()
mock_vs = self._make_mock_vector_store()
chunks = [
{"index": 0, "content": "chunk1", "metadata": {}},
{"index": 1, "content": "chunk2", "metadata": {}},
]
with patch("llama_index.core.schema.TextNode") as mock_node_cls:
mock_nodes = [MagicMock() for _ in chunks]
mock_node_cls.side_effect = mock_nodes
await processor.vectorize(chunks, "kb-1", "doc-1", mock_vs, mock_embed)
assert mock_embed.aget_text_embedding.await_count == 2
async def test_vectorize_failure_propagates(self):
"""vectorize 失败时抛出异常。"""
processor = DocumentProcessor()
mock_embed = MagicMock()
mock_embed.aget_text_embedding = AsyncMock(side_effect=RuntimeError("embed failed"))
mock_vs = self._make_mock_vector_store()
chunks = ["chunk1", "chunk2"]
with patch("llama_index.core.schema.TextNode"):
with pytest.raises(RuntimeError, match="embed failed"):
await processor.vectorize(chunks, "kb-1", "doc-1", mock_vs, mock_embed)
async def test_vectorize_sets_metadata(self):
"""vectorize 为每个 node 设置正确的 metadata。"""
processor = DocumentProcessor()
mock_embed = self._make_mock_embed_model()
mock_vs = self._make_mock_vector_store()
chunks = ["chunk1", "chunk2"]
with patch("llama_index.core.schema.TextNode") as mock_node_cls:
mock_nodes = [MagicMock() for _ in chunks]
mock_node_cls.side_effect = mock_nodes
await processor.vectorize(chunks, "kb-1", "doc-1", mock_vs, mock_embed)
# 验证 TextNode 被创建时使用了正确的 metadata
assert mock_node_cls.call_count == 2
first_call_kwargs = mock_node_cls.call_args_list[0].kwargs
assert first_call_kwargs["metadata"]["kb_id"] == "kb-1"
assert first_call_kwargs["metadata"]["document_id"] == "doc-1"
assert first_call_kwargs["metadata"]["chunk_index"] == 0
class TestProcess:
"""process 方法(完整管道 + 状态转换)测试。"""
def _make_mock_store(self):
"""创建 mock KBStore。"""
store = MagicMock()
store.update_document_status = AsyncMock()
return store
async def test_process_success_transitions(self, tmp_path):
"""process 成功时状态转换parsing → segmenting → vectorizing → indexed。"""
file_path = tmp_path / "test.txt"
file_path.write_text("Test content.", encoding="utf-8")
processor = DocumentProcessor()
mock_embed = MagicMock()
mock_embed.aget_text_embedding = AsyncMock(return_value=[0.1] * 1536)
mock_vs = MagicMock()
mock_vs.async_add = AsyncMock()
mock_store = self._make_mock_store()
with patch("llama_index.core.node_parser.SentenceSplitter") as mock_splitter_cls:
mock_splitter = MagicMock()
mock_splitter.split_text.return_value = ["chunk1"]
mock_splitter_cls.return_value = mock_splitter
with patch("llama_index.core.schema.TextNode"):
await processor.process(
str(file_path),
"txt",
"kb-1",
"doc-1",
mock_vs,
mock_embed,
mock_store,
)
# 验证状态转换序列
status_calls = mock_store.update_document_status.call_args_list
assert len(status_calls) == 4
assert status_calls[0].args[1] == DocumentStatus.parsing
assert status_calls[1].args[1] == DocumentStatus.segmenting
assert status_calls[2].args[1] == DocumentStatus.vectorizing
assert status_calls[3].args[1] == DocumentStatus.indexed
async def test_process_failure_sets_failed_status(self, tmp_path):
"""process 失败时状态置为 failed。"""
file_path = tmp_path / "test.txt"
file_path.write_text("Test content.", encoding="utf-8")
processor = DocumentProcessor()
# embed 模型失败
mock_embed = MagicMock()
mock_embed.aget_text_embedding = AsyncMock(side_effect=RuntimeError("embed error"))
mock_vs = MagicMock()
mock_vs.async_add = AsyncMock()
mock_store = self._make_mock_store()
with patch("llama_index.core.node_parser.SentenceSplitter") as mock_splitter_cls:
mock_splitter = MagicMock()
mock_splitter.split_text.return_value = ["chunk1"]
mock_splitter_cls.return_value = mock_splitter
with patch("llama_index.core.schema.TextNode"):
with pytest.raises(RuntimeError, match="embed error"):
await processor.process(
str(file_path),
"txt",
"kb-1",
"doc-1",
mock_vs,
mock_embed,
mock_store,
)
# 验证最后状态为 failed且包含 error_message
status_calls = mock_store.update_document_status.call_args_list
last_call = status_calls[-1]
assert last_call.args[1] == DocumentStatus.failed
assert "embed error" in last_call.kwargs.get("error_message", "")
async def test_process_parse_failure_sets_failed_status(self, tmp_path):
"""parse 阶段失败时状态置为 failed。"""
processor = DocumentProcessor()
mock_embed = MagicMock()
mock_vs = MagicMock()
mock_store = self._make_mock_store()
# 文件不存在导致 parse 失败
with pytest.raises(FileNotFoundError):
await processor.process(
"/nonexistent/file.txt",
"txt",
"kb-1",
"doc-1",
mock_vs,
mock_embed,
mock_store,
)
# 验证状态为 failed
status_calls = mock_store.update_document_status.call_args_list
assert len(status_calls) >= 2 # 至少 parsing + failed
last_call = status_calls[-1]
assert last_call.args[1] == DocumentStatus.failed