116 lines
3.1 KiB
Python
116 lines
3.1 KiB
Python
"""知识图谱API"""
|
|
from typing import Optional
|
|
from uuid import UUID
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.api.deps import get_db, get_current_user
|
|
from app.models.user import User
|
|
from app.services.knowledge.graph_builder import GraphBuilder
|
|
from app.services.knowledge.graph_query import GraphQuery
|
|
|
|
router = APIRouter(prefix="/knowledge-bases", tags=["知识图谱"])
|
|
|
|
|
|
@router.post("/{kb_id}/graph/build")
|
|
async def build_graph(
|
|
kb_id: UUID,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
"""
|
|
从知识库构建知识图谱
|
|
|
|
对知识库中的所有Chunks执行实体和关系抽取
|
|
"""
|
|
# TODO: 实现批量构建
|
|
# 目前先实现单个Chunk的构建
|
|
return {"message": "Use /graph/build-chunk to build from specific chunk"}
|
|
|
|
|
|
@router.post("/{kb_id}/graph/build-chunk/{chunk_id}")
|
|
async def build_graph_from_chunk(
|
|
kb_id: UUID,
|
|
chunk_id: UUID,
|
|
context: Optional[str] = None,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
"""从单个Chunk构建图谱"""
|
|
builder = GraphBuilder()
|
|
|
|
try:
|
|
stats = await builder.build_from_chunk(db, str(chunk_id), context)
|
|
return {
|
|
"status": "success",
|
|
"stats": stats,
|
|
}
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=404, detail=str(e))
|
|
|
|
|
|
@router.get("/{kb_id}/graph/statistics")
|
|
async def get_graph_statistics(
|
|
kb_id: UUID,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
"""获取图谱统计信息"""
|
|
query = GraphQuery()
|
|
stats = await query.get_statistics(db, str(kb_id))
|
|
return stats
|
|
|
|
|
|
@router.get("/{kb_id}/graph/entities/search")
|
|
async def search_entities(
|
|
kb_id: UUID,
|
|
q: str,
|
|
entity_type: Optional[str] = None,
|
|
limit: int = 20,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
"""搜索实体"""
|
|
query = GraphQuery()
|
|
entities = await query.search_entities(
|
|
db, str(kb_id), q, entity_type, limit
|
|
)
|
|
return {"entities": entities}
|
|
|
|
|
|
@router.get("/{kb_id}/graph/entities/{entity_id}")
|
|
async def get_entity(
|
|
kb_id: UUID,
|
|
entity_id: UUID,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
"""获取实体详情"""
|
|
query = GraphQuery()
|
|
entity = await query.get_entity(db, str(entity_id))
|
|
|
|
if not entity:
|
|
raise HTTPException(status_code=404, detail="Entity not found")
|
|
|
|
# 获取邻居
|
|
neighbors = await query.get_entity_neighbors(db, str(entity_id))
|
|
entity["neighbors"] = neighbors
|
|
|
|
return entity
|
|
|
|
|
|
@router.get("/{kb_id}/graph/path")
|
|
async def find_path(
|
|
kb_id: UUID,
|
|
source: str,
|
|
target: str,
|
|
max_hops: int = 3,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
"""查找两个实体之间的路径"""
|
|
query = GraphQuery()
|
|
path = await query.get_entity_path(db, source, target, max_hops)
|
|
return {"path": path, "hops": len(path)}
|