geo/backend/tests/test_api/test_detection_api.py

204 lines
6.4 KiB
Python

import uuid
import pytest
import pytest_asyncio
from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.pool import StaticPool
from app.api.deps import get_current_user, get_db
from app.database import Base
from app.main import app
from app.models.brand import Brand
from app.models.user import User
from app.services.auth import hash_password
from tests.fixtures.auth import _to_uuid
@pytest_asyncio.fixture
async def async_engine():
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest_asyncio.fixture
async def async_session(async_engine):
session_maker = async_sessionmaker(
async_engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False,
)
async with session_maker() as session:
yield session
@pytest_asyncio.fixture
async def test_user(async_session):
user = User(
id=str(uuid.uuid4()),
email="test@example.com",
password=hash_password("Test@123456"),
firstName="Test User",
plan="free",
max_queries=5,
isActive=True,
emailVerified=True,
)
async_session.add(user)
await async_session.commit()
await async_session.refresh(user)
return user
@pytest_asyncio.fixture
async def test_brand(async_session, test_user):
brand = Brand(
id=uuid.uuid4(),
user_id=_to_uuid(test_user.id),
name="Test Brand",
aliases=["TestBrand", "TB"],
website="https://testbrand.com",
industry="technology",
platforms=["wenxin", "kimi"],
frequency="weekly",
status="active",
)
async_session.add(brand)
await async_session.commit()
await async_session.refresh(brand)
return brand
@pytest_asyncio.fixture
async def async_client(async_session, test_user):
async def override_get_db():
yield async_session
async def override_get_current_user():
return test_user
app.dependency_overrides[get_db] = override_get_db
app.dependency_overrides[get_current_user] = override_get_current_user
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
yield client
app.dependency_overrides.clear()
class TestDetectionTaskAPI:
@pytest.mark.asyncio
async def test_create_detection_task(self, async_client, test_brand):
task_data = {
"brand_id": str(test_brand.id),
"name": "每日品牌检测",
"frequency": "daily",
"engines": ["chatgpt", "perplexity"],
"queries": ["最佳保险品牌", "保险推荐"],
"competitor_names": ["竞品A"],
}
response = await async_client.post("/api/v1/detection/tasks", json=task_data)
assert response.status_code == 201
data = response.json()
assert data["name"] == "每日品牌检测"
assert data["frequency"] == "daily"
assert data["engines"] == ["chatgpt", "perplexity"]
assert data["queries"] == ["最佳保险品牌", "保险推荐"]
assert data["is_active"] is True
assert "id" in data
@pytest.mark.asyncio
async def test_get_detection_tasks(self, async_client, test_brand):
response = await async_client.get(
"/api/v1/detection/tasks", params={"brand_id": str(test_brand.id)}
)
assert response.status_code == 200
data = response.json()
assert "items" in data
assert "total" in data
@pytest.mark.asyncio
async def test_update_detection_task(self, async_client, test_brand):
create_data = {
"brand_id": str(test_brand.id),
"name": "原始任务",
"frequency": "weekly",
"engines": ["chatgpt"],
"queries": ["查询1"],
}
create_resp = await async_client.post("/api/v1/detection/tasks", json=create_data)
task_id = create_resp.json()["id"]
update_data = {
"name": "更新后任务",
"frequency": "daily",
"engines": ["chatgpt", "perplexity"],
}
response = await async_client.put(f"/api/v1/detection/tasks/{task_id}", json=update_data)
assert response.status_code == 200
data = response.json()
assert data["name"] == "更新后任务"
assert data["frequency"] == "daily"
@pytest.mark.asyncio
async def test_delete_detection_task(self, async_client, test_brand):
create_data = {
"brand_id": str(test_brand.id),
"name": "待删除任务",
"frequency": "daily",
"engines": ["chatgpt"],
"queries": ["查询1"],
}
create_resp = await async_client.post("/api/v1/detection/tasks", json=create_data)
task_id = create_resp.json()["id"]
response = await async_client.delete(f"/api/v1/detection/tasks/{task_id}")
assert response.status_code == 204
@pytest.mark.asyncio
async def test_trigger_detection_task(self, async_client, test_brand):
create_data = {
"brand_id": str(test_brand.id),
"name": "手动触发任务",
"frequency": "daily",
"engines": ["chatgpt"],
"queries": ["查询1"],
}
create_resp = await async_client.post("/api/v1/detection/tasks", json=create_data)
task_id = create_resp.json()["id"]
response = await async_client.post(f"/api/v1/detection/tasks/{task_id}/trigger")
assert response.status_code == 200
data = response.json()
assert data["status"] == "success"
@pytest.mark.asyncio
async def test_unauthorized_access(self, async_session):
async def override_get_db():
yield async_session
app.dependency_overrides[get_db] = override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
headers = {"Authorization": "Bearer invalid_token"}
response = await client.get("/api/v1/detection/tasks", headers=headers)
assert response.status_code == 401
app.dependency_overrides.clear()