266 lines
9.1 KiB
Python
266 lines
9.1 KiB
Python
"""Unit tests for LLM Gateway proxy routes (U1).
|
|
|
|
Covers:
|
|
- Non-streaming chat returns serialized LLMResponse JSON
|
|
- Streaming chat returns SSE-formatted chunks
|
|
- Invalid model returns 404 (ModelNotFoundError)
|
|
- LLM provider failure returns 502 (LLMProviderError)
|
|
- Empty messages list returns 422 (Pydantic validation)
|
|
"""
|
|
|
|
import json
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
from fastapi import FastAPI
|
|
from fastapi.testclient import TestClient
|
|
|
|
from agentkit.core.exceptions import LLMProviderError, ModelNotFoundError
|
|
from agentkit.llm.protocol import LLMResponse, StreamChunk, TokenUsage, ToolCall
|
|
from agentkit.server.routes import llm_gateway as llm_gateway_module
|
|
|
|
|
|
def _make_response(
|
|
content: str = "Hello from LLM",
|
|
model: str = "test-model",
|
|
prompt_tokens: int = 10,
|
|
completion_tokens: int = 20,
|
|
tool_calls: list[ToolCall] | None = None,
|
|
latency_ms: float = 123.4,
|
|
) -> LLMResponse:
|
|
return LLMResponse(
|
|
content=content,
|
|
model=model,
|
|
usage=TokenUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens),
|
|
tool_calls=tool_calls or [],
|
|
latency_ms=latency_ms,
|
|
)
|
|
|
|
|
|
def _make_chunks() -> list[StreamChunk]:
|
|
return [
|
|
StreamChunk(content="Hello", model="test-model"),
|
|
StreamChunk(content=" world", model="test-model"),
|
|
StreamChunk(
|
|
content="",
|
|
model="test-model",
|
|
usage=TokenUsage(prompt_tokens=10, completion_tokens=20),
|
|
is_final=True,
|
|
),
|
|
]
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_gateway():
|
|
"""A mock LLMGateway.
|
|
|
|
`chat` is an AsyncMock (returns awaitable).
|
|
`chat_stream` is a regular MagicMock that returns an async generator
|
|
when called (matching the real gateway's async-generator behavior).
|
|
"""
|
|
gateway = MagicMock()
|
|
gateway.chat = AsyncMock()
|
|
# chat_stream must return an async generator when called, not a coroutine.
|
|
gateway.chat_stream = MagicMock()
|
|
return gateway
|
|
|
|
|
|
@pytest.fixture
|
|
def app(mock_gateway):
|
|
application = FastAPI()
|
|
application.state.llm_gateway = mock_gateway
|
|
application.include_router(llm_gateway_module.router, prefix="/api/v1")
|
|
return application
|
|
|
|
|
|
@pytest.fixture
|
|
def client(app):
|
|
return TestClient(app)
|
|
|
|
|
|
class TestLLMGatewayChatNonStreaming:
|
|
"""POST /api/v1/llm/chat — non-streaming"""
|
|
|
|
def test_chat_returns_llm_response_json(self, client, mock_gateway):
|
|
"""Non-streaming chat returns serialized LLMResponse JSON."""
|
|
mock_gateway.chat.return_value = _make_response(
|
|
content="Hello from LLM",
|
|
model="test-model",
|
|
prompt_tokens=10,
|
|
completion_tokens=20,
|
|
tool_calls=[ToolCall(id="tc_1", name="search", arguments={"q": "test"})],
|
|
latency_ms=123.4,
|
|
)
|
|
|
|
response = client.post(
|
|
"/api/v1/llm/chat",
|
|
json={
|
|
"messages": [{"role": "user", "content": "Hi"}],
|
|
"model": "test-model",
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["content"] == "Hello from LLM"
|
|
assert data["model"] == "test-model"
|
|
assert data["usage"]["prompt_tokens"] == 10
|
|
assert data["usage"]["completion_tokens"] == 20
|
|
assert data["usage"]["total_tokens"] == 30
|
|
assert data["latency_ms"] == 123.4
|
|
assert len(data["tool_calls"]) == 1
|
|
assert data["tool_calls"][0]["id"] == "tc_1"
|
|
assert data["tool_calls"][0]["name"] == "search"
|
|
assert data["tool_calls"][0]["arguments"] == {"q": "test"}
|
|
|
|
# Verify gateway.chat was called with the right args
|
|
mock_gateway.chat.assert_awaited_once()
|
|
call_kwargs = mock_gateway.chat.await_args.kwargs
|
|
assert call_kwargs["model"] == "test-model"
|
|
assert call_kwargs["messages"] == [{"role": "user", "content": "Hi"}]
|
|
|
|
def test_chat_invalid_model_returns_404(self, client, mock_gateway):
|
|
"""Invalid model returns 404 (ModelNotFoundError)."""
|
|
mock_gateway.chat.side_effect = ModelNotFoundError("nonexistent/model")
|
|
|
|
response = client.post(
|
|
"/api/v1/llm/chat",
|
|
json={
|
|
"messages": [{"role": "user", "content": "Hi"}],
|
|
"model": "nonexistent/model",
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 404
|
|
assert "nonexistent/model" in response.json()["detail"]
|
|
|
|
def test_chat_provider_error_returns_502(self, client, mock_gateway):
|
|
"""LLM Provider failure returns 502 (LLMProviderError)."""
|
|
mock_gateway.chat.side_effect = LLMProviderError("openai", "API error")
|
|
|
|
response = client.post(
|
|
"/api/v1/llm/chat",
|
|
json={
|
|
"messages": [{"role": "user", "content": "Hi"}],
|
|
"model": "openai/gpt-4o",
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 502
|
|
assert "openai" in response.json()["detail"]
|
|
|
|
def test_chat_empty_messages_returns_422(self, client):
|
|
"""Empty messages list returns 422 (Pydantic validation)."""
|
|
response = client.post(
|
|
"/api/v1/llm/chat",
|
|
json={
|
|
"messages": [],
|
|
"model": "test-model",
|
|
},
|
|
)
|
|
|
|
assert response.status_code == 422
|
|
|
|
|
|
class TestLLMGatewayChatStream:
|
|
"""POST /api/v1/llm/chat/stream — SSE streaming"""
|
|
|
|
def test_stream_returns_sse_format(self, client, mock_gateway):
|
|
"""Streaming chat returns SSE-formatted chunks terminated by [DONE]."""
|
|
|
|
async def fake_stream(**_kwargs):
|
|
for chunk in _make_chunks():
|
|
yield chunk
|
|
|
|
mock_gateway.chat_stream.return_value = fake_stream()
|
|
|
|
with client.stream(
|
|
"POST",
|
|
"/api/v1/llm/chat/stream",
|
|
json={
|
|
"messages": [{"role": "user", "content": "Hi"}],
|
|
"model": "test-model",
|
|
},
|
|
) as response:
|
|
assert response.status_code == 200
|
|
assert response.headers["content-type"].startswith("text/event-stream")
|
|
assert response.headers["cache-control"] == "no-cache"
|
|
assert response.headers["connection"] == "keep-alive"
|
|
|
|
chunks_text = ""
|
|
for line in response.iter_lines():
|
|
chunks_text += line + "\n"
|
|
|
|
# Each data line should be valid JSON, ending with [DONE]
|
|
data_lines = [
|
|
line[len("data: ") :] for line in chunks_text.split("\n") if line.startswith("data: ")
|
|
]
|
|
assert data_lines[-1] == "[DONE]"
|
|
|
|
parsed = [json.loads(d) for d in data_lines[:-1]]
|
|
assert len(parsed) == 3
|
|
assert parsed[0]["content"] == "Hello"
|
|
assert parsed[1]["content"] == " world"
|
|
assert parsed[2]["is_final"] is True
|
|
assert parsed[2]["usage"]["total_tokens"] == 30
|
|
|
|
def test_stream_invalid_model_emits_error_then_done(self, client, mock_gateway):
|
|
"""ModelNotFoundError during stream emits error payload then [DONE]."""
|
|
|
|
async def failing_stream(**_kwargs):
|
|
raise ModelNotFoundError("nonexistent/model")
|
|
yield # pragma: no cover - makes this an async generator
|
|
|
|
mock_gateway.chat_stream.return_value = failing_stream()
|
|
|
|
with client.stream(
|
|
"POST",
|
|
"/api/v1/llm/chat/stream",
|
|
json={
|
|
"messages": [{"role": "user", "content": "Hi"}],
|
|
"model": "nonexistent/model",
|
|
},
|
|
) as response:
|
|
assert response.status_code == 200
|
|
chunks_text = ""
|
|
for line in response.iter_lines():
|
|
chunks_text += line + "\n"
|
|
|
|
data_lines = [
|
|
line[len("data: ") :] for line in chunks_text.split("\n") if line.startswith("data: ")
|
|
]
|
|
assert data_lines[-1] == "[DONE]"
|
|
error_payload = json.loads(data_lines[0])
|
|
assert error_payload["error"] == "model_not_found"
|
|
assert "nonexistent/model" in error_payload["detail"]
|
|
|
|
def test_stream_provider_error_emits_error_then_done(self, client, mock_gateway):
|
|
"""LLMProviderError during stream emits error payload then [DONE]."""
|
|
|
|
async def failing_stream(**_kwargs):
|
|
raise LLMProviderError("openai", "API error")
|
|
yield # pragma: no cover - makes this an async generator
|
|
|
|
mock_gateway.chat_stream.return_value = failing_stream()
|
|
|
|
with client.stream(
|
|
"POST",
|
|
"/api/v1/llm/chat/stream",
|
|
json={
|
|
"messages": [{"role": "user", "content": "Hi"}],
|
|
"model": "openai/gpt-4o",
|
|
},
|
|
) as response:
|
|
assert response.status_code == 200
|
|
chunks_text = ""
|
|
for line in response.iter_lines():
|
|
chunks_text += line + "\n"
|
|
|
|
data_lines = [
|
|
line[len("data: ") :] for line in chunks_text.split("\n") if line.startswith("data: ")
|
|
]
|
|
assert data_lines[-1] == "[DONE]"
|
|
error_payload = json.loads(data_lines[0])
|
|
assert error_payload["error"] == "provider_error"
|
|
assert "openai" in error_payload["detail"]
|