564 lines
18 KiB
Python
564 lines
18 KiB
Python
"""知识图谱模块TDD测试
|
||
|
||
测试策略:
|
||
- 使用真实数据库(内存SQLite)进行测试
|
||
- 不使用Mock测试数据库操作
|
||
- LLM调用使用真实调用(如果配置了API Key)或跳过
|
||
"""
|
||
import uuid
|
||
from datetime import datetime
|
||
|
||
import pytest
|
||
import pytest_asyncio
|
||
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
|
||
from sqlalchemy.pool import StaticPool
|
||
|
||
from app.database import Base
|
||
from app.models.knowledge_graph import (
|
||
KnowledgeEntity,
|
||
KnowledgeRelation,
|
||
EntityType,
|
||
RelationType,
|
||
)
|
||
from app.models.knowledge import KnowledgeBase, KnowledgeDocument, KnowledgeChunk
|
||
from app.services.knowledge.graph_builder import GraphBuilder
|
||
from app.services.knowledge.graph_query import GraphQuery
|
||
|
||
|
||
# ============================================================================
|
||
# Fixtures
|
||
# ============================================================================
|
||
|
||
@pytest_asyncio.fixture
|
||
async def kg_db_engine():
|
||
"""创建知识图谱测试用内存数据库引擎"""
|
||
engine = create_async_engine(
|
||
"sqlite+aiosqlite:///:memory:",
|
||
connect_args={"check_same_thread": False},
|
||
poolclass=StaticPool,
|
||
)
|
||
async with engine.begin() as conn:
|
||
await conn.run_sync(Base.metadata.create_all)
|
||
yield engine
|
||
await engine.dispose()
|
||
|
||
|
||
@pytest_asyncio.fixture
|
||
async def kg_db_session(kg_db_engine):
|
||
"""创建知识图谱测试用数据库会话"""
|
||
async_session = async_sessionmaker(
|
||
kg_db_engine,
|
||
class_=AsyncSession,
|
||
expire_on_commit=False,
|
||
autoflush=False,
|
||
autocommit=False,
|
||
)
|
||
async with async_session() as session:
|
||
yield session
|
||
await session.rollback()
|
||
|
||
|
||
@pytest_asyncio.fixture
|
||
async def kg_test_data(kg_db_session):
|
||
"""创建知识图谱测试基础数据(知识库、文档、Chunk)"""
|
||
# 创建知识库
|
||
kb = KnowledgeBase(
|
||
id=uuid.uuid4(),
|
||
name="测试知识库",
|
||
description="用于测试的知识库",
|
||
)
|
||
kg_db_session.add(kb)
|
||
|
||
# 创建文档
|
||
doc = KnowledgeDocument(
|
||
id=uuid.uuid4(),
|
||
knowledge_base_id=kb.id,
|
||
title="华为公司介绍",
|
||
source="test",
|
||
)
|
||
kg_db_session.add(doc)
|
||
|
||
# 创建Chunk
|
||
chunk = KnowledgeChunk(
|
||
id=uuid.uuid4(),
|
||
document_id=doc.id,
|
||
content="华为是全球领先的ICT解决方案供应商,总部位于深圳。华为与小米是竞争对手。",
|
||
chunk_index=0,
|
||
)
|
||
kg_db_session.add(chunk)
|
||
|
||
await kg_db_session.commit()
|
||
|
||
return {
|
||
"kb_id": kb.id,
|
||
"doc_id": doc.id,
|
||
"chunk_id": chunk.id,
|
||
}
|
||
|
||
|
||
# ============================================================================
|
||
# 知识实体测试
|
||
# ============================================================================
|
||
|
||
class TestKnowledgeEntity:
|
||
"""知识实体测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_create_entity(self, kg_db_session, kg_test_data):
|
||
"""测试创建知识实体"""
|
||
entity = KnowledgeEntity(
|
||
knowledge_base_id=kg_test_data["kb_id"],
|
||
name="华为技术有限公司",
|
||
entity_type=EntityType.ORGANIZATION,
|
||
description="一家全球领先的ICT公司",
|
||
properties={"founded": "1987", "headquarters": "深圳"},
|
||
source_chunk_id=kg_test_data["chunk_id"],
|
||
confidence="high",
|
||
)
|
||
kg_db_session.add(entity)
|
||
await kg_db_session.commit()
|
||
|
||
# 验证实体被创建
|
||
assert entity.id is not None
|
||
assert entity.name == "华为技术有限公司"
|
||
assert entity.entity_type == EntityType.ORGANIZATION
|
||
assert entity.properties["founded"] == "1987"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_entity_relationships(self, kg_db_session, kg_test_data):
|
||
"""测试实体关系创建"""
|
||
# 创建两个实体
|
||
entity1 = KnowledgeEntity(
|
||
knowledge_base_id=kg_test_data["kb_id"],
|
||
name="华为",
|
||
entity_type=EntityType.ORGANIZATION,
|
||
)
|
||
entity2 = KnowledgeEntity(
|
||
knowledge_base_id=kg_test_data["kb_id"],
|
||
name="小米",
|
||
entity_type=EntityType.ORGANIZATION,
|
||
)
|
||
kg_db_session.add_all([entity1, entity2])
|
||
await kg_db_session.flush()
|
||
|
||
# 创建竞争关系
|
||
relation = KnowledgeRelation(
|
||
source_entity_id=entity1.id,
|
||
target_entity_id=entity2.id,
|
||
relation_type=RelationType.COMPETES_WITH,
|
||
source_chunk_id=kg_test_data["chunk_id"],
|
||
confidence="high",
|
||
)
|
||
kg_db_session.add(relation)
|
||
await kg_db_session.commit()
|
||
|
||
# 验证关系
|
||
assert relation.id is not None
|
||
assert relation.source_entity_id == entity1.id
|
||
assert relation.target_entity_id == entity2.id
|
||
assert relation.relation_type == RelationType.COMPETES_WITH
|
||
|
||
|
||
# ============================================================================
|
||
# 图谱查询服务测试
|
||
# ============================================================================
|
||
|
||
class TestGraphQuery:
|
||
"""图谱查询服务测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_search_entities_by_name(self, kg_db_session, kg_test_data):
|
||
"""测试按名称搜索实体"""
|
||
# 创建测试实体
|
||
entity1 = KnowledgeEntity(
|
||
knowledge_base_id=kg_test_data["kb_id"],
|
||
name="华为技术有限公司",
|
||
entity_type=EntityType.ORGANIZATION,
|
||
)
|
||
entity2 = KnowledgeEntity(
|
||
knowledge_base_id=kg_test_data["kb_id"],
|
||
name="华为技术",
|
||
entity_type=EntityType.TECHNOLOGY,
|
||
)
|
||
entity3 = KnowledgeEntity(
|
||
knowledge_base_id=kg_test_data["kb_id"],
|
||
name="小米科技",
|
||
entity_type=EntityType.ORGANIZATION,
|
||
)
|
||
kg_db_session.add_all([entity1, entity2, entity3])
|
||
await kg_db_session.commit()
|
||
|
||
# 执行搜索
|
||
query = GraphQuery()
|
||
results = await query.search_entities(
|
||
kg_db_session,
|
||
kg_test_data["kb_id"],
|
||
"华为"
|
||
)
|
||
|
||
# 验证搜索结果
|
||
assert len(results) == 2
|
||
entity_names = [r["name"] for r in results]
|
||
assert "华为技术有限公司" in entity_names
|
||
assert "华为技术" in entity_names
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_search_entities_by_type(self, kg_db_session, kg_test_data):
|
||
"""测试按类型筛选实体"""
|
||
# 创建测试实体
|
||
entity1 = KnowledgeEntity(
|
||
knowledge_base_id=kg_test_data["kb_id"],
|
||
name="华为",
|
||
entity_type=EntityType.ORGANIZATION,
|
||
)
|
||
entity2 = KnowledgeEntity(
|
||
knowledge_base_id=kg_test_data["kb_id"],
|
||
name="5G技术",
|
||
entity_type=EntityType.TECHNOLOGY,
|
||
)
|
||
kg_db_session.add_all([entity1, entity2])
|
||
await kg_db_session.commit()
|
||
|
||
# 执行搜索 - 按类型筛选
|
||
query = GraphQuery()
|
||
results = await query.search_entities(
|
||
kg_db_session,
|
||
kg_test_data["kb_id"],
|
||
"",
|
||
entity_type="TECHNOLOGY"
|
||
)
|
||
|
||
# 验证搜索结果
|
||
assert len(results) == 1
|
||
assert results[0]["name"] == "5G技术"
|
||
assert results[0]["entity_type"] == "TECHNOLOGY"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_entity_neighbors(self, kg_db_session, kg_test_data):
|
||
"""测试获取实体邻居"""
|
||
# 创建实体和关系
|
||
huawei = KnowledgeEntity(
|
||
knowledge_base_id=kg_test_data["kb_id"],
|
||
name="华为",
|
||
entity_type=EntityType.ORGANIZATION,
|
||
)
|
||
xiaomi = KnowledgeEntity(
|
||
knowledge_base_id=kg_test_data["kb_id"],
|
||
name="小米",
|
||
entity_type=EntityType.ORGANIZATION,
|
||
)
|
||
kg_db_session.add_all([huawei, xiaomi])
|
||
await kg_db_session.flush()
|
||
|
||
relation = KnowledgeRelation(
|
||
source_entity_id=huawei.id,
|
||
target_entity_id=xiaomi.id,
|
||
relation_type=RelationType.COMPETES_WITH,
|
||
)
|
||
kg_db_session.add(relation)
|
||
await kg_db_session.commit()
|
||
|
||
# 执行查询
|
||
query = GraphQuery()
|
||
neighbors = await query.get_entity_neighbors(kg_db_session, huawei.id)
|
||
|
||
# 验证结果
|
||
assert neighbors is not None
|
||
assert neighbors["entity"]["name"] == "华为"
|
||
assert len(neighbors["outgoing"]) == 1
|
||
assert neighbors["outgoing"][0]["entity"]["name"] == "小米"
|
||
assert neighbors["outgoing"][0]["relation"]["relation_type"] == "COMPETES_WITH"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_entity_path(self, kg_db_session, kg_test_data):
|
||
"""测试查找实体间路径"""
|
||
# 创建实体链:A -> B -> C
|
||
entity_a = KnowledgeEntity(
|
||
knowledge_base_id=kg_test_data["kb_id"],
|
||
name="华为",
|
||
entity_type=EntityType.ORGANIZATION,
|
||
)
|
||
entity_b = KnowledgeEntity(
|
||
knowledge_base_id=kg_test_data["kb_id"],
|
||
name="荣耀",
|
||
entity_type=EntityType.BRAND,
|
||
)
|
||
entity_c = KnowledgeEntity(
|
||
knowledge_base_id=kg_test_data["kb_id"],
|
||
name="华为终端",
|
||
entity_type=EntityType.ORGANIZATION,
|
||
)
|
||
kg_db_session.add_all([entity_a, entity_b, entity_c])
|
||
await kg_db_session.flush()
|
||
|
||
# 创建关系链
|
||
rel1 = KnowledgeRelation(
|
||
source_entity_id=entity_a.id,
|
||
target_entity_id=entity_b.id,
|
||
relation_type=RelationType.PART_OF,
|
||
)
|
||
rel2 = KnowledgeRelation(
|
||
source_entity_id=entity_b.id,
|
||
target_entity_id=entity_c.id,
|
||
relation_type=RelationType.PART_OF,
|
||
)
|
||
kg_db_session.add_all([rel1, rel2])
|
||
await kg_db_session.commit()
|
||
|
||
# 执行路径查找
|
||
query = GraphQuery()
|
||
path = await query.get_entity_path(
|
||
kg_db_session,
|
||
"华为",
|
||
"华为终端",
|
||
max_hops=3
|
||
)
|
||
|
||
# 验证路径
|
||
assert len(path) == 2 # 两跳关系
|
||
assert path[0]["from"] == "华为"
|
||
assert path[0]["to"] == "荣耀"
|
||
assert path[1]["from"] == "荣耀"
|
||
assert path[1]["to"] == "华为终端"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_statistics(self, kg_db_session, kg_test_data):
|
||
"""测试获取图谱统计信息"""
|
||
# 创建测试实体和关系
|
||
entities = [
|
||
KnowledgeEntity(knowledge_base_id=kg_test_data["kb_id"], name=f"实体{i}", entity_type=EntityType.ORGANIZATION if i % 2 == 0 else EntityType.TECHNOLOGY)
|
||
for i in range(5)
|
||
]
|
||
kg_db_session.add_all(entities)
|
||
await kg_db_session.flush()
|
||
|
||
# 创建关系
|
||
for i in range(len(entities) - 1):
|
||
rel = KnowledgeRelation(
|
||
source_entity_id=entities[i].id,
|
||
target_entity_id=entities[i + 1].id,
|
||
relation_type=RelationType.RELATED_TO,
|
||
)
|
||
kg_db_session.add(rel)
|
||
await kg_db_session.commit()
|
||
|
||
# 执行统计查询
|
||
query = GraphQuery()
|
||
stats = await query.get_statistics(kg_db_session, kg_test_data["kb_id"])
|
||
|
||
# 验证统计结果
|
||
assert stats["entity_count"] == 5
|
||
assert stats["relation_count"] == 4
|
||
assert "ORGANIZATION" in stats["entity_type_distribution"]
|
||
assert "TECHNOLOGY" in stats["entity_type_distribution"]
|
||
assert stats["entity_type_distribution"]["ORGANIZATION"] == 3
|
||
assert stats["entity_type_distribution"]["TECHNOLOGY"] == 2
|
||
|
||
|
||
# ============================================================================
|
||
# 图谱构建服务测试
|
||
# ============================================================================
|
||
|
||
class TestGraphBuilder:
|
||
"""图谱构建服务测试"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_build_from_chunk_requires_valid_chunk(self, kg_db_session):
|
||
"""测试构建图谱需要有效的Chunk"""
|
||
builder = GraphBuilder()
|
||
|
||
with pytest.raises(ValueError, match="Chunk not found"):
|
||
await builder.build_from_chunk(
|
||
kg_db_session,
|
||
chunk_id=str(uuid.uuid4())
|
||
)
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_chunk_kb_id(self, kg_db_session, kg_test_data):
|
||
"""测试获取Chunk所属知识库ID"""
|
||
builder = GraphBuilder()
|
||
|
||
kb_id = await builder._get_chunk_kb_id(
|
||
kg_db_session,
|
||
kg_test_data["chunk_id"]
|
||
)
|
||
|
||
assert kb_id == kg_test_data["kb_id"]
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_or_create_entity_creates_new(self, kg_db_session, kg_test_data):
|
||
"""测试创建新实体"""
|
||
from app.services.knowledge.entity_extractor import ExtractedEntity
|
||
|
||
builder = GraphBuilder()
|
||
|
||
extracted = ExtractedEntity(
|
||
name="新实体",
|
||
entity_type="ORGANIZATION",
|
||
description="测试描述",
|
||
properties={"confidence": "high"},
|
||
)
|
||
|
||
entity, created = await builder._get_or_create_entity(
|
||
kg_db_session,
|
||
kg_test_data["chunk_id"],
|
||
extracted
|
||
)
|
||
|
||
assert created is True
|
||
assert entity.name == "新实体"
|
||
assert entity.entity_type == EntityType.ORGANIZATION
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_get_or_create_entity_returns_existing(self, kg_db_session, kg_test_data):
|
||
"""测试返回已存在的实体"""
|
||
from app.services.knowledge.entity_extractor import ExtractedEntity
|
||
|
||
# 先创建一个实体
|
||
existing = KnowledgeEntity(
|
||
knowledge_base_id=kg_test_data["kb_id"],
|
||
name="已存在实体",
|
||
entity_type=EntityType.ORGANIZATION,
|
||
)
|
||
kg_db_session.add(existing)
|
||
await kg_db_session.commit()
|
||
|
||
# 尝试再次创建
|
||
builder = GraphBuilder()
|
||
extracted = ExtractedEntity(
|
||
name="已存在实体",
|
||
entity_type="ORGANIZATION",
|
||
)
|
||
|
||
entity, created = await builder._get_or_create_entity(
|
||
kg_db_session,
|
||
kg_test_data["chunk_id"],
|
||
extracted
|
||
)
|
||
|
||
assert created is False
|
||
assert entity.id == existing.id
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_create_relation_creates_new(self, kg_db_session, kg_test_data):
|
||
"""测试创建新关系"""
|
||
from app.services.knowledge.entity_extractor import ExtractedRelation
|
||
|
||
# 创建两个实体
|
||
entity1 = KnowledgeEntity(
|
||
knowledge_base_id=kg_test_data["kb_id"],
|
||
name="源实体",
|
||
entity_type=EntityType.ORGANIZATION,
|
||
)
|
||
entity2 = KnowledgeEntity(
|
||
knowledge_base_id=kg_test_data["kb_id"],
|
||
name="目标实体",
|
||
entity_type=EntityType.ORGANIZATION,
|
||
)
|
||
kg_db_session.add_all([entity1, entity2])
|
||
await kg_db_session.flush()
|
||
|
||
builder = GraphBuilder()
|
||
extracted = ExtractedRelation(
|
||
source_entity="源实体",
|
||
target_entity="目标实体",
|
||
relation_type="COMPETES_WITH",
|
||
properties={"confidence": "high"},
|
||
)
|
||
|
||
created = await builder._create_relation(
|
||
kg_db_session,
|
||
kg_test_data["chunk_id"],
|
||
entity1.id,
|
||
entity2.id,
|
||
extracted
|
||
)
|
||
|
||
assert created is True
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_create_relation_returns_existing(self, kg_db_session, kg_test_data):
|
||
"""测试关系已存在时返回False"""
|
||
from app.services.knowledge.entity_extractor import ExtractedRelation
|
||
|
||
# 创建两个实体和已有关系
|
||
entity1 = KnowledgeEntity(
|
||
knowledge_base_id=kg_test_data["kb_id"],
|
||
name="实体A",
|
||
entity_type=EntityType.ORGANIZATION,
|
||
)
|
||
entity2 = KnowledgeEntity(
|
||
knowledge_base_id=kg_test_data["kb_id"],
|
||
name="实体B",
|
||
entity_type=EntityType.ORGANIZATION,
|
||
)
|
||
kg_db_session.add_all([entity1, entity2])
|
||
await kg_db_session.flush()
|
||
|
||
existing_relation = KnowledgeRelation(
|
||
source_entity_id=entity1.id,
|
||
target_entity_id=entity2.id,
|
||
relation_type=RelationType.COMPETES_WITH,
|
||
)
|
||
kg_db_session.add(existing_relation)
|
||
await kg_db_session.commit()
|
||
|
||
# 尝试再次创建相同关系
|
||
builder = GraphBuilder()
|
||
extracted = ExtractedRelation(
|
||
source_entity="实体A",
|
||
target_entity="实体B",
|
||
relation_type="COMPETES_WITH",
|
||
)
|
||
|
||
created = await builder._create_relation(
|
||
kg_db_session,
|
||
kg_test_data["chunk_id"],
|
||
entity1.id,
|
||
entity2.id,
|
||
extracted
|
||
)
|
||
|
||
assert created is False
|
||
|
||
|
||
# ============================================================================
|
||
# 实体类型枚举测试
|
||
# ============================================================================
|
||
|
||
class TestEntityType:
|
||
"""实体类型枚举测试"""
|
||
|
||
def test_entity_type_values(self):
|
||
"""测试实体类型枚举值"""
|
||
assert EntityType.ORGANIZATION.value == "ORGANIZATION"
|
||
assert EntityType.PRODUCT.value == "PRODUCT"
|
||
assert EntityType.PERSON.value == "PERSON"
|
||
assert EntityType.TECHNOLOGY.value == "TECHNOLOGY"
|
||
assert EntityType.BRAND.value == "BRAND"
|
||
|
||
def test_entity_type_from_string(self):
|
||
"""测试从字符串创建实体类型"""
|
||
entity_type = EntityType("ORGANIZATION")
|
||
assert entity_type == EntityType.ORGANIZATION
|
||
|
||
|
||
# ============================================================================
|
||
# 关系类型枚举测试
|
||
# ============================================================================
|
||
|
||
class TestRelationType:
|
||
"""关系类型枚举测试"""
|
||
|
||
def test_relation_type_values(self):
|
||
"""测试关系类型枚举值"""
|
||
assert RelationType.COMPETES_WITH.value == "COMPETES_WITH"
|
||
assert RelationType.PARTNERS_WITH.value == "PARTNERS_WITH"
|
||
assert RelationType.PRODUCES.value == "PRODUCES"
|
||
assert RelationType.RELATED_TO.value == "RELATED_TO"
|
||
|
||
def test_relation_type_from_string(self):
|
||
"""测试从字符串创建关系类型"""
|
||
rel_type = RelationType("COMPETES_WITH")
|
||
assert rel_type == RelationType.COMPETES_WITH
|