import uuid import pytest import pytest_asyncio from httpx import ASGITransport, AsyncClient from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.pool import StaticPool from app.api.deps import get_current_user, get_db from app.database import Base from app.main import app from app.models.brand import Brand from app.models.user import User from app.services.auth import hash_password from tests.fixtures.auth import _to_uuid @pytest_asyncio.fixture async def async_engine(): engine = create_async_engine( "sqlite+aiosqlite:///:memory:", connect_args={"check_same_thread": False}, poolclass=StaticPool, ) async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) yield engine await engine.dispose() @pytest_asyncio.fixture async def async_session(async_engine): session_maker = async_sessionmaker( async_engine, class_=AsyncSession, expire_on_commit=False, autoflush=False, autocommit=False, ) async with session_maker() as session: yield session @pytest_asyncio.fixture async def test_user(async_session): user = User( id=str(uuid.uuid4()), email="test@example.com", password=hash_password("Test@123456"), firstName="Test User", plan="free", max_queries=5, isActive=True, emailVerified=True, ) async_session.add(user) await async_session.commit() await async_session.refresh(user) return user @pytest_asyncio.fixture async def test_brand(async_session, test_user): brand = Brand( id=uuid.uuid4(), user_id=_to_uuid(test_user.id), name="Test Brand", aliases=["TestBrand", "TB"], website="https://testbrand.com", industry="technology", platforms=["wenxin", "kimi"], frequency="weekly", status="active", ) async_session.add(brand) await async_session.commit() await async_session.refresh(brand) return brand @pytest_asyncio.fixture async def async_client(async_session, test_user): async def override_get_db(): yield async_session async def override_get_current_user(): return test_user app.dependency_overrides[get_db] = override_get_db app.dependency_overrides[get_current_user] = override_get_current_user transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: yield client app.dependency_overrides.clear() class TestDetectionTaskAPI: @pytest.mark.asyncio async def test_create_detection_task(self, async_client, test_brand): task_data = { "brand_id": str(test_brand.id), "name": "每日品牌检测", "frequency": "daily", "engines": ["chatgpt", "perplexity"], "queries": ["最佳保险品牌", "保险推荐"], "competitor_names": ["竞品A"], } response = await async_client.post("/api/v1/detection/tasks", json=task_data) assert response.status_code == 201 data = response.json() assert data["name"] == "每日品牌检测" assert data["frequency"] == "daily" assert data["engines"] == ["chatgpt", "perplexity"] assert data["queries"] == ["最佳保险品牌", "保险推荐"] assert data["is_active"] is True assert "id" in data @pytest.mark.asyncio async def test_get_detection_tasks(self, async_client, test_brand): response = await async_client.get( "/api/v1/detection/tasks", params={"brand_id": str(test_brand.id)} ) assert response.status_code == 200 data = response.json() assert "items" in data assert "total" in data @pytest.mark.asyncio async def test_update_detection_task(self, async_client, test_brand): create_data = { "brand_id": str(test_brand.id), "name": "原始任务", "frequency": "weekly", "engines": ["chatgpt"], "queries": ["查询1"], } create_resp = await async_client.post("/api/v1/detection/tasks", json=create_data) task_id = create_resp.json()["id"] update_data = { "name": "更新后任务", "frequency": "daily", "engines": ["chatgpt", "perplexity"], } response = await async_client.put(f"/api/v1/detection/tasks/{task_id}", json=update_data) assert response.status_code == 200 data = response.json() assert data["name"] == "更新后任务" assert data["frequency"] == "daily" @pytest.mark.asyncio async def test_delete_detection_task(self, async_client, test_brand): create_data = { "brand_id": str(test_brand.id), "name": "待删除任务", "frequency": "daily", "engines": ["chatgpt"], "queries": ["查询1"], } create_resp = await async_client.post("/api/v1/detection/tasks", json=create_data) task_id = create_resp.json()["id"] response = await async_client.delete(f"/api/v1/detection/tasks/{task_id}") assert response.status_code == 204 @pytest.mark.asyncio async def test_trigger_detection_task(self, async_client, test_brand): create_data = { "brand_id": str(test_brand.id), "name": "手动触发任务", "frequency": "daily", "engines": ["chatgpt"], "queries": ["查询1"], } create_resp = await async_client.post("/api/v1/detection/tasks", json=create_data) task_id = create_resp.json()["id"] response = await async_client.post(f"/api/v1/detection/tasks/{task_id}/trigger") assert response.status_code == 200 data = response.json() assert data["status"] == "success" @pytest.mark.asyncio async def test_unauthorized_access(self, async_session): async def override_get_db(): yield async_session app.dependency_overrides[get_db] = override_get_db transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: headers = {"Authorization": "Bearer invalid_token"} response = await client.get("/api/v1/detection/tasks", headers=headers) assert response.status_code == 401 app.dependency_overrides.clear()