fischer-agentkit/src/agentkit/memory/local_rag.py

538 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""LocalRAGService - 本地文档 RAG 服务
实现 KnowledgeBase 协议,支持文档摄取、语义检索、删除和来源追溯。
提供两种实现:
- LocalRAGService: 基于 pgvector + PostgreSQL生产环境
- InMemoryLocalRAGService: 基于内存(测试和开发环境)
"""
from __future__ import annotations
import json
import logging
import re
from datetime import datetime, timezone
from typing import TYPE_CHECKING, TypeAlias
from agentkit.memory.chunking import Chunk, StructuralChunker, TextChunker
from agentkit.memory.document_loader import Document as LoaderDocument
from agentkit.memory.embedder import Embedder
from agentkit.memory.knowledge_base import (
Document,
QueryResult,
SourceInfo,
)
from agentkit.utils.vector_math import compute_cosine_similarity
if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncSession
_SAFE_TABLE_NAME_PATTERN = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")
logger = logging.getLogger(__name__)
# InMemoryLocalRAGService 内部存储的文档元信息结构。
# 字段title/source_id/format/chunk_ids/metadata/created_at — 值为标量或 list[str]。
InMemoryDocInfo: TypeAlias = dict[str, object]
# 内部 chunk 存储结构content/embedding/metadata/source_doc_id。
InMemoryChunkInfo: TypeAlias = dict[str, object]
def _loader_doc_to_kb_doc(loader_doc: LoaderDocument) -> Document:
"""将 document_loader.Document 转换为 knowledge_base.Document"""
return Document(
doc_id=loader_doc.doc_id,
content=loader_doc.content,
title=loader_doc.title,
source_id=loader_doc.metadata.get("source", ""),
metadata=loader_doc.metadata,
)
class LocalRAGService:
"""基于 pgvector 的本地 RAG 服务
实现 KnowledgeBase 协议,使用 pgvector 存储 + 检索。
复用 EpisodicMemory 的 pgvector 基础设施模式。
摄取 Pipeline上传 → 解析 → 分块 → 嵌入 → 写入 pgvector
"""
def __init__(
self,
session_factory: object,
embedder: Embedder,
chunk_size: int = 1000,
chunk_overlap: int = 200,
table_name: str = "knowledge_chunks",
pgvector_enabled: bool = True,
):
"""
Args:
session_factory: 返回 async context manager 的工厂
embedder: 嵌入器,用于生成向量
chunk_size: 分块大小
chunk_overlap: 分块重叠
table_name: pgvector 查询使用的表名
pgvector_enabled: 是否使用 pgvector 原生检索
"""
self._session_factory = session_factory
self._embedder = embedder
self._chunk_size = chunk_size
self._chunk_overlap = chunk_overlap
self._table_name = table_name
if not _SAFE_TABLE_NAME_PATTERN.match(self._table_name):
raise ValueError(
f"Invalid table_name: {self._table_name}. Must match [a-zA-Z_][a-zA-Z0-9_]*"
)
self._pgvector_enabled = pgvector_enabled
self._text_chunker = TextChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
self._structural_chunker = StructuralChunker(
chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
async def ingest(self, documents: list[Document]) -> list[str]:
"""摄取文档列表
Args:
documents: knowledge_base.Document 对象列表
Returns:
成功摄取的文档 ID 列表
"""
ingested_ids = []
for doc in documents:
try:
chunks = self._chunk_document(doc)
await self._store_chunks(doc, chunks)
ingested_ids.append(doc.doc_id)
logger.info(f"Ingested document '{doc.title}' with {len(chunks)} chunks")
except Exception as e:
logger.error(f"Failed to ingest document '{doc.title}': {e}")
return ingested_ids
async def query(self, text: str, top_k: int = 5) -> list[QueryResult]:
"""语义检索
Args:
text: 查询文本
top_k: 返回结果数量
Returns:
检索结果列表
"""
query_embedding = await self._embedder.embed(text)
async with self._session_factory() as db:
try:
if self._pgvector_enabled:
return await self._query_pgvector(db, query_embedding, top_k)
return await self._query_client_side(db, query_embedding, top_k)
except Exception as e:
logger.error(f"Failed to query knowledge base: {e}")
return []
async def delete_by_id(self, id: str) -> bool:
"""按文档 ID 删除
Args:
id: 文档 ID
Returns:
是否删除成功
"""
async with self._session_factory() as db:
try:
from sqlalchemy import text as sql_text
sql = sql_text(f"DELETE FROM {self._table_name} WHERE source_doc_id = :doc_id")
await db.execute(sql, {"doc_id": id})
await db.commit()
return True
except Exception as e:
await db.rollback()
logger.error(f"Failed to delete document {id}: {e}")
return False
async def list_sources(self) -> list[SourceInfo]:
"""列出已摄取的文档
Returns:
文档元信息列表
"""
async with self._session_factory() as db:
try:
from sqlalchemy import text as sql_text
sql = sql_text(
f"SELECT source_doc_id, source_title, doc_format, "
f"COUNT(*) as chunk_count, "
f"MIN(created_at) as created_at, "
f"MIN(doc_metadata) as doc_metadata "
f"FROM {self._table_name} "
f"GROUP BY source_doc_id, source_title, doc_format "
f"ORDER BY MIN(created_at) DESC"
)
result = await db.execute(sql)
rows = result.mappings().all()
sources = []
for row in rows:
sources.append(
SourceInfo(
source_id=row["source_doc_id"],
source_name=row.get("source_title", ""),
source_type=row.get("doc_format", "local"),
document_count=row.get("chunk_count", 0),
last_updated=row["created_at"] if row.get("created_at") else None,
)
)
return sources
except Exception as e:
logger.error(f"Failed to list sources: {e}")
return []
async def health_check(self) -> bool:
"""检查服务健康状态"""
async with self._session_factory() as db:
try:
from sqlalchemy import text as sql_text
await db.execute(sql_text(f"SELECT 1 FROM {self._table_name} LIMIT 1"))
return True
except Exception as e:
logger.error(f"Health check failed: {e}")
return False
def _chunk_document(self, doc: Document) -> list[Chunk]:
"""将文档分块"""
doc_format = doc.metadata.get("format", "text")
# Markdown 和 HTML 使用结构化分块
if doc_format in ("markdown", "html"):
chunks = self._structural_chunker.chunk(
doc.content,
source_doc_id=doc.doc_id,
metadata=doc.metadata,
)
else:
chunks = self._text_chunker.chunk(
doc.content,
source_doc_id=doc.doc_id,
metadata=doc.metadata,
)
return chunks
async def _store_chunks(self, doc: Document, chunks: list[Chunk]) -> None:
"""存储文档块到 pgvector"""
async with self._session_factory() as db:
try:
from sqlalchemy import text as sql_text
# Batch embedding generation
embeddings: list[list[float]] = []
if hasattr(self._embedder, "embed_batch"):
embeddings = await self._embedder.embed_batch([c.content for c in chunks])
else:
for chunk in chunks:
embedding = await self._embedder.embed(chunk.content)
embeddings.append(embedding)
# Batch INSERT using executemany
sql = sql_text(
f"INSERT INTO {self._table_name} "
f"(chunk_id, source_doc_id, source_title, doc_format, "
f"content, embedding, chunk_metadata, doc_metadata, created_at) "
f"VALUES (:chunk_id, :doc_id, :title, :format, "
f":content, :embedding, :chunk_meta, :doc_meta, :created_at)"
)
now = datetime.now(timezone.utc)
params_list = [
{
"chunk_id": chunk.chunk_id,
"doc_id": doc.doc_id,
"title": doc.title,
"format": doc.metadata.get("format", "unknown"),
"content": chunk.content,
"embedding": str(embedding),
"chunk_meta": json.dumps(chunk.metadata, ensure_ascii=False),
"doc_meta": json.dumps(doc.metadata, ensure_ascii=False),
"created_at": now,
}
for chunk, embedding in zip(chunks, embeddings)
]
await db.execute(sql, params_list)
await db.commit()
except Exception as e:
await db.rollback()
logger.error(f"Failed to store chunks for document '{doc.title}': {e}")
raise
async def _query_pgvector(
self,
db: AsyncSession,
query_embedding: list[float],
top_k: int,
) -> list[QueryResult]:
"""使用 pgvector <=> 算符检索"""
from sqlalchemy import text as sql_text
sql = sql_text(
f"SELECT chunk_id, source_doc_id, source_title, content, "
f"chunk_metadata, embedding <=> :query_vec AS distance "
f"FROM {self._table_name} "
f"ORDER BY embedding <=> :query_vec "
f"LIMIT :lim"
)
result = await db.execute(
sql,
{
"query_vec": str(query_embedding),
"lim": top_k,
},
)
rows = result.mappings().all()
results = []
for row in rows:
# 从 distance 计算 cosine similarity
distance = row.get("distance", 0.0)
# pgvector <=> 返回 cosine distance = 1 - cosine_similarity
cosine = max(0.0, 1.0 - float(distance))
chunk_meta = {}
if row.get("chunk_metadata"):
try:
chunk_meta = json.loads(row["chunk_metadata"])
except (json.JSONDecodeError, TypeError):
pass
results.append(
QueryResult(
content=row["content"],
source_id=row["source_doc_id"],
source_name=row.get("source_title", ""),
score=cosine,
metadata=chunk_meta,
doc_id=row["source_doc_id"],
title=row.get("source_title", ""),
)
)
return results
async def _query_client_side(
self,
db: AsyncSession,
query_embedding: list[float],
top_k: int,
) -> list[QueryResult]:
"""客户端 O(N) cosine similarity 检索(回退路径)"""
from sqlalchemy import text as sql_text
sql = sql_text(
f"SELECT chunk_id, source_doc_id, source_title, content, "
f"chunk_metadata, embedding "
f"FROM {self._table_name} "
f"LIMIT 500"
)
result = await db.execute(sql)
rows = result.mappings().all()
candidates = []
for row in rows:
row_embedding = row.get("embedding")
if row_embedding is None:
continue
# 解析存储的 embedding
try:
if isinstance(row_embedding, str):
stored_embedding = json.loads(row_embedding)
else:
stored_embedding = list(row_embedding)
except (json.JSONDecodeError, TypeError):
continue
cosine = compute_cosine_similarity(query_embedding, stored_embedding)
if cosine < 0.1:
continue
chunk_meta = {}
if row.get("chunk_metadata"):
try:
chunk_meta = json.loads(row["chunk_metadata"])
except (json.JSONDecodeError, TypeError):
pass
candidates.append(
QueryResult(
content=row["content"],
source_id=row["source_doc_id"],
source_name=row.get("source_title", ""),
score=cosine,
metadata=chunk_meta,
doc_id=row["source_doc_id"],
title=row.get("source_title", ""),
)
)
candidates.sort(key=lambda x: x.score, reverse=True)
return candidates[:top_k]
class InMemoryLocalRAGService:
"""基于内存的本地 RAG 服务
用于测试和开发环境,无需 pgvector 依赖。
实现 KnowledgeBase 协议。
"""
def __init__(
self,
embedder: Embedder,
chunk_size: int = 1000,
chunk_overlap: int = 200,
):
"""
Args:
embedder: 嵌入器
chunk_size: 分块大小
chunk_overlap: 分块重叠
"""
self._embedder = embedder
self._text_chunker = TextChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
self._structural_chunker = StructuralChunker(
chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
# 内存存储
self._chunks: dict[str, InMemoryChunkInfo] = {} # chunk_id → {content, embedding, metadata}
self._documents: dict[
str, InMemoryDocInfo
] = {} # doc_id → {title, format, chunk_ids, metadata, created_at}
async def ingest(self, documents: list[Document]) -> list[str]:
"""摄取文档列表
也支持传入 document_loader.Document会自动转换为 knowledge_base.Document。
"""
ingested_ids = []
for doc in documents:
# 支持 document_loader.Document 自动转换
if isinstance(doc, LoaderDocument):
doc = _loader_doc_to_kb_doc(doc)
try:
chunks = self._chunk_document(doc)
chunk_ids = []
for chunk in chunks:
embedding = await self._embedder.embed(chunk.content)
self._chunks[chunk.chunk_id] = {
"content": chunk.content,
"embedding": embedding,
"metadata": chunk.metadata,
"source_doc_id": doc.doc_id,
}
chunk_ids.append(chunk.chunk_id)
self._documents[doc.doc_id] = {
"title": doc.title,
"source_id": doc.source_id,
"format": doc.metadata.get("format", "unknown"),
"chunk_ids": chunk_ids,
"metadata": doc.metadata,
"created_at": datetime.now(timezone.utc),
}
ingested_ids.append(doc.doc_id)
logger.info(f"Ingested document '{doc.title}' with {len(chunks)} chunks")
except Exception as e:
logger.error(f"Failed to ingest document '{doc.title}': {e}")
return ingested_ids
async def query(self, text: str, top_k: int = 5) -> list[QueryResult]:
"""语义检索"""
query_embedding = await self._embedder.embed(text)
candidates = []
for chunk_id, chunk_data in self._chunks.items():
stored_embedding = chunk_data["embedding"]
cosine = compute_cosine_similarity(query_embedding, stored_embedding)
if cosine < 0.1:
continue
source_doc_id = chunk_data["source_doc_id"]
doc_info = self._documents.get(source_doc_id, {})
candidates.append(
QueryResult(
content=chunk_data["content"],
source_id=source_doc_id,
source_name=doc_info.get("title", ""),
score=cosine,
metadata=chunk_data.get("metadata", {}),
doc_id=source_doc_id,
title=doc_info.get("title", ""),
)
)
candidates.sort(key=lambda x: x.score, reverse=True)
return candidates[:top_k]
async def delete_by_id(self, id: str) -> bool:
"""按文档 ID 删除"""
if id not in self._documents:
return False
doc_info = self._documents[id]
for chunk_id in doc_info.get("chunk_ids", []):
self._chunks.pop(chunk_id, None)
del self._documents[id]
return True
async def list_sources(self) -> list[SourceInfo]:
"""列出已摄取的文档"""
sources = []
for doc_id, doc_info in self._documents.items():
sources.append(
SourceInfo(
source_id=doc_id,
source_name=doc_info["title"],
source_type=doc_info.get("format", "local"),
document_count=len(doc_info.get("chunk_ids", [])),
last_updated=doc_info.get("created_at"),
)
)
return sources
async def health_check(self) -> bool:
"""检查服务健康状态"""
return True
def _chunk_document(self, doc: Document) -> list[Chunk]:
"""将文档分块"""
doc_format = doc.metadata.get("format", "text")
if doc_format in ("markdown", "html"):
return self._structural_chunker.chunk(
doc.content,
source_doc_id=doc.doc_id,
metadata=doc.metadata,
)
return self._text_chunker.chunk(
doc.content,
source_doc_id=doc.doc_id,
metadata=doc.metadata,
)