141 lines
5.1 KiB
Python
141 lines
5.1 KiB
Python
"""Tests for Chat API routes."""
|
|
|
|
import tempfile
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
from fastapi.testclient import TestClient
|
|
|
|
from agentkit.session.manager import SessionManager
|
|
from agentkit.session.store import InMemorySessionStore
|
|
from agentkit.session.models import SessionStatus
|
|
|
|
|
|
@pytest.fixture
|
|
def app_with_chat():
|
|
"""Create a FastAPI app with Chat routes and mocked dependencies."""
|
|
from fastapi import FastAPI
|
|
from agentkit.server.routes.chat import router
|
|
|
|
app = FastAPI()
|
|
app.include_router(router, prefix="/api/v1")
|
|
|
|
# Mock app.state dependencies
|
|
app.state.session_manager = SessionManager(store=InMemorySessionStore())
|
|
app.state.llm_gateway = MagicMock()
|
|
app.state.agent_pool = MagicMock()
|
|
app.state.server_config = MagicMock()
|
|
app.state.server_config.api_key = None
|
|
|
|
return app
|
|
|
|
|
|
@pytest.fixture
|
|
def client(app_with_chat):
|
|
return TestClient(app_with_chat)
|
|
|
|
|
|
class TestChatSessionCRUD:
|
|
def test_create_session(self, client):
|
|
resp = client.post("/api/v1/chat/sessions", json={"agent_name": "test-agent"})
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["agent_name"] == "test-agent"
|
|
assert data["status"] == "active"
|
|
assert "session_id" in data
|
|
|
|
def test_get_session(self, client):
|
|
create_resp = client.post("/api/v1/chat/sessions", json={"agent_name": "test-agent"})
|
|
session_id = create_resp.json()["session_id"]
|
|
|
|
get_resp = client.get(f"/api/v1/chat/sessions/{session_id}")
|
|
assert get_resp.status_code == 200
|
|
assert get_resp.json()["session_id"] == session_id
|
|
|
|
def test_get_nonexistent_session(self, client):
|
|
resp = client.get("/api/v1/chat/sessions/nonexistent")
|
|
assert resp.status_code == 404
|
|
|
|
def test_close_session(self, client):
|
|
create_resp = client.post("/api/v1/chat/sessions", json={"agent_name": "test-agent"})
|
|
session_id = create_resp.json()["session_id"]
|
|
|
|
close_resp = client.delete(f"/api/v1/chat/sessions/{session_id}")
|
|
assert close_resp.status_code == 200
|
|
assert close_resp.json()["status"] == "closed"
|
|
|
|
def test_close_nonexistent_session(self, client):
|
|
resp = client.delete("/api/v1/chat/sessions/nonexistent")
|
|
assert resp.status_code == 404
|
|
|
|
def test_get_messages_empty(self, client):
|
|
create_resp = client.post("/api/v1/chat/sessions", json={"agent_name": "test-agent"})
|
|
session_id = create_resp.json()["session_id"]
|
|
|
|
msgs_resp = client.get(f"/api/v1/chat/sessions/{session_id}/messages")
|
|
assert msgs_resp.status_code == 200
|
|
assert msgs_resp.json() == []
|
|
|
|
def test_get_messages_nonexistent_session(self, client):
|
|
resp = client.get("/api/v1/chat/sessions/nonexistent/messages")
|
|
assert resp.status_code == 404
|
|
|
|
def test_send_message_closed_session(self, client):
|
|
create_resp = client.post("/api/v1/chat/sessions", json={"agent_name": "test-agent"})
|
|
session_id = create_resp.json()["session_id"]
|
|
|
|
client.delete(f"/api/v1/chat/sessions/{session_id}")
|
|
|
|
msg_resp = client.post(
|
|
f"/api/v1/chat/sessions/{session_id}/messages",
|
|
json={"content": "Hello"},
|
|
)
|
|
assert msg_resp.status_code == 400
|
|
|
|
def test_send_message_nonexistent_session(self, client):
|
|
resp = client.post(
|
|
"/api/v1/chat/sessions/nonexistent/messages",
|
|
json={"content": "Hello"},
|
|
)
|
|
assert resp.status_code == 404
|
|
|
|
|
|
class TestChatFileUpload:
|
|
def test_upload_and_download_file(self, client, monkeypatch):
|
|
from agentkit.server.routes import chat as chat_module
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
monkeypatch.setattr(chat_module, "UPLOAD_DIR", Path(tmpdir))
|
|
resp = client.post(
|
|
"/api/v1/chat/upload",
|
|
files={"file": ("hello.txt", b"hello world", "text/plain")},
|
|
)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["filename"] == "hello.txt"
|
|
assert data["content_type"] == "text/plain"
|
|
assert data["size"] == 11
|
|
assert "/api/v1/chat/uploads/" in data["download_url"]
|
|
|
|
stored_name = data["stored_name"]
|
|
download = client.get(f"/api/v1/chat/uploads/{stored_name}")
|
|
assert download.status_code == 200
|
|
assert download.content == b"hello world"
|
|
|
|
def test_upload_file_exceeds_size_limit(self, client, monkeypatch):
|
|
from agentkit.server.routes import chat as chat_module
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
monkeypatch.setattr(chat_module, "UPLOAD_DIR", Path(tmpdir))
|
|
large_content = b"x" * (chat_module.MAX_UPLOAD_SIZE + 1)
|
|
resp = client.post(
|
|
"/api/v1/chat/upload",
|
|
files={"file": ("big.bin", large_content, "application/octet-stream")},
|
|
)
|
|
assert resp.status_code == 413
|
|
|
|
def test_download_missing_file(self, client):
|
|
resp = client.get("/api/v1/chat/uploads/notfound.bin")
|
|
assert resp.status_code == 404
|