fischer-agentkit/tests/unit/test_llm_gateway_routes.py

266 lines
9.1 KiB
Python

"""Unit tests for LLM Gateway proxy routes (U1).
Covers:
- Non-streaming chat returns serialized LLMResponse JSON
- Streaming chat returns SSE-formatted chunks
- Invalid model returns 404 (ModelNotFoundError)
- LLM provider failure returns 502 (LLMProviderError)
- Empty messages list returns 422 (Pydantic validation)
"""
import json
from unittest.mock import AsyncMock, MagicMock
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from agentkit.core.exceptions import LLMProviderError, ModelNotFoundError
from agentkit.llm.protocol import LLMResponse, StreamChunk, TokenUsage, ToolCall
from agentkit.server.routes import llm_gateway as llm_gateway_module
def _make_response(
content: str = "Hello from LLM",
model: str = "test-model",
prompt_tokens: int = 10,
completion_tokens: int = 20,
tool_calls: list[ToolCall] | None = None,
latency_ms: float = 123.4,
) -> LLMResponse:
return LLMResponse(
content=content,
model=model,
usage=TokenUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens),
tool_calls=tool_calls or [],
latency_ms=latency_ms,
)
def _make_chunks() -> list[StreamChunk]:
return [
StreamChunk(content="Hello", model="test-model"),
StreamChunk(content=" world", model="test-model"),
StreamChunk(
content="",
model="test-model",
usage=TokenUsage(prompt_tokens=10, completion_tokens=20),
is_final=True,
),
]
@pytest.fixture
def mock_gateway():
"""A mock LLMGateway.
`chat` is an AsyncMock (returns awaitable).
`chat_stream` is a regular MagicMock that returns an async generator
when called (matching the real gateway's async-generator behavior).
"""
gateway = MagicMock()
gateway.chat = AsyncMock()
# chat_stream must return an async generator when called, not a coroutine.
gateway.chat_stream = MagicMock()
return gateway
@pytest.fixture
def app(mock_gateway):
application = FastAPI()
application.state.llm_gateway = mock_gateway
application.include_router(llm_gateway_module.router, prefix="/api/v1")
return application
@pytest.fixture
def client(app):
return TestClient(app)
class TestLLMGatewayChatNonStreaming:
"""POST /api/v1/llm/chat — non-streaming"""
def test_chat_returns_llm_response_json(self, client, mock_gateway):
"""Non-streaming chat returns serialized LLMResponse JSON."""
mock_gateway.chat.return_value = _make_response(
content="Hello from LLM",
model="test-model",
prompt_tokens=10,
completion_tokens=20,
tool_calls=[ToolCall(id="tc_1", name="search", arguments={"q": "test"})],
latency_ms=123.4,
)
response = client.post(
"/api/v1/llm/chat",
json={
"messages": [{"role": "user", "content": "Hi"}],
"model": "test-model",
},
)
assert response.status_code == 200
data = response.json()
assert data["content"] == "Hello from LLM"
assert data["model"] == "test-model"
assert data["usage"]["prompt_tokens"] == 10
assert data["usage"]["completion_tokens"] == 20
assert data["usage"]["total_tokens"] == 30
assert data["latency_ms"] == 123.4
assert len(data["tool_calls"]) == 1
assert data["tool_calls"][0]["id"] == "tc_1"
assert data["tool_calls"][0]["name"] == "search"
assert data["tool_calls"][0]["arguments"] == {"q": "test"}
# Verify gateway.chat was called with the right args
mock_gateway.chat.assert_awaited_once()
call_kwargs = mock_gateway.chat.await_args.kwargs
assert call_kwargs["model"] == "test-model"
assert call_kwargs["messages"] == [{"role": "user", "content": "Hi"}]
def test_chat_invalid_model_returns_404(self, client, mock_gateway):
"""Invalid model returns 404 (ModelNotFoundError)."""
mock_gateway.chat.side_effect = ModelNotFoundError("nonexistent/model")
response = client.post(
"/api/v1/llm/chat",
json={
"messages": [{"role": "user", "content": "Hi"}],
"model": "nonexistent/model",
},
)
assert response.status_code == 404
assert "nonexistent/model" in response.json()["detail"]
def test_chat_provider_error_returns_502(self, client, mock_gateway):
"""LLM Provider failure returns 502 (LLMProviderError)."""
mock_gateway.chat.side_effect = LLMProviderError("openai", "API error")
response = client.post(
"/api/v1/llm/chat",
json={
"messages": [{"role": "user", "content": "Hi"}],
"model": "openai/gpt-4o",
},
)
assert response.status_code == 502
assert "openai" in response.json()["detail"]
def test_chat_empty_messages_returns_422(self, client):
"""Empty messages list returns 422 (Pydantic validation)."""
response = client.post(
"/api/v1/llm/chat",
json={
"messages": [],
"model": "test-model",
},
)
assert response.status_code == 422
class TestLLMGatewayChatStream:
"""POST /api/v1/llm/chat/stream — SSE streaming"""
def test_stream_returns_sse_format(self, client, mock_gateway):
"""Streaming chat returns SSE-formatted chunks terminated by [DONE]."""
async def fake_stream(**_kwargs):
for chunk in _make_chunks():
yield chunk
mock_gateway.chat_stream.return_value = fake_stream()
with client.stream(
"POST",
"/api/v1/llm/chat/stream",
json={
"messages": [{"role": "user", "content": "Hi"}],
"model": "test-model",
},
) as response:
assert response.status_code == 200
assert response.headers["content-type"].startswith("text/event-stream")
assert response.headers["cache-control"] == "no-cache"
assert response.headers["connection"] == "keep-alive"
chunks_text = ""
for line in response.iter_lines():
chunks_text += line + "\n"
# Each data line should be valid JSON, ending with [DONE]
data_lines = [
line[len("data: ") :] for line in chunks_text.split("\n") if line.startswith("data: ")
]
assert data_lines[-1] == "[DONE]"
parsed = [json.loads(d) for d in data_lines[:-1]]
assert len(parsed) == 3
assert parsed[0]["content"] == "Hello"
assert parsed[1]["content"] == " world"
assert parsed[2]["is_final"] is True
assert parsed[2]["usage"]["total_tokens"] == 30
def test_stream_invalid_model_emits_error_then_done(self, client, mock_gateway):
"""ModelNotFoundError during stream emits error payload then [DONE]."""
async def failing_stream(**_kwargs):
raise ModelNotFoundError("nonexistent/model")
yield # pragma: no cover - makes this an async generator
mock_gateway.chat_stream.return_value = failing_stream()
with client.stream(
"POST",
"/api/v1/llm/chat/stream",
json={
"messages": [{"role": "user", "content": "Hi"}],
"model": "nonexistent/model",
},
) as response:
assert response.status_code == 200
chunks_text = ""
for line in response.iter_lines():
chunks_text += line + "\n"
data_lines = [
line[len("data: ") :] for line in chunks_text.split("\n") if line.startswith("data: ")
]
assert data_lines[-1] == "[DONE]"
error_payload = json.loads(data_lines[0])
assert error_payload["error"] == "model_not_found"
assert "nonexistent/model" in error_payload["detail"]
def test_stream_provider_error_emits_error_then_done(self, client, mock_gateway):
"""LLMProviderError during stream emits error payload then [DONE]."""
async def failing_stream(**_kwargs):
raise LLMProviderError("openai", "API error")
yield # pragma: no cover - makes this an async generator
mock_gateway.chat_stream.return_value = failing_stream()
with client.stream(
"POST",
"/api/v1/llm/chat/stream",
json={
"messages": [{"role": "user", "content": "Hi"}],
"model": "openai/gpt-4o",
},
) as response:
assert response.status_code == 200
chunks_text = ""
for line in response.iter_lines():
chunks_text += line + "\n"
data_lines = [
line[len("data: ") :] for line in chunks_text.split("\n") if line.startswith("data: ")
]
assert data_lines[-1] == "[DONE]"
error_payload = json.loads(data_lines[0])
assert error_payload["error"] == "provider_error"
assert "openai" in error_payload["detail"]