geo/backend/app/api/knowledge_graph.py

116 lines
3.1 KiB
Python

"""知识图谱API"""
from typing import Optional
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_db, get_current_user
from app.models.user import User
from app.services.knowledge.graph_builder import GraphBuilder
from app.services.knowledge.graph_query import GraphQuery
router = APIRouter(prefix="/knowledge-bases", tags=["知识图谱"])
@router.post("/{kb_id}/graph/build")
async def build_graph(
kb_id: UUID,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
从知识库构建知识图谱
对知识库中的所有Chunks执行实体和关系抽取
"""
# TODO: 实现批量构建
# 目前先实现单个Chunk的构建
return {"message": "Use /graph/build-chunk to build from specific chunk"}
@router.post("/{kb_id}/graph/build-chunk/{chunk_id}")
async def build_graph_from_chunk(
kb_id: UUID,
chunk_id: UUID,
context: Optional[str] = None,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""从单个Chunk构建图谱"""
builder = GraphBuilder()
try:
stats = await builder.build_from_chunk(db, str(chunk_id), context)
return {
"status": "success",
"stats": stats,
}
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
@router.get("/{kb_id}/graph/statistics")
async def get_graph_statistics(
kb_id: UUID,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""获取图谱统计信息"""
query = GraphQuery()
stats = await query.get_statistics(db, str(kb_id))
return stats
@router.get("/{kb_id}/graph/entities/search")
async def search_entities(
kb_id: UUID,
q: str,
entity_type: Optional[str] = None,
limit: int = 20,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""搜索实体"""
query = GraphQuery()
entities = await query.search_entities(
db, str(kb_id), q, entity_type, limit
)
return {"entities": entities}
@router.get("/{kb_id}/graph/entities/{entity_id}")
async def get_entity(
kb_id: UUID,
entity_id: UUID,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""获取实体详情"""
query = GraphQuery()
entity = await query.get_entity(db, str(entity_id))
if not entity:
raise HTTPException(status_code=404, detail="Entity not found")
# 获取邻居
neighbors = await query.get_entity_neighbors(db, str(entity_id))
entity["neighbors"] = neighbors
return entity
@router.get("/{kb_id}/graph/path")
async def find_path(
kb_id: UUID,
source: str,
target: str,
max_hops: int = 3,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""查找两个实体之间的路径"""
query = GraphQuery()
path = await query.get_entity_path(db, source, target, max_hops)
return {"path": path, "hops": len(path)}