352 lines
9.5 KiB
Markdown
352 lines
9.5 KiB
Markdown
# GEO知识库系统 - 技术实施指南
|
||
|
||
## 快速开始
|
||
|
||
### 1. 环境设置
|
||
|
||
```bash
|
||
# 后端依赖
|
||
cd backend
|
||
pip install langchain langchain-community langchain-postgres
|
||
pip install llama-index llama-index-readers-web llama-index-embeddings-huggingface
|
||
pip install unstructured[pdf,docx] sentence-transformers
|
||
pip install pgvector sqlalchemy asyncpg
|
||
pip install tiktoken # Token计数
|
||
|
||
# 前端依赖(如需开发UI)
|
||
cd frontend
|
||
npm install recharts # 检索结果可视化
|
||
```
|
||
|
||
### 2. PostgreSQL pgvector扩展
|
||
|
||
```sql
|
||
-- 在GEO数据库中启用pgvector
|
||
CREATE EXTENSION IF NOT EXISTS vector;
|
||
|
||
-- 创建知识库表
|
||
CREATE TABLE knowledge_chunks (
|
||
id BIGSERIAL PRIMARY KEY,
|
||
chunk_id VARCHAR(255) UNIQUE NOT NULL,
|
||
content TEXT NOT NULL,
|
||
embedding vector(1024), -- bge-m3维度
|
||
metadata JSONB,
|
||
knowledge_base_type VARCHAR(50), -- 'industry' 或 'company'
|
||
source_id VARCHAR(255),
|
||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||
);
|
||
|
||
-- 创建HNSW索引(高性能搜索)
|
||
CREATE INDEX ON knowledge_chunks
|
||
USING hnsw (embedding vector_cosine_ops)
|
||
WITH (m = 16, ef_construction = 64);
|
||
|
||
-- 创建全文搜索索引
|
||
CREATE INDEX knowledge_chunks_content_tsvector_idx
|
||
ON knowledge_chunks USING GIN (to_tsvector('english', content));
|
||
```
|
||
|
||
### 3. 核心RAG类实现
|
||
|
||
```python
|
||
# backend/app/services/rag_service.py
|
||
|
||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||
from langchain_community.embeddings import HuggingFaceEmbeddings
|
||
from langchain_postgres import PGVector
|
||
from sentence_transformers import CrossEncoder
|
||
import asyncpg
|
||
|
||
class RAGService:
|
||
def __init__(self):
|
||
self.text_splitter = RecursiveCharacterTextSplitter(
|
||
chunk_size=512,
|
||
chunk_overlap=50,
|
||
separators=["\n\n\n", "\n\n", "\n", " ", ""]
|
||
)
|
||
|
||
self.embeddings = HuggingFaceEmbeddings(
|
||
model_name="BAAI/bge-m3",
|
||
model_kwargs={"device": "cuda"}
|
||
)
|
||
|
||
self.vector_store = PGVector(
|
||
connection_string="postgresql+psycopg://...",
|
||
embedding_function=self.embeddings,
|
||
collection_name="knowledge_base"
|
||
)
|
||
|
||
self.cross_encoder = CrossEncoder("cross-encoder/mmarco-MiniLMv2-L12-H384")
|
||
|
||
async def chunk_document(self, content: str) -> list:
|
||
"""文档分块"""
|
||
chunks = self.text_splitter.split_text(content)
|
||
return chunks
|
||
|
||
async def hybrid_search(
|
||
self,
|
||
query: str,
|
||
top_k: int = 10,
|
||
alpha: float = 0.5
|
||
) -> list:
|
||
"""混合搜索:BM25 + 向量 + 重排"""
|
||
|
||
# 1. 向量搜索
|
||
vector_results = await self.vector_search(query, top_k * 2)
|
||
|
||
# 2. BM25搜索
|
||
bm25_results = await self.bm25_search(query, top_k * 2)
|
||
|
||
# 3. 分数融合
|
||
hybrid_scores = self._fuse_scores(
|
||
vector_results,
|
||
bm25_results,
|
||
alpha
|
||
)
|
||
|
||
# 4. 重排
|
||
ranked = self.cross_encoder.rank(
|
||
query,
|
||
[doc['content'] for doc in hybrid_scores],
|
||
top_k=top_k
|
||
)
|
||
|
||
return ranked
|
||
|
||
async def vector_search(self, query: str, top_k: int) -> list:
|
||
"""纯向量搜索"""
|
||
query_embedding = self.embeddings.embed_query(query)
|
||
results = await self.vector_store.asimilarity_search_with_score(
|
||
query,
|
||
k=top_k
|
||
)
|
||
return results
|
||
|
||
async def bm25_search(self, query: str, top_k: int) -> list:
|
||
"""BM25关键词搜索"""
|
||
# 使用SQL全文搜索
|
||
async with await asyncpg.connect(...) as conn:
|
||
results = await conn.fetch(
|
||
"""
|
||
SELECT id, content,
|
||
ts_rank(to_tsvector('english', content),
|
||
plainto_tsquery('english', $1)) as score
|
||
FROM knowledge_chunks
|
||
WHERE to_tsvector('english', content) @@
|
||
plainto_tsquery('english', $1)
|
||
ORDER BY score DESC
|
||
LIMIT $2
|
||
""",
|
||
query, top_k
|
||
)
|
||
return results
|
||
|
||
def _fuse_scores(self, vector_results, bm25_results, alpha):
|
||
"""融合两种搜索的分数"""
|
||
# 实现细节...
|
||
pass
|
||
```
|
||
|
||
### 4. FastAPI集成
|
||
|
||
```python
|
||
# backend/app/api/knowledge_base.py
|
||
|
||
from fastapi import APIRouter, UploadFile, File
|
||
from app.services.rag_service import RAGService
|
||
|
||
router = APIRouter(prefix="/api/knowledge_base", tags=["knowledge_base"])
|
||
rag_service = RAGService()
|
||
|
||
@router.post("/upload")
|
||
async def upload_document(file: UploadFile = File(...)):
|
||
"""上传知识库文档"""
|
||
content = await file.read()
|
||
|
||
# 文档处理
|
||
chunks = await rag_service.chunk_document(content.decode())
|
||
|
||
# 存储到向量库
|
||
for chunk in chunks:
|
||
embedding = rag_service.embeddings.embed_query(chunk)
|
||
# 存储到pgvector
|
||
|
||
return {"chunks": len(chunks), "status": "success"}
|
||
|
||
@router.post("/search")
|
||
async def search(
|
||
query: str,
|
||
kb_type: str = "both", # 'industry', 'company', 'both'
|
||
hybrid: bool = True,
|
||
alpha: float = 0.5
|
||
):
|
||
"""检索知识库"""
|
||
if hybrid:
|
||
results = await rag_service.hybrid_search(query, alpha=alpha)
|
||
else:
|
||
results = await rag_service.vector_search(query)
|
||
|
||
return {
|
||
"query": query,
|
||
"results": results,
|
||
"count": len(results)
|
||
}
|
||
```
|
||
|
||
### 5. 知识库管理UI(Next.js组件参考)
|
||
|
||
```typescript
|
||
// frontend/components/knowledge_base/KnowledgeBaseManager.tsx
|
||
|
||
import React, { useState } from 'react'
|
||
import { Upload, Search } from 'lucide-react'
|
||
|
||
export const KnowledgeBaseManager = () => {
|
||
const [uploadProgress, setUploadProgress] = useState(0)
|
||
const [searchQuery, setSearchQuery] = useState('')
|
||
const [searchResults, setSearchResults] = useState([])
|
||
|
||
const handleFileUpload = async (file: File) => {
|
||
const formData = new FormData()
|
||
formData.append('file', file)
|
||
|
||
const response = await fetch('/api/knowledge_base/upload', {
|
||
method: 'POST',
|
||
body: formData
|
||
})
|
||
|
||
const result = await response.json()
|
||
console.log(`上传成功: ${result.chunks}个文档块`)
|
||
}
|
||
|
||
const handleSearch = async () => {
|
||
const response = await fetch('/api/knowledge_base/search', {
|
||
method: 'POST',
|
||
headers: { 'Content-Type': 'application/json' },
|
||
body: JSON.stringify({
|
||
query: searchQuery,
|
||
hybrid: true,
|
||
alpha: 0.5
|
||
})
|
||
})
|
||
|
||
const result = await response.json()
|
||
setSearchResults(result.results)
|
||
}
|
||
|
||
return (
|
||
<div className="p-6">
|
||
<div className="mb-8">
|
||
<h2 className="text-2xl font-bold mb-4">知识库管理</h2>
|
||
|
||
<div className="bg-blue-50 p-4 rounded-lg">
|
||
<Upload className="inline mr-2" />
|
||
<span>拖拽文件上传或点击选择</span>
|
||
</div>
|
||
</div>
|
||
|
||
<div>
|
||
<input
|
||
type="text"
|
||
value={searchQuery}
|
||
onChange={(e) => setSearchQuery(e.target.value)}
|
||
placeholder="搜索知识库..."
|
||
className="w-full px-4 py-2 border rounded-lg"
|
||
/>
|
||
<button
|
||
onClick={handleSearch}
|
||
className="mt-2 px-4 py-2 bg-blue-600 text-white rounded-lg"
|
||
>
|
||
<Search className="inline mr-2" />
|
||
搜索
|
||
</button>
|
||
</div>
|
||
|
||
<div className="mt-6">
|
||
{searchResults.map((result, idx) => (
|
||
<div key={idx} className="mb-4 p-4 border rounded-lg">
|
||
<p className="font-semibold mb-2">{result.title}</p>
|
||
<p className="text-sm text-gray-600">{result.content}</p>
|
||
<p className="text-xs text-gray-400 mt-2">相似度: {result.score.toFixed(3)}</p>
|
||
</div>
|
||
))}
|
||
</div>
|
||
</div>
|
||
)
|
||
}
|
||
```
|
||
|
||
---
|
||
|
||
## 性能调优检查清单
|
||
|
||
- [ ] pgvector HNSW索引已创建
|
||
- [ ] 全文搜索索引已创建
|
||
- [ ] Embedding模型缓存已配置
|
||
- [ ] Hybrid search alpha参数已根据业务调整
|
||
- [ ] Cross-Encoder模型已选择(性能vs精度权衡)
|
||
- [ ] Context Window管理已实现(防止溢出)
|
||
- [ ] 查询日志已记录(用于后续优化)
|
||
|
||
---
|
||
|
||
## 监控和可观测性
|
||
|
||
```python
|
||
# 添加到FastAPI中间件
|
||
from opentelemetry import trace, metrics
|
||
from opentelemetry.exporter.jaeger.thrift import JaegerExporter
|
||
|
||
tracer = trace.get_tracer(__name__)
|
||
|
||
@router.post("/search")
|
||
async def search(query: str):
|
||
with tracer.start_as_current_span("kb_search") as span:
|
||
span.set_attribute("query", query)
|
||
|
||
# 执行搜索...
|
||
|
||
span.set_attribute("results.count", len(results))
|
||
span.set_attribute("search.latency_ms", latency)
|
||
|
||
return results
|
||
```
|
||
|
||
---
|
||
|
||
## 测试脚本
|
||
|
||
```python
|
||
# tests/test_rag_service.py
|
||
|
||
import pytest
|
||
from app.services.rag_service import RAGService
|
||
|
||
@pytest.fixture
|
||
def rag_service():
|
||
return RAGService()
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_hybrid_search(rag_service):
|
||
"""测试混合搜索"""
|
||
results = await rag_service.hybrid_search(
|
||
"GEO平台品牌优化",
|
||
top_k=10,
|
||
alpha=0.5
|
||
)
|
||
|
||
assert len(results) > 0
|
||
assert results[0]['score'] > 0.5
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_chunking(rag_service):
|
||
"""测试文档分块"""
|
||
test_content = "这是一个很长的测试文档。" * 100
|
||
chunks = await rag_service.chunk_document(test_content)
|
||
|
||
assert len(chunks) > 1
|
||
assert all(len(chunk) <= 2048 for chunk in chunks)
|
||
```
|
||
|