172 lines
4.8 KiB
Python
172 lines
4.8 KiB
Python
"""知识图谱API"""
|
|
from typing import Optional
|
|
from uuid import UUID
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException
|
|
from pydantic import BaseModel, Field
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.api.deps import get_db, get_current_user
|
|
from app.models.knowledge import KnowledgeBase
|
|
from app.models.knowledge_graph import KnowledgeEntity, EntityType
|
|
from app.models.user import User
|
|
from app.services.knowledge.graph_builder import GraphBuilder
|
|
from app.services.knowledge.graph_query import GraphQuery
|
|
|
|
router = APIRouter(prefix="/knowledge-bases", tags=["知识图谱"])
|
|
|
|
|
|
class EntityCreateRequest(BaseModel):
|
|
name: str = Field(..., max_length=500)
|
|
entity_type: str
|
|
description: Optional[str] = None
|
|
properties: Optional[dict] = None
|
|
|
|
|
|
def _entity_to_dict(entity: KnowledgeEntity) -> dict:
|
|
return {
|
|
"id": str(entity.id),
|
|
"name": entity.name,
|
|
"entity_type": entity.entity_type.value,
|
|
"description": entity.description,
|
|
"properties": entity.properties,
|
|
"confidence": entity.confidence,
|
|
}
|
|
|
|
|
|
@router.post("/{kb_id}/entities/batch")
|
|
async def batch_create_entities(
|
|
kb_id: UUID,
|
|
entities: list[EntityCreateRequest],
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
"""批量创建知识图谱实体"""
|
|
if not entities:
|
|
raise HTTPException(status_code=400, detail="实体列表不能为空")
|
|
|
|
if len(entities) > 100:
|
|
raise HTTPException(status_code=400, detail="单次批量创建不能超过100个实体")
|
|
|
|
kb = await db.get(KnowledgeBase, kb_id)
|
|
if not kb:
|
|
raise HTTPException(status_code=404, detail="知识库不存在")
|
|
|
|
created = []
|
|
for entity_req in entities:
|
|
entity = KnowledgeEntity(
|
|
knowledge_base_id=kb_id,
|
|
name=entity_req.name,
|
|
entity_type=EntityType(entity_req.entity_type),
|
|
description=entity_req.description,
|
|
properties=entity_req.properties or {},
|
|
)
|
|
db.add(entity)
|
|
created.append(entity)
|
|
|
|
await db.commit()
|
|
for entity in created:
|
|
await db.refresh(entity)
|
|
|
|
return {"created_count": len(created), "entities": [_entity_to_dict(e) for e in created]}
|
|
|
|
|
|
@router.post("/{kb_id}/graph/build")
|
|
async def build_graph(
|
|
kb_id: UUID,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
"""
|
|
从知识库构建知识图谱
|
|
|
|
对知识库中的所有Chunks执行实体和关系抽取
|
|
"""
|
|
return {"message": "Use /graph/build-chunk to build from specific chunk"}
|
|
|
|
|
|
@router.post("/{kb_id}/graph/build-chunk/{chunk_id}")
|
|
async def build_graph_from_chunk(
|
|
kb_id: UUID,
|
|
chunk_id: UUID,
|
|
context: Optional[str] = None,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
"""从单个Chunk构建图谱"""
|
|
builder = GraphBuilder()
|
|
|
|
try:
|
|
stats = await builder.build_from_chunk(db, str(chunk_id), context)
|
|
return {
|
|
"status": "success",
|
|
"stats": stats,
|
|
}
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=404, detail=str(e))
|
|
|
|
|
|
@router.get("/{kb_id}/graph/statistics")
|
|
async def get_graph_statistics(
|
|
kb_id: UUID,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
"""获取图谱统计信息"""
|
|
query = GraphQuery()
|
|
stats = await query.get_statistics(db, str(kb_id))
|
|
return stats
|
|
|
|
|
|
@router.get("/{kb_id}/graph/entities/search")
|
|
async def search_entities(
|
|
kb_id: UUID,
|
|
q: str,
|
|
entity_type: Optional[str] = None,
|
|
limit: int = 20,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
"""搜索实体"""
|
|
query = GraphQuery()
|
|
entities = await query.search_entities(
|
|
db, str(kb_id), q, entity_type, limit
|
|
)
|
|
return {"entities": entities}
|
|
|
|
|
|
@router.get("/{kb_id}/graph/entities/{entity_id}")
|
|
async def get_entity(
|
|
kb_id: UUID,
|
|
entity_id: UUID,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
"""获取实体详情"""
|
|
query = GraphQuery()
|
|
entity = await query.get_entity(db, str(entity_id))
|
|
|
|
if not entity:
|
|
raise HTTPException(status_code=404, detail="Entity not found")
|
|
|
|
# 获取邻居
|
|
neighbors = await query.get_entity_neighbors(db, str(entity_id))
|
|
entity["neighbors"] = neighbors
|
|
|
|
return entity
|
|
|
|
|
|
@router.get("/{kb_id}/graph/path")
|
|
async def find_path(
|
|
kb_id: UUID,
|
|
source: str,
|
|
target: str,
|
|
max_hops: int = 3,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
"""查找两个实体之间的路径"""
|
|
query = GraphQuery()
|
|
path = await query.get_entity_path(db, source, target, max_hops)
|
|
return {"path": path, "hops": len(path)}
|