217 lines
7.2 KiB
Python
217 lines
7.2 KiB
Python
"""诊断API测试"""
|
|
import uuid
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
from httpx import AsyncClient, ASGITransport
|
|
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession, create_async_engine
|
|
from sqlalchemy.pool import StaticPool
|
|
|
|
from app.database import Base
|
|
from app.main import app
|
|
from app.models.user import User
|
|
from app.models.brand import Brand
|
|
from app.api.deps import get_current_user, get_db
|
|
from app.services.auth import hash_password
|
|
from tests.fixtures.auth import _to_uuid
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def async_engine():
|
|
"""创建测试用SQLite异步引擎"""
|
|
engine = create_async_engine(
|
|
"sqlite+aiosqlite:///:memory:",
|
|
connect_args={"check_same_thread": False},
|
|
poolclass=StaticPool,
|
|
)
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
yield engine
|
|
await engine.dispose()
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def async_session(async_engine):
|
|
"""创建测试用异步数据库会话"""
|
|
async_session_maker = async_sessionmaker(
|
|
async_engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False,
|
|
autoflush=False,
|
|
autocommit=False,
|
|
)
|
|
async with async_session_maker() as session:
|
|
yield session
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def test_user(async_session):
|
|
"""创建测试用户"""
|
|
user = User(
|
|
id=str(uuid.uuid4()),
|
|
email="test@example.com",
|
|
password=hash_password("Test@123456"),
|
|
firstName="Test User",
|
|
plan="free",
|
|
max_queries=5,
|
|
isActive=True,
|
|
emailVerified=True,
|
|
)
|
|
async_session.add(user)
|
|
await async_session.commit()
|
|
await async_session.refresh(user)
|
|
return user
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def test_brand(async_session, test_user):
|
|
"""创建测试品牌"""
|
|
brand = Brand(
|
|
id=uuid.uuid4(),
|
|
user_id=_to_uuid(test_user.id),
|
|
name="Test Brand",
|
|
aliases=["TestBrand", "TB"],
|
|
website="https://testbrand.com",
|
|
industry="technology",
|
|
platforms=["wenxin", "kimi"],
|
|
frequency="weekly",
|
|
status="active",
|
|
)
|
|
async_session.add(brand)
|
|
await async_session.commit()
|
|
await async_session.refresh(brand)
|
|
return brand
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def async_client(async_session, test_user):
|
|
"""创建异步HTTP客户端用于API测试"""
|
|
|
|
async def override_get_db():
|
|
yield async_session
|
|
|
|
async def override_get_current_user():
|
|
return test_user
|
|
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
app.dependency_overrides[get_current_user] = override_get_current_user
|
|
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
yield client
|
|
|
|
app.dependency_overrides.clear()
|
|
|
|
|
|
class TestDiagnosisAPI:
|
|
"""诊断API测试"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_seo_diagnosis_success(self, async_client, test_brand):
|
|
"""测试SEO诊断端点成功返回"""
|
|
response = await async_client.get(f"/api/v1/diagnosis/seo/{test_brand.id}")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "overall_score" in data
|
|
assert "health_level" in data
|
|
assert "dimensions" in data
|
|
assert "recommendations" in data
|
|
assert isinstance(data["overall_score"], (int, float))
|
|
assert isinstance(data["dimensions"], list)
|
|
assert isinstance(data["recommendations"], list)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_geo_diagnosis_success(self, async_client, test_brand):
|
|
"""测试GEO诊断端点成功返回"""
|
|
response = await async_client.post(f"/api/v1/diagnosis/geo/{test_brand.id}")
|
|
|
|
assert response.status_code == 202
|
|
data = response.json()
|
|
assert "task_id" in data or "status" in data
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_combined_diagnosis_success(self, async_client, test_brand):
|
|
"""测试综合诊断端点成功返回"""
|
|
response = await async_client.get(f"/api/v1/diagnosis/combined/{test_brand.id}")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "seo_score" in data
|
|
assert "geo_score" in data
|
|
assert "combined_score" in data
|
|
assert "seo_diagnosis" in data
|
|
assert "geo_diagnosis" in data
|
|
assert isinstance(data["seo_score"], (int, float))
|
|
assert isinstance(data["geo_score"], (int, float))
|
|
assert isinstance(data["combined_score"], (int, float))
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_diagnosis_brand_not_found(self, async_client):
|
|
"""测试品牌不存在时返回404"""
|
|
non_existent_id = uuid.uuid4()
|
|
|
|
seo_response = await async_client.get(f"/api/v1/diagnosis/seo/{non_existent_id}")
|
|
assert seo_response.status_code == 404
|
|
|
|
geo_response = await async_client.post(f"/api/v1/diagnosis/geo/{non_existent_id}")
|
|
assert geo_response.status_code == 404
|
|
|
|
combined_response = await async_client.get(f"/api/v1/diagnosis/combined/{non_existent_id}")
|
|
assert combined_response.status_code == 404
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_diagnosis_unauthorized_access(self, async_session):
|
|
"""测试未认证时返回401"""
|
|
async def override_get_db():
|
|
yield async_session
|
|
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
headers = {"Authorization": "Bearer invalid_token"}
|
|
|
|
seo_response = await client.get(f"/api/v1/diagnosis/seo/{uuid.uuid4()}", headers=headers)
|
|
assert seo_response.status_code == 401
|
|
|
|
geo_response = await client.post(f"/api/v1/diagnosis/geo/{uuid.uuid4()}", headers=headers)
|
|
assert geo_response.status_code == 401
|
|
|
|
combined_response = await client.get(f"/api/v1/diagnosis/combined/{uuid.uuid4()}", headers=headers)
|
|
assert combined_response.status_code == 401
|
|
|
|
app.dependency_overrides.clear()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_diagnosis_result_format(self, async_client, test_brand):
|
|
"""测试诊断结果格式正确"""
|
|
response = await async_client.get(f"/api/v1/diagnosis/seo/{test_brand.id}")
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
|
|
assert 0 <= data["overall_score"] <= 100
|
|
assert data["health_level"] in ["excellent", "good", "pass", "danger"]
|
|
|
|
for dimension in data["dimensions"]:
|
|
assert "name" in dimension
|
|
assert "score" in dimension
|
|
assert "max_score" in dimension
|
|
assert "percentage" in dimension
|
|
assert "status" in dimension
|
|
assert "items" in dimension
|
|
assert isinstance(dimension["items"], list)
|
|
|
|
for item in dimension["items"]:
|
|
assert "name" in item
|
|
assert "status" in item
|
|
assert "description" in item
|
|
assert "suggestion" in item
|
|
assert "score" in item
|
|
|
|
for rec in data["recommendations"]:
|
|
assert "priority" in rec
|
|
assert "dimension" in rec
|
|
assert "description" in rec
|