geo/docs/KB_Implementation_Guide.md

9.5 KiB
Raw Blame History

GEO知识库系统 - 技术实施指南

快速开始

1. 环境设置

# 后端依赖
cd backend
pip install langchain langchain-community langchain-postgres
pip install llama-index llama-index-readers-web llama-index-embeddings-huggingface
pip install unstructured[pdf,docx] sentence-transformers
pip install pgvector sqlalchemy asyncpg
pip install tiktoken  # Token计数

# 前端依赖如需开发UI
cd frontend
npm install recharts  # 检索结果可视化

2. PostgreSQL pgvector扩展

-- 在GEO数据库中启用pgvector
CREATE EXTENSION IF NOT EXISTS vector;

-- 创建知识库表
CREATE TABLE knowledge_chunks (
    id BIGSERIAL PRIMARY KEY,
    chunk_id VARCHAR(255) UNIQUE NOT NULL,
    content TEXT NOT NULL,
    embedding vector(1024),  -- bge-m3维度
    metadata JSONB,
    knowledge_base_type VARCHAR(50),  -- 'industry' 或 'company'
    source_id VARCHAR(255),
    created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
    updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);

-- 创建HNSW索引高性能搜索
CREATE INDEX ON knowledge_chunks
USING hnsw (embedding vector_cosine_ops)
WITH (m = 16, ef_construction = 64);

-- 创建全文搜索索引
CREATE INDEX knowledge_chunks_content_tsvector_idx
ON knowledge_chunks USING GIN (to_tsvector('english', content));

3. 核心RAG类实现

# backend/app/services/rag_service.py

from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_postgres import PGVector
from sentence_transformers import CrossEncoder
import asyncpg

class RAGService:
    def __init__(self):
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=512,
            chunk_overlap=50,
            separators=["\n\n\n", "\n\n", "\n", " ", ""]
        )
        
        self.embeddings = HuggingFaceEmbeddings(
            model_name="BAAI/bge-m3",
            model_kwargs={"device": "cuda"}
        )
        
        self.vector_store = PGVector(
            connection_string="postgresql+psycopg://...",
            embedding_function=self.embeddings,
            collection_name="knowledge_base"
        )
        
        self.cross_encoder = CrossEncoder("cross-encoder/mmarco-MiniLMv2-L12-H384")
    
    async def chunk_document(self, content: str) -> list:
        """文档分块"""
        chunks = self.text_splitter.split_text(content)
        return chunks
    
    async def hybrid_search(
        self, 
        query: str, 
        top_k: int = 10,
        alpha: float = 0.5
    ) -> list:
        """混合搜索BM25 + 向量 + 重排"""
        
        # 1. 向量搜索
        vector_results = await self.vector_search(query, top_k * 2)
        
        # 2. BM25搜索
        bm25_results = await self.bm25_search(query, top_k * 2)
        
        # 3. 分数融合
        hybrid_scores = self._fuse_scores(
            vector_results, 
            bm25_results, 
            alpha
        )
        
        # 4. 重排
        ranked = self.cross_encoder.rank(
            query,
            [doc['content'] for doc in hybrid_scores],
            top_k=top_k
        )
        
        return ranked

    async def vector_search(self, query: str, top_k: int) -> list:
        """纯向量搜索"""
        query_embedding = self.embeddings.embed_query(query)
        results = await self.vector_store.asimilarity_search_with_score(
            query,
            k=top_k
        )
        return results
    
    async def bm25_search(self, query: str, top_k: int) -> list:
        """BM25关键词搜索"""
        # 使用SQL全文搜索
        async with await asyncpg.connect(...) as conn:
            results = await conn.fetch(
                """
                SELECT id, content, 
                    ts_rank(to_tsvector('english', content),
                           plainto_tsquery('english', $1)) as score
                FROM knowledge_chunks
                WHERE to_tsvector('english', content) @@ 
                      plainto_tsquery('english', $1)
                ORDER BY score DESC
                LIMIT $2
                """,
                query, top_k
            )
        return results
    
    def _fuse_scores(self, vector_results, bm25_results, alpha):
        """融合两种搜索的分数"""
        # 实现细节...
        pass

4. FastAPI集成

# backend/app/api/knowledge_base.py

from fastapi import APIRouter, UploadFile, File
from app.services.rag_service import RAGService

router = APIRouter(prefix="/api/knowledge_base", tags=["knowledge_base"])
rag_service = RAGService()

@router.post("/upload")
async def upload_document(file: UploadFile = File(...)):
    """上传知识库文档"""
    content = await file.read()
    
    # 文档处理
    chunks = await rag_service.chunk_document(content.decode())
    
    # 存储到向量库
    for chunk in chunks:
        embedding = rag_service.embeddings.embed_query(chunk)
        # 存储到pgvector
    
    return {"chunks": len(chunks), "status": "success"}

@router.post("/search")
async def search(
    query: str,
    kb_type: str = "both",  # 'industry', 'company', 'both'
    hybrid: bool = True,
    alpha: float = 0.5
):
    """检索知识库"""
    if hybrid:
        results = await rag_service.hybrid_search(query, alpha=alpha)
    else:
        results = await rag_service.vector_search(query)
    
    return {
        "query": query,
        "results": results,
        "count": len(results)
    }

5. 知识库管理UINext.js组件参考

// frontend/components/knowledge_base/KnowledgeBaseManager.tsx

import React, { useState } from 'react'
import { Upload, Search } from 'lucide-react'

export const KnowledgeBaseManager = () => {
  const [uploadProgress, setUploadProgress] = useState(0)
  const [searchQuery, setSearchQuery] = useState('')
  const [searchResults, setSearchResults] = useState([])
  
  const handleFileUpload = async (file: File) => {
    const formData = new FormData()
    formData.append('file', file)
    
    const response = await fetch('/api/knowledge_base/upload', {
      method: 'POST',
      body: formData
    })
    
    const result = await response.json()
    console.log(`上传成功: ${result.chunks}个文档块`)
  }
  
  const handleSearch = async () => {
    const response = await fetch('/api/knowledge_base/search', {
      method: 'POST',
      headers: { 'Content-Type': 'application/json' },
      body: JSON.stringify({
        query: searchQuery,
        hybrid: true,
        alpha: 0.5
      })
    })
    
    const result = await response.json()
    setSearchResults(result.results)
  }
  
  return (
    <div className="p-6">
      <div className="mb-8">
        <h2 className="text-2xl font-bold mb-4">知识库管理</h2>
        
        <div className="bg-blue-50 p-4 rounded-lg">
          <Upload className="inline mr-2" />
          <span>拖拽文件上传或点击选择</span>
        </div>
      </div>
      
      <div>
        <input
          type="text"
          value={searchQuery}
          onChange={(e) => setSearchQuery(e.target.value)}
          placeholder="搜索知识库..."
          className="w-full px-4 py-2 border rounded-lg"
        />
        <button
          onClick={handleSearch}
          className="mt-2 px-4 py-2 bg-blue-600 text-white rounded-lg"
        >
          <Search className="inline mr-2" />
          搜索
        </button>
      </div>
      
      <div className="mt-6">
        {searchResults.map((result, idx) => (
          <div key={idx} className="mb-4 p-4 border rounded-lg">
            <p className="font-semibold mb-2">{result.title}</p>
            <p className="text-sm text-gray-600">{result.content}</p>
            <p className="text-xs text-gray-400 mt-2">相似度: {result.score.toFixed(3)}</p>
          </div>
        ))}
      </div>
    </div>
  )
}

性能调优检查清单

  • pgvector HNSW索引已创建
  • 全文搜索索引已创建
  • Embedding模型缓存已配置
  • Hybrid search alpha参数已根据业务调整
  • Cross-Encoder模型已选择性能vs精度权衡
  • Context Window管理已实现防止溢出
  • 查询日志已记录(用于后续优化)

监控和可观测性

# 添加到FastAPI中间件
from opentelemetry import trace, metrics
from opentelemetry.exporter.jaeger.thrift import JaegerExporter

tracer = trace.get_tracer(__name__)

@router.post("/search")
async def search(query: str):
    with tracer.start_as_current_span("kb_search") as span:
        span.set_attribute("query", query)
        
        # 执行搜索...
        
        span.set_attribute("results.count", len(results))
        span.set_attribute("search.latency_ms", latency)
    
    return results

测试脚本

# tests/test_rag_service.py

import pytest
from app.services.rag_service import RAGService

@pytest.fixture
def rag_service():
    return RAGService()

@pytest.mark.asyncio
async def test_hybrid_search(rag_service):
    """测试混合搜索"""
    results = await rag_service.hybrid_search(
        "GEO平台品牌优化",
        top_k=10,
        alpha=0.5
    )
    
    assert len(results) > 0
    assert results[0]['score'] > 0.5

@pytest.mark.asyncio
async def test_chunking(rag_service):
    """测试文档分块"""
    test_content = "这是一个很长的测试文档。" * 100
    chunks = await rag_service.chunk_document(test_content)
    
    assert len(chunks) > 1
    assert all(len(chunk) <= 2048 for chunk in chunks)