252 lines
10 KiB
Python
252 lines
10 KiB
Python
"""Remote LLM Provider — forwards requests to a server-side LLM gateway."""
|
|
|
|
import json
|
|
import logging
|
|
from collections.abc import AsyncIterator, Awaitable, Callable
|
|
|
|
import httpx
|
|
|
|
from agentkit.core.exceptions import LLMProviderError, ModelNotFoundError
|
|
from agentkit.llm.protocol import (
|
|
LLMProvider,
|
|
LLMRequest,
|
|
LLMResponse,
|
|
StreamChunk,
|
|
TokenUsage,
|
|
ToolCall,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RemoteLLMProvider(LLMProvider):
|
|
"""LLM Provider that forwards requests to a server-side LLM gateway.
|
|
|
|
This provider does NOT store API keys locally — all LLM calls go through
|
|
the server's LLM gateway which manages keys, usage tracking, and cost control.
|
|
|
|
On 401 responses, the provider can optionally invoke a ``refresh_callback``
|
|
to obtain a fresh JWT, then retry the request once. This mirrors the
|
|
browser-side refresh flow in ``auth.ts``.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
server_url: str,
|
|
auth_token_provider: Callable[[], str],
|
|
timeout: float = 120.0,
|
|
refresh_callback: Callable[[], Awaitable[bool]] | None = None,
|
|
):
|
|
"""Initialize the remote provider.
|
|
|
|
Args:
|
|
server_url: Base URL of the server (e.g., "https://api.example.com")
|
|
auth_token_provider: Callable that returns the current JWT token
|
|
timeout: Request timeout in seconds
|
|
refresh_callback: Optional async callable that attempts a token
|
|
refresh and returns ``True`` on success. When provided, a 401
|
|
response triggers one refresh + retry cycle.
|
|
"""
|
|
self._server_url = server_url.rstrip("/")
|
|
self._auth_token_provider = auth_token_provider
|
|
self._timeout = timeout
|
|
self._refresh_callback = refresh_callback
|
|
self._client = httpx.AsyncClient(timeout=timeout)
|
|
|
|
async def close(self) -> None:
|
|
"""Close the HTTP client connection pool."""
|
|
await self._client.aclose()
|
|
|
|
def _headers(self) -> dict[str, str]:
|
|
"""Build request headers with JWT authentication."""
|
|
token = self._auth_token_provider()
|
|
return {
|
|
"Authorization": f"Bearer {token}",
|
|
"Content-Type": "application/json",
|
|
}
|
|
|
|
def _build_payload(self, request: LLMRequest) -> dict[str, object]:
|
|
"""Convert LLMRequest to server API payload."""
|
|
return {
|
|
"messages": request.messages,
|
|
"model": request.model,
|
|
"temperature": request.temperature,
|
|
"max_tokens": request.max_tokens,
|
|
"tools": request.tools,
|
|
"tool_choice": request.tool_choice,
|
|
"timeout": request.timeout,
|
|
}
|
|
|
|
def _extract_error_detail(self, resp: httpx.Response) -> str:
|
|
"""Extract a human-readable error detail from an error response."""
|
|
try:
|
|
body = resp.json()
|
|
except Exception:
|
|
return resp.text
|
|
if isinstance(body, dict):
|
|
if "detail" in body:
|
|
return str(body["detail"])
|
|
if "error" in body:
|
|
return str(body["error"])
|
|
return str(body)
|
|
|
|
def _parse_response(self, data: dict[str, object], request: LLMRequest) -> LLMResponse:
|
|
"""Parse server response JSON into an LLMResponse."""
|
|
usage_data = data.get("usage") or {}
|
|
usage = TokenUsage(
|
|
prompt_tokens=usage_data.get("prompt_tokens", 0),
|
|
completion_tokens=usage_data.get("completion_tokens", 0),
|
|
)
|
|
tool_calls: list[ToolCall] = []
|
|
for tc in data.get("tool_calls") or []:
|
|
tool_calls.append(
|
|
ToolCall(
|
|
id=tc.get("id", ""),
|
|
name=tc.get("name", ""),
|
|
arguments=tc.get("arguments") or {},
|
|
)
|
|
)
|
|
return LLMResponse(
|
|
content=data.get("content", ""),
|
|
model=data.get("model", request.model),
|
|
usage=usage,
|
|
tool_calls=tool_calls,
|
|
latency_ms=data.get("latency_ms", 0.0),
|
|
)
|
|
|
|
def _parse_chunk(self, data: dict[str, object], request: LLMRequest) -> StreamChunk:
|
|
"""Parse a single SSE data payload into a StreamChunk."""
|
|
usage: TokenUsage | None = None
|
|
usage_data = data.get("usage")
|
|
if usage_data:
|
|
usage = TokenUsage(
|
|
prompt_tokens=usage_data.get("prompt_tokens", 0),
|
|
completion_tokens=usage_data.get("completion_tokens", 0),
|
|
)
|
|
tool_calls: list[ToolCall] = []
|
|
for tc in data.get("tool_calls") or []:
|
|
tool_calls.append(
|
|
ToolCall(
|
|
id=tc.get("id", ""),
|
|
name=tc.get("name", ""),
|
|
arguments=tc.get("arguments") or {},
|
|
)
|
|
)
|
|
return StreamChunk(
|
|
content=data.get("content", ""),
|
|
model=data.get("model", request.model),
|
|
tool_calls=tool_calls,
|
|
usage=usage,
|
|
is_final=data.get("is_final", False),
|
|
)
|
|
|
|
async def _try_refresh(self) -> bool:
|
|
"""Attempt a token refresh via the configured callback.
|
|
|
|
Returns ``True`` if the refresh succeeded (or no callback is
|
|
configured, in which case 401 is not recoverable).
|
|
"""
|
|
if self._refresh_callback is None:
|
|
return False
|
|
try:
|
|
return bool(await self._refresh_callback())
|
|
except Exception as exc:
|
|
logger.warning(f"Token refresh callback failed: {exc}")
|
|
return False
|
|
|
|
async def chat(self, request: LLMRequest) -> LLMResponse:
|
|
"""Send a non-streaming chat request to the server gateway."""
|
|
url = f"{self._server_url}/api/v1/llm/chat"
|
|
payload = self._build_payload(request)
|
|
|
|
for attempt in range(2): # original + 1 retry after refresh
|
|
headers = self._headers()
|
|
try:
|
|
resp = await self._client.post(url, json=payload, headers=headers)
|
|
except httpx.TimeoutException as e:
|
|
raise LLMProviderError("remote", f"Request timeout after {self._timeout}s") from e
|
|
except httpx.HTTPError as e:
|
|
raise LLMProviderError("remote", str(e)) from e
|
|
|
|
if resp.status_code == 401 and attempt == 0 and await self._try_refresh():
|
|
logger.info("LLM gateway returned 401; refreshed token, retrying once.")
|
|
continue
|
|
if resp.status_code == 401:
|
|
raise ConnectionError("Authentication failed")
|
|
if resp.status_code == 404:
|
|
raise ModelNotFoundError(request.model)
|
|
if resp.status_code == 502:
|
|
detail = self._extract_error_detail(resp)
|
|
raise LLMProviderError("remote", f"Server LLM gateway error: {detail}")
|
|
if resp.status_code != 200:
|
|
raise LLMProviderError(
|
|
"remote", f"Unexpected status {resp.status_code}: {resp.text}"
|
|
)
|
|
|
|
data = resp.json()
|
|
return self._parse_response(data, request)
|
|
|
|
# Should be unreachable — loop exits via return or raise above.
|
|
raise ConnectionError("Authentication failed")
|
|
|
|
async def chat_stream(self, request: LLMRequest) -> AsyncIterator[StreamChunk]:
|
|
"""Send a streaming chat request to the server gateway.
|
|
|
|
Parses SSE response (data: {json}\\n\\n) into StreamChunk objects.
|
|
Terminates on ``data: [DONE]``. Raises LLMProviderError if the
|
|
stream contains an error payload.
|
|
"""
|
|
url = f"{self._server_url}/api/v1/llm/chat/stream"
|
|
payload = self._build_payload(request)
|
|
|
|
try:
|
|
for attempt in range(2): # original + 1 retry after refresh
|
|
headers = self._headers()
|
|
async with self._client.stream(
|
|
"POST", url, json=payload, headers=headers
|
|
) as response:
|
|
if response.status_code == 401 and attempt == 0 and await self._try_refresh():
|
|
logger.info(
|
|
"LLM gateway stream returned 401; refreshed token, retrying once."
|
|
)
|
|
continue
|
|
if response.status_code == 401:
|
|
raise ConnectionError("Authentication failed")
|
|
if response.status_code == 404:
|
|
raise ModelNotFoundError(request.model)
|
|
if response.status_code == 502:
|
|
await response.aread()
|
|
detail = self._extract_error_detail(response)
|
|
raise LLMProviderError("remote", f"Server LLM gateway error: {detail}")
|
|
if response.status_code != 200:
|
|
await response.aread()
|
|
raise LLMProviderError(
|
|
"remote",
|
|
f"Unexpected status {response.status_code}: {response.text}",
|
|
)
|
|
|
|
async for line in response.aiter_lines():
|
|
line = line.strip()
|
|
if not line or not line.startswith("data: "):
|
|
continue
|
|
data_str = line[6:]
|
|
if data_str == "[DONE]":
|
|
break
|
|
try:
|
|
data = json.loads(data_str)
|
|
except json.JSONDecodeError:
|
|
logger.warning(f"Failed to parse SSE line: {data_str}")
|
|
continue
|
|
if "error" in data:
|
|
raise LLMProviderError(
|
|
"remote",
|
|
f"Stream error: {data.get('detail', data['error'])}",
|
|
)
|
|
yield self._parse_chunk(data, request)
|
|
# Stream completed successfully — exit retry loop.
|
|
break
|
|
except httpx.TimeoutException as e:
|
|
raise LLMProviderError("remote", f"Request timeout after {self._timeout}s") from e
|
|
except httpx.HTTPError as e:
|
|
raise LLMProviderError("remote", str(e)) from e
|