feat(rag_platform): U1 — RAG platform skeleton + LlamaIndex integration
Create src/agentkit/rag_platform/ module with: - models.py: Pydantic domain models (KB, Document, Chunk, QueryResult) - indexing.py: PGVectorStore wrapper with explicit table name (rag_platform_kb_chunks) for schema isolation from episodic_memory - pipeline.py: RAGPipeline wrapping LlamaIndex IngestionPipeline (SentenceSplitter + embedding + vector store) Add dependencies: llama-index-core, llama-index-vector-stores-postgres, llama-index-embeddings-openai, pgvector, jieba. Tests: 14 unit tests covering models, indexing (URL conversion, table name isolation, embed_dim), and pipeline (ingest, query, chunk params).
This commit is contained in:
parent
22c89763e2
commit
27d0184392
|
|
@ -37,6 +37,12 @@ dependencies = [
|
||||||
"docxtpl>=0.16",
|
"docxtpl>=0.16",
|
||||||
"jinja2>=3.1",
|
"jinja2>=3.1",
|
||||||
"markdown>=3.5",
|
"markdown>=3.5",
|
||||||
|
# RAG platform (P1 — LlamaIndex + pgvector)
|
||||||
|
"llama-index-core>=0.12",
|
||||||
|
"llama-index-vector-stores-postgres>=0.4",
|
||||||
|
"llama-index-embeddings-openai>=0.3",
|
||||||
|
"pgvector>=0.3",
|
||||||
|
"jieba>=0.42",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,21 @@
|
||||||
|
"""RAG 平台模块 — 企业知识库场景的工业级 RAG 管道。"""
|
||||||
|
|
||||||
|
from agentkit.rag_platform.models import (
|
||||||
|
Chunk,
|
||||||
|
Document,
|
||||||
|
DocumentStatus,
|
||||||
|
KBStatus,
|
||||||
|
KnowledgeBase,
|
||||||
|
QueryMode,
|
||||||
|
QueryResult,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Chunk",
|
||||||
|
"Document",
|
||||||
|
"DocumentStatus",
|
||||||
|
"KBStatus",
|
||||||
|
"KnowledgeBase",
|
||||||
|
"QueryMode",
|
||||||
|
"QueryResult",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,79 @@
|
||||||
|
"""pgvector 索引管理 — LlamaIndex PGVectorStore 封装。
|
||||||
|
|
||||||
|
Schema 隔离:使用显式表名 `rag_platform_kb_chunks`,不可使用默认 `data_<name>`。
|
||||||
|
确认 `create_if_not_exists` 不会触碰 `episodic_memory` 或任何 `memory/` 所属表。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from llama_index.vector_stores.postgres import PGVectorStore
|
||||||
|
|
||||||
|
# 显式表名 — 与 memory/episodic.py 的 episodic_memories 表完全隔离
|
||||||
|
KB_CHUNKS_TABLE = "rag_platform_kb_chunks"
|
||||||
|
|
||||||
|
# 默认 embedding 维度(OpenAI text-embedding-3-small)
|
||||||
|
DEFAULT_EMBED_DIM = 1536
|
||||||
|
|
||||||
|
|
||||||
|
def _async_to_sync_url(database_url: str) -> str:
|
||||||
|
"""将 asyncpg URL 转换为 psycopg2 URL(LlamaIndex PGVectorStore 使用同步驱动)。"""
|
||||||
|
return database_url.replace("postgresql+asyncpg://", "postgresql://")
|
||||||
|
|
||||||
|
|
||||||
|
def create_vector_store(
|
||||||
|
database_url: str,
|
||||||
|
embed_dim: int = DEFAULT_EMBED_DIM,
|
||||||
|
hybrid_search: bool = True,
|
||||||
|
) -> "PGVectorStore":
|
||||||
|
"""创建 PGVectorStore,使用显式表名实现 schema 隔离。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
database_url: SQLAlchemy 数据库 URL(async 或 sync 均可)
|
||||||
|
embed_dim: embedding 向量维度
|
||||||
|
hybrid_search: 是否启用混合搜索(向量 + 全文)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LlamaIndex PGVectorStore 实例
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImportError: 如果 llama-index-vector-stores-postgres 未安装
|
||||||
|
"""
|
||||||
|
from llama_index.vector_stores.postgres import PGVectorStore
|
||||||
|
|
||||||
|
sync_url = _async_to_sync_url(database_url)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Creating PGVectorStore: table=%s, embed_dim=%d, hybrid=%s",
|
||||||
|
KB_CHUNKS_TABLE,
|
||||||
|
embed_dim,
|
||||||
|
hybrid_search,
|
||||||
|
)
|
||||||
|
|
||||||
|
return PGVectorStore.from_params(
|
||||||
|
database=sync_url,
|
||||||
|
table_name=KB_CHUNKS_TABLE,
|
||||||
|
embed_dim=embed_dim,
|
||||||
|
hybrid_search=hybrid_search,
|
||||||
|
text_search_config="english", # U4 将用 jieba 替换为中文分词
|
||||||
|
# ponytail: 不设 schema_name,默认 public — 避免创建独立 schema 的运维复杂度
|
||||||
|
# 表名前缀 rag_platform_ 已足够隔离
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def ensure_vector_store_schema(database_url: str, embed_dim: int = DEFAULT_EMBED_DIM) -> None:
|
||||||
|
"""确保 vector store 表存在(幂等)。
|
||||||
|
|
||||||
|
在应用启动时调用,创建表结构(如果不存在)。
|
||||||
|
不会触碰 episodic_memory 或任何 memory/ 所属表。
|
||||||
|
"""
|
||||||
|
vs = create_vector_store(database_url, embed_dim=embed_dim)
|
||||||
|
# PGVectorStore.__init__ 内部会调用 create_if_not_exists
|
||||||
|
# 显式调用 _initialize 来确保表创建
|
||||||
|
vs._initialize()
|
||||||
|
logger.info("Vector store schema ensured: table=%s", KB_CHUNKS_TABLE)
|
||||||
|
|
@ -0,0 +1,106 @@
|
||||||
|
"""RAG 平台 Pydantic 数据模型。
|
||||||
|
|
||||||
|
这些是领域模型(非 ORM),ORM 模型定义在 store.py(U2)。
|
||||||
|
与 memory/knowledge_base.py 的 KnowledgeBase protocol 分离 —
|
||||||
|
rag_platform 服务企业知识库场景,memory/ 服务 Agent 运行时记忆。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
|
||||||
|
def _utcnow() -> datetime:
|
||||||
|
return datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
|
def _uuid() -> str:
|
||||||
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
|
class KBStatus(str, Enum):
|
||||||
|
active = "active"
|
||||||
|
archived = "archived"
|
||||||
|
|
||||||
|
|
||||||
|
class DocumentStatus(str, Enum):
|
||||||
|
"""文档处理状态机:pending → parsing → segmenting → vectorizing → indexed | failed。"""
|
||||||
|
|
||||||
|
pending = "pending"
|
||||||
|
parsing = "parsing"
|
||||||
|
segmenting = "segmenting"
|
||||||
|
vectorizing = "vectorizing"
|
||||||
|
indexed = "indexed"
|
||||||
|
failed = "failed"
|
||||||
|
|
||||||
|
|
||||||
|
class QueryMode(str, Enum):
|
||||||
|
"""检索模式:embedding 语义 / keywords 全文 / blend 双索引合并。"""
|
||||||
|
|
||||||
|
embedding = "embedding"
|
||||||
|
keywords = "keywords"
|
||||||
|
blend = "blend"
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeBase(BaseModel):
|
||||||
|
"""知识库领域模型。"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
id: str = Field(default_factory=_uuid)
|
||||||
|
name: str
|
||||||
|
description: str = ""
|
||||||
|
owner: str
|
||||||
|
status: KBStatus = KBStatus.active
|
||||||
|
# 检索与命中处理默认配置(Agent 运行时可覆盖)
|
||||||
|
default_query_mode: QueryMode = QueryMode.blend
|
||||||
|
default_hit_processing: str = "model_opt" # model_opt | direct
|
||||||
|
caching_disabled: bool = False
|
||||||
|
created_at: datetime = Field(default_factory=_utcnow)
|
||||||
|
updated_at: datetime = Field(default_factory=_utcnow)
|
||||||
|
|
||||||
|
|
||||||
|
class Document(BaseModel):
|
||||||
|
"""知识库文档领域模型。"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
id: str = Field(default_factory=_uuid)
|
||||||
|
kb_id: str
|
||||||
|
filename: str
|
||||||
|
file_type: str
|
||||||
|
file_size: int
|
||||||
|
status: DocumentStatus = DocumentStatus.pending
|
||||||
|
error_message: str | None = None
|
||||||
|
created_at: datetime = Field(default_factory=_utcnow)
|
||||||
|
updated_at: datetime = Field(default_factory=_utcnow)
|
||||||
|
|
||||||
|
|
||||||
|
class Chunk(BaseModel):
|
||||||
|
"""文档分段后的文本块。"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
id: str = Field(default_factory=_uuid)
|
||||||
|
document_id: str
|
||||||
|
kb_id: str
|
||||||
|
content: str
|
||||||
|
metadata: dict = Field(default_factory=dict)
|
||||||
|
embedding: list[float] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class QueryResult(BaseModel):
|
||||||
|
"""检索结果条目。"""
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
chunk_id: str
|
||||||
|
content: str
|
||||||
|
score: float
|
||||||
|
metadata: dict = Field(default_factory=dict)
|
||||||
|
document_id: str
|
||||||
|
kb_id: str
|
||||||
|
|
@ -0,0 +1,122 @@
|
||||||
|
"""LlamaIndex IngestionPipeline 封装 — 文档处理管道。
|
||||||
|
|
||||||
|
管道流程:文档 → SentenceSplitter 分段 → embedding → pgvector 索引。
|
||||||
|
U3 将扩展为完整 IngestionPipeline(含解析、预览、净化)。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from agentkit.rag_platform.indexing import KB_CHUNKS_TABLE
|
||||||
|
from agentkit.rag_platform.models import QueryResult
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from llama_index.core.embeddings import BaseEmbedding
|
||||||
|
from llama_index.core.schema import TextNode
|
||||||
|
from llama_index.vector_stores.postgres import PGVectorStore
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_CHUNK_SIZE = 512
|
||||||
|
DEFAULT_CHUNK_OVERLAP = 50
|
||||||
|
DEFAULT_TOP_K = 5
|
||||||
|
|
||||||
|
|
||||||
|
class RAGPipeline:
|
||||||
|
"""封装 LlamaIndex IngestionPipeline 用于 KB 文档处理。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vector_store: LlamaIndex PGVectorStore 实例
|
||||||
|
embed_model: LlamaIndex embedding 模型
|
||||||
|
chunk_size: 分段大小(token 数)
|
||||||
|
chunk_overlap: 分段重叠(token 数)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vector_store: "PGVectorStore",
|
||||||
|
embed_model: "BaseEmbedding",
|
||||||
|
chunk_size: int = DEFAULT_CHUNK_SIZE,
|
||||||
|
chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
|
||||||
|
) -> None:
|
||||||
|
from llama_index.core.ingestion import IngestionPipeline
|
||||||
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
|
|
||||||
|
self._vector_store = vector_store
|
||||||
|
self._embed_model = embed_model
|
||||||
|
|
||||||
|
self._pipeline = IngestionPipeline(
|
||||||
|
transformations=[
|
||||||
|
SentenceSplitter(
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
chunk_overlap=chunk_overlap,
|
||||||
|
),
|
||||||
|
embed_model,
|
||||||
|
],
|
||||||
|
vector_store=vector_store,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"RAGPipeline initialized: chunk_size=%d, chunk_overlap=%d, table=%s",
|
||||||
|
chunk_size,
|
||||||
|
chunk_overlap,
|
||||||
|
KB_CHUNKS_TABLE,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def ingest(self, text: str, metadata: dict[str, Any] | None = None) -> list["TextNode"]:
|
||||||
|
"""将文本摄入向量存储。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 文档文本
|
||||||
|
metadata: 附加到每个 chunk 的元数据(kb_id, document_id 等)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LlamaIndex TextNode 列表(已写入 vector store)
|
||||||
|
"""
|
||||||
|
from llama_index.core.schema import Document as LIDocument
|
||||||
|
|
||||||
|
doc = LIDocument(text=text, metadata=metadata or {})
|
||||||
|
nodes = await self._pipeline.arun(documents=[doc])
|
||||||
|
logger.info("Ingested %d chunks from document", len(nodes))
|
||||||
|
return nodes
|
||||||
|
|
||||||
|
async def query(
|
||||||
|
self,
|
||||||
|
query_text: str,
|
||||||
|
top_k: int = DEFAULT_TOP_K,
|
||||||
|
) -> list[QueryResult]:
|
||||||
|
"""语义检索。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_text: 查询文本
|
||||||
|
top_k: 返回结果数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
QueryResult 列表(按相似度降序)
|
||||||
|
"""
|
||||||
|
from llama_index.core.vector_stores.types import VectorStoreQuery
|
||||||
|
|
||||||
|
query_embedding = await self._embed_model.aget_text_embedding(query_text)
|
||||||
|
|
||||||
|
vs_query = VectorStoreQuery(
|
||||||
|
query_embedding=query_embedding,
|
||||||
|
similarity_top_k=top_k,
|
||||||
|
)
|
||||||
|
result = await self._vector_store.aquery(vs_query)
|
||||||
|
|
||||||
|
results: list[QueryResult] = []
|
||||||
|
for node, score in zip(result.nodes, result.similarities):
|
||||||
|
meta = node.metadata or {}
|
||||||
|
results.append(
|
||||||
|
QueryResult(
|
||||||
|
chunk_id=node.node_id,
|
||||||
|
content=node.get_content(),
|
||||||
|
score=float(score) if score is not None else 0.0,
|
||||||
|
metadata=meta,
|
||||||
|
document_id=meta.get("document_id", ""),
|
||||||
|
kb_id=meta.get("kb_id", ""),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.info("Query returned %d results (top_k=%d)", len(results), top_k)
|
||||||
|
return results
|
||||||
|
|
@ -0,0 +1,255 @@
|
||||||
|
"""U1 测试 — RAG 平台模块骨架 + LlamaIndex 集成。
|
||||||
|
|
||||||
|
测试场景:
|
||||||
|
1. LlamaIndex PGVectorStore 连接现有 pgvector 扩展(mock 验证参数)
|
||||||
|
2. 基础 ingest(文档 → chunk → embedding → pgvector INSERT)端到端工作
|
||||||
|
3. 基础 query(query → embedding → pgvector cosine 检索)返回结果
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
from agentkit.rag_platform.indexing import (
|
||||||
|
DEFAULT_EMBED_DIM,
|
||||||
|
KB_CHUNKS_TABLE,
|
||||||
|
_async_to_sync_url,
|
||||||
|
create_vector_store,
|
||||||
|
)
|
||||||
|
from agentkit.rag_platform.models import (
|
||||||
|
Chunk,
|
||||||
|
Document,
|
||||||
|
DocumentStatus,
|
||||||
|
KBStatus,
|
||||||
|
KnowledgeBase,
|
||||||
|
QueryMode,
|
||||||
|
QueryResult,
|
||||||
|
)
|
||||||
|
from agentkit.rag_platform.pipeline import RAGPipeline
|
||||||
|
|
||||||
|
|
||||||
|
class TestModels:
|
||||||
|
"""领域模型测试。"""
|
||||||
|
|
||||||
|
def test_knowledge_base_defaults(self):
|
||||||
|
"""KnowledgeBase 默认值正确。"""
|
||||||
|
kb = KnowledgeBase(name="test-kb", owner="user1")
|
||||||
|
assert kb.status == KBStatus.active
|
||||||
|
assert kb.default_query_mode == QueryMode.blend
|
||||||
|
assert kb.default_hit_processing == "model_opt"
|
||||||
|
assert kb.caching_disabled is False
|
||||||
|
assert kb.id # auto-generated UUID
|
||||||
|
|
||||||
|
def test_document_status_transitions(self):
|
||||||
|
"""DocumentStatus 状态机值正确。"""
|
||||||
|
doc = Document(
|
||||||
|
kb_id="kb1",
|
||||||
|
filename="test.pdf",
|
||||||
|
file_type="pdf",
|
||||||
|
file_size=1024,
|
||||||
|
)
|
||||||
|
assert doc.status == DocumentStatus.pending
|
||||||
|
assert doc.error_message is None
|
||||||
|
|
||||||
|
def test_chunk_model(self):
|
||||||
|
"""Chunk 模型字段正确。"""
|
||||||
|
chunk = Chunk(
|
||||||
|
document_id="doc1",
|
||||||
|
kb_id="kb1",
|
||||||
|
content="hello world",
|
||||||
|
)
|
||||||
|
assert chunk.embedding is None
|
||||||
|
assert chunk.metadata == {}
|
||||||
|
|
||||||
|
def test_query_result_model(self):
|
||||||
|
"""QueryResult 模型字段正确。"""
|
||||||
|
result = QueryResult(
|
||||||
|
chunk_id="c1",
|
||||||
|
content="hello",
|
||||||
|
score=0.95,
|
||||||
|
document_id="doc1",
|
||||||
|
kb_id="kb1",
|
||||||
|
)
|
||||||
|
assert result.score == 0.95
|
||||||
|
assert result.metadata == {}
|
||||||
|
|
||||||
|
|
||||||
|
class TestIndexing:
|
||||||
|
"""pgvector 索引管理测试。"""
|
||||||
|
|
||||||
|
def test_async_to_sync_url_conversion(self):
|
||||||
|
"""asyncpg URL 正确转换为 psycopg2 URL。"""
|
||||||
|
async_url = "postgresql+asyncpg://user:pass@localhost:5432/db"
|
||||||
|
sync_url = _async_to_sync_url(async_url)
|
||||||
|
assert sync_url == "postgresql://user:pass@localhost:5432/db"
|
||||||
|
|
||||||
|
def test_sync_url_unchanged(self):
|
||||||
|
"""已经是 sync URL 时不转换。"""
|
||||||
|
sync_url = "postgresql://user:pass@localhost:5432/db"
|
||||||
|
assert _async_to_sync_url(sync_url) == sync_url
|
||||||
|
|
||||||
|
@patch("llama_index.vector_stores.postgres.PGVectorStore.from_params")
|
||||||
|
def test_create_vector_store_uses_explicit_table_name(self, mock_from_params):
|
||||||
|
"""create_vector_store 使用显式表名(schema 隔离)。"""
|
||||||
|
mock_store = MagicMock()
|
||||||
|
mock_from_params.return_value = mock_store
|
||||||
|
|
||||||
|
result = create_vector_store(
|
||||||
|
"postgresql+asyncpg://user:pass@localhost:5432/db",
|
||||||
|
embed_dim=768,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 验证 from_params 被调用时使用了正确的表名
|
||||||
|
mock_from_params.assert_called_once()
|
||||||
|
call_kwargs = mock_from_params.call_args.kwargs
|
||||||
|
assert call_kwargs["table_name"] == KB_CHUNKS_TABLE
|
||||||
|
assert call_kwargs["embed_dim"] == 768
|
||||||
|
assert call_kwargs["hybrid_search"] is True
|
||||||
|
assert result is mock_store
|
||||||
|
|
||||||
|
@patch("llama_index.vector_stores.postgres.PGVectorStore.from_params")
|
||||||
|
def test_create_vector_store_default_embed_dim(self, mock_from_params):
|
||||||
|
"""create_vector_store 默认 embed_dim 为 1536。"""
|
||||||
|
mock_from_params.return_value = MagicMock()
|
||||||
|
|
||||||
|
create_vector_store("postgresql://localhost/db")
|
||||||
|
|
||||||
|
call_kwargs = mock_from_params.call_args.kwargs
|
||||||
|
assert call_kwargs["embed_dim"] == DEFAULT_EMBED_DIM
|
||||||
|
|
||||||
|
@patch("llama_index.vector_stores.postgres.PGVectorStore.from_params")
|
||||||
|
def test_create_vector_store_converts_async_url(self, mock_from_params):
|
||||||
|
"""create_vector_store 将 asyncpg URL 转换为 sync URL。"""
|
||||||
|
mock_from_params.return_value = MagicMock()
|
||||||
|
|
||||||
|
create_vector_store("postgresql+asyncpg://user:pass@localhost:5432/db")
|
||||||
|
|
||||||
|
call_kwargs = mock_from_params.call_args.kwargs
|
||||||
|
assert call_kwargs["database"] == "postgresql://user:pass@localhost:5432/db"
|
||||||
|
|
||||||
|
|
||||||
|
class TestRAGPipeline:
|
||||||
|
"""RAGPipeline 管道测试。"""
|
||||||
|
|
||||||
|
def _make_mock_embed_model(self):
|
||||||
|
"""创建 mock embedding 模型。"""
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.aget_text_embedding = AsyncMock(return_value=[0.1] * 1536)
|
||||||
|
return mock
|
||||||
|
|
||||||
|
def _make_mock_vector_store(self):
|
||||||
|
"""创建 mock vector store。"""
|
||||||
|
mock = MagicMock()
|
||||||
|
mock.aquery = AsyncMock()
|
||||||
|
return mock
|
||||||
|
|
||||||
|
def _make_mock_text_node(self, node_id: str, content: str, metadata: dict | None = None):
|
||||||
|
"""创建 mock TextNode。"""
|
||||||
|
node = MagicMock()
|
||||||
|
node.node_id = node_id
|
||||||
|
node.get_content.return_value = content
|
||||||
|
node.metadata = metadata or {}
|
||||||
|
return node
|
||||||
|
|
||||||
|
async def test_ingest_calls_pipeline_arun(self):
|
||||||
|
"""ingest 调用 IngestionPipeline.arun 并返回 TextNode 列表。"""
|
||||||
|
mock_vs = self._make_mock_vector_store()
|
||||||
|
mock_embed = self._make_mock_embed_model()
|
||||||
|
|
||||||
|
mock_nodes = [
|
||||||
|
self._make_mock_text_node("n1", "chunk 1", {"kb_id": "kb1"}),
|
||||||
|
self._make_mock_text_node("n2", "chunk 2", {"kb_id": "kb1"}),
|
||||||
|
]
|
||||||
|
|
||||||
|
with patch("llama_index.core.ingestion.IngestionPipeline") as mock_pipeline_cls:
|
||||||
|
mock_pipeline = MagicMock()
|
||||||
|
mock_pipeline.arun = AsyncMock(return_value=mock_nodes)
|
||||||
|
mock_pipeline_cls.return_value = mock_pipeline
|
||||||
|
|
||||||
|
rag = RAGPipeline(mock_vs, mock_embed, chunk_size=256, chunk_overlap=20)
|
||||||
|
nodes = await rag.ingest("hello world", metadata={"kb_id": "kb1", "document_id": "d1"})
|
||||||
|
|
||||||
|
assert len(nodes) == 2
|
||||||
|
assert nodes[0].node_id == "n1"
|
||||||
|
mock_pipeline.arun.assert_awaited_once()
|
||||||
|
|
||||||
|
async def test_query_returns_query_results(self):
|
||||||
|
"""query 返回 QueryResult 列表,按相似度排序。"""
|
||||||
|
mock_vs = self._make_mock_vector_store()
|
||||||
|
mock_embed = self._make_mock_embed_model()
|
||||||
|
|
||||||
|
# 模拟 vector store query 返回
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.nodes = [
|
||||||
|
self._make_mock_text_node(
|
||||||
|
"n1", "relevant chunk", {"kb_id": "kb1", "document_id": "d1"}
|
||||||
|
),
|
||||||
|
self._make_mock_text_node("n2", "another chunk", {"kb_id": "kb1", "document_id": "d1"}),
|
||||||
|
]
|
||||||
|
mock_result.similarities = [0.95, 0.80]
|
||||||
|
mock_vs.aquery.return_value = mock_result
|
||||||
|
|
||||||
|
with patch("llama_index.core.ingestion.IngestionPipeline"):
|
||||||
|
rag = RAGPipeline(mock_vs, mock_embed)
|
||||||
|
results = await rag.query("test query", top_k=2)
|
||||||
|
|
||||||
|
assert len(results) == 2
|
||||||
|
assert isinstance(results[0], QueryResult)
|
||||||
|
assert results[0].chunk_id == "n1"
|
||||||
|
assert results[0].score == 0.95
|
||||||
|
assert results[0].kb_id == "kb1"
|
||||||
|
assert results[1].score == 0.80
|
||||||
|
|
||||||
|
async def test_query_calls_embed_model(self):
|
||||||
|
"""query 调用 embedding 模型获取查询向量。"""
|
||||||
|
mock_vs = self._make_mock_vector_store()
|
||||||
|
mock_embed = self._make_mock_embed_model()
|
||||||
|
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.nodes = []
|
||||||
|
mock_result.similarities = []
|
||||||
|
mock_vs.aquery.return_value = mock_result
|
||||||
|
|
||||||
|
with patch("llama_index.core.ingestion.IngestionPipeline"):
|
||||||
|
rag = RAGPipeline(mock_vs, mock_embed)
|
||||||
|
await rag.query("test", top_k=5)
|
||||||
|
|
||||||
|
mock_embed.aget_text_embedding.assert_awaited_once_with("test")
|
||||||
|
|
||||||
|
async def test_query_passes_top_k_to_vector_store(self):
|
||||||
|
"""query 将 top_k 传递给 vector store query。"""
|
||||||
|
mock_vs = self._make_mock_vector_store()
|
||||||
|
mock_embed = self._make_mock_embed_model()
|
||||||
|
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.nodes = []
|
||||||
|
mock_result.similarities = []
|
||||||
|
mock_vs.aquery.return_value = mock_result
|
||||||
|
|
||||||
|
with patch("llama_index.core.ingestion.IngestionPipeline"):
|
||||||
|
rag = RAGPipeline(mock_vs, mock_embed)
|
||||||
|
await rag.query("test", top_k=10)
|
||||||
|
|
||||||
|
# 验证 aquery 被调用时 similarity_top_k=10
|
||||||
|
call_args = mock_vs.aquery.call_args
|
||||||
|
vs_query = call_args.args[0]
|
||||||
|
assert vs_query.similarity_top_k == 10
|
||||||
|
|
||||||
|
def test_pipeline_init_with_custom_chunk_params(self):
|
||||||
|
"""RAGPipeline 接受自定义 chunk_size 和 chunk_overlap。"""
|
||||||
|
mock_vs = self._make_mock_vector_store()
|
||||||
|
mock_embed = self._make_mock_embed_model()
|
||||||
|
|
||||||
|
with patch("llama_index.core.ingestion.IngestionPipeline") as mock_pipeline_cls:
|
||||||
|
RAGPipeline(mock_vs, mock_embed, chunk_size=1024, chunk_overlap=100)
|
||||||
|
|
||||||
|
# 验证 SentenceSplitter 被创建时使用了自定义参数
|
||||||
|
call_args = mock_pipeline_cls.call_args
|
||||||
|
transformations = call_args.kwargs.get("transformations") or call_args[1].get(
|
||||||
|
"transformations"
|
||||||
|
)
|
||||||
|
assert transformations is not None
|
||||||
|
# 第一个 transformation 是 SentenceSplitter
|
||||||
|
splitter = transformations[0]
|
||||||
|
assert splitter.chunk_size == 1024
|
||||||
|
assert splitter.chunk_overlap == 100
|
||||||
Loading…
Reference in New Issue