256 lines
9.3 KiB
Python
256 lines
9.3 KiB
Python
"""U1 测试 — RAG 平台模块骨架 + LlamaIndex 集成。
|
||
|
||
测试场景:
|
||
1. LlamaIndex PGVectorStore 连接现有 pgvector 扩展(mock 验证参数)
|
||
2. 基础 ingest(文档 → chunk → embedding → pgvector INSERT)端到端工作
|
||
3. 基础 query(query → embedding → pgvector cosine 检索)返回结果
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
from agentkit.rag_platform.indexing import (
|
||
DEFAULT_EMBED_DIM,
|
||
KB_CHUNKS_TABLE,
|
||
_async_to_sync_url,
|
||
create_vector_store,
|
||
)
|
||
from agentkit.rag_platform.models import (
|
||
Chunk,
|
||
Document,
|
||
DocumentStatus,
|
||
KBStatus,
|
||
KnowledgeBase,
|
||
QueryMode,
|
||
QueryResult,
|
||
)
|
||
from agentkit.rag_platform.pipeline import RAGPipeline
|
||
|
||
|
||
class TestModels:
|
||
"""领域模型测试。"""
|
||
|
||
def test_knowledge_base_defaults(self):
|
||
"""KnowledgeBase 默认值正确。"""
|
||
kb = KnowledgeBase(name="test-kb", owner="user1")
|
||
assert kb.status == KBStatus.active
|
||
assert kb.default_query_mode == QueryMode.blend
|
||
assert kb.default_hit_processing == "model_opt"
|
||
assert kb.caching_disabled is False
|
||
assert kb.id # auto-generated UUID
|
||
|
||
def test_document_status_transitions(self):
|
||
"""DocumentStatus 状态机值正确。"""
|
||
doc = Document(
|
||
kb_id="kb1",
|
||
filename="test.pdf",
|
||
file_type="pdf",
|
||
file_size=1024,
|
||
)
|
||
assert doc.status == DocumentStatus.pending
|
||
assert doc.error_message is None
|
||
|
||
def test_chunk_model(self):
|
||
"""Chunk 模型字段正确。"""
|
||
chunk = Chunk(
|
||
document_id="doc1",
|
||
kb_id="kb1",
|
||
content="hello world",
|
||
)
|
||
assert chunk.embedding is None
|
||
assert chunk.metadata == {}
|
||
|
||
def test_query_result_model(self):
|
||
"""QueryResult 模型字段正确。"""
|
||
result = QueryResult(
|
||
chunk_id="c1",
|
||
content="hello",
|
||
score=0.95,
|
||
document_id="doc1",
|
||
kb_id="kb1",
|
||
)
|
||
assert result.score == 0.95
|
||
assert result.metadata == {}
|
||
|
||
|
||
class TestIndexing:
|
||
"""pgvector 索引管理测试。"""
|
||
|
||
def test_async_to_sync_url_conversion(self):
|
||
"""asyncpg URL 正确转换为 psycopg2 URL。"""
|
||
async_url = "postgresql+asyncpg://user:pass@localhost:5432/db"
|
||
sync_url = _async_to_sync_url(async_url)
|
||
assert sync_url == "postgresql://user:pass@localhost:5432/db"
|
||
|
||
def test_sync_url_unchanged(self):
|
||
"""已经是 sync URL 时不转换。"""
|
||
sync_url = "postgresql://user:pass@localhost:5432/db"
|
||
assert _async_to_sync_url(sync_url) == sync_url
|
||
|
||
@patch("llama_index.vector_stores.postgres.PGVectorStore.from_params")
|
||
def test_create_vector_store_uses_explicit_table_name(self, mock_from_params):
|
||
"""create_vector_store 使用显式表名(schema 隔离)。"""
|
||
mock_store = MagicMock()
|
||
mock_from_params.return_value = mock_store
|
||
|
||
result = create_vector_store(
|
||
"postgresql+asyncpg://user:pass@localhost:5432/db",
|
||
embed_dim=768,
|
||
)
|
||
|
||
# 验证 from_params 被调用时使用了正确的表名
|
||
mock_from_params.assert_called_once()
|
||
call_kwargs = mock_from_params.call_args.kwargs
|
||
assert call_kwargs["table_name"] == KB_CHUNKS_TABLE
|
||
assert call_kwargs["embed_dim"] == 768
|
||
assert call_kwargs["hybrid_search"] is True
|
||
assert result is mock_store
|
||
|
||
@patch("llama_index.vector_stores.postgres.PGVectorStore.from_params")
|
||
def test_create_vector_store_default_embed_dim(self, mock_from_params):
|
||
"""create_vector_store 默认 embed_dim 为 1536。"""
|
||
mock_from_params.return_value = MagicMock()
|
||
|
||
create_vector_store("postgresql://localhost/db")
|
||
|
||
call_kwargs = mock_from_params.call_args.kwargs
|
||
assert call_kwargs["embed_dim"] == DEFAULT_EMBED_DIM
|
||
|
||
@patch("llama_index.vector_stores.postgres.PGVectorStore.from_params")
|
||
def test_create_vector_store_converts_async_url(self, mock_from_params):
|
||
"""create_vector_store 将 asyncpg URL 转换为 sync URL。"""
|
||
mock_from_params.return_value = MagicMock()
|
||
|
||
create_vector_store("postgresql+asyncpg://user:pass@localhost:5432/db")
|
||
|
||
call_kwargs = mock_from_params.call_args.kwargs
|
||
assert call_kwargs["database"] == "postgresql://user:pass@localhost:5432/db"
|
||
|
||
|
||
class TestRAGPipeline:
|
||
"""RAGPipeline 管道测试。"""
|
||
|
||
def _make_mock_embed_model(self):
|
||
"""创建 mock embedding 模型。"""
|
||
mock = MagicMock()
|
||
mock.aget_text_embedding = AsyncMock(return_value=[0.1] * 1536)
|
||
return mock
|
||
|
||
def _make_mock_vector_store(self):
|
||
"""创建 mock vector store。"""
|
||
mock = MagicMock()
|
||
mock.aquery = AsyncMock()
|
||
return mock
|
||
|
||
def _make_mock_text_node(self, node_id: str, content: str, metadata: dict | None = None):
|
||
"""创建 mock TextNode。"""
|
||
node = MagicMock()
|
||
node.node_id = node_id
|
||
node.get_content.return_value = content
|
||
node.metadata = metadata or {}
|
||
return node
|
||
|
||
async def test_ingest_calls_pipeline_arun(self):
|
||
"""ingest 调用 IngestionPipeline.arun 并返回 TextNode 列表。"""
|
||
mock_vs = self._make_mock_vector_store()
|
||
mock_embed = self._make_mock_embed_model()
|
||
|
||
mock_nodes = [
|
||
self._make_mock_text_node("n1", "chunk 1", {"kb_id": "kb1"}),
|
||
self._make_mock_text_node("n2", "chunk 2", {"kb_id": "kb1"}),
|
||
]
|
||
|
||
with patch("llama_index.core.ingestion.IngestionPipeline") as mock_pipeline_cls:
|
||
mock_pipeline = MagicMock()
|
||
mock_pipeline.arun = AsyncMock(return_value=mock_nodes)
|
||
mock_pipeline_cls.return_value = mock_pipeline
|
||
|
||
rag = RAGPipeline(mock_vs, mock_embed, chunk_size=256, chunk_overlap=20)
|
||
nodes = await rag.ingest("hello world", metadata={"kb_id": "kb1", "document_id": "d1"})
|
||
|
||
assert len(nodes) == 2
|
||
assert nodes[0].node_id == "n1"
|
||
mock_pipeline.arun.assert_awaited_once()
|
||
|
||
async def test_query_returns_query_results(self):
|
||
"""query 返回 QueryResult 列表,按相似度排序。"""
|
||
mock_vs = self._make_mock_vector_store()
|
||
mock_embed = self._make_mock_embed_model()
|
||
|
||
# 模拟 vector store query 返回
|
||
mock_result = MagicMock()
|
||
mock_result.nodes = [
|
||
self._make_mock_text_node(
|
||
"n1", "relevant chunk", {"kb_id": "kb1", "document_id": "d1"}
|
||
),
|
||
self._make_mock_text_node("n2", "another chunk", {"kb_id": "kb1", "document_id": "d1"}),
|
||
]
|
||
mock_result.similarities = [0.95, 0.80]
|
||
mock_vs.aquery.return_value = mock_result
|
||
|
||
with patch("llama_index.core.ingestion.IngestionPipeline"):
|
||
rag = RAGPipeline(mock_vs, mock_embed)
|
||
results = await rag.query("test query", top_k=2)
|
||
|
||
assert len(results) == 2
|
||
assert isinstance(results[0], QueryResult)
|
||
assert results[0].chunk_id == "n1"
|
||
assert results[0].score == 0.95
|
||
assert results[0].kb_id == "kb1"
|
||
assert results[1].score == 0.80
|
||
|
||
async def test_query_calls_embed_model(self):
|
||
"""query 调用 embedding 模型获取查询向量。"""
|
||
mock_vs = self._make_mock_vector_store()
|
||
mock_embed = self._make_mock_embed_model()
|
||
|
||
mock_result = MagicMock()
|
||
mock_result.nodes = []
|
||
mock_result.similarities = []
|
||
mock_vs.aquery.return_value = mock_result
|
||
|
||
with patch("llama_index.core.ingestion.IngestionPipeline"):
|
||
rag = RAGPipeline(mock_vs, mock_embed)
|
||
await rag.query("test", top_k=5)
|
||
|
||
mock_embed.aget_text_embedding.assert_awaited_once_with("test")
|
||
|
||
async def test_query_passes_top_k_to_vector_store(self):
|
||
"""query 将 top_k 传递给 vector store query。"""
|
||
mock_vs = self._make_mock_vector_store()
|
||
mock_embed = self._make_mock_embed_model()
|
||
|
||
mock_result = MagicMock()
|
||
mock_result.nodes = []
|
||
mock_result.similarities = []
|
||
mock_vs.aquery.return_value = mock_result
|
||
|
||
with patch("llama_index.core.ingestion.IngestionPipeline"):
|
||
rag = RAGPipeline(mock_vs, mock_embed)
|
||
await rag.query("test", top_k=10)
|
||
|
||
# 验证 aquery 被调用时 similarity_top_k=10
|
||
call_args = mock_vs.aquery.call_args
|
||
vs_query = call_args.args[0]
|
||
assert vs_query.similarity_top_k == 10
|
||
|
||
def test_pipeline_init_with_custom_chunk_params(self):
|
||
"""RAGPipeline 接受自定义 chunk_size 和 chunk_overlap。"""
|
||
mock_vs = self._make_mock_vector_store()
|
||
mock_embed = self._make_mock_embed_model()
|
||
|
||
with patch("llama_index.core.ingestion.IngestionPipeline") as mock_pipeline_cls:
|
||
RAGPipeline(mock_vs, mock_embed, chunk_size=1024, chunk_overlap=100)
|
||
|
||
# 验证 SentenceSplitter 被创建时使用了自定义参数
|
||
call_args = mock_pipeline_cls.call_args
|
||
transformations = call_args.kwargs.get("transformations") or call_args[1].get(
|
||
"transformations"
|
||
)
|
||
assert transformations is not None
|
||
# 第一个 transformation 是 SentenceSplitter
|
||
splitter = transformations[0]
|
||
assert splitter.chunk_size == 1024
|
||
assert splitter.chunk_overlap == 100
|