import asyncio import json import os from typing import AsyncIterator import httpx from .base import LLMError, LLMProvider, LLMResponse from .rate_limiter import get_rate_limiter _DEFAULT_MODEL = "deepseek-chat" _DEFAULT_MAX_CONTEXT = 64_000 _MAX_RETRIES = 3 _RETRYABLE_STATUS = {429, 500, 502, 503} class DeepSeekProvider(LLMProvider): """DeepSeek LLM服务提供商(OpenAI兼容格式,纯httpx实现)""" def __init__( self, api_key: str | None = None, model: str | None = None, base_url: str | None = None, ): self._api_key = api_key or os.getenv("DEEPSEEK_API_KEY", "") if not self._api_key: raise LLMError("DEEPSEEK_API_KEY 未配置", provider="deepseek") self._model = model or os.getenv("DEEPSEEK_MODEL", _DEFAULT_MODEL) self._base_url = ( base_url or os.getenv("DEEPSEEK_BASE_URL", "https://api.deepseek.com/v1") ).rstrip("/") self._endpoint = f"{self._base_url}/chat/completions" self._max_context = int( os.getenv("DEEPSEEK_MAX_CONTEXT", str(_DEFAULT_MAX_CONTEXT)) ) self._client = httpx.AsyncClient( timeout=httpx.Timeout(connect=10.0, read=120.0, write=10.0, pool=10.0), headers={ "Authorization": f"Bearer {self._api_key}", "Content-Type": "application/json", }, ) # ---- LLMProvider 接口实现 ---- @property def provider_name(self) -> str: return "deepseek" @property def model_name(self) -> str: return self._model @property def max_context_length(self) -> int: return self._max_context async def chat( self, messages: list[dict], temperature: float = 0.7, max_tokens: int = 4096, stop: list[str] | None = None, ) -> LLMResponse: payload: dict = { "model": self._model, "messages": messages, "temperature": temperature, "max_tokens": max_tokens, } if stop: payload["stop"] = stop data = await self._request_with_retry(payload, stream=False) choice = data["choices"][0] content = choice["message"]["content"] usage = data.get("usage", {}) return LLMResponse( content=content, model=data.get("model", self._model), usage={ "prompt_tokens": usage.get("prompt_tokens", 0), "completion_tokens": usage.get("completion_tokens", 0), "total_tokens": usage.get("total_tokens", 0), }, ) async def chat_stream( self, messages: list[dict], temperature: float = 0.7, max_tokens: int = 4096, ) -> AsyncIterator[str]: payload = { "model": self._model, "messages": messages, "temperature": temperature, "max_tokens": max_tokens, "stream": True, } async for chunk in self._stream_request(payload): yield chunk # ---- 内部方法 ---- async def _request_with_retry(self, payload: dict, *, stream: bool = False) -> dict: """带重试的请求(指数退避:1s, 2s, 4s)""" # 全局速率限制 await get_rate_limiter().acquire() last_error: Exception | None = None for attempt in range(_MAX_RETRIES): try: response = await self._client.post(self._endpoint, json=payload) if response.status_code == 200: return response.json() if response.status_code in _RETRYABLE_STATUS: retry_after = response.headers.get("retry-after") wait = float(retry_after) if retry_after else 2**attempt last_error = LLMError( f"HTTP {response.status_code}: {response.text[:300]}", provider="deepseek", status_code=response.status_code, ) await asyncio.sleep(wait) continue raise LLMError( f"HTTP {response.status_code}: {response.text[:300]}", provider="deepseek", status_code=response.status_code, ) except httpx.TransportError as exc: last_error = LLMError( f"网络错误: {exc}", provider="deepseek", ) await asyncio.sleep(2**attempt) continue raise last_error or LLMError("超过最大重试次数", provider="deepseek") async def _stream_request(self, payload: dict) -> AsyncIterator[str]: """SSE流式请求(OpenAI兼容格式)""" # 全局速率限制 await get_rate_limiter().acquire() last_error: Exception | None = None for attempt in range(_MAX_RETRIES): try: async with self._client.stream( "POST", self._endpoint, json=payload ) as response: if response.status_code != 200: body = await response.aread() if response.status_code in _RETRYABLE_STATUS: last_error = LLMError( f"HTTP {response.status_code}: {body.decode()[:300]}", provider="deepseek", status_code=response.status_code, ) await asyncio.sleep(2**attempt) continue raise LLMError( f"HTTP {response.status_code}: {body.decode()[:300]}", provider="deepseek", status_code=response.status_code, ) async for line in response.aiter_lines(): if not line.startswith("data: "): continue data_str = line[len("data: "):] if data_str.strip() == "[DONE]": return try: chunk = json.loads(data_str) except json.JSONDecodeError: continue choices = chunk.get("choices", []) if choices: delta = choices[0].get("delta", {}) content = delta.get("content") if content: yield content return # 流正常结束 except httpx.TransportError as exc: last_error = LLMError( f"网络错误: {exc}", provider="deepseek", ) await asyncio.sleep(2**attempt) continue raise last_error or LLMError("超过最大重试次数", provider="deepseek") async def close(self) -> None: """关闭HTTP客户端""" await self._client.aclose() async def __aenter__(self) -> "DeepSeekProvider": return self async def __aexit__(self, *exc) -> None: await self.close()