"""Tests for RemoteLLMProvider (U3) — client-side LLM gateway forwarding.""" import json from unittest.mock import AsyncMock, MagicMock import httpx import pytest from pytest_httpx import HTTPXMock from agentkit.core.exceptions import LLMProviderError, ModelNotFoundError from agentkit.llm.protocol import LLMRequest, LLMResponse, StreamChunk from agentkit.llm.remote_provider import RemoteLLMProvider SERVER_URL = "http://test-server:8000" CHAT_URL = f"{SERVER_URL}/api/v1/llm/chat" STREAM_URL = f"{SERVER_URL}/api/v1/llm/chat/stream" def _make_provider() -> RemoteLLMProvider: return RemoteLLMProvider( server_url=SERVER_URL, auth_token_provider=lambda: "test-jwt-token", timeout=30.0, ) def _make_request() -> LLMRequest: return LLMRequest( messages=[{"role": "user", "content": "Hello"}], model="gpt-4o-mini", temperature=0.5, max_tokens=1000, ) def _ok_response_body( content: str = "Hi there!", model: str = "gpt-4o-mini", tool_calls: list | None = None, ) -> dict: return { "content": content, "model": model, "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, "tool_calls": tool_calls or [], "latency_ms": 123.4, } class TestRemoteLLMProviderChat: """chat() method tests.""" async def test_chat_returns_llm_response(self, httpx_mock: HTTPXMock): httpx_mock.add_response(url=CHAT_URL, json=_ok_response_body()) provider = _make_provider() response = await provider.chat(_make_request()) assert isinstance(response, LLMResponse) assert response.content == "Hi there!" assert response.model == "gpt-4o-mini" assert response.usage.prompt_tokens == 10 assert response.usage.completion_tokens == 5 assert response.usage.total_tokens == 15 assert response.latency_ms == 123.4 assert response.tool_calls == [] async def test_chat_with_tool_calls(self, httpx_mock: HTTPXMock): httpx_mock.add_response( url=CHAT_URL, json=_ok_response_body( content="", tool_calls=[ {"id": "call_1", "name": "get_weather", "arguments": {"city": "Beijing"}}, ], ), ) provider = _make_provider() response = await provider.chat(_make_request()) assert response.has_tool_calls assert len(response.tool_calls) == 1 tc = response.tool_calls[0] assert tc.id == "call_1" assert tc.name == "get_weather" assert tc.arguments == {"city": "Beijing"} async def test_chat_401_raises_connection_error(self, httpx_mock: HTTPXMock): httpx_mock.add_response(url=CHAT_URL, status_code=401, json={"detail": "Unauthorized"}) provider = _make_provider() with pytest.raises(ConnectionError, match="Authentication failed"): await provider.chat(_make_request()) async def test_chat_404_raises_model_not_found(self, httpx_mock: HTTPXMock): httpx_mock.add_response( url=CHAT_URL, status_code=404, json={"detail": "Model not found: gpt-4o-mini"} ) provider = _make_provider() with pytest.raises(ModelNotFoundError): await provider.chat(_make_request()) async def test_chat_502_raises_provider_error(self, httpx_mock: HTTPXMock): httpx_mock.add_response( url=CHAT_URL, status_code=502, json={"detail": "Upstream provider error"} ) provider = _make_provider() with pytest.raises(LLMProviderError, match="Server LLM gateway error"): await provider.chat(_make_request()) async def test_chat_timeout_raises_provider_error(self, httpx_mock: HTTPXMock): httpx_mock.add_exception(httpx.ReadTimeout("read timeout")) provider = _make_provider() with pytest.raises(LLMProviderError, match="Request timeout"): await provider.chat(_make_request()) async def test_chat_payload_correct(self, httpx_mock: HTTPXMock): httpx_mock.add_response(url=CHAT_URL, json=_ok_response_body()) provider = _make_provider() request = LLMRequest( messages=[{"role": "user", "content": "Hi"}], model="gpt-4o-mini", temperature=0.3, max_tokens=500, tools=[{"type": "function", "function": {"name": "search"}}], tool_choice="auto", ) await provider.chat(request) sent = httpx_mock.get_request() body = json.loads(sent.content) assert body["messages"] == [{"role": "user", "content": "Hi"}] assert body["model"] == "gpt-4o-mini" assert body["temperature"] == 0.3 assert body["max_tokens"] == 500 assert body["tools"] == [{"type": "function", "function": {"name": "search"}}] assert body["tool_choice"] == "auto" async def test_chat_headers_include_auth(self, httpx_mock: HTTPXMock): httpx_mock.add_response(url=CHAT_URL, json=_ok_response_body()) provider = _make_provider() await provider.chat(_make_request()) sent = httpx_mock.get_request() assert sent.headers["Authorization"] == "Bearer test-jwt-token" assert sent.headers["Content-Type"] == "application/json" async def test_chat_unexpected_status_raises_provider_error(self, httpx_mock: HTTPXMock): httpx_mock.add_response(url=CHAT_URL, status_code=500, text="Internal Server Error") provider = _make_provider() with pytest.raises(LLMProviderError, match="Unexpected status 500"): await provider.chat(_make_request()) class TestRemoteLLMProviderStream: """chat_stream() method tests.""" @staticmethod def _make_stream_response(sse_lines: list[str], status_code: int = 200) -> MagicMock: """Create a mock httpx streaming response context manager.""" response = MagicMock() response.status_code = status_code async def aiter_lines(): for line in sse_lines: yield line response.aiter_lines = aiter_lines response.aread = AsyncMock(return_value=b"") response.text = "" context = MagicMock() context.__aenter__ = AsyncMock(return_value=response) context.__aexit__ = AsyncMock(return_value=False) return context @staticmethod def _sse(data: dict) -> str: return f"data: {json.dumps(data)}" async def test_stream_parses_sse(self): sse_lines = [ TestRemoteLLMProviderStream._sse( {"content": "Hello", "model": "gpt-4o-mini", "is_final": False} ), "", TestRemoteLLMProviderStream._sse( {"content": " world", "model": "gpt-4o-mini", "is_final": False} ), "", TestRemoteLLMProviderStream._sse( { "content": "", "model": "gpt-4o-mini", "is_final": True, "usage": { "prompt_tokens": 5, "completion_tokens": 3, "total_tokens": 8, }, } ), "", "data: [DONE]", "", ] provider = _make_provider() provider._client.stream = MagicMock(return_value=self._make_stream_response(sse_lines)) chunks = [] async for chunk in provider.chat_stream(_make_request()): chunks.append(chunk) assert len(chunks) == 3 assert chunks[0].content == "Hello" assert chunks[0].is_final is False assert chunks[1].content == " world" assert chunks[2].is_final is True assert chunks[2].usage is not None assert chunks[2].usage.prompt_tokens == 5 assert chunks[2].usage.completion_tokens == 3 async def test_stream_done_terminates(self): """data: [DONE] should stop iteration — later lines must not be yielded.""" sse_lines = [ self._sse({"content": "Hi", "model": "gpt-4o-mini", "is_final": False}), "", "data: [DONE]", "", self._sse({"content": "should not appear", "model": "gpt-4o-mini", "is_final": False}), "", ] provider = _make_provider() provider._client.stream = MagicMock(return_value=self._make_stream_response(sse_lines)) chunks = [] async for chunk in provider.chat_stream(_make_request()): chunks.append(chunk) assert len(chunks) == 1 assert chunks[0].content == "Hi" async def test_stream_error_raises(self): """An error payload in the stream should raise LLMProviderError.""" sse_lines = [ self._sse({"error": "provider_error", "detail": "upstream failed"}), "", ] provider = _make_provider() provider._client.stream = MagicMock(return_value=self._make_stream_response(sse_lines)) with pytest.raises(LLMProviderError, match="Stream error"): async for _ in provider.chat_stream(_make_request()): pass async def test_stream_empty(self): """A stream with only [DONE] yields no chunks.""" sse_lines = ["data: [DONE]", ""] provider = _make_provider() provider._client.stream = MagicMock(return_value=self._make_stream_response(sse_lines)) chunks = [] async for chunk in provider.chat_stream(_make_request()): chunks.append(chunk) assert chunks == [] async def test_stream_is_final(self): """A chunk with is_final=True should be correctly marked.""" sse_lines = [ self._sse( { "content": "response", "model": "gpt-4o-mini", "is_final": True, "usage": { "prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2, }, } ), "", "data: [DONE]", "", ] provider = _make_provider() provider._client.stream = MagicMock(return_value=self._make_stream_response(sse_lines)) chunks = [] async for chunk in provider.chat_stream(_make_request()): chunks.append(chunk) assert len(chunks) == 1 assert isinstance(chunks[0], StreamChunk) assert chunks[0].is_final is True assert chunks[0].usage is not None assert chunks[0].usage.total_tokens == 2 async def test_stream_tool_calls_in_chunk(self): """Tool calls in a stream chunk should be parsed correctly.""" sse_lines = [ self._sse( { "content": "", "model": "gpt-4o-mini", "is_final": True, "tool_calls": [ {"id": "tc_1", "name": "search", "arguments": {"q": "test"}}, ], "usage": { "prompt_tokens": 2, "completion_tokens": 2, "total_tokens": 4, }, } ), "", "data: [DONE]", "", ] provider = _make_provider() provider._client.stream = MagicMock(return_value=self._make_stream_response(sse_lines)) chunks = [] async for chunk in provider.chat_stream(_make_request()): chunks.append(chunk) assert len(chunks) == 1 assert len(chunks[0].tool_calls) == 1 assert chunks[0].tool_calls[0].name == "search" assert chunks[0].tool_calls[0].arguments == {"q": "test"} async def test_stream_401_raises_connection_error(self): sse_lines: list[str] = [] provider = _make_provider() provider._client.stream = MagicMock( return_value=self._make_stream_response(sse_lines, status_code=401) ) with pytest.raises(ConnectionError, match="Authentication failed"): async for _ in provider.chat_stream(_make_request()): pass async def test_stream_skips_non_data_lines(self): """Lines without 'data: ' prefix should be skipped.""" sse_lines = [ ": comment line", "", self._sse({"content": "ok", "model": "gpt-4o-mini", "is_final": False}), "", "event: ping", "data: [DONE]", "", ] provider = _make_provider() provider._client.stream = MagicMock(return_value=self._make_stream_response(sse_lines)) chunks = [] async for chunk in provider.chat_stream(_make_request()): chunks.append(chunk) assert len(chunks) == 1 assert chunks[0].content == "ok" class TestRemoteLLMProviderHelpers: """Tests for helper methods.""" def test_headers_include_bearer_token(self): provider = RemoteLLMProvider( server_url="https://api.example.com", auth_token_provider=lambda: "abc123", ) headers = provider._headers() assert headers["Authorization"] == "Bearer abc123" assert headers["Content-Type"] == "application/json" def test_server_url_trailing_slash_stripped(self): provider = RemoteLLMProvider( server_url="https://api.example.com/", auth_token_provider=lambda: "token", ) assert provider._server_url == "https://api.example.com" def test_build_payload_includes_all_fields(self): provider = _make_provider() request = LLMRequest( messages=[{"role": "user", "content": "Hi"}], model="gpt-4", temperature=0.1, max_tokens=50, tools=[{"type": "function"}], tool_choice="required", timeout=10.0, ) payload = provider._build_payload(request) assert payload["messages"] == [{"role": "user", "content": "Hi"}] assert payload["model"] == "gpt-4" assert payload["temperature"] == 0.1 assert payload["max_tokens"] == 50 assert payload["tools"] == [{"type": "function"}] assert payload["tool_choice"] == "required" assert payload["timeout"] == 10.0