"""U1 测试 — RAG 平台模块骨架 + LlamaIndex 集成。 测试场景: 1. LlamaIndex PGVectorStore 连接现有 pgvector 扩展(mock 验证参数) 2. 基础 ingest(文档 → chunk → embedding → pgvector INSERT)端到端工作 3. 基础 query(query → embedding → pgvector cosine 检索)返回结果 """ from __future__ import annotations from unittest.mock import AsyncMock, MagicMock, patch from agentkit.rag_platform.indexing import ( DEFAULT_EMBED_DIM, KB_CHUNKS_TABLE, _async_to_sync_url, create_vector_store, ) from agentkit.rag_platform.models import ( Chunk, Document, DocumentStatus, KBStatus, KnowledgeBase, QueryMode, QueryResult, ) from agentkit.rag_platform.pipeline import RAGPipeline class TestModels: """领域模型测试。""" def test_knowledge_base_defaults(self): """KnowledgeBase 默认值正确。""" kb = KnowledgeBase(name="test-kb", owner="user1") assert kb.status == KBStatus.active assert kb.default_query_mode == QueryMode.blend assert kb.default_hit_processing == "model_opt" assert kb.caching_disabled is False assert kb.id # auto-generated UUID def test_document_status_transitions(self): """DocumentStatus 状态机值正确。""" doc = Document( kb_id="kb1", filename="test.pdf", file_type="pdf", file_size=1024, ) assert doc.status == DocumentStatus.pending assert doc.error_message is None def test_chunk_model(self): """Chunk 模型字段正确。""" chunk = Chunk( document_id="doc1", kb_id="kb1", content="hello world", ) assert chunk.embedding is None assert chunk.metadata == {} def test_query_result_model(self): """QueryResult 模型字段正确。""" result = QueryResult( chunk_id="c1", content="hello", score=0.95, document_id="doc1", kb_id="kb1", ) assert result.score == 0.95 assert result.metadata == {} class TestIndexing: """pgvector 索引管理测试。""" def test_async_to_sync_url_conversion(self): """asyncpg URL 正确转换为 psycopg2 URL。""" async_url = "postgresql+asyncpg://user:pass@localhost:5432/db" sync_url = _async_to_sync_url(async_url) assert sync_url == "postgresql://user:pass@localhost:5432/db" def test_sync_url_unchanged(self): """已经是 sync URL 时不转换。""" sync_url = "postgresql://user:pass@localhost:5432/db" assert _async_to_sync_url(sync_url) == sync_url @patch("llama_index.vector_stores.postgres.PGVectorStore.from_params") def test_create_vector_store_uses_explicit_table_name(self, mock_from_params): """create_vector_store 使用显式表名(schema 隔离)。""" mock_store = MagicMock() mock_from_params.return_value = mock_store result = create_vector_store( "postgresql+asyncpg://user:pass@localhost:5432/db", embed_dim=768, ) # 验证 from_params 被调用时使用了正确的表名 mock_from_params.assert_called_once() call_kwargs = mock_from_params.call_args.kwargs assert call_kwargs["table_name"] == KB_CHUNKS_TABLE assert call_kwargs["embed_dim"] == 768 assert call_kwargs["hybrid_search"] is True assert result is mock_store @patch("llama_index.vector_stores.postgres.PGVectorStore.from_params") def test_create_vector_store_default_embed_dim(self, mock_from_params): """create_vector_store 默认 embed_dim 为 1536。""" mock_from_params.return_value = MagicMock() create_vector_store("postgresql://localhost/db") call_kwargs = mock_from_params.call_args.kwargs assert call_kwargs["embed_dim"] == DEFAULT_EMBED_DIM @patch("llama_index.vector_stores.postgres.PGVectorStore.from_params") def test_create_vector_store_converts_async_url(self, mock_from_params): """create_vector_store 将 asyncpg URL 转换为 sync URL。""" mock_from_params.return_value = MagicMock() create_vector_store("postgresql+asyncpg://user:pass@localhost:5432/db") call_kwargs = mock_from_params.call_args.kwargs assert call_kwargs["database"] == "postgresql://user:pass@localhost:5432/db" class TestRAGPipeline: """RAGPipeline 管道测试。""" def _make_mock_embed_model(self): """创建 mock embedding 模型。""" mock = MagicMock() mock.aget_text_embedding = AsyncMock(return_value=[0.1] * 1536) return mock def _make_mock_vector_store(self): """创建 mock vector store。""" mock = MagicMock() mock.aquery = AsyncMock() return mock def _make_mock_text_node(self, node_id: str, content: str, metadata: dict | None = None): """创建 mock TextNode。""" node = MagicMock() node.node_id = node_id node.get_content.return_value = content node.metadata = metadata or {} return node async def test_ingest_calls_pipeline_arun(self): """ingest 调用 IngestionPipeline.arun 并返回 TextNode 列表。""" mock_vs = self._make_mock_vector_store() mock_embed = self._make_mock_embed_model() mock_nodes = [ self._make_mock_text_node("n1", "chunk 1", {"kb_id": "kb1"}), self._make_mock_text_node("n2", "chunk 2", {"kb_id": "kb1"}), ] with patch("llama_index.core.ingestion.IngestionPipeline") as mock_pipeline_cls: mock_pipeline = MagicMock() mock_pipeline.arun = AsyncMock(return_value=mock_nodes) mock_pipeline_cls.return_value = mock_pipeline rag = RAGPipeline(mock_vs, mock_embed, chunk_size=256, chunk_overlap=20) nodes = await rag.ingest("hello world", metadata={"kb_id": "kb1", "document_id": "d1"}) assert len(nodes) == 2 assert nodes[0].node_id == "n1" mock_pipeline.arun.assert_awaited_once() async def test_query_returns_query_results(self): """query 返回 QueryResult 列表,按相似度排序。""" mock_vs = self._make_mock_vector_store() mock_embed = self._make_mock_embed_model() # 模拟 vector store query 返回 mock_result = MagicMock() mock_result.nodes = [ self._make_mock_text_node( "n1", "relevant chunk", {"kb_id": "kb1", "document_id": "d1"} ), self._make_mock_text_node("n2", "another chunk", {"kb_id": "kb1", "document_id": "d1"}), ] mock_result.similarities = [0.95, 0.80] mock_vs.aquery.return_value = mock_result with patch("llama_index.core.ingestion.IngestionPipeline"): rag = RAGPipeline(mock_vs, mock_embed) results = await rag.query("test query", top_k=2) assert len(results) == 2 assert isinstance(results[0], QueryResult) assert results[0].chunk_id == "n1" assert results[0].score == 0.95 assert results[0].kb_id == "kb1" assert results[1].score == 0.80 async def test_query_calls_embed_model(self): """query 调用 embedding 模型获取查询向量。""" mock_vs = self._make_mock_vector_store() mock_embed = self._make_mock_embed_model() mock_result = MagicMock() mock_result.nodes = [] mock_result.similarities = [] mock_vs.aquery.return_value = mock_result with patch("llama_index.core.ingestion.IngestionPipeline"): rag = RAGPipeline(mock_vs, mock_embed) await rag.query("test", top_k=5) mock_embed.aget_text_embedding.assert_awaited_once_with("test") async def test_query_passes_top_k_to_vector_store(self): """query 将 top_k 传递给 vector store query。""" mock_vs = self._make_mock_vector_store() mock_embed = self._make_mock_embed_model() mock_result = MagicMock() mock_result.nodes = [] mock_result.similarities = [] mock_vs.aquery.return_value = mock_result with patch("llama_index.core.ingestion.IngestionPipeline"): rag = RAGPipeline(mock_vs, mock_embed) await rag.query("test", top_k=10) # 验证 aquery 被调用时 similarity_top_k=10 call_args = mock_vs.aquery.call_args vs_query = call_args.args[0] assert vs_query.similarity_top_k == 10 def test_pipeline_init_with_custom_chunk_params(self): """RAGPipeline 接受自定义 chunk_size 和 chunk_overlap。""" mock_vs = self._make_mock_vector_store() mock_embed = self._make_mock_embed_model() with patch("llama_index.core.ingestion.IngestionPipeline") as mock_pipeline_cls: RAGPipeline(mock_vs, mock_embed, chunk_size=1024, chunk_overlap=100) # 验证 SentenceSplitter 被创建时使用了自定义参数 call_args = mock_pipeline_cls.call_args transformations = call_args.kwargs.get("transformations") or call_args[1].get( "transformations" ) assert transformations is not None # 第一个 transformation 是 SentenceSplitter splitter = transformations[0] assert splitter.chunk_size == 1024 assert splitter.chunk_overlap == 100