"""Citation 业务工具 - 将引用检测服务注册为 FunctionTool""" import logging from typing import Any from agentkit.tools.function_tool import FunctionTool from agentkit.tools.registry import ToolRegistry logger = logging.getLogger(__name__) async def execute_single_platform( keyword: str, platform: str, target_brand: str, brand_aliases: list[str] | None = None, ) -> dict: """调用平台执行单次引用检测""" from app.services.ai_engine.platform_bridge import execute_single_platform as _exec return await _exec( keyword=keyword, platform=platform, target_brand=target_brand, brand_aliases=brand_aliases or [], ) async def get_or_create_task(query_id: str, platform: str) -> dict: """获取或创建查询任务""" import uuid from datetime import datetime from sqlalchemy import select from app.database import AsyncSessionLocal from app.models.query_task import QueryTask async with AsyncSessionLocal() as db: stmt = select(QueryTask).where( QueryTask.query_id == uuid.UUID(query_id), QueryTask.platform == platform, ) result = await db.execute(stmt) task_obj = result.scalar_one_or_none() if not task_obj: task_obj = QueryTask( query_id=uuid.UUID(query_id), platform=platform, status="pending", ) db.add(task_obj) await db.commit() await db.refresh(task_obj) return {"id": str(task_obj.id), "platform": task_obj.platform, "status": task_obj.status} def register_citation_tools(registry: ToolRegistry) -> None: """注册所有引用检测相关工具""" registry.register( FunctionTool( name="execute_single_platform", description="在指定AI平台执行引用检测,返回引用结果", func=execute_single_platform, tags=["citation", "detection"], ) ) registry.register( FunctionTool( name="get_or_create_task", description="获取或创建引用检测的查询任务记录", func=get_or_create_task, tags=["citation", "task"], ) ) logger.info("Citation tools registered")