geo/backend/tests/test_api/test_alert_settings_api.py

331 lines
10 KiB
Python

"""告警设置API测试"""
import uuid
import pytest
import pytest_asyncio
from httpx import AsyncClient, ASGITransport
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession, create_async_engine
from sqlalchemy.pool import StaticPool
from app.database import Base
from app.main import app
from app.models.user import User
from app.models.brand import Brand
from app.models.alert_setting import AlertSetting
from app.api.deps import get_current_user, get_db
from app.services.auth import hash_password, create_access_token
from tests.fixtures.auth import _to_uuid
# ==================== Fixtures ====================
@pytest_asyncio.fixture
async def async_engine():
"""创建测试用SQLite异步引擎"""
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest_asyncio.fixture
async def async_session(async_engine):
"""创建测试用异步数据库会话"""
async_session_maker = async_sessionmaker(
async_engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False,
)
async with async_session_maker() as session:
yield session
@pytest_asyncio.fixture
async def test_user(async_session):
"""创建测试用户"""
user = User(
id=str(uuid.uuid4()),
email="test@example.com",
password=hash_password("Test@123456"),
firstName="Test User",
plan="free",
max_queries=5,
isActive=True,
emailVerified=True,
)
async_session.add(user)
await async_session.commit()
await async_session.refresh(user)
return user
@pytest_asyncio.fixture
async def test_brand(async_session, test_user):
"""创建测试品牌"""
brand = Brand(
id=uuid.uuid4(),
user_id=_to_uuid(test_user.id),
name="Test Brand",
aliases=["TestBrand", "TB"],
website="https://testbrand.com",
industry="technology",
platforms=["wenxin", "kimi"],
frequency="weekly",
status="active",
)
async_session.add(brand)
await async_session.commit()
await async_session.refresh(brand)
return brand
@pytest_asyncio.fixture
async def test_alert_setting(async_session, test_user, test_brand):
"""创建测试告警设置"""
setting = AlertSetting(
id=uuid.uuid4(),
brand_id=test_brand.id,
user_id=_to_uuid(test_user.id),
alert_type="score_drop",
enabled=True,
threshold=5.0,
)
async_session.add(setting)
await async_session.commit()
await async_session.refresh(setting)
return setting
@pytest_asyncio.fixture
async def async_client(async_session, test_user):
"""创建异步HTTP客户端用于API测试"""
async def override_get_db():
yield async_session
async def override_get_current_user():
return test_user
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_get_current_user
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
app.dependency_overrides.clear()
@pytest.fixture
def auth_headers(test_user):
"""创建认证请求头"""
token = create_access_token(data={"sub": str(test_user.id)})
return {"Authorization": f"Bearer {token}"}
# ==================== 测试类 ====================
class TestAlertSettingsAPI:
"""告警设置API测试"""
@pytest.mark.asyncio
async def test_get_alert_settings_success(self, async_client, test_alert_setting):
"""测试获取告警设置 - 成功返回设置列表"""
response = await async_client.get("/api/v1/alerts/settings")
assert response.status_code == 200
data = response.json()
assert "items" in data
assert "total" in data
assert data["total"] >= 1
assert len(data["items"]) >= 1
first_item = data["items"][0]
assert "id" in first_item
assert "brand_id" in first_item
assert "alert_type" in first_item
assert "enabled" in first_item
assert "threshold" in first_item
@pytest.mark.asyncio
async def test_get_alert_settings_by_brand(self, async_client, test_brand, test_alert_setting):
"""测试按品牌筛选告警设置"""
response = await async_client.get(
f"/api/v1/alerts/settings?brand_id={test_brand.id}"
)
assert response.status_code == 200
data = response.json()
assert data["total"] >= 1
for item in data["items"]:
assert item["brand_id"] == str(test_brand.id)
@pytest.mark.asyncio
async def test_update_alert_settings_success(self, async_client, test_brand):
"""测试更新告警设置 - 成功更新并返回新设置"""
update_data = {
"settings": [
{
"brand_id": str(test_brand.id),
"alert_type": "score_drop",
"enabled": True,
"threshold": 20.0,
}
]
}
response = await async_client.put(
"/api/v1/alerts/settings", json=update_data
)
assert response.status_code == 200
data = response.json()
assert "items" in data
assert len(data["items"]) >= 1
updated_setting = data["items"][0]
assert updated_setting["alert_type"] == "score_drop"
assert updated_setting["threshold"] == 20.0
assert updated_setting["enabled"] is True
@pytest.mark.asyncio
async def test_create_alert_setting_success(self, async_client, test_brand):
"""测试创建告警设置 - 为新品牌创建默认设置"""
create_data = {
"brand_id": str(test_brand.id),
"alert_type": "negative_sentiment",
"enabled": True,
"threshold": 1.0,
}
response = await async_client.post(
"/api/v1/alerts/settings", json=create_data
)
assert response.status_code == 201
data = response.json()
assert data["alert_type"] == "negative_sentiment"
assert data["threshold"] == 1.0
assert data["enabled"] is True
assert "id" in data
@pytest.mark.asyncio
async def test_delete_alert_setting_success(self, async_client, test_alert_setting):
"""测试删除告警设置 - 成功删除"""
response = await async_client.delete(
f"/api/v1/alerts/settings/{test_alert_setting.id}"
)
assert response.status_code == 204
get_response = await async_client.get("/api/v1/alerts/settings")
data = get_response.json()
deleted_ids = [item["id"] for item in data["items"]]
assert str(test_alert_setting.id) not in deleted_ids
@pytest.mark.asyncio
async def test_delete_alert_setting_not_found(self, async_client):
"""测试删除不存在的告警设置"""
non_existent_id = uuid.uuid4()
response = await async_client.delete(
f"/api/v1/alerts/settings/{non_existent_id}"
)
assert response.status_code == 404
class TestAlertSettingsValidation:
"""告警设置验证测试"""
@pytest.mark.asyncio
async def test_brand_not_found_returns_404(self, async_client):
"""测试品牌不存在时返回404"""
non_existent_brand_id = uuid.uuid4()
create_data = {
"brand_id": str(non_existent_brand_id),
"alert_type": "score_drop",
"enabled": True,
"threshold": 5.0,
}
response = await async_client.post(
"/api/v1/alerts/settings", json=create_data
)
assert response.status_code == 404
@pytest.mark.asyncio
async def test_unauthorized_returns_401(self, async_session):
"""测试未认证时返回401"""
async def override_get_db():
yield async_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
headers = {"Authorization": "Bearer invalid_token"}
response = await client.get(
"/api/v1/alerts/settings",
headers=headers
)
assert response.status_code == 401
app.dependency_overrides.clear()
@pytest.mark.asyncio
async def test_invalid_alert_type_returns_422(self, async_client, test_brand):
"""测试无效的告警类型返回422"""
create_data = {
"brand_id": str(test_brand.id),
"alert_type": "invalid_type",
"enabled": True,
"threshold": 5.0,
}
response = await async_client.post(
"/api/v1/alerts/settings", json=create_data
)
assert response.status_code == 422
@pytest.mark.asyncio
async def test_negative_threshold_returns_422(self, async_client, test_brand):
"""测试负数阈值返回422"""
create_data = {
"brand_id": str(test_brand.id),
"alert_type": "score_drop",
"enabled": True,
"threshold": -10.0,
}
response = await async_client.post(
"/api/v1/alerts/settings", json=create_data
)
assert response.status_code == 422
@pytest.mark.asyncio
async def test_threshold_over_100_returns_422(self, async_client, test_brand):
"""测试阈值超过100返回422"""
create_data = {
"brand_id": str(test_brand.id),
"alert_type": "score_drop",
"enabled": True,
"threshold": 150.0,
}
response = await async_client.post(
"/api/v1/alerts/settings", json=create_data
)
assert response.status_code == 422