geo/backend/app/agent_framework/tools/citation_tools.py

77 lines
2.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Citation 业务工具 - 将引用检测服务注册为 FunctionTool"""
import logging
from typing import Any
from agentkit.tools.function_tool import FunctionTool
from agentkit.tools.registry import ToolRegistry
logger = logging.getLogger(__name__)
async def execute_single_platform(
keyword: str,
platform: str,
target_brand: str,
brand_aliases: list[str] | None = None,
) -> dict:
"""调用平台执行单次引用检测"""
from app.services.ai_engine.platform_bridge import execute_single_platform as _exec
return await _exec(
keyword=keyword,
platform=platform,
target_brand=target_brand,
brand_aliases=brand_aliases or [],
)
async def get_or_create_task(query_id: str, platform: str) -> dict:
"""获取或创建查询任务"""
import uuid
from datetime import datetime
from sqlalchemy import select
from app.database import AsyncSessionLocal
from app.models.query_task import QueryTask
async with AsyncSessionLocal() as db:
stmt = select(QueryTask).where(
QueryTask.query_id == uuid.UUID(query_id),
QueryTask.platform == platform,
)
result = await db.execute(stmt)
task_obj = result.scalar_one_or_none()
if not task_obj:
task_obj = QueryTask(
query_id=uuid.UUID(query_id),
platform=platform,
status="pending",
)
db.add(task_obj)
await db.commit()
await db.refresh(task_obj)
return {"id": str(task_obj.id), "platform": task_obj.platform, "status": task_obj.status}
def register_citation_tools(registry: ToolRegistry) -> None:
"""注册所有引用检测相关工具"""
registry.register(
FunctionTool(
name="execute_single_platform",
description="在指定AI平台执行引用检测返回引用结果",
func=execute_single_platform,
tags=["citation", "detection"],
)
)
registry.register(
FunctionTool(
name="get_or_create_task",
description="获取或创建引用检测的查询任务记录",
func=get_or_create_task,
tags=["citation", "task"],
)
)
logger.info("Citation tools registered")