feat(rag_platform): U1 — RAG platform skeleton + LlamaIndex integration
Create src/agentkit/rag_platform/ module with: - models.py: Pydantic domain models (KB, Document, Chunk, QueryResult) - indexing.py: PGVectorStore wrapper with explicit table name (rag_platform_kb_chunks) for schema isolation from episodic_memory - pipeline.py: RAGPipeline wrapping LlamaIndex IngestionPipeline (SentenceSplitter + embedding + vector store) Add dependencies: llama-index-core, llama-index-vector-stores-postgres, llama-index-embeddings-openai, pgvector, jieba. Tests: 14 unit tests covering models, indexing (URL conversion, table name isolation, embed_dim), and pipeline (ingest, query, chunk params).
This commit is contained in:
parent
22c89763e2
commit
27d0184392
|
|
@ -37,6 +37,12 @@ dependencies = [
|
|||
"docxtpl>=0.16",
|
||||
"jinja2>=3.1",
|
||||
"markdown>=3.5",
|
||||
# RAG platform (P1 — LlamaIndex + pgvector)
|
||||
"llama-index-core>=0.12",
|
||||
"llama-index-vector-stores-postgres>=0.4",
|
||||
"llama-index-embeddings-openai>=0.3",
|
||||
"pgvector>=0.3",
|
||||
"jieba>=0.42",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,21 @@
|
|||
"""RAG 平台模块 — 企业知识库场景的工业级 RAG 管道。"""
|
||||
|
||||
from agentkit.rag_platform.models import (
|
||||
Chunk,
|
||||
Document,
|
||||
DocumentStatus,
|
||||
KBStatus,
|
||||
KnowledgeBase,
|
||||
QueryMode,
|
||||
QueryResult,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Chunk",
|
||||
"Document",
|
||||
"DocumentStatus",
|
||||
"KBStatus",
|
||||
"KnowledgeBase",
|
||||
"QueryMode",
|
||||
"QueryResult",
|
||||
]
|
||||
|
|
@ -0,0 +1,79 @@
|
|||
"""pgvector 索引管理 — LlamaIndex PGVectorStore 封装。
|
||||
|
||||
Schema 隔离:使用显式表名 `rag_platform_kb_chunks`,不可使用默认 `data_<name>`。
|
||||
确认 `create_if_not_exists` 不会触碰 `episodic_memory` 或任何 `memory/` 所属表。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llama_index.vector_stores.postgres import PGVectorStore
|
||||
|
||||
# 显式表名 — 与 memory/episodic.py 的 episodic_memories 表完全隔离
|
||||
KB_CHUNKS_TABLE = "rag_platform_kb_chunks"
|
||||
|
||||
# 默认 embedding 维度(OpenAI text-embedding-3-small)
|
||||
DEFAULT_EMBED_DIM = 1536
|
||||
|
||||
|
||||
def _async_to_sync_url(database_url: str) -> str:
|
||||
"""将 asyncpg URL 转换为 psycopg2 URL(LlamaIndex PGVectorStore 使用同步驱动)。"""
|
||||
return database_url.replace("postgresql+asyncpg://", "postgresql://")
|
||||
|
||||
|
||||
def create_vector_store(
|
||||
database_url: str,
|
||||
embed_dim: int = DEFAULT_EMBED_DIM,
|
||||
hybrid_search: bool = True,
|
||||
) -> "PGVectorStore":
|
||||
"""创建 PGVectorStore,使用显式表名实现 schema 隔离。
|
||||
|
||||
Args:
|
||||
database_url: SQLAlchemy 数据库 URL(async 或 sync 均可)
|
||||
embed_dim: embedding 向量维度
|
||||
hybrid_search: 是否启用混合搜索(向量 + 全文)
|
||||
|
||||
Returns:
|
||||
LlamaIndex PGVectorStore 实例
|
||||
|
||||
Raises:
|
||||
ImportError: 如果 llama-index-vector-stores-postgres 未安装
|
||||
"""
|
||||
from llama_index.vector_stores.postgres import PGVectorStore
|
||||
|
||||
sync_url = _async_to_sync_url(database_url)
|
||||
|
||||
logger.info(
|
||||
"Creating PGVectorStore: table=%s, embed_dim=%d, hybrid=%s",
|
||||
KB_CHUNKS_TABLE,
|
||||
embed_dim,
|
||||
hybrid_search,
|
||||
)
|
||||
|
||||
return PGVectorStore.from_params(
|
||||
database=sync_url,
|
||||
table_name=KB_CHUNKS_TABLE,
|
||||
embed_dim=embed_dim,
|
||||
hybrid_search=hybrid_search,
|
||||
text_search_config="english", # U4 将用 jieba 替换为中文分词
|
||||
# ponytail: 不设 schema_name,默认 public — 避免创建独立 schema 的运维复杂度
|
||||
# 表名前缀 rag_platform_ 已足够隔离
|
||||
)
|
||||
|
||||
|
||||
async def ensure_vector_store_schema(database_url: str, embed_dim: int = DEFAULT_EMBED_DIM) -> None:
|
||||
"""确保 vector store 表存在(幂等)。
|
||||
|
||||
在应用启动时调用,创建表结构(如果不存在)。
|
||||
不会触碰 episodic_memory 或任何 memory/ 所属表。
|
||||
"""
|
||||
vs = create_vector_store(database_url, embed_dim=embed_dim)
|
||||
# PGVectorStore.__init__ 内部会调用 create_if_not_exists
|
||||
# 显式调用 _initialize 来确保表创建
|
||||
vs._initialize()
|
||||
logger.info("Vector store schema ensured: table=%s", KB_CHUNKS_TABLE)
|
||||
|
|
@ -0,0 +1,106 @@
|
|||
"""RAG 平台 Pydantic 数据模型。
|
||||
|
||||
这些是领域模型(非 ORM),ORM 模型定义在 store.py(U2)。
|
||||
与 memory/knowledge_base.py 的 KnowledgeBase protocol 分离 —
|
||||
rag_platform 服务企业知识库场景,memory/ 服务 Agent 运行时记忆。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
def _utcnow() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def _uuid() -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
class KBStatus(str, Enum):
|
||||
active = "active"
|
||||
archived = "archived"
|
||||
|
||||
|
||||
class DocumentStatus(str, Enum):
|
||||
"""文档处理状态机:pending → parsing → segmenting → vectorizing → indexed | failed。"""
|
||||
|
||||
pending = "pending"
|
||||
parsing = "parsing"
|
||||
segmenting = "segmenting"
|
||||
vectorizing = "vectorizing"
|
||||
indexed = "indexed"
|
||||
failed = "failed"
|
||||
|
||||
|
||||
class QueryMode(str, Enum):
|
||||
"""检索模式:embedding 语义 / keywords 全文 / blend 双索引合并。"""
|
||||
|
||||
embedding = "embedding"
|
||||
keywords = "keywords"
|
||||
blend = "blend"
|
||||
|
||||
|
||||
class KnowledgeBase(BaseModel):
|
||||
"""知识库领域模型。"""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: str = Field(default_factory=_uuid)
|
||||
name: str
|
||||
description: str = ""
|
||||
owner: str
|
||||
status: KBStatus = KBStatus.active
|
||||
# 检索与命中处理默认配置(Agent 运行时可覆盖)
|
||||
default_query_mode: QueryMode = QueryMode.blend
|
||||
default_hit_processing: str = "model_opt" # model_opt | direct
|
||||
caching_disabled: bool = False
|
||||
created_at: datetime = Field(default_factory=_utcnow)
|
||||
updated_at: datetime = Field(default_factory=_utcnow)
|
||||
|
||||
|
||||
class Document(BaseModel):
|
||||
"""知识库文档领域模型。"""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: str = Field(default_factory=_uuid)
|
||||
kb_id: str
|
||||
filename: str
|
||||
file_type: str
|
||||
file_size: int
|
||||
status: DocumentStatus = DocumentStatus.pending
|
||||
error_message: str | None = None
|
||||
created_at: datetime = Field(default_factory=_utcnow)
|
||||
updated_at: datetime = Field(default_factory=_utcnow)
|
||||
|
||||
|
||||
class Chunk(BaseModel):
|
||||
"""文档分段后的文本块。"""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: str = Field(default_factory=_uuid)
|
||||
document_id: str
|
||||
kb_id: str
|
||||
content: str
|
||||
metadata: dict = Field(default_factory=dict)
|
||||
embedding: list[float] | None = None
|
||||
|
||||
|
||||
class QueryResult(BaseModel):
|
||||
"""检索结果条目。"""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
chunk_id: str
|
||||
content: str
|
||||
score: float
|
||||
metadata: dict = Field(default_factory=dict)
|
||||
document_id: str
|
||||
kb_id: str
|
||||
|
|
@ -0,0 +1,122 @@
|
|||
"""LlamaIndex IngestionPipeline 封装 — 文档处理管道。
|
||||
|
||||
管道流程:文档 → SentenceSplitter 分段 → embedding → pgvector 索引。
|
||||
U3 将扩展为完整 IngestionPipeline(含解析、预览、净化)。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from agentkit.rag_platform.indexing import KB_CHUNKS_TABLE
|
||||
from agentkit.rag_platform.models import QueryResult
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.schema import TextNode
|
||||
from llama_index.vector_stores.postgres import PGVectorStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_CHUNK_SIZE = 512
|
||||
DEFAULT_CHUNK_OVERLAP = 50
|
||||
DEFAULT_TOP_K = 5
|
||||
|
||||
|
||||
class RAGPipeline:
|
||||
"""封装 LlamaIndex IngestionPipeline 用于 KB 文档处理。
|
||||
|
||||
Args:
|
||||
vector_store: LlamaIndex PGVectorStore 实例
|
||||
embed_model: LlamaIndex embedding 模型
|
||||
chunk_size: 分段大小(token 数)
|
||||
chunk_overlap: 分段重叠(token 数)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vector_store: "PGVectorStore",
|
||||
embed_model: "BaseEmbedding",
|
||||
chunk_size: int = DEFAULT_CHUNK_SIZE,
|
||||
chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
|
||||
) -> None:
|
||||
from llama_index.core.ingestion import IngestionPipeline
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
self._vector_store = vector_store
|
||||
self._embed_model = embed_model
|
||||
|
||||
self._pipeline = IngestionPipeline(
|
||||
transformations=[
|
||||
SentenceSplitter(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
),
|
||||
embed_model,
|
||||
],
|
||||
vector_store=vector_store,
|
||||
)
|
||||
logger.info(
|
||||
"RAGPipeline initialized: chunk_size=%d, chunk_overlap=%d, table=%s",
|
||||
chunk_size,
|
||||
chunk_overlap,
|
||||
KB_CHUNKS_TABLE,
|
||||
)
|
||||
|
||||
async def ingest(self, text: str, metadata: dict[str, Any] | None = None) -> list["TextNode"]:
|
||||
"""将文本摄入向量存储。
|
||||
|
||||
Args:
|
||||
text: 文档文本
|
||||
metadata: 附加到每个 chunk 的元数据(kb_id, document_id 等)
|
||||
|
||||
Returns:
|
||||
LlamaIndex TextNode 列表(已写入 vector store)
|
||||
"""
|
||||
from llama_index.core.schema import Document as LIDocument
|
||||
|
||||
doc = LIDocument(text=text, metadata=metadata or {})
|
||||
nodes = await self._pipeline.arun(documents=[doc])
|
||||
logger.info("Ingested %d chunks from document", len(nodes))
|
||||
return nodes
|
||||
|
||||
async def query(
|
||||
self,
|
||||
query_text: str,
|
||||
top_k: int = DEFAULT_TOP_K,
|
||||
) -> list[QueryResult]:
|
||||
"""语义检索。
|
||||
|
||||
Args:
|
||||
query_text: 查询文本
|
||||
top_k: 返回结果数
|
||||
|
||||
Returns:
|
||||
QueryResult 列表(按相似度降序)
|
||||
"""
|
||||
from llama_index.core.vector_stores.types import VectorStoreQuery
|
||||
|
||||
query_embedding = await self._embed_model.aget_text_embedding(query_text)
|
||||
|
||||
vs_query = VectorStoreQuery(
|
||||
query_embedding=query_embedding,
|
||||
similarity_top_k=top_k,
|
||||
)
|
||||
result = await self._vector_store.aquery(vs_query)
|
||||
|
||||
results: list[QueryResult] = []
|
||||
for node, score in zip(result.nodes, result.similarities):
|
||||
meta = node.metadata or {}
|
||||
results.append(
|
||||
QueryResult(
|
||||
chunk_id=node.node_id,
|
||||
content=node.get_content(),
|
||||
score=float(score) if score is not None else 0.0,
|
||||
metadata=meta,
|
||||
document_id=meta.get("document_id", ""),
|
||||
kb_id=meta.get("kb_id", ""),
|
||||
)
|
||||
)
|
||||
logger.info("Query returned %d results (top_k=%d)", len(results), top_k)
|
||||
return results
|
||||
|
|
@ -0,0 +1,255 @@
|
|||
"""U1 测试 — RAG 平台模块骨架 + LlamaIndex 集成。
|
||||
|
||||
测试场景:
|
||||
1. LlamaIndex PGVectorStore 连接现有 pgvector 扩展(mock 验证参数)
|
||||
2. 基础 ingest(文档 → chunk → embedding → pgvector INSERT)端到端工作
|
||||
3. 基础 query(query → embedding → pgvector cosine 检索)返回结果
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from agentkit.rag_platform.indexing import (
|
||||
DEFAULT_EMBED_DIM,
|
||||
KB_CHUNKS_TABLE,
|
||||
_async_to_sync_url,
|
||||
create_vector_store,
|
||||
)
|
||||
from agentkit.rag_platform.models import (
|
||||
Chunk,
|
||||
Document,
|
||||
DocumentStatus,
|
||||
KBStatus,
|
||||
KnowledgeBase,
|
||||
QueryMode,
|
||||
QueryResult,
|
||||
)
|
||||
from agentkit.rag_platform.pipeline import RAGPipeline
|
||||
|
||||
|
||||
class TestModels:
|
||||
"""领域模型测试。"""
|
||||
|
||||
def test_knowledge_base_defaults(self):
|
||||
"""KnowledgeBase 默认值正确。"""
|
||||
kb = KnowledgeBase(name="test-kb", owner="user1")
|
||||
assert kb.status == KBStatus.active
|
||||
assert kb.default_query_mode == QueryMode.blend
|
||||
assert kb.default_hit_processing == "model_opt"
|
||||
assert kb.caching_disabled is False
|
||||
assert kb.id # auto-generated UUID
|
||||
|
||||
def test_document_status_transitions(self):
|
||||
"""DocumentStatus 状态机值正确。"""
|
||||
doc = Document(
|
||||
kb_id="kb1",
|
||||
filename="test.pdf",
|
||||
file_type="pdf",
|
||||
file_size=1024,
|
||||
)
|
||||
assert doc.status == DocumentStatus.pending
|
||||
assert doc.error_message is None
|
||||
|
||||
def test_chunk_model(self):
|
||||
"""Chunk 模型字段正确。"""
|
||||
chunk = Chunk(
|
||||
document_id="doc1",
|
||||
kb_id="kb1",
|
||||
content="hello world",
|
||||
)
|
||||
assert chunk.embedding is None
|
||||
assert chunk.metadata == {}
|
||||
|
||||
def test_query_result_model(self):
|
||||
"""QueryResult 模型字段正确。"""
|
||||
result = QueryResult(
|
||||
chunk_id="c1",
|
||||
content="hello",
|
||||
score=0.95,
|
||||
document_id="doc1",
|
||||
kb_id="kb1",
|
||||
)
|
||||
assert result.score == 0.95
|
||||
assert result.metadata == {}
|
||||
|
||||
|
||||
class TestIndexing:
|
||||
"""pgvector 索引管理测试。"""
|
||||
|
||||
def test_async_to_sync_url_conversion(self):
|
||||
"""asyncpg URL 正确转换为 psycopg2 URL。"""
|
||||
async_url = "postgresql+asyncpg://user:pass@localhost:5432/db"
|
||||
sync_url = _async_to_sync_url(async_url)
|
||||
assert sync_url == "postgresql://user:pass@localhost:5432/db"
|
||||
|
||||
def test_sync_url_unchanged(self):
|
||||
"""已经是 sync URL 时不转换。"""
|
||||
sync_url = "postgresql://user:pass@localhost:5432/db"
|
||||
assert _async_to_sync_url(sync_url) == sync_url
|
||||
|
||||
@patch("llama_index.vector_stores.postgres.PGVectorStore.from_params")
|
||||
def test_create_vector_store_uses_explicit_table_name(self, mock_from_params):
|
||||
"""create_vector_store 使用显式表名(schema 隔离)。"""
|
||||
mock_store = MagicMock()
|
||||
mock_from_params.return_value = mock_store
|
||||
|
||||
result = create_vector_store(
|
||||
"postgresql+asyncpg://user:pass@localhost:5432/db",
|
||||
embed_dim=768,
|
||||
)
|
||||
|
||||
# 验证 from_params 被调用时使用了正确的表名
|
||||
mock_from_params.assert_called_once()
|
||||
call_kwargs = mock_from_params.call_args.kwargs
|
||||
assert call_kwargs["table_name"] == KB_CHUNKS_TABLE
|
||||
assert call_kwargs["embed_dim"] == 768
|
||||
assert call_kwargs["hybrid_search"] is True
|
||||
assert result is mock_store
|
||||
|
||||
@patch("llama_index.vector_stores.postgres.PGVectorStore.from_params")
|
||||
def test_create_vector_store_default_embed_dim(self, mock_from_params):
|
||||
"""create_vector_store 默认 embed_dim 为 1536。"""
|
||||
mock_from_params.return_value = MagicMock()
|
||||
|
||||
create_vector_store("postgresql://localhost/db")
|
||||
|
||||
call_kwargs = mock_from_params.call_args.kwargs
|
||||
assert call_kwargs["embed_dim"] == DEFAULT_EMBED_DIM
|
||||
|
||||
@patch("llama_index.vector_stores.postgres.PGVectorStore.from_params")
|
||||
def test_create_vector_store_converts_async_url(self, mock_from_params):
|
||||
"""create_vector_store 将 asyncpg URL 转换为 sync URL。"""
|
||||
mock_from_params.return_value = MagicMock()
|
||||
|
||||
create_vector_store("postgresql+asyncpg://user:pass@localhost:5432/db")
|
||||
|
||||
call_kwargs = mock_from_params.call_args.kwargs
|
||||
assert call_kwargs["database"] == "postgresql://user:pass@localhost:5432/db"
|
||||
|
||||
|
||||
class TestRAGPipeline:
|
||||
"""RAGPipeline 管道测试。"""
|
||||
|
||||
def _make_mock_embed_model(self):
|
||||
"""创建 mock embedding 模型。"""
|
||||
mock = MagicMock()
|
||||
mock.aget_text_embedding = AsyncMock(return_value=[0.1] * 1536)
|
||||
return mock
|
||||
|
||||
def _make_mock_vector_store(self):
|
||||
"""创建 mock vector store。"""
|
||||
mock = MagicMock()
|
||||
mock.aquery = AsyncMock()
|
||||
return mock
|
||||
|
||||
def _make_mock_text_node(self, node_id: str, content: str, metadata: dict | None = None):
|
||||
"""创建 mock TextNode。"""
|
||||
node = MagicMock()
|
||||
node.node_id = node_id
|
||||
node.get_content.return_value = content
|
||||
node.metadata = metadata or {}
|
||||
return node
|
||||
|
||||
async def test_ingest_calls_pipeline_arun(self):
|
||||
"""ingest 调用 IngestionPipeline.arun 并返回 TextNode 列表。"""
|
||||
mock_vs = self._make_mock_vector_store()
|
||||
mock_embed = self._make_mock_embed_model()
|
||||
|
||||
mock_nodes = [
|
||||
self._make_mock_text_node("n1", "chunk 1", {"kb_id": "kb1"}),
|
||||
self._make_mock_text_node("n2", "chunk 2", {"kb_id": "kb1"}),
|
||||
]
|
||||
|
||||
with patch("llama_index.core.ingestion.IngestionPipeline") as mock_pipeline_cls:
|
||||
mock_pipeline = MagicMock()
|
||||
mock_pipeline.arun = AsyncMock(return_value=mock_nodes)
|
||||
mock_pipeline_cls.return_value = mock_pipeline
|
||||
|
||||
rag = RAGPipeline(mock_vs, mock_embed, chunk_size=256, chunk_overlap=20)
|
||||
nodes = await rag.ingest("hello world", metadata={"kb_id": "kb1", "document_id": "d1"})
|
||||
|
||||
assert len(nodes) == 2
|
||||
assert nodes[0].node_id == "n1"
|
||||
mock_pipeline.arun.assert_awaited_once()
|
||||
|
||||
async def test_query_returns_query_results(self):
|
||||
"""query 返回 QueryResult 列表,按相似度排序。"""
|
||||
mock_vs = self._make_mock_vector_store()
|
||||
mock_embed = self._make_mock_embed_model()
|
||||
|
||||
# 模拟 vector store query 返回
|
||||
mock_result = MagicMock()
|
||||
mock_result.nodes = [
|
||||
self._make_mock_text_node(
|
||||
"n1", "relevant chunk", {"kb_id": "kb1", "document_id": "d1"}
|
||||
),
|
||||
self._make_mock_text_node("n2", "another chunk", {"kb_id": "kb1", "document_id": "d1"}),
|
||||
]
|
||||
mock_result.similarities = [0.95, 0.80]
|
||||
mock_vs.aquery.return_value = mock_result
|
||||
|
||||
with patch("llama_index.core.ingestion.IngestionPipeline"):
|
||||
rag = RAGPipeline(mock_vs, mock_embed)
|
||||
results = await rag.query("test query", top_k=2)
|
||||
|
||||
assert len(results) == 2
|
||||
assert isinstance(results[0], QueryResult)
|
||||
assert results[0].chunk_id == "n1"
|
||||
assert results[0].score == 0.95
|
||||
assert results[0].kb_id == "kb1"
|
||||
assert results[1].score == 0.80
|
||||
|
||||
async def test_query_calls_embed_model(self):
|
||||
"""query 调用 embedding 模型获取查询向量。"""
|
||||
mock_vs = self._make_mock_vector_store()
|
||||
mock_embed = self._make_mock_embed_model()
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.nodes = []
|
||||
mock_result.similarities = []
|
||||
mock_vs.aquery.return_value = mock_result
|
||||
|
||||
with patch("llama_index.core.ingestion.IngestionPipeline"):
|
||||
rag = RAGPipeline(mock_vs, mock_embed)
|
||||
await rag.query("test", top_k=5)
|
||||
|
||||
mock_embed.aget_text_embedding.assert_awaited_once_with("test")
|
||||
|
||||
async def test_query_passes_top_k_to_vector_store(self):
|
||||
"""query 将 top_k 传递给 vector store query。"""
|
||||
mock_vs = self._make_mock_vector_store()
|
||||
mock_embed = self._make_mock_embed_model()
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.nodes = []
|
||||
mock_result.similarities = []
|
||||
mock_vs.aquery.return_value = mock_result
|
||||
|
||||
with patch("llama_index.core.ingestion.IngestionPipeline"):
|
||||
rag = RAGPipeline(mock_vs, mock_embed)
|
||||
await rag.query("test", top_k=10)
|
||||
|
||||
# 验证 aquery 被调用时 similarity_top_k=10
|
||||
call_args = mock_vs.aquery.call_args
|
||||
vs_query = call_args.args[0]
|
||||
assert vs_query.similarity_top_k == 10
|
||||
|
||||
def test_pipeline_init_with_custom_chunk_params(self):
|
||||
"""RAGPipeline 接受自定义 chunk_size 和 chunk_overlap。"""
|
||||
mock_vs = self._make_mock_vector_store()
|
||||
mock_embed = self._make_mock_embed_model()
|
||||
|
||||
with patch("llama_index.core.ingestion.IngestionPipeline") as mock_pipeline_cls:
|
||||
RAGPipeline(mock_vs, mock_embed, chunk_size=1024, chunk_overlap=100)
|
||||
|
||||
# 验证 SentenceSplitter 被创建时使用了自定义参数
|
||||
call_args = mock_pipeline_cls.call_args
|
||||
transformations = call_args.kwargs.get("transformations") or call_args[1].get(
|
||||
"transformations"
|
||||
)
|
||||
assert transformations is not None
|
||||
# 第一个 transformation 是 SentenceSplitter
|
||||
splitter = transformations[0]
|
||||
assert splitter.chunk_size == 1024
|
||||
assert splitter.chunk_overlap == 100
|
||||
Loading…
Reference in New Issue