291 lines
11 KiB
Python
291 lines
11 KiB
Python
"""OpenAI Compatible Provider - 支持 OpenAI/DeepSeek/Anthropic 等兼容 API"""
|
||
|
||
import json
|
||
import logging
|
||
import time
|
||
|
||
import httpx
|
||
|
||
from agentkit.core.exceptions import LLMProviderError
|
||
from agentkit.llm.protocol import LLMProvider, LLMRequest, LLMResponse, StreamChunk, TokenUsage, ToolCall
|
||
from agentkit.llm.retry import (
|
||
CircuitBreaker,
|
||
CircuitBreakerConfig,
|
||
RetryConfig,
|
||
RetryPolicy,
|
||
)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class _StreamContext:
|
||
"""Wraps an httpx streaming response context manager for use with retry/circuit breaker.
|
||
|
||
The ``__aenter__`` returns the httpx response so callers can use
|
||
``async with ctx as response:`` naturally.
|
||
"""
|
||
|
||
def __init__(self, response_ctx, response):
|
||
self._response_ctx = response_ctx
|
||
self._response = response
|
||
|
||
async def __aenter__(self):
|
||
return self._response
|
||
|
||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||
return await self._response_ctx.__aexit__(exc_type, exc_val, exc_tb)
|
||
|
||
|
||
class OpenAICompatibleProvider(LLMProvider):
|
||
"""OpenAI 兼容 API Provider"""
|
||
|
||
def __init__(
|
||
self,
|
||
api_key: str,
|
||
base_url: str = "https://api.openai.com/v1",
|
||
default_model: str = "gpt-4o-mini",
|
||
retry_config: RetryConfig | None = None,
|
||
circuit_breaker_config: CircuitBreakerConfig | None = None,
|
||
):
|
||
self._api_key = api_key
|
||
self._base_url = base_url.rstrip("/")
|
||
self._default_model = default_model
|
||
self._client = httpx.AsyncClient(timeout=60.0)
|
||
self._retry_policy = RetryPolicy(retry_config) if retry_config else None
|
||
self._circuit_breaker = (
|
||
CircuitBreaker(circuit_breaker_config, provider="openai")
|
||
if circuit_breaker_config
|
||
else None
|
||
)
|
||
|
||
async def close(self) -> None:
|
||
"""关闭 HTTP 客户端连接池"""
|
||
await self._client.aclose()
|
||
|
||
async def chat(self, request: LLMRequest) -> LLMResponse:
|
||
"""发送 chat 请求(带 retry + circuit breaker)"""
|
||
if self._circuit_breaker and self._retry_policy:
|
||
return await self._circuit_breaker.execute(
|
||
self._retry_policy.execute, self._chat_impl, request
|
||
)
|
||
if self._retry_policy:
|
||
return await self._retry_policy.execute(self._chat_impl, request)
|
||
if self._circuit_breaker:
|
||
return await self._circuit_breaker.execute(self._chat_impl, request)
|
||
return await self._chat_impl(request)
|
||
|
||
async def _chat_impl(self, request: LLMRequest) -> LLMResponse:
|
||
"""发送 chat 请求"""
|
||
url = f"{self._base_url}/chat/completions"
|
||
headers = {
|
||
"Authorization": f"Bearer {self._api_key}",
|
||
"Content-Type": "application/json",
|
||
}
|
||
|
||
payload: dict = {
|
||
"model": request.model,
|
||
"messages": request.messages,
|
||
"temperature": request.temperature,
|
||
"max_tokens": request.max_tokens,
|
||
}
|
||
|
||
if request.tools:
|
||
payload["tools"] = request.tools
|
||
payload["tool_choice"] = request.tool_choice
|
||
|
||
logger.debug(f"Chat request to {url}: model={request.model}, messages={len(request.messages)}, tools={len(request.tools or [])}")
|
||
|
||
start = time.monotonic()
|
||
|
||
try:
|
||
resp = await self._client.post(url, json=payload, headers=headers)
|
||
except httpx.HTTPError as e:
|
||
raise LLMProviderError("openai", str(e)) from e
|
||
|
||
latency_ms = (time.monotonic() - start) * 1000
|
||
|
||
if resp.status_code != 200:
|
||
try:
|
||
error_body = resp.json()
|
||
error_msg = error_body.get("error", {}).get("message", "Request failed")
|
||
except Exception:
|
||
error_msg = f"HTTP {resp.status_code}"
|
||
logger.error(f"Chat request failed: HTTP {resp.status_code}, error: {error_msg}")
|
||
# 不在错误消息中暴露完整响应体,防止 API Key 泄露
|
||
raise LLMProviderError("openai", f"HTTP {resp.status_code}: {error_msg}")
|
||
|
||
data = resp.json()
|
||
choice = data["choices"][0]
|
||
message = choice["message"]
|
||
|
||
usage_data = data.get("usage", {})
|
||
usage = TokenUsage(
|
||
prompt_tokens=usage_data.get("prompt_tokens", 0),
|
||
completion_tokens=usage_data.get("completion_tokens", 0),
|
||
)
|
||
|
||
tool_calls: list[ToolCall] = []
|
||
raw_tool_calls = message.get("tool_calls")
|
||
if raw_tool_calls:
|
||
for tc in raw_tool_calls:
|
||
func = tc["function"]
|
||
arguments = json.loads(func["arguments"]) if isinstance(func["arguments"], str) else func["arguments"]
|
||
tool_calls.append(
|
||
ToolCall(
|
||
id=tc["id"],
|
||
name=func["name"],
|
||
arguments=arguments,
|
||
)
|
||
)
|
||
|
||
content = message.get("content") or ""
|
||
|
||
return LLMResponse(
|
||
content=content,
|
||
model=data.get("model", request.model),
|
||
usage=usage,
|
||
tool_calls=tool_calls,
|
||
latency_ms=latency_ms,
|
||
)
|
||
|
||
async def chat_stream(self, request: LLMRequest):
|
||
"""Stream chat response using SSE(带 retry + circuit breaker)"""
|
||
# For streaming, retry/circuit breaker only protect the connection phase.
|
||
# Once the stream is open, we iterate without retry.
|
||
if self._circuit_breaker and self._retry_policy:
|
||
ctx = await self._circuit_breaker.execute(
|
||
self._retry_policy.execute, self._open_stream, request
|
||
)
|
||
elif self._retry_policy:
|
||
ctx = await self._retry_policy.execute(self._open_stream, request)
|
||
elif self._circuit_breaker:
|
||
ctx = await self._circuit_breaker.execute(self._open_stream, request)
|
||
else:
|
||
ctx = await self._open_stream(request)
|
||
|
||
async with ctx as response:
|
||
async for chunk in self._iterate_stream(response, request):
|
||
yield chunk
|
||
|
||
async def _open_stream(self, request: LLMRequest):
|
||
"""Open the streaming HTTP connection; returns a _StreamContext."""
|
||
url = f"{self._base_url}/chat/completions"
|
||
headers = {
|
||
"Authorization": f"Bearer {self._api_key}",
|
||
"Content-Type": "application/json",
|
||
}
|
||
payload: dict = {
|
||
"model": request.model,
|
||
"messages": request.messages,
|
||
"temperature": request.temperature,
|
||
"max_tokens": request.max_tokens,
|
||
"stream": True,
|
||
}
|
||
if request.tools:
|
||
payload["tools"] = request.tools
|
||
payload["tool_choice"] = request.tool_choice
|
||
tool_names = [t.get("function", {}).get("name", "?") for t in request.tools]
|
||
logger.info(f"OpenAIProvider stream: model={request.model}, tools={len(request.tools)} {tool_names}")
|
||
else:
|
||
logger.info(f"OpenAIProvider stream: model={request.model}, NO tools")
|
||
|
||
response_ctx = self._client.stream("POST", url, json=payload, headers=headers)
|
||
response = await response_ctx.__aenter__()
|
||
|
||
if response.status_code != 200:
|
||
await response.aread()
|
||
await response_ctx.__aexit__(None, None, None)
|
||
# Parse error body for detailed message
|
||
try:
|
||
error_body = response.json()
|
||
error_msg = error_body.get("error", {}).get("message", f"HTTP {response.status_code}")
|
||
except Exception:
|
||
error_msg = f"HTTP {response.status_code}"
|
||
logger.error(f"Stream request failed: HTTP {response.status_code}, error: {error_msg}")
|
||
raise LLMProviderError("openai", f"HTTP {response.status_code}: {error_msg}")
|
||
|
||
return _StreamContext(response_ctx, response)
|
||
|
||
async def _iterate_stream(self, response, request: LLMRequest):
|
||
"""Iterate over an already-open SSE stream and yield StreamChunks."""
|
||
accumulated_tool_calls: dict[int, dict] = {} # index -> {id, name, arguments_str}
|
||
|
||
async for line in response.aiter_lines():
|
||
line = line.strip()
|
||
if not line or not line.startswith("data: "):
|
||
continue
|
||
data_str = line[6:] # Remove "data: " prefix
|
||
if data_str == "[DONE]":
|
||
break
|
||
|
||
try:
|
||
data = json.loads(data_str)
|
||
except json.JSONDecodeError:
|
||
continue
|
||
|
||
choices = data.get("choices", [])
|
||
if not choices:
|
||
# Usage-only chunk
|
||
usage_data = data.get("usage")
|
||
if usage_data:
|
||
yield StreamChunk(
|
||
content="",
|
||
model=data.get("model", request.model),
|
||
usage=TokenUsage(
|
||
prompt_tokens=usage_data.get("prompt_tokens", 0),
|
||
completion_tokens=usage_data.get("completion_tokens", 0),
|
||
),
|
||
is_final=True,
|
||
)
|
||
continue
|
||
|
||
delta = choices[0].get("delta", {})
|
||
content = delta.get("content", "")
|
||
|
||
# Accumulate tool calls from streaming
|
||
raw_tool_calls = delta.get("tool_calls")
|
||
if raw_tool_calls:
|
||
for tc in raw_tool_calls:
|
||
idx = tc.get("index", 0)
|
||
if idx not in accumulated_tool_calls:
|
||
accumulated_tool_calls[idx] = {
|
||
"id": tc.get("id", ""),
|
||
"name": "",
|
||
"arguments_str": "",
|
||
}
|
||
if tc.get("id"):
|
||
accumulated_tool_calls[idx]["id"] = tc["id"]
|
||
func = tc.get("function", {})
|
||
if func.get("name"):
|
||
accumulated_tool_calls[idx]["name"] = func["name"]
|
||
if func.get("arguments"):
|
||
accumulated_tool_calls[idx]["arguments_str"] += func["arguments"]
|
||
|
||
# Only yield content chunks (not empty deltas)
|
||
if content:
|
||
yield StreamChunk(
|
||
content=content,
|
||
model=data.get("model", request.model),
|
||
)
|
||
|
||
# If we accumulated tool calls, yield them as a final chunk
|
||
if accumulated_tool_calls:
|
||
tool_calls = []
|
||
for idx in sorted(accumulated_tool_calls.keys()):
|
||
tc_data = accumulated_tool_calls[idx]
|
||
try:
|
||
arguments = json.loads(tc_data["arguments_str"]) if tc_data["arguments_str"] else {}
|
||
except json.JSONDecodeError:
|
||
arguments = {"raw": tc_data["arguments_str"]}
|
||
tool_calls.append(ToolCall(
|
||
id=tc_data["id"],
|
||
name=tc_data["name"],
|
||
arguments=arguments,
|
||
))
|
||
yield StreamChunk(
|
||
content="",
|
||
model=request.model,
|
||
tool_calls=tool_calls,
|
||
is_final=True,
|
||
)
|