538 lines
19 KiB
Python
538 lines
19 KiB
Python
"""LocalRAGService - 本地文档 RAG 服务
|
||
|
||
实现 KnowledgeBase 协议,支持文档摄取、语义检索、删除和来源追溯。
|
||
提供两种实现:
|
||
- LocalRAGService: 基于 pgvector + PostgreSQL(生产环境)
|
||
- InMemoryLocalRAGService: 基于内存(测试和开发环境)
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import logging
|
||
import re
|
||
from datetime import datetime, timezone
|
||
from typing import TYPE_CHECKING, TypeAlias
|
||
|
||
from agentkit.memory.chunking import Chunk, StructuralChunker, TextChunker
|
||
from agentkit.memory.document_loader import Document as LoaderDocument
|
||
from agentkit.memory.embedder import Embedder
|
||
from agentkit.memory.knowledge_base import (
|
||
Document,
|
||
QueryResult,
|
||
SourceInfo,
|
||
)
|
||
from agentkit.utils.vector_math import compute_cosine_similarity
|
||
|
||
if TYPE_CHECKING:
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
_SAFE_TABLE_NAME_PATTERN = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# InMemoryLocalRAGService 内部存储的文档元信息结构。
|
||
# 字段:title/source_id/format/chunk_ids/metadata/created_at — 值为标量或 list[str]。
|
||
InMemoryDocInfo: TypeAlias = dict[str, object]
|
||
# 内部 chunk 存储结构:content/embedding/metadata/source_doc_id。
|
||
InMemoryChunkInfo: TypeAlias = dict[str, object]
|
||
|
||
|
||
def _loader_doc_to_kb_doc(loader_doc: LoaderDocument) -> Document:
|
||
"""将 document_loader.Document 转换为 knowledge_base.Document"""
|
||
return Document(
|
||
doc_id=loader_doc.doc_id,
|
||
content=loader_doc.content,
|
||
title=loader_doc.title,
|
||
source_id=loader_doc.metadata.get("source", ""),
|
||
metadata=loader_doc.metadata,
|
||
)
|
||
|
||
|
||
class LocalRAGService:
|
||
"""基于 pgvector 的本地 RAG 服务
|
||
|
||
实现 KnowledgeBase 协议,使用 pgvector 存储 + 检索。
|
||
复用 EpisodicMemory 的 pgvector 基础设施模式。
|
||
|
||
摄取 Pipeline:上传 → 解析 → 分块 → 嵌入 → 写入 pgvector
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
session_factory: object,
|
||
embedder: Embedder,
|
||
chunk_size: int = 1000,
|
||
chunk_overlap: int = 200,
|
||
table_name: str = "knowledge_chunks",
|
||
pgvector_enabled: bool = True,
|
||
):
|
||
"""
|
||
Args:
|
||
session_factory: 返回 async context manager 的工厂
|
||
embedder: 嵌入器,用于生成向量
|
||
chunk_size: 分块大小
|
||
chunk_overlap: 分块重叠
|
||
table_name: pgvector 查询使用的表名
|
||
pgvector_enabled: 是否使用 pgvector 原生检索
|
||
"""
|
||
self._session_factory = session_factory
|
||
self._embedder = embedder
|
||
self._chunk_size = chunk_size
|
||
self._chunk_overlap = chunk_overlap
|
||
self._table_name = table_name
|
||
if not _SAFE_TABLE_NAME_PATTERN.match(self._table_name):
|
||
raise ValueError(
|
||
f"Invalid table_name: {self._table_name}. Must match [a-zA-Z_][a-zA-Z0-9_]*"
|
||
)
|
||
self._pgvector_enabled = pgvector_enabled
|
||
self._text_chunker = TextChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
||
self._structural_chunker = StructuralChunker(
|
||
chunk_size=chunk_size, chunk_overlap=chunk_overlap
|
||
)
|
||
|
||
async def ingest(self, documents: list[Document]) -> list[str]:
|
||
"""摄取文档列表
|
||
|
||
Args:
|
||
documents: knowledge_base.Document 对象列表
|
||
|
||
Returns:
|
||
成功摄取的文档 ID 列表
|
||
"""
|
||
ingested_ids = []
|
||
|
||
for doc in documents:
|
||
try:
|
||
chunks = self._chunk_document(doc)
|
||
await self._store_chunks(doc, chunks)
|
||
ingested_ids.append(doc.doc_id)
|
||
logger.info(f"Ingested document '{doc.title}' with {len(chunks)} chunks")
|
||
except Exception as e:
|
||
logger.error(f"Failed to ingest document '{doc.title}': {e}")
|
||
|
||
return ingested_ids
|
||
|
||
async def query(self, text: str, top_k: int = 5) -> list[QueryResult]:
|
||
"""语义检索
|
||
|
||
Args:
|
||
text: 查询文本
|
||
top_k: 返回结果数量
|
||
|
||
Returns:
|
||
检索结果列表
|
||
"""
|
||
query_embedding = await self._embedder.embed(text)
|
||
|
||
async with self._session_factory() as db:
|
||
try:
|
||
if self._pgvector_enabled:
|
||
return await self._query_pgvector(db, query_embedding, top_k)
|
||
return await self._query_client_side(db, query_embedding, top_k)
|
||
except Exception as e:
|
||
logger.error(f"Failed to query knowledge base: {e}")
|
||
return []
|
||
|
||
async def delete_by_id(self, id: str) -> bool:
|
||
"""按文档 ID 删除
|
||
|
||
Args:
|
||
id: 文档 ID
|
||
|
||
Returns:
|
||
是否删除成功
|
||
"""
|
||
async with self._session_factory() as db:
|
||
try:
|
||
from sqlalchemy import text as sql_text
|
||
|
||
sql = sql_text(f"DELETE FROM {self._table_name} WHERE source_doc_id = :doc_id")
|
||
await db.execute(sql, {"doc_id": id})
|
||
await db.commit()
|
||
return True
|
||
except Exception as e:
|
||
await db.rollback()
|
||
logger.error(f"Failed to delete document {id}: {e}")
|
||
return False
|
||
|
||
async def list_sources(self) -> list[SourceInfo]:
|
||
"""列出已摄取的文档
|
||
|
||
Returns:
|
||
文档元信息列表
|
||
"""
|
||
async with self._session_factory() as db:
|
||
try:
|
||
from sqlalchemy import text as sql_text
|
||
|
||
sql = sql_text(
|
||
f"SELECT source_doc_id, source_title, doc_format, "
|
||
f"COUNT(*) as chunk_count, "
|
||
f"MIN(created_at) as created_at, "
|
||
f"MIN(doc_metadata) as doc_metadata "
|
||
f"FROM {self._table_name} "
|
||
f"GROUP BY source_doc_id, source_title, doc_format "
|
||
f"ORDER BY MIN(created_at) DESC"
|
||
)
|
||
result = await db.execute(sql)
|
||
rows = result.mappings().all()
|
||
|
||
sources = []
|
||
for row in rows:
|
||
sources.append(
|
||
SourceInfo(
|
||
source_id=row["source_doc_id"],
|
||
source_name=row.get("source_title", ""),
|
||
source_type=row.get("doc_format", "local"),
|
||
document_count=row.get("chunk_count", 0),
|
||
last_updated=row["created_at"] if row.get("created_at") else None,
|
||
)
|
||
)
|
||
return sources
|
||
except Exception as e:
|
||
logger.error(f"Failed to list sources: {e}")
|
||
return []
|
||
|
||
async def health_check(self) -> bool:
|
||
"""检查服务健康状态"""
|
||
async with self._session_factory() as db:
|
||
try:
|
||
from sqlalchemy import text as sql_text
|
||
|
||
await db.execute(sql_text(f"SELECT 1 FROM {self._table_name} LIMIT 1"))
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"Health check failed: {e}")
|
||
return False
|
||
|
||
def _chunk_document(self, doc: Document) -> list[Chunk]:
|
||
"""将文档分块"""
|
||
doc_format = doc.metadata.get("format", "text")
|
||
|
||
# Markdown 和 HTML 使用结构化分块
|
||
if doc_format in ("markdown", "html"):
|
||
chunks = self._structural_chunker.chunk(
|
||
doc.content,
|
||
source_doc_id=doc.doc_id,
|
||
metadata=doc.metadata,
|
||
)
|
||
else:
|
||
chunks = self._text_chunker.chunk(
|
||
doc.content,
|
||
source_doc_id=doc.doc_id,
|
||
metadata=doc.metadata,
|
||
)
|
||
|
||
return chunks
|
||
|
||
async def _store_chunks(self, doc: Document, chunks: list[Chunk]) -> None:
|
||
"""存储文档块到 pgvector"""
|
||
async with self._session_factory() as db:
|
||
try:
|
||
from sqlalchemy import text as sql_text
|
||
|
||
# Batch embedding generation
|
||
embeddings: list[list[float]] = []
|
||
if hasattr(self._embedder, "embed_batch"):
|
||
embeddings = await self._embedder.embed_batch([c.content for c in chunks])
|
||
else:
|
||
for chunk in chunks:
|
||
embedding = await self._embedder.embed(chunk.content)
|
||
embeddings.append(embedding)
|
||
|
||
# Batch INSERT using executemany
|
||
sql = sql_text(
|
||
f"INSERT INTO {self._table_name} "
|
||
f"(chunk_id, source_doc_id, source_title, doc_format, "
|
||
f"content, embedding, chunk_metadata, doc_metadata, created_at) "
|
||
f"VALUES (:chunk_id, :doc_id, :title, :format, "
|
||
f":content, :embedding, :chunk_meta, :doc_meta, :created_at)"
|
||
)
|
||
|
||
now = datetime.now(timezone.utc)
|
||
params_list = [
|
||
{
|
||
"chunk_id": chunk.chunk_id,
|
||
"doc_id": doc.doc_id,
|
||
"title": doc.title,
|
||
"format": doc.metadata.get("format", "unknown"),
|
||
"content": chunk.content,
|
||
"embedding": str(embedding),
|
||
"chunk_meta": json.dumps(chunk.metadata, ensure_ascii=False),
|
||
"doc_meta": json.dumps(doc.metadata, ensure_ascii=False),
|
||
"created_at": now,
|
||
}
|
||
for chunk, embedding in zip(chunks, embeddings)
|
||
]
|
||
|
||
await db.execute(sql, params_list)
|
||
await db.commit()
|
||
except Exception as e:
|
||
await db.rollback()
|
||
logger.error(f"Failed to store chunks for document '{doc.title}': {e}")
|
||
raise
|
||
|
||
async def _query_pgvector(
|
||
self,
|
||
db: AsyncSession,
|
||
query_embedding: list[float],
|
||
top_k: int,
|
||
) -> list[QueryResult]:
|
||
"""使用 pgvector <=> 算符检索"""
|
||
from sqlalchemy import text as sql_text
|
||
|
||
sql = sql_text(
|
||
f"SELECT chunk_id, source_doc_id, source_title, content, "
|
||
f"chunk_metadata, embedding <=> :query_vec AS distance "
|
||
f"FROM {self._table_name} "
|
||
f"ORDER BY embedding <=> :query_vec "
|
||
f"LIMIT :lim"
|
||
)
|
||
|
||
result = await db.execute(
|
||
sql,
|
||
{
|
||
"query_vec": str(query_embedding),
|
||
"lim": top_k,
|
||
},
|
||
)
|
||
rows = result.mappings().all()
|
||
|
||
results = []
|
||
for row in rows:
|
||
# 从 distance 计算 cosine similarity
|
||
distance = row.get("distance", 0.0)
|
||
# pgvector <=> 返回 cosine distance = 1 - cosine_similarity
|
||
cosine = max(0.0, 1.0 - float(distance))
|
||
|
||
chunk_meta = {}
|
||
if row.get("chunk_metadata"):
|
||
try:
|
||
chunk_meta = json.loads(row["chunk_metadata"])
|
||
except (json.JSONDecodeError, TypeError):
|
||
pass
|
||
|
||
results.append(
|
||
QueryResult(
|
||
content=row["content"],
|
||
source_id=row["source_doc_id"],
|
||
source_name=row.get("source_title", ""),
|
||
score=cosine,
|
||
metadata=chunk_meta,
|
||
doc_id=row["source_doc_id"],
|
||
title=row.get("source_title", ""),
|
||
)
|
||
)
|
||
|
||
return results
|
||
|
||
async def _query_client_side(
|
||
self,
|
||
db: AsyncSession,
|
||
query_embedding: list[float],
|
||
top_k: int,
|
||
) -> list[QueryResult]:
|
||
"""客户端 O(N) cosine similarity 检索(回退路径)"""
|
||
from sqlalchemy import text as sql_text
|
||
|
||
sql = sql_text(
|
||
f"SELECT chunk_id, source_doc_id, source_title, content, "
|
||
f"chunk_metadata, embedding "
|
||
f"FROM {self._table_name} "
|
||
f"LIMIT 500"
|
||
)
|
||
|
||
result = await db.execute(sql)
|
||
rows = result.mappings().all()
|
||
|
||
candidates = []
|
||
for row in rows:
|
||
row_embedding = row.get("embedding")
|
||
if row_embedding is None:
|
||
continue
|
||
|
||
# 解析存储的 embedding
|
||
try:
|
||
if isinstance(row_embedding, str):
|
||
stored_embedding = json.loads(row_embedding)
|
||
else:
|
||
stored_embedding = list(row_embedding)
|
||
except (json.JSONDecodeError, TypeError):
|
||
continue
|
||
|
||
cosine = compute_cosine_similarity(query_embedding, stored_embedding)
|
||
if cosine < 0.1:
|
||
continue
|
||
|
||
chunk_meta = {}
|
||
if row.get("chunk_metadata"):
|
||
try:
|
||
chunk_meta = json.loads(row["chunk_metadata"])
|
||
except (json.JSONDecodeError, TypeError):
|
||
pass
|
||
|
||
candidates.append(
|
||
QueryResult(
|
||
content=row["content"],
|
||
source_id=row["source_doc_id"],
|
||
source_name=row.get("source_title", ""),
|
||
score=cosine,
|
||
metadata=chunk_meta,
|
||
doc_id=row["source_doc_id"],
|
||
title=row.get("source_title", ""),
|
||
)
|
||
)
|
||
|
||
candidates.sort(key=lambda x: x.score, reverse=True)
|
||
return candidates[:top_k]
|
||
|
||
|
||
class InMemoryLocalRAGService:
|
||
"""基于内存的本地 RAG 服务
|
||
|
||
用于测试和开发环境,无需 pgvector 依赖。
|
||
实现 KnowledgeBase 协议。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
embedder: Embedder,
|
||
chunk_size: int = 1000,
|
||
chunk_overlap: int = 200,
|
||
):
|
||
"""
|
||
Args:
|
||
embedder: 嵌入器
|
||
chunk_size: 分块大小
|
||
chunk_overlap: 分块重叠
|
||
"""
|
||
self._embedder = embedder
|
||
self._text_chunker = TextChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
||
self._structural_chunker = StructuralChunker(
|
||
chunk_size=chunk_size, chunk_overlap=chunk_overlap
|
||
)
|
||
|
||
# 内存存储
|
||
self._chunks: dict[str, InMemoryChunkInfo] = {} # chunk_id → {content, embedding, metadata}
|
||
self._documents: dict[
|
||
str, InMemoryDocInfo
|
||
] = {} # doc_id → {title, format, chunk_ids, metadata, created_at}
|
||
|
||
async def ingest(self, documents: list[Document]) -> list[str]:
|
||
"""摄取文档列表
|
||
|
||
也支持传入 document_loader.Document,会自动转换为 knowledge_base.Document。
|
||
"""
|
||
ingested_ids = []
|
||
|
||
for doc in documents:
|
||
# 支持 document_loader.Document 自动转换
|
||
if isinstance(doc, LoaderDocument):
|
||
doc = _loader_doc_to_kb_doc(doc)
|
||
|
||
try:
|
||
chunks = self._chunk_document(doc)
|
||
chunk_ids = []
|
||
|
||
for chunk in chunks:
|
||
embedding = await self._embedder.embed(chunk.content)
|
||
self._chunks[chunk.chunk_id] = {
|
||
"content": chunk.content,
|
||
"embedding": embedding,
|
||
"metadata": chunk.metadata,
|
||
"source_doc_id": doc.doc_id,
|
||
}
|
||
chunk_ids.append(chunk.chunk_id)
|
||
|
||
self._documents[doc.doc_id] = {
|
||
"title": doc.title,
|
||
"source_id": doc.source_id,
|
||
"format": doc.metadata.get("format", "unknown"),
|
||
"chunk_ids": chunk_ids,
|
||
"metadata": doc.metadata,
|
||
"created_at": datetime.now(timezone.utc),
|
||
}
|
||
ingested_ids.append(doc.doc_id)
|
||
logger.info(f"Ingested document '{doc.title}' with {len(chunks)} chunks")
|
||
except Exception as e:
|
||
logger.error(f"Failed to ingest document '{doc.title}': {e}")
|
||
|
||
return ingested_ids
|
||
|
||
async def query(self, text: str, top_k: int = 5) -> list[QueryResult]:
|
||
"""语义检索"""
|
||
query_embedding = await self._embedder.embed(text)
|
||
|
||
candidates = []
|
||
for chunk_id, chunk_data in self._chunks.items():
|
||
stored_embedding = chunk_data["embedding"]
|
||
cosine = compute_cosine_similarity(query_embedding, stored_embedding)
|
||
if cosine < 0.1:
|
||
continue
|
||
|
||
source_doc_id = chunk_data["source_doc_id"]
|
||
doc_info = self._documents.get(source_doc_id, {})
|
||
|
||
candidates.append(
|
||
QueryResult(
|
||
content=chunk_data["content"],
|
||
source_id=source_doc_id,
|
||
source_name=doc_info.get("title", ""),
|
||
score=cosine,
|
||
metadata=chunk_data.get("metadata", {}),
|
||
doc_id=source_doc_id,
|
||
title=doc_info.get("title", ""),
|
||
)
|
||
)
|
||
|
||
candidates.sort(key=lambda x: x.score, reverse=True)
|
||
return candidates[:top_k]
|
||
|
||
async def delete_by_id(self, id: str) -> bool:
|
||
"""按文档 ID 删除"""
|
||
if id not in self._documents:
|
||
return False
|
||
|
||
doc_info = self._documents[id]
|
||
for chunk_id in doc_info.get("chunk_ids", []):
|
||
self._chunks.pop(chunk_id, None)
|
||
|
||
del self._documents[id]
|
||
return True
|
||
|
||
async def list_sources(self) -> list[SourceInfo]:
|
||
"""列出已摄取的文档"""
|
||
sources = []
|
||
for doc_id, doc_info in self._documents.items():
|
||
sources.append(
|
||
SourceInfo(
|
||
source_id=doc_id,
|
||
source_name=doc_info["title"],
|
||
source_type=doc_info.get("format", "local"),
|
||
document_count=len(doc_info.get("chunk_ids", [])),
|
||
last_updated=doc_info.get("created_at"),
|
||
)
|
||
)
|
||
return sources
|
||
|
||
async def health_check(self) -> bool:
|
||
"""检查服务健康状态"""
|
||
return True
|
||
|
||
def _chunk_document(self, doc: Document) -> list[Chunk]:
|
||
"""将文档分块"""
|
||
doc_format = doc.metadata.get("format", "text")
|
||
|
||
if doc_format in ("markdown", "html"):
|
||
return self._structural_chunker.chunk(
|
||
doc.content,
|
||
source_doc_id=doc.doc_id,
|
||
metadata=doc.metadata,
|
||
)
|
||
return self._text_chunker.chunk(
|
||
doc.content,
|
||
source_doc_id=doc.doc_id,
|
||
metadata=doc.metadata,
|
||
)
|