"""LocalRAGService - 本地文档 RAG 服务 实现 KnowledgeBase 协议,支持文档摄取、语义检索、删除和来源追溯。 提供两种实现: - LocalRAGService: 基于 pgvector + PostgreSQL(生产环境) - InMemoryLocalRAGService: 基于内存(测试和开发环境) """ from __future__ import annotations import json import logging import re import uuid from datetime import datetime, timezone from typing import Any _SAFE_TABLE_NAME_PATTERN = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$') from agentkit.memory.chunking import Chunk, StructuralChunker, TextChunker from agentkit.memory.document_loader import Document as LoaderDocument from agentkit.memory.embedder import Embedder from agentkit.memory.knowledge_base import ( Document, KnowledgeBase, QueryResult, SourceInfo, ) from agentkit.utils.vector_math import compute_cosine_similarity logger = logging.getLogger(__name__) def _loader_doc_to_kb_doc(loader_doc: LoaderDocument) -> Document: """将 document_loader.Document 转换为 knowledge_base.Document""" return Document( doc_id=loader_doc.doc_id, content=loader_doc.content, title=loader_doc.title, source_id=loader_doc.metadata.get("source", ""), metadata=loader_doc.metadata, ) class LocalRAGService: """基于 pgvector 的本地 RAG 服务 实现 KnowledgeBase 协议,使用 pgvector 存储 + 检索。 复用 EpisodicMemory 的 pgvector 基础设施模式。 摄取 Pipeline:上传 → 解析 → 分块 → 嵌入 → 写入 pgvector """ def __init__( self, session_factory: Any, embedder: Embedder, chunk_size: int = 1000, chunk_overlap: int = 200, table_name: str = "knowledge_chunks", pgvector_enabled: bool = True, ): """ Args: session_factory: 返回 async context manager 的工厂 embedder: 嵌入器,用于生成向量 chunk_size: 分块大小 chunk_overlap: 分块重叠 table_name: pgvector 查询使用的表名 pgvector_enabled: 是否使用 pgvector 原生检索 """ self._session_factory = session_factory self._embedder = embedder self._chunk_size = chunk_size self._chunk_overlap = chunk_overlap self._table_name = table_name if not _SAFE_TABLE_NAME_PATTERN.match(self._table_name): raise ValueError(f"Invalid table_name: {self._table_name}. Must match [a-zA-Z_][a-zA-Z0-9_]*") self._pgvector_enabled = pgvector_enabled self._text_chunker = TextChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap) self._structural_chunker = StructuralChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap) async def ingest(self, documents: list[Document]) -> list[str]: """摄取文档列表 Args: documents: knowledge_base.Document 对象列表 Returns: 成功摄取的文档 ID 列表 """ ingested_ids = [] for doc in documents: try: chunks = self._chunk_document(doc) await self._store_chunks(doc, chunks) ingested_ids.append(doc.doc_id) logger.info(f"Ingested document '{doc.title}' with {len(chunks)} chunks") except Exception as e: logger.error(f"Failed to ingest document '{doc.title}': {e}") return ingested_ids async def query(self, text: str, top_k: int = 5) -> list[QueryResult]: """语义检索 Args: text: 查询文本 top_k: 返回结果数量 Returns: 检索结果列表 """ query_embedding = await self._embedder.embed(text) async with self._session_factory() as db: try: if self._pgvector_enabled: return await self._query_pgvector(db, query_embedding, top_k) return await self._query_client_side(db, query_embedding, top_k) except Exception as e: logger.error(f"Failed to query knowledge base: {e}") return [] async def delete_by_id(self, id: str) -> bool: """按文档 ID 删除 Args: id: 文档 ID Returns: 是否删除成功 """ async with self._session_factory() as db: try: from sqlalchemy import text as sql_text sql = sql_text( f"DELETE FROM {self._table_name} WHERE source_doc_id = :doc_id" ) await db.execute(sql, {"doc_id": id}) await db.commit() return True except Exception as e: await db.rollback() logger.error(f"Failed to delete document {id}: {e}") return False async def list_sources(self) -> list[SourceInfo]: """列出已摄取的文档 Returns: 文档元信息列表 """ async with self._session_factory() as db: try: from sqlalchemy import text as sql_text sql = sql_text( f"SELECT source_doc_id, source_title, doc_format, " f"COUNT(*) as chunk_count, " f"MIN(created_at) as created_at, " f"MIN(doc_metadata) as doc_metadata " f"FROM {self._table_name} " f"GROUP BY source_doc_id, source_title, doc_format " f"ORDER BY MIN(created_at) DESC" ) result = await db.execute(sql) rows = result.mappings().all() sources = [] for row in rows: meta = {} if row.get("doc_metadata"): try: meta = json.loads(row["doc_metadata"]) except (json.JSONDecodeError, TypeError): pass sources.append(SourceInfo( source_id=row["source_doc_id"], source_name=row.get("source_title", ""), source_type=row.get("doc_format", "local"), document_count=row.get("chunk_count", 0), last_updated=row["created_at"] if row.get("created_at") else None, )) return sources except Exception as e: logger.error(f"Failed to list sources: {e}") return [] async def health_check(self) -> bool: """检查服务健康状态""" async with self._session_factory() as db: try: from sqlalchemy import text as sql_text await db.execute(sql_text(f"SELECT 1 FROM {self._table_name} LIMIT 1")) return True except Exception as e: logger.error(f"Health check failed: {e}") return False def _chunk_document(self, doc: Document) -> list[Chunk]: """将文档分块""" doc_format = doc.metadata.get("format", "text") # Markdown 和 HTML 使用结构化分块 if doc_format in ("markdown", "html"): chunks = self._structural_chunker.chunk( doc.content, source_doc_id=doc.doc_id, metadata=doc.metadata, ) else: chunks = self._text_chunker.chunk( doc.content, source_doc_id=doc.doc_id, metadata=doc.metadata, ) return chunks async def _store_chunks(self, doc: Document, chunks: list[Chunk]) -> None: """存储文档块到 pgvector""" async with self._session_factory() as db: try: from sqlalchemy import text as sql_text for chunk in chunks: # 生成嵌入 embedding = await self._embedder.embed(chunk.content) sql = sql_text( f"INSERT INTO {self._table_name} " f"(chunk_id, source_doc_id, source_title, doc_format, " f"content, embedding, chunk_metadata, doc_metadata, created_at) " f"VALUES (:chunk_id, :doc_id, :title, :format, " f":content, :embedding, :chunk_meta, :doc_meta, :created_at)" ) await db.execute(sql, { "chunk_id": chunk.chunk_id, "doc_id": doc.doc_id, "title": doc.title, "format": doc.metadata.get("format", "unknown"), "content": chunk.content, "embedding": str(embedding), "chunk_meta": json.dumps(chunk.metadata, ensure_ascii=False), "doc_meta": json.dumps(doc.metadata, ensure_ascii=False), "created_at": datetime.now(timezone.utc), }) await db.commit() except Exception as e: await db.rollback() logger.error(f"Failed to store chunks for document '{doc.title}': {e}") raise async def _query_pgvector( self, db: Any, query_embedding: list[float], top_k: int, ) -> list[QueryResult]: """使用 pgvector <=> 算符检索""" from sqlalchemy import text as sql_text sql = sql_text( f"SELECT chunk_id, source_doc_id, source_title, content, " f"chunk_metadata, embedding <=> :query_vec AS distance " f"FROM {self._table_name} " f"ORDER BY embedding <=> :query_vec " f"LIMIT :lim" ) result = await db.execute(sql, { "query_vec": str(query_embedding), "lim": top_k, }) rows = result.mappings().all() results = [] for row in rows: # 从 distance 计算 cosine similarity distance = row.get("distance", 0.0) # pgvector <=> 返回 cosine distance = 1 - cosine_similarity cosine = max(0.0, 1.0 - float(distance)) chunk_meta = {} if row.get("chunk_metadata"): try: chunk_meta = json.loads(row["chunk_metadata"]) except (json.JSONDecodeError, TypeError): pass results.append(QueryResult( content=row["content"], source_id=row["source_doc_id"], source_name=row.get("source_title", ""), score=cosine, metadata=chunk_meta, doc_id=row["source_doc_id"], title=row.get("source_title", ""), )) return results async def _query_client_side( self, db: Any, query_embedding: list[float], top_k: int, ) -> list[QueryResult]: """客户端 O(N) cosine similarity 检索(回退路径)""" from sqlalchemy import text as sql_text sql = sql_text( f"SELECT chunk_id, source_doc_id, source_title, content, " f"chunk_metadata, embedding " f"FROM {self._table_name} " f"LIMIT 500" ) result = await db.execute(sql) rows = result.mappings().all() candidates = [] for row in rows: row_embedding = row.get("embedding") if row_embedding is None: continue # 解析存储的 embedding try: if isinstance(row_embedding, str): stored_embedding = json.loads(row_embedding) else: stored_embedding = list(row_embedding) except (json.JSONDecodeError, TypeError): continue cosine = compute_cosine_similarity(query_embedding, stored_embedding) if cosine < 0.1: continue chunk_meta = {} if row.get("chunk_metadata"): try: chunk_meta = json.loads(row["chunk_metadata"]) except (json.JSONDecodeError, TypeError): pass candidates.append(QueryResult( content=row["content"], source_id=row["source_doc_id"], source_name=row.get("source_title", ""), score=cosine, metadata=chunk_meta, doc_id=row["source_doc_id"], title=row.get("source_title", ""), )) candidates.sort(key=lambda x: x.score, reverse=True) return candidates[:top_k] class InMemoryLocalRAGService: """基于内存的本地 RAG 服务 用于测试和开发环境,无需 pgvector 依赖。 实现 KnowledgeBase 协议。 """ def __init__( self, embedder: Embedder, chunk_size: int = 1000, chunk_overlap: int = 200, ): """ Args: embedder: 嵌入器 chunk_size: 分块大小 chunk_overlap: 分块重叠 """ self._embedder = embedder self._text_chunker = TextChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap) self._structural_chunker = StructuralChunker(chunk_size=chunk_size, chunk_overlap=chunk_overlap) # 内存存储 self._chunks: dict[str, dict[str, Any]] = {} # chunk_id → {content, embedding, metadata} self._documents: dict[str, dict[str, Any]] = {} # doc_id → {title, format, chunk_ids, metadata, created_at} async def ingest(self, documents: list[Document]) -> list[str]: """摄取文档列表 也支持传入 document_loader.Document,会自动转换为 knowledge_base.Document。 """ ingested_ids = [] for doc in documents: # 支持 document_loader.Document 自动转换 if isinstance(doc, LoaderDocument): doc = _loader_doc_to_kb_doc(doc) try: chunks = self._chunk_document(doc) chunk_ids = [] for chunk in chunks: embedding = await self._embedder.embed(chunk.content) self._chunks[chunk.chunk_id] = { "content": chunk.content, "embedding": embedding, "metadata": chunk.metadata, "source_doc_id": doc.doc_id, } chunk_ids.append(chunk.chunk_id) self._documents[doc.doc_id] = { "title": doc.title, "source_id": doc.source_id, "format": doc.metadata.get("format", "unknown"), "chunk_ids": chunk_ids, "metadata": doc.metadata, "created_at": datetime.now(timezone.utc), } ingested_ids.append(doc.doc_id) logger.info(f"Ingested document '{doc.title}' with {len(chunks)} chunks") except Exception as e: logger.error(f"Failed to ingest document '{doc.title}': {e}") return ingested_ids async def query(self, text: str, top_k: int = 5) -> list[QueryResult]: """语义检索""" query_embedding = await self._embedder.embed(text) candidates = [] for chunk_id, chunk_data in self._chunks.items(): stored_embedding = chunk_data["embedding"] cosine = compute_cosine_similarity(query_embedding, stored_embedding) if cosine < 0.1: continue source_doc_id = chunk_data["source_doc_id"] doc_info = self._documents.get(source_doc_id, {}) candidates.append(QueryResult( content=chunk_data["content"], source_id=source_doc_id, source_name=doc_info.get("title", ""), score=cosine, metadata=chunk_data.get("metadata", {}), doc_id=source_doc_id, title=doc_info.get("title", ""), )) candidates.sort(key=lambda x: x.score, reverse=True) return candidates[:top_k] async def delete_by_id(self, id: str) -> bool: """按文档 ID 删除""" if id not in self._documents: return False doc_info = self._documents[id] for chunk_id in doc_info.get("chunk_ids", []): self._chunks.pop(chunk_id, None) del self._documents[id] return True async def list_sources(self) -> list[SourceInfo]: """列出已摄取的文档""" sources = [] for doc_id, doc_info in self._documents.items(): sources.append(SourceInfo( source_id=doc_id, source_name=doc_info["title"], source_type=doc_info.get("format", "local"), document_count=len(doc_info.get("chunk_ids", [])), last_updated=doc_info.get("created_at"), )) return sources async def health_check(self) -> bool: """检查服务健康状态""" return True def _chunk_document(self, doc: Document) -> list[Chunk]: """将文档分块""" doc_format = doc.metadata.get("format", "text") if doc_format in ("markdown", "html"): return self._structural_chunker.chunk( doc.content, source_doc_id=doc.doc_id, metadata=doc.metadata, ) return self._text_chunker.chunk( doc.content, source_doc_id=doc.doc_id, metadata=doc.metadata, )