243 lines
8.1 KiB
Python
243 lines
8.1 KiB
Python
import asyncio
|
||
import json
|
||
import os
|
||
import time
|
||
from typing import AsyncIterator
|
||
|
||
import httpx
|
||
|
||
from .base import LLMError, LLMProvider, LLMResponse
|
||
from .rate_limiter import get_rate_limiter
|
||
from app.monitoring.llm_metrics import get_llm_metrics
|
||
|
||
_DEFAULT_MODEL = "deepseek-chat"
|
||
_DEFAULT_MAX_CONTEXT = 64_000
|
||
_MAX_RETRIES = 3
|
||
_RETRYABLE_STATUS = {429, 500, 502, 503}
|
||
|
||
|
||
class DeepSeekProvider(LLMProvider):
|
||
"""DeepSeek LLM服务提供商(OpenAI兼容格式,纯httpx实现)"""
|
||
|
||
def __init__(
|
||
self,
|
||
api_key: str | None = None,
|
||
model: str | None = None,
|
||
base_url: str | None = None,
|
||
):
|
||
self._api_key = api_key or os.getenv("DEEPSEEK_API_KEY", "")
|
||
if not self._api_key:
|
||
raise LLMError("DEEPSEEK_API_KEY 未配置", provider="deepseek")
|
||
|
||
self._model = model or os.getenv("DEEPSEEK_MODEL", _DEFAULT_MODEL)
|
||
self._base_url = (
|
||
base_url
|
||
or os.getenv("DEEPSEEK_BASE_URL", "https://api.deepseek.com/v1")
|
||
).rstrip("/")
|
||
self._endpoint = f"{self._base_url}/chat/completions"
|
||
self._max_context = int(
|
||
os.getenv("DEEPSEEK_MAX_CONTEXT", str(_DEFAULT_MAX_CONTEXT))
|
||
)
|
||
|
||
self._client = httpx.AsyncClient(
|
||
timeout=httpx.Timeout(connect=10.0, read=120.0, write=10.0, pool=10.0),
|
||
headers={
|
||
"Authorization": f"Bearer {self._api_key}",
|
||
"Content-Type": "application/json",
|
||
},
|
||
)
|
||
|
||
# ---- LLMProvider 接口实现 ----
|
||
|
||
@property
|
||
def provider_name(self) -> str:
|
||
return "deepseek"
|
||
|
||
@property
|
||
def model_name(self) -> str:
|
||
return self._model
|
||
|
||
@property
|
||
def max_context_length(self) -> int:
|
||
return self._max_context
|
||
|
||
async def chat(
|
||
self,
|
||
messages: list[dict],
|
||
temperature: float = 0.7,
|
||
max_tokens: int = 4096,
|
||
stop: list[str] | None = None,
|
||
) -> LLMResponse:
|
||
payload: dict = {
|
||
"model": self._model,
|
||
"messages": messages,
|
||
"temperature": temperature,
|
||
"max_tokens": max_tokens,
|
||
}
|
||
if stop:
|
||
payload["stop"] = stop
|
||
|
||
start_time = time.perf_counter()
|
||
metrics = get_llm_metrics(self.provider_name, self._model)
|
||
|
||
try:
|
||
data = await self._request_with_retry(payload, stream=False)
|
||
|
||
choice = data["choices"][0]
|
||
content = choice["message"]["content"]
|
||
usage = data.get("usage", {})
|
||
|
||
duration = time.perf_counter() - start_time
|
||
metrics.record_request(
|
||
status="success",
|
||
duration=duration,
|
||
prompt_tokens=usage.get("prompt_tokens"),
|
||
completion_tokens=usage.get("completion_tokens"),
|
||
)
|
||
|
||
return LLMResponse(
|
||
content=content,
|
||
model=data.get("model", self._model),
|
||
usage={
|
||
"prompt_tokens": usage.get("prompt_tokens", 0),
|
||
"completion_tokens": usage.get("completion_tokens", 0),
|
||
"total_tokens": usage.get("total_tokens", 0),
|
||
},
|
||
)
|
||
except Exception as e:
|
||
duration = time.perf_counter() - start_time
|
||
metrics.record_request(
|
||
status="error",
|
||
duration=duration,
|
||
)
|
||
raise
|
||
|
||
async def chat_stream(
|
||
self,
|
||
messages: list[dict],
|
||
temperature: float = 0.7,
|
||
max_tokens: int = 4096,
|
||
) -> AsyncIterator[str]:
|
||
payload = {
|
||
"model": self._model,
|
||
"messages": messages,
|
||
"temperature": temperature,
|
||
"max_tokens": max_tokens,
|
||
"stream": True,
|
||
}
|
||
|
||
async for chunk in self._stream_request(payload):
|
||
yield chunk
|
||
|
||
# ---- 内部方法 ----
|
||
|
||
async def _request_with_retry(self, payload: dict, *, stream: bool = False) -> dict:
|
||
"""带重试的请求(指数退避:1s, 2s, 4s)"""
|
||
# 全局速率限制
|
||
await get_rate_limiter().acquire()
|
||
|
||
last_error: Exception | None = None
|
||
|
||
for attempt in range(_MAX_RETRIES):
|
||
try:
|
||
response = await self._client.post(self._endpoint, json=payload)
|
||
|
||
if response.status_code == 200:
|
||
return response.json()
|
||
|
||
if response.status_code in _RETRYABLE_STATUS:
|
||
retry_after = response.headers.get("retry-after")
|
||
wait = float(retry_after) if retry_after else 2**attempt
|
||
last_error = LLMError(
|
||
f"HTTP {response.status_code}: {response.text[:300]}",
|
||
provider="deepseek",
|
||
status_code=response.status_code,
|
||
)
|
||
await asyncio.sleep(wait)
|
||
continue
|
||
|
||
raise LLMError(
|
||
f"HTTP {response.status_code}: {response.text[:300]}",
|
||
provider="deepseek",
|
||
status_code=response.status_code,
|
||
)
|
||
|
||
except httpx.TransportError as exc:
|
||
last_error = LLMError(
|
||
f"网络错误: {exc}",
|
||
provider="deepseek",
|
||
)
|
||
await asyncio.sleep(2**attempt)
|
||
continue
|
||
|
||
raise last_error or LLMError("超过最大重试次数", provider="deepseek")
|
||
|
||
async def _stream_request(self, payload: dict) -> AsyncIterator[str]:
|
||
"""SSE流式请求(OpenAI兼容格式)"""
|
||
# 全局速率限制
|
||
await get_rate_limiter().acquire()
|
||
|
||
last_error: Exception | None = None
|
||
|
||
for attempt in range(_MAX_RETRIES):
|
||
try:
|
||
async with self._client.stream(
|
||
"POST", self._endpoint, json=payload
|
||
) as response:
|
||
if response.status_code != 200:
|
||
body = await response.aread()
|
||
if response.status_code in _RETRYABLE_STATUS:
|
||
last_error = LLMError(
|
||
f"HTTP {response.status_code}: {body.decode()[:300]}",
|
||
provider="deepseek",
|
||
status_code=response.status_code,
|
||
)
|
||
await asyncio.sleep(2**attempt)
|
||
continue
|
||
raise LLMError(
|
||
f"HTTP {response.status_code}: {body.decode()[:300]}",
|
||
provider="deepseek",
|
||
status_code=response.status_code,
|
||
)
|
||
|
||
async for line in response.aiter_lines():
|
||
if not line.startswith("data: "):
|
||
continue
|
||
data_str = line[len("data: "):]
|
||
if data_str.strip() == "[DONE]":
|
||
return
|
||
|
||
try:
|
||
chunk = json.loads(data_str)
|
||
except json.JSONDecodeError:
|
||
continue
|
||
|
||
choices = chunk.get("choices", [])
|
||
if choices:
|
||
delta = choices[0].get("delta", {})
|
||
content = delta.get("content")
|
||
if content:
|
||
yield content
|
||
|
||
return # 流正常结束
|
||
|
||
except httpx.TransportError as exc:
|
||
last_error = LLMError(
|
||
f"网络错误: {exc}",
|
||
provider="deepseek",
|
||
)
|
||
await asyncio.sleep(2**attempt)
|
||
continue
|
||
|
||
raise last_error or LLMError("超过最大重试次数", provider="deepseek")
|
||
|
||
async def close(self) -> None:
|
||
"""关闭HTTP客户端"""
|
||
await self._client.aclose()
|
||
|
||
async def __aenter__(self) -> "DeepSeekProvider":
|
||
return self
|
||
|
||
async def __aexit__(self, *exc) -> None:
|
||
await self.close()
|