77 lines
2.3 KiB
Python
77 lines
2.3 KiB
Python
"""Citation 业务工具 - 将引用检测服务注册为 FunctionTool"""
|
||
|
||
import logging
|
||
from typing import Any
|
||
|
||
from agentkit.tools.function_tool import FunctionTool
|
||
from agentkit.tools.registry import ToolRegistry
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
async def execute_single_platform(
|
||
keyword: str,
|
||
platform: str,
|
||
target_brand: str,
|
||
brand_aliases: list[str] | None = None,
|
||
) -> dict:
|
||
"""调用平台执行单次引用检测"""
|
||
from app.services.ai_engine.platform_bridge import execute_single_platform as _exec
|
||
|
||
return await _exec(
|
||
keyword=keyword,
|
||
platform=platform,
|
||
target_brand=target_brand,
|
||
brand_aliases=brand_aliases or [],
|
||
)
|
||
|
||
|
||
async def get_or_create_task(query_id: str, platform: str) -> dict:
|
||
"""获取或创建查询任务"""
|
||
import uuid
|
||
from datetime import datetime
|
||
from sqlalchemy import select
|
||
from app.database import AsyncSessionLocal
|
||
from app.models.query_task import QueryTask
|
||
|
||
async with AsyncSessionLocal() as db:
|
||
stmt = select(QueryTask).where(
|
||
QueryTask.query_id == uuid.UUID(query_id),
|
||
QueryTask.platform == platform,
|
||
)
|
||
result = await db.execute(stmt)
|
||
task_obj = result.scalar_one_or_none()
|
||
|
||
if not task_obj:
|
||
task_obj = QueryTask(
|
||
query_id=uuid.UUID(query_id),
|
||
platform=platform,
|
||
status="pending",
|
||
)
|
||
db.add(task_obj)
|
||
await db.commit()
|
||
await db.refresh(task_obj)
|
||
|
||
return {"id": str(task_obj.id), "platform": task_obj.platform, "status": task_obj.status}
|
||
|
||
|
||
def register_citation_tools(registry: ToolRegistry) -> None:
|
||
"""注册所有引用检测相关工具"""
|
||
registry.register(
|
||
FunctionTool(
|
||
name="execute_single_platform",
|
||
description="在指定AI平台执行引用检测,返回引用结果",
|
||
func=execute_single_platform,
|
||
tags=["citation", "detection"],
|
||
)
|
||
)
|
||
registry.register(
|
||
FunctionTool(
|
||
name="get_or_create_task",
|
||
description="获取或创建引用检测的查询任务记录",
|
||
func=get_or_create_task,
|
||
tags=["citation", "task"],
|
||
)
|
||
)
|
||
logger.info("Citation tools registered")
|