125 lines
3.7 KiB
Python
125 lines
3.7 KiB
Python
"""
|
||
EmbeddingService: Embedding 抽象基类及实现
|
||
- EmbeddingService: ABC
|
||
- OpenAIEmbedder: 调用 OpenAI text-embedding-3-small(httpx async)
|
||
- MockEmbedder: 返回随机向量,用于测试
|
||
"""
|
||
import logging
|
||
from abc import ABC, abstractmethod
|
||
from typing import Optional
|
||
|
||
import httpx
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class EmbeddingService(ABC):
|
||
"""Embedding 抽象基类"""
|
||
|
||
@abstractmethod
|
||
async def embed(self, text: str) -> list[float]:
|
||
"""单文本 embedding"""
|
||
|
||
@abstractmethod
|
||
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
||
"""批量 embedding"""
|
||
|
||
@property
|
||
@abstractmethod
|
||
def dimension(self) -> int:
|
||
"""向量维度"""
|
||
|
||
|
||
class OpenAIEmbedder(EmbeddingService):
|
||
"""OpenAI Embedding 实现(text-embedding-3-small, dim=1536)"""
|
||
|
||
def __init__(
|
||
self,
|
||
model: str = "text-embedding-3-small",
|
||
api_key: Optional[str] = None,
|
||
base_url: str = "https://api.openai.com/v1/embeddings",
|
||
timeout: float = 30.0,
|
||
):
|
||
self.model = model
|
||
self.base_url = base_url
|
||
self.timeout = timeout
|
||
self._dimension = 1536
|
||
|
||
# 获取 API Key:优先参数传入,其次从 settings 读取
|
||
if api_key:
|
||
self.api_key = api_key
|
||
else:
|
||
try:
|
||
from app.config import settings
|
||
self.api_key = getattr(settings, "OPENAI_API_KEY", "")
|
||
except Exception:
|
||
self.api_key = ""
|
||
|
||
@property
|
||
def dimension(self) -> int:
|
||
return self._dimension
|
||
|
||
async def embed(self, text: str) -> list[float]:
|
||
"""调用 OpenAI API 获取单条 embedding"""
|
||
results = await self.embed_batch([text])
|
||
return results[0]
|
||
|
||
async def embed_batch(
|
||
self, texts: list[str], batch_size: int = 100
|
||
) -> list[list[float]]:
|
||
"""批量处理,每批最多 batch_size 条"""
|
||
if not texts:
|
||
return []
|
||
|
||
all_embeddings: list[list[float]] = []
|
||
|
||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||
for i in range(0, len(texts), batch_size):
|
||
batch = texts[i : i + batch_size]
|
||
embeddings = await self._call_api(client, batch)
|
||
all_embeddings.extend(embeddings)
|
||
|
||
return all_embeddings
|
||
|
||
async def _call_api(
|
||
self, client: httpx.AsyncClient, texts: list[str]
|
||
) -> list[list[float]]:
|
||
"""发起单次 API 请求"""
|
||
headers = {
|
||
"Authorization": f"Bearer {self.api_key}",
|
||
"Content-Type": "application/json",
|
||
}
|
||
payload = {
|
||
"model": self.model,
|
||
"input": texts,
|
||
}
|
||
|
||
response = await client.post(self.base_url, json=payload, headers=headers)
|
||
response.raise_for_status()
|
||
|
||
data = response.json()
|
||
# OpenAI 返回格式:{"data": [{"index": i, "embedding": [...]}]}
|
||
items = sorted(data["data"], key=lambda x: x["index"])
|
||
return [item["embedding"] for item in items]
|
||
|
||
|
||
class MockEmbedder(EmbeddingService):
|
||
"""Mock 实现,返回随机向量,用于测试/开发环境"""
|
||
|
||
def __init__(self, dimension: int = 1536):
|
||
self._dimension = dimension
|
||
|
||
@property
|
||
def dimension(self) -> int:
|
||
return self._dimension
|
||
|
||
async def embed(self, text: str) -> list[float]:
|
||
import random
|
||
# 基于文本哈希生成确定性随机向量(相同文本返回相同向量)
|
||
seed = hash(text) % (2 ** 32)
|
||
rng = random.Random(seed)
|
||
return [rng.random() for _ in range(self._dimension)]
|
||
|
||
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
|
||
return [await self.embed(t) for t in texts]
|