420 lines
14 KiB
Python
420 lines
14 KiB
Python
"""Tests for RemoteLLMProvider (U3) — client-side LLM gateway forwarding."""
|
|
|
|
import json
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import httpx
|
|
import pytest
|
|
from pytest_httpx import HTTPXMock
|
|
|
|
from agentkit.core.exceptions import LLMProviderError, ModelNotFoundError
|
|
from agentkit.llm.protocol import LLMRequest, LLMResponse, StreamChunk
|
|
from agentkit.llm.remote_provider import RemoteLLMProvider
|
|
|
|
SERVER_URL = "http://test-server:8000"
|
|
CHAT_URL = f"{SERVER_URL}/api/v1/llm/chat"
|
|
STREAM_URL = f"{SERVER_URL}/api/v1/llm/chat/stream"
|
|
|
|
|
|
def _make_provider() -> RemoteLLMProvider:
|
|
return RemoteLLMProvider(
|
|
server_url=SERVER_URL,
|
|
auth_token_provider=lambda: "test-jwt-token",
|
|
timeout=30.0,
|
|
)
|
|
|
|
|
|
def _make_request() -> LLMRequest:
|
|
return LLMRequest(
|
|
messages=[{"role": "user", "content": "Hello"}],
|
|
model="gpt-4o-mini",
|
|
temperature=0.5,
|
|
max_tokens=1000,
|
|
)
|
|
|
|
|
|
def _ok_response_body(
|
|
content: str = "Hi there!",
|
|
model: str = "gpt-4o-mini",
|
|
tool_calls: list | None = None,
|
|
) -> dict:
|
|
return {
|
|
"content": content,
|
|
"model": model,
|
|
"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
|
|
"tool_calls": tool_calls or [],
|
|
"latency_ms": 123.4,
|
|
}
|
|
|
|
|
|
class TestRemoteLLMProviderChat:
|
|
"""chat() method tests."""
|
|
|
|
async def test_chat_returns_llm_response(self, httpx_mock: HTTPXMock):
|
|
httpx_mock.add_response(url=CHAT_URL, json=_ok_response_body())
|
|
|
|
provider = _make_provider()
|
|
response = await provider.chat(_make_request())
|
|
|
|
assert isinstance(response, LLMResponse)
|
|
assert response.content == "Hi there!"
|
|
assert response.model == "gpt-4o-mini"
|
|
assert response.usage.prompt_tokens == 10
|
|
assert response.usage.completion_tokens == 5
|
|
assert response.usage.total_tokens == 15
|
|
assert response.latency_ms == 123.4
|
|
assert response.tool_calls == []
|
|
|
|
async def test_chat_with_tool_calls(self, httpx_mock: HTTPXMock):
|
|
httpx_mock.add_response(
|
|
url=CHAT_URL,
|
|
json=_ok_response_body(
|
|
content="",
|
|
tool_calls=[
|
|
{"id": "call_1", "name": "get_weather", "arguments": {"city": "Beijing"}},
|
|
],
|
|
),
|
|
)
|
|
|
|
provider = _make_provider()
|
|
response = await provider.chat(_make_request())
|
|
|
|
assert response.has_tool_calls
|
|
assert len(response.tool_calls) == 1
|
|
tc = response.tool_calls[0]
|
|
assert tc.id == "call_1"
|
|
assert tc.name == "get_weather"
|
|
assert tc.arguments == {"city": "Beijing"}
|
|
|
|
async def test_chat_401_raises_connection_error(self, httpx_mock: HTTPXMock):
|
|
httpx_mock.add_response(url=CHAT_URL, status_code=401, json={"detail": "Unauthorized"})
|
|
|
|
provider = _make_provider()
|
|
with pytest.raises(ConnectionError, match="Authentication failed"):
|
|
await provider.chat(_make_request())
|
|
|
|
async def test_chat_404_raises_model_not_found(self, httpx_mock: HTTPXMock):
|
|
httpx_mock.add_response(
|
|
url=CHAT_URL, status_code=404, json={"detail": "Model not found: gpt-4o-mini"}
|
|
)
|
|
|
|
provider = _make_provider()
|
|
with pytest.raises(ModelNotFoundError):
|
|
await provider.chat(_make_request())
|
|
|
|
async def test_chat_502_raises_provider_error(self, httpx_mock: HTTPXMock):
|
|
httpx_mock.add_response(
|
|
url=CHAT_URL, status_code=502, json={"detail": "Upstream provider error"}
|
|
)
|
|
|
|
provider = _make_provider()
|
|
with pytest.raises(LLMProviderError, match="Server LLM gateway error"):
|
|
await provider.chat(_make_request())
|
|
|
|
async def test_chat_timeout_raises_provider_error(self, httpx_mock: HTTPXMock):
|
|
httpx_mock.add_exception(httpx.ReadTimeout("read timeout"))
|
|
|
|
provider = _make_provider()
|
|
with pytest.raises(LLMProviderError, match="Request timeout"):
|
|
await provider.chat(_make_request())
|
|
|
|
async def test_chat_payload_correct(self, httpx_mock: HTTPXMock):
|
|
httpx_mock.add_response(url=CHAT_URL, json=_ok_response_body())
|
|
|
|
provider = _make_provider()
|
|
request = LLMRequest(
|
|
messages=[{"role": "user", "content": "Hi"}],
|
|
model="gpt-4o-mini",
|
|
temperature=0.3,
|
|
max_tokens=500,
|
|
tools=[{"type": "function", "function": {"name": "search"}}],
|
|
tool_choice="auto",
|
|
)
|
|
await provider.chat(request)
|
|
|
|
sent = httpx_mock.get_request()
|
|
body = json.loads(sent.content)
|
|
assert body["messages"] == [{"role": "user", "content": "Hi"}]
|
|
assert body["model"] == "gpt-4o-mini"
|
|
assert body["temperature"] == 0.3
|
|
assert body["max_tokens"] == 500
|
|
assert body["tools"] == [{"type": "function", "function": {"name": "search"}}]
|
|
assert body["tool_choice"] == "auto"
|
|
|
|
async def test_chat_headers_include_auth(self, httpx_mock: HTTPXMock):
|
|
httpx_mock.add_response(url=CHAT_URL, json=_ok_response_body())
|
|
|
|
provider = _make_provider()
|
|
await provider.chat(_make_request())
|
|
|
|
sent = httpx_mock.get_request()
|
|
assert sent.headers["Authorization"] == "Bearer test-jwt-token"
|
|
assert sent.headers["Content-Type"] == "application/json"
|
|
|
|
async def test_chat_unexpected_status_raises_provider_error(self, httpx_mock: HTTPXMock):
|
|
httpx_mock.add_response(url=CHAT_URL, status_code=500, text="Internal Server Error")
|
|
|
|
provider = _make_provider()
|
|
with pytest.raises(LLMProviderError, match="Unexpected status 500"):
|
|
await provider.chat(_make_request())
|
|
|
|
|
|
class TestRemoteLLMProviderStream:
|
|
"""chat_stream() method tests."""
|
|
|
|
@staticmethod
|
|
def _make_stream_response(sse_lines: list[str], status_code: int = 200) -> MagicMock:
|
|
"""Create a mock httpx streaming response context manager."""
|
|
response = MagicMock()
|
|
response.status_code = status_code
|
|
|
|
async def aiter_lines():
|
|
for line in sse_lines:
|
|
yield line
|
|
|
|
response.aiter_lines = aiter_lines
|
|
response.aread = AsyncMock(return_value=b"")
|
|
response.text = ""
|
|
|
|
context = MagicMock()
|
|
context.__aenter__ = AsyncMock(return_value=response)
|
|
context.__aexit__ = AsyncMock(return_value=False)
|
|
return context
|
|
|
|
@staticmethod
|
|
def _sse(data: dict) -> str:
|
|
return f"data: {json.dumps(data)}"
|
|
|
|
async def test_stream_parses_sse(self):
|
|
sse_lines = [
|
|
TestRemoteLLMProviderStream._sse(
|
|
{"content": "Hello", "model": "gpt-4o-mini", "is_final": False}
|
|
),
|
|
"",
|
|
TestRemoteLLMProviderStream._sse(
|
|
{"content": " world", "model": "gpt-4o-mini", "is_final": False}
|
|
),
|
|
"",
|
|
TestRemoteLLMProviderStream._sse(
|
|
{
|
|
"content": "",
|
|
"model": "gpt-4o-mini",
|
|
"is_final": True,
|
|
"usage": {
|
|
"prompt_tokens": 5,
|
|
"completion_tokens": 3,
|
|
"total_tokens": 8,
|
|
},
|
|
}
|
|
),
|
|
"",
|
|
"data: [DONE]",
|
|
"",
|
|
]
|
|
|
|
provider = _make_provider()
|
|
provider._client.stream = MagicMock(return_value=self._make_stream_response(sse_lines))
|
|
|
|
chunks = []
|
|
async for chunk in provider.chat_stream(_make_request()):
|
|
chunks.append(chunk)
|
|
|
|
assert len(chunks) == 3
|
|
assert chunks[0].content == "Hello"
|
|
assert chunks[0].is_final is False
|
|
assert chunks[1].content == " world"
|
|
assert chunks[2].is_final is True
|
|
assert chunks[2].usage is not None
|
|
assert chunks[2].usage.prompt_tokens == 5
|
|
assert chunks[2].usage.completion_tokens == 3
|
|
|
|
async def test_stream_done_terminates(self):
|
|
"""data: [DONE] should stop iteration — later lines must not be yielded."""
|
|
sse_lines = [
|
|
self._sse({"content": "Hi", "model": "gpt-4o-mini", "is_final": False}),
|
|
"",
|
|
"data: [DONE]",
|
|
"",
|
|
self._sse({"content": "should not appear", "model": "gpt-4o-mini", "is_final": False}),
|
|
"",
|
|
]
|
|
|
|
provider = _make_provider()
|
|
provider._client.stream = MagicMock(return_value=self._make_stream_response(sse_lines))
|
|
|
|
chunks = []
|
|
async for chunk in provider.chat_stream(_make_request()):
|
|
chunks.append(chunk)
|
|
|
|
assert len(chunks) == 1
|
|
assert chunks[0].content == "Hi"
|
|
|
|
async def test_stream_error_raises(self):
|
|
"""An error payload in the stream should raise LLMProviderError."""
|
|
sse_lines = [
|
|
self._sse({"error": "provider_error", "detail": "upstream failed"}),
|
|
"",
|
|
]
|
|
|
|
provider = _make_provider()
|
|
provider._client.stream = MagicMock(return_value=self._make_stream_response(sse_lines))
|
|
|
|
with pytest.raises(LLMProviderError, match="Stream error"):
|
|
async for _ in provider.chat_stream(_make_request()):
|
|
pass
|
|
|
|
async def test_stream_empty(self):
|
|
"""A stream with only [DONE] yields no chunks."""
|
|
sse_lines = ["data: [DONE]", ""]
|
|
|
|
provider = _make_provider()
|
|
provider._client.stream = MagicMock(return_value=self._make_stream_response(sse_lines))
|
|
|
|
chunks = []
|
|
async for chunk in provider.chat_stream(_make_request()):
|
|
chunks.append(chunk)
|
|
|
|
assert chunks == []
|
|
|
|
async def test_stream_is_final(self):
|
|
"""A chunk with is_final=True should be correctly marked."""
|
|
sse_lines = [
|
|
self._sse(
|
|
{
|
|
"content": "response",
|
|
"model": "gpt-4o-mini",
|
|
"is_final": True,
|
|
"usage": {
|
|
"prompt_tokens": 1,
|
|
"completion_tokens": 1,
|
|
"total_tokens": 2,
|
|
},
|
|
}
|
|
),
|
|
"",
|
|
"data: [DONE]",
|
|
"",
|
|
]
|
|
|
|
provider = _make_provider()
|
|
provider._client.stream = MagicMock(return_value=self._make_stream_response(sse_lines))
|
|
|
|
chunks = []
|
|
async for chunk in provider.chat_stream(_make_request()):
|
|
chunks.append(chunk)
|
|
|
|
assert len(chunks) == 1
|
|
assert isinstance(chunks[0], StreamChunk)
|
|
assert chunks[0].is_final is True
|
|
assert chunks[0].usage is not None
|
|
assert chunks[0].usage.total_tokens == 2
|
|
|
|
async def test_stream_tool_calls_in_chunk(self):
|
|
"""Tool calls in a stream chunk should be parsed correctly."""
|
|
sse_lines = [
|
|
self._sse(
|
|
{
|
|
"content": "",
|
|
"model": "gpt-4o-mini",
|
|
"is_final": True,
|
|
"tool_calls": [
|
|
{"id": "tc_1", "name": "search", "arguments": {"q": "test"}},
|
|
],
|
|
"usage": {
|
|
"prompt_tokens": 2,
|
|
"completion_tokens": 2,
|
|
"total_tokens": 4,
|
|
},
|
|
}
|
|
),
|
|
"",
|
|
"data: [DONE]",
|
|
"",
|
|
]
|
|
|
|
provider = _make_provider()
|
|
provider._client.stream = MagicMock(return_value=self._make_stream_response(sse_lines))
|
|
|
|
chunks = []
|
|
async for chunk in provider.chat_stream(_make_request()):
|
|
chunks.append(chunk)
|
|
|
|
assert len(chunks) == 1
|
|
assert len(chunks[0].tool_calls) == 1
|
|
assert chunks[0].tool_calls[0].name == "search"
|
|
assert chunks[0].tool_calls[0].arguments == {"q": "test"}
|
|
|
|
async def test_stream_401_raises_connection_error(self):
|
|
sse_lines: list[str] = []
|
|
|
|
provider = _make_provider()
|
|
provider._client.stream = MagicMock(
|
|
return_value=self._make_stream_response(sse_lines, status_code=401)
|
|
)
|
|
|
|
with pytest.raises(ConnectionError, match="Authentication failed"):
|
|
async for _ in provider.chat_stream(_make_request()):
|
|
pass
|
|
|
|
async def test_stream_skips_non_data_lines(self):
|
|
"""Lines without 'data: ' prefix should be skipped."""
|
|
sse_lines = [
|
|
": comment line",
|
|
"",
|
|
self._sse({"content": "ok", "model": "gpt-4o-mini", "is_final": False}),
|
|
"",
|
|
"event: ping",
|
|
"data: [DONE]",
|
|
"",
|
|
]
|
|
|
|
provider = _make_provider()
|
|
provider._client.stream = MagicMock(return_value=self._make_stream_response(sse_lines))
|
|
|
|
chunks = []
|
|
async for chunk in provider.chat_stream(_make_request()):
|
|
chunks.append(chunk)
|
|
|
|
assert len(chunks) == 1
|
|
assert chunks[0].content == "ok"
|
|
|
|
|
|
class TestRemoteLLMProviderHelpers:
|
|
"""Tests for helper methods."""
|
|
|
|
def test_headers_include_bearer_token(self):
|
|
provider = RemoteLLMProvider(
|
|
server_url="https://api.example.com",
|
|
auth_token_provider=lambda: "abc123",
|
|
)
|
|
headers = provider._headers()
|
|
assert headers["Authorization"] == "Bearer abc123"
|
|
assert headers["Content-Type"] == "application/json"
|
|
|
|
def test_server_url_trailing_slash_stripped(self):
|
|
provider = RemoteLLMProvider(
|
|
server_url="https://api.example.com/",
|
|
auth_token_provider=lambda: "token",
|
|
)
|
|
assert provider._server_url == "https://api.example.com"
|
|
|
|
def test_build_payload_includes_all_fields(self):
|
|
provider = _make_provider()
|
|
request = LLMRequest(
|
|
messages=[{"role": "user", "content": "Hi"}],
|
|
model="gpt-4",
|
|
temperature=0.1,
|
|
max_tokens=50,
|
|
tools=[{"type": "function"}],
|
|
tool_choice="required",
|
|
timeout=10.0,
|
|
)
|
|
payload = provider._build_payload(request)
|
|
assert payload["messages"] == [{"role": "user", "content": "Hi"}]
|
|
assert payload["model"] == "gpt-4"
|
|
assert payload["temperature"] == 0.1
|
|
assert payload["max_tokens"] == 50
|
|
assert payload["tools"] == [{"type": "function"}]
|
|
assert payload["tool_choice"] == "required"
|
|
assert payload["timeout"] == 10.0
|