466 lines
15 KiB
Python
466 lines
15 KiB
Python
"""GEO 项目的 Tool 注册 — 供 AgentKit Server 使用
|
||
|
||
所有 Tool 通过 HTTP 调用 GEO Backend 的业务 API,不直接 import GEO 服务类。
|
||
"""
|
||
|
||
import logging
|
||
import os
|
||
from typing import Any
|
||
|
||
import httpx
|
||
|
||
from agentkit.tools.function_tool import FunctionTool
|
||
from agentkit.tools.registry import ToolRegistry
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
GEO_BACKEND_URL = os.getenv("GEO_BACKEND_URL", "http://localhost:8000")
|
||
INTERNAL_API_TOKEN = os.getenv("INTERNAL_API_TOKEN", "")
|
||
|
||
|
||
def _internal_headers() -> dict:
|
||
"""获取内部 API 请求头"""
|
||
headers = {"Content-Type": "application/json"}
|
||
if INTERNAL_API_TOKEN:
|
||
headers["X-Internal-Token"] = INTERNAL_API_TOKEN
|
||
return headers
|
||
|
||
|
||
# ─── Citation Tools ───
|
||
|
||
async def execute_single_platform(
|
||
keyword: str,
|
||
platform: str,
|
||
target_brand: str,
|
||
brand_aliases: list[str] | None = None,
|
||
) -> dict:
|
||
"""在单个 AI 平台执行引用检测"""
|
||
try:
|
||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||
resp = await client.post(
|
||
f"{GEO_BACKEND_URL}/api/v1/ai-engines/execute-single-platform",
|
||
json={
|
||
"keyword": keyword,
|
||
"platform": platform,
|
||
"target_brand": target_brand,
|
||
"brand_aliases": brand_aliases or [],
|
||
},
|
||
)
|
||
resp.raise_for_status()
|
||
return resp.json()
|
||
except Exception as e:
|
||
logger.error(f"execute_single_platform 失败: {e}")
|
||
return {"error": str(e), "keyword": keyword, "platform": platform}
|
||
|
||
|
||
async def get_or_create_task(query_id: str, platform: str) -> dict:
|
||
"""获取或创建查询任务 — 通过内部 API"""
|
||
try:
|
||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||
resp = await client.post(
|
||
f"{GEO_BACKEND_URL}/internal/citation/get-or-create-task",
|
||
json={"query_id": query_id, "platform": platform},
|
||
headers=_internal_headers(),
|
||
)
|
||
resp.raise_for_status()
|
||
return resp.json()
|
||
except Exception as e:
|
||
logger.error(f"get_or_create_task 失败: {e}")
|
||
return {"error": str(e), "query_id": query_id, "platform": platform}
|
||
|
||
|
||
# ─── Content Tools ───
|
||
|
||
async def retrieve_knowledge(
|
||
knowledge_base_ids: list[str],
|
||
query: str,
|
||
top_k: int = 5,
|
||
) -> dict:
|
||
"""从知识库检索相关内容 — 通过内部 API"""
|
||
if not knowledge_base_ids or not query:
|
||
return {"content": "暂无相关知识库内容", "sources": []}
|
||
try:
|
||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||
resp = await client.post(
|
||
f"{GEO_BACKEND_URL}/internal/knowledge/search",
|
||
json={"query": query, "knowledge_base_ids": knowledge_base_ids, "top_k": top_k},
|
||
headers=_internal_headers(),
|
||
)
|
||
resp.raise_for_status()
|
||
data = resp.json()
|
||
results = data.get("results", [])
|
||
if results:
|
||
content_parts = []
|
||
sources = []
|
||
for r in results:
|
||
title = r.get("document_title", "未知")
|
||
content_parts.append(f"[来源: {title}]\n{r.get('content', '')}")
|
||
sources.append(title)
|
||
return {"content": "\n\n---\n\n".join(content_parts), "sources": sources}
|
||
return {"content": "暂无相关知识库内容", "sources": []}
|
||
except Exception as e:
|
||
logger.warning(f"retrieve_knowledge 失败: {e}")
|
||
return {"content": "暂无相关知识库内容", "sources": []}
|
||
|
||
|
||
# ─── Monitor Tools ───
|
||
|
||
async def monitor_check_and_compare(record_id: str) -> dict:
|
||
"""检测并对比监测记录的变化 — 通过内部 API"""
|
||
try:
|
||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||
resp = await client.post(
|
||
f"{GEO_BACKEND_URL}/internal/monitor/check",
|
||
json={"record_id": record_id},
|
||
headers=_internal_headers(),
|
||
)
|
||
resp.raise_for_status()
|
||
return resp.json()
|
||
except Exception as e:
|
||
logger.error(f"monitor_check_and_compare 失败: {e}")
|
||
return {"error": str(e), "record_id": record_id}
|
||
|
||
|
||
async def monitor_generate_report(record_id: str) -> dict:
|
||
"""生成监测变化报告 — 通过内部 API"""
|
||
try:
|
||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||
resp = await client.post(
|
||
f"{GEO_BACKEND_URL}/internal/monitor/generate-report",
|
||
json={"record_id": record_id},
|
||
headers=_internal_headers(),
|
||
)
|
||
resp.raise_for_status()
|
||
return resp.json()
|
||
except Exception as e:
|
||
logger.error(f"monitor_generate_report 失败: {e}")
|
||
return {"error": str(e), "record_id": record_id}
|
||
|
||
|
||
async def monitor_create_record(
|
||
brand_id: str,
|
||
query_keywords: str | None = None,
|
||
platform: str | None = None,
|
||
check_interval_hours: int = 24,
|
||
) -> dict:
|
||
"""创建监测记录 — 通过内部 API"""
|
||
try:
|
||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||
resp = await client.post(
|
||
f"{GEO_BACKEND_URL}/internal/monitor/create-record",
|
||
json={
|
||
"brand_id": brand_id,
|
||
"query_keywords": query_keywords,
|
||
"platform": platform,
|
||
"check_interval_hours": check_interval_hours,
|
||
},
|
||
headers=_internal_headers(),
|
||
)
|
||
resp.raise_for_status()
|
||
return resp.json()
|
||
except Exception as e:
|
||
logger.error(f"monitor_create_record 失败: {e}")
|
||
return {"error": str(e), "brand_id": brand_id}
|
||
|
||
|
||
# ─── Schema Tools ───
|
||
|
||
SCHEMA_TEMPLATES = {
|
||
"Organization": {
|
||
"@context": "https://schema.org", "@type": "Organization",
|
||
"name": "", "description": "", "url": "", "logo": "", "sameAs": [],
|
||
},
|
||
"Product": {
|
||
"@context": "https://schema.org", "@type": "Product",
|
||
"name": "", "description": "",
|
||
"brand": {"@type": "Brand", "name": ""},
|
||
},
|
||
"FAQPage": {
|
||
"@context": "https://schema.org", "@type": "FAQPage",
|
||
"mainEntity": [{"@type": "Question", "name": "", "acceptedAnswer": {"@type": "Answer", "text": ""}}],
|
||
},
|
||
"Article": {
|
||
"@context": "https://schema.org", "@type": "Article",
|
||
"headline": "", "description": "", "author": {"@type": "Organization", "name": ""},
|
||
},
|
||
"LocalBusiness": {
|
||
"@context": "https://schema.org", "@type": "LocalBusiness",
|
||
"name": "", "address": {"@type": "PostalAddress"},
|
||
},
|
||
}
|
||
|
||
DIMENSION_SCHEMA_MAP = {
|
||
"schema_marketing": ["Organization", "LocalBusiness"],
|
||
"entity_clarity": ["Organization", "Product"],
|
||
"citation_readiness": ["FAQPage", "Article"],
|
||
"brand_visibility": ["Organization", "Product"],
|
||
"local_seo": ["LocalBusiness"],
|
||
}
|
||
|
||
|
||
async def fill_schema_with_llm(
|
||
schema_type: str,
|
||
brand_info: dict | None = None,
|
||
diagnosis_dimensions: dict | None = None,
|
||
) -> dict:
|
||
"""使用 LLM 填充 Schema JSON-LD 模板 — 通过 GEO Backend 内部 API"""
|
||
try:
|
||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||
resp = await client.post(
|
||
f"{GEO_BACKEND_URL}/internal/schema/advise",
|
||
json={
|
||
"schema_type": schema_type,
|
||
"brand_info": brand_info or {},
|
||
"diagnosis_dimensions": diagnosis_dimensions or {},
|
||
},
|
||
headers=_internal_headers(),
|
||
)
|
||
resp.raise_for_status()
|
||
return resp.json()
|
||
except Exception as e:
|
||
logger.error(f"fill_schema_with_llm 失败: {e}")
|
||
return {"error": str(e), "schema_type": schema_type}
|
||
|
||
|
||
async def identify_missing_dimensions(
|
||
diagnosis_data: dict,
|
||
focus_dimensions: list[str] | None = None,
|
||
) -> dict:
|
||
"""识别 Schema 缺失维度"""
|
||
dimensions = []
|
||
dimension_scores = diagnosis_data.get("dimensions", {})
|
||
for dim_name, dim_info in dimension_scores.items():
|
||
if dim_name not in DIMENSION_SCHEMA_MAP:
|
||
continue
|
||
if focus_dimensions and dim_name not in focus_dimensions:
|
||
continue
|
||
score = dim_info.get("score", 0) if isinstance(dim_info, dict) else dim_info
|
||
max_score = dim_info.get("max_score", 100) if isinstance(dim_info, dict) else 100
|
||
percentage = (score / max_score * 100) if max_score > 0 else 0
|
||
if percentage < 80:
|
||
dimensions.append({
|
||
"dimension": dim_name,
|
||
"current_score": round(score, 2),
|
||
"max_score": max_score,
|
||
"percentage": round(percentage, 2),
|
||
})
|
||
return {"missing_dimensions": dimensions}
|
||
|
||
|
||
# ─── Competitor Tools ───
|
||
|
||
async def competitor_analyze(
|
||
brand_id: str,
|
||
analysis_types: list[str] | None = None,
|
||
period_days: int = 30,
|
||
) -> dict:
|
||
"""执行竞品策略分析 — 通过 GEO Backend API"""
|
||
try:
|
||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||
resp = await client.post(
|
||
f"{GEO_BACKEND_URL}/api/v1/competitor/analyze",
|
||
json={
|
||
"brand_id": brand_id,
|
||
"analysis_types": analysis_types,
|
||
"period_days": period_days,
|
||
},
|
||
)
|
||
resp.raise_for_status()
|
||
return resp.json()
|
||
except Exception as e:
|
||
logger.error(f"competitor_analyze 失败: {e}")
|
||
return {"error": str(e), "brand_id": brand_id}
|
||
|
||
|
||
async def competitor_gap_analysis(
|
||
brand_id: str,
|
||
period_days: int = 30,
|
||
) -> dict:
|
||
"""执行竞品差距分析 — 通过 GEO Backend API"""
|
||
return await competitor_analyze(
|
||
brand_id=brand_id,
|
||
analysis_types=["citation_gap", "platform_coverage", "query_overlap"],
|
||
period_days=period_days,
|
||
)
|
||
|
||
|
||
# ─── Trend Tools ───
|
||
|
||
async def trend_insight(
|
||
brand_id: str,
|
||
days: int = 30,
|
||
platforms: list[str] | None = None,
|
||
keywords: list[str] | None = None,
|
||
) -> dict:
|
||
"""执行趋势洞察分析 — 通过 GEO Backend API"""
|
||
try:
|
||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||
resp = await client.post(
|
||
f"{GEO_BACKEND_URL}/api/v1/trends/insight",
|
||
json={
|
||
"brand_id": brand_id,
|
||
"days": days,
|
||
"platforms": platforms,
|
||
"keywords": keywords,
|
||
},
|
||
)
|
||
resp.raise_for_status()
|
||
return resp.json()
|
||
except Exception as e:
|
||
logger.error(f"trend_insight 失败: {e}")
|
||
return {"error": str(e), "brand_id": brand_id}
|
||
|
||
|
||
async def trend_hotspot(
|
||
brand_id: str,
|
||
days: int = 30,
|
||
) -> dict:
|
||
"""检测引用量突增的热点话题 — 通过 GEO Backend API"""
|
||
try:
|
||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||
resp = await client.post(
|
||
f"{GEO_BACKEND_URL}/api/v1/trends/hotspot",
|
||
json={"brand_id": brand_id, "days": days},
|
||
)
|
||
resp.raise_for_status()
|
||
return resp.json()
|
||
except Exception as e:
|
||
logger.error(f"trend_hotspot 失败: {e}")
|
||
return {"error": str(e), "brand_id": brand_id}
|
||
|
||
|
||
# ─── Knowledge Tools ───
|
||
|
||
async def search_knowledge(
|
||
query: str,
|
||
knowledge_base_ids: list[str],
|
||
top_k: int = 5,
|
||
) -> dict:
|
||
"""从知识库检索相关内容 — 通过内部 API"""
|
||
return await retrieve_knowledge(
|
||
knowledge_base_ids=knowledge_base_ids,
|
||
query=query,
|
||
top_k=top_k,
|
||
)
|
||
|
||
|
||
async def detect_ai_patterns(content: str, platform_id: str) -> dict:
|
||
"""检测内容中的 AI 生成模式 — 通过 GEO Backend API"""
|
||
try:
|
||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||
resp = await client.post(
|
||
f"{GEO_BACKEND_URL}/api/v1/ai-engines/detect-ai-patterns",
|
||
json={"content": content, "platform_id": platform_id},
|
||
)
|
||
resp.raise_for_status()
|
||
return resp.json()
|
||
except Exception as e:
|
||
logger.error(f"detect_ai_patterns 失败: {e}")
|
||
return {"error": str(e), "patterns": [], "count": 0}
|
||
|
||
|
||
# ─── Registration ───
|
||
|
||
def register_geo_tools(registry: ToolRegistry) -> None:
|
||
"""注册 GEO 项目的所有 Tool"""
|
||
|
||
# Citation
|
||
registry.register(FunctionTool(
|
||
name="execute_single_platform",
|
||
description="在单个AI平台执行引用检测",
|
||
func=execute_single_platform,
|
||
tags=["citation", "detection"],
|
||
))
|
||
registry.register(FunctionTool(
|
||
name="get_or_create_task",
|
||
description="获取或创建引用检测的查询任务",
|
||
func=get_or_create_task,
|
||
tags=["citation", "task"],
|
||
))
|
||
|
||
# Content
|
||
registry.register(FunctionTool(
|
||
name="retrieve_knowledge",
|
||
description="从知识库检索相关内容",
|
||
func=retrieve_knowledge,
|
||
tags=["content", "rag", "knowledge"],
|
||
))
|
||
|
||
# Monitor
|
||
registry.register(FunctionTool(
|
||
name="monitor_check_and_compare",
|
||
description="检测并对比监测记录的变化",
|
||
func=monitor_check_and_compare,
|
||
tags=["monitor", "tracking"],
|
||
))
|
||
registry.register(FunctionTool(
|
||
name="monitor_generate_report",
|
||
description="生成监测变化报告",
|
||
func=monitor_generate_report,
|
||
tags=["monitor", "report"],
|
||
))
|
||
registry.register(FunctionTool(
|
||
name="monitor_create_record",
|
||
description="创建新的监测记录",
|
||
func=monitor_create_record,
|
||
tags=["monitor", "record"],
|
||
))
|
||
|
||
# Schema
|
||
registry.register(FunctionTool(
|
||
name="fill_schema_with_llm",
|
||
description="使用LLM填充Schema JSON-LD模板",
|
||
func=fill_schema_with_llm,
|
||
tags=["schema", "llm"],
|
||
))
|
||
registry.register(FunctionTool(
|
||
name="identify_missing_dimensions",
|
||
description="识别Schema缺失维度",
|
||
func=identify_missing_dimensions,
|
||
tags=["schema", "diagnosis"],
|
||
))
|
||
|
||
# Competitor
|
||
registry.register(FunctionTool(
|
||
name="competitor_analyze",
|
||
description="执行竞品策略分析",
|
||
func=competitor_analyze,
|
||
tags=["competitor", "analysis"],
|
||
))
|
||
registry.register(FunctionTool(
|
||
name="competitor_gap_analysis",
|
||
description="执行竞品差距分析",
|
||
func=competitor_gap_analysis,
|
||
tags=["competitor", "gap"],
|
||
))
|
||
|
||
# Trend
|
||
registry.register(FunctionTool(
|
||
name="trend_insight",
|
||
description="分析品牌引用趋势",
|
||
func=trend_insight,
|
||
tags=["trend", "insight"],
|
||
))
|
||
registry.register(FunctionTool(
|
||
name="trend_hotspot",
|
||
description="检测引用量突增的热点话题",
|
||
func=trend_hotspot,
|
||
tags=["trend", "hotspot"],
|
||
))
|
||
|
||
# Knowledge
|
||
registry.register(FunctionTool(
|
||
name="search_knowledge",
|
||
description="从知识库检索相关内容",
|
||
func=search_knowledge,
|
||
tags=["knowledge", "rag"],
|
||
))
|
||
registry.register(FunctionTool(
|
||
name="detect_ai_patterns",
|
||
description="检测内容中的AI生成模式",
|
||
func=detect_ai_patterns,
|
||
tags=["knowledge", "deai"],
|
||
))
|
||
|
||
logger.info(f"GEO tools registered: {len(registry.list_tools())} tools")
|