fischer-agentkit/src/agentkit/llm/remote_provider.py

252 lines
10 KiB
Python

"""Remote LLM Provider — forwards requests to a server-side LLM gateway."""
import json
import logging
from collections.abc import AsyncIterator, Awaitable, Callable
import httpx
from agentkit.core.exceptions import LLMProviderError, ModelNotFoundError
from agentkit.llm.protocol import (
LLMProvider,
LLMRequest,
LLMResponse,
StreamChunk,
TokenUsage,
ToolCall,
)
logger = logging.getLogger(__name__)
class RemoteLLMProvider(LLMProvider):
"""LLM Provider that forwards requests to a server-side LLM gateway.
This provider does NOT store API keys locally — all LLM calls go through
the server's LLM gateway which manages keys, usage tracking, and cost control.
On 401 responses, the provider can optionally invoke a ``refresh_callback``
to obtain a fresh JWT, then retry the request once. This mirrors the
browser-side refresh flow in ``auth.ts``.
"""
def __init__(
self,
server_url: str,
auth_token_provider: Callable[[], str],
timeout: float = 120.0,
refresh_callback: Callable[[], Awaitable[bool]] | None = None,
):
"""Initialize the remote provider.
Args:
server_url: Base URL of the server (e.g., "https://api.example.com")
auth_token_provider: Callable that returns the current JWT token
timeout: Request timeout in seconds
refresh_callback: Optional async callable that attempts a token
refresh and returns ``True`` on success. When provided, a 401
response triggers one refresh + retry cycle.
"""
self._server_url = server_url.rstrip("/")
self._auth_token_provider = auth_token_provider
self._timeout = timeout
self._refresh_callback = refresh_callback
self._client = httpx.AsyncClient(timeout=timeout)
async def close(self) -> None:
"""Close the HTTP client connection pool."""
await self._client.aclose()
def _headers(self) -> dict[str, str]:
"""Build request headers with JWT authentication."""
token = self._auth_token_provider()
return {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
def _build_payload(self, request: LLMRequest) -> dict[str, object]:
"""Convert LLMRequest to server API payload."""
return {
"messages": request.messages,
"model": request.model,
"temperature": request.temperature,
"max_tokens": request.max_tokens,
"tools": request.tools,
"tool_choice": request.tool_choice,
"timeout": request.timeout,
}
def _extract_error_detail(self, resp: httpx.Response) -> str:
"""Extract a human-readable error detail from an error response."""
try:
body = resp.json()
except Exception:
return resp.text
if isinstance(body, dict):
if "detail" in body:
return str(body["detail"])
if "error" in body:
return str(body["error"])
return str(body)
def _parse_response(self, data: dict[str, object], request: LLMRequest) -> LLMResponse:
"""Parse server response JSON into an LLMResponse."""
usage_data = data.get("usage") or {}
usage = TokenUsage(
prompt_tokens=usage_data.get("prompt_tokens", 0),
completion_tokens=usage_data.get("completion_tokens", 0),
)
tool_calls: list[ToolCall] = []
for tc in data.get("tool_calls") or []:
tool_calls.append(
ToolCall(
id=tc.get("id", ""),
name=tc.get("name", ""),
arguments=tc.get("arguments") or {},
)
)
return LLMResponse(
content=data.get("content", ""),
model=data.get("model", request.model),
usage=usage,
tool_calls=tool_calls,
latency_ms=data.get("latency_ms", 0.0),
)
def _parse_chunk(self, data: dict[str, object], request: LLMRequest) -> StreamChunk:
"""Parse a single SSE data payload into a StreamChunk."""
usage: TokenUsage | None = None
usage_data = data.get("usage")
if usage_data:
usage = TokenUsage(
prompt_tokens=usage_data.get("prompt_tokens", 0),
completion_tokens=usage_data.get("completion_tokens", 0),
)
tool_calls: list[ToolCall] = []
for tc in data.get("tool_calls") or []:
tool_calls.append(
ToolCall(
id=tc.get("id", ""),
name=tc.get("name", ""),
arguments=tc.get("arguments") or {},
)
)
return StreamChunk(
content=data.get("content", ""),
model=data.get("model", request.model),
tool_calls=tool_calls,
usage=usage,
is_final=data.get("is_final", False),
)
async def _try_refresh(self) -> bool:
"""Attempt a token refresh via the configured callback.
Returns ``True`` if the refresh succeeded (or no callback is
configured, in which case 401 is not recoverable).
"""
if self._refresh_callback is None:
return False
try:
return bool(await self._refresh_callback())
except Exception as exc:
logger.warning(f"Token refresh callback failed: {exc}")
return False
async def chat(self, request: LLMRequest) -> LLMResponse:
"""Send a non-streaming chat request to the server gateway."""
url = f"{self._server_url}/api/v1/llm/chat"
payload = self._build_payload(request)
for attempt in range(2): # original + 1 retry after refresh
headers = self._headers()
try:
resp = await self._client.post(url, json=payload, headers=headers)
except httpx.TimeoutException as e:
raise LLMProviderError("remote", f"Request timeout after {self._timeout}s") from e
except httpx.HTTPError as e:
raise LLMProviderError("remote", str(e)) from e
if resp.status_code == 401 and attempt == 0 and await self._try_refresh():
logger.info("LLM gateway returned 401; refreshed token, retrying once.")
continue
if resp.status_code == 401:
raise ConnectionError("Authentication failed")
if resp.status_code == 404:
raise ModelNotFoundError(request.model)
if resp.status_code == 502:
detail = self._extract_error_detail(resp)
raise LLMProviderError("remote", f"Server LLM gateway error: {detail}")
if resp.status_code != 200:
raise LLMProviderError(
"remote", f"Unexpected status {resp.status_code}: {resp.text}"
)
data = resp.json()
return self._parse_response(data, request)
# Should be unreachable — loop exits via return or raise above.
raise ConnectionError("Authentication failed")
async def chat_stream(self, request: LLMRequest) -> AsyncIterator[StreamChunk]:
"""Send a streaming chat request to the server gateway.
Parses SSE response (data: {json}\\n\\n) into StreamChunk objects.
Terminates on ``data: [DONE]``. Raises LLMProviderError if the
stream contains an error payload.
"""
url = f"{self._server_url}/api/v1/llm/chat/stream"
payload = self._build_payload(request)
try:
for attempt in range(2): # original + 1 retry after refresh
headers = self._headers()
async with self._client.stream(
"POST", url, json=payload, headers=headers
) as response:
if response.status_code == 401 and attempt == 0 and await self._try_refresh():
logger.info(
"LLM gateway stream returned 401; refreshed token, retrying once."
)
continue
if response.status_code == 401:
raise ConnectionError("Authentication failed")
if response.status_code == 404:
raise ModelNotFoundError(request.model)
if response.status_code == 502:
await response.aread()
detail = self._extract_error_detail(response)
raise LLMProviderError("remote", f"Server LLM gateway error: {detail}")
if response.status_code != 200:
await response.aread()
raise LLMProviderError(
"remote",
f"Unexpected status {response.status_code}: {response.text}",
)
async for line in response.aiter_lines():
line = line.strip()
if not line or not line.startswith("data: "):
continue
data_str = line[6:]
if data_str == "[DONE]":
break
try:
data = json.loads(data_str)
except json.JSONDecodeError:
logger.warning(f"Failed to parse SSE line: {data_str}")
continue
if "error" in data:
raise LLMProviderError(
"remote",
f"Stream error: {data.get('detail', data['error'])}",
)
yield self._parse_chunk(data, request)
# Stream completed successfully — exit retry loop.
break
except httpx.TimeoutException as e:
raise LLMProviderError("remote", f"Request timeout after {self._timeout}s") from e
except httpx.HTTPError as e:
raise LLMProviderError("remote", str(e)) from e