geo/backend/app/services/knowledge/graph_builder.py

188 lines
5.6 KiB
Python

"""知识图谱构建服务"""
from typing import Optional
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.knowledge_graph import (
KnowledgeEntity,
KnowledgeRelation,
EntityType,
RelationType,
)
from app.models.knowledge import KnowledgeChunk, KnowledgeDocument
from app.services.knowledge.entity_extractor import EntityExtractor, ExtractionResult
class GraphBuilder:
"""知识图谱构建服务"""
def __init__(self):
self.extractor = EntityExtractor()
async def build_from_chunk(
self,
session: AsyncSession,
chunk_id: str,
context: Optional[str] = None,
) -> dict:
"""
从Chunk构建知识图谱
Args:
session: 数据库会话
chunk_id: Chunk ID
context: 可选的上下文(如品牌名)
Returns:
构建统计信息
"""
# 1. 获取Chunk内容
chunk = await session.get(KnowledgeChunk, chunk_id)
if not chunk:
raise ValueError(f"Chunk not found: {chunk_id}")
# 2. 抽取实体和关系
result = await self.extractor.extract(chunk.content, context)
# 3. 存储到图谱
stats = await self._store_extraction(session, chunk_id, result)
return stats
async def _store_extraction(
self,
session: AsyncSession,
chunk_id: str,
result: ExtractionResult,
) -> dict:
"""存储抽取结果"""
stats = {
"entities_created": 0,
"entities_existing": 0,
"relations_created": 0,
"relations_existing": 0,
}
# 实体名称到ID的映射
entity_map = {}
# 4. 存储实体
for extracted_entity in result.entities:
# 检查是否已存在
existing, created = await self._get_or_create_entity(
session,
chunk_id,
extracted_entity,
)
entity_map[extracted_entity.name] = existing.id
if created:
stats["entities_created"] += 1
else:
stats["entities_existing"] += 1
# 5. 存储关系
for extracted_relation in result.relations:
# 查找实体ID
source_id = entity_map.get(extracted_relation.source_entity)
target_id = entity_map.get(extracted_relation.target_entity)
if not source_id or not target_id:
continue # 跳过找不到实体的情况
# 创建关系
created = await self._create_relation(
session,
chunk_id,
source_id,
target_id,
extracted_relation,
)
if created:
stats["relations_created"] += 1
else:
stats["relations_existing"] += 1
await session.commit()
return stats
async def _get_or_create_entity(
self,
session: AsyncSession,
chunk_id: str,
extracted_entity,
) -> tuple:
"""获取或创建实体"""
kb_id = await self._get_chunk_kb_id(session, chunk_id)
# 查找现有实体
stmt = select(KnowledgeEntity).where(
KnowledgeEntity.knowledge_base_id == kb_id,
KnowledgeEntity.name == extracted_entity.name,
)
result = await session.execute(stmt)
existing = result.scalar_one_or_none()
if existing:
return (existing, False)
# 创建新实体
entity = KnowledgeEntity(
knowledge_base_id=kb_id,
name=extracted_entity.name,
entity_type=EntityType(extracted_entity.entity_type),
description=extracted_entity.description,
properties=extracted_entity.properties or {},
source_chunk_id=chunk_id,
confidence=extracted_entity.properties.get("confidence") if extracted_entity.properties else None,
)
session.add(entity)
await session.flush()
return (entity, True)
async def _create_relation(
self,
session: AsyncSession,
chunk_id: str,
source_id: str,
target_id: str,
extracted_relation,
) -> bool:
"""创建关系(如果不存在)"""
# 检查是否已存在
stmt = select(KnowledgeRelation).where(
KnowledgeRelation.source_entity_id == source_id,
KnowledgeRelation.target_entity_id == target_id,
KnowledgeRelation.relation_type == RelationType(extracted_relation.relation_type),
)
result = await session.execute(stmt)
existing = result.scalar_one_or_none()
if existing:
return False
# 创建关系
relation = KnowledgeRelation(
source_entity_id=source_id,
target_entity_id=target_id,
relation_type=RelationType(extracted_relation.relation_type),
properties=extracted_relation.properties or {},
source_chunk_id=chunk_id,
confidence=extracted_relation.properties.get("confidence") if extracted_relation.properties else None,
)
session.add(relation)
return True
async def _get_chunk_kb_id(self, session: AsyncSession, chunk_id: str) -> str:
"""获取Chunk所属的知识库ID"""
chunk = await session.get(KnowledgeChunk, chunk_id)
if not chunk:
raise ValueError(f"Chunk not found: {chunk_id}")
# 通过document获取kb_id
doc = await session.get(KnowledgeDocument, chunk.document_id)
return doc.knowledge_base_id