58 lines
1.9 KiB
Python
58 lines
1.9 KiB
Python
"""Content 业务工具 - 将内容生成相关服务注册为 FunctionTool"""
|
||
|
||
import logging
|
||
from typing import Any
|
||
|
||
from agentkit.tools.function_tool import FunctionTool
|
||
from agentkit.tools.registry import ToolRegistry
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
async def retrieve_knowledge(
|
||
knowledge_base_ids: list[str],
|
||
query: str,
|
||
top_k: int = 5,
|
||
) -> dict:
|
||
"""从知识库检索相关内容"""
|
||
if not knowledge_base_ids or not query:
|
||
return {"content": "暂无相关知识库内容", "sources": []}
|
||
|
||
try:
|
||
from app.database import AsyncSessionLocal
|
||
from app.services.knowledge.rag_service import RAGService
|
||
|
||
async with AsyncSessionLocal() as session:
|
||
rag = RAGService()
|
||
results = await rag.search(
|
||
session=session,
|
||
query=query,
|
||
knowledge_base_ids=knowledge_base_ids,
|
||
top_k=top_k,
|
||
)
|
||
if results:
|
||
content_parts = []
|
||
sources = []
|
||
for r in results:
|
||
title = r.get("document_title", "未知")
|
||
content_parts.append(f"[来源: {title}]\n{r.get('content', '')}")
|
||
sources.append(title)
|
||
return {"content": "\n\n---\n\n".join(content_parts), "sources": sources}
|
||
except Exception as e:
|
||
logger.warning(f"RAG检索失败: {e}")
|
||
|
||
return {"content": "暂无相关知识库内容", "sources": []}
|
||
|
||
|
||
def register_content_tools(registry: ToolRegistry) -> None:
|
||
"""注册所有内容生成相关工具"""
|
||
registry.register(
|
||
FunctionTool(
|
||
name="retrieve_knowledge",
|
||
description="从知识库检索相关内容,用于RAG增强生成",
|
||
func=retrieve_knowledge,
|
||
tags=["content", "rag", "knowledge"],
|
||
)
|
||
)
|
||
logger.info("Content tools registered")
|