279 lines
8.1 KiB
Python
279 lines
8.1 KiB
Python
"""知识图谱批量构建API测试"""
|
|
import uuid
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
from httpx import AsyncClient, ASGITransport
|
|
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession, create_async_engine
|
|
from sqlalchemy.pool import StaticPool
|
|
|
|
from app.database import Base
|
|
from app.main import app
|
|
from app.api.deps import get_db, get_current_user
|
|
from app.models.user import User
|
|
from app.models.knowledge import KnowledgeBase
|
|
from app.models.organization import Organization
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def async_engine():
|
|
engine = create_async_engine(
|
|
"sqlite+aiosqlite:///:memory:",
|
|
connect_args={"check_same_thread": False},
|
|
poolclass=StaticPool,
|
|
)
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
yield engine
|
|
await engine.dispose()
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def async_session(async_engine):
|
|
async_session_maker = async_sessionmaker(
|
|
async_engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False,
|
|
autoflush=False,
|
|
autocommit=False,
|
|
)
|
|
async with async_session_maker() as session:
|
|
yield session
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def test_user(async_session):
|
|
user = User(
|
|
id=str(uuid.uuid4()),
|
|
email="test@example.com",
|
|
password="hashed_password",
|
|
firstName="Test User",
|
|
plan="free",
|
|
max_queries=5,
|
|
isActive=True,
|
|
emailVerified=True,
|
|
)
|
|
async_session.add(user)
|
|
await async_session.commit()
|
|
await async_session.refresh(user)
|
|
return user
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def test_org(async_session):
|
|
org = Organization(
|
|
id=uuid.uuid4(),
|
|
name="Test Org",
|
|
slug="test-org",
|
|
)
|
|
async_session.add(org)
|
|
await async_session.commit()
|
|
await async_session.refresh(org)
|
|
return org
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def test_kb(async_session, test_org):
|
|
kb = KnowledgeBase(
|
|
id=uuid.uuid4(),
|
|
organization_id=test_org.id,
|
|
name="Test KB",
|
|
type="industry",
|
|
description="Test knowledge base",
|
|
)
|
|
async_session.add(kb)
|
|
await async_session.commit()
|
|
await async_session.refresh(kb)
|
|
return kb
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def async_client(async_session, test_user):
|
|
async def override_get_db():
|
|
yield async_session
|
|
|
|
async def override_get_current_user():
|
|
return test_user
|
|
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
app.dependency_overrides[get_current_user] = override_get_current_user
|
|
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
|
yield client
|
|
|
|
app.dependency_overrides.clear()
|
|
|
|
|
|
BATCH_URL = "/api/v1/knowledge-bases/{kb_id}/entities/batch"
|
|
|
|
|
|
class TestBatchCreateEntitiesEmptyInput:
|
|
"""空输入验证测试"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_empty_entities_list_returns_400(self, async_client, test_kb):
|
|
response = await async_client.post(
|
|
BATCH_URL.format(kb_id=test_kb.id),
|
|
json=[],
|
|
)
|
|
assert response.status_code == 400
|
|
data = response.json()
|
|
assert "detail" in data
|
|
assert "不能为空" in data["detail"]
|
|
|
|
|
|
class TestBatchCreateEntitiesSizeLimit:
|
|
"""批量大小限制测试"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_over_100_entities_returns_400(self, async_client, test_kb):
|
|
entities = [
|
|
{
|
|
"name": f"Entity {i}",
|
|
"entity_type": "CONCEPT",
|
|
"description": f"Test entity {i}",
|
|
}
|
|
for i in range(101)
|
|
]
|
|
response = await async_client.post(
|
|
BATCH_URL.format(kb_id=test_kb.id),
|
|
json=entities,
|
|
)
|
|
assert response.status_code == 400
|
|
data = response.json()
|
|
assert "detail" in data
|
|
assert "100" in data["detail"]
|
|
|
|
|
|
class TestBatchCreateEntitiesKBNotFound:
|
|
"""知识库不存在测试"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_nonexistent_kb_returns_404(self, async_client):
|
|
fake_kb_id = str(uuid.uuid4())
|
|
entities = [
|
|
{
|
|
"name": "Entity 1",
|
|
"entity_type": "CONCEPT",
|
|
"description": "Test",
|
|
}
|
|
]
|
|
response = await async_client.post(
|
|
BATCH_URL.format(kb_id=fake_kb_id),
|
|
json=entities,
|
|
)
|
|
assert response.status_code == 404
|
|
data = response.json()
|
|
assert "detail" in data
|
|
assert "不存在" in data["detail"]
|
|
|
|
|
|
class TestBatchCreateEntitiesSuccess:
|
|
"""批量创建成功测试"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_batch_create_entities_success(self, async_client, test_kb):
|
|
entities = [
|
|
{
|
|
"name": "公司A",
|
|
"entity_type": "ORGANIZATION",
|
|
"description": "测试公司A",
|
|
"properties": {"industry": "科技"},
|
|
},
|
|
{
|
|
"name": "产品B",
|
|
"entity_type": "PRODUCT",
|
|
"description": "测试产品B",
|
|
},
|
|
]
|
|
response = await async_client.post(
|
|
BATCH_URL.format(kb_id=test_kb.id),
|
|
json=entities,
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["created_count"] == 2
|
|
assert len(data["entities"]) == 2
|
|
|
|
for entity_data in data["entities"]:
|
|
assert "id" in entity_data
|
|
assert "name" in entity_data
|
|
assert "entity_type" in entity_data
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_batch_create_single_entity(self, async_client, test_kb):
|
|
entities = [
|
|
{
|
|
"name": "单个实体",
|
|
"entity_type": "PERSON",
|
|
"description": "测试单个实体",
|
|
}
|
|
]
|
|
response = await async_client.post(
|
|
BATCH_URL.format(kb_id=test_kb.id),
|
|
json=entities,
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["created_count"] == 1
|
|
assert len(data["entities"]) == 1
|
|
assert data["entities"][0]["name"] == "单个实体"
|
|
assert data["entities"][0]["entity_type"] == "PERSON"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_batch_create_with_properties(self, async_client, test_kb):
|
|
entities = [
|
|
{
|
|
"name": "带属性实体",
|
|
"entity_type": "TECHNOLOGY",
|
|
"description": "测试带属性",
|
|
"properties": {"version": "1.0", "category": "AI"},
|
|
}
|
|
]
|
|
response = await async_client.post(
|
|
BATCH_URL.format(kb_id=test_kb.id),
|
|
json=entities,
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["entities"][0]["properties"]["version"] == "1.0"
|
|
assert data["entities"][0]["properties"]["category"] == "AI"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_batch_create_without_properties_defaults_to_empty(
|
|
self, async_client, test_kb
|
|
):
|
|
entities = [
|
|
{
|
|
"name": "无属性实体",
|
|
"entity_type": "BRAND",
|
|
"description": "测试无属性",
|
|
}
|
|
]
|
|
response = await async_client.post(
|
|
BATCH_URL.format(kb_id=test_kb.id),
|
|
json=entities,
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["entities"][0]["properties"] == {}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_batch_create_exactly_100_entities(self, async_client, test_kb):
|
|
entities = [
|
|
{
|
|
"name": f"Entity {i}",
|
|
"entity_type": "CONCEPT",
|
|
"description": f"Test entity {i}",
|
|
}
|
|
for i in range(100)
|
|
]
|
|
response = await async_client.post(
|
|
BATCH_URL.format(kb_id=test_kb.id),
|
|
json=entities,
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["created_count"] == 100
|