69 lines
1.9 KiB
Python
69 lines
1.9 KiB
Python
"""Trend 业务工具 - 将趋势分析服务注册为 FunctionTool"""
|
|
|
|
import logging
|
|
from typing import Any
|
|
|
|
from agentkit.tools.function_tool import FunctionTool
|
|
from agentkit.tools.registry import ToolRegistry
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
async def trend_insight(
|
|
brand_id: str,
|
|
days: int = 30,
|
|
platforms: list[str] | None = None,
|
|
keywords: list[str] | None = None,
|
|
) -> dict:
|
|
"""执行趋势洞察分析"""
|
|
from app.database import AsyncSessionLocal
|
|
from app.services.trend.trend_analyzer_service import TrendAnalyzerService
|
|
|
|
async with AsyncSessionLocal() as db:
|
|
service = TrendAnalyzerService(db)
|
|
result = await service.analyze_trends(
|
|
brand_id=brand_id,
|
|
days=days,
|
|
platforms=platforms,
|
|
keywords=keywords,
|
|
)
|
|
return result
|
|
|
|
|
|
async def trend_hotspot(
|
|
brand_id: str,
|
|
days: int = 30,
|
|
) -> dict:
|
|
"""检测引用量突增的热点话题"""
|
|
from app.database import AsyncSessionLocal
|
|
from app.services.trend.trend_analyzer_service import TrendAnalyzerService
|
|
|
|
async with AsyncSessionLocal() as db:
|
|
service = TrendAnalyzerService(db)
|
|
result = await service.get_hotspots(
|
|
brand_id=brand_id,
|
|
days=days,
|
|
)
|
|
return result
|
|
|
|
|
|
def register_trend_tools(registry: ToolRegistry) -> None:
|
|
"""注册所有趋势分析相关工具"""
|
|
registry.register(
|
|
FunctionTool(
|
|
name="trend_insight",
|
|
description="分析品牌引用趋势,推断变化原因",
|
|
func=trend_insight,
|
|
tags=["trend", "insight"],
|
|
)
|
|
)
|
|
registry.register(
|
|
FunctionTool(
|
|
name="trend_hotspot",
|
|
description="检测引用量突增的热点话题",
|
|
func=trend_hotspot,
|
|
tags=["trend", "hotspot"],
|
|
)
|
|
)
|
|
logger.info("Trend tools registered")
|