"""知识图谱批量构建API测试""" import uuid import pytest import pytest_asyncio from httpx import AsyncClient, ASGITransport from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession, create_async_engine from sqlalchemy.pool import StaticPool from app.database import Base from app.main import app from app.api.deps import get_db, get_current_user from app.models.user import User from app.models.knowledge import KnowledgeBase from app.models.organization import Organization @pytest_asyncio.fixture async def async_engine(): engine = create_async_engine( "sqlite+aiosqlite:///:memory:", connect_args={"check_same_thread": False}, poolclass=StaticPool, ) async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) yield engine await engine.dispose() @pytest_asyncio.fixture async def async_session(async_engine): async_session_maker = async_sessionmaker( async_engine, class_=AsyncSession, expire_on_commit=False, autoflush=False, autocommit=False, ) async with async_session_maker() as session: yield session @pytest_asyncio.fixture async def test_user(async_session): user = User( id=str(uuid.uuid4()), email="test@example.com", password="hashed_password", firstName="Test User", plan="free", max_queries=5, isActive=True, emailVerified=True, ) async_session.add(user) await async_session.commit() await async_session.refresh(user) return user @pytest_asyncio.fixture async def test_org(async_session): org = Organization( id=uuid.uuid4(), name="Test Org", slug="test-org", ) async_session.add(org) await async_session.commit() await async_session.refresh(org) return org @pytest_asyncio.fixture async def test_kb(async_session, test_org): kb = KnowledgeBase( id=uuid.uuid4(), organization_id=test_org.id, name="Test KB", type="industry", description="Test knowledge base", ) async_session.add(kb) await async_session.commit() await async_session.refresh(kb) return kb @pytest_asyncio.fixture async def async_client(async_session, test_user): async def override_get_db(): yield async_session async def override_get_current_user(): return test_user app.dependency_overrides[get_db] = override_get_db app.dependency_overrides[get_current_user] = override_get_current_user transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: yield client app.dependency_overrides.clear() BATCH_URL = "/api/v1/knowledge-bases/{kb_id}/entities/batch" class TestBatchCreateEntitiesEmptyInput: """空输入验证测试""" @pytest.mark.asyncio async def test_empty_entities_list_returns_400(self, async_client, test_kb): response = await async_client.post( BATCH_URL.format(kb_id=test_kb.id), json=[], ) assert response.status_code == 400 data = response.json() assert "detail" in data assert "不能为空" in data["detail"] class TestBatchCreateEntitiesSizeLimit: """批量大小限制测试""" @pytest.mark.asyncio async def test_over_100_entities_returns_400(self, async_client, test_kb): entities = [ { "name": f"Entity {i}", "entity_type": "CONCEPT", "description": f"Test entity {i}", } for i in range(101) ] response = await async_client.post( BATCH_URL.format(kb_id=test_kb.id), json=entities, ) assert response.status_code == 400 data = response.json() assert "detail" in data assert "100" in data["detail"] class TestBatchCreateEntitiesKBNotFound: """知识库不存在测试""" @pytest.mark.asyncio async def test_nonexistent_kb_returns_404(self, async_client): fake_kb_id = str(uuid.uuid4()) entities = [ { "name": "Entity 1", "entity_type": "CONCEPT", "description": "Test", } ] response = await async_client.post( BATCH_URL.format(kb_id=fake_kb_id), json=entities, ) assert response.status_code == 404 data = response.json() assert "detail" in data assert "不存在" in data["detail"] class TestBatchCreateEntitiesSuccess: """批量创建成功测试""" @pytest.mark.asyncio async def test_batch_create_entities_success(self, async_client, test_kb): entities = [ { "name": "公司A", "entity_type": "ORGANIZATION", "description": "测试公司A", "properties": {"industry": "科技"}, }, { "name": "产品B", "entity_type": "PRODUCT", "description": "测试产品B", }, ] response = await async_client.post( BATCH_URL.format(kb_id=test_kb.id), json=entities, ) assert response.status_code == 200 data = response.json() assert data["created_count"] == 2 assert len(data["entities"]) == 2 for entity_data in data["entities"]: assert "id" in entity_data assert "name" in entity_data assert "entity_type" in entity_data @pytest.mark.asyncio async def test_batch_create_single_entity(self, async_client, test_kb): entities = [ { "name": "单个实体", "entity_type": "PERSON", "description": "测试单个实体", } ] response = await async_client.post( BATCH_URL.format(kb_id=test_kb.id), json=entities, ) assert response.status_code == 200 data = response.json() assert data["created_count"] == 1 assert len(data["entities"]) == 1 assert data["entities"][0]["name"] == "单个实体" assert data["entities"][0]["entity_type"] == "PERSON" @pytest.mark.asyncio async def test_batch_create_with_properties(self, async_client, test_kb): entities = [ { "name": "带属性实体", "entity_type": "TECHNOLOGY", "description": "测试带属性", "properties": {"version": "1.0", "category": "AI"}, } ] response = await async_client.post( BATCH_URL.format(kb_id=test_kb.id), json=entities, ) assert response.status_code == 200 data = response.json() assert data["entities"][0]["properties"]["version"] == "1.0" assert data["entities"][0]["properties"]["category"] == "AI" @pytest.mark.asyncio async def test_batch_create_without_properties_defaults_to_empty( self, async_client, test_kb ): entities = [ { "name": "无属性实体", "entity_type": "BRAND", "description": "测试无属性", } ] response = await async_client.post( BATCH_URL.format(kb_id=test_kb.id), json=entities, ) assert response.status_code == 200 data = response.json() assert data["entities"][0]["properties"] == {} @pytest.mark.asyncio async def test_batch_create_exactly_100_entities(self, async_client, test_kb): entities = [ { "name": f"Entity {i}", "entity_type": "CONCEPT", "description": f"Test entity {i}", } for i in range(100) ] response = await async_client.post( BATCH_URL.format(kb_id=test_kb.id), json=entities, ) assert response.status_code == 200 data = response.json() assert data["created_count"] == 100