fischer-agentkit/configs/geo_tools.py

466 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""GEO 项目的 Tool 注册 — 供 AgentKit Server 使用
所有 Tool 通过 HTTP 调用 GEO Backend 的业务 API不直接 import GEO 服务类。
"""
import logging
import os
from typing import Any
import httpx
from agentkit.tools.function_tool import FunctionTool
from agentkit.tools.registry import ToolRegistry
logger = logging.getLogger(__name__)
GEO_BACKEND_URL = os.getenv("GEO_BACKEND_URL", "http://localhost:8000")
INTERNAL_API_TOKEN = os.getenv("INTERNAL_API_TOKEN", "")
def _internal_headers() -> dict:
"""获取内部 API 请求头"""
headers = {"Content-Type": "application/json"}
if INTERNAL_API_TOKEN:
headers["X-Internal-Token"] = INTERNAL_API_TOKEN
return headers
# ─── Citation Tools ───
async def execute_single_platform(
keyword: str,
platform: str,
target_brand: str,
brand_aliases: list[str] | None = None,
) -> dict:
"""在单个 AI 平台执行引用检测"""
try:
async with httpx.AsyncClient(timeout=120.0) as client:
resp = await client.post(
f"{GEO_BACKEND_URL}/api/v1/ai-engines/execute-single-platform",
json={
"keyword": keyword,
"platform": platform,
"target_brand": target_brand,
"brand_aliases": brand_aliases or [],
},
)
resp.raise_for_status()
return resp.json()
except Exception as e:
logger.error(f"execute_single_platform 失败: {e}")
return {"error": str(e), "keyword": keyword, "platform": platform}
async def get_or_create_task(query_id: str, platform: str) -> dict:
"""获取或创建查询任务 — 通过内部 API"""
try:
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.post(
f"{GEO_BACKEND_URL}/internal/citation/get-or-create-task",
json={"query_id": query_id, "platform": platform},
headers=_internal_headers(),
)
resp.raise_for_status()
return resp.json()
except Exception as e:
logger.error(f"get_or_create_task 失败: {e}")
return {"error": str(e), "query_id": query_id, "platform": platform}
# ─── Content Tools ───
async def retrieve_knowledge(
knowledge_base_ids: list[str],
query: str,
top_k: int = 5,
) -> dict:
"""从知识库检索相关内容 — 通过内部 API"""
if not knowledge_base_ids or not query:
return {"content": "暂无相关知识库内容", "sources": []}
try:
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.post(
f"{GEO_BACKEND_URL}/internal/knowledge/search",
json={"query": query, "knowledge_base_ids": knowledge_base_ids, "top_k": top_k},
headers=_internal_headers(),
)
resp.raise_for_status()
data = resp.json()
results = data.get("results", [])
if results:
content_parts = []
sources = []
for r in results:
title = r.get("document_title", "未知")
content_parts.append(f"[来源: {title}]\n{r.get('content', '')}")
sources.append(title)
return {"content": "\n\n---\n\n".join(content_parts), "sources": sources}
return {"content": "暂无相关知识库内容", "sources": []}
except Exception as e:
logger.warning(f"retrieve_knowledge 失败: {e}")
return {"content": "暂无相关知识库内容", "sources": []}
# ─── Monitor Tools ───
async def monitor_check_and_compare(record_id: str) -> dict:
"""检测并对比监测记录的变化 — 通过内部 API"""
try:
async with httpx.AsyncClient(timeout=60.0) as client:
resp = await client.post(
f"{GEO_BACKEND_URL}/internal/monitor/check",
json={"record_id": record_id},
headers=_internal_headers(),
)
resp.raise_for_status()
return resp.json()
except Exception as e:
logger.error(f"monitor_check_and_compare 失败: {e}")
return {"error": str(e), "record_id": record_id}
async def monitor_generate_report(record_id: str) -> dict:
"""生成监测变化报告 — 通过内部 API"""
try:
async with httpx.AsyncClient(timeout=60.0) as client:
resp = await client.post(
f"{GEO_BACKEND_URL}/internal/monitor/generate-report",
json={"record_id": record_id},
headers=_internal_headers(),
)
resp.raise_for_status()
return resp.json()
except Exception as e:
logger.error(f"monitor_generate_report 失败: {e}")
return {"error": str(e), "record_id": record_id}
async def monitor_create_record(
brand_id: str,
query_keywords: str | None = None,
platform: str | None = None,
check_interval_hours: int = 24,
) -> dict:
"""创建监测记录 — 通过内部 API"""
try:
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.post(
f"{GEO_BACKEND_URL}/internal/monitor/create-record",
json={
"brand_id": brand_id,
"query_keywords": query_keywords,
"platform": platform,
"check_interval_hours": check_interval_hours,
},
headers=_internal_headers(),
)
resp.raise_for_status()
return resp.json()
except Exception as e:
logger.error(f"monitor_create_record 失败: {e}")
return {"error": str(e), "brand_id": brand_id}
# ─── Schema Tools ───
SCHEMA_TEMPLATES = {
"Organization": {
"@context": "https://schema.org", "@type": "Organization",
"name": "", "description": "", "url": "", "logo": "", "sameAs": [],
},
"Product": {
"@context": "https://schema.org", "@type": "Product",
"name": "", "description": "",
"brand": {"@type": "Brand", "name": ""},
},
"FAQPage": {
"@context": "https://schema.org", "@type": "FAQPage",
"mainEntity": [{"@type": "Question", "name": "", "acceptedAnswer": {"@type": "Answer", "text": ""}}],
},
"Article": {
"@context": "https://schema.org", "@type": "Article",
"headline": "", "description": "", "author": {"@type": "Organization", "name": ""},
},
"LocalBusiness": {
"@context": "https://schema.org", "@type": "LocalBusiness",
"name": "", "address": {"@type": "PostalAddress"},
},
}
DIMENSION_SCHEMA_MAP = {
"schema_marketing": ["Organization", "LocalBusiness"],
"entity_clarity": ["Organization", "Product"],
"citation_readiness": ["FAQPage", "Article"],
"brand_visibility": ["Organization", "Product"],
"local_seo": ["LocalBusiness"],
}
async def fill_schema_with_llm(
schema_type: str,
brand_info: dict | None = None,
diagnosis_dimensions: dict | None = None,
) -> dict:
"""使用 LLM 填充 Schema JSON-LD 模板 — 通过 GEO Backend 内部 API"""
try:
async with httpx.AsyncClient(timeout=60.0) as client:
resp = await client.post(
f"{GEO_BACKEND_URL}/internal/schema/advise",
json={
"schema_type": schema_type,
"brand_info": brand_info or {},
"diagnosis_dimensions": diagnosis_dimensions or {},
},
headers=_internal_headers(),
)
resp.raise_for_status()
return resp.json()
except Exception as e:
logger.error(f"fill_schema_with_llm 失败: {e}")
return {"error": str(e), "schema_type": schema_type}
async def identify_missing_dimensions(
diagnosis_data: dict,
focus_dimensions: list[str] | None = None,
) -> dict:
"""识别 Schema 缺失维度"""
dimensions = []
dimension_scores = diagnosis_data.get("dimensions", {})
for dim_name, dim_info in dimension_scores.items():
if dim_name not in DIMENSION_SCHEMA_MAP:
continue
if focus_dimensions and dim_name not in focus_dimensions:
continue
score = dim_info.get("score", 0) if isinstance(dim_info, dict) else dim_info
max_score = dim_info.get("max_score", 100) if isinstance(dim_info, dict) else 100
percentage = (score / max_score * 100) if max_score > 0 else 0
if percentage < 80:
dimensions.append({
"dimension": dim_name,
"current_score": round(score, 2),
"max_score": max_score,
"percentage": round(percentage, 2),
})
return {"missing_dimensions": dimensions}
# ─── Competitor Tools ───
async def competitor_analyze(
brand_id: str,
analysis_types: list[str] | None = None,
period_days: int = 30,
) -> dict:
"""执行竞品策略分析 — 通过 GEO Backend API"""
try:
async with httpx.AsyncClient(timeout=120.0) as client:
resp = await client.post(
f"{GEO_BACKEND_URL}/api/v1/competitor/analyze",
json={
"brand_id": brand_id,
"analysis_types": analysis_types,
"period_days": period_days,
},
)
resp.raise_for_status()
return resp.json()
except Exception as e:
logger.error(f"competitor_analyze 失败: {e}")
return {"error": str(e), "brand_id": brand_id}
async def competitor_gap_analysis(
brand_id: str,
period_days: int = 30,
) -> dict:
"""执行竞品差距分析 — 通过 GEO Backend API"""
return await competitor_analyze(
brand_id=brand_id,
analysis_types=["citation_gap", "platform_coverage", "query_overlap"],
period_days=period_days,
)
# ─── Trend Tools ───
async def trend_insight(
brand_id: str,
days: int = 30,
platforms: list[str] | None = None,
keywords: list[str] | None = None,
) -> dict:
"""执行趋势洞察分析 — 通过 GEO Backend API"""
try:
async with httpx.AsyncClient(timeout=120.0) as client:
resp = await client.post(
f"{GEO_BACKEND_URL}/api/v1/trends/insight",
json={
"brand_id": brand_id,
"days": days,
"platforms": platforms,
"keywords": keywords,
},
)
resp.raise_for_status()
return resp.json()
except Exception as e:
logger.error(f"trend_insight 失败: {e}")
return {"error": str(e), "brand_id": brand_id}
async def trend_hotspot(
brand_id: str,
days: int = 30,
) -> dict:
"""检测引用量突增的热点话题 — 通过 GEO Backend API"""
try:
async with httpx.AsyncClient(timeout=120.0) as client:
resp = await client.post(
f"{GEO_BACKEND_URL}/api/v1/trends/hotspot",
json={"brand_id": brand_id, "days": days},
)
resp.raise_for_status()
return resp.json()
except Exception as e:
logger.error(f"trend_hotspot 失败: {e}")
return {"error": str(e), "brand_id": brand_id}
# ─── Knowledge Tools ───
async def search_knowledge(
query: str,
knowledge_base_ids: list[str],
top_k: int = 5,
) -> dict:
"""从知识库检索相关内容 — 通过内部 API"""
return await retrieve_knowledge(
knowledge_base_ids=knowledge_base_ids,
query=query,
top_k=top_k,
)
async def detect_ai_patterns(content: str, platform_id: str) -> dict:
"""检测内容中的 AI 生成模式 — 通过 GEO Backend API"""
try:
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.post(
f"{GEO_BACKEND_URL}/api/v1/ai-engines/detect-ai-patterns",
json={"content": content, "platform_id": platform_id},
)
resp.raise_for_status()
return resp.json()
except Exception as e:
logger.error(f"detect_ai_patterns 失败: {e}")
return {"error": str(e), "patterns": [], "count": 0}
# ─── Registration ───
def register_geo_tools(registry: ToolRegistry) -> None:
"""注册 GEO 项目的所有 Tool"""
# Citation
registry.register(FunctionTool(
name="execute_single_platform",
description="在单个AI平台执行引用检测",
func=execute_single_platform,
tags=["citation", "detection"],
))
registry.register(FunctionTool(
name="get_or_create_task",
description="获取或创建引用检测的查询任务",
func=get_or_create_task,
tags=["citation", "task"],
))
# Content
registry.register(FunctionTool(
name="retrieve_knowledge",
description="从知识库检索相关内容",
func=retrieve_knowledge,
tags=["content", "rag", "knowledge"],
))
# Monitor
registry.register(FunctionTool(
name="monitor_check_and_compare",
description="检测并对比监测记录的变化",
func=monitor_check_and_compare,
tags=["monitor", "tracking"],
))
registry.register(FunctionTool(
name="monitor_generate_report",
description="生成监测变化报告",
func=monitor_generate_report,
tags=["monitor", "report"],
))
registry.register(FunctionTool(
name="monitor_create_record",
description="创建新的监测记录",
func=monitor_create_record,
tags=["monitor", "record"],
))
# Schema
registry.register(FunctionTool(
name="fill_schema_with_llm",
description="使用LLM填充Schema JSON-LD模板",
func=fill_schema_with_llm,
tags=["schema", "llm"],
))
registry.register(FunctionTool(
name="identify_missing_dimensions",
description="识别Schema缺失维度",
func=identify_missing_dimensions,
tags=["schema", "diagnosis"],
))
# Competitor
registry.register(FunctionTool(
name="competitor_analyze",
description="执行竞品策略分析",
func=competitor_analyze,
tags=["competitor", "analysis"],
))
registry.register(FunctionTool(
name="competitor_gap_analysis",
description="执行竞品差距分析",
func=competitor_gap_analysis,
tags=["competitor", "gap"],
))
# Trend
registry.register(FunctionTool(
name="trend_insight",
description="分析品牌引用趋势",
func=trend_insight,
tags=["trend", "insight"],
))
registry.register(FunctionTool(
name="trend_hotspot",
description="检测引用量突增的热点话题",
func=trend_hotspot,
tags=["trend", "hotspot"],
))
# Knowledge
registry.register(FunctionTool(
name="search_knowledge",
description="从知识库检索相关内容",
func=search_knowledge,
tags=["knowledge", "rag"],
))
registry.register(FunctionTool(
name="detect_ai_patterns",
description="检测内容中的AI生成模式",
func=detect_ai_patterns,
tags=["knowledge", "deai"],
))
logger.info(f"GEO tools registered: {len(registry.list_tools())} tools")