fischer-agentkit/src/agentkit/llm/providers/openai.py

291 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""OpenAI Compatible Provider - 支持 OpenAI/DeepSeek/Anthropic 等兼容 API"""
import json
import logging
import time
import httpx
from agentkit.core.exceptions import LLMProviderError
from agentkit.llm.protocol import LLMProvider, LLMRequest, LLMResponse, StreamChunk, TokenUsage, ToolCall
from agentkit.llm.retry import (
CircuitBreaker,
CircuitBreakerConfig,
RetryConfig,
RetryPolicy,
)
logger = logging.getLogger(__name__)
class _StreamContext:
"""Wraps an httpx streaming response context manager for use with retry/circuit breaker.
The ``__aenter__`` returns the httpx response so callers can use
``async with ctx as response:`` naturally.
"""
def __init__(self, response_ctx, response):
self._response_ctx = response_ctx
self._response = response
async def __aenter__(self):
return self._response
async def __aexit__(self, exc_type, exc_val, exc_tb):
return await self._response_ctx.__aexit__(exc_type, exc_val, exc_tb)
class OpenAICompatibleProvider(LLMProvider):
"""OpenAI 兼容 API Provider"""
def __init__(
self,
api_key: str,
base_url: str = "https://api.openai.com/v1",
default_model: str = "gpt-4o-mini",
retry_config: RetryConfig | None = None,
circuit_breaker_config: CircuitBreakerConfig | None = None,
):
self._api_key = api_key
self._base_url = base_url.rstrip("/")
self._default_model = default_model
self._client = httpx.AsyncClient(timeout=60.0)
self._retry_policy = RetryPolicy(retry_config) if retry_config else None
self._circuit_breaker = (
CircuitBreaker(circuit_breaker_config, provider="openai")
if circuit_breaker_config
else None
)
async def close(self) -> None:
"""关闭 HTTP 客户端连接池"""
await self._client.aclose()
async def chat(self, request: LLMRequest) -> LLMResponse:
"""发送 chat 请求(带 retry + circuit breaker"""
if self._circuit_breaker and self._retry_policy:
return await self._circuit_breaker.execute(
self._retry_policy.execute, self._chat_impl, request
)
if self._retry_policy:
return await self._retry_policy.execute(self._chat_impl, request)
if self._circuit_breaker:
return await self._circuit_breaker.execute(self._chat_impl, request)
return await self._chat_impl(request)
async def _chat_impl(self, request: LLMRequest) -> LLMResponse:
"""发送 chat 请求"""
url = f"{self._base_url}/chat/completions"
headers = {
"Authorization": f"Bearer {self._api_key}",
"Content-Type": "application/json",
}
payload: dict = {
"model": request.model,
"messages": request.messages,
"temperature": request.temperature,
"max_tokens": request.max_tokens,
}
if request.tools:
payload["tools"] = request.tools
payload["tool_choice"] = request.tool_choice
logger.debug(f"Chat request to {url}: model={request.model}, messages={len(request.messages)}, tools={len(request.tools or [])}")
start = time.monotonic()
try:
resp = await self._client.post(url, json=payload, headers=headers)
except httpx.HTTPError as e:
raise LLMProviderError("openai", str(e)) from e
latency_ms = (time.monotonic() - start) * 1000
if resp.status_code != 200:
try:
error_body = resp.json()
error_msg = error_body.get("error", {}).get("message", "Request failed")
except Exception:
error_msg = f"HTTP {resp.status_code}"
logger.error(f"Chat request failed: HTTP {resp.status_code}, error: {error_msg}")
# 不在错误消息中暴露完整响应体,防止 API Key 泄露
raise LLMProviderError("openai", f"HTTP {resp.status_code}: {error_msg}")
data = resp.json()
choice = data["choices"][0]
message = choice["message"]
usage_data = data.get("usage", {})
usage = TokenUsage(
prompt_tokens=usage_data.get("prompt_tokens", 0),
completion_tokens=usage_data.get("completion_tokens", 0),
)
tool_calls: list[ToolCall] = []
raw_tool_calls = message.get("tool_calls")
if raw_tool_calls:
for tc in raw_tool_calls:
func = tc["function"]
arguments = json.loads(func["arguments"]) if isinstance(func["arguments"], str) else func["arguments"]
tool_calls.append(
ToolCall(
id=tc["id"],
name=func["name"],
arguments=arguments,
)
)
content = message.get("content") or ""
return LLMResponse(
content=content,
model=data.get("model", request.model),
usage=usage,
tool_calls=tool_calls,
latency_ms=latency_ms,
)
async def chat_stream(self, request: LLMRequest):
"""Stream chat response using SSE带 retry + circuit breaker"""
# For streaming, retry/circuit breaker only protect the connection phase.
# Once the stream is open, we iterate without retry.
if self._circuit_breaker and self._retry_policy:
ctx = await self._circuit_breaker.execute(
self._retry_policy.execute, self._open_stream, request
)
elif self._retry_policy:
ctx = await self._retry_policy.execute(self._open_stream, request)
elif self._circuit_breaker:
ctx = await self._circuit_breaker.execute(self._open_stream, request)
else:
ctx = await self._open_stream(request)
async with ctx as response:
async for chunk in self._iterate_stream(response, request):
yield chunk
async def _open_stream(self, request: LLMRequest):
"""Open the streaming HTTP connection; returns a _StreamContext."""
url = f"{self._base_url}/chat/completions"
headers = {
"Authorization": f"Bearer {self._api_key}",
"Content-Type": "application/json",
}
payload: dict = {
"model": request.model,
"messages": request.messages,
"temperature": request.temperature,
"max_tokens": request.max_tokens,
"stream": True,
}
if request.tools:
payload["tools"] = request.tools
payload["tool_choice"] = request.tool_choice
tool_names = [t.get("function", {}).get("name", "?") for t in request.tools]
logger.info(f"OpenAIProvider stream: model={request.model}, tools={len(request.tools)} {tool_names}")
else:
logger.info(f"OpenAIProvider stream: model={request.model}, NO tools")
response_ctx = self._client.stream("POST", url, json=payload, headers=headers)
response = await response_ctx.__aenter__()
if response.status_code != 200:
await response.aread()
await response_ctx.__aexit__(None, None, None)
# Parse error body for detailed message
try:
error_body = response.json()
error_msg = error_body.get("error", {}).get("message", f"HTTP {response.status_code}")
except Exception:
error_msg = f"HTTP {response.status_code}"
logger.error(f"Stream request failed: HTTP {response.status_code}, error: {error_msg}")
raise LLMProviderError("openai", f"HTTP {response.status_code}: {error_msg}")
return _StreamContext(response_ctx, response)
async def _iterate_stream(self, response, request: LLMRequest):
"""Iterate over an already-open SSE stream and yield StreamChunks."""
accumulated_tool_calls: dict[int, dict] = {} # index -> {id, name, arguments_str}
async for line in response.aiter_lines():
line = line.strip()
if not line or not line.startswith("data: "):
continue
data_str = line[6:] # Remove "data: " prefix
if data_str == "[DONE]":
break
try:
data = json.loads(data_str)
except json.JSONDecodeError:
continue
choices = data.get("choices", [])
if not choices:
# Usage-only chunk
usage_data = data.get("usage")
if usage_data:
yield StreamChunk(
content="",
model=data.get("model", request.model),
usage=TokenUsage(
prompt_tokens=usage_data.get("prompt_tokens", 0),
completion_tokens=usage_data.get("completion_tokens", 0),
),
is_final=True,
)
continue
delta = choices[0].get("delta", {})
content = delta.get("content", "")
# Accumulate tool calls from streaming
raw_tool_calls = delta.get("tool_calls")
if raw_tool_calls:
for tc in raw_tool_calls:
idx = tc.get("index", 0)
if idx not in accumulated_tool_calls:
accumulated_tool_calls[idx] = {
"id": tc.get("id", ""),
"name": "",
"arguments_str": "",
}
if tc.get("id"):
accumulated_tool_calls[idx]["id"] = tc["id"]
func = tc.get("function", {})
if func.get("name"):
accumulated_tool_calls[idx]["name"] = func["name"]
if func.get("arguments"):
accumulated_tool_calls[idx]["arguments_str"] += func["arguments"]
# Only yield content chunks (not empty deltas)
if content:
yield StreamChunk(
content=content,
model=data.get("model", request.model),
)
# If we accumulated tool calls, yield them as a final chunk
if accumulated_tool_calls:
tool_calls = []
for idx in sorted(accumulated_tool_calls.keys()):
tc_data = accumulated_tool_calls[idx]
try:
arguments = json.loads(tc_data["arguments_str"]) if tc_data["arguments_str"] else {}
except json.JSONDecodeError:
arguments = {"raw": tc_data["arguments_str"]}
tool_calls.append(ToolCall(
id=tc_data["id"],
name=tc_data["name"],
arguments=arguments,
))
yield StreamChunk(
content="",
model=request.model,
tool_calls=tool_calls,
is_final=True,
)