import asyncio import json import os from typing import AsyncIterator import httpx from .base import LLMError, LLMProvider, LLMResponse from .rate_limiter import get_rate_limiter # 支持的模型及其上下文长度(百炼 Coding Plan + OpenAI) _OPENAI_MODELS: dict[str, int] = { "gpt-4o": 128_000, "gpt-4o-mini": 128_000, "gpt-4-turbo": 128_000, "gpt-3.5-turbo": 16_385, # 百炼 Coding Plan 模型 "qwen3-coder-plus": 131_072, "qwen3-coder-next": 131_072, "qwen3.5-plus": 131_072, "qwen3-max-2026-01-23": 131_072, "kimi-k2.5": 131_072, "glm-5": 128_000, "glm-4.7": 128_000, "MiniMax-M2.5": 128_000, } _DEFAULT_MODEL = "qwen3-coder-plus" _MAX_RETRIES = 3 _RETRYABLE_STATUS = {429, 500, 502, 503} class OpenAIProvider(LLMProvider): """OpenAI LLM服务提供商(纯httpx实现)""" def __init__( self, api_key: str | None = None, model: str | None = None, base_url: str | None = None, ): self._api_key = api_key or os.getenv("OPENAI_API_KEY", "") if not self._api_key: raise LLMError("OPENAI_API_KEY 未配置", provider="openai") # 优先级: 显式参数 > OPENAI_MODEL 环境变量 > DEFAULT_LLM_MODEL 环境变量 > 默认值 self._model = model or os.getenv("OPENAI_MODEL") or os.getenv("DEFAULT_LLM_MODEL", _DEFAULT_MODEL) self._base_url = ( base_url or os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1") ).rstrip("/") self._endpoint = f"{self._base_url}/chat/completions" self._max_context = _OPENAI_MODELS.get(self._model, 128_000) self._client = httpx.AsyncClient( timeout=httpx.Timeout(connect=10.0, read=120.0, write=10.0, pool=10.0), headers={ "Authorization": f"Bearer {self._api_key}", "Content-Type": "application/json", }, ) # ---- LLMProvider 接口实现 ---- @property def provider_name(self) -> str: return "openai" @property def model_name(self) -> str: return self._model @property def max_context_length(self) -> int: return self._max_context async def chat( self, messages: list[dict], temperature: float = 0.7, max_tokens: int = 4096, stop: list[str] | None = None, ) -> LLMResponse: payload: dict = { "model": self._model, "messages": messages, "temperature": temperature, "max_tokens": max_tokens, } if stop: payload["stop"] = stop data = await self._request_with_retry(payload, stream=False) choice = data["choices"][0] content = choice["message"]["content"] usage = data.get("usage", {}) return LLMResponse( content=content, model=data.get("model", self._model), usage={ "prompt_tokens": usage.get("prompt_tokens", 0), "completion_tokens": usage.get("completion_tokens", 0), "total_tokens": usage.get("total_tokens", 0), }, ) async def chat_stream( self, messages: list[dict], temperature: float = 0.7, max_tokens: int = 4096, ) -> AsyncIterator[str]: payload = { "model": self._model, "messages": messages, "temperature": temperature, "max_tokens": max_tokens, "stream": True, } async for chunk in self._stream_request(payload): yield chunk # ---- 内部方法 ---- async def _request_with_retry(self, payload: dict, *, stream: bool = False) -> dict: """带重试的请求(指数退避:1s, 2s, 4s)""" # 全局速率限制 await get_rate_limiter().acquire() last_error: Exception | None = None for attempt in range(_MAX_RETRIES): try: response = await self._client.post(self._endpoint, json=payload) if response.status_code == 200: return response.json() # 可重试的状态码 if response.status_code in _RETRYABLE_STATUS: retry_after = response.headers.get("retry-after") wait = float(retry_after) if retry_after else 2**attempt last_error = LLMError( f"HTTP {response.status_code}: {response.text[:300]}", provider="openai", status_code=response.status_code, ) await asyncio.sleep(wait) continue # 不可重试的错误(如 401, 403, 400) raise LLMError( f"HTTP {response.status_code}: {response.text[:300]}", provider="openai", status_code=response.status_code, ) except httpx.TransportError as exc: last_error = LLMError( f"网络错误: {exc}", provider="openai", ) await asyncio.sleep(2**attempt) continue raise last_error or LLMError("超过最大重试次数", provider="openai") async def _stream_request(self, payload: dict) -> AsyncIterator[str]: """SSE流式请求""" # 全局速率限制 await get_rate_limiter().acquire() last_error: Exception | None = None for attempt in range(_MAX_RETRIES): try: async with self._client.stream( "POST", self._endpoint, json=payload ) as response: if response.status_code != 200: body = await response.aread() if response.status_code in _RETRYABLE_STATUS: last_error = LLMError( f"HTTP {response.status_code}: {body.decode()[:300]}", provider="openai", status_code=response.status_code, ) await asyncio.sleep(2**attempt) continue raise LLMError( f"HTTP {response.status_code}: {body.decode()[:300]}", provider="openai", status_code=response.status_code, ) async for line in response.aiter_lines(): if not line.startswith("data: "): continue data_str = line[len("data: "):] if data_str.strip() == "[DONE]": return try: chunk = json.loads(data_str) except json.JSONDecodeError: continue choices = chunk.get("choices", []) if choices: delta = choices[0].get("delta", {}) content = delta.get("content") if content: yield content return # 流正常结束 except httpx.TransportError as exc: last_error = LLMError( f"网络错误: {exc}", provider="openai", ) await asyncio.sleep(2**attempt) continue raise last_error or LLMError("超过最大重试次数", provider="openai") async def close(self) -> None: """关闭HTTP客户端""" await self._client.aclose() async def __aenter__(self) -> "OpenAIProvider": return self async def __aexit__(self, *exc) -> None: await self.close()