geo/backend/app/agent_framework/tools/content_tools.py

58 lines
1.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Content 业务工具 - 将内容生成相关服务注册为 FunctionTool"""
import logging
from typing import Any
from agentkit.tools.function_tool import FunctionTool
from agentkit.tools.registry import ToolRegistry
logger = logging.getLogger(__name__)
async def retrieve_knowledge(
knowledge_base_ids: list[str],
query: str,
top_k: int = 5,
) -> dict:
"""从知识库检索相关内容"""
if not knowledge_base_ids or not query:
return {"content": "暂无相关知识库内容", "sources": []}
try:
from app.database import AsyncSessionLocal
from app.services.knowledge.rag_service import RAGService
async with AsyncSessionLocal() as session:
rag = RAGService()
results = await rag.search(
session=session,
query=query,
knowledge_base_ids=knowledge_base_ids,
top_k=top_k,
)
if results:
content_parts = []
sources = []
for r in results:
title = r.get("document_title", "未知")
content_parts.append(f"[来源: {title}]\n{r.get('content', '')}")
sources.append(title)
return {"content": "\n\n---\n\n".join(content_parts), "sources": sources}
except Exception as e:
logger.warning(f"RAG检索失败: {e}")
return {"content": "暂无相关知识库内容", "sources": []}
def register_content_tools(registry: ToolRegistry) -> None:
"""注册所有内容生成相关工具"""
registry.register(
FunctionTool(
name="retrieve_knowledge",
description="从知识库检索相关内容用于RAG增强生成",
func=retrieve_knowledge,
tags=["content", "rag", "knowledge"],
)
)
logger.info("Content tools registered")