geo/backend/app/services/llm/openai_provider.py

260 lines
8.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import json
import os
import time
from typing import AsyncIterator
import httpx
from .base import LLMError, LLMProvider, LLMResponse
from .rate_limiter import get_rate_limiter
from app.monitoring.llm_metrics import get_llm_metrics
# 支持的模型及其上下文长度(百炼 Coding Plan + OpenAI
_OPENAI_MODELS: dict[str, int] = {
"gpt-4o": 128_000,
"gpt-4o-mini": 128_000,
"gpt-4-turbo": 128_000,
"gpt-3.5-turbo": 16_385,
# 百炼 Coding Plan 模型
"qwen3-coder-plus": 131_072,
"qwen3-coder-next": 131_072,
"qwen3.5-plus": 131_072,
"qwen3-max-2026-01-23": 131_072,
"kimi-k2.5": 131_072,
"glm-5": 128_000,
"glm-4.7": 128_000,
"MiniMax-M2.5": 128_000,
}
_DEFAULT_MODEL = "qwen3-coder-plus"
_MAX_RETRIES = 3
_RETRYABLE_STATUS = {429, 500, 502, 503}
class OpenAIProvider(LLMProvider):
"""OpenAI LLM服务提供商纯httpx实现"""
def __init__(
self,
api_key: str | None = None,
model: str | None = None,
base_url: str | None = None,
):
self._api_key = api_key or os.getenv("OPENAI_API_KEY", "")
if not self._api_key:
raise LLMError("OPENAI_API_KEY 未配置", provider="openai")
# 优先级: 显式参数 > OPENAI_MODEL 环境变量 > DEFAULT_LLM_MODEL 环境变量 > 默认值
self._model = model or os.getenv("OPENAI_MODEL") or os.getenv("DEFAULT_LLM_MODEL", _DEFAULT_MODEL)
self._base_url = (
base_url
or os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
).rstrip("/")
self._endpoint = f"{self._base_url}/chat/completions"
self._max_context = _OPENAI_MODELS.get(self._model, 128_000)
self._client = httpx.AsyncClient(
timeout=httpx.Timeout(connect=10.0, read=120.0, write=10.0, pool=10.0),
headers={
"Authorization": f"Bearer {self._api_key}",
"Content-Type": "application/json",
},
)
# ---- LLMProvider 接口实现 ----
@property
def provider_name(self) -> str:
return "openai"
@property
def model_name(self) -> str:
return self._model
@property
def max_context_length(self) -> int:
return self._max_context
async def chat(
self,
messages: list[dict],
temperature: float = 0.7,
max_tokens: int = 4096,
stop: list[str] | None = None,
) -> LLMResponse:
payload: dict = {
"model": self._model,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
}
if stop:
payload["stop"] = stop
start_time = time.perf_counter()
metrics = get_llm_metrics(self.provider_name, self._model)
try:
data = await self._request_with_retry(payload, stream=False)
choice = data["choices"][0]
content = choice["message"]["content"]
usage = data.get("usage", {})
duration = time.perf_counter() - start_time
metrics.record_request(
status="success",
duration=duration,
prompt_tokens=usage.get("prompt_tokens"),
completion_tokens=usage.get("completion_tokens"),
)
return LLMResponse(
content=content,
model=data.get("model", self._model),
usage={
"prompt_tokens": usage.get("prompt_tokens", 0),
"completion_tokens": usage.get("completion_tokens", 0),
"total_tokens": usage.get("total_tokens", 0),
},
)
except Exception as e:
duration = time.perf_counter() - start_time
metrics.record_request(
status="error",
duration=duration,
)
raise
async def chat_stream(
self,
messages: list[dict],
temperature: float = 0.7,
max_tokens: int = 4096,
) -> AsyncIterator[str]:
payload = {
"model": self._model,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
"stream": True,
}
async for chunk in self._stream_request(payload):
yield chunk
# ---- 内部方法 ----
async def _request_with_retry(self, payload: dict, *, stream: bool = False) -> dict:
"""带重试的请求指数退避1s, 2s, 4s"""
# 全局速率限制
await get_rate_limiter().acquire()
last_error: Exception | None = None
for attempt in range(_MAX_RETRIES):
try:
response = await self._client.post(self._endpoint, json=payload)
if response.status_code == 200:
return response.json()
# 可重试的状态码
if response.status_code in _RETRYABLE_STATUS:
retry_after = response.headers.get("retry-after")
wait = float(retry_after) if retry_after else 2**attempt
last_error = LLMError(
f"HTTP {response.status_code}: {response.text[:300]}",
provider="openai",
status_code=response.status_code,
)
await asyncio.sleep(wait)
continue
# 不可重试的错误(如 401, 403, 400
raise LLMError(
f"HTTP {response.status_code}: {response.text[:300]}",
provider="openai",
status_code=response.status_code,
)
except httpx.TransportError as exc:
last_error = LLMError(
f"网络错误: {exc}",
provider="openai",
)
await asyncio.sleep(2**attempt)
continue
raise last_error or LLMError("超过最大重试次数", provider="openai")
async def _stream_request(self, payload: dict) -> AsyncIterator[str]:
"""SSE流式请求"""
# 全局速率限制
await get_rate_limiter().acquire()
last_error: Exception | None = None
for attempt in range(_MAX_RETRIES):
try:
async with self._client.stream(
"POST", self._endpoint, json=payload
) as response:
if response.status_code != 200:
body = await response.aread()
if response.status_code in _RETRYABLE_STATUS:
last_error = LLMError(
f"HTTP {response.status_code}: {body.decode()[:300]}",
provider="openai",
status_code=response.status_code,
)
await asyncio.sleep(2**attempt)
continue
raise LLMError(
f"HTTP {response.status_code}: {body.decode()[:300]}",
provider="openai",
status_code=response.status_code,
)
async for line in response.aiter_lines():
if not line.startswith("data: "):
continue
data_str = line[len("data: "):]
if data_str.strip() == "[DONE]":
return
try:
chunk = json.loads(data_str)
except json.JSONDecodeError:
continue
choices = chunk.get("choices", [])
if choices:
delta = choices[0].get("delta", {})
content = delta.get("content")
if content:
yield content
return # 流正常结束
except httpx.TransportError as exc:
last_error = LLMError(
f"网络错误: {exc}",
provider="openai",
)
await asyncio.sleep(2**attempt)
continue
raise last_error or LLMError("超过最大重试次数", provider="openai")
async def close(self) -> None:
"""关闭HTTP客户端"""
await self._client.aclose()
async def __aenter__(self) -> "OpenAIProvider":
return self
async def __aexit__(self, *exc) -> None:
await self.close()