"""Remote LLM Provider — forwards requests to a server-side LLM gateway.""" import json import logging from collections.abc import AsyncIterator, Awaitable, Callable import httpx from agentkit.core.exceptions import LLMProviderError, ModelNotFoundError from agentkit.llm.protocol import ( LLMProvider, LLMRequest, LLMResponse, StreamChunk, TokenUsage, ToolCall, ) logger = logging.getLogger(__name__) class RemoteLLMProvider(LLMProvider): """LLM Provider that forwards requests to a server-side LLM gateway. This provider does NOT store API keys locally — all LLM calls go through the server's LLM gateway which manages keys, usage tracking, and cost control. On 401 responses, the provider can optionally invoke a ``refresh_callback`` to obtain a fresh JWT, then retry the request once. This mirrors the browser-side refresh flow in ``auth.ts``. """ def __init__( self, server_url: str, auth_token_provider: Callable[[], str], timeout: float = 120.0, refresh_callback: Callable[[], Awaitable[bool]] | None = None, ): """Initialize the remote provider. Args: server_url: Base URL of the server (e.g., "https://api.example.com") auth_token_provider: Callable that returns the current JWT token timeout: Request timeout in seconds refresh_callback: Optional async callable that attempts a token refresh and returns ``True`` on success. When provided, a 401 response triggers one refresh + retry cycle. """ self._server_url = server_url.rstrip("/") self._auth_token_provider = auth_token_provider self._timeout = timeout self._refresh_callback = refresh_callback self._client = httpx.AsyncClient(timeout=timeout) async def close(self) -> None: """Close the HTTP client connection pool.""" await self._client.aclose() def _headers(self) -> dict[str, str]: """Build request headers with JWT authentication.""" token = self._auth_token_provider() return { "Authorization": f"Bearer {token}", "Content-Type": "application/json", } def _build_payload(self, request: LLMRequest) -> dict[str, object]: """Convert LLMRequest to server API payload.""" return { "messages": request.messages, "model": request.model, "temperature": request.temperature, "max_tokens": request.max_tokens, "tools": request.tools, "tool_choice": request.tool_choice, "timeout": request.timeout, } def _extract_error_detail(self, resp: httpx.Response) -> str: """Extract a human-readable error detail from an error response.""" try: body = resp.json() except Exception: return resp.text if isinstance(body, dict): if "detail" in body: return str(body["detail"]) if "error" in body: return str(body["error"]) return str(body) def _parse_response(self, data: dict[str, object], request: LLMRequest) -> LLMResponse: """Parse server response JSON into an LLMResponse.""" usage_data = data.get("usage") or {} usage = TokenUsage( prompt_tokens=usage_data.get("prompt_tokens", 0), completion_tokens=usage_data.get("completion_tokens", 0), ) tool_calls: list[ToolCall] = [] for tc in data.get("tool_calls") or []: tool_calls.append( ToolCall( id=tc.get("id", ""), name=tc.get("name", ""), arguments=tc.get("arguments") or {}, ) ) return LLMResponse( content=data.get("content", ""), model=data.get("model", request.model), usage=usage, tool_calls=tool_calls, latency_ms=data.get("latency_ms", 0.0), ) def _parse_chunk(self, data: dict[str, object], request: LLMRequest) -> StreamChunk: """Parse a single SSE data payload into a StreamChunk.""" usage: TokenUsage | None = None usage_data = data.get("usage") if usage_data: usage = TokenUsage( prompt_tokens=usage_data.get("prompt_tokens", 0), completion_tokens=usage_data.get("completion_tokens", 0), ) tool_calls: list[ToolCall] = [] for tc in data.get("tool_calls") or []: tool_calls.append( ToolCall( id=tc.get("id", ""), name=tc.get("name", ""), arguments=tc.get("arguments") or {}, ) ) return StreamChunk( content=data.get("content", ""), model=data.get("model", request.model), tool_calls=tool_calls, usage=usage, is_final=data.get("is_final", False), ) async def _try_refresh(self) -> bool: """Attempt a token refresh via the configured callback. Returns ``True`` if the refresh succeeded (or no callback is configured, in which case 401 is not recoverable). """ if self._refresh_callback is None: return False try: return bool(await self._refresh_callback()) except Exception as exc: logger.warning(f"Token refresh callback failed: {exc}") return False async def chat(self, request: LLMRequest) -> LLMResponse: """Send a non-streaming chat request to the server gateway.""" url = f"{self._server_url}/api/v1/llm/chat" payload = self._build_payload(request) for attempt in range(2): # original + 1 retry after refresh headers = self._headers() try: resp = await self._client.post(url, json=payload, headers=headers) except httpx.TimeoutException as e: raise LLMProviderError("remote", f"Request timeout after {self._timeout}s") from e except httpx.HTTPError as e: raise LLMProviderError("remote", str(e)) from e if resp.status_code == 401 and attempt == 0 and await self._try_refresh(): logger.info("LLM gateway returned 401; refreshed token, retrying once.") continue if resp.status_code == 401: raise ConnectionError("Authentication failed") if resp.status_code == 404: raise ModelNotFoundError(request.model) if resp.status_code == 502: detail = self._extract_error_detail(resp) raise LLMProviderError("remote", f"Server LLM gateway error: {detail}") if resp.status_code != 200: raise LLMProviderError( "remote", f"Unexpected status {resp.status_code}: {resp.text}" ) data = resp.json() return self._parse_response(data, request) # Should be unreachable — loop exits via return or raise above. raise ConnectionError("Authentication failed") async def chat_stream(self, request: LLMRequest) -> AsyncIterator[StreamChunk]: """Send a streaming chat request to the server gateway. Parses SSE response (data: {json}\\n\\n) into StreamChunk objects. Terminates on ``data: [DONE]``. Raises LLMProviderError if the stream contains an error payload. """ url = f"{self._server_url}/api/v1/llm/chat/stream" payload = self._build_payload(request) try: for attempt in range(2): # original + 1 retry after refresh headers = self._headers() async with self._client.stream( "POST", url, json=payload, headers=headers ) as response: if response.status_code == 401 and attempt == 0 and await self._try_refresh(): logger.info( "LLM gateway stream returned 401; refreshed token, retrying once." ) continue if response.status_code == 401: raise ConnectionError("Authentication failed") if response.status_code == 404: raise ModelNotFoundError(request.model) if response.status_code == 502: await response.aread() detail = self._extract_error_detail(response) raise LLMProviderError("remote", f"Server LLM gateway error: {detail}") if response.status_code != 200: await response.aread() raise LLMProviderError( "remote", f"Unexpected status {response.status_code}: {response.text}", ) async for line in response.aiter_lines(): line = line.strip() if not line or not line.startswith("data: "): continue data_str = line[6:] if data_str == "[DONE]": break try: data = json.loads(data_str) except json.JSONDecodeError: logger.warning(f"Failed to parse SSE line: {data_str}") continue if "error" in data: raise LLMProviderError( "remote", f"Stream error: {data.get('detail', data['error'])}", ) yield self._parse_chunk(data, request) # Stream completed successfully — exit retry loop. break except httpx.TimeoutException as e: raise LLMProviderError("remote", f"Request timeout after {self._timeout}s") from e except httpx.HTTPError as e: raise LLMProviderError("remote", str(e)) from e