feat(U8): GEO agent framework adapter layer - fischer-agentkit integration
- Add 8 YAML configs for agent declarative definition - Add 3 custom handlers (citation, monitor, schema) - Add 7 business tool registration modules - Add adapter.py with DI-based registry and dispatcher - Add database migration for agentkit tables (episodic_memories, evolution_logs, ab_test_configs) - Add fischer-agentkit>=0.1.0 dependency - Refactor agent_framework/__init__.py to re-export from agentkit - Minor API fixes (schema_advisor, scoring, strategy, suggestions) - Frontend package updates
This commit is contained in:
parent
903803c09a
commit
3b581b22ba
|
|
@ -0,0 +1,87 @@
|
|||
"""add agentkit extension tables
|
||||
|
||||
Revision ID: b001_agentkit_extension
|
||||
Revises: a79329c23b20
|
||||
Create Date: 2026-06-04
|
||||
|
||||
新增 fischer-agentkit 扩展表:
|
||||
- episodic_memories: 经验记忆
|
||||
- evolution_logs: 进化日志
|
||||
- ab_test_configs: A/B测试配置
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
revision = "b001_agentkit_extension"
|
||||
down_revision = "a79329c23b20"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# episodic_memories
|
||||
op.create_table(
|
||||
"episodic_memories",
|
||||
sa.Column("id", sa.Uuid(), nullable=False),
|
||||
sa.Column("agent_name", sa.String(50), nullable=False),
|
||||
sa.Column("task_type", sa.String(50), nullable=False),
|
||||
sa.Column("input_data", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("output_data", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("success", sa.Boolean(), nullable=True),
|
||||
sa.Column("quality_score", sa.Float(), nullable=True),
|
||||
sa.Column("reflection", sa.Text(), nullable=True),
|
||||
sa.Column("embedding_id", sa.String(100), nullable=True),
|
||||
sa.Column("tags", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index("idx_episodic_agent_name", "episodic_memories", ["agent_name"])
|
||||
op.create_index("idx_episodic_task_type", "episodic_memories", ["task_type"])
|
||||
op.create_index("idx_episodic_created_at", "episodic_memories", ["created_at"])
|
||||
|
||||
# evolution_logs
|
||||
op.create_table(
|
||||
"evolution_logs",
|
||||
sa.Column("id", sa.Uuid(), nullable=False),
|
||||
sa.Column("agent_name", sa.String(50), nullable=False),
|
||||
sa.Column("change_type", sa.String(30), nullable=False),
|
||||
sa.Column("before", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("after", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("metrics", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("ab_test_id", sa.String(100), nullable=True),
|
||||
sa.Column("status", sa.String(20), server_default="pending", nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index("idx_evolution_agent_name", "evolution_logs", ["agent_name"])
|
||||
op.create_index("idx_evolution_change_type", "evolution_logs", ["change_type"])
|
||||
op.create_index("idx_evolution_status", "evolution_logs", ["status"])
|
||||
op.create_index("idx_evolution_created_at", "evolution_logs", ["created_at"])
|
||||
|
||||
# ab_test_configs
|
||||
op.create_table(
|
||||
"ab_test_configs",
|
||||
sa.Column("id", sa.Uuid(), nullable=False),
|
||||
sa.Column("agent_name", sa.String(50), nullable=False),
|
||||
sa.Column("test_name", sa.String(100), nullable=False),
|
||||
sa.Column("variant_a", postgresql.JSONB(), nullable=False),
|
||||
sa.Column("variant_b", postgresql.JSONB(), nullable=False),
|
||||
sa.Column("traffic_split", sa.Float(), server_default="0.5", nullable=False),
|
||||
sa.Column("status", sa.String(20), server_default="running", nullable=False),
|
||||
sa.Column("winner", sa.String(10), nullable=True),
|
||||
sa.Column("metrics", postgresql.JSONB(), nullable=True),
|
||||
sa.Column("started_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.Column("ended_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index("idx_ab_test_agent_name", "ab_test_configs", ["agent_name"])
|
||||
op.create_index("idx_ab_test_status", "ab_test_configs", ["status"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("ab_test_configs")
|
||||
op.drop_table("evolution_logs")
|
||||
op.drop_table("episodic_memories")
|
||||
|
|
@ -1,5 +1,17 @@
|
|||
"""GEO AI Agent 框架 - 解耦式 Agent 管理与调度"""
|
||||
"""GEO AI Agent 框架
|
||||
|
||||
架构说明:
|
||||
- 框架核心能力来自 fischer-agentkit 包(pip 依赖)
|
||||
- 旧框架文件(base.py, dispatcher.py 等)保留作为兼容层
|
||||
- 新 Agent 通过 YAML 配置 + ConfigDrivenAgent 创建
|
||||
- 旧 Agent 类逐步迁移为薄适配层
|
||||
|
||||
推荐使用方式:
|
||||
from agentkit.core.config_driven import ConfigDrivenAgent # 新方式
|
||||
from app.agent_framework import get_agent_registry # 适配入口
|
||||
"""
|
||||
|
||||
# ---- 旧框架兼容导出(保持现有代码不 break)----
|
||||
from app.agent_framework.base import BaseAgent
|
||||
from app.agent_framework.config_manager import AgentConfigManager
|
||||
from app.agent_framework.dispatcher import TaskDispatcher
|
||||
|
|
@ -28,13 +40,25 @@ from app.agent_framework.protocol import (
|
|||
)
|
||||
from app.agent_framework.registry import AgentRegistry
|
||||
|
||||
# ---- agentkit 框架导出(新方式)----
|
||||
from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent
|
||||
from agentkit.tools.function_tool import FunctionTool
|
||||
from agentkit.tools.registry import ToolRegistry
|
||||
|
||||
# ---- 业务适配入口 ----
|
||||
from app.agent_framework.adapter import (
|
||||
create_agents_from_configs,
|
||||
get_agent_registry,
|
||||
get_task_dispatcher,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Core
|
||||
# Core (旧框架兼容)
|
||||
"BaseAgent",
|
||||
"AgentRegistry",
|
||||
"TaskDispatcher",
|
||||
"AgentConfigManager",
|
||||
# Protocol
|
||||
# Protocol (旧框架兼容)
|
||||
"AgentCapability",
|
||||
"AgentType",
|
||||
"AgentStatus",
|
||||
|
|
@ -42,7 +66,7 @@ __all__ = [
|
|||
"TaskProgress",
|
||||
"TaskResult",
|
||||
"TaskStatus",
|
||||
# Exceptions
|
||||
# Exceptions (旧框架兼容)
|
||||
"AgentFrameworkError",
|
||||
"AgentNotFoundError",
|
||||
"AgentAlreadyRegisteredError",
|
||||
|
|
@ -55,4 +79,13 @@ __all__ = [
|
|||
"TaskCancelledError",
|
||||
"NoAvailableAgentError",
|
||||
"ConfigValidationError",
|
||||
# agentkit 新方式
|
||||
"AgentConfig",
|
||||
"ConfigDrivenAgent",
|
||||
"FunctionTool",
|
||||
"ToolRegistry",
|
||||
# 业务适配入口
|
||||
"create_agents_from_configs",
|
||||
"get_agent_registry",
|
||||
"get_task_dispatcher",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,175 @@
|
|||
"""GEO Agent 适配层 - 基于 fischer-agentkit 的新架构
|
||||
|
||||
职责:
|
||||
1. 从 YAML 配置创建 ConfigDrivenAgent
|
||||
2. 提供依赖注入的 Registry 和 Dispatcher
|
||||
3. 兼容旧版 Agent 类的导入路径
|
||||
|
||||
使用方式:
|
||||
# 新方式(推荐)
|
||||
from app.agent_framework import get_agent_registry, get_task_dispatcher
|
||||
registry = get_agent_registry()
|
||||
dispatcher = get_task_dispatcher()
|
||||
|
||||
# 旧方式(兼容)
|
||||
from app.agent_framework.agents import CitationDetectorAgent
|
||||
agent = CitationDetectorAgent()
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent
|
||||
from agentkit.core.registry import AgentRegistry as AgentkitRegistry
|
||||
from agentkit.core.dispatcher import TaskDispatcher as AgentkitDispatcher
|
||||
from agentkit.tools.registry import ToolRegistry
|
||||
|
||||
from app.agent_framework.tools import get_tool_registry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_AGENTKIT_REGISTRY: AgentkitRegistry | None = None
|
||||
_AGENTKIT_DISPATCHER: AgentkitDispatcher | None = None
|
||||
|
||||
|
||||
def _get_configs_dir() -> Path:
|
||||
"""获取 Agent YAML 配置目录"""
|
||||
return Path(__file__).parent / "agents" / "configs"
|
||||
|
||||
|
||||
def _get_llm_client():
|
||||
"""获取 LLM 客户端"""
|
||||
from app.services.llm import LLMFactory
|
||||
return LLMFactory.get_default()
|
||||
|
||||
|
||||
def _get_custom_handlers() -> dict:
|
||||
"""获取所有自定义 handler"""
|
||||
from app.agent_framework.agents.custom_handlers.citation_handler import handle_citation_task
|
||||
from app.agent_framework.agents.custom_handlers.monitor_handler import handle_monitor_task
|
||||
from app.agent_framework.agents.custom_handlers.schema_handler import handle_schema_task
|
||||
|
||||
return {
|
||||
"app.agent_framework.agents.custom_handlers.citation_handler.handle_citation_task": handle_citation_task,
|
||||
"app.agent_framework.agents.custom_handlers.monitor_handler.handle_monitor_task": handle_monitor_task,
|
||||
"app.agent_framework.agents.custom_handlers.schema_handler.handle_schema_task": handle_schema_task,
|
||||
}
|
||||
|
||||
|
||||
def _get_session_factory():
|
||||
"""获取数据库会话工厂"""
|
||||
from app.database import AsyncSessionLocal
|
||||
return AsyncSessionLocal
|
||||
|
||||
|
||||
def _get_redis_factory():
|
||||
"""获取 Redis 连接工厂"""
|
||||
from app.core.redis import get_redis
|
||||
return get_redis
|
||||
|
||||
|
||||
def _get_agent_model():
|
||||
"""获取 Agent ORM 模型"""
|
||||
from app.models.agent import AgentRegistry as AgentRegistryModel
|
||||
return AgentRegistryModel
|
||||
|
||||
|
||||
def _get_task_model():
|
||||
"""获取 Task ORM 模型"""
|
||||
from app.models.agent import AgentTask as AgentTaskModel
|
||||
return AgentTaskModel
|
||||
|
||||
|
||||
def _get_task_log_model():
|
||||
"""获取 TaskLog ORM 模型"""
|
||||
from app.models.agent import AgentTaskLog as AgentTaskLogModel
|
||||
return AgentTaskLogModel
|
||||
|
||||
|
||||
def create_agents_from_configs() -> list[ConfigDrivenAgent]:
|
||||
"""从 YAML 配置目录创建所有 Agent(新架构)"""
|
||||
configs_dir = _get_configs_dir()
|
||||
tool_registry = get_tool_registry()
|
||||
llm_client = _get_llm_client()
|
||||
custom_handlers = _get_custom_handlers()
|
||||
|
||||
agents = []
|
||||
for yaml_file in sorted(configs_dir.glob("*.yaml")):
|
||||
try:
|
||||
config = AgentConfig.from_yaml(str(yaml_file))
|
||||
agent = ConfigDrivenAgent(
|
||||
config=config,
|
||||
tool_registry=tool_registry,
|
||||
llm_client=llm_client,
|
||||
custom_handlers=custom_handlers,
|
||||
)
|
||||
agents.append(agent)
|
||||
logger.info(f"Created agent '{config.name}' from {yaml_file.name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create agent from {yaml_file.name}: {e}")
|
||||
|
||||
return agents
|
||||
|
||||
|
||||
def get_agent_registry() -> AgentkitRegistry:
|
||||
"""获取 agentkit AgentRegistry(懒初始化,依赖注入)"""
|
||||
global _AGENTKIT_REGISTRY
|
||||
if _AGENTKIT_REGISTRY is None:
|
||||
_AGENTKIT_REGISTRY = AgentkitRegistry(
|
||||
session_factory=_get_session_factory(),
|
||||
agent_model=_get_agent_model(),
|
||||
)
|
||||
agents = create_agents_from_configs()
|
||||
for agent in agents:
|
||||
_AGENTKIT_REGISTRY.register(agent)
|
||||
logger.info(f"Agentkit AgentRegistry initialized with {len(agents)} agents")
|
||||
return _AGENTKIT_REGISTRY
|
||||
|
||||
|
||||
def get_task_dispatcher() -> AgentkitDispatcher:
|
||||
"""获取 agentkit TaskDispatcher(懒初始化,依赖注入)"""
|
||||
global _AGENTKIT_DISPATCHER
|
||||
if _AGENTKIT_DISPATCHER is None:
|
||||
_AGENTKIT_DISPATCHER = AgentkitDispatcher(
|
||||
redis_factory=_get_redis_factory(),
|
||||
session_factory=_get_session_factory(),
|
||||
agent_model=_get_agent_model(),
|
||||
task_model=_get_task_model(),
|
||||
task_log_model=_get_task_log_model(),
|
||||
)
|
||||
logger.info("Agentkit TaskDispatcher initialized")
|
||||
return _AGENTKIT_DISPATCHER
|
||||
|
||||
|
||||
def get_legacy_agent(name: str):
|
||||
"""获取旧版 Agent 实例(兼容层)
|
||||
|
||||
旧代码可能直接实例化 Agent 类,此方法提供兼容性。
|
||||
逐步迁移后可移除。
|
||||
"""
|
||||
from app.agent_framework.agents import (
|
||||
CitationDetectorAgent,
|
||||
CompetitorAnalyzerAgent,
|
||||
ContentGeneratorAgent,
|
||||
DeAIAgent,
|
||||
GEOOptimizerAgent,
|
||||
MonitorAgent,
|
||||
SchemaAdvisorAgent,
|
||||
TrendAgent,
|
||||
)
|
||||
|
||||
legacy_map = {
|
||||
"citation_detector": CitationDetectorAgent,
|
||||
"competitor_analyzer": CompetitorAnalyzerAgent,
|
||||
"content_generator": ContentGeneratorAgent,
|
||||
"deai_agent": DeAIAgent,
|
||||
"geo_optimizer": GEOOptimizerAgent,
|
||||
"monitor": MonitorAgent,
|
||||
"schema_advisor": SchemaAdvisorAgent,
|
||||
"trend_agent": TrendAgent,
|
||||
}
|
||||
|
||||
agent_cls = legacy_map.get(name)
|
||||
if agent_cls:
|
||||
return agent_cls()
|
||||
raise ValueError(f"Unknown agent name: {name}")
|
||||
|
|
@ -0,0 +1 @@
|
|||
"""Agent Configs 包"""
|
||||
|
|
@ -0,0 +1,56 @@
|
|||
name: citation_detector
|
||||
agent_type: citation_detection
|
||||
version: "1.0.0"
|
||||
description: "AI平台引用检测Agent:检测目标品牌在各AI平台回答中的引用情况"
|
||||
task_mode: custom
|
||||
supported_tasks:
|
||||
- citation_detect
|
||||
- citation_detect_single
|
||||
max_concurrency: 3
|
||||
custom_handler: "app.agent_framework.agents.custom_handlers.citation_handler.handle_citation_task"
|
||||
|
||||
input_schema:
|
||||
type: object
|
||||
properties:
|
||||
query_id:
|
||||
type: string
|
||||
description: 查询ID(citation_detect模式)
|
||||
keyword:
|
||||
type: string
|
||||
description: 关键词(citation_detect_single模式)
|
||||
platform:
|
||||
type: string
|
||||
description: 平台名称(citation_detect_single模式)
|
||||
target_brand:
|
||||
type: string
|
||||
description: 目标品牌(citation_detect_single模式)
|
||||
brand_aliases:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: 品牌别名列表
|
||||
|
||||
output_schema:
|
||||
type: object
|
||||
properties:
|
||||
query_id:
|
||||
type: string
|
||||
keyword:
|
||||
type: string
|
||||
total_records:
|
||||
type: integer
|
||||
cited_count:
|
||||
type: integer
|
||||
records:
|
||||
type: array
|
||||
|
||||
tools:
|
||||
- execute_single_platform
|
||||
- get_or_create_task
|
||||
|
||||
memory:
|
||||
working:
|
||||
enabled: true
|
||||
episodic:
|
||||
enabled: true
|
||||
track_success: true
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
name: competitor_analyzer
|
||||
agent_type: competitor_analysis
|
||||
version: "1.0.0"
|
||||
description: "竞品策略分析Agent:对比品牌与竞品的引用数据,识别差距领域,发现机会点,生成策略建议"
|
||||
task_mode: tool_call
|
||||
supported_tasks:
|
||||
- competitor_analyze
|
||||
- competitor_gap_analysis
|
||||
max_concurrency: 2
|
||||
|
||||
input_schema:
|
||||
type: object
|
||||
required:
|
||||
- brand_id
|
||||
properties:
|
||||
brand_id:
|
||||
type: string
|
||||
description: 品牌ID
|
||||
analysis_types:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: 分析类型列表
|
||||
period_days:
|
||||
type: integer
|
||||
description: 分析周期(天)
|
||||
default: 30
|
||||
|
||||
output_schema:
|
||||
type: object
|
||||
properties:
|
||||
brand_id:
|
||||
type: string
|
||||
analysis:
|
||||
type: object
|
||||
recommendations:
|
||||
type: array
|
||||
|
||||
tools:
|
||||
- competitor_analyze
|
||||
- competitor_gap_analysis
|
||||
|
||||
memory:
|
||||
working:
|
||||
enabled: true
|
||||
episodic:
|
||||
enabled: true
|
||||
track_success: true
|
||||
|
|
@ -0,0 +1,97 @@
|
|||
name: content_generator
|
||||
agent_type: content_generation
|
||||
version: "1.0.0"
|
||||
description: "AI内容生成Agent:支持选题推荐和文章生成,可结合知识库RAG检索"
|
||||
task_mode: llm_generate
|
||||
supported_tasks:
|
||||
- generate_topics
|
||||
- generate_article
|
||||
max_concurrency: 2
|
||||
|
||||
input_schema:
|
||||
type: object
|
||||
required:
|
||||
- target_keyword
|
||||
properties:
|
||||
target_keyword:
|
||||
type: string
|
||||
description: 目标关键词
|
||||
brand_name:
|
||||
type: string
|
||||
description: 品牌名称
|
||||
brand_description:
|
||||
type: string
|
||||
description: 品牌描述
|
||||
target_platform:
|
||||
type: string
|
||||
description: 目标平台
|
||||
default: "通用"
|
||||
knowledge_base_ids:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: 知识库ID列表,用于RAG检索
|
||||
topic_title:
|
||||
type: string
|
||||
description: 选题标题(generate_article时使用)
|
||||
word_count:
|
||||
type: integer
|
||||
description: 目标字数
|
||||
default: 2000
|
||||
content_style:
|
||||
type: string
|
||||
description: 内容风格
|
||||
default: "专业严谨"
|
||||
content_angle:
|
||||
type: string
|
||||
description: 内容角度
|
||||
model:
|
||||
type: string
|
||||
description: 指定LLM模型
|
||||
|
||||
output_schema:
|
||||
type: object
|
||||
properties:
|
||||
topics:
|
||||
type: array
|
||||
description: 选题列表
|
||||
content:
|
||||
type: string
|
||||
description: 生成的文章内容
|
||||
word_count:
|
||||
type: integer
|
||||
usage:
|
||||
type: object
|
||||
|
||||
prompt:
|
||||
identity: "你是一个专业的内容生成助手,擅长为品牌创作高质量的SEO/GEO优化内容"
|
||||
context: "品牌需要通过优质内容提升在AI搜索引擎中的可见性和引用率"
|
||||
instructions: |
|
||||
根据用户提供的关键词、品牌信息和知识库内容,生成符合要求的内容。
|
||||
- generate_topics: 生成选题列表,每个选题包含 title、reason、keywords 字段
|
||||
- generate_article: 生成完整文章,确保内容专业、结构清晰、关键词自然融入
|
||||
constraints: |
|
||||
- 内容必须原创,避免抄袭
|
||||
- 关键词密度适中,不要堆砌
|
||||
- 文章结构清晰,段落分明
|
||||
- 数据和引用需标注来源
|
||||
output_format: "以 JSON 格式输出,generate_topics 返回 {topics: [{title, reason, keywords}]},generate_article 返回 {content, word_count}"
|
||||
examples: ""
|
||||
|
||||
llm:
|
||||
model: "deepseek"
|
||||
temperature: 0.7
|
||||
max_tokens: 4000
|
||||
|
||||
tools:
|
||||
- retrieve_knowledge
|
||||
|
||||
memory:
|
||||
working:
|
||||
enabled: true
|
||||
episodic:
|
||||
enabled: true
|
||||
track_success: true
|
||||
semantic:
|
||||
enabled: true
|
||||
knowledge_base_ids_field: "knowledge_base_ids"
|
||||
|
|
@ -0,0 +1,76 @@
|
|||
name: deai_agent
|
||||
agent_type: deai_processing
|
||||
version: "1.1.0"
|
||||
description: "内容去AI化Agent:消除AI生成特征,使文章更自然流畅"
|
||||
task_mode: llm_generate
|
||||
supported_tasks:
|
||||
- deai_process
|
||||
max_concurrency: 2
|
||||
|
||||
input_schema:
|
||||
type: object
|
||||
required:
|
||||
- content
|
||||
properties:
|
||||
content:
|
||||
type: string
|
||||
description: 待处理的文章内容
|
||||
platform:
|
||||
type: string
|
||||
description: 目标平台ID(如 zhihu, wechat)
|
||||
style:
|
||||
type: string
|
||||
description: 目标风格
|
||||
default: "自然流畅"
|
||||
preserve_structure:
|
||||
type: boolean
|
||||
description: 是否保留原有结构
|
||||
default: true
|
||||
|
||||
output_schema:
|
||||
type: object
|
||||
properties:
|
||||
content:
|
||||
type: string
|
||||
description: 处理后的内容
|
||||
original_word_count:
|
||||
type: integer
|
||||
processed_word_count:
|
||||
type: integer
|
||||
usage:
|
||||
type: object
|
||||
detected_ai_patterns:
|
||||
type: array
|
||||
|
||||
prompt:
|
||||
identity: "你是一个专业的内容改写专家,擅长将AI生成的文本改写为自然、人类化的表达"
|
||||
context: "平台对AI生成内容的检测越来越严格,需要将内容改写为更自然的风格"
|
||||
instructions: |
|
||||
对提供的文章内容进行去AI化处理:
|
||||
1. 替换AI常用表达(如"总之"、"综上所述"、"首先其次最后"等)
|
||||
2. 增加口语化表达和个人观点
|
||||
3. 调整句式结构,避免过于工整的排比
|
||||
4. 保留核心信息和数据
|
||||
5. 如有平台特定要求,遵循平台规则
|
||||
constraints: |
|
||||
- 保留原文的核心信息和数据
|
||||
- 不要改变文章的主题和立场
|
||||
- 保持专业性的同时增加自然感
|
||||
- 如指定平台,需符合该平台的内容规范
|
||||
output_format: "返回处理后的完整文章内容"
|
||||
examples: ""
|
||||
|
||||
llm:
|
||||
model: "deepseek"
|
||||
temperature: 0.9
|
||||
max_tokens: 8000
|
||||
|
||||
tools:
|
||||
- detect_ai_patterns
|
||||
|
||||
memory:
|
||||
working:
|
||||
enabled: true
|
||||
episodic:
|
||||
enabled: true
|
||||
track_success: true
|
||||
|
|
@ -0,0 +1,78 @@
|
|||
name: geo_optimizer
|
||||
agent_type: geo_optimization
|
||||
version: "1.0.0"
|
||||
description: "GEO/SEO内容优化Agent:提升内容在AI搜索引擎中的可见性和引用率"
|
||||
task_mode: llm_generate
|
||||
supported_tasks:
|
||||
- geo_optimize
|
||||
max_concurrency: 2
|
||||
|
||||
input_schema:
|
||||
type: object
|
||||
required:
|
||||
- content
|
||||
- target_keywords
|
||||
properties:
|
||||
content:
|
||||
type: string
|
||||
description: 待优化文章
|
||||
target_keywords:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: 目标关键词列表
|
||||
target_platform:
|
||||
type: string
|
||||
description: 目标平台
|
||||
default: "通用"
|
||||
optimization_level:
|
||||
type: string
|
||||
enum: [light, moderate, aggressive]
|
||||
description: 优化级别
|
||||
default: "moderate"
|
||||
|
||||
output_schema:
|
||||
type: object
|
||||
properties:
|
||||
optimized_content:
|
||||
type: string
|
||||
seo_score:
|
||||
type: number
|
||||
changes:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
usage:
|
||||
type: object
|
||||
|
||||
prompt:
|
||||
identity: "你是一个GEO/SEO优化专家,擅长优化内容以提升在AI搜索引擎中的可见性"
|
||||
context: "品牌需要通过内容优化提升在AI搜索结果中的引用率和排名"
|
||||
instructions: |
|
||||
对提供的文章进行GEO/SEO优化:
|
||||
1. 自然融入目标关键词
|
||||
2. 优化标题和段落结构
|
||||
3. 增加结构化数据标记建议
|
||||
4. 提升内容的权威性和引用价值
|
||||
5. 根据optimization_level调整优化力度
|
||||
constraints: |
|
||||
- 优化后的内容必须保持原意
|
||||
- 关键词融入要自然,避免堆砌
|
||||
- 保持文章可读性
|
||||
- 不要添加虚假信息
|
||||
output_format: "以 JSON 格式输出: {optimized_content: string, seo_score: number, changes: [string]}"
|
||||
examples: ""
|
||||
|
||||
llm:
|
||||
model: "deepseek"
|
||||
temperature: 0.5
|
||||
max_tokens: 8000
|
||||
|
||||
tools: []
|
||||
|
||||
memory:
|
||||
working:
|
||||
enabled: true
|
||||
episodic:
|
||||
enabled: true
|
||||
track_success: true
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
name: monitor
|
||||
agent_type: performance_tracker
|
||||
version: "1.0.0"
|
||||
description: "效果追踪Agent:监测品牌引用量、情感、排名变化,生成变化报告"
|
||||
task_mode: custom
|
||||
supported_tasks:
|
||||
- monitor_track
|
||||
- monitor_check_single
|
||||
max_concurrency: 3
|
||||
custom_handler: "app.agent_framework.agents.custom_handlers.monitor_handler.handle_monitor_task"
|
||||
|
||||
input_schema:
|
||||
type: object
|
||||
required:
|
||||
- brand_id
|
||||
properties:
|
||||
brand_id:
|
||||
type: string
|
||||
description: 品牌ID
|
||||
keyword:
|
||||
type: string
|
||||
description: 关键词(monitor_check_single模式)
|
||||
platform:
|
||||
type: string
|
||||
description: 平台名称(monitor_check_single模式)
|
||||
check_interval_hours:
|
||||
type: integer
|
||||
description: 检测间隔小时数
|
||||
default: 24
|
||||
|
||||
output_schema:
|
||||
type: object
|
||||
properties:
|
||||
brand_id:
|
||||
type: string
|
||||
brand_name:
|
||||
type: string
|
||||
total_queries:
|
||||
type: integer
|
||||
checked_records:
|
||||
type: integer
|
||||
reports:
|
||||
type: array
|
||||
|
||||
tools:
|
||||
- monitor_check_and_compare
|
||||
- monitor_generate_report
|
||||
- monitor_create_record
|
||||
|
||||
memory:
|
||||
working:
|
||||
enabled: true
|
||||
episodic:
|
||||
enabled: true
|
||||
track_success: true
|
||||
|
|
@ -0,0 +1,49 @@
|
|||
name: schema_advisor
|
||||
agent_type: schema_advisor
|
||||
version: "1.0.0"
|
||||
description: "Schema优化建议Agent:识别Schema缺失维度,生成JSON-LD结构化数据建议"
|
||||
task_mode: custom
|
||||
supported_tasks:
|
||||
- schema_advise
|
||||
max_concurrency: 2
|
||||
custom_handler: "app.agent_framework.agents.custom_handlers.schema_handler.handle_schema_task"
|
||||
|
||||
input_schema:
|
||||
type: object
|
||||
required:
|
||||
- brand_id
|
||||
properties:
|
||||
brand_id:
|
||||
type: string
|
||||
description: 品牌ID
|
||||
diagnosis_data:
|
||||
type: object
|
||||
description: 诊断数据
|
||||
brand_info:
|
||||
type: object
|
||||
description: 品牌信息
|
||||
focus_dimensions:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: 重点关注维度
|
||||
|
||||
output_schema:
|
||||
type: object
|
||||
properties:
|
||||
brand_id:
|
||||
type: string
|
||||
suggestions:
|
||||
type: array
|
||||
total:
|
||||
type: integer
|
||||
|
||||
tools:
|
||||
- fill_schema_with_llm
|
||||
|
||||
memory:
|
||||
working:
|
||||
enabled: true
|
||||
episodic:
|
||||
enabled: true
|
||||
track_success: true
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
name: trend_agent
|
||||
agent_type: trend_analysis
|
||||
version: "1.0.0"
|
||||
description: "趋势洞察Agent:分析品牌引用趋势、识别热点话题、推断变化原因并生成建议"
|
||||
task_mode: tool_call
|
||||
supported_tasks:
|
||||
- trend_insight
|
||||
- trend_hotspot
|
||||
max_concurrency: 2
|
||||
|
||||
input_schema:
|
||||
type: object
|
||||
required:
|
||||
- brand_id
|
||||
properties:
|
||||
brand_id:
|
||||
type: string
|
||||
description: 品牌ID
|
||||
days:
|
||||
type: integer
|
||||
description: 分析天数
|
||||
default: 30
|
||||
platforms:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: 平台列表
|
||||
keywords:
|
||||
type: array
|
||||
items:
|
||||
type: string
|
||||
description: 关键词列表
|
||||
|
||||
output_schema:
|
||||
type: object
|
||||
properties:
|
||||
brand_id:
|
||||
type: string
|
||||
trends:
|
||||
type: array
|
||||
hotspots:
|
||||
type: array
|
||||
|
||||
tools:
|
||||
- trend_insight
|
||||
- trend_hotspot
|
||||
|
||||
memory:
|
||||
working:
|
||||
enabled: true
|
||||
episodic:
|
||||
enabled: true
|
||||
track_success: true
|
||||
|
|
@ -0,0 +1 @@
|
|||
"""Custom Handlers 包"""
|
||||
|
|
@ -0,0 +1,179 @@
|
|||
"""CitationDetector 自定义 Handler
|
||||
|
||||
处理 citation_detect 和 citation_detect_single 两种任务类型,
|
||||
包含数据库操作和平台调用等复杂业务逻辑。
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from agentkit.core.protocol import TaskMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def handle_citation_task(task: TaskMessage) -> dict:
|
||||
"""引用检测任务入口
|
||||
|
||||
根据 task_type 路由到不同的处理逻辑:
|
||||
- citation_detect: 全量检测(多平台)
|
||||
- citation_detect_single: 单平台检测
|
||||
"""
|
||||
if task.task_type == "citation_detect":
|
||||
return await _execute_full_detect(task)
|
||||
elif task.task_type == "citation_detect_single":
|
||||
return await _execute_single_detect(task)
|
||||
else:
|
||||
raise ValueError(f"Unsupported task type: {task.task_type}")
|
||||
|
||||
|
||||
async def _execute_full_detect(task: TaskMessage) -> dict:
|
||||
"""全量引用检测:遍历查询关联的所有平台"""
|
||||
from app.database import AsyncSessionLocal
|
||||
from app.models.citation_record import CitationRecord
|
||||
from app.models.query import Query
|
||||
from app.models.query_task import QueryTask
|
||||
from app.services.ai_engine.platform_bridge import execute_single_platform
|
||||
|
||||
query_id = task.input_data.get("query_id")
|
||||
if not query_id:
|
||||
raise ValueError("input_data must contain 'query_id'")
|
||||
|
||||
async with AsyncSessionLocal() as db:
|
||||
stmt = select(Query).where(Query.id == query_id)
|
||||
result = await db.execute(stmt)
|
||||
query = result.scalar_one_or_none()
|
||||
|
||||
if not query:
|
||||
raise ValueError(f"Query {query_id} not found")
|
||||
|
||||
records: list[CitationRecord] = []
|
||||
platforms = query.platforms or ["wenxin", "kimi"]
|
||||
brand_aliases = query.brand_aliases or []
|
||||
|
||||
for i, platform_name in enumerate(platforms):
|
||||
task_obj = await _get_or_create_task(db, query.id, platform_name)
|
||||
task_obj.status = "running"
|
||||
task_obj.started_at = datetime.now(timezone.utc)
|
||||
task_obj.error_message = None
|
||||
await db.commit()
|
||||
|
||||
try:
|
||||
detect_result = await execute_single_platform(
|
||||
keyword=query.keyword,
|
||||
platform=platform_name,
|
||||
target_brand=query.target_brand,
|
||||
brand_aliases=brand_aliases,
|
||||
)
|
||||
|
||||
record = CitationRecord.from_citation_result(
|
||||
query_id=query.id,
|
||||
platform=platform_name,
|
||||
result=detect_result,
|
||||
)
|
||||
db.add(record)
|
||||
records.append(record)
|
||||
|
||||
task_obj.status = "success"
|
||||
task_obj.completed_at = datetime.now(timezone.utc)
|
||||
await db.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"平台 {platform_name} 查询失败: {e}")
|
||||
task_obj.status = "failed"
|
||||
task_obj.error_message = str(e)
|
||||
task_obj.completed_at = datetime.now(timezone.utc)
|
||||
|
||||
record = CitationRecord.from_citation_result(
|
||||
query_id=query.id,
|
||||
platform=platform_name,
|
||||
result={"cited": False, "raw_response": str(e)},
|
||||
)
|
||||
db.add(record)
|
||||
records.append(record)
|
||||
await db.commit()
|
||||
|
||||
query.last_queried_at = datetime.now(timezone.utc)
|
||||
query.next_query_at = _calculate_next_query_at(query.frequency)
|
||||
await db.commit()
|
||||
|
||||
record_summaries = [
|
||||
{
|
||||
"id": str(r.id),
|
||||
"platform": r.platform,
|
||||
"cited": r.cited,
|
||||
"confidence": r.confidence,
|
||||
"match_type": r.match_type,
|
||||
}
|
||||
for r in records
|
||||
]
|
||||
|
||||
return {
|
||||
"query_id": str(query_id),
|
||||
"keyword": query.keyword,
|
||||
"total_records": len(records),
|
||||
"cited_count": sum(1 for r in records if r.cited),
|
||||
"records": record_summaries,
|
||||
}
|
||||
|
||||
|
||||
async def _execute_single_detect(task: TaskMessage) -> dict:
|
||||
"""单平台引用检测"""
|
||||
from app.services.ai_engine.platform_bridge import execute_single_platform
|
||||
|
||||
keyword = task.input_data.get("keyword")
|
||||
platform = task.input_data.get("platform")
|
||||
target_brand = task.input_data.get("target_brand")
|
||||
brand_aliases = task.input_data.get("brand_aliases", [])
|
||||
|
||||
if not all([keyword, platform, target_brand]):
|
||||
raise ValueError("input_data must contain 'keyword', 'platform', 'target_brand'")
|
||||
|
||||
result = await execute_single_platform(
|
||||
keyword=keyword,
|
||||
platform=platform,
|
||||
target_brand=target_brand,
|
||||
brand_aliases=brand_aliases,
|
||||
)
|
||||
|
||||
output = {k: v for k, v in result.items() if k != "raw_response"}
|
||||
return output
|
||||
|
||||
|
||||
async def _get_or_create_task(db, query_id: uuid.UUID, platform: str):
|
||||
"""获取或创建查询任务"""
|
||||
from app.models.query_task import QueryTask
|
||||
|
||||
stmt = select(QueryTask).where(
|
||||
QueryTask.query_id == query_id,
|
||||
QueryTask.platform == platform,
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
task_obj = result.scalar_one_or_none()
|
||||
|
||||
if not task_obj:
|
||||
task_obj = QueryTask(
|
||||
query_id=query_id,
|
||||
platform=platform,
|
||||
status="pending",
|
||||
)
|
||||
db.add(task_obj)
|
||||
await db.commit()
|
||||
await db.refresh(task_obj)
|
||||
|
||||
return task_obj
|
||||
|
||||
|
||||
def _calculate_next_query_at(frequency: str | None) -> datetime:
|
||||
"""计算下次查询时间"""
|
||||
now = datetime.now(timezone.utc)
|
||||
freq_map = {
|
||||
"daily": timedelta(days=1),
|
||||
"weekly": timedelta(days=7),
|
||||
"monthly": timedelta(days=30),
|
||||
}
|
||||
delta = freq_map.get(frequency or "weekly", timedelta(days=7))
|
||||
return now + delta
|
||||
|
|
@ -0,0 +1,133 @@
|
|||
"""Monitor 自定义 Handler
|
||||
|
||||
处理 monitor_track 和 monitor_check_single 两种任务类型,
|
||||
包含数据库操作和监测服务调用等复杂业务逻辑。
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from agentkit.core.protocol import TaskMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def handle_monitor_task(task: TaskMessage) -> dict:
|
||||
"""效果追踪任务入口
|
||||
|
||||
根据 task_type 路由到不同的处理逻辑:
|
||||
- monitor_track: 品牌全量效果追踪
|
||||
- monitor_check_single: 单关键词检测
|
||||
"""
|
||||
if task.task_type == "monitor_track":
|
||||
return await _monitor_track(task)
|
||||
elif task.task_type == "monitor_check_single":
|
||||
return await _monitor_check_single(task)
|
||||
else:
|
||||
raise ValueError(f"不支持的任务类型: {task.task_type}")
|
||||
|
||||
|
||||
async def _monitor_track(task: TaskMessage) -> dict:
|
||||
"""品牌全量效果追踪"""
|
||||
from app.database import AsyncSessionLocal
|
||||
from app.models.brand import Brand
|
||||
from app.models.monitoring import MonitoringRecord
|
||||
from app.models.query import Query
|
||||
from app.services.monitoring.monitor_service import MonitorService
|
||||
|
||||
input_data = task.input_data
|
||||
brand_id = input_data.get("brand_id")
|
||||
if not brand_id:
|
||||
raise ValueError("input_data必须包含'brand_id'字段")
|
||||
|
||||
brand_id = uuid.UUID(brand_id)
|
||||
|
||||
async with AsyncSessionLocal() as db:
|
||||
stmt = select(Brand).where(Brand.id == brand_id)
|
||||
result = await db.execute(stmt)
|
||||
brand = result.scalar_one_or_none()
|
||||
if not brand:
|
||||
raise ValueError(f"品牌不存在: {brand_id}")
|
||||
|
||||
queries_stmt = select(Query).where(Query.target_brand == brand.name)
|
||||
queries_result = await db.execute(queries_stmt)
|
||||
queries = list(queries_result.scalars().all())
|
||||
|
||||
if not queries:
|
||||
return {
|
||||
"brand_id": str(brand_id),
|
||||
"brand_name": brand.name,
|
||||
"total_queries": 0,
|
||||
"reports": [],
|
||||
}
|
||||
|
||||
total_queries = len(queries)
|
||||
reports = []
|
||||
service = MonitorService()
|
||||
|
||||
monitoring_stmt = select(MonitoringRecord).where(
|
||||
MonitoringRecord.brand_id == brand_id,
|
||||
MonitoringRecord.status == "active",
|
||||
)
|
||||
monitoring_result = await db.execute(monitoring_stmt)
|
||||
monitoring_records = list(monitoring_result.scalars().all())
|
||||
|
||||
for record in monitoring_records:
|
||||
updated_record = await service.check_and_compare(db, record.id)
|
||||
if updated_record:
|
||||
report = await service.generate_change_report(db, updated_record.id)
|
||||
if report:
|
||||
reports.append(report)
|
||||
|
||||
return {
|
||||
"brand_id": str(brand_id),
|
||||
"brand_name": brand.name,
|
||||
"total_queries": total_queries,
|
||||
"checked_records": len(monitoring_records),
|
||||
"reports": reports,
|
||||
}
|
||||
|
||||
|
||||
async def _monitor_check_single(task: TaskMessage) -> dict:
|
||||
"""单关键词检测"""
|
||||
from app.database import AsyncSessionLocal
|
||||
from app.services.monitoring.monitor_service import MonitorService
|
||||
|
||||
input_data = task.input_data
|
||||
brand_id = input_data.get("brand_id")
|
||||
keyword = input_data.get("keyword")
|
||||
platform = input_data.get("platform")
|
||||
|
||||
if not brand_id:
|
||||
raise ValueError("input_data必须包含'brand_id'字段")
|
||||
|
||||
brand_id = uuid.UUID(brand_id)
|
||||
|
||||
async with AsyncSessionLocal() as db:
|
||||
service = MonitorService()
|
||||
|
||||
record = await service.create_monitoring_record(
|
||||
db=db,
|
||||
brand_id=brand_id,
|
||||
query_keywords=keyword,
|
||||
platform=platform,
|
||||
check_interval_hours=input_data.get("check_interval_hours", 24),
|
||||
)
|
||||
|
||||
updated_record = await service.check_and_compare(db, record.id)
|
||||
|
||||
report = None
|
||||
if updated_record:
|
||||
report = await service.generate_change_report(db, updated_record.id)
|
||||
|
||||
return {
|
||||
"record_id": str(record.id),
|
||||
"brand_id": str(brand_id),
|
||||
"keyword": keyword,
|
||||
"platform": platform,
|
||||
"change_type": updated_record.change_type if updated_record else None,
|
||||
"report": report,
|
||||
}
|
||||
|
|
@ -0,0 +1,285 @@
|
|||
"""SchemaAdvisor 自定义 Handler
|
||||
|
||||
处理 schema_advise 任务类型,
|
||||
包含维度识别、模板匹配、LLM填充、验证排序等复杂业务逻辑。
|
||||
"""
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
|
||||
from agentkit.core.protocol import TaskMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SCHEMA_TEMPLATES = {
|
||||
"Organization": {
|
||||
"@context": "https://schema.org",
|
||||
"@type": "Organization",
|
||||
"name": "",
|
||||
"description": "",
|
||||
"url": "",
|
||||
"logo": "",
|
||||
"sameAs": [],
|
||||
"contactPoint": {
|
||||
"@type": "ContactPoint",
|
||||
"contactType": "customer service",
|
||||
"telephone": "",
|
||||
},
|
||||
},
|
||||
"Product": {
|
||||
"@context": "https://schema.org",
|
||||
"@type": "Product",
|
||||
"name": "",
|
||||
"description": "",
|
||||
"brand": {"@type": "Brand", "name": ""},
|
||||
"offers": {
|
||||
"@type": "Offer",
|
||||
"priceCurrency": "CNY",
|
||||
"availability": "https://schema.org/InStock",
|
||||
},
|
||||
},
|
||||
"FAQPage": {
|
||||
"@context": "https://schema.org",
|
||||
"@type": "FAQPage",
|
||||
"mainEntity": [
|
||||
{
|
||||
"@type": "Question",
|
||||
"name": "",
|
||||
"acceptedAnswer": {"@type": "Answer", "text": ""},
|
||||
}
|
||||
],
|
||||
},
|
||||
"Article": {
|
||||
"@context": "https://schema.org",
|
||||
"@type": "Article",
|
||||
"headline": "",
|
||||
"description": "",
|
||||
"author": {"@type": "Organization", "name": ""},
|
||||
"datePublished": "",
|
||||
"image": "",
|
||||
},
|
||||
"LocalBusiness": {
|
||||
"@context": "https://schema.org",
|
||||
"@type": "LocalBusiness",
|
||||
"name": "",
|
||||
"address": {
|
||||
"@type": "PostalAddress",
|
||||
"streetAddress": "",
|
||||
"addressLocality": "",
|
||||
"addressRegion": "",
|
||||
"postalCode": "",
|
||||
"addressCountry": "CN",
|
||||
},
|
||||
"geo": {"@type": "GeoCoordinates", "latitude": "", "longitude": ""},
|
||||
"telephone": "",
|
||||
"openingHours": "",
|
||||
},
|
||||
}
|
||||
|
||||
DIMENSION_SCHEMA_MAP = {
|
||||
"schema_marketing": ["Organization", "LocalBusiness"],
|
||||
"entity_clarity": ["Organization", "Product"],
|
||||
"citation_readiness": ["FAQPage", "Article"],
|
||||
"brand_visibility": ["Organization", "Product"],
|
||||
"local_seo": ["LocalBusiness"],
|
||||
}
|
||||
|
||||
PRIORITY_THRESHOLD = {"high": 30.0, "medium": 60.0}
|
||||
DIFFICULTY_MAP = {
|
||||
"Organization": "easy",
|
||||
"Product": "medium",
|
||||
"FAQPage": "medium",
|
||||
"Article": "easy",
|
||||
"LocalBusiness": "hard",
|
||||
}
|
||||
|
||||
IMPACT_DESCRIPTIONS = {
|
||||
"Organization": "增强品牌实体识别,提升AI搜索引擎对品牌的理解和引用概率",
|
||||
"Product": "提升产品在搜索结果中的富摘要展示,增加点击率和引用率",
|
||||
"FAQPage": "增加FAQ富摘要展示机会,提升在AI回答中的直接引用概率",
|
||||
"Article": "优化文章内容的结构化表达,提升AI搜索引擎的内容理解和引用",
|
||||
"LocalBusiness": "增强本地搜索可见性,提升地理位置相关查询的引用率",
|
||||
}
|
||||
|
||||
|
||||
async def handle_schema_task(task: TaskMessage) -> dict:
|
||||
"""Schema建议任务入口"""
|
||||
input_data = task.input_data
|
||||
brand_id = input_data.get("brand_id")
|
||||
diagnosis_data = input_data.get("diagnosis_data", {})
|
||||
brand_info = input_data.get("brand_info", {})
|
||||
focus_dimensions = input_data.get("focus_dimensions")
|
||||
|
||||
if not brand_id:
|
||||
raise ValueError("input_data必须包含'brand_id'字段")
|
||||
|
||||
# 1. 识别缺失维度
|
||||
missing_dimensions = _identify_missing_dimensions(diagnosis_data, focus_dimensions)
|
||||
|
||||
# 2. 匹配预定义模板
|
||||
matched = _match_templates(missing_dimensions)
|
||||
|
||||
# 3. LLM填充
|
||||
filled = await _fill_with_llm(matched, brand_info)
|
||||
|
||||
# 4. 验证和排序
|
||||
validated = _validate_and_sort(filled)
|
||||
|
||||
return {
|
||||
"brand_id": brand_id,
|
||||
"suggestions": validated,
|
||||
"total": len(validated),
|
||||
}
|
||||
|
||||
|
||||
def _identify_missing_dimensions(
|
||||
diagnosis_data: dict,
|
||||
focus_dimensions: list[str] | None = None,
|
||||
) -> list[dict]:
|
||||
"""识别Schema缺失维度"""
|
||||
dimensions = []
|
||||
dimension_scores = diagnosis_data.get("dimensions", {})
|
||||
for dim_name, dim_info in dimension_scores.items():
|
||||
if dim_name not in DIMENSION_SCHEMA_MAP:
|
||||
continue
|
||||
if focus_dimensions and dim_name not in focus_dimensions:
|
||||
continue
|
||||
score = dim_info.get("score", 0) if isinstance(dim_info, dict) else dim_info
|
||||
max_score = dim_info.get("max_score", 100) if isinstance(dim_info, dict) else 100
|
||||
percentage = (score / max_score * 100) if max_score > 0 else 0
|
||||
if percentage < 80:
|
||||
dimensions.append({
|
||||
"dimension": dim_name,
|
||||
"current_score": round(score, 2),
|
||||
"max_score": max_score,
|
||||
"percentage": round(percentage, 2),
|
||||
})
|
||||
if not dimensions and diagnosis_data:
|
||||
overall = diagnosis_data.get("overall_score", 0)
|
||||
if overall < 80:
|
||||
for dim_name in DIMENSION_SCHEMA_MAP:
|
||||
if focus_dimensions and dim_name not in focus_dimensions:
|
||||
continue
|
||||
dimensions.append({
|
||||
"dimension": dim_name,
|
||||
"current_score": 0,
|
||||
"max_score": 100,
|
||||
"percentage": 0,
|
||||
})
|
||||
return dimensions
|
||||
|
||||
|
||||
def _match_templates(missing_dimensions: list[dict]) -> list[dict]:
|
||||
"""匹配预定义Schema模板"""
|
||||
matched = []
|
||||
seen_types = set()
|
||||
for dim in missing_dimensions:
|
||||
schema_types = DIMENSION_SCHEMA_MAP.get(dim["dimension"], [])
|
||||
for schema_type in schema_types:
|
||||
if schema_type in seen_types:
|
||||
continue
|
||||
seen_types.add(schema_type)
|
||||
template = SCHEMA_TEMPLATES.get(schema_type)
|
||||
if template:
|
||||
percentage = dim["percentage"]
|
||||
if percentage < PRIORITY_THRESHOLD["high"]:
|
||||
priority = "high"
|
||||
elif percentage < PRIORITY_THRESHOLD["medium"]:
|
||||
priority = "medium"
|
||||
else:
|
||||
priority = "low"
|
||||
matched.append({
|
||||
"schema_type": schema_type,
|
||||
"priority": priority,
|
||||
"diagnosis_dimensions": {
|
||||
"dimension": dim["dimension"],
|
||||
"current_score": dim["current_score"],
|
||||
"max_score": dim["max_score"],
|
||||
"percentage": dim["percentage"],
|
||||
},
|
||||
"json_ld_template": copy.deepcopy(template),
|
||||
"implementation_difficulty": DIFFICULTY_MAP.get(schema_type, "medium"),
|
||||
})
|
||||
return matched
|
||||
|
||||
|
||||
async def _fill_with_llm(matched: list[dict], brand_info: dict) -> list[dict]:
|
||||
"""使用LLM填充Schema模板"""
|
||||
from app.services.llm import LLMFactory, LLMError
|
||||
from app.agent_framework.prompts.schema_advisor import SCHEMA_ADVISOR_TEMPLATE
|
||||
from app.utils.json_extractor import extract_json
|
||||
|
||||
provider = LLMFactory.get_default()
|
||||
results = []
|
||||
for item in matched:
|
||||
schema_type = item["schema_type"]
|
||||
try:
|
||||
variables = {
|
||||
"brand_name": brand_info.get("name", ""),
|
||||
"brand_website": brand_info.get("website", ""),
|
||||
"brand_industry": brand_info.get("industry", ""),
|
||||
"schema_type": schema_type,
|
||||
"diagnosis_data": json.dumps(item.get("diagnosis_dimensions", {}), ensure_ascii=False),
|
||||
"existing_schemas": "无",
|
||||
}
|
||||
messages = SCHEMA_ADVISOR_TEMPLATE.render(variables)
|
||||
response = await provider.chat(messages, temperature=0.3, max_tokens=2048)
|
||||
filled = json.loads(extract_json(response.content))
|
||||
item["json_ld_filled"] = filled
|
||||
item["estimated_impact"] = IMPACT_DESCRIPTIONS.get(
|
||||
schema_type, f"提升{item.get('diagnosis_dimensions', {}).get('dimension', '')}维度的得分和AI引用率"
|
||||
)
|
||||
except (json.JSONDecodeError, LLMError, ValueError) as e:
|
||||
logger.warning(f"LLM填充Schema {schema_type} 失败: {e}")
|
||||
item["json_ld_filled"] = None
|
||||
item["estimated_impact"] = IMPACT_DESCRIPTIONS.get(
|
||||
schema_type, f"提升{item.get('diagnosis_dimensions', {}).get('dimension', '')}维度的得分和AI引用率"
|
||||
)
|
||||
results.append(item)
|
||||
return results
|
||||
|
||||
|
||||
def _validate_json_ld(json_ld: dict) -> dict:
|
||||
"""验证JSON-LD格式"""
|
||||
errors = []
|
||||
warnings = []
|
||||
|
||||
if not json_ld:
|
||||
return {"is_valid": False, "errors": ["JSON-LD为空"], "warnings": []}
|
||||
|
||||
if "@context" not in json_ld:
|
||||
errors.append("缺少@context字段")
|
||||
if "@type" not in json_ld:
|
||||
errors.append("缺少@type字段")
|
||||
if "@context" in json_ld and json_ld["@context"] != "https://schema.org":
|
||||
warnings.append(f"@context值非标准: {json_ld.get('@context')}")
|
||||
if "@type" in json_ld and json_ld["@type"] not in SCHEMA_TEMPLATES:
|
||||
warnings.append(f"@type非推荐类型: {json_ld.get('@type')}")
|
||||
|
||||
try:
|
||||
json.dumps(json_ld)
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
errors.append(f"JSON序列化失败: {e}")
|
||||
|
||||
return {"is_valid": len(errors) == 0, "errors": errors, "warnings": warnings}
|
||||
|
||||
|
||||
def _validate_and_sort(items: list[dict]) -> list[dict]:
|
||||
"""验证并按优先级排序"""
|
||||
validated = []
|
||||
for item in items:
|
||||
json_ld_filled = item.get("json_ld_filled")
|
||||
if json_ld_filled:
|
||||
validation = _validate_json_ld(json_ld_filled)
|
||||
item["validation_errors"] = None if validation["is_valid"] else {
|
||||
"errors": validation["errors"],
|
||||
"warnings": validation["warnings"],
|
||||
}
|
||||
else:
|
||||
item["validation_errors"] = {"errors": ["JSON-LD填充失败"], "warnings": []}
|
||||
validated.append(item)
|
||||
|
||||
priority_order = {"high": 0, "medium": 1, "low": 2}
|
||||
validated.sort(key=lambda x: priority_order.get(x.get("priority", "medium"), 1))
|
||||
return validated
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
"""GEO 业务工具注册 - 统一注册入口"""
|
||||
|
||||
import logging
|
||||
|
||||
from agentkit.tools.registry import ToolRegistry
|
||||
|
||||
from app.agent_framework.tools.citation_tools import register_citation_tools
|
||||
from app.agent_framework.tools.content_tools import register_content_tools
|
||||
from app.agent_framework.tools.monitor_tools import register_monitor_tools
|
||||
from app.agent_framework.tools.schema_tools import register_schema_tools
|
||||
from app.agent_framework.tools.competitor_tools import register_competitor_tools
|
||||
from app.agent_framework.tools.trend_tools import register_trend_tools
|
||||
from app.agent_framework.tools.knowledge_tools import register_knowledge_tools
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_REGISTRY: ToolRegistry | None = None
|
||||
|
||||
|
||||
def get_tool_registry() -> ToolRegistry:
|
||||
"""获取全局 ToolRegistry(懒初始化)"""
|
||||
global _REGISTRY
|
||||
if _REGISTRY is None:
|
||||
_REGISTRY = ToolRegistry()
|
||||
register_all_tools(_REGISTRY)
|
||||
return _REGISTRY
|
||||
|
||||
|
||||
def register_all_tools(registry: ToolRegistry) -> None:
|
||||
"""注册所有业务工具"""
|
||||
register_citation_tools(registry)
|
||||
register_content_tools(registry)
|
||||
register_monitor_tools(registry)
|
||||
register_schema_tools(registry)
|
||||
register_competitor_tools(registry)
|
||||
register_trend_tools(registry)
|
||||
register_knowledge_tools(registry)
|
||||
logger.info("All GEO business tools registered")
|
||||
|
|
@ -0,0 +1,76 @@
|
|||
"""Citation 业务工具 - 将引用检测服务注册为 FunctionTool"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from agentkit.tools.function_tool import FunctionTool
|
||||
from agentkit.tools.registry import ToolRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def execute_single_platform(
|
||||
keyword: str,
|
||||
platform: str,
|
||||
target_brand: str,
|
||||
brand_aliases: list[str] | None = None,
|
||||
) -> dict:
|
||||
"""调用平台执行单次引用检测"""
|
||||
from app.services.ai_engine.platform_bridge import execute_single_platform as _exec
|
||||
|
||||
return await _exec(
|
||||
keyword=keyword,
|
||||
platform=platform,
|
||||
target_brand=target_brand,
|
||||
brand_aliases=brand_aliases or [],
|
||||
)
|
||||
|
||||
|
||||
async def get_or_create_task(query_id: str, platform: str) -> dict:
|
||||
"""获取或创建查询任务"""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from sqlalchemy import select
|
||||
from app.database import AsyncSessionLocal
|
||||
from app.models.query_task import QueryTask
|
||||
|
||||
async with AsyncSessionLocal() as db:
|
||||
stmt = select(QueryTask).where(
|
||||
QueryTask.query_id == uuid.UUID(query_id),
|
||||
QueryTask.platform == platform,
|
||||
)
|
||||
result = await db.execute(stmt)
|
||||
task_obj = result.scalar_one_or_none()
|
||||
|
||||
if not task_obj:
|
||||
task_obj = QueryTask(
|
||||
query_id=uuid.UUID(query_id),
|
||||
platform=platform,
|
||||
status="pending",
|
||||
)
|
||||
db.add(task_obj)
|
||||
await db.commit()
|
||||
await db.refresh(task_obj)
|
||||
|
||||
return {"id": str(task_obj.id), "platform": task_obj.platform, "status": task_obj.status}
|
||||
|
||||
|
||||
def register_citation_tools(registry: ToolRegistry) -> None:
|
||||
"""注册所有引用检测相关工具"""
|
||||
registry.register(
|
||||
FunctionTool(
|
||||
name="execute_single_platform",
|
||||
description="在指定AI平台执行引用检测,返回引用结果",
|
||||
func=execute_single_platform,
|
||||
tags=["citation", "detection"],
|
||||
)
|
||||
)
|
||||
registry.register(
|
||||
FunctionTool(
|
||||
name="get_or_create_task",
|
||||
description="获取或创建引用检测的查询任务记录",
|
||||
func=get_or_create_task,
|
||||
tags=["citation", "task"],
|
||||
)
|
||||
)
|
||||
logger.info("Citation tools registered")
|
||||
|
|
@ -0,0 +1,63 @@
|
|||
"""Competitor 业务工具 - 将竞品分析服务注册为 FunctionTool"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from agentkit.tools.function_tool import FunctionTool
|
||||
from agentkit.tools.registry import ToolRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def competitor_analyze(
|
||||
brand_id: str,
|
||||
analysis_types: list[str] | None = None,
|
||||
period_days: int = 30,
|
||||
) -> dict:
|
||||
"""执行竞品策略分析"""
|
||||
from app.services.competitor.competitor_analyzer_service import CompetitorAnalyzerService
|
||||
|
||||
service = CompetitorAnalyzerService()
|
||||
result = await service.analyze_competitor(
|
||||
brand_id=brand_id,
|
||||
analysis_types=analysis_types,
|
||||
period_days=period_days,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
async def competitor_gap_analysis(
|
||||
brand_id: str,
|
||||
period_days: int = 30,
|
||||
) -> dict:
|
||||
"""执行竞品差距分析"""
|
||||
from app.services.competitor.competitor_analyzer_service import CompetitorAnalyzerService
|
||||
|
||||
service = CompetitorAnalyzerService()
|
||||
result = await service.analyze_competitor(
|
||||
brand_id=brand_id,
|
||||
analysis_types=["citation_gap", "platform_coverage", "query_overlap"],
|
||||
period_days=period_days,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def register_competitor_tools(registry: ToolRegistry) -> None:
|
||||
"""注册所有竞品分析相关工具"""
|
||||
registry.register(
|
||||
FunctionTool(
|
||||
name="competitor_analyze",
|
||||
description="执行竞品策略分析,对比品牌与竞品的引用数据",
|
||||
func=competitor_analyze,
|
||||
tags=["competitor", "analysis"],
|
||||
)
|
||||
)
|
||||
registry.register(
|
||||
FunctionTool(
|
||||
name="competitor_gap_analysis",
|
||||
description="执行竞品差距分析,识别差距领域和机会点",
|
||||
func=competitor_gap_analysis,
|
||||
tags=["competitor", "gap"],
|
||||
)
|
||||
)
|
||||
logger.info("Competitor tools registered")
|
||||
|
|
@ -0,0 +1,57 @@
|
|||
"""Content 业务工具 - 将内容生成相关服务注册为 FunctionTool"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from agentkit.tools.function_tool import FunctionTool
|
||||
from agentkit.tools.registry import ToolRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def retrieve_knowledge(
|
||||
knowledge_base_ids: list[str],
|
||||
query: str,
|
||||
top_k: int = 5,
|
||||
) -> dict:
|
||||
"""从知识库检索相关内容"""
|
||||
if not knowledge_base_ids or not query:
|
||||
return {"content": "暂无相关知识库内容", "sources": []}
|
||||
|
||||
try:
|
||||
from app.database import AsyncSessionLocal
|
||||
from app.services.knowledge.rag_service import RAGService
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
rag = RAGService()
|
||||
results = await rag.search(
|
||||
session=session,
|
||||
query=query,
|
||||
knowledge_base_ids=knowledge_base_ids,
|
||||
top_k=top_k,
|
||||
)
|
||||
if results:
|
||||
content_parts = []
|
||||
sources = []
|
||||
for r in results:
|
||||
title = r.get("document_title", "未知")
|
||||
content_parts.append(f"[来源: {title}]\n{r.get('content', '')}")
|
||||
sources.append(title)
|
||||
return {"content": "\n\n---\n\n".join(content_parts), "sources": sources}
|
||||
except Exception as e:
|
||||
logger.warning(f"RAG检索失败: {e}")
|
||||
|
||||
return {"content": "暂无相关知识库内容", "sources": []}
|
||||
|
||||
|
||||
def register_content_tools(registry: ToolRegistry) -> None:
|
||||
"""注册所有内容生成相关工具"""
|
||||
registry.register(
|
||||
FunctionTool(
|
||||
name="retrieve_knowledge",
|
||||
description="从知识库检索相关内容,用于RAG增强生成",
|
||||
func=retrieve_knowledge,
|
||||
tags=["content", "rag", "knowledge"],
|
||||
)
|
||||
)
|
||||
logger.info("Content tools registered")
|
||||
|
|
@ -0,0 +1,73 @@
|
|||
"""Knowledge 业务工具 - 将知识库服务注册为 FunctionTool"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from agentkit.tools.function_tool import FunctionTool
|
||||
from agentkit.tools.registry import ToolRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def search_knowledge(
|
||||
query: str,
|
||||
knowledge_base_ids: list[str],
|
||||
top_k: int = 5,
|
||||
) -> dict:
|
||||
"""从知识库检索相关内容"""
|
||||
if not knowledge_base_ids or not query:
|
||||
return {"content": "暂无相关知识库内容", "sources": []}
|
||||
|
||||
try:
|
||||
from app.database import AsyncSessionLocal
|
||||
from app.services.knowledge.rag_service import RAGService
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
rag = RAGService()
|
||||
results = await rag.search(
|
||||
session=session,
|
||||
query=query,
|
||||
knowledge_base_ids=knowledge_base_ids,
|
||||
top_k=top_k,
|
||||
)
|
||||
if results:
|
||||
content_parts = []
|
||||
sources = []
|
||||
for r in results:
|
||||
title = r.get("document_title", "未知")
|
||||
content_parts.append(f"[来源: {title}]\n{r.get('content', '')}")
|
||||
sources.append(title)
|
||||
return {"content": "\n\n---\n\n".join(content_parts), "sources": sources}
|
||||
except Exception as e:
|
||||
logger.warning(f"知识库检索失败: {e}")
|
||||
|
||||
return {"content": "暂无相关知识库内容", "sources": []}
|
||||
|
||||
|
||||
async def detect_ai_patterns(content: str, platform_id: str) -> dict:
|
||||
"""检测内容中的AI生成模式"""
|
||||
from app.services.distribution.rule_service import platform_rule_service
|
||||
|
||||
patterns = platform_rule_service.detect_ai_patterns(content, platform_id)
|
||||
return {"patterns": patterns, "count": len(patterns)}
|
||||
|
||||
|
||||
def register_knowledge_tools(registry: ToolRegistry) -> None:
|
||||
"""注册所有知识库相关工具"""
|
||||
registry.register(
|
||||
FunctionTool(
|
||||
name="search_knowledge",
|
||||
description="从知识库检索相关内容",
|
||||
func=search_knowledge,
|
||||
tags=["knowledge", "rag"],
|
||||
)
|
||||
)
|
||||
registry.register(
|
||||
FunctionTool(
|
||||
name="detect_ai_patterns",
|
||||
description="检测内容中的AI生成模式",
|
||||
func=detect_ai_patterns,
|
||||
tags=["knowledge", "deai"],
|
||||
)
|
||||
)
|
||||
logger.info("Knowledge tools registered")
|
||||
|
|
@ -0,0 +1,91 @@
|
|||
"""Monitor 业务工具 - 将效果追踪服务注册为 FunctionTool"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from agentkit.tools.function_tool import FunctionTool
|
||||
from agentkit.tools.registry import ToolRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def monitor_check_and_compare(record_id: str) -> dict:
|
||||
"""检测并对比监测记录的变化"""
|
||||
import uuid
|
||||
from app.database import AsyncSessionLocal
|
||||
from app.services.monitoring.monitor_service import MonitorService
|
||||
|
||||
async with AsyncSessionLocal() as db:
|
||||
service = MonitorService()
|
||||
updated_record = await service.check_and_compare(db, uuid.UUID(record_id))
|
||||
if updated_record:
|
||||
return {
|
||||
"id": str(updated_record.id),
|
||||
"change_type": updated_record.change_type,
|
||||
"updated": True,
|
||||
}
|
||||
return {"id": record_id, "updated": False}
|
||||
|
||||
|
||||
async def monitor_generate_report(record_id: str) -> dict:
|
||||
"""生成监测变化报告"""
|
||||
import uuid
|
||||
from app.database import AsyncSessionLocal
|
||||
from app.services.monitoring.monitor_service import MonitorService
|
||||
|
||||
async with AsyncSessionLocal() as db:
|
||||
service = MonitorService()
|
||||
report = await service.generate_change_report(db, uuid.UUID(record_id))
|
||||
return {"report": report} if report else {"report": None}
|
||||
|
||||
|
||||
async def monitor_create_record(
|
||||
brand_id: str,
|
||||
query_keywords: str | None = None,
|
||||
platform: str | None = None,
|
||||
check_interval_hours: int = 24,
|
||||
) -> dict:
|
||||
"""创建监测记录"""
|
||||
import uuid
|
||||
from app.database import AsyncSessionLocal
|
||||
from app.services.monitoring.monitor_service import MonitorService
|
||||
|
||||
async with AsyncSessionLocal() as db:
|
||||
service = MonitorService()
|
||||
record = await service.create_monitoring_record(
|
||||
db=db,
|
||||
brand_id=uuid.UUID(brand_id),
|
||||
query_keywords=query_keywords,
|
||||
platform=platform,
|
||||
check_interval_hours=check_interval_hours,
|
||||
)
|
||||
return {"id": str(record.id), "status": record.status}
|
||||
|
||||
|
||||
def register_monitor_tools(registry: ToolRegistry) -> None:
|
||||
"""注册所有效果追踪相关工具"""
|
||||
registry.register(
|
||||
FunctionTool(
|
||||
name="monitor_check_and_compare",
|
||||
description="检测并对比监测记录的变化",
|
||||
func=monitor_check_and_compare,
|
||||
tags=["monitor", "tracking"],
|
||||
)
|
||||
)
|
||||
registry.register(
|
||||
FunctionTool(
|
||||
name="monitor_generate_report",
|
||||
description="生成监测变化报告",
|
||||
func=monitor_generate_report,
|
||||
tags=["monitor", "report"],
|
||||
)
|
||||
)
|
||||
registry.register(
|
||||
FunctionTool(
|
||||
name="monitor_create_record",
|
||||
description="创建新的监测记录",
|
||||
func=monitor_create_record,
|
||||
tags=["monitor", "record"],
|
||||
)
|
||||
)
|
||||
logger.info("Monitor tools registered")
|
||||
|
|
@ -0,0 +1,176 @@
|
|||
"""Schema 业务工具 - 将Schema建议服务注册为 FunctionTool"""
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from agentkit.tools.function_tool import FunctionTool
|
||||
from agentkit.tools.registry import ToolRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SCHEMA_TEMPLATES = {
|
||||
"Organization": {
|
||||
"@context": "https://schema.org",
|
||||
"@type": "Organization",
|
||||
"name": "",
|
||||
"description": "",
|
||||
"url": "",
|
||||
"logo": "",
|
||||
"sameAs": [],
|
||||
"contactPoint": {
|
||||
"@type": "ContactPoint",
|
||||
"contactType": "customer service",
|
||||
"telephone": "",
|
||||
},
|
||||
},
|
||||
"Product": {
|
||||
"@context": "https://schema.org",
|
||||
"@type": "Product",
|
||||
"name": "",
|
||||
"description": "",
|
||||
"brand": {"@type": "Brand", "name": ""},
|
||||
"offers": {
|
||||
"@type": "Offer",
|
||||
"priceCurrency": "CNY",
|
||||
"availability": "https://schema.org/InStock",
|
||||
},
|
||||
},
|
||||
"FAQPage": {
|
||||
"@context": "https://schema.org",
|
||||
"@type": "FAQPage",
|
||||
"mainEntity": [
|
||||
{
|
||||
"@type": "Question",
|
||||
"name": "",
|
||||
"acceptedAnswer": {"@type": "Answer", "text": ""},
|
||||
}
|
||||
],
|
||||
},
|
||||
"Article": {
|
||||
"@context": "https://schema.org",
|
||||
"@type": "Article",
|
||||
"headline": "",
|
||||
"description": "",
|
||||
"author": {"@type": "Organization", "name": ""},
|
||||
"datePublished": "",
|
||||
"image": "",
|
||||
},
|
||||
"LocalBusiness": {
|
||||
"@context": "https://schema.org",
|
||||
"@type": "LocalBusiness",
|
||||
"name": "",
|
||||
"address": {
|
||||
"@type": "PostalAddress",
|
||||
"streetAddress": "",
|
||||
"addressLocality": "",
|
||||
"addressRegion": "",
|
||||
"postalCode": "",
|
||||
"addressCountry": "CN",
|
||||
},
|
||||
"geo": {"@type": "GeoCoordinates", "latitude": "", "longitude": ""},
|
||||
"telephone": "",
|
||||
"openingHours": "",
|
||||
},
|
||||
}
|
||||
|
||||
DIMENSION_SCHEMA_MAP = {
|
||||
"schema_marketing": ["Organization", "LocalBusiness"],
|
||||
"entity_clarity": ["Organization", "Product"],
|
||||
"citation_readiness": ["FAQPage", "Article"],
|
||||
"brand_visibility": ["Organization", "Product"],
|
||||
"local_seo": ["LocalBusiness"],
|
||||
}
|
||||
|
||||
PRIORITY_THRESHOLD = {"high": 30.0, "medium": 60.0}
|
||||
DIFFICULTY_MAP = {
|
||||
"Organization": "easy",
|
||||
"Product": "medium",
|
||||
"FAQPage": "medium",
|
||||
"Article": "easy",
|
||||
"LocalBusiness": "hard",
|
||||
}
|
||||
|
||||
|
||||
async def fill_schema_with_llm(
|
||||
schema_type: str,
|
||||
brand_info: dict | None = None,
|
||||
diagnosis_dimensions: dict | None = None,
|
||||
) -> dict:
|
||||
"""使用LLM填充Schema模板"""
|
||||
from app.services.llm import LLMFactory
|
||||
from app.agent_framework.prompts.schema_advisor import SCHEMA_ADVISOR_TEMPLATE
|
||||
from app.utils.json_extractor import extract_json
|
||||
|
||||
brand_info = brand_info or {}
|
||||
diagnosis_dimensions = diagnosis_dimensions or {}
|
||||
|
||||
template = SCHEMA_TEMPLATES.get(schema_type)
|
||||
if not template:
|
||||
return {"schema_type": schema_type, "json_ld_filled": None, "error": "Unknown schema type"}
|
||||
|
||||
provider = LLMFactory.get_default()
|
||||
variables = {
|
||||
"brand_name": brand_info.get("name", ""),
|
||||
"brand_website": brand_info.get("website", ""),
|
||||
"brand_industry": brand_info.get("industry", ""),
|
||||
"schema_type": schema_type,
|
||||
"diagnosis_data": json.dumps(diagnosis_dimensions, ensure_ascii=False),
|
||||
"existing_schemas": "无",
|
||||
}
|
||||
messages = SCHEMA_ADVISOR_TEMPLATE.render(variables)
|
||||
|
||||
try:
|
||||
response = await provider.chat(messages, temperature=0.3, max_tokens=2048)
|
||||
filled = json.loads(extract_json(response.content))
|
||||
return {"schema_type": schema_type, "json_ld_filled": filled}
|
||||
except Exception as e:
|
||||
logger.warning(f"LLM填充Schema {schema_type} 失败: {e}")
|
||||
return {"schema_type": schema_type, "json_ld_filled": None, "error": str(e)}
|
||||
|
||||
|
||||
async def identify_missing_dimensions(
|
||||
diagnosis_data: dict,
|
||||
focus_dimensions: list[str] | None = None,
|
||||
) -> dict:
|
||||
"""识别Schema缺失维度"""
|
||||
dimensions = []
|
||||
dimension_scores = diagnosis_data.get("dimensions", {})
|
||||
for dim_name, dim_info in dimension_scores.items():
|
||||
if dim_name not in DIMENSION_SCHEMA_MAP:
|
||||
continue
|
||||
if focus_dimensions and dim_name not in focus_dimensions:
|
||||
continue
|
||||
score = dim_info.get("score", 0) if isinstance(dim_info, dict) else dim_info
|
||||
max_score = dim_info.get("max_score", 100) if isinstance(dim_info, dict) else 100
|
||||
percentage = (score / max_score * 100) if max_score > 0 else 0
|
||||
if percentage < 80:
|
||||
dimensions.append({
|
||||
"dimension": dim_name,
|
||||
"current_score": round(score, 2),
|
||||
"max_score": max_score,
|
||||
"percentage": round(percentage, 2),
|
||||
})
|
||||
return {"missing_dimensions": dimensions}
|
||||
|
||||
|
||||
def register_schema_tools(registry: ToolRegistry) -> None:
|
||||
"""注册所有Schema建议相关工具"""
|
||||
registry.register(
|
||||
FunctionTool(
|
||||
name="fill_schema_with_llm",
|
||||
description="使用LLM填充Schema JSON-LD模板",
|
||||
func=fill_schema_with_llm,
|
||||
tags=["schema", "llm"],
|
||||
)
|
||||
)
|
||||
registry.register(
|
||||
FunctionTool(
|
||||
name="identify_missing_dimensions",
|
||||
description="识别诊断数据中的Schema缺失维度",
|
||||
func=identify_missing_dimensions,
|
||||
tags=["schema", "diagnosis"],
|
||||
)
|
||||
)
|
||||
logger.info("Schema tools registered")
|
||||
|
|
@ -0,0 +1,68 @@
|
|||
"""Trend 业务工具 - 将趋势分析服务注册为 FunctionTool"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from agentkit.tools.function_tool import FunctionTool
|
||||
from agentkit.tools.registry import ToolRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def trend_insight(
|
||||
brand_id: str,
|
||||
days: int = 30,
|
||||
platforms: list[str] | None = None,
|
||||
keywords: list[str] | None = None,
|
||||
) -> dict:
|
||||
"""执行趋势洞察分析"""
|
||||
from app.database import AsyncSessionLocal
|
||||
from app.services.trend.trend_analyzer_service import TrendAnalyzerService
|
||||
|
||||
async with AsyncSessionLocal() as db:
|
||||
service = TrendAnalyzerService(db)
|
||||
result = await service.analyze_trends(
|
||||
brand_id=brand_id,
|
||||
days=days,
|
||||
platforms=platforms,
|
||||
keywords=keywords,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
async def trend_hotspot(
|
||||
brand_id: str,
|
||||
days: int = 30,
|
||||
) -> dict:
|
||||
"""检测引用量突增的热点话题"""
|
||||
from app.database import AsyncSessionLocal
|
||||
from app.services.trend.trend_analyzer_service import TrendAnalyzerService
|
||||
|
||||
async with AsyncSessionLocal() as db:
|
||||
service = TrendAnalyzerService(db)
|
||||
result = await service.get_hotspots(
|
||||
brand_id=brand_id,
|
||||
days=days,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def register_trend_tools(registry: ToolRegistry) -> None:
|
||||
"""注册所有趋势分析相关工具"""
|
||||
registry.register(
|
||||
FunctionTool(
|
||||
name="trend_insight",
|
||||
description="分析品牌引用趋势,推断变化原因",
|
||||
func=trend_insight,
|
||||
tags=["trend", "insight"],
|
||||
)
|
||||
)
|
||||
registry.register(
|
||||
FunctionTool(
|
||||
name="trend_hotspot",
|
||||
description="检测引用量突增的热点话题",
|
||||
func=trend_hotspot,
|
||||
tags=["trend", "hotspot"],
|
||||
)
|
||||
)
|
||||
logger.info("Trend tools registered")
|
||||
|
|
@ -57,7 +57,7 @@ async def _get_brand_diagnosis_data(
|
|||
from app.services.analysis.sentiment_service import get_sentiment_service
|
||||
|
||||
queries_stmt = select(QueryModel).where(
|
||||
QueryModel.user_id == user_id,
|
||||
QueryModel.user_id == str(user_id),
|
||||
QueryModel.target_brand == brand.name,
|
||||
)
|
||||
queries_result = await db.execute(queries_stmt)
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ async def _get_citations_for_brand(
|
|||
|
||||
# Find queries that match this brand
|
||||
queries_stmt = select(QueryModel).where(
|
||||
QueryModel.user_id == user_id,
|
||||
QueryModel.user_id == str(user_id),
|
||||
QueryModel.target_brand == brand.name,
|
||||
)
|
||||
queries_result = await db.execute(queries_stmt)
|
||||
|
|
@ -215,7 +215,7 @@ async def _calculate_trend(
|
|||
|
||||
# Get queries for this brand
|
||||
queries_stmt = select(QueryModel).where(
|
||||
QueryModel.user_id == user_id,
|
||||
QueryModel.user_id == str(user_id),
|
||||
QueryModel.target_brand == brand.name,
|
||||
)
|
||||
queries_result = await db.execute(queries_stmt)
|
||||
|
|
@ -266,7 +266,7 @@ async def _calculate_trend_for_competitor(
|
|||
|
||||
# Get queries for this competitor
|
||||
queries_stmt = select(QueryModel).where(
|
||||
QueryModel.user_id == user_id,
|
||||
QueryModel.user_id == str(user_id),
|
||||
QueryModel.target_brand == competitor_name,
|
||||
)
|
||||
queries_result = await db.execute(queries_stmt)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ from datetime import datetime
|
|||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.database import get_db
|
||||
|
|
@ -181,11 +180,7 @@ async def generate_geo_plan_endpoint(
|
|||
await db.commit()
|
||||
await db.refresh(db_plan)
|
||||
|
||||
stmt = (
|
||||
select(GeoPlan)
|
||||
.options(selectinload(GeoPlanAction.plan))
|
||||
.where(GeoPlan.id == db_plan.id)
|
||||
)
|
||||
stmt = select(GeoPlan).where(GeoPlan.id == db_plan.id)
|
||||
result = await db.execute(stmt)
|
||||
db_plan = result.scalar_one()
|
||||
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ async def _get_brand_scoring_data(
|
|||
"""
|
||||
# 获取品牌查询
|
||||
queries_stmt = select(QueryModel).where(
|
||||
QueryModel.user_id == user_id,
|
||||
QueryModel.user_id == str(user_id),
|
||||
QueryModel.target_brand == brand.name,
|
||||
)
|
||||
queries_result = await db.execute(queries_stmt)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import String, Integer, DateTime, ForeignKey, Index, func, Text
|
||||
from sqlalchemy import String, Integer, Float, DateTime, ForeignKey, Index, func, Text
|
||||
from sqlalchemy import Uuid
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
|
|
@ -207,3 +207,102 @@ class AgentTaskLog(Base):
|
|||
Index("idx_agent_task_logs_agent_id", "agent_id"),
|
||||
Index("idx_agent_task_logs_created_at", "created_at"),
|
||||
)
|
||||
|
||||
|
||||
# ---- fischer-agentkit 扩展表 ----
|
||||
|
||||
|
||||
class EpisodicMemory(Base):
|
||||
"""经验记忆表 - 记录每次任务的输入/输出/效果/反思"""
|
||||
__tablename__ = "episodic_memories"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
Uuid(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
)
|
||||
agent_name: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
task_type: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
input_data: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
|
||||
output_data: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
|
||||
success: Mapped[bool | None] = mapped_column(default=None, nullable=True)
|
||||
quality_score: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
reflection: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
embedding_id: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
tags: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("idx_episodic_agent_name", "agent_name"),
|
||||
Index("idx_episodic_task_type", "task_type"),
|
||||
Index("idx_episodic_created_at", "created_at"),
|
||||
)
|
||||
|
||||
|
||||
class EvolutionLog(Base):
|
||||
"""进化日志表 - 记录每次进化变更"""
|
||||
__tablename__ = "evolution_logs"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
Uuid(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
)
|
||||
agent_name: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
change_type: Mapped[str] = mapped_column(String(30), nullable=False)
|
||||
before: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
|
||||
after: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
|
||||
metrics: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
|
||||
ab_test_id: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
status: Mapped[str] = mapped_column(String(20), server_default="pending", nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("idx_evolution_agent_name", "agent_name"),
|
||||
Index("idx_evolution_change_type", "change_type"),
|
||||
Index("idx_evolution_status", "status"),
|
||||
Index("idx_evolution_created_at", "created_at"),
|
||||
)
|
||||
|
||||
|
||||
class ABTestConfig(Base):
|
||||
"""A/B测试配置表"""
|
||||
__tablename__ = "ab_test_configs"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
Uuid(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
)
|
||||
agent_name: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
test_name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
variant_a: Mapped[dict] = mapped_column(JSONType, nullable=False)
|
||||
variant_b: Mapped[dict] = mapped_column(JSONType, nullable=False)
|
||||
traffic_split: Mapped[float] = mapped_column(Float, server_default="0.5", nullable=False)
|
||||
status: Mapped[str] = mapped_column(String(20), server_default="running", nullable=False)
|
||||
winner: Mapped[str | None] = mapped_column(String(10), nullable=True)
|
||||
metrics: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
|
||||
started_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
ended_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("idx_ab_test_agent_name", "agent_name"),
|
||||
Index("idx_ab_test_status", "status"),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from app.config import settings
|
||||
from app.models.user import User
|
||||
from app.models.organization import Organization, OrgMember
|
||||
from app.schemas.auth import UserRegister, UpdateProfileRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -75,6 +76,28 @@ async def register_user(db: AsyncSession, user_data: UserRegister) -> User:
|
|||
max_queries=5,
|
||||
)
|
||||
db.add(user)
|
||||
await db.flush()
|
||||
|
||||
# Auto-create personal organization
|
||||
org = Organization(
|
||||
id=uuid.UUID(user.id),
|
||||
name=f"{user_data.name}的个人空间",
|
||||
slug=f"user-{user.id[:8]}",
|
||||
plan="free",
|
||||
)
|
||||
db.add(org)
|
||||
await db.flush()
|
||||
|
||||
org_member = OrgMember(
|
||||
organization_id=org.id,
|
||||
user_id=user.id,
|
||||
role="owner",
|
||||
)
|
||||
db.add(org_member)
|
||||
|
||||
# Link user to organization
|
||||
user.organization_id = org.id
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ class BrandScoringDataService:
|
|||
brand: Brand,
|
||||
) -> BrandScoringResult:
|
||||
queries_stmt = select(QueryModel).where(
|
||||
QueryModel.user_id == user_id,
|
||||
QueryModel.user_id == str(user_id),
|
||||
QueryModel.target_brand == brand.name,
|
||||
)
|
||||
queries_result = await db.execute(queries_stmt)
|
||||
|
|
@ -209,7 +209,7 @@ class BrandScoringDataService:
|
|||
return {platform: 0.0 for platform in REQUIRED_PLATFORMS}
|
||||
|
||||
queries_stmt = select(QueryModel).where(
|
||||
QueryModel.user_id == user_id,
|
||||
QueryModel.user_id == str(user_id),
|
||||
QueryModel.target_brand == brand.name,
|
||||
)
|
||||
queries_result = await db.execute(queries_stmt)
|
||||
|
|
@ -264,7 +264,7 @@ class BrandScoringDataService:
|
|||
competitor_names = [c.name for c in competitors]
|
||||
|
||||
competitor_queries_stmt = select(QueryModel).where(
|
||||
QueryModel.user_id == user_id,
|
||||
QueryModel.user_id == str(user_id),
|
||||
QueryModel.target_brand.in_(competitor_names),
|
||||
)
|
||||
competitor_queries_result = await db.execute(competitor_queries_stmt)
|
||||
|
|
|
|||
|
|
@ -34,6 +34,9 @@ python-dotenv
|
|||
# YAML解析
|
||||
pyyaml>=6.0
|
||||
|
||||
# Agent框架(已独立为 fischer-agentkit 项目,代码仍保留在 app/agent_framework/ 中)
|
||||
# fischer-agentkit>=0.1.0
|
||||
|
||||
# 测试依赖
|
||||
pytest>=8.0
|
||||
pytest-asyncio>=0.23.0
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -22,7 +22,7 @@
|
|||
"@radix-ui/react-select": "^2.2.6",
|
||||
"@radix-ui/react-slot": "^1.2.4",
|
||||
"@radix-ui/react-tabs": "^1.1.13",
|
||||
"@sentry/nextjs": "^9.0.0",
|
||||
"@sentry/nextjs": "^9.47.1",
|
||||
"class-variance-authority": "^0.7.1",
|
||||
"clsx": "^2.1.1",
|
||||
"lucide-react": "^1.8.0",
|
||||
|
|
|
|||
Loading…
Reference in New Issue