geo/backend/app/services/knowledge/embedder.py

125 lines
3.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
EmbeddingService: Embedding 抽象基类及实现
- EmbeddingService: ABC
- OpenAIEmbedder: 调用 OpenAI text-embedding-3-smallhttpx async
- MockEmbedder: 返回随机向量,用于测试
"""
import logging
from abc import ABC, abstractmethod
from typing import Optional
import httpx
logger = logging.getLogger(__name__)
class EmbeddingService(ABC):
"""Embedding 抽象基类"""
@abstractmethod
async def embed(self, text: str) -> list[float]:
"""单文本 embedding"""
@abstractmethod
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
"""批量 embedding"""
@property
@abstractmethod
def dimension(self) -> int:
"""向量维度"""
class OpenAIEmbedder(EmbeddingService):
"""OpenAI Embedding 实现text-embedding-3-small, dim=1536"""
def __init__(
self,
model: str = "text-embedding-3-small",
api_key: Optional[str] = None,
base_url: str = "https://api.openai.com/v1/embeddings",
timeout: float = 30.0,
):
self.model = model
self.base_url = base_url
self.timeout = timeout
self._dimension = 1536
# 获取 API Key优先参数传入其次从 settings 读取
if api_key:
self.api_key = api_key
else:
try:
from app.config import settings
self.api_key = getattr(settings, "OPENAI_API_KEY", "")
except Exception:
self.api_key = ""
@property
def dimension(self) -> int:
return self._dimension
async def embed(self, text: str) -> list[float]:
"""调用 OpenAI API 获取单条 embedding"""
results = await self.embed_batch([text])
return results[0]
async def embed_batch(
self, texts: list[str], batch_size: int = 100
) -> list[list[float]]:
"""批量处理,每批最多 batch_size 条"""
if not texts:
return []
all_embeddings: list[list[float]] = []
async with httpx.AsyncClient(timeout=self.timeout) as client:
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
embeddings = await self._call_api(client, batch)
all_embeddings.extend(embeddings)
return all_embeddings
async def _call_api(
self, client: httpx.AsyncClient, texts: list[str]
) -> list[list[float]]:
"""发起单次 API 请求"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
payload = {
"model": self.model,
"input": texts,
}
response = await client.post(self.base_url, json=payload, headers=headers)
response.raise_for_status()
data = response.json()
# OpenAI 返回格式:{"data": [{"index": i, "embedding": [...]}]}
items = sorted(data["data"], key=lambda x: x["index"])
return [item["embedding"] for item in items]
class MockEmbedder(EmbeddingService):
"""Mock 实现,返回随机向量,用于测试/开发环境"""
def __init__(self, dimension: int = 1536):
self._dimension = dimension
@property
def dimension(self) -> int:
return self._dimension
async def embed(self, text: str) -> list[float]:
import random
# 基于文本哈希生成确定性随机向量(相同文本返回相同向量)
seed = hash(text) % (2 ** 32)
rng = random.Random(seed)
return [rng.random() for _ in range(self._dimension)]
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
return [await self.embed(t) for t in texts]