383 lines
15 KiB
Python
383 lines
15 KiB
Python
"""U3+U7 测试 — 文档处理管道。
|
||
|
||
测试场景:
|
||
1. parse + segment 管道(mock 文件 I/O)
|
||
2. preview 返回 chunk 列表
|
||
3. vectorize 调用 embed model + vector store
|
||
4. 失败时设置 error 状态
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
import pytest
|
||
|
||
from agentkit.rag_platform.document_processor import DocumentProcessor
|
||
from agentkit.rag_platform.models import DocumentStatus
|
||
from agentkit.rag_platform.preview import PreviewResult, generate_preview
|
||
|
||
|
||
class TestParseAndSegment:
|
||
"""parse + segment 管道测试。"""
|
||
|
||
def test_parse_extracts_text(self, tmp_path):
|
||
"""parse 从文件提取文本。"""
|
||
# 创建测试文件
|
||
file_path = tmp_path / "test.txt"
|
||
file_path.write_text("Hello, world!\nThis is a test document.", encoding="utf-8")
|
||
|
||
processor = DocumentProcessor()
|
||
text = processor.parse(str(file_path), "txt")
|
||
|
||
assert "Hello, world!" in text
|
||
assert "test document" in text
|
||
|
||
def test_parse_applies_sanitization(self, tmp_path):
|
||
"""parse 对文本格式应用内容净化。"""
|
||
file_path = tmp_path / "test.md"
|
||
file_path.write_text(
|
||
"Hello <script>alert(1)</script> world",
|
||
encoding="utf-8",
|
||
)
|
||
|
||
processor = DocumentProcessor()
|
||
text = processor.parse(str(file_path), "md")
|
||
|
||
assert "<script>" not in text
|
||
assert "Hello" in text
|
||
assert "world" in text
|
||
|
||
def test_segment_splits_text(self):
|
||
"""segment 将文本分段。"""
|
||
processor = DocumentProcessor(chunk_size=100, chunk_overlap=10)
|
||
text = "This is a sentence. " * 50 # 足够长以触发分段
|
||
|
||
with patch("llama_index.core.node_parser.SentenceSplitter") as mock_splitter_cls:
|
||
mock_splitter = MagicMock()
|
||
mock_splitter.split_text.return_value = ["chunk1", "chunk2", "chunk3"]
|
||
mock_splitter_cls.return_value = mock_splitter
|
||
|
||
chunks = processor.segment(text, chunk_size=100, chunk_overlap=10)
|
||
|
||
assert chunks == ["chunk1", "chunk2", "chunk3"]
|
||
mock_splitter.split_text.assert_called_once_with(text)
|
||
|
||
def test_segment_uses_custom_params(self):
|
||
"""segment 使用自定义 chunk_size 和 chunk_overlap。"""
|
||
processor = DocumentProcessor()
|
||
|
||
with patch("llama_index.core.node_parser.SentenceSplitter") as mock_splitter_cls:
|
||
mock_splitter = MagicMock()
|
||
mock_splitter.split_text.return_value = ["chunk"]
|
||
mock_splitter_cls.return_value = mock_splitter
|
||
|
||
processor.segment("text", chunk_size=256, chunk_overlap=20)
|
||
|
||
# 验证 SentenceSplitter 使用了自定义参数
|
||
call_kwargs = mock_splitter_cls.call_args.kwargs
|
||
assert call_kwargs["chunk_size"] == 256
|
||
assert call_kwargs["chunk_overlap"] == 20
|
||
|
||
def test_parse_file_not_found(self):
|
||
"""parse 文件不存在时抛出 FileNotFoundError。"""
|
||
processor = DocumentProcessor()
|
||
with pytest.raises(FileNotFoundError):
|
||
processor.parse("/nonexistent/file.txt", "txt")
|
||
|
||
|
||
class TestPreview:
|
||
"""preview 方法测试。"""
|
||
|
||
def test_preview_returns_chunks(self, tmp_path):
|
||
"""preview 返回 chunk 字典列表。"""
|
||
file_path = tmp_path / "test.txt"
|
||
file_path.write_text("Test content for preview.", encoding="utf-8")
|
||
|
||
processor = DocumentProcessor(chunk_size=100, chunk_overlap=10)
|
||
|
||
with patch("llama_index.core.node_parser.SentenceSplitter") as mock_splitter_cls:
|
||
mock_splitter = MagicMock()
|
||
mock_splitter.split_text.return_value = ["chunk1", "chunk2"]
|
||
mock_splitter_cls.return_value = mock_splitter
|
||
|
||
chunks = processor.preview(str(file_path), "txt")
|
||
|
||
assert len(chunks) == 2
|
||
assert chunks[0]["index"] == 0
|
||
assert chunks[0]["content"] == "chunk1"
|
||
assert chunks[1]["index"] == 1
|
||
assert chunks[1]["content"] == "chunk2"
|
||
assert chunks[0]["metadata"] == {}
|
||
|
||
def test_preview_uses_default_params(self, tmp_path):
|
||
"""preview 使用构造器默认参数。"""
|
||
file_path = tmp_path / "test.txt"
|
||
file_path.write_text("Test content.", encoding="utf-8")
|
||
|
||
processor = DocumentProcessor(chunk_size=512, chunk_overlap=50)
|
||
|
||
with patch("llama_index.core.node_parser.SentenceSplitter") as mock_splitter_cls:
|
||
mock_splitter = MagicMock()
|
||
mock_splitter.split_text.return_value = ["chunk"]
|
||
mock_splitter_cls.return_value = mock_splitter
|
||
|
||
processor.preview(str(file_path), "txt")
|
||
|
||
call_kwargs = mock_splitter_cls.call_args.kwargs
|
||
assert call_kwargs["chunk_size"] == 512
|
||
assert call_kwargs["chunk_overlap"] == 50
|
||
|
||
|
||
class TestGeneratePreview:
|
||
"""generate_preview 函数测试。"""
|
||
|
||
def test_generate_preview_returns_preview_result(self, tmp_path):
|
||
"""generate_preview 返回 PreviewResult。"""
|
||
file_path = tmp_path / "test.txt"
|
||
file_path.write_text("Test content for preview.", encoding="utf-8")
|
||
|
||
with patch("llama_index.core.node_parser.SentenceSplitter") as mock_splitter_cls:
|
||
mock_splitter = MagicMock()
|
||
mock_splitter.split_text.return_value = ["chunk1", "chunk2", "chunk3"]
|
||
mock_splitter_cls.return_value = mock_splitter
|
||
|
||
result = generate_preview(str(file_path), "txt")
|
||
|
||
assert isinstance(result, PreviewResult)
|
||
assert result.total_chunks == 3
|
||
assert len(result.chunks) == 3
|
||
assert result.chunks[0].index == 0
|
||
assert result.chunks[0].content == "chunk1"
|
||
assert result.chunks[2].index == 2
|
||
assert result.chunks[2].content == "chunk3"
|
||
assert result.document_id == "" # 预览阶段无 document_id
|
||
|
||
def test_generate_preview_with_custom_params(self, tmp_path):
|
||
"""generate_preview 使用自定义分段参数。"""
|
||
file_path = tmp_path / "test.txt"
|
||
file_path.write_text("Test content.", encoding="utf-8")
|
||
|
||
with patch("llama_index.core.node_parser.SentenceSplitter") as mock_splitter_cls:
|
||
mock_splitter = MagicMock()
|
||
mock_splitter.split_text.return_value = ["chunk"]
|
||
mock_splitter_cls.return_value = mock_splitter
|
||
|
||
generate_preview(str(file_path), "txt", chunk_size=256, chunk_overlap=20)
|
||
|
||
call_kwargs = mock_splitter_cls.call_args.kwargs
|
||
assert call_kwargs["chunk_size"] == 256
|
||
assert call_kwargs["chunk_overlap"] == 20
|
||
|
||
|
||
class TestVectorize:
|
||
"""vectorize 方法测试。"""
|
||
|
||
def _make_mock_embed_model(self):
|
||
"""创建 mock embedding 模型。"""
|
||
mock = MagicMock()
|
||
mock.aget_text_embedding = AsyncMock(return_value=[0.1] * 1536)
|
||
return mock
|
||
|
||
def _make_mock_vector_store(self):
|
||
"""创建 mock vector store。"""
|
||
mock = MagicMock()
|
||
mock.async_add = AsyncMock()
|
||
return mock
|
||
|
||
async def test_vectorize_calls_embed_model(self):
|
||
"""vectorize 调用 embedding 模型。"""
|
||
processor = DocumentProcessor()
|
||
mock_embed = self._make_mock_embed_model()
|
||
mock_vs = self._make_mock_vector_store()
|
||
|
||
chunks = ["chunk1", "chunk2", "chunk3"]
|
||
|
||
with patch("llama_index.core.schema.TextNode") as mock_node_cls:
|
||
mock_nodes = [MagicMock() for _ in chunks]
|
||
mock_node_cls.side_effect = mock_nodes
|
||
|
||
await processor.vectorize(chunks, "kb-1", "doc-1", mock_vs, mock_embed)
|
||
|
||
# 每个 chunk 都调用了 embed
|
||
assert mock_embed.aget_text_embedding.await_count == 3
|
||
|
||
async def test_vectorize_calls_vector_store(self):
|
||
"""vectorize 调用 vector store 的 async_add。"""
|
||
processor = DocumentProcessor()
|
||
mock_embed = self._make_mock_embed_model()
|
||
mock_vs = self._make_mock_vector_store()
|
||
|
||
chunks = ["chunk1", "chunk2"]
|
||
|
||
with patch("llama_index.core.schema.TextNode") as mock_node_cls:
|
||
mock_nodes = [MagicMock() for _ in chunks]
|
||
mock_node_cls.side_effect = mock_nodes
|
||
|
||
await processor.vectorize(chunks, "kb-1", "doc-1", mock_vs, mock_embed)
|
||
|
||
mock_vs.async_add.assert_awaited_once()
|
||
# 验证传入的 nodes 数量正确
|
||
added_nodes = mock_vs.async_add.call_args.args[0]
|
||
assert len(added_nodes) == 2
|
||
|
||
async def test_vectorize_accepts_dict_chunks(self):
|
||
"""vectorize 接受字典格式的 chunk。"""
|
||
processor = DocumentProcessor()
|
||
mock_embed = self._make_mock_embed_model()
|
||
mock_vs = self._make_mock_vector_store()
|
||
|
||
chunks = [
|
||
{"index": 0, "content": "chunk1", "metadata": {}},
|
||
{"index": 1, "content": "chunk2", "metadata": {}},
|
||
]
|
||
|
||
with patch("llama_index.core.schema.TextNode") as mock_node_cls:
|
||
mock_nodes = [MagicMock() for _ in chunks]
|
||
mock_node_cls.side_effect = mock_nodes
|
||
|
||
await processor.vectorize(chunks, "kb-1", "doc-1", mock_vs, mock_embed)
|
||
|
||
assert mock_embed.aget_text_embedding.await_count == 2
|
||
|
||
async def test_vectorize_failure_propagates(self):
|
||
"""vectorize 失败时抛出异常。"""
|
||
processor = DocumentProcessor()
|
||
mock_embed = MagicMock()
|
||
mock_embed.aget_text_embedding = AsyncMock(side_effect=RuntimeError("embed failed"))
|
||
mock_vs = self._make_mock_vector_store()
|
||
|
||
chunks = ["chunk1", "chunk2"]
|
||
|
||
with patch("llama_index.core.schema.TextNode"):
|
||
with pytest.raises(RuntimeError, match="embed failed"):
|
||
await processor.vectorize(chunks, "kb-1", "doc-1", mock_vs, mock_embed)
|
||
|
||
async def test_vectorize_sets_metadata(self):
|
||
"""vectorize 为每个 node 设置正确的 metadata。"""
|
||
processor = DocumentProcessor()
|
||
mock_embed = self._make_mock_embed_model()
|
||
mock_vs = self._make_mock_vector_store()
|
||
|
||
chunks = ["chunk1", "chunk2"]
|
||
|
||
with patch("llama_index.core.schema.TextNode") as mock_node_cls:
|
||
mock_nodes = [MagicMock() for _ in chunks]
|
||
mock_node_cls.side_effect = mock_nodes
|
||
|
||
await processor.vectorize(chunks, "kb-1", "doc-1", mock_vs, mock_embed)
|
||
|
||
# 验证 TextNode 被创建时使用了正确的 metadata
|
||
assert mock_node_cls.call_count == 2
|
||
first_call_kwargs = mock_node_cls.call_args_list[0].kwargs
|
||
assert first_call_kwargs["metadata"]["kb_id"] == "kb-1"
|
||
assert first_call_kwargs["metadata"]["document_id"] == "doc-1"
|
||
assert first_call_kwargs["metadata"]["chunk_index"] == 0
|
||
|
||
|
||
class TestProcess:
|
||
"""process 方法(完整管道 + 状态转换)测试。"""
|
||
|
||
def _make_mock_store(self):
|
||
"""创建 mock KBStore。"""
|
||
store = MagicMock()
|
||
store.update_document_status = AsyncMock()
|
||
return store
|
||
|
||
async def test_process_success_transitions(self, tmp_path):
|
||
"""process 成功时状态转换:parsing → segmenting → vectorizing → indexed。"""
|
||
file_path = tmp_path / "test.txt"
|
||
file_path.write_text("Test content.", encoding="utf-8")
|
||
|
||
processor = DocumentProcessor()
|
||
mock_embed = MagicMock()
|
||
mock_embed.aget_text_embedding = AsyncMock(return_value=[0.1] * 1536)
|
||
mock_vs = MagicMock()
|
||
mock_vs.async_add = AsyncMock()
|
||
mock_store = self._make_mock_store()
|
||
|
||
with patch("llama_index.core.node_parser.SentenceSplitter") as mock_splitter_cls:
|
||
mock_splitter = MagicMock()
|
||
mock_splitter.split_text.return_value = ["chunk1"]
|
||
mock_splitter_cls.return_value = mock_splitter
|
||
|
||
with patch("llama_index.core.schema.TextNode"):
|
||
await processor.process(
|
||
str(file_path),
|
||
"txt",
|
||
"kb-1",
|
||
"doc-1",
|
||
mock_vs,
|
||
mock_embed,
|
||
mock_store,
|
||
)
|
||
|
||
# 验证状态转换序列
|
||
status_calls = mock_store.update_document_status.call_args_list
|
||
assert len(status_calls) == 4
|
||
assert status_calls[0].args[1] == DocumentStatus.parsing
|
||
assert status_calls[1].args[1] == DocumentStatus.segmenting
|
||
assert status_calls[2].args[1] == DocumentStatus.vectorizing
|
||
assert status_calls[3].args[1] == DocumentStatus.indexed
|
||
|
||
async def test_process_failure_sets_failed_status(self, tmp_path):
|
||
"""process 失败时状态置为 failed。"""
|
||
file_path = tmp_path / "test.txt"
|
||
file_path.write_text("Test content.", encoding="utf-8")
|
||
|
||
processor = DocumentProcessor()
|
||
# embed 模型失败
|
||
mock_embed = MagicMock()
|
||
mock_embed.aget_text_embedding = AsyncMock(side_effect=RuntimeError("embed error"))
|
||
mock_vs = MagicMock()
|
||
mock_vs.async_add = AsyncMock()
|
||
mock_store = self._make_mock_store()
|
||
|
||
with patch("llama_index.core.node_parser.SentenceSplitter") as mock_splitter_cls:
|
||
mock_splitter = MagicMock()
|
||
mock_splitter.split_text.return_value = ["chunk1"]
|
||
mock_splitter_cls.return_value = mock_splitter
|
||
|
||
with patch("llama_index.core.schema.TextNode"):
|
||
with pytest.raises(RuntimeError, match="embed error"):
|
||
await processor.process(
|
||
str(file_path),
|
||
"txt",
|
||
"kb-1",
|
||
"doc-1",
|
||
mock_vs,
|
||
mock_embed,
|
||
mock_store,
|
||
)
|
||
|
||
# 验证最后状态为 failed,且包含 error_message
|
||
status_calls = mock_store.update_document_status.call_args_list
|
||
last_call = status_calls[-1]
|
||
assert last_call.args[1] == DocumentStatus.failed
|
||
assert "embed error" in last_call.kwargs.get("error_message", "")
|
||
|
||
async def test_process_parse_failure_sets_failed_status(self, tmp_path):
|
||
"""parse 阶段失败时状态置为 failed。"""
|
||
processor = DocumentProcessor()
|
||
mock_embed = MagicMock()
|
||
mock_vs = MagicMock()
|
||
mock_store = self._make_mock_store()
|
||
|
||
# 文件不存在导致 parse 失败
|
||
with pytest.raises(FileNotFoundError):
|
||
await processor.process(
|
||
"/nonexistent/file.txt",
|
||
"txt",
|
||
"kb-1",
|
||
"doc-1",
|
||
mock_vs,
|
||
mock_embed,
|
||
mock_store,
|
||
)
|
||
|
||
# 验证状态为 failed
|
||
status_calls = mock_store.update_document_status.call_args_list
|
||
assert len(status_calls) >= 2 # 至少 parsing + failed
|
||
last_call = status_calls[-1]
|
||
assert last_call.args[1] == DocumentStatus.failed
|