geo/tests/test_knowledge_graph.py

564 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""知识图谱模块TDD测试
测试策略:
- 使用真实数据库内存SQLite进行测试
- 不使用Mock测试数据库操作
- LLM调用使用真实调用如果配置了API Key或跳过
"""
import uuid
from datetime import datetime
import pytest
import pytest_asyncio
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession
from sqlalchemy.pool import StaticPool
from app.database import Base
from app.models.knowledge_graph import (
KnowledgeEntity,
KnowledgeRelation,
EntityType,
RelationType,
)
from app.models.knowledge import KnowledgeBase, KnowledgeDocument, KnowledgeChunk
from app.services.knowledge.graph_builder import GraphBuilder
from app.services.knowledge.graph_query import GraphQuery
# ============================================================================
# Fixtures
# ============================================================================
@pytest_asyncio.fixture
async def kg_db_engine():
"""创建知识图谱测试用内存数据库引擎"""
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest_asyncio.fixture
async def kg_db_session(kg_db_engine):
"""创建知识图谱测试用数据库会话"""
async_session = async_sessionmaker(
kg_db_engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False,
)
async with async_session() as session:
yield session
await session.rollback()
@pytest_asyncio.fixture
async def kg_test_data(kg_db_session):
"""创建知识图谱测试基础数据知识库、文档、Chunk"""
# 创建知识库
kb = KnowledgeBase(
id=uuid.uuid4(),
name="测试知识库",
description="用于测试的知识库",
)
kg_db_session.add(kb)
# 创建文档
doc = KnowledgeDocument(
id=uuid.uuid4(),
knowledge_base_id=kb.id,
title="华为公司介绍",
source="test",
)
kg_db_session.add(doc)
# 创建Chunk
chunk = KnowledgeChunk(
id=uuid.uuid4(),
document_id=doc.id,
content="华为是全球领先的ICT解决方案供应商总部位于深圳。华为与小米是竞争对手。",
chunk_index=0,
)
kg_db_session.add(chunk)
await kg_db_session.commit()
return {
"kb_id": kb.id,
"doc_id": doc.id,
"chunk_id": chunk.id,
}
# ============================================================================
# 知识实体测试
# ============================================================================
class TestKnowledgeEntity:
"""知识实体测试"""
@pytest.mark.asyncio
async def test_create_entity(self, kg_db_session, kg_test_data):
"""测试创建知识实体"""
entity = KnowledgeEntity(
knowledge_base_id=kg_test_data["kb_id"],
name="华为技术有限公司",
entity_type=EntityType.ORGANIZATION,
description="一家全球领先的ICT公司",
properties={"founded": "1987", "headquarters": "深圳"},
source_chunk_id=kg_test_data["chunk_id"],
confidence="high",
)
kg_db_session.add(entity)
await kg_db_session.commit()
# 验证实体被创建
assert entity.id is not None
assert entity.name == "华为技术有限公司"
assert entity.entity_type == EntityType.ORGANIZATION
assert entity.properties["founded"] == "1987"
@pytest.mark.asyncio
async def test_entity_relationships(self, kg_db_session, kg_test_data):
"""测试实体关系创建"""
# 创建两个实体
entity1 = KnowledgeEntity(
knowledge_base_id=kg_test_data["kb_id"],
name="华为",
entity_type=EntityType.ORGANIZATION,
)
entity2 = KnowledgeEntity(
knowledge_base_id=kg_test_data["kb_id"],
name="小米",
entity_type=EntityType.ORGANIZATION,
)
kg_db_session.add_all([entity1, entity2])
await kg_db_session.flush()
# 创建竞争关系
relation = KnowledgeRelation(
source_entity_id=entity1.id,
target_entity_id=entity2.id,
relation_type=RelationType.COMPETES_WITH,
source_chunk_id=kg_test_data["chunk_id"],
confidence="high",
)
kg_db_session.add(relation)
await kg_db_session.commit()
# 验证关系
assert relation.id is not None
assert relation.source_entity_id == entity1.id
assert relation.target_entity_id == entity2.id
assert relation.relation_type == RelationType.COMPETES_WITH
# ============================================================================
# 图谱查询服务测试
# ============================================================================
class TestGraphQuery:
"""图谱查询服务测试"""
@pytest.mark.asyncio
async def test_search_entities_by_name(self, kg_db_session, kg_test_data):
"""测试按名称搜索实体"""
# 创建测试实体
entity1 = KnowledgeEntity(
knowledge_base_id=kg_test_data["kb_id"],
name="华为技术有限公司",
entity_type=EntityType.ORGANIZATION,
)
entity2 = KnowledgeEntity(
knowledge_base_id=kg_test_data["kb_id"],
name="华为技术",
entity_type=EntityType.TECHNOLOGY,
)
entity3 = KnowledgeEntity(
knowledge_base_id=kg_test_data["kb_id"],
name="小米科技",
entity_type=EntityType.ORGANIZATION,
)
kg_db_session.add_all([entity1, entity2, entity3])
await kg_db_session.commit()
# 执行搜索
query = GraphQuery()
results = await query.search_entities(
kg_db_session,
kg_test_data["kb_id"],
"华为"
)
# 验证搜索结果
assert len(results) == 2
entity_names = [r["name"] for r in results]
assert "华为技术有限公司" in entity_names
assert "华为技术" in entity_names
@pytest.mark.asyncio
async def test_search_entities_by_type(self, kg_db_session, kg_test_data):
"""测试按类型筛选实体"""
# 创建测试实体
entity1 = KnowledgeEntity(
knowledge_base_id=kg_test_data["kb_id"],
name="华为",
entity_type=EntityType.ORGANIZATION,
)
entity2 = KnowledgeEntity(
knowledge_base_id=kg_test_data["kb_id"],
name="5G技术",
entity_type=EntityType.TECHNOLOGY,
)
kg_db_session.add_all([entity1, entity2])
await kg_db_session.commit()
# 执行搜索 - 按类型筛选
query = GraphQuery()
results = await query.search_entities(
kg_db_session,
kg_test_data["kb_id"],
"",
entity_type="TECHNOLOGY"
)
# 验证搜索结果
assert len(results) == 1
assert results[0]["name"] == "5G技术"
assert results[0]["entity_type"] == "TECHNOLOGY"
@pytest.mark.asyncio
async def test_get_entity_neighbors(self, kg_db_session, kg_test_data):
"""测试获取实体邻居"""
# 创建实体和关系
huawei = KnowledgeEntity(
knowledge_base_id=kg_test_data["kb_id"],
name="华为",
entity_type=EntityType.ORGANIZATION,
)
xiaomi = KnowledgeEntity(
knowledge_base_id=kg_test_data["kb_id"],
name="小米",
entity_type=EntityType.ORGANIZATION,
)
kg_db_session.add_all([huawei, xiaomi])
await kg_db_session.flush()
relation = KnowledgeRelation(
source_entity_id=huawei.id,
target_entity_id=xiaomi.id,
relation_type=RelationType.COMPETES_WITH,
)
kg_db_session.add(relation)
await kg_db_session.commit()
# 执行查询
query = GraphQuery()
neighbors = await query.get_entity_neighbors(kg_db_session, huawei.id)
# 验证结果
assert neighbors is not None
assert neighbors["entity"]["name"] == "华为"
assert len(neighbors["outgoing"]) == 1
assert neighbors["outgoing"][0]["entity"]["name"] == "小米"
assert neighbors["outgoing"][0]["relation"]["relation_type"] == "COMPETES_WITH"
@pytest.mark.asyncio
async def test_get_entity_path(self, kg_db_session, kg_test_data):
"""测试查找实体间路径"""
# 创建实体链A -> B -> C
entity_a = KnowledgeEntity(
knowledge_base_id=kg_test_data["kb_id"],
name="华为",
entity_type=EntityType.ORGANIZATION,
)
entity_b = KnowledgeEntity(
knowledge_base_id=kg_test_data["kb_id"],
name="荣耀",
entity_type=EntityType.BRAND,
)
entity_c = KnowledgeEntity(
knowledge_base_id=kg_test_data["kb_id"],
name="华为终端",
entity_type=EntityType.ORGANIZATION,
)
kg_db_session.add_all([entity_a, entity_b, entity_c])
await kg_db_session.flush()
# 创建关系链
rel1 = KnowledgeRelation(
source_entity_id=entity_a.id,
target_entity_id=entity_b.id,
relation_type=RelationType.PART_OF,
)
rel2 = KnowledgeRelation(
source_entity_id=entity_b.id,
target_entity_id=entity_c.id,
relation_type=RelationType.PART_OF,
)
kg_db_session.add_all([rel1, rel2])
await kg_db_session.commit()
# 执行路径查找
query = GraphQuery()
path = await query.get_entity_path(
kg_db_session,
"华为",
"华为终端",
max_hops=3
)
# 验证路径
assert len(path) == 2 # 两跳关系
assert path[0]["from"] == "华为"
assert path[0]["to"] == "荣耀"
assert path[1]["from"] == "荣耀"
assert path[1]["to"] == "华为终端"
@pytest.mark.asyncio
async def test_get_statistics(self, kg_db_session, kg_test_data):
"""测试获取图谱统计信息"""
# 创建测试实体和关系
entities = [
KnowledgeEntity(knowledge_base_id=kg_test_data["kb_id"], name=f"实体{i}", entity_type=EntityType.ORGANIZATION if i % 2 == 0 else EntityType.TECHNOLOGY)
for i in range(5)
]
kg_db_session.add_all(entities)
await kg_db_session.flush()
# 创建关系
for i in range(len(entities) - 1):
rel = KnowledgeRelation(
source_entity_id=entities[i].id,
target_entity_id=entities[i + 1].id,
relation_type=RelationType.RELATED_TO,
)
kg_db_session.add(rel)
await kg_db_session.commit()
# 执行统计查询
query = GraphQuery()
stats = await query.get_statistics(kg_db_session, kg_test_data["kb_id"])
# 验证统计结果
assert stats["entity_count"] == 5
assert stats["relation_count"] == 4
assert "ORGANIZATION" in stats["entity_type_distribution"]
assert "TECHNOLOGY" in stats["entity_type_distribution"]
assert stats["entity_type_distribution"]["ORGANIZATION"] == 3
assert stats["entity_type_distribution"]["TECHNOLOGY"] == 2
# ============================================================================
# 图谱构建服务测试
# ============================================================================
class TestGraphBuilder:
"""图谱构建服务测试"""
@pytest.mark.asyncio
async def test_build_from_chunk_requires_valid_chunk(self, kg_db_session):
"""测试构建图谱需要有效的Chunk"""
builder = GraphBuilder()
with pytest.raises(ValueError, match="Chunk not found"):
await builder.build_from_chunk(
kg_db_session,
chunk_id=str(uuid.uuid4())
)
@pytest.mark.asyncio
async def test_get_chunk_kb_id(self, kg_db_session, kg_test_data):
"""测试获取Chunk所属知识库ID"""
builder = GraphBuilder()
kb_id = await builder._get_chunk_kb_id(
kg_db_session,
kg_test_data["chunk_id"]
)
assert kb_id == kg_test_data["kb_id"]
@pytest.mark.asyncio
async def test_get_or_create_entity_creates_new(self, kg_db_session, kg_test_data):
"""测试创建新实体"""
from app.services.knowledge.entity_extractor import ExtractedEntity
builder = GraphBuilder()
extracted = ExtractedEntity(
name="新实体",
entity_type="ORGANIZATION",
description="测试描述",
properties={"confidence": "high"},
)
entity, created = await builder._get_or_create_entity(
kg_db_session,
kg_test_data["chunk_id"],
extracted
)
assert created is True
assert entity.name == "新实体"
assert entity.entity_type == EntityType.ORGANIZATION
@pytest.mark.asyncio
async def test_get_or_create_entity_returns_existing(self, kg_db_session, kg_test_data):
"""测试返回已存在的实体"""
from app.services.knowledge.entity_extractor import ExtractedEntity
# 先创建一个实体
existing = KnowledgeEntity(
knowledge_base_id=kg_test_data["kb_id"],
name="已存在实体",
entity_type=EntityType.ORGANIZATION,
)
kg_db_session.add(existing)
await kg_db_session.commit()
# 尝试再次创建
builder = GraphBuilder()
extracted = ExtractedEntity(
name="已存在实体",
entity_type="ORGANIZATION",
)
entity, created = await builder._get_or_create_entity(
kg_db_session,
kg_test_data["chunk_id"],
extracted
)
assert created is False
assert entity.id == existing.id
@pytest.mark.asyncio
async def test_create_relation_creates_new(self, kg_db_session, kg_test_data):
"""测试创建新关系"""
from app.services.knowledge.entity_extractor import ExtractedRelation
# 创建两个实体
entity1 = KnowledgeEntity(
knowledge_base_id=kg_test_data["kb_id"],
name="源实体",
entity_type=EntityType.ORGANIZATION,
)
entity2 = KnowledgeEntity(
knowledge_base_id=kg_test_data["kb_id"],
name="目标实体",
entity_type=EntityType.ORGANIZATION,
)
kg_db_session.add_all([entity1, entity2])
await kg_db_session.flush()
builder = GraphBuilder()
extracted = ExtractedRelation(
source_entity="源实体",
target_entity="目标实体",
relation_type="COMPETES_WITH",
properties={"confidence": "high"},
)
created = await builder._create_relation(
kg_db_session,
kg_test_data["chunk_id"],
entity1.id,
entity2.id,
extracted
)
assert created is True
@pytest.mark.asyncio
async def test_create_relation_returns_existing(self, kg_db_session, kg_test_data):
"""测试关系已存在时返回False"""
from app.services.knowledge.entity_extractor import ExtractedRelation
# 创建两个实体和已有关系
entity1 = KnowledgeEntity(
knowledge_base_id=kg_test_data["kb_id"],
name="实体A",
entity_type=EntityType.ORGANIZATION,
)
entity2 = KnowledgeEntity(
knowledge_base_id=kg_test_data["kb_id"],
name="实体B",
entity_type=EntityType.ORGANIZATION,
)
kg_db_session.add_all([entity1, entity2])
await kg_db_session.flush()
existing_relation = KnowledgeRelation(
source_entity_id=entity1.id,
target_entity_id=entity2.id,
relation_type=RelationType.COMPETES_WITH,
)
kg_db_session.add(existing_relation)
await kg_db_session.commit()
# 尝试再次创建相同关系
builder = GraphBuilder()
extracted = ExtractedRelation(
source_entity="实体A",
target_entity="实体B",
relation_type="COMPETES_WITH",
)
created = await builder._create_relation(
kg_db_session,
kg_test_data["chunk_id"],
entity1.id,
entity2.id,
extracted
)
assert created is False
# ============================================================================
# 实体类型枚举测试
# ============================================================================
class TestEntityType:
"""实体类型枚举测试"""
def test_entity_type_values(self):
"""测试实体类型枚举值"""
assert EntityType.ORGANIZATION.value == "ORGANIZATION"
assert EntityType.PRODUCT.value == "PRODUCT"
assert EntityType.PERSON.value == "PERSON"
assert EntityType.TECHNOLOGY.value == "TECHNOLOGY"
assert EntityType.BRAND.value == "BRAND"
def test_entity_type_from_string(self):
"""测试从字符串创建实体类型"""
entity_type = EntityType("ORGANIZATION")
assert entity_type == EntityType.ORGANIZATION
# ============================================================================
# 关系类型枚举测试
# ============================================================================
class TestRelationType:
"""关系类型枚举测试"""
def test_relation_type_values(self):
"""测试关系类型枚举值"""
assert RelationType.COMPETES_WITH.value == "COMPETES_WITH"
assert RelationType.PARTNERS_WITH.value == "PARTNERS_WITH"
assert RelationType.PRODUCES.value == "PRODUCES"
assert RelationType.RELATED_TO.value == "RELATED_TO"
def test_relation_type_from_string(self):
"""测试从字符串创建关系类型"""
rel_type = RelationType("COMPETES_WITH")
assert rel_type == RelationType.COMPETES_WITH