"""知识图谱查询服务""" from typing import Optional from sqlalchemy import select, func from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from app.models.knowledge_graph import ( KnowledgeEntity, KnowledgeRelation, ) class GraphQuery: """知识图谱查询服务""" async def get_entity( self, session: AsyncSession, entity_id: str, ) -> Optional[dict]: """根据ID获取实体详情""" entity = await session.get(KnowledgeEntity, entity_id) if not entity: return None return self._entity_to_dict(entity) async def search_entities( self, session: AsyncSession, kb_id: str, query: str, entity_type: Optional[str] = None, limit: int = 20, ) -> list[dict]: """搜索实体""" stmt = select(KnowledgeEntity).where( KnowledgeEntity.knowledge_base_id == kb_id, KnowledgeEntity.name.ilike(f"%{query}%"), ) if entity_type: stmt = stmt.where(KnowledgeEntity.entity_type == entity_type) stmt = stmt.limit(limit) result = await session.execute(stmt) return [self._entity_to_dict(e) for e in result.scalars()] async def get_entity_neighbors( self, session: AsyncSession, entity_id: str, max_depth: int = 1, ) -> dict: """获取实体的邻居(直接关联的实体)""" entity = await session.get(KnowledgeEntity, entity_id) if not entity: return None neighbors = { "entity": self._entity_to_dict(entity), "incoming": [], # 入边(别人指向我) "outgoing": [], # 出边(我指向别人) } # 获取入边 incoming_stmt = ( select(KnowledgeRelation, KnowledgeEntity) .join(KnowledgeEntity, KnowledgeRelation.source_entity_id == KnowledgeEntity.id) .where(KnowledgeRelation.target_entity_id == entity_id) ) incoming_result = await session.execute(incoming_stmt) for rel, source_entity in incoming_result: neighbors["incoming"].append({ "relation": self._relation_to_dict(rel), "entity": self._entity_to_dict(source_entity), }) # 获取出边 outgoing_stmt = ( select(KnowledgeRelation, KnowledgeEntity) .join(KnowledgeEntity, KnowledgeRelation.target_entity_id == KnowledgeEntity.id) .where(KnowledgeRelation.source_entity_id == entity_id) ) outgoing_result = await session.execute(outgoing_stmt) for rel, target_entity in outgoing_result: neighbors["outgoing"].append({ "relation": self._relation_to_dict(rel), "entity": self._entity_to_dict(target_entity), }) return neighbors async def get_entity_path( self, session: AsyncSession, source_name: str, target_name: str, max_hops: int = 3, ) -> list[dict]: """ 查找两个实体之间的路径 使用简单BFS查找路径 """ # 获取实体ID source_stmt = select(KnowledgeEntity).where( KnowledgeEntity.name == source_name ) source_result = await session.execute(source_stmt) source_entity = source_result.scalar_one_or_none() target_stmt = select(KnowledgeEntity).where( KnowledgeEntity.name == target_name ) target_result = await session.execute(target_stmt) target_entity = target_result.scalar_one_or_none() if not source_entity or not target_entity: return [] # BFS查找路径 visited = {str(source_entity.id)} queue = [(str(source_entity.id), [])] while queue: current_id, path = queue.pop(0) if current_id == str(target_entity.id): # 找到路径,返回 return await self._format_path(path, session) if len(path) >= max_hops: continue # 探索邻居 neighbors_stmt = ( select(KnowledgeRelation, KnowledgeEntity) .join(KnowledgeEntity, KnowledgeRelation.target_entity_id == KnowledgeEntity.id) .where(KnowledgeRelation.source_entity_id == current_id) ) neighbors_result = await session.execute(neighbors_stmt) for rel, neighbor in neighbors_result: neighbor_id = str(neighbor.id) if neighbor_id not in visited: visited.add(neighbor_id) new_path = path + [{ "from": current_id, "relation": rel.relation_type.value, "to": neighbor_id, }] queue.append((neighbor_id, new_path)) return [] async def get_statistics( self, session: AsyncSession, kb_id: str, ) -> dict: """获取图谱统计信息""" # 实体数量 entity_count_stmt = select(func.count()).where( KnowledgeEntity.knowledge_base_id == kb_id ) entity_count_result = await session.execute(entity_count_stmt) entity_count = entity_count_result.scalar() or 0 # 关系数量 kb_entities = select(KnowledgeEntity.id).where( KnowledgeEntity.knowledge_base_id == kb_id ) relation_count_stmt = select(func.count()).where( KnowledgeRelation.source_entity_id.in_(kb_entities) ) relation_count_result = await session.execute(relation_count_stmt) relation_count = relation_count_result.scalar() or 0 # 实体类型分布 type_dist_stmt = ( select( KnowledgeEntity.entity_type, func.count() ) .where(KnowledgeEntity.knowledge_base_id == kb_id) .group_by(KnowledgeEntity.entity_type) ) type_dist_result = await session.execute(type_dist_stmt) entity_type_dist = {str(k): v for k, v in type_dist_result} # 关系类型分布 rel_type_dist_stmt = ( select( KnowledgeRelation.relation_type, func.count() ) .where(KnowledgeRelation.source_entity_id.in_(kb_entities)) .group_by(KnowledgeRelation.relation_type) ) rel_type_dist_result = await session.execute(rel_type_dist_stmt) relation_type_dist = {str(k): v for k, v in rel_type_dist_result} return { "entity_count": entity_count, "relation_count": relation_count, "entity_type_distribution": entity_type_dist, "relation_type_distribution": relation_type_dist, } def _entity_to_dict(self, entity: KnowledgeEntity) -> dict: """实体转字典""" return { "id": str(entity.id), "name": entity.name, "entity_type": entity.entity_type.value, "description": entity.description, "properties": entity.properties, "confidence": entity.confidence, } def _relation_to_dict(self, relation: KnowledgeRelation) -> dict: """关系转字典""" return { "id": str(relation.id), "source_id": str(relation.source_entity_id), "target_id": str(relation.target_entity_id), "relation_type": relation.relation_type.value, "properties": relation.properties, "confidence": relation.confidence, } async def _format_path(self, path: list, session: AsyncSession) -> list[dict]: """格式化路径,返回实体名称""" formatted = [] for step in path: # 获取实体名称 from_entity = await session.get(KnowledgeEntity, step["from"]) to_entity = await session.get(KnowledgeEntity, step["to"]) formatted.append({ "from": from_entity.name if from_entity else step["from"], "relation": step["relation"], "to": to_entity.name if to_entity else step["to"], }) return formatted