fischer-agentkit/tests/unit/test_remote_provider.py

420 lines
14 KiB
Python

"""Tests for RemoteLLMProvider (U3) — client-side LLM gateway forwarding."""
import json
from unittest.mock import AsyncMock, MagicMock
import httpx
import pytest
from pytest_httpx import HTTPXMock
from agentkit.core.exceptions import LLMProviderError, ModelNotFoundError
from agentkit.llm.protocol import LLMRequest, LLMResponse, StreamChunk
from agentkit.llm.remote_provider import RemoteLLMProvider
SERVER_URL = "http://test-server:8000"
CHAT_URL = f"{SERVER_URL}/api/v1/llm/chat"
STREAM_URL = f"{SERVER_URL}/api/v1/llm/chat/stream"
def _make_provider() -> RemoteLLMProvider:
return RemoteLLMProvider(
server_url=SERVER_URL,
auth_token_provider=lambda: "test-jwt-token",
timeout=30.0,
)
def _make_request() -> LLMRequest:
return LLMRequest(
messages=[{"role": "user", "content": "Hello"}],
model="gpt-4o-mini",
temperature=0.5,
max_tokens=1000,
)
def _ok_response_body(
content: str = "Hi there!",
model: str = "gpt-4o-mini",
tool_calls: list | None = None,
) -> dict:
return {
"content": content,
"model": model,
"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
"tool_calls": tool_calls or [],
"latency_ms": 123.4,
}
class TestRemoteLLMProviderChat:
"""chat() method tests."""
async def test_chat_returns_llm_response(self, httpx_mock: HTTPXMock):
httpx_mock.add_response(url=CHAT_URL, json=_ok_response_body())
provider = _make_provider()
response = await provider.chat(_make_request())
assert isinstance(response, LLMResponse)
assert response.content == "Hi there!"
assert response.model == "gpt-4o-mini"
assert response.usage.prompt_tokens == 10
assert response.usage.completion_tokens == 5
assert response.usage.total_tokens == 15
assert response.latency_ms == 123.4
assert response.tool_calls == []
async def test_chat_with_tool_calls(self, httpx_mock: HTTPXMock):
httpx_mock.add_response(
url=CHAT_URL,
json=_ok_response_body(
content="",
tool_calls=[
{"id": "call_1", "name": "get_weather", "arguments": {"city": "Beijing"}},
],
),
)
provider = _make_provider()
response = await provider.chat(_make_request())
assert response.has_tool_calls
assert len(response.tool_calls) == 1
tc = response.tool_calls[0]
assert tc.id == "call_1"
assert tc.name == "get_weather"
assert tc.arguments == {"city": "Beijing"}
async def test_chat_401_raises_connection_error(self, httpx_mock: HTTPXMock):
httpx_mock.add_response(url=CHAT_URL, status_code=401, json={"detail": "Unauthorized"})
provider = _make_provider()
with pytest.raises(ConnectionError, match="Authentication failed"):
await provider.chat(_make_request())
async def test_chat_404_raises_model_not_found(self, httpx_mock: HTTPXMock):
httpx_mock.add_response(
url=CHAT_URL, status_code=404, json={"detail": "Model not found: gpt-4o-mini"}
)
provider = _make_provider()
with pytest.raises(ModelNotFoundError):
await provider.chat(_make_request())
async def test_chat_502_raises_provider_error(self, httpx_mock: HTTPXMock):
httpx_mock.add_response(
url=CHAT_URL, status_code=502, json={"detail": "Upstream provider error"}
)
provider = _make_provider()
with pytest.raises(LLMProviderError, match="Server LLM gateway error"):
await provider.chat(_make_request())
async def test_chat_timeout_raises_provider_error(self, httpx_mock: HTTPXMock):
httpx_mock.add_exception(httpx.ReadTimeout("read timeout"))
provider = _make_provider()
with pytest.raises(LLMProviderError, match="Request timeout"):
await provider.chat(_make_request())
async def test_chat_payload_correct(self, httpx_mock: HTTPXMock):
httpx_mock.add_response(url=CHAT_URL, json=_ok_response_body())
provider = _make_provider()
request = LLMRequest(
messages=[{"role": "user", "content": "Hi"}],
model="gpt-4o-mini",
temperature=0.3,
max_tokens=500,
tools=[{"type": "function", "function": {"name": "search"}}],
tool_choice="auto",
)
await provider.chat(request)
sent = httpx_mock.get_request()
body = json.loads(sent.content)
assert body["messages"] == [{"role": "user", "content": "Hi"}]
assert body["model"] == "gpt-4o-mini"
assert body["temperature"] == 0.3
assert body["max_tokens"] == 500
assert body["tools"] == [{"type": "function", "function": {"name": "search"}}]
assert body["tool_choice"] == "auto"
async def test_chat_headers_include_auth(self, httpx_mock: HTTPXMock):
httpx_mock.add_response(url=CHAT_URL, json=_ok_response_body())
provider = _make_provider()
await provider.chat(_make_request())
sent = httpx_mock.get_request()
assert sent.headers["Authorization"] == "Bearer test-jwt-token"
assert sent.headers["Content-Type"] == "application/json"
async def test_chat_unexpected_status_raises_provider_error(self, httpx_mock: HTTPXMock):
httpx_mock.add_response(url=CHAT_URL, status_code=500, text="Internal Server Error")
provider = _make_provider()
with pytest.raises(LLMProviderError, match="Unexpected status 500"):
await provider.chat(_make_request())
class TestRemoteLLMProviderStream:
"""chat_stream() method tests."""
@staticmethod
def _make_stream_response(sse_lines: list[str], status_code: int = 200) -> MagicMock:
"""Create a mock httpx streaming response context manager."""
response = MagicMock()
response.status_code = status_code
async def aiter_lines():
for line in sse_lines:
yield line
response.aiter_lines = aiter_lines
response.aread = AsyncMock(return_value=b"")
response.text = ""
context = MagicMock()
context.__aenter__ = AsyncMock(return_value=response)
context.__aexit__ = AsyncMock(return_value=False)
return context
@staticmethod
def _sse(data: dict) -> str:
return f"data: {json.dumps(data)}"
async def test_stream_parses_sse(self):
sse_lines = [
TestRemoteLLMProviderStream._sse(
{"content": "Hello", "model": "gpt-4o-mini", "is_final": False}
),
"",
TestRemoteLLMProviderStream._sse(
{"content": " world", "model": "gpt-4o-mini", "is_final": False}
),
"",
TestRemoteLLMProviderStream._sse(
{
"content": "",
"model": "gpt-4o-mini",
"is_final": True,
"usage": {
"prompt_tokens": 5,
"completion_tokens": 3,
"total_tokens": 8,
},
}
),
"",
"data: [DONE]",
"",
]
provider = _make_provider()
provider._client.stream = MagicMock(return_value=self._make_stream_response(sse_lines))
chunks = []
async for chunk in provider.chat_stream(_make_request()):
chunks.append(chunk)
assert len(chunks) == 3
assert chunks[0].content == "Hello"
assert chunks[0].is_final is False
assert chunks[1].content == " world"
assert chunks[2].is_final is True
assert chunks[2].usage is not None
assert chunks[2].usage.prompt_tokens == 5
assert chunks[2].usage.completion_tokens == 3
async def test_stream_done_terminates(self):
"""data: [DONE] should stop iteration — later lines must not be yielded."""
sse_lines = [
self._sse({"content": "Hi", "model": "gpt-4o-mini", "is_final": False}),
"",
"data: [DONE]",
"",
self._sse({"content": "should not appear", "model": "gpt-4o-mini", "is_final": False}),
"",
]
provider = _make_provider()
provider._client.stream = MagicMock(return_value=self._make_stream_response(sse_lines))
chunks = []
async for chunk in provider.chat_stream(_make_request()):
chunks.append(chunk)
assert len(chunks) == 1
assert chunks[0].content == "Hi"
async def test_stream_error_raises(self):
"""An error payload in the stream should raise LLMProviderError."""
sse_lines = [
self._sse({"error": "provider_error", "detail": "upstream failed"}),
"",
]
provider = _make_provider()
provider._client.stream = MagicMock(return_value=self._make_stream_response(sse_lines))
with pytest.raises(LLMProviderError, match="Stream error"):
async for _ in provider.chat_stream(_make_request()):
pass
async def test_stream_empty(self):
"""A stream with only [DONE] yields no chunks."""
sse_lines = ["data: [DONE]", ""]
provider = _make_provider()
provider._client.stream = MagicMock(return_value=self._make_stream_response(sse_lines))
chunks = []
async for chunk in provider.chat_stream(_make_request()):
chunks.append(chunk)
assert chunks == []
async def test_stream_is_final(self):
"""A chunk with is_final=True should be correctly marked."""
sse_lines = [
self._sse(
{
"content": "response",
"model": "gpt-4o-mini",
"is_final": True,
"usage": {
"prompt_tokens": 1,
"completion_tokens": 1,
"total_tokens": 2,
},
}
),
"",
"data: [DONE]",
"",
]
provider = _make_provider()
provider._client.stream = MagicMock(return_value=self._make_stream_response(sse_lines))
chunks = []
async for chunk in provider.chat_stream(_make_request()):
chunks.append(chunk)
assert len(chunks) == 1
assert isinstance(chunks[0], StreamChunk)
assert chunks[0].is_final is True
assert chunks[0].usage is not None
assert chunks[0].usage.total_tokens == 2
async def test_stream_tool_calls_in_chunk(self):
"""Tool calls in a stream chunk should be parsed correctly."""
sse_lines = [
self._sse(
{
"content": "",
"model": "gpt-4o-mini",
"is_final": True,
"tool_calls": [
{"id": "tc_1", "name": "search", "arguments": {"q": "test"}},
],
"usage": {
"prompt_tokens": 2,
"completion_tokens": 2,
"total_tokens": 4,
},
}
),
"",
"data: [DONE]",
"",
]
provider = _make_provider()
provider._client.stream = MagicMock(return_value=self._make_stream_response(sse_lines))
chunks = []
async for chunk in provider.chat_stream(_make_request()):
chunks.append(chunk)
assert len(chunks) == 1
assert len(chunks[0].tool_calls) == 1
assert chunks[0].tool_calls[0].name == "search"
assert chunks[0].tool_calls[0].arguments == {"q": "test"}
async def test_stream_401_raises_connection_error(self):
sse_lines: list[str] = []
provider = _make_provider()
provider._client.stream = MagicMock(
return_value=self._make_stream_response(sse_lines, status_code=401)
)
with pytest.raises(ConnectionError, match="Authentication failed"):
async for _ in provider.chat_stream(_make_request()):
pass
async def test_stream_skips_non_data_lines(self):
"""Lines without 'data: ' prefix should be skipped."""
sse_lines = [
": comment line",
"",
self._sse({"content": "ok", "model": "gpt-4o-mini", "is_final": False}),
"",
"event: ping",
"data: [DONE]",
"",
]
provider = _make_provider()
provider._client.stream = MagicMock(return_value=self._make_stream_response(sse_lines))
chunks = []
async for chunk in provider.chat_stream(_make_request()):
chunks.append(chunk)
assert len(chunks) == 1
assert chunks[0].content == "ok"
class TestRemoteLLMProviderHelpers:
"""Tests for helper methods."""
def test_headers_include_bearer_token(self):
provider = RemoteLLMProvider(
server_url="https://api.example.com",
auth_token_provider=lambda: "abc123",
)
headers = provider._headers()
assert headers["Authorization"] == "Bearer abc123"
assert headers["Content-Type"] == "application/json"
def test_server_url_trailing_slash_stripped(self):
provider = RemoteLLMProvider(
server_url="https://api.example.com/",
auth_token_provider=lambda: "token",
)
assert provider._server_url == "https://api.example.com"
def test_build_payload_includes_all_fields(self):
provider = _make_provider()
request = LLMRequest(
messages=[{"role": "user", "content": "Hi"}],
model="gpt-4",
temperature=0.1,
max_tokens=50,
tools=[{"type": "function"}],
tool_choice="required",
timeout=10.0,
)
payload = provider._build_payload(request)
assert payload["messages"] == [{"role": "user", "content": "Hi"}]
assert payload["model"] == "gpt-4"
assert payload["temperature"] == 0.1
assert payload["max_tokens"] == 50
assert payload["tools"] == [{"type": "function"}]
assert payload["tool_choice"] == "required"
assert payload["timeout"] == 10.0