feat(core): add ConfigDrivenAgent with YAML-driven agent definition

- AgentConfig: YAML/dict config model with validation
- ConfigDrivenAgent: 3 task modes (llm_generate, tool_call, custom)
- StandaloneRunner: auto-discover YAML configs and build agents
- 25 new tests covering all modes and edge cases
- Total: 56 tests passing
This commit is contained in:
chiguyong 2026-06-04 22:39:25 +08:00
parent 2ddffcdf37
commit 5a90824c77
5 changed files with 956 additions and 0 deletions

View File

@ -1,6 +1,7 @@
"""Fischer AgentKit - Unified Agent Framework"""
from agentkit.core.base import BaseAgent
from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent
from agentkit.core.protocol import (
AgentCapability,
AgentStatus,
@ -15,6 +16,8 @@ __version__ = "0.1.0"
__all__ = [
"BaseAgent",
"AgentConfig",
"ConfigDrivenAgent",
"AgentCapability",
"AgentStatus",
"HandoffMessage",

View File

@ -1,6 +1,7 @@
"""AgentKit Core - 基础组件"""
from agentkit.core.base import BaseAgent
from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent
from agentkit.core.exceptions import (
AgentAlreadyRegisteredError,
AgentFrameworkError,
@ -33,6 +34,8 @@ from agentkit.core.protocol import (
__all__ = [
"BaseAgent",
"AgentConfig",
"ConfigDrivenAgent",
"AgentCapability",
"AgentStatus",
"AgentFrameworkError",

View File

@ -0,0 +1,421 @@
"""ConfigDrivenAgent - 配置驱动的 Agent 定义
核心设计
- YAML/Dict 配置自动组装 AgentPrompt + LLM + Tool + Memory
- 支持三种任务模式llm_generate / tool_call / custom
- 新增 Agent 从写 150 行代码降为 10-20 行配置
"""
import logging
from typing import Any, Callable, Coroutine
import yaml
from agentkit.core.base import BaseAgent
from agentkit.core.exceptions import ConfigValidationError
from agentkit.core.protocol import AgentCapability, TaskMessage
from agentkit.prompts.section import PromptSection
from agentkit.prompts.template import PromptTemplate
from agentkit.tools.base import Tool
from agentkit.tools.registry import ToolRegistry
logger = logging.getLogger(__name__)
class AgentConfig:
"""Agent 配置模型,从 YAML 或 Dict 构建"""
def __init__(
self,
name: str,
agent_type: str,
version: str = "1.0.0",
description: str = "",
task_mode: str = "llm_generate",
supported_tasks: list[str] | None = None,
max_concurrency: int = 1,
input_schema: dict[str, Any] | None = None,
output_schema: dict[str, Any] | None = None,
prompt: dict[str, str] | None = None,
llm: dict[str, Any] | None = None,
tools: list[str] | None = None,
memory: dict[str, Any] | None = None,
custom_handler: str | None = None,
):
self.name = name
self.agent_type = agent_type
self.version = version
self.description = description
self.task_mode = task_mode
self.supported_tasks = supported_tasks or [agent_type]
self.max_concurrency = max_concurrency
self.input_schema = input_schema
self.output_schema = output_schema
self.prompt = prompt or {}
self.llm = llm or {}
self.tools = tools or []
self.memory = memory or {}
self.custom_handler = custom_handler
self._validate()
def _validate(self) -> None:
"""校验配置合法性"""
valid_modes = {"llm_generate", "tool_call", "custom"}
if self.task_mode not in valid_modes:
raise ConfigValidationError(
agent_name=self.name,
key="task_mode",
reason=f"Invalid task_mode '{self.task_mode}', must be one of {valid_modes}",
)
if self.task_mode == "llm_generate" and not self.prompt:
raise ConfigValidationError(
agent_name=self.name,
key="prompt",
reason="llm_generate mode requires 'prompt' configuration",
)
if self.task_mode == "tool_call" and not self.tools:
raise ConfigValidationError(
agent_name=self.name,
key="tools",
reason="tool_call mode requires at least one tool in 'tools' list",
)
if self.task_mode == "custom" and not self.custom_handler:
raise ConfigValidationError(
agent_name=self.name,
key="custom_handler",
reason="custom mode requires 'custom_handler' (dotted path to callable)",
)
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "AgentConfig":
"""从字典创建配置"""
return cls(
name=data["name"],
agent_type=data["agent_type"],
version=data.get("version", "1.0.0"),
description=data.get("description", ""),
task_mode=data.get("task_mode", "llm_generate"),
supported_tasks=data.get("supported_tasks"),
max_concurrency=data.get("max_concurrency", 1),
input_schema=data.get("input_schema"),
output_schema=data.get("output_schema"),
prompt=data.get("prompt"),
llm=data.get("llm"),
tools=data.get("tools"),
memory=data.get("memory"),
custom_handler=data.get("custom_handler"),
)
@classmethod
def from_yaml(cls, path: str) -> "AgentConfig":
"""从 YAML 文件加载配置"""
with open(path, "r", encoding="utf-8") as f:
data = yaml.safe_load(f)
if not isinstance(data, dict):
raise ConfigValidationError(
agent_name="unknown",
key="config",
reason=f"YAML config must be a mapping, got {type(data)}",
)
return cls.from_dict(data)
def to_dict(self) -> dict[str, Any]:
"""序列化为字典"""
d = {
"name": self.name,
"agent_type": self.agent_type,
"version": self.version,
"description": self.description,
"task_mode": self.task_mode,
"supported_tasks": self.supported_tasks,
"max_concurrency": self.max_concurrency,
}
if self.input_schema:
d["input_schema"] = self.input_schema
if self.output_schema:
d["output_schema"] = self.output_schema
if self.prompt:
d["prompt"] = self.prompt
if self.llm:
d["llm"] = self.llm
if self.tools:
d["tools"] = self.tools
if self.memory:
d["memory"] = self.memory
if self.custom_handler:
d["custom_handler"] = self.custom_handler
return d
class ConfigDrivenAgent(BaseAgent):
"""配置驱动的 Agent
YAML/Dict 配置自动组装支持三种任务模式
- llm_generate: 渲染 Prompt 调用 LLM 解析 JSON 输出
- tool_call: 调用注册的 Tool 并返回结果
- custom: 自定义 handler 函数
示例 YAML 配置::
name: content_generator
agent_type: content_generation
task_mode: llm_generate
description: "内容生成 Agent"
prompt:
identity: "你是一个专业的内容生成助手"
instructions: "根据用户需求生成高质量内容"
output_format: "以 JSON 格式输出"
llm:
model: "gpt-4"
temperature: 0.7
tools:
- retrieve_knowledge
"""
def __init__(
self,
config: AgentConfig,
tool_registry: ToolRegistry | None = None,
llm_client: Any = None,
custom_handlers: dict[str, Callable[..., Coroutine]] | None = None,
):
super().__init__(
name=config.name,
agent_type=config.agent_type,
version=config.version,
)
self._config = config
self._tool_registry = tool_registry or ToolRegistry()
self._llm_client = llm_client
self._custom_handlers = custom_handlers or {}
self._prompt_template: PromptTemplate | None = None
# 从配置构建 Prompt 模板
if config.prompt:
sections = PromptSection(
identity=config.prompt.get("identity", ""),
context=config.prompt.get("context", ""),
instructions=config.prompt.get("instructions", ""),
constraints=config.prompt.get("constraints", ""),
output_format=config.prompt.get("output_format", ""),
examples=config.prompt.get("examples", ""),
)
self._prompt_template = PromptTemplate(
sections=sections,
name=config.name,
version=config.version,
)
# 从配置绑定 Tool
self._bind_tools()
@property
def config(self) -> AgentConfig:
return self._config
@property
def prompt_template(self) -> PromptTemplate | None:
return self._prompt_template
def _bind_tools(self) -> None:
"""根据配置绑定工具"""
for tool_name in self._config.tools:
try:
tool = self._tool_registry.get(tool_name)
self.use_tool(tool)
logger.info(f"ConfigDrivenAgent '{self.name}' bound tool '{tool_name}'")
except Exception as e:
logger.warning(
f"ConfigDrivenAgent '{self.name}' failed to bind tool '{tool_name}': {e}"
)
def get_capabilities(self) -> AgentCapability:
return AgentCapability(
agent_name=self.name,
agent_type=self.agent_type,
version=self.version,
supported_tasks=self._config.supported_tasks,
max_concurrency=self._config.max_concurrency,
description=self._config.description,
input_schema=self._config.input_schema,
output_schema=self._config.output_schema,
)
async def handle_task(self, task: TaskMessage) -> dict:
"""根据 task_mode 执行任务"""
if self._config.task_mode == "llm_generate":
return await self._handle_llm_generate(task)
elif self._config.task_mode == "tool_call":
return await self._handle_tool_call(task)
elif self._config.task_mode == "custom":
return await self._handle_custom(task)
else:
raise ConfigValidationError(
agent_name=self.name,
key="task_mode",
reason=f"Unknown task_mode: {self._config.task_mode}",
)
async def _handle_llm_generate(self, task: TaskMessage) -> dict:
"""LLM 生成模式:渲染 Prompt → 调用 LLM → 解析输出"""
if not self._prompt_template:
raise ConfigValidationError(
agent_name=self.name,
key="prompt",
reason="llm_generate mode requires prompt configuration",
)
# 渲染 Prompt
variables = task.input_data.copy()
variables["task_type"] = task.task_type
messages = self._prompt_template.render(variables=variables)
# 调用 LLM
if self._llm_client is None:
# 无 LLM 客户端时返回渲染后的 Prompt降级模式
return {
"mode": "llm_generate_no_client",
"messages": messages,
"note": "No LLM client configured, returning rendered prompt",
}
# 使用配置的 LLM 参数
llm_params = self._config.llm.copy()
response = await self._call_llm(messages, **llm_params)
return self._parse_llm_response(response)
async def _handle_tool_call(self, task: TaskMessage) -> dict:
"""工具调用模式:调用指定 Tool 并返回结果"""
if not self._tools:
raise ConfigValidationError(
agent_name=self.name,
key="tools",
reason="tool_call mode requires at least one bound tool",
)
# 使用第一个绑定的工具(或根据 task_type 匹配)
tool = self._resolve_tool(task)
result = await tool.safe_execute(**task.input_data)
return result
async def _handle_custom(self, task: TaskMessage) -> dict:
"""自定义模式:调用注册的 handler 函数"""
handler_name = self._config.custom_handler
if handler_name not in self._custom_handlers:
# 尝试动态导入
handler = self._import_handler(handler_name)
self._custom_handlers[handler_name] = handler
handler = self._custom_handlers[handler_name]
result = await handler(task)
if not isinstance(result, dict):
raise ConfigValidationError(
agent_name=self.name,
key="custom_handler",
reason=f"Custom handler '{handler_name}' must return dict, got {type(result)}",
)
return result
def _resolve_tool(self, task: TaskMessage) -> Tool:
"""根据任务类型解析要使用的工具"""
# 优先匹配 task_type 与 tool name
for tool in self._tools:
if task.task_type in tool.name or task.task_type.replace("_", "-") in tool.name:
return tool
# 回退到第一个工具
return self._tools[0]
async def _call_llm(self, messages: list[dict[str, str]], **kwargs) -> str:
"""调用 LLM 客户端"""
model = kwargs.pop("model", "gpt-4")
temperature = kwargs.pop("temperature", 0.7)
max_tokens = kwargs.pop("max_tokens", 2000)
# 标准化 LLM 客户端调用接口
if hasattr(self._llm_client, "chat"):
response = await self._llm_client.chat(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
**kwargs,
)
elif hasattr(self._llm_client, "create"):
response = await self._llm_client.create(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
**kwargs,
)
elif callable(self._llm_client):
response = await self._llm_client(
messages=messages,
model=model,
temperature=temperature,
max_tokens=max_tokens,
**kwargs,
)
else:
raise ConfigValidationError(
agent_name=self.name,
key="llm_client",
reason="LLM client must have 'chat'/'create' method or be callable",
)
# 提取文本内容
if isinstance(response, str):
return response
if isinstance(response, dict):
return response.get("content", str(response))
if hasattr(response, "content"):
return response.content
return str(response)
def _parse_llm_response(self, response: str) -> dict:
"""解析 LLM 响应为 dict"""
import json
# 尝试直接解析 JSON
try:
return json.loads(response)
except (json.JSONDecodeError, TypeError):
pass
# 尝试提取 JSON 块
import re
json_match = re.search(r"```(?:json)?\s*\n?(.*?)\n?```", response, re.DOTALL)
if json_match:
try:
return json.loads(json_match.group(1))
except (json.JSONDecodeError, TypeError):
pass
# 降级:包装为文本结果
return {"text": response}
def _import_handler(self, dotted_path: str) -> Callable[..., Coroutine]:
"""动态导入自定义 handler"""
try:
module_path, func_name = dotted_path.rsplit(".", 1)
import importlib
module = importlib.import_module(module_path)
handler = getattr(module, func_name)
if not callable(handler):
raise ConfigValidationError(
agent_name=self.name,
key="custom_handler",
reason=f"'{dotted_path}' is not callable",
)
return handler
except (ImportError, AttributeError, ValueError) as e:
raise ConfigValidationError(
agent_name=self.name,
key="custom_handler",
reason=f"Failed to import custom handler '{dotted_path}': {e}",
)

View File

@ -0,0 +1,173 @@
"""Standalone Runner - 自动发现并启动配置驱动的 Agent
扫描 agent_configs/ 目录下的 YAML 文件自动注册和启动 Agent
支持命令行启动python -m agentkit.core.standalone
"""
import asyncio
import logging
import os
import sys
from pathlib import Path
import yaml
from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent
from agentkit.tools.function_tool import FunctionTool
from agentkit.tools.registry import ToolRegistry
logger = logging.getLogger(__name__)
DEFAULT_CONFIG_DIR = "agent_configs"
class StandaloneRunner:
"""自动发现并启动配置驱动的 Agent
用法::
runner = StandaloneRunner(config_dir="agent_configs")
runner.add_tool(FunctionTool.from_func(my_tool_func))
await runner.start_all(redis_url="redis://localhost:6379")
"""
def __init__(
self,
config_dir: str = DEFAULT_CONFIG_DIR,
tool_registry: ToolRegistry | None = None,
llm_client=None,
custom_handlers: dict | None = None,
):
self._config_dir = config_dir
self._tool_registry = tool_registry or ToolRegistry()
self._llm_client = llm_client
self._custom_handlers = custom_handlers or {}
self._agents: dict[str, ConfigDrivenAgent] = {}
@property
def agents(self) -> dict[str, ConfigDrivenAgent]:
return self._agents
def add_tool(self, tool) -> "StandaloneRunner":
"""添加工具到注册中心"""
self._tool_registry.register(tool)
return self
def add_custom_handler(self, name: str, handler) -> "StandaloneRunner":
"""注册自定义 handler"""
self._custom_handlers[name] = handler
return self
def discover_configs(self) -> list[AgentConfig]:
"""扫描配置目录,发现所有 YAML 配置"""
configs = []
config_path = Path(self._config_dir)
if not config_path.exists():
logger.warning(f"Config directory '{self._config_dir}' not found")
return configs
for yaml_file in sorted(config_path.glob("*.yaml")):
try:
config = AgentConfig.from_yaml(str(yaml_file))
configs.append(config)
logger.info(f"Discovered agent config: {config.name} from {yaml_file.name}")
except Exception as e:
logger.error(f"Failed to load config '{yaml_file}': {e}")
for yml_file in sorted(config_path.glob("*.yml")):
try:
config = AgentConfig.from_yaml(str(yml_file))
configs.append(config)
logger.info(f"Discovered agent config: {config.name} from {yml_file.name}")
except Exception as e:
logger.error(f"Failed to load config '{yml_file}': {e}")
return configs
def build_agents(self) -> dict[str, ConfigDrivenAgent]:
"""从发现的配置构建 Agent 实例"""
configs = self.discover_configs()
self._agents.clear()
for config in configs:
try:
agent = ConfigDrivenAgent(
config=config,
tool_registry=self._tool_registry,
llm_client=self._llm_client,
custom_handlers=self._custom_handlers,
)
self._agents[config.name] = agent
logger.info(f"Built agent: {config.name} (mode={config.task_mode})")
except Exception as e:
logger.error(f"Failed to build agent '{config.name}': {e}")
return self._agents
async def start_all(self, redis_url: str = "") -> None:
"""启动所有已构建的 Agent"""
if not self._agents:
self.build_agents()
for name, agent in self._agents.items():
try:
await agent.start(redis_url=redis_url)
logger.info(f"Agent '{name}' started")
except Exception as e:
logger.error(f"Failed to start agent '{name}': {e}")
async def stop_all(self) -> None:
"""停止所有 Agent"""
for name, agent in self._agents.items():
try:
await agent.stop()
logger.info(f"Agent '{name}' stopped")
except Exception as e:
logger.error(f"Failed to stop agent '{name}': {e}")
async def execute_task(self, agent_name: str, task) -> dict | None:
"""在指定 Agent 上执行任务(本地模式)"""
if agent_name not in self._agents:
logger.error(f"Agent '{agent_name}' not found")
return None
agent = self._agents[agent_name]
result = await agent.execute(task)
return result.to_dict()
async def main():
"""命令行入口"""
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(name)s] %(levelname)s: %(message)s",
)
config_dir = sys.argv[1] if len(sys.argv) > 1 else DEFAULT_CONFIG_DIR
redis_url = os.environ.get("REDIS_URL", "")
runner = StandaloneRunner(config_dir=config_dir)
agents = runner.build_agents()
if not agents:
logger.error("No agents discovered. Check your config directory.")
sys.exit(1)
logger.info(f"Discovered {len(agents)} agent(s): {list(agents.keys())}")
try:
await runner.start_all(redis_url=redis_url)
logger.info("All agents started. Press Ctrl+C to stop.")
# 保持运行
while True:
await asyncio.sleep(1)
except KeyboardInterrupt:
logger.info("Shutting down...")
finally:
await runner.stop_all()
if __name__ == "__main__":
asyncio.run(main())

View File

@ -0,0 +1,356 @@
"""U2 测试: ConfigDrivenAgent + YAML 配置驱动"""
import json
import tempfile
from pathlib import Path
import pytest
import yaml
from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent
from agentkit.core.protocol import TaskMessage, TaskStatus
from agentkit.core.standalone import StandaloneRunner
from agentkit.tools.function_tool import FunctionTool
from agentkit.tools.registry import ToolRegistry
# ── Fixtures ──────────────────────────────────────────────
def _make_task(**overrides) -> TaskMessage:
defaults = dict(
task_id="test-task-001",
agent_name="test_agent",
task_type="generate",
priority=1,
input_data={"query": "hello"},
callback_url=None,
created_at=None,
)
defaults.update(overrides)
return TaskMessage.from_dict(defaults)
def _sample_llm_config() -> dict:
return {
"name": "content_generator",
"agent_type": "content_generation",
"version": "1.0.0",
"description": "内容生成 Agent",
"task_mode": "llm_generate",
"supported_tasks": ["content_generation"],
"max_concurrency": 2,
"prompt": {
"identity": "你是一个专业的内容生成助手",
"instructions": "根据用户需求生成高质量内容",
"output_format": "以 JSON 格式输出 {title, content}",
},
"llm": {
"model": "gpt-4",
"temperature": 0.7,
},
}
def _sample_tool_call_config() -> dict:
return {
"name": "citation_detector",
"agent_type": "citation_detection",
"task_mode": "tool_call",
"description": "引用检测 Agent",
"tools": ["check_citation"],
}
def _sample_custom_config() -> dict:
return {
"name": "monitor",
"agent_type": "monitoring",
"task_mode": "custom",
"description": "监控 Agent",
"custom_handler": "my_handler",
}
# ── AgentConfig 测试 ──────────────────────────────────────
class TestAgentConfig:
def test_from_dict_llm_generate(self):
config = AgentConfig.from_dict(_sample_llm_config())
assert config.name == "content_generator"
assert config.task_mode == "llm_generate"
assert config.prompt["identity"] == "你是一个专业的内容生成助手"
assert config.llm["model"] == "gpt-4"
def test_from_dict_tool_call(self):
config = AgentConfig.from_dict(_sample_tool_call_config())
assert config.task_mode == "tool_call"
assert config.tools == ["check_citation"]
def test_from_dict_custom(self):
config = AgentConfig.from_dict(_sample_custom_config())
assert config.task_mode == "custom"
assert config.custom_handler == "my_handler"
def test_invalid_task_mode(self):
with pytest.raises(Exception, match="Invalid task_mode"):
AgentConfig(name="x", agent_type="x", task_mode="invalid_mode")
def test_llm_generate_requires_prompt(self):
with pytest.raises(Exception, match="llm_generate mode requires"):
AgentConfig(name="x", agent_type="x", task_mode="llm_generate", prompt=None)
def test_tool_call_requires_tools(self):
with pytest.raises(Exception, match="tool_call mode requires"):
AgentConfig(name="x", agent_type="x", task_mode="tool_call", tools=[])
def test_custom_requires_handler(self):
with pytest.raises(Exception, match="custom mode requires"):
AgentConfig(name="x", agent_type="x", task_mode="custom", custom_handler=None)
def test_from_yaml(self):
with tempfile.NamedTemporaryFile(
mode="w", suffix=".yaml", delete=False
) as f:
yaml.dump(_sample_llm_config(), f)
f.flush()
config = AgentConfig.from_yaml(f.name)
assert config.name == "content_generator"
assert config.task_mode == "llm_generate"
def test_to_dict_roundtrip(self):
original = _sample_llm_config()
config = AgentConfig.from_dict(original)
result = config.to_dict()
assert result["name"] == original["name"]
assert result["task_mode"] == original["task_mode"]
assert result["prompt"] == original["prompt"]
# ── ConfigDrivenAgent 测试 ────────────────────────────────
class TestConfigDrivenAgent:
async def test_llm_generate_no_client(self):
"""无 LLM 客户端时降级返回渲染后的 Prompt"""
config = AgentConfig.from_dict(_sample_llm_config())
agent = ConfigDrivenAgent(config=config)
task = _make_task()
result = await agent.handle_task(task)
assert result["mode"] == "llm_generate_no_client"
assert len(result["messages"]) == 2 # system + user
async def test_llm_generate_with_client(self):
"""有 LLM 客户端时调用 LLM 并解析结果"""
class MockLLMClient:
async def chat(self, messages, **kwargs):
return json.dumps({"title": "Test", "content": "Hello world"})
config = AgentConfig.from_dict(_sample_llm_config())
agent = ConfigDrivenAgent(config=config, llm_client=MockLLMClient())
task = _make_task()
result = await agent.handle_task(task)
assert result["title"] == "Test"
assert result["content"] == "Hello world"
async def test_llm_generate_with_markdown_json(self):
"""LLM 返回 markdown 代码块包裹的 JSON"""
class MockLLMClient:
async def chat(self, messages, **kwargs):
return '```json\n{"title": "Wrapped", "content": "In markdown"}\n```'
config = AgentConfig.from_dict(_sample_llm_config())
agent = ConfigDrivenAgent(config=config, llm_client=MockLLMClient())
task = _make_task()
result = await agent.handle_task(task)
assert result["title"] == "Wrapped"
async def test_llm_generate_fallback_text(self):
"""LLM 返回非 JSON 时降级为文本"""
class MockLLMClient:
async def chat(self, messages, **kwargs):
return "This is plain text response"
config = AgentConfig.from_dict(_sample_llm_config())
agent = ConfigDrivenAgent(config=config, llm_client=MockLLMClient())
task = _make_task()
result = await agent.handle_task(task)
assert result["text"] == "This is plain text response"
async def test_tool_call_mode(self):
"""tool_call 模式调用注册的 Tool"""
registry = ToolRegistry()
async def check_citation(url: str, **kwargs) -> dict:
return {"found": True, "url": url}
tool = FunctionTool(
name="check_citation",
description="Check citation",
func=check_citation,
)
registry.register(tool)
config = AgentConfig.from_dict(_sample_tool_call_config())
agent = ConfigDrivenAgent(config=config, tool_registry=registry)
task = _make_task(input_data={"url": "https://example.com"})
result = await agent.handle_task(task)
assert result["found"] is True
assert result["url"] == "https://example.com"
async def test_custom_mode(self):
"""custom 模式调用自定义 handler"""
config = AgentConfig.from_dict(_sample_custom_config())
async def my_handler(task):
return {"status": "monitored", "task_id": task.task_id}
agent = ConfigDrivenAgent(
config=config,
custom_handlers={"my_handler": my_handler},
)
task = _make_task()
result = await agent.handle_task(task)
assert result["status"] == "monitored"
assert result["task_id"] == "test-task-001"
async def test_execute_wraps_task_result(self):
"""execute() 自动包装 handle_task 结果为 TaskResult"""
config = AgentConfig.from_dict(_sample_llm_config())
agent = ConfigDrivenAgent(config=config)
task = _make_task()
result = await agent.execute(task)
assert result.status == TaskStatus.COMPLETED
assert result.output_data is not None
assert result.metrics["elapsed_seconds"] >= 0
def test_get_capabilities(self):
"""能力声明从配置正确构建"""
config = AgentConfig.from_dict(_sample_llm_config())
agent = ConfigDrivenAgent(config=config)
cap = agent.get_capabilities()
assert cap.agent_name == "content_generator"
assert cap.agent_type == "content_generation"
assert cap.max_concurrency == 2
assert "content_generation" in cap.supported_tasks
def test_prompt_template_rendering(self):
"""Prompt 模板正确渲染"""
config = AgentConfig.from_dict(_sample_llm_config())
agent = ConfigDrivenAgent(config=config)
assert agent.prompt_template is not None
messages = agent.prompt_template.render(variables={"query": "test"})
assert len(messages) == 2
assert "专业的内容生成助手" in messages[0]["content"]
async def test_callable_llm_client(self):
"""LLM 客户端为可调用对象"""
async def mock_llm(messages, **kwargs):
return '{"result": "from_callable"}'
config = AgentConfig.from_dict(_sample_llm_config())
agent = ConfigDrivenAgent(config=config, llm_client=mock_llm)
task = _make_task()
result = await agent.handle_task(task)
assert result["result"] == "from_callable"
# ── StandaloneRunner 测试 ─────────────────────────────────
class TestStandaloneRunner:
def test_discover_configs(self):
"""自动发现 YAML 配置"""
with tempfile.TemporaryDirectory() as tmpdir:
for name in ["agent_a.yaml", "agent_b.yml"]:
config = {
"name": name.replace(".", "_"),
"agent_type": "test",
"task_mode": "llm_generate",
"prompt": {"identity": "test", "instructions": "test"},
}
with open(Path(tmpdir) / name, "w") as f:
yaml.dump(config, f)
runner = StandaloneRunner(config_dir=tmpdir)
configs = runner.discover_configs()
assert len(configs) == 2
def test_build_agents(self):
"""从配置构建 Agent 实例"""
with tempfile.TemporaryDirectory() as tmpdir:
config = _sample_llm_config()
with open(Path(tmpdir) / "test.yaml", "w") as f:
yaml.dump(config, f)
runner = StandaloneRunner(config_dir=tmpdir)
agents = runner.build_agents()
assert "content_generator" in agents
assert agents["content_generator"].config.task_mode == "llm_generate"
def test_add_tool(self):
"""添加工具到注册中心"""
runner = StandaloneRunner()
async def my_tool(x: int) -> dict:
return {"doubled": x * 2}
tool = FunctionTool(name="my_tool", description="Test tool", func=my_tool)
runner.add_tool(tool)
assert runner._tool_registry.has_tool("my_tool")
async def test_execute_task_local(self):
"""本地模式执行任务"""
with tempfile.TemporaryDirectory() as tmpdir:
config = _sample_llm_config()
with open(Path(tmpdir) / "test.yaml", "w") as f:
yaml.dump(config, f)
runner = StandaloneRunner(config_dir=tmpdir)
runner.build_agents()
task = _make_task(agent_name="content_generator")
result = await runner.execute_task("content_generator", task)
assert result is not None
assert result["status"] == "completed"
def test_empty_config_dir(self):
"""空配置目录不报错"""
with tempfile.TemporaryDirectory() as tmpdir:
runner = StandaloneRunner(config_dir=tmpdir)
configs = runner.discover_configs()
assert len(configs) == 0
def test_nonexistent_config_dir(self):
"""不存在的配置目录不报错"""
runner = StandaloneRunner(config_dir="/nonexistent/path")
configs = runner.discover_configs()
assert len(configs) == 0