1123 lines
44 KiB
Python
1123 lines
44 KiB
Python
"""ConfigDrivenAgent - 配置驱动的 Agent 定义
|
||
|
||
核心设计:
|
||
- 从 YAML/Dict 配置自动组装 Agent(Prompt + LLM + Tool + Memory)
|
||
- 支持三种任务模式:llm_generate / tool_call / custom
|
||
- v2: 支持 SkillConfig + ReAct 执行模式 + LLMGateway + Quality Gate
|
||
- 新增 Agent 从写 150 行代码降为 10-20 行配置
|
||
"""
|
||
|
||
import json
|
||
import logging
|
||
import os
|
||
from typing import Callable, Coroutine
|
||
|
||
import yaml
|
||
|
||
from agentkit.core.base import BaseAgent
|
||
from agentkit.core.exceptions import ConfigValidationError
|
||
from agentkit.core.protocol import AgentCapability, TaskMessage
|
||
from agentkit.evolution.lifecycle import EvolutionMixin
|
||
from agentkit.evolution.reflector import Reflector
|
||
from agentkit.prompts.section import PromptSection
|
||
from agentkit.prompts.template import PromptTemplate
|
||
from agentkit.tools.base import Tool
|
||
from agentkit.tools.registry import ToolRegistry
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class AgentConfig:
|
||
"""Agent 配置模型,从 YAML 或 Dict 构建"""
|
||
|
||
def __init__(
|
||
self,
|
||
name: str,
|
||
agent_type: str,
|
||
version: str = "1.0.0",
|
||
description: str = "",
|
||
task_mode: str = "llm_generate",
|
||
supported_tasks: list[str] | None = None,
|
||
max_concurrency: int = 1,
|
||
input_schema: dict[str, object] | None = None,
|
||
output_schema: dict[str, object] | None = None,
|
||
prompt: dict[str, str] | None = None,
|
||
llm: dict[str, object] | None = None,
|
||
tools: list[str] | None = None,
|
||
memory: dict[str, object] | None = None,
|
||
custom_handler: str | None = None,
|
||
):
|
||
self.name = name
|
||
self.agent_type = agent_type
|
||
self.version = version
|
||
self.description = description
|
||
self.task_mode = task_mode
|
||
self.supported_tasks = supported_tasks or [agent_type]
|
||
self.max_concurrency = max_concurrency
|
||
self.input_schema = input_schema
|
||
self.output_schema = output_schema
|
||
self.prompt = prompt or {}
|
||
self.llm = llm or {}
|
||
self.tools = tools or []
|
||
self.memory = memory or {}
|
||
self.custom_handler = custom_handler
|
||
|
||
self._validate()
|
||
|
||
def _validate(self) -> None:
|
||
"""校验配置合法性"""
|
||
valid_modes = {"llm_generate", "tool_call", "custom"}
|
||
if self.task_mode not in valid_modes:
|
||
raise ConfigValidationError(
|
||
agent_name=self.name,
|
||
key="task_mode",
|
||
reason=f"Invalid task_mode '{self.task_mode}', must be one of {valid_modes}",
|
||
)
|
||
|
||
if self.task_mode == "llm_generate" and not self.prompt:
|
||
raise ConfigValidationError(
|
||
agent_name=self.name,
|
||
key="prompt",
|
||
reason="llm_generate mode requires 'prompt' configuration",
|
||
)
|
||
|
||
if self.task_mode == "tool_call" and not self.tools:
|
||
raise ConfigValidationError(
|
||
agent_name=self.name,
|
||
key="tools",
|
||
reason="tool_call mode requires at least one tool in 'tools' list",
|
||
)
|
||
|
||
if self.task_mode == "custom" and not self.custom_handler:
|
||
raise ConfigValidationError(
|
||
agent_name=self.name,
|
||
key="custom_handler",
|
||
reason="custom mode requires 'custom_handler' (dotted path to callable)",
|
||
)
|
||
|
||
@classmethod
|
||
def from_dict(cls, data: dict[str, object]) -> "AgentConfig":
|
||
"""从字典创建配置"""
|
||
return cls(
|
||
name=data["name"],
|
||
agent_type=data["agent_type"],
|
||
version=data.get("version", "1.0.0"),
|
||
description=data.get("description", ""),
|
||
task_mode=data.get("task_mode", "llm_generate"),
|
||
supported_tasks=data.get("supported_tasks"),
|
||
max_concurrency=data.get("max_concurrency", 1),
|
||
input_schema=data.get("input_schema"),
|
||
output_schema=data.get("output_schema"),
|
||
prompt=data.get("prompt"),
|
||
llm=data.get("llm"),
|
||
tools=data.get("tools"),
|
||
memory=data.get("memory"),
|
||
custom_handler=data.get("custom_handler"),
|
||
)
|
||
|
||
@classmethod
|
||
def from_yaml(cls, path: str) -> "AgentConfig":
|
||
"""从 YAML 文件加载配置"""
|
||
with open(path, "r", encoding="utf-8") as f:
|
||
data = yaml.safe_load(f)
|
||
if not isinstance(data, dict):
|
||
raise ConfigValidationError(
|
||
agent_name="unknown",
|
||
key="config",
|
||
reason=f"YAML config must be a mapping, got {type(data)}",
|
||
)
|
||
return cls.from_dict(data)
|
||
|
||
def to_dict(self) -> dict[str, object]:
|
||
"""序列化为字典"""
|
||
d = {
|
||
"name": self.name,
|
||
"agent_type": self.agent_type,
|
||
"version": self.version,
|
||
"description": self.description,
|
||
"task_mode": self.task_mode,
|
||
"supported_tasks": self.supported_tasks,
|
||
"max_concurrency": self.max_concurrency,
|
||
}
|
||
if self.input_schema:
|
||
d["input_schema"] = self.input_schema
|
||
if self.output_schema:
|
||
d["output_schema"] = self.output_schema
|
||
if self.prompt:
|
||
d["prompt"] = self.prompt
|
||
if self.llm:
|
||
d["llm"] = self.llm
|
||
if self.tools:
|
||
d["tools"] = self.tools
|
||
if self.memory:
|
||
d["memory"] = self.memory
|
||
if self.custom_handler:
|
||
d["custom_handler"] = self.custom_handler
|
||
return d
|
||
|
||
|
||
class ConfigDrivenAgent(BaseAgent, EvolutionMixin):
|
||
"""配置驱动的 Agent
|
||
|
||
从 YAML/Dict 配置自动组装,支持三种任务模式:
|
||
- llm_generate: 渲染 Prompt → 调用 LLM → 解析 JSON 输出
|
||
- tool_call: 调用注册的 Tool 并返回结果
|
||
- custom: 自定义 handler 函数
|
||
|
||
v2 增强:
|
||
- 接受 SkillConfig,自动创建 Skill 并启用 ReAct 模式
|
||
- llm_gateway 参数直接传入 LLMGateway
|
||
- llm_client 参数自动包装为 LLMGateway(向后兼容)
|
||
- Quality Gate 自动集成
|
||
|
||
示例 YAML 配置::
|
||
|
||
name: content_generator
|
||
agent_type: content_generation
|
||
task_mode: llm_generate
|
||
description: "内容生成 Agent"
|
||
prompt:
|
||
identity: "你是一个专业的内容生成助手"
|
||
instructions: "根据用户需求生成高质量内容"
|
||
output_format: "以 JSON 格式输出"
|
||
llm:
|
||
model: "gpt-4"
|
||
temperature: 0.7
|
||
tools:
|
||
- retrieve_knowledge
|
||
"""
|
||
|
||
# Security: whitelist of allowed module prefixes for dynamic handler import
|
||
_ALLOWED_HANDLER_PREFIXES = (
|
||
"agentkit.",
|
||
"app.agent_framework.",
|
||
)
|
||
|
||
def __init__(
|
||
self,
|
||
config: AgentConfig,
|
||
tool_registry: ToolRegistry | None = None,
|
||
llm_client: object | None = None,
|
||
custom_handlers: dict[str, Callable[..., Coroutine]] | None = None,
|
||
llm_gateway: object | None = None, # NEW v2 param: LLMGateway
|
||
mcp_servers: dict[str, str] | None = None, # NEW v2 param: MCP server URLs
|
||
compressor: object | None = None, # CompressionStrategy | None
|
||
):
|
||
# v2: If SkillConfig, extract skill info
|
||
from agentkit.skills.base import SkillConfig, Skill
|
||
|
||
self._skill_config: SkillConfig | None = None
|
||
self._skill_instance: Skill | None = None
|
||
|
||
if isinstance(config, SkillConfig):
|
||
self._skill_config = config
|
||
self._skill_instance = Skill(config=config)
|
||
|
||
self._config = config
|
||
self._tool_registry = tool_registry or ToolRegistry()
|
||
self._llm_client = llm_client
|
||
self._custom_handlers = custom_handlers or {}
|
||
self._prompt_template: PromptTemplate | None = None
|
||
|
||
# Call super().__init__() first
|
||
super().__init__(
|
||
name=config.name,
|
||
agent_type=config.agent_type,
|
||
version=config.version,
|
||
)
|
||
|
||
# v2: Backward compat — wrap llm_client into LLMGateway if no gateway provided
|
||
if llm_gateway is not None:
|
||
self._llm_gateway = llm_gateway
|
||
elif llm_client is not None:
|
||
self._llm_gateway = self._wrap_llm_client(llm_client)
|
||
else:
|
||
self._llm_gateway = None
|
||
|
||
# v2: Set skill on base agent
|
||
if self._skill_instance:
|
||
self._skill = self._skill_instance
|
||
|
||
# v2: Initialize ReAct engine if gateway available
|
||
self._react_engine = None
|
||
if self._llm_gateway:
|
||
from agentkit.core.react import ReActEngine
|
||
|
||
self._react_engine = ReActEngine(
|
||
llm_gateway=self._llm_gateway,
|
||
max_steps=getattr(config, "max_steps", 5),
|
||
)
|
||
|
||
# v2: Initialize Quality Gate (always available)
|
||
from agentkit.quality.gate import QualityGate
|
||
|
||
self._quality_gate = QualityGate()
|
||
|
||
# v2: Initialize Evolution if configured
|
||
evolution_config = getattr(config, "evolution", None)
|
||
if evolution_config is not None:
|
||
# Support both dict and EvolutionConfig
|
||
if isinstance(evolution_config, dict):
|
||
is_enabled = evolution_config.get("enabled", False)
|
||
else:
|
||
is_enabled = getattr(evolution_config, "enabled", False)
|
||
else:
|
||
is_enabled = False
|
||
|
||
if is_enabled:
|
||
reflector = Reflector()
|
||
EvolutionMixin.__init__(
|
||
self,
|
||
reflector=reflector,
|
||
)
|
||
self._evolution_enabled = True
|
||
else:
|
||
EvolutionMixin.__init__(self) # Initialize with no components
|
||
self._evolution_enabled = False
|
||
|
||
# v2: Initialize Output Standardizer
|
||
from agentkit.quality.output import OutputStandardizer
|
||
|
||
self._output_standardizer = OutputStandardizer()
|
||
|
||
# v2: Store compressor for ReAct engine
|
||
self._compressor = compressor
|
||
|
||
# 从配置构建 Prompt 模板
|
||
if config.prompt:
|
||
sections = PromptSection(
|
||
identity=config.prompt.get("identity", ""),
|
||
context=config.prompt.get("context", ""),
|
||
instructions=config.prompt.get("instructions", ""),
|
||
constraints=config.prompt.get("constraints", ""),
|
||
output_format=config.prompt.get("output_format", ""),
|
||
examples=config.prompt.get("examples", ""),
|
||
)
|
||
self._prompt_template = PromptTemplate(
|
||
sections=sections,
|
||
name=config.name,
|
||
version=config.version,
|
||
)
|
||
|
||
# 从配置绑定 Tool
|
||
self._bind_tools()
|
||
|
||
# v2: Merge Skill-bound tools into Agent's tool list
|
||
if self._skill_instance and self._skill_instance.tools:
|
||
for tool in self._skill_instance.tools:
|
||
if not any(t.name == tool.name for t in self._tools):
|
||
self.use_tool(tool)
|
||
logger.info(f"Merged skill tool '{tool.name}' into agent '{self.name}'")
|
||
|
||
# v2: Register MCP tools if mcp_servers provided
|
||
self._mcp_clients: list[object] = []
|
||
self._mcp_servers: dict[str, str] = mcp_servers or {}
|
||
self._mcp_tools_registered = False
|
||
|
||
# Memory integration: 从 config.memory 自动实例化 MemoryRetriever
|
||
self._memory_retriever: object | None = None
|
||
if config.memory:
|
||
try:
|
||
from agentkit.memory.retriever import MemoryRetriever
|
||
from agentkit.memory.working import WorkingMemory
|
||
from agentkit.memory.semantic import SemanticMemory
|
||
from agentkit.memory.http_rag import HttpRAGService
|
||
|
||
working = None
|
||
episodic = None
|
||
semantic = None
|
||
|
||
if config.memory.get("working", {}).get("enabled"):
|
||
import redis.asyncio as aioredis
|
||
|
||
redis_url = config.memory["working"].get("redis_url", "redis://localhost:6379")
|
||
redis_client = aioredis.from_url(redis_url, decode_responses=True)
|
||
working = WorkingMemory(redis=redis_client)
|
||
|
||
if config.memory.get("episodic", {}).get("enabled"):
|
||
from agentkit.memory.episodic import EpisodicMemory
|
||
from agentkit.memory.embedder import OpenAIEmbedder, EmbeddingCache
|
||
|
||
epi_conf = config.memory["episodic"]
|
||
embedder = None
|
||
if epi_conf.get("embedder_api_key") or os.environ.get("OPENAI_API_KEY"):
|
||
cache = EmbeddingCache(
|
||
max_size=epi_conf.get("cache_max_size", 1000),
|
||
ttl=epi_conf.get("cache_ttl", 3600),
|
||
)
|
||
embedder = OpenAIEmbedder(
|
||
api_key=epi_conf.get("embedder_api_key"),
|
||
model=epi_conf.get("embedder_model", "text-embedding-3-small"),
|
||
base_url=epi_conf.get("embedder_base_url"),
|
||
cache=cache,
|
||
)
|
||
episodic = EpisodicMemory(
|
||
session_factory=None, # Set externally when DB session is available
|
||
episodic_model=None, # Set externally when ORM model is available
|
||
embedder=embedder,
|
||
decay_rate=epi_conf.get("decay_rate", 0.01),
|
||
alpha=epi_conf.get("alpha", 0.7),
|
||
retrieve_limit=epi_conf.get("retrieve_limit", 200),
|
||
pgvector_enabled=epi_conf.get("pgvector_enabled", True),
|
||
table_name=epi_conf.get("table_name", "episodic_memories"),
|
||
)
|
||
|
||
if config.memory.get("semantic", {}).get("enabled"):
|
||
sem_conf = config.memory["semantic"]
|
||
rag_service = HttpRAGService(
|
||
base_url=sem_conf["base_url"],
|
||
api_key=sem_conf.get("api_key"),
|
||
knowledge_base_ids=sem_conf.get("knowledge_base_ids", []),
|
||
timeout=sem_conf.get("timeout", 30),
|
||
)
|
||
semantic = SemanticMemory(
|
||
rag_service=rag_service,
|
||
knowledge_base_ids=sem_conf.get("knowledge_base_ids", []),
|
||
search_mode=sem_conf.get("search_mode", "standard"),
|
||
use_rerank=sem_conf.get("use_rerank", True),
|
||
use_compression=sem_conf.get("use_compression", False),
|
||
kb_weights=sem_conf.get("kb_weights"),
|
||
)
|
||
|
||
self._memory_retriever = MemoryRetriever(
|
||
working_memory=working,
|
||
episodic_memory=episodic,
|
||
semantic_memory=semantic,
|
||
)
|
||
|
||
# Inject into BaseAgent
|
||
self._memory_retriever_ref = self._memory_retriever
|
||
|
||
logger.info(f"ConfigDrivenAgent '{self.name}' initialized memory system")
|
||
except Exception as e:
|
||
logger.warning(f"Failed to initialize memory system: {e}")
|
||
self._memory_retriever = None
|
||
|
||
# Auto-register retrieve_knowledge tool if semantic memory is configured
|
||
if self._memory_retriever:
|
||
retrieve_tool = self._memory_retriever.create_retrieve_tool()
|
||
if retrieve_tool:
|
||
self.use_tool(retrieve_tool)
|
||
|
||
def get_tools(self) -> list[Tool]:
|
||
"""Return registered tools for this agent (from config + post-init registration)."""
|
||
all_tools = list(self._tools)
|
||
if self._tool_registry:
|
||
for tool in self._tool_registry.list_tools():
|
||
if not any(t.name == tool.name for t in all_tools):
|
||
all_tools.append(tool)
|
||
return all_tools
|
||
|
||
def get_model(self) -> str:
|
||
"""Return the LLM model name for this agent."""
|
||
return self._config.llm.get("model", "default") if self._config.llm else "default"
|
||
|
||
def get_system_prompt(self) -> str | None:
|
||
"""Return the system prompt for this agent, including available tools."""
|
||
parts = []
|
||
|
||
if self._prompt_template:
|
||
sections = self._prompt_template._sections
|
||
for key in ("identity", "context", "instructions", "constraints", "output_format"):
|
||
val = getattr(sections, key, "")
|
||
if val:
|
||
parts.append(val)
|
||
|
||
# Append available tools description so the LLM knows what it can use
|
||
# Use _tool_registry (which may have tools registered post-init) as well as _tools (from config)
|
||
all_tools = list(self._tools)
|
||
if self._tool_registry:
|
||
for tool in self._tool_registry.list_tools():
|
||
if not any(t.name == tool.name for t in all_tools):
|
||
all_tools.append(tool)
|
||
if all_tools:
|
||
tools_desc = self._build_tools_description(all_tools)
|
||
parts.append(f"\n\n## 可用工具\n{tools_desc}")
|
||
|
||
return "\n".join(parts) if parts else None
|
||
|
||
def _build_tools_description(self, tools: list | None = None) -> str:
|
||
"""Build a text description of available tools for the system prompt."""
|
||
tools = tools or self._tools
|
||
lines = []
|
||
for tool in tools:
|
||
lines.append(f"- **{tool.name}**: {tool.description}")
|
||
if tool.input_schema and "properties" in tool.input_schema:
|
||
params = list(tool.input_schema["properties"].keys())
|
||
if params:
|
||
lines[-1] += f" (参数: {', '.join(params)})"
|
||
return "\n".join(lines)
|
||
|
||
def get_react_config(self) -> dict:
|
||
"""Return ReAct engine configuration."""
|
||
max_steps = 10
|
||
timeout_seconds = None
|
||
if self._skill_config:
|
||
max_steps = self._skill_config.max_steps
|
||
timeout_seconds = getattr(self._skill_config, "timeout_seconds", None)
|
||
return {
|
||
"max_steps": max_steps,
|
||
"timeout_seconds": timeout_seconds,
|
||
}
|
||
|
||
@property
|
||
def config(self) -> AgentConfig:
|
||
return self._config
|
||
|
||
@property
|
||
def prompt_template(self) -> PromptTemplate | None:
|
||
return self._prompt_template
|
||
|
||
async def on_task_complete(self, task: TaskMessage, output: dict) -> None:
|
||
"""Task complete hook - trigger evolution if enabled"""
|
||
if self._evolution_enabled:
|
||
try:
|
||
from agentkit.core.protocol import TaskResult, TaskStatus
|
||
from datetime import datetime, timezone
|
||
|
||
result = TaskResult(
|
||
task_id=task.task_id,
|
||
agent_name=self.name,
|
||
status=TaskStatus.COMPLETED,
|
||
output_data=output,
|
||
error_message=None,
|
||
started_at=datetime.now(timezone.utc),
|
||
completed_at=datetime.now(timezone.utc),
|
||
)
|
||
await self.evolve_after_task(task, result)
|
||
except Exception as e:
|
||
logger.warning(f"Evolution after task failed: {e}")
|
||
|
||
async def on_task_failed(self, task: TaskMessage, error: Exception) -> None:
|
||
"""Task failed hook - record failure for evolution"""
|
||
if self._evolution_enabled:
|
||
try:
|
||
from agentkit.core.protocol import TaskResult, TaskStatus
|
||
from datetime import datetime, timezone
|
||
|
||
result = TaskResult(
|
||
task_id=task.task_id,
|
||
agent_name=self.name,
|
||
status=TaskStatus.FAILED,
|
||
output_data=None,
|
||
error_message=str(error),
|
||
started_at=datetime.now(timezone.utc),
|
||
completed_at=datetime.now(timezone.utc),
|
||
)
|
||
await self.evolve_after_task(task, result)
|
||
except Exception as e:
|
||
logger.warning(f"Evolution after task failure failed: {e}")
|
||
|
||
def _bind_tools(self) -> None:
|
||
"""根据配置绑定工具"""
|
||
for tool_name in self._config.tools:
|
||
try:
|
||
tool = self._tool_registry.get(tool_name)
|
||
self.use_tool(tool)
|
||
logger.info(f"ConfigDrivenAgent '{self.name}' bound tool '{tool_name}'")
|
||
except Exception as e:
|
||
logger.warning(
|
||
f"ConfigDrivenAgent '{self.name}' failed to bind tool '{tool_name}': {e}"
|
||
)
|
||
|
||
def _auto_set_current_module(self) -> None:
|
||
"""Auto-set _current_module from SkillConfig for evolution.
|
||
|
||
Creates a Module from the current SkillConfig's instruction/prompt
|
||
so that prompt optimization has a target to work with.
|
||
"""
|
||
from agentkit.evolution.prompt_optimizer import Module, Signature
|
||
|
||
prompt = self._config.prompt or {}
|
||
instruction_parts = []
|
||
for key in ("identity", "instructions", "constraints"):
|
||
val = prompt.get(key, "")
|
||
if val:
|
||
instruction_parts.append(val)
|
||
instruction = "\n".join(instruction_parts)
|
||
|
||
input_fields = {}
|
||
if self._config.input_schema:
|
||
for field_name, field_info in self._config.input_schema.items():
|
||
input_fields[field_name] = (
|
||
str(field_info) if not isinstance(field_info, str) else field_info
|
||
)
|
||
|
||
output_fields = {}
|
||
if self._config.output_schema:
|
||
for field_name, field_info in self._config.output_schema.items():
|
||
output_fields[field_name] = (
|
||
str(field_info) if not isinstance(field_info, str) else field_info
|
||
)
|
||
|
||
module = Module(
|
||
name=self.name,
|
||
signature=Signature(
|
||
input_fields=input_fields or {"input": "task input"},
|
||
output_fields=output_fields or {"output": "task output"},
|
||
instruction=instruction,
|
||
),
|
||
)
|
||
self.set_current_module(module)
|
||
logger.debug(f"Auto-set _current_module for agent '{self.name}'")
|
||
|
||
async def _register_mcp_tools(self) -> None:
|
||
"""Lazily register tools from MCP servers as agent tools.
|
||
|
||
Called on first task execution to allow async MCP client operations.
|
||
"""
|
||
if self._mcp_tools_registered or not self._mcp_servers:
|
||
return
|
||
|
||
self._mcp_tools_registered = True
|
||
from agentkit.mcp.client import MCPClient
|
||
|
||
for server_name, base_url in self._mcp_servers.items():
|
||
try:
|
||
client = MCPClient(server_url=base_url)
|
||
self._mcp_clients.append(client)
|
||
|
||
# List available tools from the MCP server
|
||
tools = await client.list_tools()
|
||
for tool_info in tools:
|
||
tool_name = tool_info.get("name", "")
|
||
tool_desc = tool_info.get("description", "")
|
||
if not tool_name:
|
||
continue
|
||
|
||
# Create MCPTool and register it
|
||
mcp_tool = client.as_tool(tool_name, tool_desc)
|
||
self.use_tool(mcp_tool)
|
||
logger.info(
|
||
f"Agent '{self.name}' registered MCP tool '{tool_name}' "
|
||
f"from server '{server_name}'"
|
||
)
|
||
except Exception as e:
|
||
logger.warning(
|
||
f"Agent '{self.name}' failed to connect to MCP server "
|
||
f"'{server_name}' at {base_url}: {e}"
|
||
)
|
||
|
||
def get_capabilities(self) -> AgentCapability:
|
||
return AgentCapability(
|
||
agent_name=self.name,
|
||
agent_type=self.agent_type,
|
||
version=self.version,
|
||
supported_tasks=self._config.supported_tasks,
|
||
max_concurrency=self._config.max_concurrency,
|
||
description=self._config.description,
|
||
input_schema=self._config.input_schema,
|
||
output_schema=self._config.output_schema,
|
||
)
|
||
|
||
async def handle_task(self, task: TaskMessage) -> dict:
|
||
"""根据 execution_mode 和 task_mode 执行任务
|
||
|
||
v2 execution_mode 优先级:
|
||
- react: 使用 ReAct 引擎自主推理
|
||
- direct: 直接调用 LLM(不经过 ReAct 循环)
|
||
- custom: 使用自定义 handler
|
||
|
||
如果没有 SkillConfig,回退到传统 task_mode 分支。
|
||
"""
|
||
# Lazy-register MCP tools on first task execution
|
||
await self._register_mcp_tools()
|
||
|
||
# v2: execution_mode routing (when SkillConfig is present)
|
||
if self._skill_config:
|
||
execution_mode = self._skill_config.execution_mode
|
||
|
||
if execution_mode == "react" and self._react_engine:
|
||
return await self._handle_react(task)
|
||
elif execution_mode == "rewoo" and self._react_engine:
|
||
return await self._handle_rewoo(task)
|
||
elif execution_mode == "plan_exec" and self._react_engine:
|
||
return await self._handle_plan_exec(task)
|
||
elif execution_mode == "reflexion" and self._react_engine:
|
||
return await self._handle_reflexion(task)
|
||
elif execution_mode == "direct":
|
||
return await self._handle_direct(task)
|
||
elif execution_mode == "custom":
|
||
return await self._handle_custom(task)
|
||
|
||
# Fall back to existing task_mode modes
|
||
if self._config.task_mode == "llm_generate":
|
||
return await self._handle_llm_generate(task)
|
||
elif self._config.task_mode == "tool_call":
|
||
return await self._handle_tool_call(task)
|
||
elif self._config.task_mode == "custom":
|
||
return await self._handle_custom(task)
|
||
else:
|
||
raise ConfigValidationError(
|
||
agent_name=self.name,
|
||
key="task_mode",
|
||
reason=f"Unknown task_mode: {self._config.task_mode}",
|
||
)
|
||
|
||
async def _handle_react(self, task: TaskMessage) -> dict:
|
||
"""ReAct mode: use ReAct engine for autonomous reasoning"""
|
||
# Auto-set _current_module from SkillConfig if evolution is enabled
|
||
if self._evolution_enabled and self._current_module is None:
|
||
self._auto_set_current_module()
|
||
|
||
# Build variables for prompt rendering
|
||
variables = task.input_data.copy()
|
||
variables["task_type"] = task.task_type
|
||
|
||
# Use PromptTemplate.render() to get full messages (system + user)
|
||
if self._prompt_template:
|
||
rendered_messages = self._prompt_template.render(variables=variables)
|
||
else:
|
||
rendered_messages = [{"role": "user", "content": str(task.input_data)}]
|
||
|
||
# Separate system_prompt from user messages
|
||
# PromptTemplate.render() returns [system_msg, user_msg] or [user_msg]
|
||
system_prompt = None
|
||
user_messages = []
|
||
for msg in rendered_messages:
|
||
if msg["role"] == "system":
|
||
system_prompt = msg["content"]
|
||
else:
|
||
user_messages.append(msg)
|
||
|
||
# If no user messages, add a default one
|
||
if not user_messages:
|
||
user_messages.append({"role": "user", "content": str(task.input_data)})
|
||
|
||
# Get CancellationToken for this task (set by BaseAgent.execute)
|
||
cancellation_token = self._active_tokens.get(task.task_id)
|
||
|
||
# Determine timeout from task or config
|
||
timeout_seconds = float(task.timeout_seconds) if task.timeout_seconds > 0 else None
|
||
|
||
# Execute ReAct loop
|
||
retrieval_config = self._config.memory.get("retrieval", {}) if self._config.memory else {}
|
||
result = await self._react_engine.execute(
|
||
messages=user_messages,
|
||
tools=self.get_tools() or None,
|
||
model=self._config.llm.get("model", "default") if self._config.llm else "default",
|
||
agent_name=self.name,
|
||
task_type=task.task_type,
|
||
system_prompt=system_prompt,
|
||
memory_retriever=self._memory_retriever,
|
||
task_id=task.task_id,
|
||
retrieval_config=retrieval_config or None,
|
||
cancellation_token=cancellation_token,
|
||
timeout_seconds=timeout_seconds,
|
||
compressor=self._compressor,
|
||
)
|
||
|
||
# Parse result
|
||
return self._parse_llm_response(result.output)
|
||
|
||
async def _handle_rewoo(self, task: TaskMessage) -> dict:
|
||
"""ReWOO mode: plan all tool calls upfront, then execute in batch"""
|
||
from agentkit.core.rewoo import ReWOOEngine
|
||
|
||
variables = task.input_data.copy()
|
||
variables["task_type"] = task.task_type
|
||
|
||
if self._prompt_template:
|
||
rendered_messages = self._prompt_template.render(variables=variables)
|
||
else:
|
||
rendered_messages = [{"role": "user", "content": str(task.input_data)}]
|
||
|
||
system_prompt = None
|
||
user_messages = []
|
||
for msg in rendered_messages:
|
||
if msg["role"] == "system":
|
||
system_prompt = msg["content"]
|
||
else:
|
||
user_messages.append(msg)
|
||
|
||
if not user_messages:
|
||
user_messages.append({"role": "user", "content": str(task.input_data)})
|
||
|
||
cancellation_token = self._active_tokens.get(task.task_id)
|
||
timeout_seconds = float(task.timeout_seconds) if task.timeout_seconds > 0 else None
|
||
|
||
rewoo_engine = ReWOOEngine(
|
||
llm_gateway=self._llm_gateway,
|
||
max_plan_steps=self._skill_config.max_steps if self._skill_config else 5,
|
||
default_timeout=300.0,
|
||
fallback_strategies=(
|
||
self._skill_config.fallback_strategies
|
||
if self._skill_config and self._skill_config.fallback_strategies
|
||
else None
|
||
),
|
||
)
|
||
|
||
result = await rewoo_engine.execute(
|
||
messages=user_messages,
|
||
tools=self.get_tools() or None,
|
||
model=self._config.llm.get("model", "default") if self._config.llm else "default",
|
||
agent_name=self.name,
|
||
task_type=task.task_type,
|
||
system_prompt=system_prompt,
|
||
task_id=task.task_id,
|
||
cancellation_token=cancellation_token,
|
||
timeout_seconds=timeout_seconds,
|
||
)
|
||
|
||
return self._parse_llm_response(result.output)
|
||
|
||
async def _handle_plan_exec(self, task: TaskMessage) -> dict:
|
||
"""Plan-and-Execute mode: decompose task into plan, execute steps, replan if needed"""
|
||
from agentkit.core.plan_exec_engine import PlanExecEngine
|
||
|
||
variables = task.input_data.copy()
|
||
variables["task_type"] = task.task_type
|
||
|
||
if self._prompt_template:
|
||
rendered_messages = self._prompt_template.render(variables=variables)
|
||
else:
|
||
rendered_messages = [{"role": "user", "content": str(task.input_data)}]
|
||
|
||
system_prompt = None
|
||
user_messages = []
|
||
for msg in rendered_messages:
|
||
if msg["role"] == "system":
|
||
system_prompt = msg["content"]
|
||
else:
|
||
user_messages.append(msg)
|
||
|
||
if not user_messages:
|
||
user_messages.append({"role": "user", "content": str(task.input_data)})
|
||
|
||
cancellation_token = self._active_tokens.get(task.task_id)
|
||
timeout_seconds = float(task.timeout_seconds) if task.timeout_seconds > 0 else None
|
||
|
||
plan_exec_engine = PlanExecEngine(
|
||
llm_gateway=self._llm_gateway,
|
||
max_replans=2,
|
||
default_timeout=300.0,
|
||
)
|
||
|
||
result = await plan_exec_engine.execute(
|
||
messages=user_messages,
|
||
tools=self.get_tools() or None,
|
||
model=self._config.llm.get("model", "default") if self._config.llm else "default",
|
||
agent_name=self.name,
|
||
task_type=task.task_type,
|
||
system_prompt=system_prompt,
|
||
task_id=task.task_id,
|
||
cancellation_token=cancellation_token,
|
||
timeout_seconds=timeout_seconds,
|
||
)
|
||
|
||
return self._parse_llm_response(result.output)
|
||
|
||
async def _handle_reflexion(self, task: TaskMessage) -> dict:
|
||
"""Reflexion mode: ReAct + Evaluate + Reflect + Retry for high-precision tasks"""
|
||
from agentkit.core.reflexion import ReflexionEngine
|
||
|
||
variables = task.input_data.copy()
|
||
variables["task_type"] = task.task_type
|
||
|
||
if self._prompt_template:
|
||
rendered_messages = self._prompt_template.render(variables=variables)
|
||
else:
|
||
rendered_messages = [{"role": "user", "content": str(task.input_data)}]
|
||
|
||
system_prompt = None
|
||
user_messages = []
|
||
for msg in rendered_messages:
|
||
if msg["role"] == "system":
|
||
system_prompt = msg["content"]
|
||
else:
|
||
user_messages.append(msg)
|
||
|
||
if not user_messages:
|
||
user_messages.append({"role": "user", "content": str(task.input_data)})
|
||
|
||
cancellation_token = self._active_tokens.get(task.task_id)
|
||
timeout_seconds = float(task.timeout_seconds) if task.timeout_seconds > 0 else None
|
||
|
||
reflexion_engine = ReflexionEngine(
|
||
llm_gateway=self._llm_gateway,
|
||
max_steps=self._skill_config.max_steps if self._skill_config else 5,
|
||
max_reflections=3,
|
||
quality_threshold=0.7,
|
||
default_timeout=300.0,
|
||
)
|
||
|
||
result = await reflexion_engine.execute(
|
||
messages=user_messages,
|
||
tools=self.get_tools() or None,
|
||
model=self._config.llm.get("model", "default") if self._config.llm else "default",
|
||
agent_name=self.name,
|
||
task_type=task.task_type,
|
||
system_prompt=system_prompt,
|
||
task_id=task.task_id,
|
||
cancellation_token=cancellation_token,
|
||
timeout_seconds=timeout_seconds,
|
||
)
|
||
|
||
return self._parse_llm_response(result.output)
|
||
|
||
async def _handle_direct(self, task: TaskMessage) -> dict:
|
||
"""Direct mode: single LLM call without ReAct loop.
|
||
|
||
Renders the full prompt template and makes one LLM call via LLMGateway.
|
||
Falls back to _handle_llm_generate if no LLMGateway is available.
|
||
"""
|
||
if not self._llm_gateway:
|
||
return await self._handle_llm_generate(task)
|
||
|
||
# Build variables for prompt rendering
|
||
variables = task.input_data.copy()
|
||
variables["task_type"] = task.task_type
|
||
|
||
# Use PromptTemplate.render() to get full messages
|
||
if self._prompt_template:
|
||
rendered_messages = self._prompt_template.render(variables=variables)
|
||
else:
|
||
rendered_messages = [{"role": "user", "content": str(task.input_data)}]
|
||
|
||
# Make a single LLM call
|
||
model = self._config.llm.get("model", "default") if self._config.llm else "default"
|
||
response = await self._llm_gateway.chat(
|
||
messages=rendered_messages,
|
||
model=model,
|
||
agent_name=self.name,
|
||
task_type=task.task_type,
|
||
)
|
||
|
||
return self._parse_llm_response(response.content)
|
||
|
||
async def handle_task_with_feedback(self, task: TaskMessage, feedback: str) -> dict:
|
||
"""Re-execute task with quality feedback"""
|
||
enhanced_input = task.input_data.copy()
|
||
enhanced_input["quality_feedback"] = feedback
|
||
|
||
enhanced_task = TaskMessage(
|
||
task_id=task.task_id,
|
||
agent_name=task.agent_name,
|
||
task_type=task.task_type,
|
||
input_data=enhanced_input,
|
||
priority=task.priority,
|
||
created_at=task.created_at,
|
||
callback_url=task.callback_url,
|
||
timeout_seconds=task.timeout_seconds,
|
||
conversation_id=task.conversation_id,
|
||
)
|
||
return await self.handle_task(enhanced_task)
|
||
|
||
def _wrap_llm_client(self, llm_client: object):
|
||
"""Wrap legacy llm_client into LLMGateway"""
|
||
from agentkit.llm.gateway import LLMGateway
|
||
from agentkit.llm.protocol import LLMProvider, LLMRequest, LLMResponse, TokenUsage
|
||
|
||
class ClientProvider(LLMProvider):
|
||
"""Adapter: wraps legacy llm_client as an LLMProvider"""
|
||
|
||
def __init__(self, raw_client: object):
|
||
self._raw_client = raw_client
|
||
|
||
async def chat(self, request: LLMRequest) -> LLMResponse:
|
||
kwargs = dict(request._extra) if hasattr(request, "_extra") else {}
|
||
kwargs["model"] = request.model
|
||
kwargs["temperature"] = request.temperature
|
||
kwargs["max_tokens"] = request.max_tokens
|
||
|
||
if hasattr(self._raw_client, "chat"):
|
||
response = await self._raw_client.chat(messages=request.messages, **kwargs)
|
||
elif hasattr(self._raw_client, "create"):
|
||
response = await self._raw_client.create(messages=request.messages, **kwargs)
|
||
elif callable(self._raw_client):
|
||
response = await self._raw_client(messages=request.messages, **kwargs)
|
||
else:
|
||
raise ConfigValidationError(
|
||
agent_name="",
|
||
key="llm_client",
|
||
reason="LLM client must have 'chat'/'create' method or be callable",
|
||
)
|
||
|
||
# Normalize response to string
|
||
if isinstance(response, str):
|
||
content = response
|
||
elif isinstance(response, dict):
|
||
content = response.get("content", json.dumps(response))
|
||
elif hasattr(response, "content"):
|
||
content = response.content
|
||
else:
|
||
content = str(response)
|
||
|
||
return LLMResponse(
|
||
content=content,
|
||
model=request.model,
|
||
usage=TokenUsage(prompt_tokens=0, completion_tokens=0),
|
||
)
|
||
|
||
gateway = LLMGateway()
|
||
gateway.register_provider("wrapped", ClientProvider(llm_client))
|
||
return gateway
|
||
|
||
async def _handle_llm_generate(self, task: TaskMessage) -> dict:
|
||
"""LLM 生成模式:渲染 Prompt → 调用 LLM → 解析输出"""
|
||
if not self._prompt_template:
|
||
raise ConfigValidationError(
|
||
agent_name=self.name,
|
||
key="prompt",
|
||
reason="llm_generate mode requires prompt configuration",
|
||
)
|
||
|
||
# 渲染 Prompt
|
||
variables = task.input_data.copy()
|
||
variables["task_type"] = task.task_type
|
||
messages = self._prompt_template.render(variables=variables)
|
||
|
||
# 调用 LLM
|
||
if self._llm_client is None:
|
||
# 无 LLM 客户端时返回渲染后的 Prompt(降级模式)
|
||
return {
|
||
"mode": "llm_generate_no_client",
|
||
"messages": messages,
|
||
"note": "No LLM client configured, returning rendered prompt",
|
||
}
|
||
|
||
# 使用配置的 LLM 参数
|
||
llm_params = self._config.llm.copy()
|
||
response = await self._call_llm(messages, **llm_params)
|
||
return self._parse_llm_response(response)
|
||
|
||
async def _handle_tool_call(self, task: TaskMessage) -> dict:
|
||
"""工具调用模式:调用指定 Tool 并返回结果"""
|
||
if not self._tools:
|
||
raise ConfigValidationError(
|
||
agent_name=self.name,
|
||
key="tools",
|
||
reason="tool_call mode requires at least one bound tool",
|
||
)
|
||
|
||
# 使用第一个绑定的工具(或根据 task_type 匹配)
|
||
tool = self._resolve_tool(task)
|
||
result = await tool.safe_execute(**task.input_data)
|
||
return result
|
||
|
||
async def _handle_custom(self, task: TaskMessage) -> dict:
|
||
"""自定义模式:调用注册的 handler 函数"""
|
||
handler_name = self._config.custom_handler
|
||
if handler_name not in self._custom_handlers:
|
||
# 尝试动态导入
|
||
handler = self._import_handler(handler_name)
|
||
self._custom_handlers[handler_name] = handler
|
||
|
||
handler = self._custom_handlers[handler_name]
|
||
result = await handler(task)
|
||
if not isinstance(result, dict):
|
||
raise ConfigValidationError(
|
||
agent_name=self.name,
|
||
key="custom_handler",
|
||
reason=f"Custom handler '{handler_name}' must return dict, got {type(result)}",
|
||
)
|
||
return result
|
||
|
||
def _resolve_tool(self, task: TaskMessage) -> Tool:
|
||
"""根据任务类型解析要使用的工具"""
|
||
# 优先匹配 task_type 与 tool name
|
||
for tool in self._tools:
|
||
if task.task_type in tool.name or task.task_type.replace("_", "-") in tool.name:
|
||
return tool
|
||
|
||
# 回退到第一个工具
|
||
return self._tools[0]
|
||
|
||
async def _call_llm(self, messages: list[dict[str, str]], **kwargs) -> str:
|
||
"""调用 LLM 客户端"""
|
||
model = kwargs.pop("model", "default")
|
||
temperature = kwargs.pop("temperature", 0.7)
|
||
max_tokens = kwargs.pop("max_tokens", 2000)
|
||
|
||
# 标准化 LLM 客户端调用接口
|
||
if hasattr(self._llm_client, "chat"):
|
||
response = await self._llm_client.chat(
|
||
messages=messages,
|
||
model=model,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
**kwargs,
|
||
)
|
||
elif hasattr(self._llm_client, "create"):
|
||
response = await self._llm_client.create(
|
||
messages=messages,
|
||
model=model,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
**kwargs,
|
||
)
|
||
elif callable(self._llm_client):
|
||
response = await self._llm_client(
|
||
messages=messages,
|
||
model=model,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
**kwargs,
|
||
)
|
||
else:
|
||
raise ConfigValidationError(
|
||
agent_name=self.name,
|
||
key="llm_client",
|
||
reason="LLM client must have 'chat'/'create' method or be callable",
|
||
)
|
||
|
||
# 提取文本内容
|
||
if isinstance(response, str):
|
||
return response
|
||
if isinstance(response, dict):
|
||
return response.get("content", str(response))
|
||
if hasattr(response, "content"):
|
||
return response.content
|
||
return str(response)
|
||
|
||
def _parse_llm_response(self, response: str) -> dict:
|
||
"""解析 LLM 响应为 dict"""
|
||
# 尝试直接解析 JSON
|
||
try:
|
||
return json.loads(response)
|
||
except (json.JSONDecodeError, TypeError):
|
||
pass
|
||
|
||
# 尝试提取 JSON 块
|
||
import re
|
||
|
||
json_match = re.search(r"```(?:json)?\s*\n?(.*?)\n?```", response, re.DOTALL)
|
||
if json_match:
|
||
try:
|
||
return json.loads(json_match.group(1))
|
||
except (json.JSONDecodeError, TypeError):
|
||
pass
|
||
|
||
# 降级:包装为文本结果
|
||
return {"text": response}
|
||
|
||
def _import_handler(self, dotted_path: str) -> Callable[..., Coroutine]:
|
||
"""动态导入自定义 handler"""
|
||
# Security: validate module prefix to prevent arbitrary code execution
|
||
if not any(dotted_path.startswith(prefix) for prefix in self._ALLOWED_HANDLER_PREFIXES):
|
||
raise ConfigValidationError(
|
||
agent_name=self.name,
|
||
key="custom_handler",
|
||
reason=f"Handler '{dotted_path}' is not in allowed module prefixes: {self._ALLOWED_HANDLER_PREFIXES}",
|
||
)
|
||
|
||
try:
|
||
module_path, func_name = dotted_path.rsplit(".", 1)
|
||
import importlib
|
||
|
||
module = importlib.import_module(module_path)
|
||
handler = getattr(module, func_name)
|
||
if not callable(handler):
|
||
raise ConfigValidationError(
|
||
agent_name=self.name,
|
||
key="custom_handler",
|
||
reason=f"'{dotted_path}' is not callable",
|
||
)
|
||
return handler
|
||
except (ImportError, AttributeError, ValueError) as e:
|
||
raise ConfigValidationError(
|
||
agent_name=self.name,
|
||
key="custom_handler",
|
||
reason=f"Failed to import custom handler '{dotted_path}': {e}",
|
||
)
|