geo/backend/app/agent_framework/adapter.py

176 lines
5.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""GEO Agent 适配层 - 基于 fischer-agentkit 的新架构
职责:
1. 从 YAML 配置创建 ConfigDrivenAgent
2. 提供依赖注入的 Registry 和 Dispatcher
3. 兼容旧版 Agent 类的导入路径
使用方式:
# 新方式(推荐)
from app.agent_framework import get_agent_registry, get_task_dispatcher
registry = get_agent_registry()
dispatcher = get_task_dispatcher()
# 旧方式(兼容)
from app.agent_framework.agents import CitationDetectorAgent
agent = CitationDetectorAgent()
"""
import logging
from pathlib import Path
from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent
from agentkit.core.registry import AgentRegistry as AgentkitRegistry
from agentkit.core.dispatcher import TaskDispatcher as AgentkitDispatcher
from agentkit.tools.registry import ToolRegistry
from app.agent_framework.tools import get_tool_registry
logger = logging.getLogger(__name__)
_AGENTKIT_REGISTRY: AgentkitRegistry | None = None
_AGENTKIT_DISPATCHER: AgentkitDispatcher | None = None
def _get_configs_dir() -> Path:
"""获取 Agent YAML 配置目录"""
return Path(__file__).parent / "agents" / "configs"
def _get_llm_client():
"""获取 LLM 客户端"""
from app.services.llm import LLMFactory
return LLMFactory.get_default()
def _get_custom_handlers() -> dict:
"""获取所有自定义 handler"""
from app.agent_framework.agents.custom_handlers.citation_handler import handle_citation_task
from app.agent_framework.agents.custom_handlers.monitor_handler import handle_monitor_task
from app.agent_framework.agents.custom_handlers.schema_handler import handle_schema_task
return {
"app.agent_framework.agents.custom_handlers.citation_handler.handle_citation_task": handle_citation_task,
"app.agent_framework.agents.custom_handlers.monitor_handler.handle_monitor_task": handle_monitor_task,
"app.agent_framework.agents.custom_handlers.schema_handler.handle_schema_task": handle_schema_task,
}
def _get_session_factory():
"""获取数据库会话工厂"""
from app.database import AsyncSessionLocal
return AsyncSessionLocal
def _get_redis_factory():
"""获取 Redis 连接工厂"""
from app.core.redis import get_redis
return get_redis
def _get_agent_model():
"""获取 Agent ORM 模型"""
from app.models.agent import AgentRegistry as AgentRegistryModel
return AgentRegistryModel
def _get_task_model():
"""获取 Task ORM 模型"""
from app.models.agent import AgentTask as AgentTaskModel
return AgentTaskModel
def _get_task_log_model():
"""获取 TaskLog ORM 模型"""
from app.models.agent import AgentTaskLog as AgentTaskLogModel
return AgentTaskLogModel
def create_agents_from_configs() -> list[ConfigDrivenAgent]:
"""从 YAML 配置目录创建所有 Agent新架构"""
configs_dir = _get_configs_dir()
tool_registry = get_tool_registry()
llm_client = _get_llm_client()
custom_handlers = _get_custom_handlers()
agents = []
for yaml_file in sorted(configs_dir.glob("*.yaml")):
try:
config = AgentConfig.from_yaml(str(yaml_file))
agent = ConfigDrivenAgent(
config=config,
tool_registry=tool_registry,
llm_client=llm_client,
custom_handlers=custom_handlers,
)
agents.append(agent)
logger.info(f"Created agent '{config.name}' from {yaml_file.name}")
except Exception as e:
logger.error(f"Failed to create agent from {yaml_file.name}: {e}")
return agents
def get_agent_registry() -> AgentkitRegistry:
"""获取 agentkit AgentRegistry懒初始化依赖注入"""
global _AGENTKIT_REGISTRY
if _AGENTKIT_REGISTRY is None:
_AGENTKIT_REGISTRY = AgentkitRegistry(
session_factory=_get_session_factory(),
agent_model=_get_agent_model(),
)
agents = create_agents_from_configs()
for agent in agents:
_AGENTKIT_REGISTRY.register(agent)
logger.info(f"Agentkit AgentRegistry initialized with {len(agents)} agents")
return _AGENTKIT_REGISTRY
def get_task_dispatcher() -> AgentkitDispatcher:
"""获取 agentkit TaskDispatcher懒初始化依赖注入"""
global _AGENTKIT_DISPATCHER
if _AGENTKIT_DISPATCHER is None:
_AGENTKIT_DISPATCHER = AgentkitDispatcher(
redis_factory=_get_redis_factory(),
session_factory=_get_session_factory(),
agent_model=_get_agent_model(),
task_model=_get_task_model(),
task_log_model=_get_task_log_model(),
)
logger.info("Agentkit TaskDispatcher initialized")
return _AGENTKIT_DISPATCHER
def get_legacy_agent(name: str):
"""获取旧版 Agent 实例(兼容层)
旧代码可能直接实例化 Agent 类,此方法提供兼容性。
逐步迁移后可移除。
"""
from app.agent_framework.agents import (
CitationDetectorAgent,
CompetitorAnalyzerAgent,
ContentGeneratorAgent,
DeAIAgent,
GEOOptimizerAgent,
MonitorAgent,
SchemaAdvisorAgent,
TrendAgent,
)
legacy_map = {
"citation_detector": CitationDetectorAgent,
"competitor_analyzer": CompetitorAnalyzerAgent,
"content_generator": ContentGeneratorAgent,
"deai_agent": DeAIAgent,
"geo_optimizer": GEOOptimizerAgent,
"monitor": MonitorAgent,
"schema_advisor": SchemaAdvisorAgent,
"trend_agent": TrendAgent,
}
agent_cls = legacy_map.get(name)
if agent_cls:
return agent_cls()
raise ValueError(f"Unknown agent name: {name}")