"""诊断API测试""" import uuid import pytest import pytest_asyncio from httpx import AsyncClient, ASGITransport from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession, create_async_engine from sqlalchemy.pool import StaticPool from app.database import Base from app.main import app from app.models.user import User from app.models.brand import Brand from app.api.deps import get_current_user, get_db from app.services.auth import hash_password from tests.fixtures.auth import _to_uuid @pytest_asyncio.fixture async def async_engine(): """创建测试用SQLite异步引擎""" engine = create_async_engine( "sqlite+aiosqlite:///:memory:", connect_args={"check_same_thread": False}, poolclass=StaticPool, ) async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) yield engine await engine.dispose() @pytest_asyncio.fixture async def async_session(async_engine): """创建测试用异步数据库会话""" async_session_maker = async_sessionmaker( async_engine, class_=AsyncSession, expire_on_commit=False, autoflush=False, autocommit=False, ) async with async_session_maker() as session: yield session @pytest_asyncio.fixture async def test_user(async_session): """创建测试用户""" user = User( id=str(uuid.uuid4()), email="test@example.com", password=hash_password("Test@123456"), firstName="Test User", plan="free", max_queries=5, isActive=True, emailVerified=True, ) async_session.add(user) await async_session.commit() await async_session.refresh(user) return user @pytest_asyncio.fixture async def test_brand(async_session, test_user): """创建测试品牌""" brand = Brand( id=uuid.uuid4(), user_id=_to_uuid(test_user.id), name="Test Brand", aliases=["TestBrand", "TB"], website="https://testbrand.com", industry="technology", platforms=["wenxin", "kimi"], frequency="weekly", status="active", ) async_session.add(brand) await async_session.commit() await async_session.refresh(brand) return brand @pytest_asyncio.fixture async def async_client(async_session, test_user): """创建异步HTTP客户端用于API测试""" async def override_get_db(): yield async_session async def override_get_current_user(): return test_user app.dependency_overrides[get_db] = override_get_db app.dependency_overrides[get_current_user] = override_get_current_user transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: yield client app.dependency_overrides.clear() class TestDiagnosisAPI: """诊断API测试""" @pytest.mark.asyncio async def test_seo_diagnosis_success(self, async_client, test_brand): """测试SEO诊断端点成功返回""" response = await async_client.get(f"/api/v1/diagnosis/seo/{test_brand.id}") assert response.status_code == 200 data = response.json() assert "overall_score" in data assert "health_level" in data assert "dimensions" in data assert "recommendations" in data assert isinstance(data["overall_score"], (int, float)) assert isinstance(data["dimensions"], list) assert isinstance(data["recommendations"], list) @pytest.mark.asyncio async def test_geo_diagnosis_success(self, async_client, test_brand): """测试GEO诊断端点成功返回""" response = await async_client.post(f"/api/v1/diagnosis/geo/{test_brand.id}") assert response.status_code == 202 data = response.json() assert "task_id" in data or "status" in data @pytest.mark.asyncio async def test_combined_diagnosis_success(self, async_client, test_brand): """测试综合诊断端点成功返回""" response = await async_client.get(f"/api/v1/diagnosis/combined/{test_brand.id}") assert response.status_code == 200 data = response.json() assert "seo_score" in data assert "geo_score" in data assert "combined_score" in data assert "seo_diagnosis" in data assert "geo_diagnosis" in data assert isinstance(data["seo_score"], (int, float)) assert isinstance(data["geo_score"], (int, float)) assert isinstance(data["combined_score"], (int, float)) @pytest.mark.asyncio async def test_diagnosis_brand_not_found(self, async_client): """测试品牌不存在时返回404""" non_existent_id = uuid.uuid4() seo_response = await async_client.get(f"/api/v1/diagnosis/seo/{non_existent_id}") assert seo_response.status_code == 404 geo_response = await async_client.post(f"/api/v1/diagnosis/geo/{non_existent_id}") assert geo_response.status_code == 404 combined_response = await async_client.get(f"/api/v1/diagnosis/combined/{non_existent_id}") assert combined_response.status_code == 404 @pytest.mark.asyncio async def test_diagnosis_unauthorized_access(self, async_session): """测试未认证时返回401""" async def override_get_db(): yield async_session app.dependency_overrides[get_db] = override_get_db transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: headers = {"Authorization": "Bearer invalid_token"} seo_response = await client.get(f"/api/v1/diagnosis/seo/{uuid.uuid4()}", headers=headers) assert seo_response.status_code == 401 geo_response = await client.post(f"/api/v1/diagnosis/geo/{uuid.uuid4()}", headers=headers) assert geo_response.status_code == 401 combined_response = await client.get(f"/api/v1/diagnosis/combined/{uuid.uuid4()}", headers=headers) assert combined_response.status_code == 401 app.dependency_overrides.clear() @pytest.mark.asyncio async def test_diagnosis_result_format(self, async_client, test_brand): """测试诊断结果格式正确""" response = await async_client.get(f"/api/v1/diagnosis/seo/{test_brand.id}") assert response.status_code == 200 data = response.json() assert 0 <= data["overall_score"] <= 100 assert data["health_level"] in ["excellent", "good", "pass", "danger"] for dimension in data["dimensions"]: assert "name" in dimension assert "score" in dimension assert "max_score" in dimension assert "percentage" in dimension assert "status" in dimension assert "items" in dimension assert isinstance(dimension["items"], list) for item in dimension["items"]: assert "name" in item assert "status" in item assert "description" in item assert "suggestion" in item assert "score" in item for rec in data["recommendations"]: assert "priority" in rec assert "dimension" in rec assert "description" in rec