176 lines
5.6 KiB
Python
176 lines
5.6 KiB
Python
"""GEO Agent 适配层 - 基于 fischer-agentkit 的新架构
|
||
|
||
职责:
|
||
1. 从 YAML 配置创建 ConfigDrivenAgent
|
||
2. 提供依赖注入的 Registry 和 Dispatcher
|
||
3. 兼容旧版 Agent 类的导入路径
|
||
|
||
使用方式:
|
||
# 新方式(推荐)
|
||
from app.agent_framework import get_agent_registry, get_task_dispatcher
|
||
registry = get_agent_registry()
|
||
dispatcher = get_task_dispatcher()
|
||
|
||
# 旧方式(兼容)
|
||
from app.agent_framework.agents import CitationDetectorAgent
|
||
agent = CitationDetectorAgent()
|
||
"""
|
||
|
||
import logging
|
||
from pathlib import Path
|
||
|
||
from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent
|
||
from agentkit.core.registry import AgentRegistry as AgentkitRegistry
|
||
from agentkit.core.dispatcher import TaskDispatcher as AgentkitDispatcher
|
||
from agentkit.tools.registry import ToolRegistry
|
||
|
||
from app.agent_framework.tools import get_tool_registry
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
_AGENTKIT_REGISTRY: AgentkitRegistry | None = None
|
||
_AGENTKIT_DISPATCHER: AgentkitDispatcher | None = None
|
||
|
||
|
||
def _get_configs_dir() -> Path:
|
||
"""获取 Agent YAML 配置目录"""
|
||
return Path(__file__).parent / "agents" / "configs"
|
||
|
||
|
||
def _get_llm_client():
|
||
"""获取 LLM 客户端"""
|
||
from app.services.llm import LLMFactory
|
||
return LLMFactory.get_default()
|
||
|
||
|
||
def _get_custom_handlers() -> dict:
|
||
"""获取所有自定义 handler"""
|
||
from app.agent_framework.agents.custom_handlers.citation_handler import handle_citation_task
|
||
from app.agent_framework.agents.custom_handlers.monitor_handler import handle_monitor_task
|
||
from app.agent_framework.agents.custom_handlers.schema_handler import handle_schema_task
|
||
|
||
return {
|
||
"app.agent_framework.agents.custom_handlers.citation_handler.handle_citation_task": handle_citation_task,
|
||
"app.agent_framework.agents.custom_handlers.monitor_handler.handle_monitor_task": handle_monitor_task,
|
||
"app.agent_framework.agents.custom_handlers.schema_handler.handle_schema_task": handle_schema_task,
|
||
}
|
||
|
||
|
||
def _get_session_factory():
|
||
"""获取数据库会话工厂"""
|
||
from app.database import AsyncSessionLocal
|
||
return AsyncSessionLocal
|
||
|
||
|
||
def _get_redis_factory():
|
||
"""获取 Redis 连接工厂"""
|
||
from app.core.redis import get_redis
|
||
return get_redis
|
||
|
||
|
||
def _get_agent_model():
|
||
"""获取 Agent ORM 模型"""
|
||
from app.models.agent import AgentRegistry as AgentRegistryModel
|
||
return AgentRegistryModel
|
||
|
||
|
||
def _get_task_model():
|
||
"""获取 Task ORM 模型"""
|
||
from app.models.agent import AgentTask as AgentTaskModel
|
||
return AgentTaskModel
|
||
|
||
|
||
def _get_task_log_model():
|
||
"""获取 TaskLog ORM 模型"""
|
||
from app.models.agent import AgentTaskLog as AgentTaskLogModel
|
||
return AgentTaskLogModel
|
||
|
||
|
||
def create_agents_from_configs() -> list[ConfigDrivenAgent]:
|
||
"""从 YAML 配置目录创建所有 Agent(新架构)"""
|
||
configs_dir = _get_configs_dir()
|
||
tool_registry = get_tool_registry()
|
||
llm_client = _get_llm_client()
|
||
custom_handlers = _get_custom_handlers()
|
||
|
||
agents = []
|
||
for yaml_file in sorted(configs_dir.glob("*.yaml")):
|
||
try:
|
||
config = AgentConfig.from_yaml(str(yaml_file))
|
||
agent = ConfigDrivenAgent(
|
||
config=config,
|
||
tool_registry=tool_registry,
|
||
llm_client=llm_client,
|
||
custom_handlers=custom_handlers,
|
||
)
|
||
agents.append(agent)
|
||
logger.info(f"Created agent '{config.name}' from {yaml_file.name}")
|
||
except Exception as e:
|
||
logger.error(f"Failed to create agent from {yaml_file.name}: {e}")
|
||
|
||
return agents
|
||
|
||
|
||
def get_agent_registry() -> AgentkitRegistry:
|
||
"""获取 agentkit AgentRegistry(懒初始化,依赖注入)"""
|
||
global _AGENTKIT_REGISTRY
|
||
if _AGENTKIT_REGISTRY is None:
|
||
_AGENTKIT_REGISTRY = AgentkitRegistry(
|
||
session_factory=_get_session_factory(),
|
||
agent_model=_get_agent_model(),
|
||
)
|
||
agents = create_agents_from_configs()
|
||
for agent in agents:
|
||
_AGENTKIT_REGISTRY.register(agent)
|
||
logger.info(f"Agentkit AgentRegistry initialized with {len(agents)} agents")
|
||
return _AGENTKIT_REGISTRY
|
||
|
||
|
||
def get_task_dispatcher() -> AgentkitDispatcher:
|
||
"""获取 agentkit TaskDispatcher(懒初始化,依赖注入)"""
|
||
global _AGENTKIT_DISPATCHER
|
||
if _AGENTKIT_DISPATCHER is None:
|
||
_AGENTKIT_DISPATCHER = AgentkitDispatcher(
|
||
redis_factory=_get_redis_factory(),
|
||
session_factory=_get_session_factory(),
|
||
agent_model=_get_agent_model(),
|
||
task_model=_get_task_model(),
|
||
task_log_model=_get_task_log_model(),
|
||
)
|
||
logger.info("Agentkit TaskDispatcher initialized")
|
||
return _AGENTKIT_DISPATCHER
|
||
|
||
|
||
def get_legacy_agent(name: str):
|
||
"""获取旧版 Agent 实例(兼容层)
|
||
|
||
旧代码可能直接实例化 Agent 类,此方法提供兼容性。
|
||
逐步迁移后可移除。
|
||
"""
|
||
from app.agent_framework.agents import (
|
||
CitationDetectorAgent,
|
||
CompetitorAnalyzerAgent,
|
||
ContentGeneratorAgent,
|
||
DeAIAgent,
|
||
GEOOptimizerAgent,
|
||
MonitorAgent,
|
||
SchemaAdvisorAgent,
|
||
TrendAgent,
|
||
)
|
||
|
||
legacy_map = {
|
||
"citation_detector": CitationDetectorAgent,
|
||
"competitor_analyzer": CompetitorAnalyzerAgent,
|
||
"content_generator": ContentGeneratorAgent,
|
||
"deai_agent": DeAIAgent,
|
||
"geo_optimizer": GEOOptimizerAgent,
|
||
"monitor": MonitorAgent,
|
||
"schema_advisor": SchemaAdvisorAgent,
|
||
"trend_agent": TrendAgent,
|
||
}
|
||
|
||
agent_cls = legacy_map.get(name)
|
||
if agent_cls:
|
||
return agent_cls()
|
||
raise ValueError(f"Unknown agent name: {name}")
|