geo/backend/app/services/knowledge/graph_query.py

250 lines
8.3 KiB
Python

"""知识图谱查询服务"""
from typing import Optional
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.models.knowledge_graph import (
KnowledgeEntity,
KnowledgeRelation,
)
class GraphQuery:
"""知识图谱查询服务"""
async def get_entity(
self,
session: AsyncSession,
entity_id: str,
) -> Optional[dict]:
"""根据ID获取实体详情"""
entity = await session.get(KnowledgeEntity, entity_id)
if not entity:
return None
return self._entity_to_dict(entity)
async def search_entities(
self,
session: AsyncSession,
kb_id: str,
query: str,
entity_type: Optional[str] = None,
limit: int = 20,
) -> list[dict]:
"""搜索实体"""
stmt = select(KnowledgeEntity).where(
KnowledgeEntity.knowledge_base_id == kb_id,
KnowledgeEntity.name.ilike(f"%{query}%"),
)
if entity_type:
stmt = stmt.where(KnowledgeEntity.entity_type == entity_type)
stmt = stmt.limit(limit)
result = await session.execute(stmt)
return [self._entity_to_dict(e) for e in result.scalars()]
async def get_entity_neighbors(
self,
session: AsyncSession,
entity_id: str,
max_depth: int = 1,
) -> dict:
"""获取实体的邻居(直接关联的实体)"""
entity = await session.get(KnowledgeEntity, entity_id)
if not entity:
return None
neighbors = {
"entity": self._entity_to_dict(entity),
"incoming": [], # 入边(别人指向我)
"outgoing": [], # 出边(我指向别人)
}
# 获取入边
incoming_stmt = (
select(KnowledgeRelation, KnowledgeEntity)
.join(KnowledgeEntity, KnowledgeRelation.source_entity_id == KnowledgeEntity.id)
.where(KnowledgeRelation.target_entity_id == entity_id)
)
incoming_result = await session.execute(incoming_stmt)
for rel, source_entity in incoming_result:
neighbors["incoming"].append({
"relation": self._relation_to_dict(rel),
"entity": self._entity_to_dict(source_entity),
})
# 获取出边
outgoing_stmt = (
select(KnowledgeRelation, KnowledgeEntity)
.join(KnowledgeEntity, KnowledgeRelation.target_entity_id == KnowledgeEntity.id)
.where(KnowledgeRelation.source_entity_id == entity_id)
)
outgoing_result = await session.execute(outgoing_stmt)
for rel, target_entity in outgoing_result:
neighbors["outgoing"].append({
"relation": self._relation_to_dict(rel),
"entity": self._entity_to_dict(target_entity),
})
return neighbors
async def get_entity_path(
self,
session: AsyncSession,
source_name: str,
target_name: str,
max_hops: int = 3,
) -> list[dict]:
"""
查找两个实体之间的路径
使用简单BFS查找路径
"""
# 获取实体ID
source_stmt = select(KnowledgeEntity).where(
KnowledgeEntity.name == source_name
)
source_result = await session.execute(source_stmt)
source_entity = source_result.scalar_one_or_none()
target_stmt = select(KnowledgeEntity).where(
KnowledgeEntity.name == target_name
)
target_result = await session.execute(target_stmt)
target_entity = target_result.scalar_one_or_none()
if not source_entity or not target_entity:
return []
# BFS查找路径
visited = {str(source_entity.id)}
queue = [(str(source_entity.id), [])]
while queue:
current_id, path = queue.pop(0)
if current_id == str(target_entity.id):
# 找到路径,返回
return await self._format_path(path, session)
if len(path) >= max_hops:
continue
# 探索邻居
neighbors_stmt = (
select(KnowledgeRelation, KnowledgeEntity)
.join(KnowledgeEntity, KnowledgeRelation.target_entity_id == KnowledgeEntity.id)
.where(KnowledgeRelation.source_entity_id == current_id)
)
neighbors_result = await session.execute(neighbors_stmt)
for rel, neighbor in neighbors_result:
neighbor_id = str(neighbor.id)
if neighbor_id not in visited:
visited.add(neighbor_id)
new_path = path + [{
"from": current_id,
"relation": rel.relation_type.value,
"to": neighbor_id,
}]
queue.append((neighbor_id, new_path))
return []
async def get_statistics(
self,
session: AsyncSession,
kb_id: str,
) -> dict:
"""获取图谱统计信息"""
# 实体数量
entity_count_stmt = select(func.count()).where(
KnowledgeEntity.knowledge_base_id == kb_id
)
entity_count_result = await session.execute(entity_count_stmt)
entity_count = entity_count_result.scalar() or 0
# 关系数量
kb_entities = select(KnowledgeEntity.id).where(
KnowledgeEntity.knowledge_base_id == kb_id
)
relation_count_stmt = select(func.count()).where(
KnowledgeRelation.source_entity_id.in_(kb_entities)
)
relation_count_result = await session.execute(relation_count_stmt)
relation_count = relation_count_result.scalar() or 0
# 实体类型分布
type_dist_stmt = (
select(
KnowledgeEntity.entity_type,
func.count()
)
.where(KnowledgeEntity.knowledge_base_id == kb_id)
.group_by(KnowledgeEntity.entity_type)
)
type_dist_result = await session.execute(type_dist_stmt)
entity_type_dist = {str(k): v for k, v in type_dist_result}
# 关系类型分布
rel_type_dist_stmt = (
select(
KnowledgeRelation.relation_type,
func.count()
)
.where(KnowledgeRelation.source_entity_id.in_(kb_entities))
.group_by(KnowledgeRelation.relation_type)
)
rel_type_dist_result = await session.execute(rel_type_dist_stmt)
relation_type_dist = {str(k): v for k, v in rel_type_dist_result}
return {
"entity_count": entity_count,
"relation_count": relation_count,
"entity_type_distribution": entity_type_dist,
"relation_type_distribution": relation_type_dist,
}
def _entity_to_dict(self, entity: KnowledgeEntity) -> dict:
"""实体转字典"""
return {
"id": str(entity.id),
"name": entity.name,
"entity_type": entity.entity_type.value,
"description": entity.description,
"properties": entity.properties,
"confidence": entity.confidence,
}
def _relation_to_dict(self, relation: KnowledgeRelation) -> dict:
"""关系转字典"""
return {
"id": str(relation.id),
"source_id": str(relation.source_entity_id),
"target_id": str(relation.target_entity_id),
"relation_type": relation.relation_type.value,
"properties": relation.properties,
"confidence": relation.confidence,
}
async def _format_path(self, path: list, session: AsyncSession) -> list[dict]:
"""格式化路径,返回实体名称"""
formatted = []
for step in path:
# 获取实体名称
from_entity = await session.get(KnowledgeEntity, step["from"])
to_entity = await session.get(KnowledgeEntity, step["to"])
formatted.append({
"from": from_entity.name if from_entity else step["from"],
"relation": step["relation"],
"to": to_entity.name if to_entity else step["to"],
})
return formatted