geo/backend/app/api/knowledge_graph.py

172 lines
4.8 KiB
Python

"""知识图谱API"""
from typing import Optional
from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_db, get_current_user
from app.models.knowledge import KnowledgeBase
from app.models.knowledge_graph import KnowledgeEntity, EntityType
from app.models.user import User
from app.services.knowledge.graph_builder import GraphBuilder
from app.services.knowledge.graph_query import GraphQuery
router = APIRouter(prefix="/knowledge-bases", tags=["知识图谱"])
class EntityCreateRequest(BaseModel):
name: str = Field(..., max_length=500)
entity_type: str
description: Optional[str] = None
properties: Optional[dict] = None
def _entity_to_dict(entity: KnowledgeEntity) -> dict:
return {
"id": str(entity.id),
"name": entity.name,
"entity_type": entity.entity_type.value,
"description": entity.description,
"properties": entity.properties,
"confidence": entity.confidence,
}
@router.post("/{kb_id}/entities/batch")
async def batch_create_entities(
kb_id: UUID,
entities: list[EntityCreateRequest],
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""批量创建知识图谱实体"""
if not entities:
raise HTTPException(status_code=400, detail="实体列表不能为空")
if len(entities) > 100:
raise HTTPException(status_code=400, detail="单次批量创建不能超过100个实体")
kb = await db.get(KnowledgeBase, kb_id)
if not kb:
raise HTTPException(status_code=404, detail="知识库不存在")
created = []
for entity_req in entities:
entity = KnowledgeEntity(
knowledge_base_id=kb_id,
name=entity_req.name,
entity_type=EntityType(entity_req.entity_type),
description=entity_req.description,
properties=entity_req.properties or {},
)
db.add(entity)
created.append(entity)
await db.commit()
for entity in created:
await db.refresh(entity)
return {"created_count": len(created), "entities": [_entity_to_dict(e) for e in created]}
@router.post("/{kb_id}/graph/build")
async def build_graph(
kb_id: UUID,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
从知识库构建知识图谱
对知识库中的所有Chunks执行实体和关系抽取
"""
return {"message": "Use /graph/build-chunk to build from specific chunk"}
@router.post("/{kb_id}/graph/build-chunk/{chunk_id}")
async def build_graph_from_chunk(
kb_id: UUID,
chunk_id: UUID,
context: Optional[str] = None,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""从单个Chunk构建图谱"""
builder = GraphBuilder()
try:
stats = await builder.build_from_chunk(db, str(chunk_id), context)
return {
"status": "success",
"stats": stats,
}
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
@router.get("/{kb_id}/graph/statistics")
async def get_graph_statistics(
kb_id: UUID,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""获取图谱统计信息"""
query = GraphQuery()
stats = await query.get_statistics(db, str(kb_id))
return stats
@router.get("/{kb_id}/graph/entities/search")
async def search_entities(
kb_id: UUID,
q: str,
entity_type: Optional[str] = None,
limit: int = 20,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""搜索实体"""
query = GraphQuery()
entities = await query.search_entities(
db, str(kb_id), q, entity_type, limit
)
return {"entities": entities}
@router.get("/{kb_id}/graph/entities/{entity_id}")
async def get_entity(
kb_id: UUID,
entity_id: UUID,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""获取实体详情"""
query = GraphQuery()
entity = await query.get_entity(db, str(entity_id))
if not entity:
raise HTTPException(status_code=404, detail="Entity not found")
# 获取邻居
neighbors = await query.get_entity_neighbors(db, str(entity_id))
entity["neighbors"] = neighbors
return entity
@router.get("/{kb_id}/graph/path")
async def find_path(
kb_id: UUID,
source: str,
target: str,
max_hops: int = 3,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""查找两个实体之间的路径"""
query = GraphQuery()
path = await query.get_entity_path(db, source, target, max_hops)
return {"path": path, "hops": len(path)}