188 lines
5.6 KiB
Python
188 lines
5.6 KiB
Python
"""知识图谱构建服务"""
|
|
from typing import Optional
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.models.knowledge_graph import (
|
|
KnowledgeEntity,
|
|
KnowledgeRelation,
|
|
EntityType,
|
|
RelationType,
|
|
)
|
|
from app.models.knowledge import KnowledgeChunk, KnowledgeDocument
|
|
from app.services.knowledge.entity_extractor import EntityExtractor, ExtractionResult
|
|
|
|
|
|
class GraphBuilder:
|
|
"""知识图谱构建服务"""
|
|
|
|
def __init__(self):
|
|
self.extractor = EntityExtractor()
|
|
|
|
async def build_from_chunk(
|
|
self,
|
|
session: AsyncSession,
|
|
chunk_id: str,
|
|
context: Optional[str] = None,
|
|
) -> dict:
|
|
"""
|
|
从Chunk构建知识图谱
|
|
|
|
Args:
|
|
session: 数据库会话
|
|
chunk_id: Chunk ID
|
|
context: 可选的上下文(如品牌名)
|
|
|
|
Returns:
|
|
构建统计信息
|
|
"""
|
|
# 1. 获取Chunk内容
|
|
chunk = await session.get(KnowledgeChunk, chunk_id)
|
|
if not chunk:
|
|
raise ValueError(f"Chunk not found: {chunk_id}")
|
|
|
|
# 2. 抽取实体和关系
|
|
result = await self.extractor.extract(chunk.content, context)
|
|
|
|
# 3. 存储到图谱
|
|
stats = await self._store_extraction(session, chunk_id, result)
|
|
|
|
return stats
|
|
|
|
async def _store_extraction(
|
|
self,
|
|
session: AsyncSession,
|
|
chunk_id: str,
|
|
result: ExtractionResult,
|
|
) -> dict:
|
|
"""存储抽取结果"""
|
|
stats = {
|
|
"entities_created": 0,
|
|
"entities_existing": 0,
|
|
"relations_created": 0,
|
|
"relations_existing": 0,
|
|
}
|
|
|
|
# 实体名称到ID的映射
|
|
entity_map = {}
|
|
|
|
# 4. 存储实体
|
|
for extracted_entity in result.entities:
|
|
# 检查是否已存在
|
|
existing, created = await self._get_or_create_entity(
|
|
session,
|
|
chunk_id,
|
|
extracted_entity,
|
|
)
|
|
|
|
entity_map[extracted_entity.name] = existing.id
|
|
if created:
|
|
stats["entities_created"] += 1
|
|
else:
|
|
stats["entities_existing"] += 1
|
|
|
|
# 5. 存储关系
|
|
for extracted_relation in result.relations:
|
|
# 查找实体ID
|
|
source_id = entity_map.get(extracted_relation.source_entity)
|
|
target_id = entity_map.get(extracted_relation.target_entity)
|
|
|
|
if not source_id or not target_id:
|
|
continue # 跳过找不到实体的情况
|
|
|
|
# 创建关系
|
|
created = await self._create_relation(
|
|
session,
|
|
chunk_id,
|
|
source_id,
|
|
target_id,
|
|
extracted_relation,
|
|
)
|
|
|
|
if created:
|
|
stats["relations_created"] += 1
|
|
else:
|
|
stats["relations_existing"] += 1
|
|
|
|
await session.commit()
|
|
return stats
|
|
|
|
async def _get_or_create_entity(
|
|
self,
|
|
session: AsyncSession,
|
|
chunk_id: str,
|
|
extracted_entity,
|
|
) -> tuple:
|
|
"""获取或创建实体"""
|
|
kb_id = await self._get_chunk_kb_id(session, chunk_id)
|
|
|
|
# 查找现有实体
|
|
stmt = select(KnowledgeEntity).where(
|
|
KnowledgeEntity.knowledge_base_id == kb_id,
|
|
KnowledgeEntity.name == extracted_entity.name,
|
|
)
|
|
result = await session.execute(stmt)
|
|
existing = result.scalar_one_or_none()
|
|
|
|
if existing:
|
|
return (existing, False)
|
|
|
|
# 创建新实体
|
|
entity = KnowledgeEntity(
|
|
knowledge_base_id=kb_id,
|
|
name=extracted_entity.name,
|
|
entity_type=EntityType(extracted_entity.entity_type),
|
|
description=extracted_entity.description,
|
|
properties=extracted_entity.properties or {},
|
|
source_chunk_id=chunk_id,
|
|
confidence=extracted_entity.properties.get("confidence") if extracted_entity.properties else None,
|
|
)
|
|
session.add(entity)
|
|
await session.flush()
|
|
|
|
return (entity, True)
|
|
|
|
async def _create_relation(
|
|
self,
|
|
session: AsyncSession,
|
|
chunk_id: str,
|
|
source_id: str,
|
|
target_id: str,
|
|
extracted_relation,
|
|
) -> bool:
|
|
"""创建关系(如果不存在)"""
|
|
# 检查是否已存在
|
|
stmt = select(KnowledgeRelation).where(
|
|
KnowledgeRelation.source_entity_id == source_id,
|
|
KnowledgeRelation.target_entity_id == target_id,
|
|
KnowledgeRelation.relation_type == RelationType(extracted_relation.relation_type),
|
|
)
|
|
result = await session.execute(stmt)
|
|
existing = result.scalar_one_or_none()
|
|
|
|
if existing:
|
|
return False
|
|
|
|
# 创建关系
|
|
relation = KnowledgeRelation(
|
|
source_entity_id=source_id,
|
|
target_entity_id=target_id,
|
|
relation_type=RelationType(extracted_relation.relation_type),
|
|
properties=extracted_relation.properties or {},
|
|
source_chunk_id=chunk_id,
|
|
confidence=extracted_relation.properties.get("confidence") if extracted_relation.properties else None,
|
|
)
|
|
session.add(relation)
|
|
return True
|
|
|
|
async def _get_chunk_kb_id(self, session: AsyncSession, chunk_id: str) -> str:
|
|
"""获取Chunk所属的知识库ID"""
|
|
chunk = await session.get(KnowledgeChunk, chunk_id)
|
|
if not chunk:
|
|
raise ValueError(f"Chunk not found: {chunk_id}")
|
|
|
|
# 通过document获取kb_id
|
|
doc = await session.get(KnowledgeDocument, chunk.document_id)
|
|
return doc.knowledge_base_id
|