"""Unit tests for LLM Gateway proxy routes (U1). Covers: - Non-streaming chat returns serialized LLMResponse JSON - Streaming chat returns SSE-formatted chunks - Invalid model returns 404 (ModelNotFoundError) - LLM provider failure returns 502 (LLMProviderError) - Empty messages list returns 422 (Pydantic validation) """ import json from unittest.mock import AsyncMock, MagicMock import pytest from fastapi import FastAPI from fastapi.testclient import TestClient from agentkit.core.exceptions import LLMProviderError, ModelNotFoundError from agentkit.llm.protocol import LLMResponse, StreamChunk, TokenUsage, ToolCall from agentkit.server.routes import llm_gateway as llm_gateway_module def _make_response( content: str = "Hello from LLM", model: str = "test-model", prompt_tokens: int = 10, completion_tokens: int = 20, tool_calls: list[ToolCall] | None = None, latency_ms: float = 123.4, ) -> LLMResponse: return LLMResponse( content=content, model=model, usage=TokenUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens), tool_calls=tool_calls or [], latency_ms=latency_ms, ) def _make_chunks() -> list[StreamChunk]: return [ StreamChunk(content="Hello", model="test-model"), StreamChunk(content=" world", model="test-model"), StreamChunk( content="", model="test-model", usage=TokenUsage(prompt_tokens=10, completion_tokens=20), is_final=True, ), ] @pytest.fixture def mock_gateway(): """A mock LLMGateway. `chat` is an AsyncMock (returns awaitable). `chat_stream` is a regular MagicMock that returns an async generator when called (matching the real gateway's async-generator behavior). """ gateway = MagicMock() gateway.chat = AsyncMock() # chat_stream must return an async generator when called, not a coroutine. gateway.chat_stream = MagicMock() return gateway @pytest.fixture def app(mock_gateway): application = FastAPI() application.state.llm_gateway = mock_gateway application.include_router(llm_gateway_module.router, prefix="/api/v1") return application @pytest.fixture def client(app): return TestClient(app) class TestLLMGatewayChatNonStreaming: """POST /api/v1/llm/chat — non-streaming""" def test_chat_returns_llm_response_json(self, client, mock_gateway): """Non-streaming chat returns serialized LLMResponse JSON.""" mock_gateway.chat.return_value = _make_response( content="Hello from LLM", model="test-model", prompt_tokens=10, completion_tokens=20, tool_calls=[ToolCall(id="tc_1", name="search", arguments={"q": "test"})], latency_ms=123.4, ) response = client.post( "/api/v1/llm/chat", json={ "messages": [{"role": "user", "content": "Hi"}], "model": "test-model", }, ) assert response.status_code == 200 data = response.json() assert data["content"] == "Hello from LLM" assert data["model"] == "test-model" assert data["usage"]["prompt_tokens"] == 10 assert data["usage"]["completion_tokens"] == 20 assert data["usage"]["total_tokens"] == 30 assert data["latency_ms"] == 123.4 assert len(data["tool_calls"]) == 1 assert data["tool_calls"][0]["id"] == "tc_1" assert data["tool_calls"][0]["name"] == "search" assert data["tool_calls"][0]["arguments"] == {"q": "test"} # Verify gateway.chat was called with the right args mock_gateway.chat.assert_awaited_once() call_kwargs = mock_gateway.chat.await_args.kwargs assert call_kwargs["model"] == "test-model" assert call_kwargs["messages"] == [{"role": "user", "content": "Hi"}] def test_chat_invalid_model_returns_404(self, client, mock_gateway): """Invalid model returns 404 (ModelNotFoundError).""" mock_gateway.chat.side_effect = ModelNotFoundError("nonexistent/model") response = client.post( "/api/v1/llm/chat", json={ "messages": [{"role": "user", "content": "Hi"}], "model": "nonexistent/model", }, ) assert response.status_code == 404 assert "nonexistent/model" in response.json()["detail"] def test_chat_provider_error_returns_502(self, client, mock_gateway): """LLM Provider failure returns 502 (LLMProviderError).""" mock_gateway.chat.side_effect = LLMProviderError("openai", "API error") response = client.post( "/api/v1/llm/chat", json={ "messages": [{"role": "user", "content": "Hi"}], "model": "openai/gpt-4o", }, ) assert response.status_code == 502 assert "openai" in response.json()["detail"] def test_chat_empty_messages_returns_422(self, client): """Empty messages list returns 422 (Pydantic validation).""" response = client.post( "/api/v1/llm/chat", json={ "messages": [], "model": "test-model", }, ) assert response.status_code == 422 class TestLLMGatewayChatStream: """POST /api/v1/llm/chat/stream — SSE streaming""" def test_stream_returns_sse_format(self, client, mock_gateway): """Streaming chat returns SSE-formatted chunks terminated by [DONE].""" async def fake_stream(**_kwargs): for chunk in _make_chunks(): yield chunk mock_gateway.chat_stream.return_value = fake_stream() with client.stream( "POST", "/api/v1/llm/chat/stream", json={ "messages": [{"role": "user", "content": "Hi"}], "model": "test-model", }, ) as response: assert response.status_code == 200 assert response.headers["content-type"].startswith("text/event-stream") assert response.headers["cache-control"] == "no-cache" assert response.headers["connection"] == "keep-alive" chunks_text = "" for line in response.iter_lines(): chunks_text += line + "\n" # Each data line should be valid JSON, ending with [DONE] data_lines = [ line[len("data: ") :] for line in chunks_text.split("\n") if line.startswith("data: ") ] assert data_lines[-1] == "[DONE]" parsed = [json.loads(d) for d in data_lines[:-1]] assert len(parsed) == 3 assert parsed[0]["content"] == "Hello" assert parsed[1]["content"] == " world" assert parsed[2]["is_final"] is True assert parsed[2]["usage"]["total_tokens"] == 30 def test_stream_invalid_model_emits_error_then_done(self, client, mock_gateway): """ModelNotFoundError during stream emits error payload then [DONE].""" async def failing_stream(**_kwargs): raise ModelNotFoundError("nonexistent/model") yield # pragma: no cover - makes this an async generator mock_gateway.chat_stream.return_value = failing_stream() with client.stream( "POST", "/api/v1/llm/chat/stream", json={ "messages": [{"role": "user", "content": "Hi"}], "model": "nonexistent/model", }, ) as response: assert response.status_code == 200 chunks_text = "" for line in response.iter_lines(): chunks_text += line + "\n" data_lines = [ line[len("data: ") :] for line in chunks_text.split("\n") if line.startswith("data: ") ] assert data_lines[-1] == "[DONE]" error_payload = json.loads(data_lines[0]) assert error_payload["error"] == "model_not_found" assert "nonexistent/model" in error_payload["detail"] def test_stream_provider_error_emits_error_then_done(self, client, mock_gateway): """LLMProviderError during stream emits error payload then [DONE].""" async def failing_stream(**_kwargs): raise LLMProviderError("openai", "API error") yield # pragma: no cover - makes this an async generator mock_gateway.chat_stream.return_value = failing_stream() with client.stream( "POST", "/api/v1/llm/chat/stream", json={ "messages": [{"role": "user", "content": "Hi"}], "model": "openai/gpt-4o", }, ) as response: assert response.status_code == 200 chunks_text = "" for line in response.iter_lines(): chunks_text += line + "\n" data_lines = [ line[len("data: ") :] for line in chunks_text.split("\n") if line.startswith("data: ") ] assert data_lines[-1] == "[DONE]" error_payload = json.loads(data_lines[0]) assert error_payload["error"] == "provider_error" assert "openai" in error_payload["detail"]