"""GEO Agent 适配层 - 基于 fischer-agentkit 的新架构 职责: 1. 从 YAML 配置创建 ConfigDrivenAgent 2. 提供依赖注入的 Registry 和 Dispatcher 3. 兼容旧版 Agent 类的导入路径 使用方式: # 新方式(推荐) from app.agent_framework import get_agent_registry, get_task_dispatcher registry = get_agent_registry() dispatcher = get_task_dispatcher() # 旧方式(兼容) from app.agent_framework.agents import CitationDetectorAgent agent = CitationDetectorAgent() """ import logging from pathlib import Path from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent from agentkit.core.registry import AgentRegistry as AgentkitRegistry from agentkit.core.dispatcher import TaskDispatcher as AgentkitDispatcher from agentkit.tools.registry import ToolRegistry from app.agent_framework.tools import get_tool_registry logger = logging.getLogger(__name__) _AGENTKIT_REGISTRY: AgentkitRegistry | None = None _AGENTKIT_DISPATCHER: AgentkitDispatcher | None = None def _get_configs_dir() -> Path: """获取 Agent YAML 配置目录""" return Path(__file__).parent / "agents" / "configs" def _get_llm_client(): """获取 LLM 客户端""" from app.services.llm import LLMFactory return LLMFactory.get_default() def _get_custom_handlers() -> dict: """获取所有自定义 handler""" from app.agent_framework.agents.custom_handlers.citation_handler import handle_citation_task from app.agent_framework.agents.custom_handlers.monitor_handler import handle_monitor_task from app.agent_framework.agents.custom_handlers.schema_handler import handle_schema_task return { "app.agent_framework.agents.custom_handlers.citation_handler.handle_citation_task": handle_citation_task, "app.agent_framework.agents.custom_handlers.monitor_handler.handle_monitor_task": handle_monitor_task, "app.agent_framework.agents.custom_handlers.schema_handler.handle_schema_task": handle_schema_task, } def _get_session_factory(): """获取数据库会话工厂""" from app.database import AsyncSessionLocal return AsyncSessionLocal def _get_redis_factory(): """获取 Redis 连接工厂""" from app.core.redis import get_redis return get_redis def _get_agent_model(): """获取 Agent ORM 模型""" from app.models.agent import AgentRegistry as AgentRegistryModel return AgentRegistryModel def _get_task_model(): """获取 Task ORM 模型""" from app.models.agent import AgentTask as AgentTaskModel return AgentTaskModel def _get_task_log_model(): """获取 TaskLog ORM 模型""" from app.models.agent import AgentTaskLog as AgentTaskLogModel return AgentTaskLogModel def create_agents_from_configs() -> list[ConfigDrivenAgent]: """从 YAML 配置目录创建所有 Agent(新架构)""" configs_dir = _get_configs_dir() tool_registry = get_tool_registry() llm_client = _get_llm_client() custom_handlers = _get_custom_handlers() agents = [] for yaml_file in sorted(configs_dir.glob("*.yaml")): try: config = AgentConfig.from_yaml(str(yaml_file)) agent = ConfigDrivenAgent( config=config, tool_registry=tool_registry, llm_client=llm_client, custom_handlers=custom_handlers, ) agents.append(agent) logger.info(f"Created agent '{config.name}' from {yaml_file.name}") except Exception as e: logger.error(f"Failed to create agent from {yaml_file.name}: {e}") return agents def get_agent_registry() -> AgentkitRegistry: """获取 agentkit AgentRegistry(懒初始化,依赖注入)""" global _AGENTKIT_REGISTRY if _AGENTKIT_REGISTRY is None: _AGENTKIT_REGISTRY = AgentkitRegistry( session_factory=_get_session_factory(), agent_model=_get_agent_model(), ) agents = create_agents_from_configs() for agent in agents: _AGENTKIT_REGISTRY.register(agent) logger.info(f"Agentkit AgentRegistry initialized with {len(agents)} agents") return _AGENTKIT_REGISTRY def get_task_dispatcher() -> AgentkitDispatcher: """获取 agentkit TaskDispatcher(懒初始化,依赖注入)""" global _AGENTKIT_DISPATCHER if _AGENTKIT_DISPATCHER is None: _AGENTKIT_DISPATCHER = AgentkitDispatcher( redis_factory=_get_redis_factory(), session_factory=_get_session_factory(), agent_model=_get_agent_model(), task_model=_get_task_model(), task_log_model=_get_task_log_model(), ) logger.info("Agentkit TaskDispatcher initialized") return _AGENTKIT_DISPATCHER def get_legacy_agent(name: str): """获取旧版 Agent 实例(兼容层) 旧代码可能直接实例化 Agent 类,此方法提供兼容性。 逐步迁移后可移除。 """ from app.agent_framework.agents import ( CitationDetectorAgent, CompetitorAnalyzerAgent, ContentGeneratorAgent, DeAIAgent, GEOOptimizerAgent, MonitorAgent, SchemaAdvisorAgent, TrendAgent, ) legacy_map = { "citation_detector": CitationDetectorAgent, "competitor_analyzer": CompetitorAnalyzerAgent, "content_generator": ContentGeneratorAgent, "deai_agent": DeAIAgent, "geo_optimizer": GEOOptimizerAgent, "monitor": MonitorAgent, "schema_advisor": SchemaAdvisorAgent, "trend_agent": TrendAgent, } agent_cls = legacy_map.get(name) if agent_cls: return agent_cls() raise ValueError(f"Unknown agent name: {name}")