feat(U8): GEO agent framework adapter layer - fischer-agentkit integration

- Add 8 YAML configs for agent declarative definition
- Add 3 custom handlers (citation, monitor, schema)
- Add 7 business tool registration modules
- Add adapter.py with DI-based registry and dispatcher
- Add database migration for agentkit tables (episodic_memories, evolution_logs, ab_test_configs)
- Add fischer-agentkit>=0.1.0 dependency
- Refactor agent_framework/__init__.py to re-export from agentkit
- Minor API fixes (schema_advisor, scoring, strategy, suggestions)
- Frontend package updates
This commit is contained in:
chiguyong 2026-06-05 19:08:36 +08:00
parent 903803c09a
commit 3b581b22ba
34 changed files with 5549 additions and 1316 deletions

View File

@ -0,0 +1,87 @@
"""add agentkit extension tables
Revision ID: b001_agentkit_extension
Revises: a79329c23b20
Create Date: 2026-06-04
新增 fischer-agentkit 扩展表
- episodic_memories: 经验记忆
- evolution_logs: 进化日志
- ab_test_configs: A/B测试配置
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
revision = "b001_agentkit_extension"
down_revision = "a79329c23b20"
branch_labels = None
depends_on = None
def upgrade() -> None:
# episodic_memories
op.create_table(
"episodic_memories",
sa.Column("id", sa.Uuid(), nullable=False),
sa.Column("agent_name", sa.String(50), nullable=False),
sa.Column("task_type", sa.String(50), nullable=False),
sa.Column("input_data", postgresql.JSONB(), nullable=True),
sa.Column("output_data", postgresql.JSONB(), nullable=True),
sa.Column("success", sa.Boolean(), nullable=True),
sa.Column("quality_score", sa.Float(), nullable=True),
sa.Column("reflection", sa.Text(), nullable=True),
sa.Column("embedding_id", sa.String(100), nullable=True),
sa.Column("tags", postgresql.JSONB(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.create_index("idx_episodic_agent_name", "episodic_memories", ["agent_name"])
op.create_index("idx_episodic_task_type", "episodic_memories", ["task_type"])
op.create_index("idx_episodic_created_at", "episodic_memories", ["created_at"])
# evolution_logs
op.create_table(
"evolution_logs",
sa.Column("id", sa.Uuid(), nullable=False),
sa.Column("agent_name", sa.String(50), nullable=False),
sa.Column("change_type", sa.String(30), nullable=False),
sa.Column("before", postgresql.JSONB(), nullable=True),
sa.Column("after", postgresql.JSONB(), nullable=True),
sa.Column("metrics", postgresql.JSONB(), nullable=True),
sa.Column("ab_test_id", sa.String(100), nullable=True),
sa.Column("status", sa.String(20), server_default="pending", nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.create_index("idx_evolution_agent_name", "evolution_logs", ["agent_name"])
op.create_index("idx_evolution_change_type", "evolution_logs", ["change_type"])
op.create_index("idx_evolution_status", "evolution_logs", ["status"])
op.create_index("idx_evolution_created_at", "evolution_logs", ["created_at"])
# ab_test_configs
op.create_table(
"ab_test_configs",
sa.Column("id", sa.Uuid(), nullable=False),
sa.Column("agent_name", sa.String(50), nullable=False),
sa.Column("test_name", sa.String(100), nullable=False),
sa.Column("variant_a", postgresql.JSONB(), nullable=False),
sa.Column("variant_b", postgresql.JSONB(), nullable=False),
sa.Column("traffic_split", sa.Float(), server_default="0.5", nullable=False),
sa.Column("status", sa.String(20), server_default="running", nullable=False),
sa.Column("winner", sa.String(10), nullable=True),
sa.Column("metrics", postgresql.JSONB(), nullable=True),
sa.Column("started_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.Column("ended_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
op.create_index("idx_ab_test_agent_name", "ab_test_configs", ["agent_name"])
op.create_index("idx_ab_test_status", "ab_test_configs", ["status"])
def downgrade() -> None:
op.drop_table("ab_test_configs")
op.drop_table("evolution_logs")
op.drop_table("episodic_memories")

View File

@ -1,5 +1,17 @@
"""GEO AI Agent 框架 - 解耦式 Agent 管理与调度"""
"""GEO AI Agent 框架
架构说明
- 框架核心能力来自 fischer-agentkit pip 依赖
- 旧框架文件base.py, dispatcher.py 保留作为兼容层
- Agent 通过 YAML 配置 + ConfigDrivenAgent 创建
- Agent 类逐步迁移为薄适配层
推荐使用方式
from agentkit.core.config_driven import ConfigDrivenAgent # 新方式
from app.agent_framework import get_agent_registry # 适配入口
"""
# ---- 旧框架兼容导出(保持现有代码不 break----
from app.agent_framework.base import BaseAgent
from app.agent_framework.config_manager import AgentConfigManager
from app.agent_framework.dispatcher import TaskDispatcher
@ -28,13 +40,25 @@ from app.agent_framework.protocol import (
)
from app.agent_framework.registry import AgentRegistry
# ---- agentkit 框架导出(新方式)----
from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent
from agentkit.tools.function_tool import FunctionTool
from agentkit.tools.registry import ToolRegistry
# ---- 业务适配入口 ----
from app.agent_framework.adapter import (
create_agents_from_configs,
get_agent_registry,
get_task_dispatcher,
)
__all__ = [
# Core
# Core (旧框架兼容)
"BaseAgent",
"AgentRegistry",
"TaskDispatcher",
"AgentConfigManager",
# Protocol
# Protocol (旧框架兼容)
"AgentCapability",
"AgentType",
"AgentStatus",
@ -42,7 +66,7 @@ __all__ = [
"TaskProgress",
"TaskResult",
"TaskStatus",
# Exceptions
# Exceptions (旧框架兼容)
"AgentFrameworkError",
"AgentNotFoundError",
"AgentAlreadyRegisteredError",
@ -55,4 +79,13 @@ __all__ = [
"TaskCancelledError",
"NoAvailableAgentError",
"ConfigValidationError",
# agentkit 新方式
"AgentConfig",
"ConfigDrivenAgent",
"FunctionTool",
"ToolRegistry",
# 业务适配入口
"create_agents_from_configs",
"get_agent_registry",
"get_task_dispatcher",
]

View File

@ -0,0 +1,175 @@
"""GEO Agent 适配层 - 基于 fischer-agentkit 的新架构
职责
1. YAML 配置创建 ConfigDrivenAgent
2. 提供依赖注入的 Registry Dispatcher
3. 兼容旧版 Agent 类的导入路径
使用方式
# 新方式(推荐)
from app.agent_framework import get_agent_registry, get_task_dispatcher
registry = get_agent_registry()
dispatcher = get_task_dispatcher()
# 旧方式(兼容)
from app.agent_framework.agents import CitationDetectorAgent
agent = CitationDetectorAgent()
"""
import logging
from pathlib import Path
from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent
from agentkit.core.registry import AgentRegistry as AgentkitRegistry
from agentkit.core.dispatcher import TaskDispatcher as AgentkitDispatcher
from agentkit.tools.registry import ToolRegistry
from app.agent_framework.tools import get_tool_registry
logger = logging.getLogger(__name__)
_AGENTKIT_REGISTRY: AgentkitRegistry | None = None
_AGENTKIT_DISPATCHER: AgentkitDispatcher | None = None
def _get_configs_dir() -> Path:
"""获取 Agent YAML 配置目录"""
return Path(__file__).parent / "agents" / "configs"
def _get_llm_client():
"""获取 LLM 客户端"""
from app.services.llm import LLMFactory
return LLMFactory.get_default()
def _get_custom_handlers() -> dict:
"""获取所有自定义 handler"""
from app.agent_framework.agents.custom_handlers.citation_handler import handle_citation_task
from app.agent_framework.agents.custom_handlers.monitor_handler import handle_monitor_task
from app.agent_framework.agents.custom_handlers.schema_handler import handle_schema_task
return {
"app.agent_framework.agents.custom_handlers.citation_handler.handle_citation_task": handle_citation_task,
"app.agent_framework.agents.custom_handlers.monitor_handler.handle_monitor_task": handle_monitor_task,
"app.agent_framework.agents.custom_handlers.schema_handler.handle_schema_task": handle_schema_task,
}
def _get_session_factory():
"""获取数据库会话工厂"""
from app.database import AsyncSessionLocal
return AsyncSessionLocal
def _get_redis_factory():
"""获取 Redis 连接工厂"""
from app.core.redis import get_redis
return get_redis
def _get_agent_model():
"""获取 Agent ORM 模型"""
from app.models.agent import AgentRegistry as AgentRegistryModel
return AgentRegistryModel
def _get_task_model():
"""获取 Task ORM 模型"""
from app.models.agent import AgentTask as AgentTaskModel
return AgentTaskModel
def _get_task_log_model():
"""获取 TaskLog ORM 模型"""
from app.models.agent import AgentTaskLog as AgentTaskLogModel
return AgentTaskLogModel
def create_agents_from_configs() -> list[ConfigDrivenAgent]:
"""从 YAML 配置目录创建所有 Agent新架构"""
configs_dir = _get_configs_dir()
tool_registry = get_tool_registry()
llm_client = _get_llm_client()
custom_handlers = _get_custom_handlers()
agents = []
for yaml_file in sorted(configs_dir.glob("*.yaml")):
try:
config = AgentConfig.from_yaml(str(yaml_file))
agent = ConfigDrivenAgent(
config=config,
tool_registry=tool_registry,
llm_client=llm_client,
custom_handlers=custom_handlers,
)
agents.append(agent)
logger.info(f"Created agent '{config.name}' from {yaml_file.name}")
except Exception as e:
logger.error(f"Failed to create agent from {yaml_file.name}: {e}")
return agents
def get_agent_registry() -> AgentkitRegistry:
"""获取 agentkit AgentRegistry懒初始化依赖注入"""
global _AGENTKIT_REGISTRY
if _AGENTKIT_REGISTRY is None:
_AGENTKIT_REGISTRY = AgentkitRegistry(
session_factory=_get_session_factory(),
agent_model=_get_agent_model(),
)
agents = create_agents_from_configs()
for agent in agents:
_AGENTKIT_REGISTRY.register(agent)
logger.info(f"Agentkit AgentRegistry initialized with {len(agents)} agents")
return _AGENTKIT_REGISTRY
def get_task_dispatcher() -> AgentkitDispatcher:
"""获取 agentkit TaskDispatcher懒初始化依赖注入"""
global _AGENTKIT_DISPATCHER
if _AGENTKIT_DISPATCHER is None:
_AGENTKIT_DISPATCHER = AgentkitDispatcher(
redis_factory=_get_redis_factory(),
session_factory=_get_session_factory(),
agent_model=_get_agent_model(),
task_model=_get_task_model(),
task_log_model=_get_task_log_model(),
)
logger.info("Agentkit TaskDispatcher initialized")
return _AGENTKIT_DISPATCHER
def get_legacy_agent(name: str):
"""获取旧版 Agent 实例(兼容层)
旧代码可能直接实例化 Agent 此方法提供兼容性
逐步迁移后可移除
"""
from app.agent_framework.agents import (
CitationDetectorAgent,
CompetitorAnalyzerAgent,
ContentGeneratorAgent,
DeAIAgent,
GEOOptimizerAgent,
MonitorAgent,
SchemaAdvisorAgent,
TrendAgent,
)
legacy_map = {
"citation_detector": CitationDetectorAgent,
"competitor_analyzer": CompetitorAnalyzerAgent,
"content_generator": ContentGeneratorAgent,
"deai_agent": DeAIAgent,
"geo_optimizer": GEOOptimizerAgent,
"monitor": MonitorAgent,
"schema_advisor": SchemaAdvisorAgent,
"trend_agent": TrendAgent,
}
agent_cls = legacy_map.get(name)
if agent_cls:
return agent_cls()
raise ValueError(f"Unknown agent name: {name}")

View File

@ -0,0 +1 @@
"""Agent Configs 包"""

View File

@ -0,0 +1,56 @@
name: citation_detector
agent_type: citation_detection
version: "1.0.0"
description: "AI平台引用检测Agent检测目标品牌在各AI平台回答中的引用情况"
task_mode: custom
supported_tasks:
- citation_detect
- citation_detect_single
max_concurrency: 3
custom_handler: "app.agent_framework.agents.custom_handlers.citation_handler.handle_citation_task"
input_schema:
type: object
properties:
query_id:
type: string
description: 查询IDcitation_detect模式
keyword:
type: string
description: 关键词citation_detect_single模式
platform:
type: string
description: 平台名称citation_detect_single模式
target_brand:
type: string
description: 目标品牌citation_detect_single模式
brand_aliases:
type: array
items:
type: string
description: 品牌别名列表
output_schema:
type: object
properties:
query_id:
type: string
keyword:
type: string
total_records:
type: integer
cited_count:
type: integer
records:
type: array
tools:
- execute_single_platform
- get_or_create_task
memory:
working:
enabled: true
episodic:
enabled: true
track_success: true

View File

@ -0,0 +1,48 @@
name: competitor_analyzer
agent_type: competitor_analysis
version: "1.0.0"
description: "竞品策略分析Agent对比品牌与竞品的引用数据识别差距领域发现机会点生成策略建议"
task_mode: tool_call
supported_tasks:
- competitor_analyze
- competitor_gap_analysis
max_concurrency: 2
input_schema:
type: object
required:
- brand_id
properties:
brand_id:
type: string
description: 品牌ID
analysis_types:
type: array
items:
type: string
description: 分析类型列表
period_days:
type: integer
description: 分析周期(天)
default: 30
output_schema:
type: object
properties:
brand_id:
type: string
analysis:
type: object
recommendations:
type: array
tools:
- competitor_analyze
- competitor_gap_analysis
memory:
working:
enabled: true
episodic:
enabled: true
track_success: true

View File

@ -0,0 +1,97 @@
name: content_generator
agent_type: content_generation
version: "1.0.0"
description: "AI内容生成Agent支持选题推荐和文章生成可结合知识库RAG检索"
task_mode: llm_generate
supported_tasks:
- generate_topics
- generate_article
max_concurrency: 2
input_schema:
type: object
required:
- target_keyword
properties:
target_keyword:
type: string
description: 目标关键词
brand_name:
type: string
description: 品牌名称
brand_description:
type: string
description: 品牌描述
target_platform:
type: string
description: 目标平台
default: "通用"
knowledge_base_ids:
type: array
items:
type: string
description: 知识库ID列表用于RAG检索
topic_title:
type: string
description: 选题标题generate_article时使用
word_count:
type: integer
description: 目标字数
default: 2000
content_style:
type: string
description: 内容风格
default: "专业严谨"
content_angle:
type: string
description: 内容角度
model:
type: string
description: 指定LLM模型
output_schema:
type: object
properties:
topics:
type: array
description: 选题列表
content:
type: string
description: 生成的文章内容
word_count:
type: integer
usage:
type: object
prompt:
identity: "你是一个专业的内容生成助手擅长为品牌创作高质量的SEO/GEO优化内容"
context: "品牌需要通过优质内容提升在AI搜索引擎中的可见性和引用率"
instructions: |
根据用户提供的关键词、品牌信息和知识库内容,生成符合要求的内容。
- generate_topics: 生成选题列表,每个选题包含 title、reason、keywords 字段
- generate_article: 生成完整文章,确保内容专业、结构清晰、关键词自然融入
constraints: |
- 内容必须原创,避免抄袭
- 关键词密度适中,不要堆砌
- 文章结构清晰,段落分明
- 数据和引用需标注来源
output_format: "以 JSON 格式输出generate_topics 返回 {topics: [{title, reason, keywords}]}generate_article 返回 {content, word_count}"
examples: ""
llm:
model: "deepseek"
temperature: 0.7
max_tokens: 4000
tools:
- retrieve_knowledge
memory:
working:
enabled: true
episodic:
enabled: true
track_success: true
semantic:
enabled: true
knowledge_base_ids_field: "knowledge_base_ids"

View File

@ -0,0 +1,76 @@
name: deai_agent
agent_type: deai_processing
version: "1.1.0"
description: "内容去AI化Agent消除AI生成特征使文章更自然流畅"
task_mode: llm_generate
supported_tasks:
- deai_process
max_concurrency: 2
input_schema:
type: object
required:
- content
properties:
content:
type: string
description: 待处理的文章内容
platform:
type: string
description: 目标平台ID如 zhihu, wechat
style:
type: string
description: 目标风格
default: "自然流畅"
preserve_structure:
type: boolean
description: 是否保留原有结构
default: true
output_schema:
type: object
properties:
content:
type: string
description: 处理后的内容
original_word_count:
type: integer
processed_word_count:
type: integer
usage:
type: object
detected_ai_patterns:
type: array
prompt:
identity: "你是一个专业的内容改写专家擅长将AI生成的文本改写为自然、人类化的表达"
context: "平台对AI生成内容的检测越来越严格需要将内容改写为更自然的风格"
instructions: |
对提供的文章内容进行去AI化处理
1. 替换AI常用表达如"总之"、"综上所述"、"首先其次最后"等)
2. 增加口语化表达和个人观点
3. 调整句式结构,避免过于工整的排比
4. 保留核心信息和数据
5. 如有平台特定要求,遵循平台规则
constraints: |
- 保留原文的核心信息和数据
- 不要改变文章的主题和立场
- 保持专业性的同时增加自然感
- 如指定平台,需符合该平台的内容规范
output_format: "返回处理后的完整文章内容"
examples: ""
llm:
model: "deepseek"
temperature: 0.9
max_tokens: 8000
tools:
- detect_ai_patterns
memory:
working:
enabled: true
episodic:
enabled: true
track_success: true

View File

@ -0,0 +1,78 @@
name: geo_optimizer
agent_type: geo_optimization
version: "1.0.0"
description: "GEO/SEO内容优化Agent提升内容在AI搜索引擎中的可见性和引用率"
task_mode: llm_generate
supported_tasks:
- geo_optimize
max_concurrency: 2
input_schema:
type: object
required:
- content
- target_keywords
properties:
content:
type: string
description: 待优化文章
target_keywords:
type: array
items:
type: string
description: 目标关键词列表
target_platform:
type: string
description: 目标平台
default: "通用"
optimization_level:
type: string
enum: [light, moderate, aggressive]
description: 优化级别
default: "moderate"
output_schema:
type: object
properties:
optimized_content:
type: string
seo_score:
type: number
changes:
type: array
items:
type: string
usage:
type: object
prompt:
identity: "你是一个GEO/SEO优化专家擅长优化内容以提升在AI搜索引擎中的可见性"
context: "品牌需要通过内容优化提升在AI搜索结果中的引用率和排名"
instructions: |
对提供的文章进行GEO/SEO优化
1. 自然融入目标关键词
2. 优化标题和段落结构
3. 增加结构化数据标记建议
4. 提升内容的权威性和引用价值
5. 根据optimization_level调整优化力度
constraints: |
- 优化后的内容必须保持原意
- 关键词融入要自然,避免堆砌
- 保持文章可读性
- 不要添加虚假信息
output_format: "以 JSON 格式输出: {optimized_content: string, seo_score: number, changes: [string]}"
examples: ""
llm:
model: "deepseek"
temperature: 0.5
max_tokens: 8000
tools: []
memory:
working:
enabled: true
episodic:
enabled: true
track_success: true

View File

@ -0,0 +1,55 @@
name: monitor
agent_type: performance_tracker
version: "1.0.0"
description: "效果追踪Agent监测品牌引用量、情感、排名变化生成变化报告"
task_mode: custom
supported_tasks:
- monitor_track
- monitor_check_single
max_concurrency: 3
custom_handler: "app.agent_framework.agents.custom_handlers.monitor_handler.handle_monitor_task"
input_schema:
type: object
required:
- brand_id
properties:
brand_id:
type: string
description: 品牌ID
keyword:
type: string
description: 关键词monitor_check_single模式
platform:
type: string
description: 平台名称monitor_check_single模式
check_interval_hours:
type: integer
description: 检测间隔小时数
default: 24
output_schema:
type: object
properties:
brand_id:
type: string
brand_name:
type: string
total_queries:
type: integer
checked_records:
type: integer
reports:
type: array
tools:
- monitor_check_and_compare
- monitor_generate_report
- monitor_create_record
memory:
working:
enabled: true
episodic:
enabled: true
track_success: true

View File

@ -0,0 +1,49 @@
name: schema_advisor
agent_type: schema_advisor
version: "1.0.0"
description: "Schema优化建议Agent识别Schema缺失维度生成JSON-LD结构化数据建议"
task_mode: custom
supported_tasks:
- schema_advise
max_concurrency: 2
custom_handler: "app.agent_framework.agents.custom_handlers.schema_handler.handle_schema_task"
input_schema:
type: object
required:
- brand_id
properties:
brand_id:
type: string
description: 品牌ID
diagnosis_data:
type: object
description: 诊断数据
brand_info:
type: object
description: 品牌信息
focus_dimensions:
type: array
items:
type: string
description: 重点关注维度
output_schema:
type: object
properties:
brand_id:
type: string
suggestions:
type: array
total:
type: integer
tools:
- fill_schema_with_llm
memory:
working:
enabled: true
episodic:
enabled: true
track_success: true

View File

@ -0,0 +1,53 @@
name: trend_agent
agent_type: trend_analysis
version: "1.0.0"
description: "趋势洞察Agent分析品牌引用趋势、识别热点话题、推断变化原因并生成建议"
task_mode: tool_call
supported_tasks:
- trend_insight
- trend_hotspot
max_concurrency: 2
input_schema:
type: object
required:
- brand_id
properties:
brand_id:
type: string
description: 品牌ID
days:
type: integer
description: 分析天数
default: 30
platforms:
type: array
items:
type: string
description: 平台列表
keywords:
type: array
items:
type: string
description: 关键词列表
output_schema:
type: object
properties:
brand_id:
type: string
trends:
type: array
hotspots:
type: array
tools:
- trend_insight
- trend_hotspot
memory:
working:
enabled: true
episodic:
enabled: true
track_success: true

View File

@ -0,0 +1 @@
"""Custom Handlers 包"""

View File

@ -0,0 +1,179 @@
"""CitationDetector 自定义 Handler
处理 citation_detect citation_detect_single 两种任务类型
包含数据库操作和平台调用等复杂业务逻辑
"""
import logging
import uuid
from datetime import datetime, timedelta, timezone
from sqlalchemy import select
from agentkit.core.protocol import TaskMessage
logger = logging.getLogger(__name__)
async def handle_citation_task(task: TaskMessage) -> dict:
"""引用检测任务入口
根据 task_type 路由到不同的处理逻辑
- citation_detect: 全量检测多平台
- citation_detect_single: 单平台检测
"""
if task.task_type == "citation_detect":
return await _execute_full_detect(task)
elif task.task_type == "citation_detect_single":
return await _execute_single_detect(task)
else:
raise ValueError(f"Unsupported task type: {task.task_type}")
async def _execute_full_detect(task: TaskMessage) -> dict:
"""全量引用检测:遍历查询关联的所有平台"""
from app.database import AsyncSessionLocal
from app.models.citation_record import CitationRecord
from app.models.query import Query
from app.models.query_task import QueryTask
from app.services.ai_engine.platform_bridge import execute_single_platform
query_id = task.input_data.get("query_id")
if not query_id:
raise ValueError("input_data must contain 'query_id'")
async with AsyncSessionLocal() as db:
stmt = select(Query).where(Query.id == query_id)
result = await db.execute(stmt)
query = result.scalar_one_or_none()
if not query:
raise ValueError(f"Query {query_id} not found")
records: list[CitationRecord] = []
platforms = query.platforms or ["wenxin", "kimi"]
brand_aliases = query.brand_aliases or []
for i, platform_name in enumerate(platforms):
task_obj = await _get_or_create_task(db, query.id, platform_name)
task_obj.status = "running"
task_obj.started_at = datetime.now(timezone.utc)
task_obj.error_message = None
await db.commit()
try:
detect_result = await execute_single_platform(
keyword=query.keyword,
platform=platform_name,
target_brand=query.target_brand,
brand_aliases=brand_aliases,
)
record = CitationRecord.from_citation_result(
query_id=query.id,
platform=platform_name,
result=detect_result,
)
db.add(record)
records.append(record)
task_obj.status = "success"
task_obj.completed_at = datetime.now(timezone.utc)
await db.commit()
except Exception as e:
logger.error(f"平台 {platform_name} 查询失败: {e}")
task_obj.status = "failed"
task_obj.error_message = str(e)
task_obj.completed_at = datetime.now(timezone.utc)
record = CitationRecord.from_citation_result(
query_id=query.id,
platform=platform_name,
result={"cited": False, "raw_response": str(e)},
)
db.add(record)
records.append(record)
await db.commit()
query.last_queried_at = datetime.now(timezone.utc)
query.next_query_at = _calculate_next_query_at(query.frequency)
await db.commit()
record_summaries = [
{
"id": str(r.id),
"platform": r.platform,
"cited": r.cited,
"confidence": r.confidence,
"match_type": r.match_type,
}
for r in records
]
return {
"query_id": str(query_id),
"keyword": query.keyword,
"total_records": len(records),
"cited_count": sum(1 for r in records if r.cited),
"records": record_summaries,
}
async def _execute_single_detect(task: TaskMessage) -> dict:
"""单平台引用检测"""
from app.services.ai_engine.platform_bridge import execute_single_platform
keyword = task.input_data.get("keyword")
platform = task.input_data.get("platform")
target_brand = task.input_data.get("target_brand")
brand_aliases = task.input_data.get("brand_aliases", [])
if not all([keyword, platform, target_brand]):
raise ValueError("input_data must contain 'keyword', 'platform', 'target_brand'")
result = await execute_single_platform(
keyword=keyword,
platform=platform,
target_brand=target_brand,
brand_aliases=brand_aliases,
)
output = {k: v for k, v in result.items() if k != "raw_response"}
return output
async def _get_or_create_task(db, query_id: uuid.UUID, platform: str):
"""获取或创建查询任务"""
from app.models.query_task import QueryTask
stmt = select(QueryTask).where(
QueryTask.query_id == query_id,
QueryTask.platform == platform,
)
result = await db.execute(stmt)
task_obj = result.scalar_one_or_none()
if not task_obj:
task_obj = QueryTask(
query_id=query_id,
platform=platform,
status="pending",
)
db.add(task_obj)
await db.commit()
await db.refresh(task_obj)
return task_obj
def _calculate_next_query_at(frequency: str | None) -> datetime:
"""计算下次查询时间"""
now = datetime.now(timezone.utc)
freq_map = {
"daily": timedelta(days=1),
"weekly": timedelta(days=7),
"monthly": timedelta(days=30),
}
delta = freq_map.get(frequency or "weekly", timedelta(days=7))
return now + delta

View File

@ -0,0 +1,133 @@
"""Monitor 自定义 Handler
处理 monitor_track monitor_check_single 两种任务类型
包含数据库操作和监测服务调用等复杂业务逻辑
"""
import logging
import uuid
from datetime import datetime, timezone
from sqlalchemy import select
from agentkit.core.protocol import TaskMessage
logger = logging.getLogger(__name__)
async def handle_monitor_task(task: TaskMessage) -> dict:
"""效果追踪任务入口
根据 task_type 路由到不同的处理逻辑
- monitor_track: 品牌全量效果追踪
- monitor_check_single: 单关键词检测
"""
if task.task_type == "monitor_track":
return await _monitor_track(task)
elif task.task_type == "monitor_check_single":
return await _monitor_check_single(task)
else:
raise ValueError(f"不支持的任务类型: {task.task_type}")
async def _monitor_track(task: TaskMessage) -> dict:
"""品牌全量效果追踪"""
from app.database import AsyncSessionLocal
from app.models.brand import Brand
from app.models.monitoring import MonitoringRecord
from app.models.query import Query
from app.services.monitoring.monitor_service import MonitorService
input_data = task.input_data
brand_id = input_data.get("brand_id")
if not brand_id:
raise ValueError("input_data必须包含'brand_id'字段")
brand_id = uuid.UUID(brand_id)
async with AsyncSessionLocal() as db:
stmt = select(Brand).where(Brand.id == brand_id)
result = await db.execute(stmt)
brand = result.scalar_one_or_none()
if not brand:
raise ValueError(f"品牌不存在: {brand_id}")
queries_stmt = select(Query).where(Query.target_brand == brand.name)
queries_result = await db.execute(queries_stmt)
queries = list(queries_result.scalars().all())
if not queries:
return {
"brand_id": str(brand_id),
"brand_name": brand.name,
"total_queries": 0,
"reports": [],
}
total_queries = len(queries)
reports = []
service = MonitorService()
monitoring_stmt = select(MonitoringRecord).where(
MonitoringRecord.brand_id == brand_id,
MonitoringRecord.status == "active",
)
monitoring_result = await db.execute(monitoring_stmt)
monitoring_records = list(monitoring_result.scalars().all())
for record in monitoring_records:
updated_record = await service.check_and_compare(db, record.id)
if updated_record:
report = await service.generate_change_report(db, updated_record.id)
if report:
reports.append(report)
return {
"brand_id": str(brand_id),
"brand_name": brand.name,
"total_queries": total_queries,
"checked_records": len(monitoring_records),
"reports": reports,
}
async def _monitor_check_single(task: TaskMessage) -> dict:
"""单关键词检测"""
from app.database import AsyncSessionLocal
from app.services.monitoring.monitor_service import MonitorService
input_data = task.input_data
brand_id = input_data.get("brand_id")
keyword = input_data.get("keyword")
platform = input_data.get("platform")
if not brand_id:
raise ValueError("input_data必须包含'brand_id'字段")
brand_id = uuid.UUID(brand_id)
async with AsyncSessionLocal() as db:
service = MonitorService()
record = await service.create_monitoring_record(
db=db,
brand_id=brand_id,
query_keywords=keyword,
platform=platform,
check_interval_hours=input_data.get("check_interval_hours", 24),
)
updated_record = await service.check_and_compare(db, record.id)
report = None
if updated_record:
report = await service.generate_change_report(db, updated_record.id)
return {
"record_id": str(record.id),
"brand_id": str(brand_id),
"keyword": keyword,
"platform": platform,
"change_type": updated_record.change_type if updated_record else None,
"report": report,
}

View File

@ -0,0 +1,285 @@
"""SchemaAdvisor 自定义 Handler
处理 schema_advise 任务类型
包含维度识别模板匹配LLM填充验证排序等复杂业务逻辑
"""
import copy
import json
import logging
from agentkit.core.protocol import TaskMessage
logger = logging.getLogger(__name__)
SCHEMA_TEMPLATES = {
"Organization": {
"@context": "https://schema.org",
"@type": "Organization",
"name": "",
"description": "",
"url": "",
"logo": "",
"sameAs": [],
"contactPoint": {
"@type": "ContactPoint",
"contactType": "customer service",
"telephone": "",
},
},
"Product": {
"@context": "https://schema.org",
"@type": "Product",
"name": "",
"description": "",
"brand": {"@type": "Brand", "name": ""},
"offers": {
"@type": "Offer",
"priceCurrency": "CNY",
"availability": "https://schema.org/InStock",
},
},
"FAQPage": {
"@context": "https://schema.org",
"@type": "FAQPage",
"mainEntity": [
{
"@type": "Question",
"name": "",
"acceptedAnswer": {"@type": "Answer", "text": ""},
}
],
},
"Article": {
"@context": "https://schema.org",
"@type": "Article",
"headline": "",
"description": "",
"author": {"@type": "Organization", "name": ""},
"datePublished": "",
"image": "",
},
"LocalBusiness": {
"@context": "https://schema.org",
"@type": "LocalBusiness",
"name": "",
"address": {
"@type": "PostalAddress",
"streetAddress": "",
"addressLocality": "",
"addressRegion": "",
"postalCode": "",
"addressCountry": "CN",
},
"geo": {"@type": "GeoCoordinates", "latitude": "", "longitude": ""},
"telephone": "",
"openingHours": "",
},
}
DIMENSION_SCHEMA_MAP = {
"schema_marketing": ["Organization", "LocalBusiness"],
"entity_clarity": ["Organization", "Product"],
"citation_readiness": ["FAQPage", "Article"],
"brand_visibility": ["Organization", "Product"],
"local_seo": ["LocalBusiness"],
}
PRIORITY_THRESHOLD = {"high": 30.0, "medium": 60.0}
DIFFICULTY_MAP = {
"Organization": "easy",
"Product": "medium",
"FAQPage": "medium",
"Article": "easy",
"LocalBusiness": "hard",
}
IMPACT_DESCRIPTIONS = {
"Organization": "增强品牌实体识别提升AI搜索引擎对品牌的理解和引用概率",
"Product": "提升产品在搜索结果中的富摘要展示,增加点击率和引用率",
"FAQPage": "增加FAQ富摘要展示机会提升在AI回答中的直接引用概率",
"Article": "优化文章内容的结构化表达提升AI搜索引擎的内容理解和引用",
"LocalBusiness": "增强本地搜索可见性,提升地理位置相关查询的引用率",
}
async def handle_schema_task(task: TaskMessage) -> dict:
"""Schema建议任务入口"""
input_data = task.input_data
brand_id = input_data.get("brand_id")
diagnosis_data = input_data.get("diagnosis_data", {})
brand_info = input_data.get("brand_info", {})
focus_dimensions = input_data.get("focus_dimensions")
if not brand_id:
raise ValueError("input_data必须包含'brand_id'字段")
# 1. 识别缺失维度
missing_dimensions = _identify_missing_dimensions(diagnosis_data, focus_dimensions)
# 2. 匹配预定义模板
matched = _match_templates(missing_dimensions)
# 3. LLM填充
filled = await _fill_with_llm(matched, brand_info)
# 4. 验证和排序
validated = _validate_and_sort(filled)
return {
"brand_id": brand_id,
"suggestions": validated,
"total": len(validated),
}
def _identify_missing_dimensions(
diagnosis_data: dict,
focus_dimensions: list[str] | None = None,
) -> list[dict]:
"""识别Schema缺失维度"""
dimensions = []
dimension_scores = diagnosis_data.get("dimensions", {})
for dim_name, dim_info in dimension_scores.items():
if dim_name not in DIMENSION_SCHEMA_MAP:
continue
if focus_dimensions and dim_name not in focus_dimensions:
continue
score = dim_info.get("score", 0) if isinstance(dim_info, dict) else dim_info
max_score = dim_info.get("max_score", 100) if isinstance(dim_info, dict) else 100
percentage = (score / max_score * 100) if max_score > 0 else 0
if percentage < 80:
dimensions.append({
"dimension": dim_name,
"current_score": round(score, 2),
"max_score": max_score,
"percentage": round(percentage, 2),
})
if not dimensions and diagnosis_data:
overall = diagnosis_data.get("overall_score", 0)
if overall < 80:
for dim_name in DIMENSION_SCHEMA_MAP:
if focus_dimensions and dim_name not in focus_dimensions:
continue
dimensions.append({
"dimension": dim_name,
"current_score": 0,
"max_score": 100,
"percentage": 0,
})
return dimensions
def _match_templates(missing_dimensions: list[dict]) -> list[dict]:
"""匹配预定义Schema模板"""
matched = []
seen_types = set()
for dim in missing_dimensions:
schema_types = DIMENSION_SCHEMA_MAP.get(dim["dimension"], [])
for schema_type in schema_types:
if schema_type in seen_types:
continue
seen_types.add(schema_type)
template = SCHEMA_TEMPLATES.get(schema_type)
if template:
percentage = dim["percentage"]
if percentage < PRIORITY_THRESHOLD["high"]:
priority = "high"
elif percentage < PRIORITY_THRESHOLD["medium"]:
priority = "medium"
else:
priority = "low"
matched.append({
"schema_type": schema_type,
"priority": priority,
"diagnosis_dimensions": {
"dimension": dim["dimension"],
"current_score": dim["current_score"],
"max_score": dim["max_score"],
"percentage": dim["percentage"],
},
"json_ld_template": copy.deepcopy(template),
"implementation_difficulty": DIFFICULTY_MAP.get(schema_type, "medium"),
})
return matched
async def _fill_with_llm(matched: list[dict], brand_info: dict) -> list[dict]:
"""使用LLM填充Schema模板"""
from app.services.llm import LLMFactory, LLMError
from app.agent_framework.prompts.schema_advisor import SCHEMA_ADVISOR_TEMPLATE
from app.utils.json_extractor import extract_json
provider = LLMFactory.get_default()
results = []
for item in matched:
schema_type = item["schema_type"]
try:
variables = {
"brand_name": brand_info.get("name", ""),
"brand_website": brand_info.get("website", ""),
"brand_industry": brand_info.get("industry", ""),
"schema_type": schema_type,
"diagnosis_data": json.dumps(item.get("diagnosis_dimensions", {}), ensure_ascii=False),
"existing_schemas": "",
}
messages = SCHEMA_ADVISOR_TEMPLATE.render(variables)
response = await provider.chat(messages, temperature=0.3, max_tokens=2048)
filled = json.loads(extract_json(response.content))
item["json_ld_filled"] = filled
item["estimated_impact"] = IMPACT_DESCRIPTIONS.get(
schema_type, f"提升{item.get('diagnosis_dimensions', {}).get('dimension', '')}维度的得分和AI引用率"
)
except (json.JSONDecodeError, LLMError, ValueError) as e:
logger.warning(f"LLM填充Schema {schema_type} 失败: {e}")
item["json_ld_filled"] = None
item["estimated_impact"] = IMPACT_DESCRIPTIONS.get(
schema_type, f"提升{item.get('diagnosis_dimensions', {}).get('dimension', '')}维度的得分和AI引用率"
)
results.append(item)
return results
def _validate_json_ld(json_ld: dict) -> dict:
"""验证JSON-LD格式"""
errors = []
warnings = []
if not json_ld:
return {"is_valid": False, "errors": ["JSON-LD为空"], "warnings": []}
if "@context" not in json_ld:
errors.append("缺少@context字段")
if "@type" not in json_ld:
errors.append("缺少@type字段")
if "@context" in json_ld and json_ld["@context"] != "https://schema.org":
warnings.append(f"@context值非标准: {json_ld.get('@context')}")
if "@type" in json_ld and json_ld["@type"] not in SCHEMA_TEMPLATES:
warnings.append(f"@type非推荐类型: {json_ld.get('@type')}")
try:
json.dumps(json_ld)
except (json.JSONDecodeError, TypeError) as e:
errors.append(f"JSON序列化失败: {e}")
return {"is_valid": len(errors) == 0, "errors": errors, "warnings": warnings}
def _validate_and_sort(items: list[dict]) -> list[dict]:
"""验证并按优先级排序"""
validated = []
for item in items:
json_ld_filled = item.get("json_ld_filled")
if json_ld_filled:
validation = _validate_json_ld(json_ld_filled)
item["validation_errors"] = None if validation["is_valid"] else {
"errors": validation["errors"],
"warnings": validation["warnings"],
}
else:
item["validation_errors"] = {"errors": ["JSON-LD填充失败"], "warnings": []}
validated.append(item)
priority_order = {"high": 0, "medium": 1, "low": 2}
validated.sort(key=lambda x: priority_order.get(x.get("priority", "medium"), 1))
return validated

View File

@ -0,0 +1,38 @@
"""GEO 业务工具注册 - 统一注册入口"""
import logging
from agentkit.tools.registry import ToolRegistry
from app.agent_framework.tools.citation_tools import register_citation_tools
from app.agent_framework.tools.content_tools import register_content_tools
from app.agent_framework.tools.monitor_tools import register_monitor_tools
from app.agent_framework.tools.schema_tools import register_schema_tools
from app.agent_framework.tools.competitor_tools import register_competitor_tools
from app.agent_framework.tools.trend_tools import register_trend_tools
from app.agent_framework.tools.knowledge_tools import register_knowledge_tools
logger = logging.getLogger(__name__)
_REGISTRY: ToolRegistry | None = None
def get_tool_registry() -> ToolRegistry:
"""获取全局 ToolRegistry懒初始化"""
global _REGISTRY
if _REGISTRY is None:
_REGISTRY = ToolRegistry()
register_all_tools(_REGISTRY)
return _REGISTRY
def register_all_tools(registry: ToolRegistry) -> None:
"""注册所有业务工具"""
register_citation_tools(registry)
register_content_tools(registry)
register_monitor_tools(registry)
register_schema_tools(registry)
register_competitor_tools(registry)
register_trend_tools(registry)
register_knowledge_tools(registry)
logger.info("All GEO business tools registered")

View File

@ -0,0 +1,76 @@
"""Citation 业务工具 - 将引用检测服务注册为 FunctionTool"""
import logging
from typing import Any
from agentkit.tools.function_tool import FunctionTool
from agentkit.tools.registry import ToolRegistry
logger = logging.getLogger(__name__)
async def execute_single_platform(
keyword: str,
platform: str,
target_brand: str,
brand_aliases: list[str] | None = None,
) -> dict:
"""调用平台执行单次引用检测"""
from app.services.ai_engine.platform_bridge import execute_single_platform as _exec
return await _exec(
keyword=keyword,
platform=platform,
target_brand=target_brand,
brand_aliases=brand_aliases or [],
)
async def get_or_create_task(query_id: str, platform: str) -> dict:
"""获取或创建查询任务"""
import uuid
from datetime import datetime
from sqlalchemy import select
from app.database import AsyncSessionLocal
from app.models.query_task import QueryTask
async with AsyncSessionLocal() as db:
stmt = select(QueryTask).where(
QueryTask.query_id == uuid.UUID(query_id),
QueryTask.platform == platform,
)
result = await db.execute(stmt)
task_obj = result.scalar_one_or_none()
if not task_obj:
task_obj = QueryTask(
query_id=uuid.UUID(query_id),
platform=platform,
status="pending",
)
db.add(task_obj)
await db.commit()
await db.refresh(task_obj)
return {"id": str(task_obj.id), "platform": task_obj.platform, "status": task_obj.status}
def register_citation_tools(registry: ToolRegistry) -> None:
"""注册所有引用检测相关工具"""
registry.register(
FunctionTool(
name="execute_single_platform",
description="在指定AI平台执行引用检测返回引用结果",
func=execute_single_platform,
tags=["citation", "detection"],
)
)
registry.register(
FunctionTool(
name="get_or_create_task",
description="获取或创建引用检测的查询任务记录",
func=get_or_create_task,
tags=["citation", "task"],
)
)
logger.info("Citation tools registered")

View File

@ -0,0 +1,63 @@
"""Competitor 业务工具 - 将竞品分析服务注册为 FunctionTool"""
import logging
from typing import Any
from agentkit.tools.function_tool import FunctionTool
from agentkit.tools.registry import ToolRegistry
logger = logging.getLogger(__name__)
async def competitor_analyze(
brand_id: str,
analysis_types: list[str] | None = None,
period_days: int = 30,
) -> dict:
"""执行竞品策略分析"""
from app.services.competitor.competitor_analyzer_service import CompetitorAnalyzerService
service = CompetitorAnalyzerService()
result = await service.analyze_competitor(
brand_id=brand_id,
analysis_types=analysis_types,
period_days=period_days,
)
return result
async def competitor_gap_analysis(
brand_id: str,
period_days: int = 30,
) -> dict:
"""执行竞品差距分析"""
from app.services.competitor.competitor_analyzer_service import CompetitorAnalyzerService
service = CompetitorAnalyzerService()
result = await service.analyze_competitor(
brand_id=brand_id,
analysis_types=["citation_gap", "platform_coverage", "query_overlap"],
period_days=period_days,
)
return result
def register_competitor_tools(registry: ToolRegistry) -> None:
"""注册所有竞品分析相关工具"""
registry.register(
FunctionTool(
name="competitor_analyze",
description="执行竞品策略分析,对比品牌与竞品的引用数据",
func=competitor_analyze,
tags=["competitor", "analysis"],
)
)
registry.register(
FunctionTool(
name="competitor_gap_analysis",
description="执行竞品差距分析,识别差距领域和机会点",
func=competitor_gap_analysis,
tags=["competitor", "gap"],
)
)
logger.info("Competitor tools registered")

View File

@ -0,0 +1,57 @@
"""Content 业务工具 - 将内容生成相关服务注册为 FunctionTool"""
import logging
from typing import Any
from agentkit.tools.function_tool import FunctionTool
from agentkit.tools.registry import ToolRegistry
logger = logging.getLogger(__name__)
async def retrieve_knowledge(
knowledge_base_ids: list[str],
query: str,
top_k: int = 5,
) -> dict:
"""从知识库检索相关内容"""
if not knowledge_base_ids or not query:
return {"content": "暂无相关知识库内容", "sources": []}
try:
from app.database import AsyncSessionLocal
from app.services.knowledge.rag_service import RAGService
async with AsyncSessionLocal() as session:
rag = RAGService()
results = await rag.search(
session=session,
query=query,
knowledge_base_ids=knowledge_base_ids,
top_k=top_k,
)
if results:
content_parts = []
sources = []
for r in results:
title = r.get("document_title", "未知")
content_parts.append(f"[来源: {title}]\n{r.get('content', '')}")
sources.append(title)
return {"content": "\n\n---\n\n".join(content_parts), "sources": sources}
except Exception as e:
logger.warning(f"RAG检索失败: {e}")
return {"content": "暂无相关知识库内容", "sources": []}
def register_content_tools(registry: ToolRegistry) -> None:
"""注册所有内容生成相关工具"""
registry.register(
FunctionTool(
name="retrieve_knowledge",
description="从知识库检索相关内容用于RAG增强生成",
func=retrieve_knowledge,
tags=["content", "rag", "knowledge"],
)
)
logger.info("Content tools registered")

View File

@ -0,0 +1,73 @@
"""Knowledge 业务工具 - 将知识库服务注册为 FunctionTool"""
import logging
from typing import Any
from agentkit.tools.function_tool import FunctionTool
from agentkit.tools.registry import ToolRegistry
logger = logging.getLogger(__name__)
async def search_knowledge(
query: str,
knowledge_base_ids: list[str],
top_k: int = 5,
) -> dict:
"""从知识库检索相关内容"""
if not knowledge_base_ids or not query:
return {"content": "暂无相关知识库内容", "sources": []}
try:
from app.database import AsyncSessionLocal
from app.services.knowledge.rag_service import RAGService
async with AsyncSessionLocal() as session:
rag = RAGService()
results = await rag.search(
session=session,
query=query,
knowledge_base_ids=knowledge_base_ids,
top_k=top_k,
)
if results:
content_parts = []
sources = []
for r in results:
title = r.get("document_title", "未知")
content_parts.append(f"[来源: {title}]\n{r.get('content', '')}")
sources.append(title)
return {"content": "\n\n---\n\n".join(content_parts), "sources": sources}
except Exception as e:
logger.warning(f"知识库检索失败: {e}")
return {"content": "暂无相关知识库内容", "sources": []}
async def detect_ai_patterns(content: str, platform_id: str) -> dict:
"""检测内容中的AI生成模式"""
from app.services.distribution.rule_service import platform_rule_service
patterns = platform_rule_service.detect_ai_patterns(content, platform_id)
return {"patterns": patterns, "count": len(patterns)}
def register_knowledge_tools(registry: ToolRegistry) -> None:
"""注册所有知识库相关工具"""
registry.register(
FunctionTool(
name="search_knowledge",
description="从知识库检索相关内容",
func=search_knowledge,
tags=["knowledge", "rag"],
)
)
registry.register(
FunctionTool(
name="detect_ai_patterns",
description="检测内容中的AI生成模式",
func=detect_ai_patterns,
tags=["knowledge", "deai"],
)
)
logger.info("Knowledge tools registered")

View File

@ -0,0 +1,91 @@
"""Monitor 业务工具 - 将效果追踪服务注册为 FunctionTool"""
import logging
from typing import Any
from agentkit.tools.function_tool import FunctionTool
from agentkit.tools.registry import ToolRegistry
logger = logging.getLogger(__name__)
async def monitor_check_and_compare(record_id: str) -> dict:
"""检测并对比监测记录的变化"""
import uuid
from app.database import AsyncSessionLocal
from app.services.monitoring.monitor_service import MonitorService
async with AsyncSessionLocal() as db:
service = MonitorService()
updated_record = await service.check_and_compare(db, uuid.UUID(record_id))
if updated_record:
return {
"id": str(updated_record.id),
"change_type": updated_record.change_type,
"updated": True,
}
return {"id": record_id, "updated": False}
async def monitor_generate_report(record_id: str) -> dict:
"""生成监测变化报告"""
import uuid
from app.database import AsyncSessionLocal
from app.services.monitoring.monitor_service import MonitorService
async with AsyncSessionLocal() as db:
service = MonitorService()
report = await service.generate_change_report(db, uuid.UUID(record_id))
return {"report": report} if report else {"report": None}
async def monitor_create_record(
brand_id: str,
query_keywords: str | None = None,
platform: str | None = None,
check_interval_hours: int = 24,
) -> dict:
"""创建监测记录"""
import uuid
from app.database import AsyncSessionLocal
from app.services.monitoring.monitor_service import MonitorService
async with AsyncSessionLocal() as db:
service = MonitorService()
record = await service.create_monitoring_record(
db=db,
brand_id=uuid.UUID(brand_id),
query_keywords=query_keywords,
platform=platform,
check_interval_hours=check_interval_hours,
)
return {"id": str(record.id), "status": record.status}
def register_monitor_tools(registry: ToolRegistry) -> None:
"""注册所有效果追踪相关工具"""
registry.register(
FunctionTool(
name="monitor_check_and_compare",
description="检测并对比监测记录的变化",
func=monitor_check_and_compare,
tags=["monitor", "tracking"],
)
)
registry.register(
FunctionTool(
name="monitor_generate_report",
description="生成监测变化报告",
func=monitor_generate_report,
tags=["monitor", "report"],
)
)
registry.register(
FunctionTool(
name="monitor_create_record",
description="创建新的监测记录",
func=monitor_create_record,
tags=["monitor", "record"],
)
)
logger.info("Monitor tools registered")

View File

@ -0,0 +1,176 @@
"""Schema 业务工具 - 将Schema建议服务注册为 FunctionTool"""
import copy
import json
import logging
from typing import Any
from agentkit.tools.function_tool import FunctionTool
from agentkit.tools.registry import ToolRegistry
logger = logging.getLogger(__name__)
SCHEMA_TEMPLATES = {
"Organization": {
"@context": "https://schema.org",
"@type": "Organization",
"name": "",
"description": "",
"url": "",
"logo": "",
"sameAs": [],
"contactPoint": {
"@type": "ContactPoint",
"contactType": "customer service",
"telephone": "",
},
},
"Product": {
"@context": "https://schema.org",
"@type": "Product",
"name": "",
"description": "",
"brand": {"@type": "Brand", "name": ""},
"offers": {
"@type": "Offer",
"priceCurrency": "CNY",
"availability": "https://schema.org/InStock",
},
},
"FAQPage": {
"@context": "https://schema.org",
"@type": "FAQPage",
"mainEntity": [
{
"@type": "Question",
"name": "",
"acceptedAnswer": {"@type": "Answer", "text": ""},
}
],
},
"Article": {
"@context": "https://schema.org",
"@type": "Article",
"headline": "",
"description": "",
"author": {"@type": "Organization", "name": ""},
"datePublished": "",
"image": "",
},
"LocalBusiness": {
"@context": "https://schema.org",
"@type": "LocalBusiness",
"name": "",
"address": {
"@type": "PostalAddress",
"streetAddress": "",
"addressLocality": "",
"addressRegion": "",
"postalCode": "",
"addressCountry": "CN",
},
"geo": {"@type": "GeoCoordinates", "latitude": "", "longitude": ""},
"telephone": "",
"openingHours": "",
},
}
DIMENSION_SCHEMA_MAP = {
"schema_marketing": ["Organization", "LocalBusiness"],
"entity_clarity": ["Organization", "Product"],
"citation_readiness": ["FAQPage", "Article"],
"brand_visibility": ["Organization", "Product"],
"local_seo": ["LocalBusiness"],
}
PRIORITY_THRESHOLD = {"high": 30.0, "medium": 60.0}
DIFFICULTY_MAP = {
"Organization": "easy",
"Product": "medium",
"FAQPage": "medium",
"Article": "easy",
"LocalBusiness": "hard",
}
async def fill_schema_with_llm(
schema_type: str,
brand_info: dict | None = None,
diagnosis_dimensions: dict | None = None,
) -> dict:
"""使用LLM填充Schema模板"""
from app.services.llm import LLMFactory
from app.agent_framework.prompts.schema_advisor import SCHEMA_ADVISOR_TEMPLATE
from app.utils.json_extractor import extract_json
brand_info = brand_info or {}
diagnosis_dimensions = diagnosis_dimensions or {}
template = SCHEMA_TEMPLATES.get(schema_type)
if not template:
return {"schema_type": schema_type, "json_ld_filled": None, "error": "Unknown schema type"}
provider = LLMFactory.get_default()
variables = {
"brand_name": brand_info.get("name", ""),
"brand_website": brand_info.get("website", ""),
"brand_industry": brand_info.get("industry", ""),
"schema_type": schema_type,
"diagnosis_data": json.dumps(diagnosis_dimensions, ensure_ascii=False),
"existing_schemas": "",
}
messages = SCHEMA_ADVISOR_TEMPLATE.render(variables)
try:
response = await provider.chat(messages, temperature=0.3, max_tokens=2048)
filled = json.loads(extract_json(response.content))
return {"schema_type": schema_type, "json_ld_filled": filled}
except Exception as e:
logger.warning(f"LLM填充Schema {schema_type} 失败: {e}")
return {"schema_type": schema_type, "json_ld_filled": None, "error": str(e)}
async def identify_missing_dimensions(
diagnosis_data: dict,
focus_dimensions: list[str] | None = None,
) -> dict:
"""识别Schema缺失维度"""
dimensions = []
dimension_scores = diagnosis_data.get("dimensions", {})
for dim_name, dim_info in dimension_scores.items():
if dim_name not in DIMENSION_SCHEMA_MAP:
continue
if focus_dimensions and dim_name not in focus_dimensions:
continue
score = dim_info.get("score", 0) if isinstance(dim_info, dict) else dim_info
max_score = dim_info.get("max_score", 100) if isinstance(dim_info, dict) else 100
percentage = (score / max_score * 100) if max_score > 0 else 0
if percentage < 80:
dimensions.append({
"dimension": dim_name,
"current_score": round(score, 2),
"max_score": max_score,
"percentage": round(percentage, 2),
})
return {"missing_dimensions": dimensions}
def register_schema_tools(registry: ToolRegistry) -> None:
"""注册所有Schema建议相关工具"""
registry.register(
FunctionTool(
name="fill_schema_with_llm",
description="使用LLM填充Schema JSON-LD模板",
func=fill_schema_with_llm,
tags=["schema", "llm"],
)
)
registry.register(
FunctionTool(
name="identify_missing_dimensions",
description="识别诊断数据中的Schema缺失维度",
func=identify_missing_dimensions,
tags=["schema", "diagnosis"],
)
)
logger.info("Schema tools registered")

View File

@ -0,0 +1,68 @@
"""Trend 业务工具 - 将趋势分析服务注册为 FunctionTool"""
import logging
from typing import Any
from agentkit.tools.function_tool import FunctionTool
from agentkit.tools.registry import ToolRegistry
logger = logging.getLogger(__name__)
async def trend_insight(
brand_id: str,
days: int = 30,
platforms: list[str] | None = None,
keywords: list[str] | None = None,
) -> dict:
"""执行趋势洞察分析"""
from app.database import AsyncSessionLocal
from app.services.trend.trend_analyzer_service import TrendAnalyzerService
async with AsyncSessionLocal() as db:
service = TrendAnalyzerService(db)
result = await service.analyze_trends(
brand_id=brand_id,
days=days,
platforms=platforms,
keywords=keywords,
)
return result
async def trend_hotspot(
brand_id: str,
days: int = 30,
) -> dict:
"""检测引用量突增的热点话题"""
from app.database import AsyncSessionLocal
from app.services.trend.trend_analyzer_service import TrendAnalyzerService
async with AsyncSessionLocal() as db:
service = TrendAnalyzerService(db)
result = await service.get_hotspots(
brand_id=brand_id,
days=days,
)
return result
def register_trend_tools(registry: ToolRegistry) -> None:
"""注册所有趋势分析相关工具"""
registry.register(
FunctionTool(
name="trend_insight",
description="分析品牌引用趋势,推断变化原因",
func=trend_insight,
tags=["trend", "insight"],
)
)
registry.register(
FunctionTool(
name="trend_hotspot",
description="检测引用量突增的热点话题",
func=trend_hotspot,
tags=["trend", "hotspot"],
)
)
logger.info("Trend tools registered")

View File

@ -57,7 +57,7 @@ async def _get_brand_diagnosis_data(
from app.services.analysis.sentiment_service import get_sentiment_service
queries_stmt = select(QueryModel).where(
QueryModel.user_id == user_id,
QueryModel.user_id == str(user_id),
QueryModel.target_brand == brand.name,
)
queries_result = await db.execute(queries_stmt)

View File

@ -81,7 +81,7 @@ async def _get_citations_for_brand(
# Find queries that match this brand
queries_stmt = select(QueryModel).where(
QueryModel.user_id == user_id,
QueryModel.user_id == str(user_id),
QueryModel.target_brand == brand.name,
)
queries_result = await db.execute(queries_stmt)
@ -215,7 +215,7 @@ async def _calculate_trend(
# Get queries for this brand
queries_stmt = select(QueryModel).where(
QueryModel.user_id == user_id,
QueryModel.user_id == str(user_id),
QueryModel.target_brand == brand.name,
)
queries_result = await db.execute(queries_stmt)
@ -266,7 +266,7 @@ async def _calculate_trend_for_competitor(
# Get queries for this competitor
queries_stmt = select(QueryModel).where(
QueryModel.user_id == user_id,
QueryModel.user_id == str(user_id),
QueryModel.target_brand == competitor_name,
)
queries_result = await db.execute(queries_stmt)

View File

@ -4,7 +4,6 @@ from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from app.api.deps import get_current_user
from app.database import get_db
@ -181,11 +180,7 @@ async def generate_geo_plan_endpoint(
await db.commit()
await db.refresh(db_plan)
stmt = (
select(GeoPlan)
.options(selectinload(GeoPlanAction.plan))
.where(GeoPlan.id == db_plan.id)
)
stmt = select(GeoPlan).where(GeoPlan.id == db_plan.id)
result = await db.execute(stmt)
db_plan = result.scalar_one()

View File

@ -70,7 +70,7 @@ async def _get_brand_scoring_data(
"""
# 获取品牌查询
queries_stmt = select(QueryModel).where(
QueryModel.user_id == user_id,
QueryModel.user_id == str(user_id),
QueryModel.target_brand == brand.name,
)
queries_result = await db.execute(queries_stmt)

View File

@ -1,7 +1,7 @@
import uuid
from datetime import datetime
from sqlalchemy import String, Integer, DateTime, ForeignKey, Index, func, Text
from sqlalchemy import String, Integer, Float, DateTime, ForeignKey, Index, func, Text
from sqlalchemy import Uuid
from sqlalchemy.orm import Mapped, mapped_column, relationship
@ -207,3 +207,102 @@ class AgentTaskLog(Base):
Index("idx_agent_task_logs_agent_id", "agent_id"),
Index("idx_agent_task_logs_created_at", "created_at"),
)
# ---- fischer-agentkit 扩展表 ----
class EpisodicMemory(Base):
"""经验记忆表 - 记录每次任务的输入/输出/效果/反思"""
__tablename__ = "episodic_memories"
id: Mapped[uuid.UUID] = mapped_column(
Uuid(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
)
agent_name: Mapped[str] = mapped_column(String(50), nullable=False)
task_type: Mapped[str] = mapped_column(String(50), nullable=False)
input_data: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
output_data: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
success: Mapped[bool | None] = mapped_column(default=None, nullable=True)
quality_score: Mapped[float | None] = mapped_column(Float, nullable=True)
reflection: Mapped[str | None] = mapped_column(Text, nullable=True)
embedding_id: Mapped[str | None] = mapped_column(String(100), nullable=True)
tags: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
__table_args__ = (
Index("idx_episodic_agent_name", "agent_name"),
Index("idx_episodic_task_type", "task_type"),
Index("idx_episodic_created_at", "created_at"),
)
class EvolutionLog(Base):
"""进化日志表 - 记录每次进化变更"""
__tablename__ = "evolution_logs"
id: Mapped[uuid.UUID] = mapped_column(
Uuid(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
)
agent_name: Mapped[str] = mapped_column(String(50), nullable=False)
change_type: Mapped[str] = mapped_column(String(30), nullable=False)
before: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
after: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
metrics: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
ab_test_id: Mapped[str | None] = mapped_column(String(100), nullable=True)
status: Mapped[str] = mapped_column(String(20), server_default="pending", nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
__table_args__ = (
Index("idx_evolution_agent_name", "agent_name"),
Index("idx_evolution_change_type", "change_type"),
Index("idx_evolution_status", "status"),
Index("idx_evolution_created_at", "created_at"),
)
class ABTestConfig(Base):
"""A/B测试配置表"""
__tablename__ = "ab_test_configs"
id: Mapped[uuid.UUID] = mapped_column(
Uuid(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
)
agent_name: Mapped[str] = mapped_column(String(50), nullable=False)
test_name: Mapped[str] = mapped_column(String(100), nullable=False)
variant_a: Mapped[dict] = mapped_column(JSONType, nullable=False)
variant_b: Mapped[dict] = mapped_column(JSONType, nullable=False)
traffic_split: Mapped[float] = mapped_column(Float, server_default="0.5", nullable=False)
status: Mapped[str] = mapped_column(String(20), server_default="running", nullable=False)
winner: Mapped[str | None] = mapped_column(String(10), nullable=True)
metrics: Mapped[dict | None] = mapped_column(JSONType, nullable=True)
started_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
ended_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
__table_args__ = (
Index("idx_ab_test_agent_name", "agent_name"),
Index("idx_ab_test_status", "status"),
)

View File

@ -10,6 +10,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.config import settings
from app.models.user import User
from app.models.organization import Organization, OrgMember
from app.schemas.auth import UserRegister, UpdateProfileRequest
logger = logging.getLogger(__name__)
@ -75,6 +76,28 @@ async def register_user(db: AsyncSession, user_data: UserRegister) -> User:
max_queries=5,
)
db.add(user)
await db.flush()
# Auto-create personal organization
org = Organization(
id=uuid.UUID(user.id),
name=f"{user_data.name}的个人空间",
slug=f"user-{user.id[:8]}",
plan="free",
)
db.add(org)
await db.flush()
org_member = OrgMember(
organization_id=org.id,
user_id=user.id,
role="owner",
)
db.add(org_member)
# Link user to organization
user.organization_id = org.id
await db.commit()
await db.refresh(user)
return user

View File

@ -46,7 +46,7 @@ class BrandScoringDataService:
brand: Brand,
) -> BrandScoringResult:
queries_stmt = select(QueryModel).where(
QueryModel.user_id == user_id,
QueryModel.user_id == str(user_id),
QueryModel.target_brand == brand.name,
)
queries_result = await db.execute(queries_stmt)
@ -209,7 +209,7 @@ class BrandScoringDataService:
return {platform: 0.0 for platform in REQUIRED_PLATFORMS}
queries_stmt = select(QueryModel).where(
QueryModel.user_id == user_id,
QueryModel.user_id == str(user_id),
QueryModel.target_brand == brand.name,
)
queries_result = await db.execute(queries_stmt)
@ -264,7 +264,7 @@ class BrandScoringDataService:
competitor_names = [c.name for c in competitors]
competitor_queries_stmt = select(QueryModel).where(
QueryModel.user_id == user_id,
QueryModel.user_id == str(user_id),
QueryModel.target_brand.in_(competitor_names),
)
competitor_queries_result = await db.execute(competitor_queries_stmt)

View File

@ -34,6 +34,9 @@ python-dotenv
# YAML解析
pyyaml>=6.0
# Agent框架已独立为 fischer-agentkit 项目,代码仍保留在 app/agent_framework/ 中)
# fischer-agentkit>=0.1.0
# 测试依赖
pytest>=8.0
pytest-asyncio>=0.23.0

File diff suppressed because it is too large Load Diff

View File

@ -22,7 +22,7 @@
"@radix-ui/react-select": "^2.2.6",
"@radix-ui/react-slot": "^1.2.4",
"@radix-ui/react-tabs": "^1.1.13",
"@sentry/nextjs": "^9.0.0",
"@sentry/nextjs": "^9.47.1",
"class-variance-authority": "^0.7.1",
"clsx": "^2.1.1",
"lucide-react": "^1.8.0",