feat(core): add ConfigDrivenAgent with YAML-driven agent definition
- AgentConfig: YAML/dict config model with validation - ConfigDrivenAgent: 3 task modes (llm_generate, tool_call, custom) - StandaloneRunner: auto-discover YAML configs and build agents - 25 new tests covering all modes and edge cases - Total: 56 tests passing
This commit is contained in:
parent
2ddffcdf37
commit
5a90824c77
|
|
@ -1,6 +1,7 @@
|
||||||
"""Fischer AgentKit - Unified Agent Framework"""
|
"""Fischer AgentKit - Unified Agent Framework"""
|
||||||
|
|
||||||
from agentkit.core.base import BaseAgent
|
from agentkit.core.base import BaseAgent
|
||||||
|
from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent
|
||||||
from agentkit.core.protocol import (
|
from agentkit.core.protocol import (
|
||||||
AgentCapability,
|
AgentCapability,
|
||||||
AgentStatus,
|
AgentStatus,
|
||||||
|
|
@ -15,6 +16,8 @@ __version__ = "0.1.0"
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseAgent",
|
"BaseAgent",
|
||||||
|
"AgentConfig",
|
||||||
|
"ConfigDrivenAgent",
|
||||||
"AgentCapability",
|
"AgentCapability",
|
||||||
"AgentStatus",
|
"AgentStatus",
|
||||||
"HandoffMessage",
|
"HandoffMessage",
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
"""AgentKit Core - 基础组件"""
|
"""AgentKit Core - 基础组件"""
|
||||||
|
|
||||||
from agentkit.core.base import BaseAgent
|
from agentkit.core.base import BaseAgent
|
||||||
|
from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent
|
||||||
from agentkit.core.exceptions import (
|
from agentkit.core.exceptions import (
|
||||||
AgentAlreadyRegisteredError,
|
AgentAlreadyRegisteredError,
|
||||||
AgentFrameworkError,
|
AgentFrameworkError,
|
||||||
|
|
@ -33,6 +34,8 @@ from agentkit.core.protocol import (
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseAgent",
|
"BaseAgent",
|
||||||
|
"AgentConfig",
|
||||||
|
"ConfigDrivenAgent",
|
||||||
"AgentCapability",
|
"AgentCapability",
|
||||||
"AgentStatus",
|
"AgentStatus",
|
||||||
"AgentFrameworkError",
|
"AgentFrameworkError",
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,421 @@
|
||||||
|
"""ConfigDrivenAgent - 配置驱动的 Agent 定义
|
||||||
|
|
||||||
|
核心设计:
|
||||||
|
- 从 YAML/Dict 配置自动组装 Agent(Prompt + LLM + Tool + Memory)
|
||||||
|
- 支持三种任务模式:llm_generate / tool_call / custom
|
||||||
|
- 新增 Agent 从写 150 行代码降为 10-20 行配置
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any, Callable, Coroutine
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from agentkit.core.base import BaseAgent
|
||||||
|
from agentkit.core.exceptions import ConfigValidationError
|
||||||
|
from agentkit.core.protocol import AgentCapability, TaskMessage
|
||||||
|
from agentkit.prompts.section import PromptSection
|
||||||
|
from agentkit.prompts.template import PromptTemplate
|
||||||
|
from agentkit.tools.base import Tool
|
||||||
|
from agentkit.tools.registry import ToolRegistry
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentConfig:
|
||||||
|
"""Agent 配置模型,从 YAML 或 Dict 构建"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
agent_type: str,
|
||||||
|
version: str = "1.0.0",
|
||||||
|
description: str = "",
|
||||||
|
task_mode: str = "llm_generate",
|
||||||
|
supported_tasks: list[str] | None = None,
|
||||||
|
max_concurrency: int = 1,
|
||||||
|
input_schema: dict[str, Any] | None = None,
|
||||||
|
output_schema: dict[str, Any] | None = None,
|
||||||
|
prompt: dict[str, str] | None = None,
|
||||||
|
llm: dict[str, Any] | None = None,
|
||||||
|
tools: list[str] | None = None,
|
||||||
|
memory: dict[str, Any] | None = None,
|
||||||
|
custom_handler: str | None = None,
|
||||||
|
):
|
||||||
|
self.name = name
|
||||||
|
self.agent_type = agent_type
|
||||||
|
self.version = version
|
||||||
|
self.description = description
|
||||||
|
self.task_mode = task_mode
|
||||||
|
self.supported_tasks = supported_tasks or [agent_type]
|
||||||
|
self.max_concurrency = max_concurrency
|
||||||
|
self.input_schema = input_schema
|
||||||
|
self.output_schema = output_schema
|
||||||
|
self.prompt = prompt or {}
|
||||||
|
self.llm = llm or {}
|
||||||
|
self.tools = tools or []
|
||||||
|
self.memory = memory or {}
|
||||||
|
self.custom_handler = custom_handler
|
||||||
|
|
||||||
|
self._validate()
|
||||||
|
|
||||||
|
def _validate(self) -> None:
|
||||||
|
"""校验配置合法性"""
|
||||||
|
valid_modes = {"llm_generate", "tool_call", "custom"}
|
||||||
|
if self.task_mode not in valid_modes:
|
||||||
|
raise ConfigValidationError(
|
||||||
|
agent_name=self.name,
|
||||||
|
key="task_mode",
|
||||||
|
reason=f"Invalid task_mode '{self.task_mode}', must be one of {valid_modes}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.task_mode == "llm_generate" and not self.prompt:
|
||||||
|
raise ConfigValidationError(
|
||||||
|
agent_name=self.name,
|
||||||
|
key="prompt",
|
||||||
|
reason="llm_generate mode requires 'prompt' configuration",
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.task_mode == "tool_call" and not self.tools:
|
||||||
|
raise ConfigValidationError(
|
||||||
|
agent_name=self.name,
|
||||||
|
key="tools",
|
||||||
|
reason="tool_call mode requires at least one tool in 'tools' list",
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.task_mode == "custom" and not self.custom_handler:
|
||||||
|
raise ConfigValidationError(
|
||||||
|
agent_name=self.name,
|
||||||
|
key="custom_handler",
|
||||||
|
reason="custom mode requires 'custom_handler' (dotted path to callable)",
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: dict[str, Any]) -> "AgentConfig":
|
||||||
|
"""从字典创建配置"""
|
||||||
|
return cls(
|
||||||
|
name=data["name"],
|
||||||
|
agent_type=data["agent_type"],
|
||||||
|
version=data.get("version", "1.0.0"),
|
||||||
|
description=data.get("description", ""),
|
||||||
|
task_mode=data.get("task_mode", "llm_generate"),
|
||||||
|
supported_tasks=data.get("supported_tasks"),
|
||||||
|
max_concurrency=data.get("max_concurrency", 1),
|
||||||
|
input_schema=data.get("input_schema"),
|
||||||
|
output_schema=data.get("output_schema"),
|
||||||
|
prompt=data.get("prompt"),
|
||||||
|
llm=data.get("llm"),
|
||||||
|
tools=data.get("tools"),
|
||||||
|
memory=data.get("memory"),
|
||||||
|
custom_handler=data.get("custom_handler"),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_yaml(cls, path: str) -> "AgentConfig":
|
||||||
|
"""从 YAML 文件加载配置"""
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
|
data = yaml.safe_load(f)
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
raise ConfigValidationError(
|
||||||
|
agent_name="unknown",
|
||||||
|
key="config",
|
||||||
|
reason=f"YAML config must be a mapping, got {type(data)}",
|
||||||
|
)
|
||||||
|
return cls.from_dict(data)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""序列化为字典"""
|
||||||
|
d = {
|
||||||
|
"name": self.name,
|
||||||
|
"agent_type": self.agent_type,
|
||||||
|
"version": self.version,
|
||||||
|
"description": self.description,
|
||||||
|
"task_mode": self.task_mode,
|
||||||
|
"supported_tasks": self.supported_tasks,
|
||||||
|
"max_concurrency": self.max_concurrency,
|
||||||
|
}
|
||||||
|
if self.input_schema:
|
||||||
|
d["input_schema"] = self.input_schema
|
||||||
|
if self.output_schema:
|
||||||
|
d["output_schema"] = self.output_schema
|
||||||
|
if self.prompt:
|
||||||
|
d["prompt"] = self.prompt
|
||||||
|
if self.llm:
|
||||||
|
d["llm"] = self.llm
|
||||||
|
if self.tools:
|
||||||
|
d["tools"] = self.tools
|
||||||
|
if self.memory:
|
||||||
|
d["memory"] = self.memory
|
||||||
|
if self.custom_handler:
|
||||||
|
d["custom_handler"] = self.custom_handler
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigDrivenAgent(BaseAgent):
|
||||||
|
"""配置驱动的 Agent
|
||||||
|
|
||||||
|
从 YAML/Dict 配置自动组装,支持三种任务模式:
|
||||||
|
- llm_generate: 渲染 Prompt → 调用 LLM → 解析 JSON 输出
|
||||||
|
- tool_call: 调用注册的 Tool 并返回结果
|
||||||
|
- custom: 自定义 handler 函数
|
||||||
|
|
||||||
|
示例 YAML 配置::
|
||||||
|
|
||||||
|
name: content_generator
|
||||||
|
agent_type: content_generation
|
||||||
|
task_mode: llm_generate
|
||||||
|
description: "内容生成 Agent"
|
||||||
|
prompt:
|
||||||
|
identity: "你是一个专业的内容生成助手"
|
||||||
|
instructions: "根据用户需求生成高质量内容"
|
||||||
|
output_format: "以 JSON 格式输出"
|
||||||
|
llm:
|
||||||
|
model: "gpt-4"
|
||||||
|
temperature: 0.7
|
||||||
|
tools:
|
||||||
|
- retrieve_knowledge
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: AgentConfig,
|
||||||
|
tool_registry: ToolRegistry | None = None,
|
||||||
|
llm_client: Any = None,
|
||||||
|
custom_handlers: dict[str, Callable[..., Coroutine]] | None = None,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
name=config.name,
|
||||||
|
agent_type=config.agent_type,
|
||||||
|
version=config.version,
|
||||||
|
)
|
||||||
|
self._config = config
|
||||||
|
self._tool_registry = tool_registry or ToolRegistry()
|
||||||
|
self._llm_client = llm_client
|
||||||
|
self._custom_handlers = custom_handlers or {}
|
||||||
|
self._prompt_template: PromptTemplate | None = None
|
||||||
|
|
||||||
|
# 从配置构建 Prompt 模板
|
||||||
|
if config.prompt:
|
||||||
|
sections = PromptSection(
|
||||||
|
identity=config.prompt.get("identity", ""),
|
||||||
|
context=config.prompt.get("context", ""),
|
||||||
|
instructions=config.prompt.get("instructions", ""),
|
||||||
|
constraints=config.prompt.get("constraints", ""),
|
||||||
|
output_format=config.prompt.get("output_format", ""),
|
||||||
|
examples=config.prompt.get("examples", ""),
|
||||||
|
)
|
||||||
|
self._prompt_template = PromptTemplate(
|
||||||
|
sections=sections,
|
||||||
|
name=config.name,
|
||||||
|
version=config.version,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 从配置绑定 Tool
|
||||||
|
self._bind_tools()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def config(self) -> AgentConfig:
|
||||||
|
return self._config
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prompt_template(self) -> PromptTemplate | None:
|
||||||
|
return self._prompt_template
|
||||||
|
|
||||||
|
def _bind_tools(self) -> None:
|
||||||
|
"""根据配置绑定工具"""
|
||||||
|
for tool_name in self._config.tools:
|
||||||
|
try:
|
||||||
|
tool = self._tool_registry.get(tool_name)
|
||||||
|
self.use_tool(tool)
|
||||||
|
logger.info(f"ConfigDrivenAgent '{self.name}' bound tool '{tool_name}'")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"ConfigDrivenAgent '{self.name}' failed to bind tool '{tool_name}': {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_capabilities(self) -> AgentCapability:
|
||||||
|
return AgentCapability(
|
||||||
|
agent_name=self.name,
|
||||||
|
agent_type=self.agent_type,
|
||||||
|
version=self.version,
|
||||||
|
supported_tasks=self._config.supported_tasks,
|
||||||
|
max_concurrency=self._config.max_concurrency,
|
||||||
|
description=self._config.description,
|
||||||
|
input_schema=self._config.input_schema,
|
||||||
|
output_schema=self._config.output_schema,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def handle_task(self, task: TaskMessage) -> dict:
|
||||||
|
"""根据 task_mode 执行任务"""
|
||||||
|
if self._config.task_mode == "llm_generate":
|
||||||
|
return await self._handle_llm_generate(task)
|
||||||
|
elif self._config.task_mode == "tool_call":
|
||||||
|
return await self._handle_tool_call(task)
|
||||||
|
elif self._config.task_mode == "custom":
|
||||||
|
return await self._handle_custom(task)
|
||||||
|
else:
|
||||||
|
raise ConfigValidationError(
|
||||||
|
agent_name=self.name,
|
||||||
|
key="task_mode",
|
||||||
|
reason=f"Unknown task_mode: {self._config.task_mode}",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _handle_llm_generate(self, task: TaskMessage) -> dict:
|
||||||
|
"""LLM 生成模式:渲染 Prompt → 调用 LLM → 解析输出"""
|
||||||
|
if not self._prompt_template:
|
||||||
|
raise ConfigValidationError(
|
||||||
|
agent_name=self.name,
|
||||||
|
key="prompt",
|
||||||
|
reason="llm_generate mode requires prompt configuration",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 渲染 Prompt
|
||||||
|
variables = task.input_data.copy()
|
||||||
|
variables["task_type"] = task.task_type
|
||||||
|
messages = self._prompt_template.render(variables=variables)
|
||||||
|
|
||||||
|
# 调用 LLM
|
||||||
|
if self._llm_client is None:
|
||||||
|
# 无 LLM 客户端时返回渲染后的 Prompt(降级模式)
|
||||||
|
return {
|
||||||
|
"mode": "llm_generate_no_client",
|
||||||
|
"messages": messages,
|
||||||
|
"note": "No LLM client configured, returning rendered prompt",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 使用配置的 LLM 参数
|
||||||
|
llm_params = self._config.llm.copy()
|
||||||
|
response = await self._call_llm(messages, **llm_params)
|
||||||
|
return self._parse_llm_response(response)
|
||||||
|
|
||||||
|
async def _handle_tool_call(self, task: TaskMessage) -> dict:
|
||||||
|
"""工具调用模式:调用指定 Tool 并返回结果"""
|
||||||
|
if not self._tools:
|
||||||
|
raise ConfigValidationError(
|
||||||
|
agent_name=self.name,
|
||||||
|
key="tools",
|
||||||
|
reason="tool_call mode requires at least one bound tool",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 使用第一个绑定的工具(或根据 task_type 匹配)
|
||||||
|
tool = self._resolve_tool(task)
|
||||||
|
result = await tool.safe_execute(**task.input_data)
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def _handle_custom(self, task: TaskMessage) -> dict:
|
||||||
|
"""自定义模式:调用注册的 handler 函数"""
|
||||||
|
handler_name = self._config.custom_handler
|
||||||
|
if handler_name not in self._custom_handlers:
|
||||||
|
# 尝试动态导入
|
||||||
|
handler = self._import_handler(handler_name)
|
||||||
|
self._custom_handlers[handler_name] = handler
|
||||||
|
|
||||||
|
handler = self._custom_handlers[handler_name]
|
||||||
|
result = await handler(task)
|
||||||
|
if not isinstance(result, dict):
|
||||||
|
raise ConfigValidationError(
|
||||||
|
agent_name=self.name,
|
||||||
|
key="custom_handler",
|
||||||
|
reason=f"Custom handler '{handler_name}' must return dict, got {type(result)}",
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _resolve_tool(self, task: TaskMessage) -> Tool:
|
||||||
|
"""根据任务类型解析要使用的工具"""
|
||||||
|
# 优先匹配 task_type 与 tool name
|
||||||
|
for tool in self._tools:
|
||||||
|
if task.task_type in tool.name or task.task_type.replace("_", "-") in tool.name:
|
||||||
|
return tool
|
||||||
|
|
||||||
|
# 回退到第一个工具
|
||||||
|
return self._tools[0]
|
||||||
|
|
||||||
|
async def _call_llm(self, messages: list[dict[str, str]], **kwargs) -> str:
|
||||||
|
"""调用 LLM 客户端"""
|
||||||
|
model = kwargs.pop("model", "gpt-4")
|
||||||
|
temperature = kwargs.pop("temperature", 0.7)
|
||||||
|
max_tokens = kwargs.pop("max_tokens", 2000)
|
||||||
|
|
||||||
|
# 标准化 LLM 客户端调用接口
|
||||||
|
if hasattr(self._llm_client, "chat"):
|
||||||
|
response = await self._llm_client.chat(
|
||||||
|
messages=messages,
|
||||||
|
model=model,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
elif hasattr(self._llm_client, "create"):
|
||||||
|
response = await self._llm_client.create(
|
||||||
|
messages=messages,
|
||||||
|
model=model,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
elif callable(self._llm_client):
|
||||||
|
response = await self._llm_client(
|
||||||
|
messages=messages,
|
||||||
|
model=model,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ConfigValidationError(
|
||||||
|
agent_name=self.name,
|
||||||
|
key="llm_client",
|
||||||
|
reason="LLM client must have 'chat'/'create' method or be callable",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 提取文本内容
|
||||||
|
if isinstance(response, str):
|
||||||
|
return response
|
||||||
|
if isinstance(response, dict):
|
||||||
|
return response.get("content", str(response))
|
||||||
|
if hasattr(response, "content"):
|
||||||
|
return response.content
|
||||||
|
return str(response)
|
||||||
|
|
||||||
|
def _parse_llm_response(self, response: str) -> dict:
|
||||||
|
"""解析 LLM 响应为 dict"""
|
||||||
|
import json
|
||||||
|
|
||||||
|
# 尝试直接解析 JSON
|
||||||
|
try:
|
||||||
|
return json.loads(response)
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 尝试提取 JSON 块
|
||||||
|
import re
|
||||||
|
json_match = re.search(r"```(?:json)?\s*\n?(.*?)\n?```", response, re.DOTALL)
|
||||||
|
if json_match:
|
||||||
|
try:
|
||||||
|
return json.loads(json_match.group(1))
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 降级:包装为文本结果
|
||||||
|
return {"text": response}
|
||||||
|
|
||||||
|
def _import_handler(self, dotted_path: str) -> Callable[..., Coroutine]:
|
||||||
|
"""动态导入自定义 handler"""
|
||||||
|
try:
|
||||||
|
module_path, func_name = dotted_path.rsplit(".", 1)
|
||||||
|
import importlib
|
||||||
|
module = importlib.import_module(module_path)
|
||||||
|
handler = getattr(module, func_name)
|
||||||
|
if not callable(handler):
|
||||||
|
raise ConfigValidationError(
|
||||||
|
agent_name=self.name,
|
||||||
|
key="custom_handler",
|
||||||
|
reason=f"'{dotted_path}' is not callable",
|
||||||
|
)
|
||||||
|
return handler
|
||||||
|
except (ImportError, AttributeError, ValueError) as e:
|
||||||
|
raise ConfigValidationError(
|
||||||
|
agent_name=self.name,
|
||||||
|
key="custom_handler",
|
||||||
|
reason=f"Failed to import custom handler '{dotted_path}': {e}",
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,173 @@
|
||||||
|
"""Standalone Runner - 自动发现并启动配置驱动的 Agent
|
||||||
|
|
||||||
|
扫描 agent_configs/ 目录下的 YAML 文件,自动注册和启动 Agent。
|
||||||
|
支持命令行启动:python -m agentkit.core.standalone
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent
|
||||||
|
from agentkit.tools.function_tool import FunctionTool
|
||||||
|
from agentkit.tools.registry import ToolRegistry
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_CONFIG_DIR = "agent_configs"
|
||||||
|
|
||||||
|
|
||||||
|
class StandaloneRunner:
|
||||||
|
"""自动发现并启动配置驱动的 Agent
|
||||||
|
|
||||||
|
用法::
|
||||||
|
|
||||||
|
runner = StandaloneRunner(config_dir="agent_configs")
|
||||||
|
runner.add_tool(FunctionTool.from_func(my_tool_func))
|
||||||
|
await runner.start_all(redis_url="redis://localhost:6379")
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config_dir: str = DEFAULT_CONFIG_DIR,
|
||||||
|
tool_registry: ToolRegistry | None = None,
|
||||||
|
llm_client=None,
|
||||||
|
custom_handlers: dict | None = None,
|
||||||
|
):
|
||||||
|
self._config_dir = config_dir
|
||||||
|
self._tool_registry = tool_registry or ToolRegistry()
|
||||||
|
self._llm_client = llm_client
|
||||||
|
self._custom_handlers = custom_handlers or {}
|
||||||
|
self._agents: dict[str, ConfigDrivenAgent] = {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def agents(self) -> dict[str, ConfigDrivenAgent]:
|
||||||
|
return self._agents
|
||||||
|
|
||||||
|
def add_tool(self, tool) -> "StandaloneRunner":
|
||||||
|
"""添加工具到注册中心"""
|
||||||
|
self._tool_registry.register(tool)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def add_custom_handler(self, name: str, handler) -> "StandaloneRunner":
|
||||||
|
"""注册自定义 handler"""
|
||||||
|
self._custom_handlers[name] = handler
|
||||||
|
return self
|
||||||
|
|
||||||
|
def discover_configs(self) -> list[AgentConfig]:
|
||||||
|
"""扫描配置目录,发现所有 YAML 配置"""
|
||||||
|
configs = []
|
||||||
|
config_path = Path(self._config_dir)
|
||||||
|
|
||||||
|
if not config_path.exists():
|
||||||
|
logger.warning(f"Config directory '{self._config_dir}' not found")
|
||||||
|
return configs
|
||||||
|
|
||||||
|
for yaml_file in sorted(config_path.glob("*.yaml")):
|
||||||
|
try:
|
||||||
|
config = AgentConfig.from_yaml(str(yaml_file))
|
||||||
|
configs.append(config)
|
||||||
|
logger.info(f"Discovered agent config: {config.name} from {yaml_file.name}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load config '{yaml_file}': {e}")
|
||||||
|
|
||||||
|
for yml_file in sorted(config_path.glob("*.yml")):
|
||||||
|
try:
|
||||||
|
config = AgentConfig.from_yaml(str(yml_file))
|
||||||
|
configs.append(config)
|
||||||
|
logger.info(f"Discovered agent config: {config.name} from {yml_file.name}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load config '{yml_file}': {e}")
|
||||||
|
|
||||||
|
return configs
|
||||||
|
|
||||||
|
def build_agents(self) -> dict[str, ConfigDrivenAgent]:
|
||||||
|
"""从发现的配置构建 Agent 实例"""
|
||||||
|
configs = self.discover_configs()
|
||||||
|
self._agents.clear()
|
||||||
|
|
||||||
|
for config in configs:
|
||||||
|
try:
|
||||||
|
agent = ConfigDrivenAgent(
|
||||||
|
config=config,
|
||||||
|
tool_registry=self._tool_registry,
|
||||||
|
llm_client=self._llm_client,
|
||||||
|
custom_handlers=self._custom_handlers,
|
||||||
|
)
|
||||||
|
self._agents[config.name] = agent
|
||||||
|
logger.info(f"Built agent: {config.name} (mode={config.task_mode})")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to build agent '{config.name}': {e}")
|
||||||
|
|
||||||
|
return self._agents
|
||||||
|
|
||||||
|
async def start_all(self, redis_url: str = "") -> None:
|
||||||
|
"""启动所有已构建的 Agent"""
|
||||||
|
if not self._agents:
|
||||||
|
self.build_agents()
|
||||||
|
|
||||||
|
for name, agent in self._agents.items():
|
||||||
|
try:
|
||||||
|
await agent.start(redis_url=redis_url)
|
||||||
|
logger.info(f"Agent '{name}' started")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to start agent '{name}': {e}")
|
||||||
|
|
||||||
|
async def stop_all(self) -> None:
|
||||||
|
"""停止所有 Agent"""
|
||||||
|
for name, agent in self._agents.items():
|
||||||
|
try:
|
||||||
|
await agent.stop()
|
||||||
|
logger.info(f"Agent '{name}' stopped")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to stop agent '{name}': {e}")
|
||||||
|
|
||||||
|
async def execute_task(self, agent_name: str, task) -> dict | None:
|
||||||
|
"""在指定 Agent 上执行任务(本地模式)"""
|
||||||
|
if agent_name not in self._agents:
|
||||||
|
logger.error(f"Agent '{agent_name}' not found")
|
||||||
|
return None
|
||||||
|
|
||||||
|
agent = self._agents[agent_name]
|
||||||
|
result = await agent.execute(task)
|
||||||
|
return result.to_dict()
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""命令行入口"""
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s [%(name)s] %(levelname)s: %(message)s",
|
||||||
|
)
|
||||||
|
|
||||||
|
config_dir = sys.argv[1] if len(sys.argv) > 1 else DEFAULT_CONFIG_DIR
|
||||||
|
redis_url = os.environ.get("REDIS_URL", "")
|
||||||
|
|
||||||
|
runner = StandaloneRunner(config_dir=config_dir)
|
||||||
|
agents = runner.build_agents()
|
||||||
|
|
||||||
|
if not agents:
|
||||||
|
logger.error("No agents discovered. Check your config directory.")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
logger.info(f"Discovered {len(agents)} agent(s): {list(agents.keys())}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
await runner.start_all(redis_url=redis_url)
|
||||||
|
logger.info("All agents started. Press Ctrl+C to stop.")
|
||||||
|
|
||||||
|
# 保持运行
|
||||||
|
while True:
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Shutting down...")
|
||||||
|
finally:
|
||||||
|
await runner.stop_all()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
|
|
@ -0,0 +1,356 @@
|
||||||
|
"""U2 测试: ConfigDrivenAgent + YAML 配置驱动"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent
|
||||||
|
from agentkit.core.protocol import TaskMessage, TaskStatus
|
||||||
|
from agentkit.core.standalone import StandaloneRunner
|
||||||
|
from agentkit.tools.function_tool import FunctionTool
|
||||||
|
from agentkit.tools.registry import ToolRegistry
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fixtures ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _make_task(**overrides) -> TaskMessage:
|
||||||
|
defaults = dict(
|
||||||
|
task_id="test-task-001",
|
||||||
|
agent_name="test_agent",
|
||||||
|
task_type="generate",
|
||||||
|
priority=1,
|
||||||
|
input_data={"query": "hello"},
|
||||||
|
callback_url=None,
|
||||||
|
created_at=None,
|
||||||
|
)
|
||||||
|
defaults.update(overrides)
|
||||||
|
return TaskMessage.from_dict(defaults)
|
||||||
|
|
||||||
|
|
||||||
|
def _sample_llm_config() -> dict:
|
||||||
|
return {
|
||||||
|
"name": "content_generator",
|
||||||
|
"agent_type": "content_generation",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"description": "内容生成 Agent",
|
||||||
|
"task_mode": "llm_generate",
|
||||||
|
"supported_tasks": ["content_generation"],
|
||||||
|
"max_concurrency": 2,
|
||||||
|
"prompt": {
|
||||||
|
"identity": "你是一个专业的内容生成助手",
|
||||||
|
"instructions": "根据用户需求生成高质量内容",
|
||||||
|
"output_format": "以 JSON 格式输出 {title, content}",
|
||||||
|
},
|
||||||
|
"llm": {
|
||||||
|
"model": "gpt-4",
|
||||||
|
"temperature": 0.7,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _sample_tool_call_config() -> dict:
|
||||||
|
return {
|
||||||
|
"name": "citation_detector",
|
||||||
|
"agent_type": "citation_detection",
|
||||||
|
"task_mode": "tool_call",
|
||||||
|
"description": "引用检测 Agent",
|
||||||
|
"tools": ["check_citation"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _sample_custom_config() -> dict:
|
||||||
|
return {
|
||||||
|
"name": "monitor",
|
||||||
|
"agent_type": "monitoring",
|
||||||
|
"task_mode": "custom",
|
||||||
|
"description": "监控 Agent",
|
||||||
|
"custom_handler": "my_handler",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── AgentConfig 测试 ──────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentConfig:
|
||||||
|
def test_from_dict_llm_generate(self):
|
||||||
|
config = AgentConfig.from_dict(_sample_llm_config())
|
||||||
|
assert config.name == "content_generator"
|
||||||
|
assert config.task_mode == "llm_generate"
|
||||||
|
assert config.prompt["identity"] == "你是一个专业的内容生成助手"
|
||||||
|
assert config.llm["model"] == "gpt-4"
|
||||||
|
|
||||||
|
def test_from_dict_tool_call(self):
|
||||||
|
config = AgentConfig.from_dict(_sample_tool_call_config())
|
||||||
|
assert config.task_mode == "tool_call"
|
||||||
|
assert config.tools == ["check_citation"]
|
||||||
|
|
||||||
|
def test_from_dict_custom(self):
|
||||||
|
config = AgentConfig.from_dict(_sample_custom_config())
|
||||||
|
assert config.task_mode == "custom"
|
||||||
|
assert config.custom_handler == "my_handler"
|
||||||
|
|
||||||
|
def test_invalid_task_mode(self):
|
||||||
|
with pytest.raises(Exception, match="Invalid task_mode"):
|
||||||
|
AgentConfig(name="x", agent_type="x", task_mode="invalid_mode")
|
||||||
|
|
||||||
|
def test_llm_generate_requires_prompt(self):
|
||||||
|
with pytest.raises(Exception, match="llm_generate mode requires"):
|
||||||
|
AgentConfig(name="x", agent_type="x", task_mode="llm_generate", prompt=None)
|
||||||
|
|
||||||
|
def test_tool_call_requires_tools(self):
|
||||||
|
with pytest.raises(Exception, match="tool_call mode requires"):
|
||||||
|
AgentConfig(name="x", agent_type="x", task_mode="tool_call", tools=[])
|
||||||
|
|
||||||
|
def test_custom_requires_handler(self):
|
||||||
|
with pytest.raises(Exception, match="custom mode requires"):
|
||||||
|
AgentConfig(name="x", agent_type="x", task_mode="custom", custom_handler=None)
|
||||||
|
|
||||||
|
def test_from_yaml(self):
|
||||||
|
with tempfile.NamedTemporaryFile(
|
||||||
|
mode="w", suffix=".yaml", delete=False
|
||||||
|
) as f:
|
||||||
|
yaml.dump(_sample_llm_config(), f)
|
||||||
|
f.flush()
|
||||||
|
config = AgentConfig.from_yaml(f.name)
|
||||||
|
|
||||||
|
assert config.name == "content_generator"
|
||||||
|
assert config.task_mode == "llm_generate"
|
||||||
|
|
||||||
|
def test_to_dict_roundtrip(self):
|
||||||
|
original = _sample_llm_config()
|
||||||
|
config = AgentConfig.from_dict(original)
|
||||||
|
result = config.to_dict()
|
||||||
|
assert result["name"] == original["name"]
|
||||||
|
assert result["task_mode"] == original["task_mode"]
|
||||||
|
assert result["prompt"] == original["prompt"]
|
||||||
|
|
||||||
|
|
||||||
|
# ── ConfigDrivenAgent 测试 ────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestConfigDrivenAgent:
|
||||||
|
async def test_llm_generate_no_client(self):
|
||||||
|
"""无 LLM 客户端时降级返回渲染后的 Prompt"""
|
||||||
|
config = AgentConfig.from_dict(_sample_llm_config())
|
||||||
|
agent = ConfigDrivenAgent(config=config)
|
||||||
|
|
||||||
|
task = _make_task()
|
||||||
|
result = await agent.handle_task(task)
|
||||||
|
|
||||||
|
assert result["mode"] == "llm_generate_no_client"
|
||||||
|
assert len(result["messages"]) == 2 # system + user
|
||||||
|
|
||||||
|
async def test_llm_generate_with_client(self):
|
||||||
|
"""有 LLM 客户端时调用 LLM 并解析结果"""
|
||||||
|
|
||||||
|
class MockLLMClient:
|
||||||
|
async def chat(self, messages, **kwargs):
|
||||||
|
return json.dumps({"title": "Test", "content": "Hello world"})
|
||||||
|
|
||||||
|
config = AgentConfig.from_dict(_sample_llm_config())
|
||||||
|
agent = ConfigDrivenAgent(config=config, llm_client=MockLLMClient())
|
||||||
|
|
||||||
|
task = _make_task()
|
||||||
|
result = await agent.handle_task(task)
|
||||||
|
|
||||||
|
assert result["title"] == "Test"
|
||||||
|
assert result["content"] == "Hello world"
|
||||||
|
|
||||||
|
async def test_llm_generate_with_markdown_json(self):
|
||||||
|
"""LLM 返回 markdown 代码块包裹的 JSON"""
|
||||||
|
|
||||||
|
class MockLLMClient:
|
||||||
|
async def chat(self, messages, **kwargs):
|
||||||
|
return '```json\n{"title": "Wrapped", "content": "In markdown"}\n```'
|
||||||
|
|
||||||
|
config = AgentConfig.from_dict(_sample_llm_config())
|
||||||
|
agent = ConfigDrivenAgent(config=config, llm_client=MockLLMClient())
|
||||||
|
|
||||||
|
task = _make_task()
|
||||||
|
result = await agent.handle_task(task)
|
||||||
|
|
||||||
|
assert result["title"] == "Wrapped"
|
||||||
|
|
||||||
|
async def test_llm_generate_fallback_text(self):
|
||||||
|
"""LLM 返回非 JSON 时降级为文本"""
|
||||||
|
|
||||||
|
class MockLLMClient:
|
||||||
|
async def chat(self, messages, **kwargs):
|
||||||
|
return "This is plain text response"
|
||||||
|
|
||||||
|
config = AgentConfig.from_dict(_sample_llm_config())
|
||||||
|
agent = ConfigDrivenAgent(config=config, llm_client=MockLLMClient())
|
||||||
|
|
||||||
|
task = _make_task()
|
||||||
|
result = await agent.handle_task(task)
|
||||||
|
|
||||||
|
assert result["text"] == "This is plain text response"
|
||||||
|
|
||||||
|
async def test_tool_call_mode(self):
|
||||||
|
"""tool_call 模式调用注册的 Tool"""
|
||||||
|
registry = ToolRegistry()
|
||||||
|
|
||||||
|
async def check_citation(url: str, **kwargs) -> dict:
|
||||||
|
return {"found": True, "url": url}
|
||||||
|
|
||||||
|
tool = FunctionTool(
|
||||||
|
name="check_citation",
|
||||||
|
description="Check citation",
|
||||||
|
func=check_citation,
|
||||||
|
)
|
||||||
|
registry.register(tool)
|
||||||
|
|
||||||
|
config = AgentConfig.from_dict(_sample_tool_call_config())
|
||||||
|
agent = ConfigDrivenAgent(config=config, tool_registry=registry)
|
||||||
|
|
||||||
|
task = _make_task(input_data={"url": "https://example.com"})
|
||||||
|
result = await agent.handle_task(task)
|
||||||
|
|
||||||
|
assert result["found"] is True
|
||||||
|
assert result["url"] == "https://example.com"
|
||||||
|
|
||||||
|
async def test_custom_mode(self):
|
||||||
|
"""custom 模式调用自定义 handler"""
|
||||||
|
config = AgentConfig.from_dict(_sample_custom_config())
|
||||||
|
|
||||||
|
async def my_handler(task):
|
||||||
|
return {"status": "monitored", "task_id": task.task_id}
|
||||||
|
|
||||||
|
agent = ConfigDrivenAgent(
|
||||||
|
config=config,
|
||||||
|
custom_handlers={"my_handler": my_handler},
|
||||||
|
)
|
||||||
|
|
||||||
|
task = _make_task()
|
||||||
|
result = await agent.handle_task(task)
|
||||||
|
|
||||||
|
assert result["status"] == "monitored"
|
||||||
|
assert result["task_id"] == "test-task-001"
|
||||||
|
|
||||||
|
async def test_execute_wraps_task_result(self):
|
||||||
|
"""execute() 自动包装 handle_task 结果为 TaskResult"""
|
||||||
|
config = AgentConfig.from_dict(_sample_llm_config())
|
||||||
|
agent = ConfigDrivenAgent(config=config)
|
||||||
|
|
||||||
|
task = _make_task()
|
||||||
|
result = await agent.execute(task)
|
||||||
|
|
||||||
|
assert result.status == TaskStatus.COMPLETED
|
||||||
|
assert result.output_data is not None
|
||||||
|
assert result.metrics["elapsed_seconds"] >= 0
|
||||||
|
|
||||||
|
def test_get_capabilities(self):
|
||||||
|
"""能力声明从配置正确构建"""
|
||||||
|
config = AgentConfig.from_dict(_sample_llm_config())
|
||||||
|
agent = ConfigDrivenAgent(config=config)
|
||||||
|
cap = agent.get_capabilities()
|
||||||
|
|
||||||
|
assert cap.agent_name == "content_generator"
|
||||||
|
assert cap.agent_type == "content_generation"
|
||||||
|
assert cap.max_concurrency == 2
|
||||||
|
assert "content_generation" in cap.supported_tasks
|
||||||
|
|
||||||
|
def test_prompt_template_rendering(self):
|
||||||
|
"""Prompt 模板正确渲染"""
|
||||||
|
config = AgentConfig.from_dict(_sample_llm_config())
|
||||||
|
agent = ConfigDrivenAgent(config=config)
|
||||||
|
|
||||||
|
assert agent.prompt_template is not None
|
||||||
|
messages = agent.prompt_template.render(variables={"query": "test"})
|
||||||
|
assert len(messages) == 2
|
||||||
|
assert "专业的内容生成助手" in messages[0]["content"]
|
||||||
|
|
||||||
|
async def test_callable_llm_client(self):
|
||||||
|
"""LLM 客户端为可调用对象"""
|
||||||
|
|
||||||
|
async def mock_llm(messages, **kwargs):
|
||||||
|
return '{"result": "from_callable"}'
|
||||||
|
|
||||||
|
config = AgentConfig.from_dict(_sample_llm_config())
|
||||||
|
agent = ConfigDrivenAgent(config=config, llm_client=mock_llm)
|
||||||
|
|
||||||
|
task = _make_task()
|
||||||
|
result = await agent.handle_task(task)
|
||||||
|
|
||||||
|
assert result["result"] == "from_callable"
|
||||||
|
|
||||||
|
|
||||||
|
# ── StandaloneRunner 测试 ─────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestStandaloneRunner:
|
||||||
|
def test_discover_configs(self):
|
||||||
|
"""自动发现 YAML 配置"""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
for name in ["agent_a.yaml", "agent_b.yml"]:
|
||||||
|
config = {
|
||||||
|
"name": name.replace(".", "_"),
|
||||||
|
"agent_type": "test",
|
||||||
|
"task_mode": "llm_generate",
|
||||||
|
"prompt": {"identity": "test", "instructions": "test"},
|
||||||
|
}
|
||||||
|
with open(Path(tmpdir) / name, "w") as f:
|
||||||
|
yaml.dump(config, f)
|
||||||
|
|
||||||
|
runner = StandaloneRunner(config_dir=tmpdir)
|
||||||
|
configs = runner.discover_configs()
|
||||||
|
|
||||||
|
assert len(configs) == 2
|
||||||
|
|
||||||
|
def test_build_agents(self):
|
||||||
|
"""从配置构建 Agent 实例"""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
config = _sample_llm_config()
|
||||||
|
with open(Path(tmpdir) / "test.yaml", "w") as f:
|
||||||
|
yaml.dump(config, f)
|
||||||
|
|
||||||
|
runner = StandaloneRunner(config_dir=tmpdir)
|
||||||
|
agents = runner.build_agents()
|
||||||
|
|
||||||
|
assert "content_generator" in agents
|
||||||
|
assert agents["content_generator"].config.task_mode == "llm_generate"
|
||||||
|
|
||||||
|
def test_add_tool(self):
|
||||||
|
"""添加工具到注册中心"""
|
||||||
|
runner = StandaloneRunner()
|
||||||
|
|
||||||
|
async def my_tool(x: int) -> dict:
|
||||||
|
return {"doubled": x * 2}
|
||||||
|
|
||||||
|
tool = FunctionTool(name="my_tool", description="Test tool", func=my_tool)
|
||||||
|
runner.add_tool(tool)
|
||||||
|
|
||||||
|
assert runner._tool_registry.has_tool("my_tool")
|
||||||
|
|
||||||
|
async def test_execute_task_local(self):
|
||||||
|
"""本地模式执行任务"""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
config = _sample_llm_config()
|
||||||
|
with open(Path(tmpdir) / "test.yaml", "w") as f:
|
||||||
|
yaml.dump(config, f)
|
||||||
|
|
||||||
|
runner = StandaloneRunner(config_dir=tmpdir)
|
||||||
|
runner.build_agents()
|
||||||
|
|
||||||
|
task = _make_task(agent_name="content_generator")
|
||||||
|
result = await runner.execute_task("content_generator", task)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result["status"] == "completed"
|
||||||
|
|
||||||
|
def test_empty_config_dir(self):
|
||||||
|
"""空配置目录不报错"""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
runner = StandaloneRunner(config_dir=tmpdir)
|
||||||
|
configs = runner.discover_configs()
|
||||||
|
assert len(configs) == 0
|
||||||
|
|
||||||
|
def test_nonexistent_config_dir(self):
|
||||||
|
"""不存在的配置目录不报错"""
|
||||||
|
runner = StandaloneRunner(config_dir="/nonexistent/path")
|
||||||
|
configs = runner.discover_configs()
|
||||||
|
assert len(configs) == 0
|
||||||
Loading…
Reference in New Issue