"""告警设置API测试""" import uuid import pytest import pytest_asyncio from httpx import AsyncClient, ASGITransport from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession, create_async_engine from sqlalchemy.pool import StaticPool from app.database import Base from app.main import app from app.models.user import User from app.models.brand import Brand from app.models.alert_setting import AlertSetting from app.api.deps import get_current_user, get_db from app.services.auth import hash_password, create_access_token # ==================== Fixtures ==================== @pytest_asyncio.fixture async def async_engine(): """创建测试用SQLite异步引擎""" engine = create_async_engine( "sqlite+aiosqlite:///:memory:", connect_args={"check_same_thread": False}, poolclass=StaticPool, ) async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) yield engine await engine.dispose() @pytest_asyncio.fixture async def async_session(async_engine): """创建测试用异步数据库会话""" async_session_maker = async_sessionmaker( async_engine, class_=AsyncSession, expire_on_commit=False, autoflush=False, autocommit=False, ) async with async_session_maker() as session: yield session @pytest_asyncio.fixture async def test_user(async_session): """创建测试用户""" user = User( id=uuid.uuid4(), email="test@example.com", password_hash=hash_password("Test@123456"), name="Test User", plan="free", max_queries=5, is_active=True, email_verified=True, ) async_session.add(user) await async_session.commit() await async_session.refresh(user) return user @pytest_asyncio.fixture async def test_brand(async_session, test_user): """创建测试品牌""" brand = Brand( id=uuid.uuid4(), user_id=test_user.id, name="Test Brand", aliases=["TestBrand", "TB"], website="https://testbrand.com", industry="technology", platforms=["wenxin", "kimi"], frequency="weekly", status="active", ) async_session.add(brand) await async_session.commit() await async_session.refresh(brand) return brand @pytest_asyncio.fixture async def test_alert_setting(async_session, test_user, test_brand): """创建测试告警设置""" setting = AlertSetting( id=uuid.uuid4(), brand_id=test_brand.id, user_id=test_user.id, alert_type="score_drop", enabled=True, threshold=5.0, ) async_session.add(setting) await async_session.commit() await async_session.refresh(setting) return setting @pytest_asyncio.fixture async def async_client(async_session, test_user): """创建异步HTTP客户端用于API测试""" async def override_get_db(): yield async_session async def override_get_current_user(): return test_user app.dependency_overrides[get_db] = override_get_db app.dependency_overrides[get_current_user] = override_get_current_user transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: yield client app.dependency_overrides.clear() @pytest.fixture def auth_headers(test_user): """创建认证请求头""" token = create_access_token(data={"sub": str(test_user.id)}) return {"Authorization": f"Bearer {token}"} # ==================== 测试类 ==================== class TestAlertSettingsAPI: """告警设置API测试""" @pytest.mark.asyncio async def test_get_alert_settings_success(self, async_client, test_alert_setting): """测试获取告警设置 - 成功返回设置列表""" response = await async_client.get("/api/v1/alerts/settings") assert response.status_code == 200 data = response.json() assert "items" in data assert "total" in data assert data["total"] >= 1 assert len(data["items"]) >= 1 first_item = data["items"][0] assert "id" in first_item assert "brand_id" in first_item assert "alert_type" in first_item assert "enabled" in first_item assert "threshold" in first_item @pytest.mark.asyncio async def test_get_alert_settings_by_brand(self, async_client, test_brand, test_alert_setting): """测试按品牌筛选告警设置""" response = await async_client.get( f"/api/v1/alerts/settings?brand_id={test_brand.id}" ) assert response.status_code == 200 data = response.json() assert data["total"] >= 1 for item in data["items"]: assert item["brand_id"] == str(test_brand.id) @pytest.mark.asyncio async def test_update_alert_settings_success(self, async_client, test_brand): """测试更新告警设置 - 成功更新并返回新设置""" update_data = { "settings": [ { "brand_id": str(test_brand.id), "alert_type": "score_drop", "enabled": True, "threshold": 20.0, } ] } response = await async_client.put( "/api/v1/alerts/settings", json=update_data ) assert response.status_code == 200 data = response.json() assert "items" in data assert len(data["items"]) >= 1 updated_setting = data["items"][0] assert updated_setting["alert_type"] == "score_drop" assert updated_setting["threshold"] == 20.0 assert updated_setting["enabled"] is True @pytest.mark.asyncio async def test_create_alert_setting_success(self, async_client, test_brand): """测试创建告警设置 - 为新品牌创建默认设置""" create_data = { "brand_id": str(test_brand.id), "alert_type": "negative_sentiment", "enabled": True, "threshold": 1.0, } response = await async_client.post( "/api/v1/alerts/settings", json=create_data ) assert response.status_code == 201 data = response.json() assert data["alert_type"] == "negative_sentiment" assert data["threshold"] == 1.0 assert data["enabled"] is True assert "id" in data @pytest.mark.asyncio async def test_delete_alert_setting_success(self, async_client, test_alert_setting): """测试删除告警设置 - 成功删除""" response = await async_client.delete( f"/api/v1/alerts/settings/{test_alert_setting.id}" ) assert response.status_code == 204 get_response = await async_client.get("/api/v1/alerts/settings") data = get_response.json() deleted_ids = [item["id"] for item in data["items"]] assert str(test_alert_setting.id) not in deleted_ids @pytest.mark.asyncio async def test_delete_alert_setting_not_found(self, async_client): """测试删除不存在的告警设置""" non_existent_id = uuid.uuid4() response = await async_client.delete( f"/api/v1/alerts/settings/{non_existent_id}" ) assert response.status_code == 404 class TestAlertSettingsValidation: """告警设置验证测试""" @pytest.mark.asyncio async def test_brand_not_found_returns_404(self, async_client): """测试品牌不存在时返回404""" non_existent_brand_id = uuid.uuid4() create_data = { "brand_id": str(non_existent_brand_id), "alert_type": "score_drop", "enabled": True, "threshold": 5.0, } response = await async_client.post( "/api/v1/alerts/settings", json=create_data ) assert response.status_code == 404 @pytest.mark.asyncio async def test_unauthorized_returns_401(self, async_session): """测试未认证时返回401""" async def override_get_db(): yield async_session app.dependency_overrides[get_db] = override_get_db transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: headers = {"Authorization": "Bearer invalid_token"} response = await client.get( "/api/v1/alerts/settings", headers=headers ) assert response.status_code == 401 app.dependency_overrides.clear() @pytest.mark.asyncio async def test_invalid_alert_type_returns_422(self, async_client, test_brand): """测试无效的告警类型返回422""" create_data = { "brand_id": str(test_brand.id), "alert_type": "invalid_type", "enabled": True, "threshold": 5.0, } response = await async_client.post( "/api/v1/alerts/settings", json=create_data ) assert response.status_code == 422 @pytest.mark.asyncio async def test_negative_threshold_returns_422(self, async_client, test_brand): """测试负数阈值返回422""" create_data = { "brand_id": str(test_brand.id), "alert_type": "score_drop", "enabled": True, "threshold": -10.0, } response = await async_client.post( "/api/v1/alerts/settings", json=create_data ) assert response.status_code == 422 @pytest.mark.asyncio async def test_threshold_over_100_returns_422(self, async_client, test_brand): """测试阈值超过100返回422""" create_data = { "brand_id": str(test_brand.id), "alert_type": "score_drop", "enabled": True, "threshold": 150.0, } response = await async_client.post( "/api/v1/alerts/settings", json=create_data ) assert response.status_code == 422