154 lines
4.9 KiB
Python
154 lines
4.9 KiB
Python
import uuid
|
|
from datetime import datetime
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
import pytest
|
|
|
|
from app.api.deps import get_current_user
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_query():
|
|
"""Return a mock query object."""
|
|
q = AsyncMock()
|
|
q.id = uuid.UUID("22345678-1234-1234-1234-123456789abc")
|
|
q.user_id = uuid.UUID("12345678-1234-1234-1234-123456789abc")
|
|
q.keyword = "test keyword"
|
|
q.target_brand = "TestBrand"
|
|
q.brand_aliases = []
|
|
q.platforms = ["wenxin", "kimi"]
|
|
q.frequency = "weekly"
|
|
q.status = "active"
|
|
q.last_queried_at = None
|
|
q.next_query_at = datetime.now()
|
|
q.created_at = datetime.now()
|
|
q.updated_at = datetime.now()
|
|
return q
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_query_success(
|
|
async_client, override_get_current_user, auth_headers, mock_query
|
|
):
|
|
with patch("app.api.queries.create_query", return_value=mock_query):
|
|
response = await async_client.post(
|
|
"/api/v1/queries/",
|
|
headers=auth_headers,
|
|
json={
|
|
"keyword": "test keyword",
|
|
"target_brand": "TestBrand",
|
|
"platforms": ["wenxin", "kimi"],
|
|
"frequency": "weekly",
|
|
},
|
|
)
|
|
assert response.status_code == 201
|
|
data = response.json()
|
|
assert data["keyword"] == "test keyword"
|
|
assert data["target_brand"] == "TestBrand"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_query_exceeds_limit(
|
|
async_client, override_get_current_user, auth_headers
|
|
):
|
|
with patch(
|
|
"app.api.queries.create_query",
|
|
side_effect=PermissionError("Query limit exceeded"),
|
|
):
|
|
response = await async_client.post(
|
|
"/api/v1/queries/",
|
|
headers=auth_headers,
|
|
json={
|
|
"keyword": "test keyword",
|
|
"target_brand": "TestBrand",
|
|
"platforms": ["wenxin"],
|
|
"frequency": "daily",
|
|
},
|
|
)
|
|
assert response.status_code == 403
|
|
data = response.json()
|
|
assert "Query limit exceeded" in data["detail"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_queries(
|
|
async_client, override_get_current_user, auth_headers, mock_query
|
|
):
|
|
with patch("app.api.queries.get_queries", return_value=([mock_query], 1)):
|
|
response = await async_client.get("/api/v1/queries/", headers=auth_headers)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["total"] == 1
|
|
assert len(data["items"]) == 1
|
|
assert data["items"][0]["keyword"] == "test keyword"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_query(
|
|
async_client, override_get_current_user, auth_headers, mock_query
|
|
):
|
|
updated_query = AsyncMock()
|
|
updated_query.id = mock_query.id
|
|
updated_query.keyword = "updated keyword"
|
|
updated_query.target_brand = "TestBrand"
|
|
updated_query.brand_aliases = []
|
|
updated_query.platforms = ["wenxin"]
|
|
updated_query.frequency = "daily"
|
|
updated_query.status = "active"
|
|
updated_query.last_queried_at = None
|
|
updated_query.next_query_at = datetime.now()
|
|
updated_query.created_at = datetime.now()
|
|
updated_query.updated_at = datetime.now()
|
|
|
|
with patch("app.api.queries.update_query", return_value=updated_query):
|
|
response = await async_client.put(
|
|
f"/api/v1/queries/{mock_query.id}",
|
|
headers=auth_headers,
|
|
json={"keyword": "updated keyword", "frequency": "daily"},
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["keyword"] == "updated keyword"
|
|
assert data["frequency"] == "daily"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_query(
|
|
async_client, override_get_current_user, auth_headers, mock_query
|
|
):
|
|
with patch("app.api.queries.delete_query", return_value=True):
|
|
response = await async_client.delete(
|
|
f"/api/v1/queries/{mock_query.id}",
|
|
headers=auth_headers,
|
|
)
|
|
assert response.status_code == 204
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_query_not_found(
|
|
async_client, override_get_current_user, auth_headers
|
|
):
|
|
non_existent_id = uuid.UUID("33333333-3333-3333-3333-333333333333")
|
|
with patch("app.api.queries.get_query", return_value=None):
|
|
response = await async_client.get(
|
|
f"/api/v1/queries/{non_existent_id}",
|
|
headers=auth_headers,
|
|
)
|
|
assert response.status_code == 404
|
|
data = response.json()
|
|
assert "Query not found" in data["detail"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_query_belongs_to_other_user(
|
|
async_client, override_get_current_user, auth_headers
|
|
):
|
|
other_user_query_id = uuid.UUID("44444444-4444-4444-4444-444444444444")
|
|
# Simulate that the query does not belong to the current user by returning None
|
|
with patch("app.api.queries.get_query", return_value=None):
|
|
response = await async_client.get(
|
|
f"/api/v1/queries/{other_user_query_id}",
|
|
headers=auth_headers,
|
|
)
|
|
assert response.status_code == 404
|