fischer-agentkit/tests/unit/rag_platform/test_pipeline.py

256 lines
9.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""U1 测试 — RAG 平台模块骨架 + LlamaIndex 集成。
测试场景:
1. LlamaIndex PGVectorStore 连接现有 pgvector 扩展mock 验证参数)
2. 基础 ingest文档 → chunk → embedding → pgvector INSERT端到端工作
3. 基础 queryquery → embedding → pgvector cosine 检索)返回结果
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch
from agentkit.rag_platform.indexing import (
DEFAULT_EMBED_DIM,
KB_CHUNKS_TABLE,
_async_to_sync_url,
create_vector_store,
)
from agentkit.rag_platform.models import (
Chunk,
Document,
DocumentStatus,
KBStatus,
KnowledgeBase,
QueryMode,
QueryResult,
)
from agentkit.rag_platform.pipeline import RAGPipeline
class TestModels:
"""领域模型测试。"""
def test_knowledge_base_defaults(self):
"""KnowledgeBase 默认值正确。"""
kb = KnowledgeBase(name="test-kb", owner="user1")
assert kb.status == KBStatus.active
assert kb.default_query_mode == QueryMode.blend
assert kb.default_hit_processing == "model_opt"
assert kb.caching_disabled is False
assert kb.id # auto-generated UUID
def test_document_status_transitions(self):
"""DocumentStatus 状态机值正确。"""
doc = Document(
kb_id="kb1",
filename="test.pdf",
file_type="pdf",
file_size=1024,
)
assert doc.status == DocumentStatus.pending
assert doc.error_message is None
def test_chunk_model(self):
"""Chunk 模型字段正确。"""
chunk = Chunk(
document_id="doc1",
kb_id="kb1",
content="hello world",
)
assert chunk.embedding is None
assert chunk.metadata == {}
def test_query_result_model(self):
"""QueryResult 模型字段正确。"""
result = QueryResult(
chunk_id="c1",
content="hello",
score=0.95,
document_id="doc1",
kb_id="kb1",
)
assert result.score == 0.95
assert result.metadata == {}
class TestIndexing:
"""pgvector 索引管理测试。"""
def test_async_to_sync_url_conversion(self):
"""asyncpg URL 正确转换为 psycopg2 URL。"""
async_url = "postgresql+asyncpg://user:pass@localhost:5432/db"
sync_url = _async_to_sync_url(async_url)
assert sync_url == "postgresql://user:pass@localhost:5432/db"
def test_sync_url_unchanged(self):
"""已经是 sync URL 时不转换。"""
sync_url = "postgresql://user:pass@localhost:5432/db"
assert _async_to_sync_url(sync_url) == sync_url
@patch("llama_index.vector_stores.postgres.PGVectorStore.from_params")
def test_create_vector_store_uses_explicit_table_name(self, mock_from_params):
"""create_vector_store 使用显式表名schema 隔离)。"""
mock_store = MagicMock()
mock_from_params.return_value = mock_store
result = create_vector_store(
"postgresql+asyncpg://user:pass@localhost:5432/db",
embed_dim=768,
)
# 验证 from_params 被调用时使用了正确的表名
mock_from_params.assert_called_once()
call_kwargs = mock_from_params.call_args.kwargs
assert call_kwargs["table_name"] == KB_CHUNKS_TABLE
assert call_kwargs["embed_dim"] == 768
assert call_kwargs["hybrid_search"] is True
assert result is mock_store
@patch("llama_index.vector_stores.postgres.PGVectorStore.from_params")
def test_create_vector_store_default_embed_dim(self, mock_from_params):
"""create_vector_store 默认 embed_dim 为 1536。"""
mock_from_params.return_value = MagicMock()
create_vector_store("postgresql://localhost/db")
call_kwargs = mock_from_params.call_args.kwargs
assert call_kwargs["embed_dim"] == DEFAULT_EMBED_DIM
@patch("llama_index.vector_stores.postgres.PGVectorStore.from_params")
def test_create_vector_store_converts_async_url(self, mock_from_params):
"""create_vector_store 将 asyncpg URL 转换为 sync URL。"""
mock_from_params.return_value = MagicMock()
create_vector_store("postgresql+asyncpg://user:pass@localhost:5432/db")
call_kwargs = mock_from_params.call_args.kwargs
assert call_kwargs["database"] == "postgresql://user:pass@localhost:5432/db"
class TestRAGPipeline:
"""RAGPipeline 管道测试。"""
def _make_mock_embed_model(self):
"""创建 mock embedding 模型。"""
mock = MagicMock()
mock.aget_text_embedding = AsyncMock(return_value=[0.1] * 1536)
return mock
def _make_mock_vector_store(self):
"""创建 mock vector store。"""
mock = MagicMock()
mock.aquery = AsyncMock()
return mock
def _make_mock_text_node(self, node_id: str, content: str, metadata: dict | None = None):
"""创建 mock TextNode。"""
node = MagicMock()
node.node_id = node_id
node.get_content.return_value = content
node.metadata = metadata or {}
return node
async def test_ingest_calls_pipeline_arun(self):
"""ingest 调用 IngestionPipeline.arun 并返回 TextNode 列表。"""
mock_vs = self._make_mock_vector_store()
mock_embed = self._make_mock_embed_model()
mock_nodes = [
self._make_mock_text_node("n1", "chunk 1", {"kb_id": "kb1"}),
self._make_mock_text_node("n2", "chunk 2", {"kb_id": "kb1"}),
]
with patch("llama_index.core.ingestion.IngestionPipeline") as mock_pipeline_cls:
mock_pipeline = MagicMock()
mock_pipeline.arun = AsyncMock(return_value=mock_nodes)
mock_pipeline_cls.return_value = mock_pipeline
rag = RAGPipeline(mock_vs, mock_embed, chunk_size=256, chunk_overlap=20)
nodes = await rag.ingest("hello world", metadata={"kb_id": "kb1", "document_id": "d1"})
assert len(nodes) == 2
assert nodes[0].node_id == "n1"
mock_pipeline.arun.assert_awaited_once()
async def test_query_returns_query_results(self):
"""query 返回 QueryResult 列表,按相似度排序。"""
mock_vs = self._make_mock_vector_store()
mock_embed = self._make_mock_embed_model()
# 模拟 vector store query 返回
mock_result = MagicMock()
mock_result.nodes = [
self._make_mock_text_node(
"n1", "relevant chunk", {"kb_id": "kb1", "document_id": "d1"}
),
self._make_mock_text_node("n2", "another chunk", {"kb_id": "kb1", "document_id": "d1"}),
]
mock_result.similarities = [0.95, 0.80]
mock_vs.aquery.return_value = mock_result
with patch("llama_index.core.ingestion.IngestionPipeline"):
rag = RAGPipeline(mock_vs, mock_embed)
results = await rag.query("test query", top_k=2)
assert len(results) == 2
assert isinstance(results[0], QueryResult)
assert results[0].chunk_id == "n1"
assert results[0].score == 0.95
assert results[0].kb_id == "kb1"
assert results[1].score == 0.80
async def test_query_calls_embed_model(self):
"""query 调用 embedding 模型获取查询向量。"""
mock_vs = self._make_mock_vector_store()
mock_embed = self._make_mock_embed_model()
mock_result = MagicMock()
mock_result.nodes = []
mock_result.similarities = []
mock_vs.aquery.return_value = mock_result
with patch("llama_index.core.ingestion.IngestionPipeline"):
rag = RAGPipeline(mock_vs, mock_embed)
await rag.query("test", top_k=5)
mock_embed.aget_text_embedding.assert_awaited_once_with("test")
async def test_query_passes_top_k_to_vector_store(self):
"""query 将 top_k 传递给 vector store query。"""
mock_vs = self._make_mock_vector_store()
mock_embed = self._make_mock_embed_model()
mock_result = MagicMock()
mock_result.nodes = []
mock_result.similarities = []
mock_vs.aquery.return_value = mock_result
with patch("llama_index.core.ingestion.IngestionPipeline"):
rag = RAGPipeline(mock_vs, mock_embed)
await rag.query("test", top_k=10)
# 验证 aquery 被调用时 similarity_top_k=10
call_args = mock_vs.aquery.call_args
vs_query = call_args.args[0]
assert vs_query.similarity_top_k == 10
def test_pipeline_init_with_custom_chunk_params(self):
"""RAGPipeline 接受自定义 chunk_size 和 chunk_overlap。"""
mock_vs = self._make_mock_vector_store()
mock_embed = self._make_mock_embed_model()
with patch("llama_index.core.ingestion.IngestionPipeline") as mock_pipeline_cls:
RAGPipeline(mock_vs, mock_embed, chunk_size=1024, chunk_overlap=100)
# 验证 SentenceSplitter 被创建时使用了自定义参数
call_args = mock_pipeline_cls.call_args
transformations = call_args.kwargs.get("transformations") or call_args[1].get(
"transformations"
)
assert transformations is not None
# 第一个 transformation 是 SentenceSplitter
splitter = transformations[0]
assert splitter.chunk_size == 1024
assert splitter.chunk_overlap == 100