250 lines
8.3 KiB
Python
250 lines
8.3 KiB
Python
"""知识图谱查询服务"""
|
|
from typing import Optional
|
|
|
|
from sqlalchemy import select, func
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.orm import selectinload
|
|
|
|
from app.models.knowledge_graph import (
|
|
KnowledgeEntity,
|
|
KnowledgeRelation,
|
|
)
|
|
|
|
|
|
class GraphQuery:
|
|
"""知识图谱查询服务"""
|
|
|
|
async def get_entity(
|
|
self,
|
|
session: AsyncSession,
|
|
entity_id: str,
|
|
) -> Optional[dict]:
|
|
"""根据ID获取实体详情"""
|
|
entity = await session.get(KnowledgeEntity, entity_id)
|
|
if not entity:
|
|
return None
|
|
|
|
return self._entity_to_dict(entity)
|
|
|
|
async def search_entities(
|
|
self,
|
|
session: AsyncSession,
|
|
kb_id: str,
|
|
query: str,
|
|
entity_type: Optional[str] = None,
|
|
limit: int = 20,
|
|
) -> list[dict]:
|
|
"""搜索实体"""
|
|
stmt = select(KnowledgeEntity).where(
|
|
KnowledgeEntity.knowledge_base_id == kb_id,
|
|
KnowledgeEntity.name.ilike(f"%{query}%"),
|
|
)
|
|
|
|
if entity_type:
|
|
stmt = stmt.where(KnowledgeEntity.entity_type == entity_type)
|
|
|
|
stmt = stmt.limit(limit)
|
|
result = await session.execute(stmt)
|
|
|
|
return [self._entity_to_dict(e) for e in result.scalars()]
|
|
|
|
async def get_entity_neighbors(
|
|
self,
|
|
session: AsyncSession,
|
|
entity_id: str,
|
|
max_depth: int = 1,
|
|
) -> dict:
|
|
"""获取实体的邻居(直接关联的实体)"""
|
|
entity = await session.get(KnowledgeEntity, entity_id)
|
|
if not entity:
|
|
return None
|
|
|
|
neighbors = {
|
|
"entity": self._entity_to_dict(entity),
|
|
"incoming": [], # 入边(别人指向我)
|
|
"outgoing": [], # 出边(我指向别人)
|
|
}
|
|
|
|
# 获取入边
|
|
incoming_stmt = (
|
|
select(KnowledgeRelation, KnowledgeEntity)
|
|
.join(KnowledgeEntity, KnowledgeRelation.source_entity_id == KnowledgeEntity.id)
|
|
.where(KnowledgeRelation.target_entity_id == entity_id)
|
|
)
|
|
incoming_result = await session.execute(incoming_stmt)
|
|
for rel, source_entity in incoming_result:
|
|
neighbors["incoming"].append({
|
|
"relation": self._relation_to_dict(rel),
|
|
"entity": self._entity_to_dict(source_entity),
|
|
})
|
|
|
|
# 获取出边
|
|
outgoing_stmt = (
|
|
select(KnowledgeRelation, KnowledgeEntity)
|
|
.join(KnowledgeEntity, KnowledgeRelation.target_entity_id == KnowledgeEntity.id)
|
|
.where(KnowledgeRelation.source_entity_id == entity_id)
|
|
)
|
|
outgoing_result = await session.execute(outgoing_stmt)
|
|
for rel, target_entity in outgoing_result:
|
|
neighbors["outgoing"].append({
|
|
"relation": self._relation_to_dict(rel),
|
|
"entity": self._entity_to_dict(target_entity),
|
|
})
|
|
|
|
return neighbors
|
|
|
|
async def get_entity_path(
|
|
self,
|
|
session: AsyncSession,
|
|
source_name: str,
|
|
target_name: str,
|
|
max_hops: int = 3,
|
|
) -> list[dict]:
|
|
"""
|
|
查找两个实体之间的路径
|
|
|
|
使用简单BFS查找路径
|
|
"""
|
|
# 获取实体ID
|
|
source_stmt = select(KnowledgeEntity).where(
|
|
KnowledgeEntity.name == source_name
|
|
)
|
|
source_result = await session.execute(source_stmt)
|
|
source_entity = source_result.scalar_one_or_none()
|
|
|
|
target_stmt = select(KnowledgeEntity).where(
|
|
KnowledgeEntity.name == target_name
|
|
)
|
|
target_result = await session.execute(target_stmt)
|
|
target_entity = target_result.scalar_one_or_none()
|
|
|
|
if not source_entity or not target_entity:
|
|
return []
|
|
|
|
# BFS查找路径
|
|
visited = {str(source_entity.id)}
|
|
queue = [(str(source_entity.id), [])]
|
|
|
|
while queue:
|
|
current_id, path = queue.pop(0)
|
|
|
|
if current_id == str(target_entity.id):
|
|
# 找到路径,返回
|
|
return await self._format_path(path, session)
|
|
|
|
if len(path) >= max_hops:
|
|
continue
|
|
|
|
# 探索邻居
|
|
neighbors_stmt = (
|
|
select(KnowledgeRelation, KnowledgeEntity)
|
|
.join(KnowledgeEntity, KnowledgeRelation.target_entity_id == KnowledgeEntity.id)
|
|
.where(KnowledgeRelation.source_entity_id == current_id)
|
|
)
|
|
neighbors_result = await session.execute(neighbors_stmt)
|
|
|
|
for rel, neighbor in neighbors_result:
|
|
neighbor_id = str(neighbor.id)
|
|
if neighbor_id not in visited:
|
|
visited.add(neighbor_id)
|
|
new_path = path + [{
|
|
"from": current_id,
|
|
"relation": rel.relation_type.value,
|
|
"to": neighbor_id,
|
|
}]
|
|
queue.append((neighbor_id, new_path))
|
|
|
|
return []
|
|
|
|
async def get_statistics(
|
|
self,
|
|
session: AsyncSession,
|
|
kb_id: str,
|
|
) -> dict:
|
|
"""获取图谱统计信息"""
|
|
# 实体数量
|
|
entity_count_stmt = select(func.count()).where(
|
|
KnowledgeEntity.knowledge_base_id == kb_id
|
|
)
|
|
entity_count_result = await session.execute(entity_count_stmt)
|
|
entity_count = entity_count_result.scalar() or 0
|
|
|
|
# 关系数量
|
|
kb_entities = select(KnowledgeEntity.id).where(
|
|
KnowledgeEntity.knowledge_base_id == kb_id
|
|
)
|
|
relation_count_stmt = select(func.count()).where(
|
|
KnowledgeRelation.source_entity_id.in_(kb_entities)
|
|
)
|
|
relation_count_result = await session.execute(relation_count_stmt)
|
|
relation_count = relation_count_result.scalar() or 0
|
|
|
|
# 实体类型分布
|
|
type_dist_stmt = (
|
|
select(
|
|
KnowledgeEntity.entity_type,
|
|
func.count()
|
|
)
|
|
.where(KnowledgeEntity.knowledge_base_id == kb_id)
|
|
.group_by(KnowledgeEntity.entity_type)
|
|
)
|
|
type_dist_result = await session.execute(type_dist_stmt)
|
|
entity_type_dist = {str(k): v for k, v in type_dist_result}
|
|
|
|
# 关系类型分布
|
|
rel_type_dist_stmt = (
|
|
select(
|
|
KnowledgeRelation.relation_type,
|
|
func.count()
|
|
)
|
|
.where(KnowledgeRelation.source_entity_id.in_(kb_entities))
|
|
.group_by(KnowledgeRelation.relation_type)
|
|
)
|
|
rel_type_dist_result = await session.execute(rel_type_dist_stmt)
|
|
relation_type_dist = {str(k): v for k, v in rel_type_dist_result}
|
|
|
|
return {
|
|
"entity_count": entity_count,
|
|
"relation_count": relation_count,
|
|
"entity_type_distribution": entity_type_dist,
|
|
"relation_type_distribution": relation_type_dist,
|
|
}
|
|
|
|
def _entity_to_dict(self, entity: KnowledgeEntity) -> dict:
|
|
"""实体转字典"""
|
|
return {
|
|
"id": str(entity.id),
|
|
"name": entity.name,
|
|
"entity_type": entity.entity_type.value,
|
|
"description": entity.description,
|
|
"properties": entity.properties,
|
|
"confidence": entity.confidence,
|
|
}
|
|
|
|
def _relation_to_dict(self, relation: KnowledgeRelation) -> dict:
|
|
"""关系转字典"""
|
|
return {
|
|
"id": str(relation.id),
|
|
"source_id": str(relation.source_entity_id),
|
|
"target_id": str(relation.target_entity_id),
|
|
"relation_type": relation.relation_type.value,
|
|
"properties": relation.properties,
|
|
"confidence": relation.confidence,
|
|
}
|
|
|
|
async def _format_path(self, path: list, session: AsyncSession) -> list[dict]:
|
|
"""格式化路径,返回实体名称"""
|
|
formatted = []
|
|
for step in path:
|
|
# 获取实体名称
|
|
from_entity = await session.get(KnowledgeEntity, step["from"])
|
|
to_entity = await session.get(KnowledgeEntity, step["to"])
|
|
|
|
formatted.append({
|
|
"from": from_entity.name if from_entity else step["from"],
|
|
"relation": step["relation"],
|
|
"to": to_entity.name if to_entity else step["to"],
|
|
})
|
|
|
|
return formatted
|