525 lines
19 KiB
Python
525 lines
19 KiB
Python
"""U2 测试: ConfigDrivenAgent + YAML 配置驱动"""
|
||
|
||
import json
|
||
import tempfile
|
||
from pathlib import Path
|
||
|
||
import pytest
|
||
import yaml
|
||
|
||
from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent
|
||
from agentkit.core.protocol import TaskMessage, TaskStatus
|
||
from agentkit.core.standalone import StandaloneRunner
|
||
from agentkit.tools.function_tool import FunctionTool
|
||
from agentkit.tools.registry import ToolRegistry
|
||
|
||
|
||
# ── Fixtures ──────────────────────────────────────────────
|
||
|
||
|
||
def _make_task(**overrides) -> TaskMessage:
|
||
defaults = dict(
|
||
task_id="test-task-001",
|
||
agent_name="test_agent",
|
||
task_type="generate",
|
||
priority=1,
|
||
input_data={"query": "hello"},
|
||
callback_url=None,
|
||
created_at=None,
|
||
)
|
||
defaults.update(overrides)
|
||
return TaskMessage.from_dict(defaults)
|
||
|
||
|
||
def _sample_llm_config() -> dict:
|
||
return {
|
||
"name": "content_generator",
|
||
"agent_type": "content_generation",
|
||
"version": "1.0.0",
|
||
"description": "内容生成 Agent",
|
||
"task_mode": "llm_generate",
|
||
"supported_tasks": ["content_generation"],
|
||
"max_concurrency": 2,
|
||
"prompt": {
|
||
"identity": "你是一个专业的内容生成助手",
|
||
"instructions": "根据用户需求生成高质量内容",
|
||
"output_format": "以 JSON 格式输出 {title, content}",
|
||
},
|
||
"llm": {
|
||
"model": "gpt-4",
|
||
"temperature": 0.7,
|
||
},
|
||
}
|
||
|
||
|
||
def _sample_tool_call_config() -> dict:
|
||
return {
|
||
"name": "citation_detector",
|
||
"agent_type": "citation_detection",
|
||
"task_mode": "tool_call",
|
||
"description": "引用检测 Agent",
|
||
"tools": ["check_citation"],
|
||
}
|
||
|
||
|
||
def _sample_custom_config() -> dict:
|
||
return {
|
||
"name": "monitor",
|
||
"agent_type": "monitoring",
|
||
"task_mode": "custom",
|
||
"description": "监控 Agent",
|
||
"custom_handler": "my_handler",
|
||
}
|
||
|
||
|
||
# ── AgentConfig 测试 ──────────────────────────────────────
|
||
|
||
|
||
class TestAgentConfig:
|
||
def test_from_dict_llm_generate(self):
|
||
config = AgentConfig.from_dict(_sample_llm_config())
|
||
assert config.name == "content_generator"
|
||
assert config.task_mode == "llm_generate"
|
||
assert config.prompt["identity"] == "你是一个专业的内容生成助手"
|
||
assert config.llm["model"] == "gpt-4"
|
||
|
||
def test_from_dict_tool_call(self):
|
||
config = AgentConfig.from_dict(_sample_tool_call_config())
|
||
assert config.task_mode == "tool_call"
|
||
assert config.tools == ["check_citation"]
|
||
|
||
def test_from_dict_custom(self):
|
||
config = AgentConfig.from_dict(_sample_custom_config())
|
||
assert config.task_mode == "custom"
|
||
assert config.custom_handler == "my_handler"
|
||
|
||
def test_invalid_task_mode(self):
|
||
with pytest.raises(Exception, match="Invalid task_mode"):
|
||
AgentConfig(name="x", agent_type="x", task_mode="invalid_mode")
|
||
|
||
def test_llm_generate_requires_prompt(self):
|
||
with pytest.raises(Exception, match="llm_generate mode requires"):
|
||
AgentConfig(name="x", agent_type="x", task_mode="llm_generate", prompt=None)
|
||
|
||
def test_tool_call_requires_tools(self):
|
||
with pytest.raises(Exception, match="tool_call mode requires"):
|
||
AgentConfig(name="x", agent_type="x", task_mode="tool_call", tools=[])
|
||
|
||
def test_custom_requires_handler(self):
|
||
with pytest.raises(Exception, match="custom mode requires"):
|
||
AgentConfig(name="x", agent_type="x", task_mode="custom", custom_handler=None)
|
||
|
||
def test_from_yaml(self):
|
||
with tempfile.NamedTemporaryFile(
|
||
mode="w", suffix=".yaml", delete=False
|
||
) as f:
|
||
yaml.dump(_sample_llm_config(), f)
|
||
f.flush()
|
||
config = AgentConfig.from_yaml(f.name)
|
||
|
||
assert config.name == "content_generator"
|
||
assert config.task_mode == "llm_generate"
|
||
|
||
def test_to_dict_roundtrip(self):
|
||
original = _sample_llm_config()
|
||
config = AgentConfig.from_dict(original)
|
||
result = config.to_dict()
|
||
assert result["name"] == original["name"]
|
||
assert result["task_mode"] == original["task_mode"]
|
||
assert result["prompt"] == original["prompt"]
|
||
|
||
|
||
# ── ConfigDrivenAgent 测试 ────────────────────────────────
|
||
|
||
|
||
class TestConfigDrivenAgent:
|
||
async def test_llm_generate_no_client(self):
|
||
"""无 LLM 客户端时降级返回渲染后的 Prompt"""
|
||
config = AgentConfig.from_dict(_sample_llm_config())
|
||
agent = ConfigDrivenAgent(config=config)
|
||
|
||
task = _make_task()
|
||
result = await agent.handle_task(task)
|
||
|
||
assert result["mode"] == "llm_generate_no_client"
|
||
assert len(result["messages"]) == 2 # system + user
|
||
|
||
async def test_llm_generate_with_client(self):
|
||
"""有 LLM 客户端时调用 LLM 并解析结果"""
|
||
|
||
class MockLLMClient:
|
||
async def chat(self, messages, **kwargs):
|
||
return json.dumps({"title": "Test", "content": "Hello world"})
|
||
|
||
config = AgentConfig.from_dict(_sample_llm_config())
|
||
agent = ConfigDrivenAgent(config=config, llm_client=MockLLMClient())
|
||
|
||
task = _make_task()
|
||
result = await agent.handle_task(task)
|
||
|
||
assert result["title"] == "Test"
|
||
assert result["content"] == "Hello world"
|
||
|
||
async def test_llm_generate_with_markdown_json(self):
|
||
"""LLM 返回 markdown 代码块包裹的 JSON"""
|
||
|
||
class MockLLMClient:
|
||
async def chat(self, messages, **kwargs):
|
||
return '```json\n{"title": "Wrapped", "content": "In markdown"}\n```'
|
||
|
||
config = AgentConfig.from_dict(_sample_llm_config())
|
||
agent = ConfigDrivenAgent(config=config, llm_client=MockLLMClient())
|
||
|
||
task = _make_task()
|
||
result = await agent.handle_task(task)
|
||
|
||
assert result["title"] == "Wrapped"
|
||
|
||
async def test_llm_generate_fallback_text(self):
|
||
"""LLM 返回非 JSON 时降级为文本"""
|
||
|
||
class MockLLMClient:
|
||
async def chat(self, messages, **kwargs):
|
||
return "This is plain text response"
|
||
|
||
config = AgentConfig.from_dict(_sample_llm_config())
|
||
agent = ConfigDrivenAgent(config=config, llm_client=MockLLMClient())
|
||
|
||
task = _make_task()
|
||
result = await agent.handle_task(task)
|
||
|
||
assert result["text"] == "This is plain text response"
|
||
|
||
async def test_tool_call_mode(self):
|
||
"""tool_call 模式调用注册的 Tool"""
|
||
registry = ToolRegistry()
|
||
|
||
async def check_citation(url: str, **kwargs) -> dict:
|
||
return {"found": True, "url": url}
|
||
|
||
tool = FunctionTool(
|
||
name="check_citation",
|
||
description="Check citation",
|
||
func=check_citation,
|
||
)
|
||
registry.register(tool)
|
||
|
||
config = AgentConfig.from_dict(_sample_tool_call_config())
|
||
agent = ConfigDrivenAgent(config=config, tool_registry=registry)
|
||
|
||
task = _make_task(input_data={"url": "https://example.com"})
|
||
result = await agent.handle_task(task)
|
||
|
||
assert result["found"] is True
|
||
assert result["url"] == "https://example.com"
|
||
|
||
async def test_custom_mode(self):
|
||
"""custom 模式调用自定义 handler"""
|
||
config = AgentConfig.from_dict(_sample_custom_config())
|
||
|
||
async def my_handler(task):
|
||
return {"status": "monitored", "task_id": task.task_id}
|
||
|
||
agent = ConfigDrivenAgent(
|
||
config=config,
|
||
custom_handlers={"my_handler": my_handler},
|
||
)
|
||
|
||
task = _make_task()
|
||
result = await agent.handle_task(task)
|
||
|
||
assert result["status"] == "monitored"
|
||
assert result["task_id"] == "test-task-001"
|
||
|
||
async def test_execute_wraps_task_result(self):
|
||
"""execute() 自动包装 handle_task 结果为 TaskResult"""
|
||
config = AgentConfig.from_dict(_sample_llm_config())
|
||
agent = ConfigDrivenAgent(config=config)
|
||
|
||
task = _make_task()
|
||
result = await agent.execute(task)
|
||
|
||
assert result.status == TaskStatus.COMPLETED
|
||
assert result.output_data is not None
|
||
assert result.metrics["elapsed_seconds"] >= 0
|
||
|
||
def test_get_capabilities(self):
|
||
"""能力声明从配置正确构建"""
|
||
config = AgentConfig.from_dict(_sample_llm_config())
|
||
agent = ConfigDrivenAgent(config=config)
|
||
cap = agent.get_capabilities()
|
||
|
||
assert cap.agent_name == "content_generator"
|
||
assert cap.agent_type == "content_generation"
|
||
assert cap.max_concurrency == 2
|
||
assert "content_generation" in cap.supported_tasks
|
||
|
||
def test_prompt_template_rendering(self):
|
||
"""Prompt 模板正确渲染"""
|
||
config = AgentConfig.from_dict(_sample_llm_config())
|
||
agent = ConfigDrivenAgent(config=config)
|
||
|
||
assert agent.prompt_template is not None
|
||
messages = agent.prompt_template.render(variables={"query": "test"})
|
||
assert len(messages) == 2
|
||
assert "专业的内容生成助手" in messages[0]["content"]
|
||
|
||
async def test_callable_llm_client(self):
|
||
"""LLM 客户端为可调用对象"""
|
||
|
||
async def mock_llm(messages, **kwargs):
|
||
return '{"result": "from_callable"}'
|
||
|
||
config = AgentConfig.from_dict(_sample_llm_config())
|
||
agent = ConfigDrivenAgent(config=config, llm_client=mock_llm)
|
||
|
||
task = _make_task()
|
||
result = await agent.handle_task(task)
|
||
|
||
assert result["result"] == "from_callable"
|
||
|
||
|
||
# ── StandaloneRunner 测试 ─────────────────────────────────
|
||
|
||
|
||
class TestStandaloneRunner:
|
||
def test_discover_configs(self):
|
||
"""自动发现 YAML 配置"""
|
||
with tempfile.TemporaryDirectory() as tmpdir:
|
||
for name in ["agent_a.yaml", "agent_b.yml"]:
|
||
config = {
|
||
"name": name.replace(".", "_"),
|
||
"agent_type": "test",
|
||
"task_mode": "llm_generate",
|
||
"prompt": {"identity": "test", "instructions": "test"},
|
||
}
|
||
with open(Path(tmpdir) / name, "w") as f:
|
||
yaml.dump(config, f)
|
||
|
||
runner = StandaloneRunner(config_dir=tmpdir)
|
||
configs = runner.discover_configs()
|
||
|
||
assert len(configs) == 2
|
||
|
||
def test_build_agents(self):
|
||
"""从配置构建 Agent 实例"""
|
||
with tempfile.TemporaryDirectory() as tmpdir:
|
||
config = _sample_llm_config()
|
||
with open(Path(tmpdir) / "test.yaml", "w") as f:
|
||
yaml.dump(config, f)
|
||
|
||
runner = StandaloneRunner(config_dir=tmpdir)
|
||
agents = runner.build_agents()
|
||
|
||
assert "content_generator" in agents
|
||
assert agents["content_generator"].config.task_mode == "llm_generate"
|
||
|
||
def test_add_tool(self):
|
||
"""添加工具到注册中心"""
|
||
runner = StandaloneRunner()
|
||
|
||
async def my_tool(x: int) -> dict:
|
||
return {"doubled": x * 2}
|
||
|
||
tool = FunctionTool(name="my_tool", description="Test tool", func=my_tool)
|
||
runner.add_tool(tool)
|
||
|
||
assert runner._tool_registry.has_tool("my_tool")
|
||
|
||
async def test_execute_task_local(self):
|
||
"""本地模式执行任务"""
|
||
with tempfile.TemporaryDirectory() as tmpdir:
|
||
config = _sample_llm_config()
|
||
with open(Path(tmpdir) / "test.yaml", "w") as f:
|
||
yaml.dump(config, f)
|
||
|
||
runner = StandaloneRunner(config_dir=tmpdir)
|
||
runner.build_agents()
|
||
|
||
task = _make_task(agent_name="content_generator")
|
||
result = await runner.execute_task("content_generator", task)
|
||
|
||
assert result is not None
|
||
assert result["status"] == "completed"
|
||
|
||
def test_empty_config_dir(self):
|
||
"""空配置目录不报错"""
|
||
with tempfile.TemporaryDirectory() as tmpdir:
|
||
runner = StandaloneRunner(config_dir=tmpdir)
|
||
configs = runner.discover_configs()
|
||
assert len(configs) == 0
|
||
|
||
def test_nonexistent_config_dir(self):
|
||
"""不存在的配置目录不报错"""
|
||
runner = StandaloneRunner(config_dir="/nonexistent/path")
|
||
configs = runner.discover_configs()
|
||
assert len(configs) == 0
|
||
|
||
|
||
# ── Handler Prefix Whitelist 测试 ─────────────────────────
|
||
|
||
|
||
class TestConfigDrivenAgentPublicAccessors:
|
||
"""U8: Test public accessor methods on ConfigDrivenAgent"""
|
||
|
||
def test_get_tools_returns_bound_tools(self):
|
||
"""get_tools() returns list of tools bound to the agent"""
|
||
from agentkit.tools.function_tool import FunctionTool
|
||
|
||
async def check_citation(url: str, **kwargs) -> dict:
|
||
return {"found": True, "url": url}
|
||
|
||
tool = FunctionTool(name="check_citation", description="Check citation", func=check_citation)
|
||
registry = ToolRegistry()
|
||
registry.register(tool)
|
||
|
||
config = AgentConfig.from_dict(_sample_tool_call_config())
|
||
agent = ConfigDrivenAgent(config=config, tool_registry=registry)
|
||
|
||
tools = agent.get_tools()
|
||
assert len(tools) >= 1
|
||
assert any(t.name == "check_citation" for t in tools)
|
||
|
||
def test_get_tools_empty_when_no_tools(self):
|
||
"""get_tools() returns empty list when no tools bound"""
|
||
config = AgentConfig.from_dict(_sample_llm_config())
|
||
agent = ConfigDrivenAgent(config=config)
|
||
|
||
tools = agent.get_tools()
|
||
assert tools == []
|
||
|
||
def test_get_model_returns_configured_model(self):
|
||
"""get_model() returns the model from config.llm"""
|
||
config = AgentConfig.from_dict(_sample_llm_config())
|
||
agent = ConfigDrivenAgent(config=config)
|
||
|
||
assert agent.get_model() == "gpt-4"
|
||
|
||
def test_get_model_default_when_no_llm_config(self):
|
||
"""get_model() returns 'default' when no llm config"""
|
||
config = AgentConfig(
|
||
name="test",
|
||
agent_type="test",
|
||
task_mode="llm_generate",
|
||
prompt={"identity": "Test"},
|
||
)
|
||
agent = ConfigDrivenAgent(config=config)
|
||
|
||
assert agent.get_model() == "default"
|
||
|
||
def test_get_system_prompt_returns_prompt_sections(self):
|
||
"""get_system_prompt() returns combined prompt sections"""
|
||
config = AgentConfig.from_dict(_sample_llm_config())
|
||
agent = ConfigDrivenAgent(config=config)
|
||
|
||
prompt = agent.get_system_prompt()
|
||
assert prompt is not None
|
||
assert "专业的内容生成助手" in prompt
|
||
assert "根据用户需求生成高质量内容" in prompt
|
||
|
||
def test_get_system_prompt_none_when_no_prompt(self):
|
||
"""get_system_prompt() returns None when no prompt configured"""
|
||
config = AgentConfig(
|
||
name="test",
|
||
agent_type="test",
|
||
task_mode="tool_call",
|
||
tools=["some_tool"],
|
||
)
|
||
agent = ConfigDrivenAgent(config=config)
|
||
|
||
assert agent.get_system_prompt() is None
|
||
|
||
def test_get_react_config_default_values(self):
|
||
"""get_react_config() returns defaults when no SkillConfig"""
|
||
config = AgentConfig.from_dict(_sample_llm_config())
|
||
agent = ConfigDrivenAgent(config=config)
|
||
|
||
react_config = agent.get_react_config()
|
||
assert react_config["max_steps"] == 10
|
||
assert react_config["timeout_seconds"] is None
|
||
|
||
def test_get_react_config_with_skill_config(self):
|
||
"""get_react_config() returns values from SkillConfig"""
|
||
from agentkit.skills.base import SkillConfig
|
||
|
||
skill_config = SkillConfig(
|
||
name="test_skill",
|
||
agent_type="test",
|
||
task_mode="llm_generate",
|
||
prompt={"identity": "Test"},
|
||
intent={"keywords": ["test"], "description": "Test"},
|
||
max_steps=20,
|
||
)
|
||
agent = ConfigDrivenAgent(config=skill_config)
|
||
|
||
react_config = agent.get_react_config()
|
||
assert react_config["max_steps"] == 20
|
||
assert react_config["timeout_seconds"] is None
|
||
|
||
|
||
class TestHandlerPrefixWhitelist:
|
||
"""U4: 测试 _import_handler 的模块前缀白名单,防止任意代码执行"""
|
||
|
||
def _make_agent_with_custom(self, handler_path: str) -> ConfigDrivenAgent:
|
||
config = AgentConfig(
|
||
name="test_agent",
|
||
agent_type="test",
|
||
task_mode="custom",
|
||
custom_handler=handler_path,
|
||
)
|
||
return ConfigDrivenAgent(config=config)
|
||
|
||
def test_allowed_prefix_agentkit(self):
|
||
"""agentkit.xxx.handler → 允许通过前缀检查"""
|
||
agent = self._make_agent_with_custom("agentkit.handlers.test_handler")
|
||
# 前缀检查通过,但模块不存在会报 ImportError,我们只验证不报 ConfigValidationError(前缀)
|
||
try:
|
||
agent._import_handler("agentkit.handlers.test_handler")
|
||
except Exception as e:
|
||
# 允许 ImportError/AttributeError(模块不存在),但不允许前缀拒绝
|
||
assert "not in allowed module prefixes" not in str(e)
|
||
|
||
def test_allowed_prefix_app_agent_framework(self):
|
||
"""app.agent_framework.handlers.xxx → 允许通过前缀检查"""
|
||
agent = self._make_agent_with_custom("app.agent_framework.handlers.xxx_handler")
|
||
try:
|
||
agent._import_handler("app.agent_framework.handlers.xxx_handler")
|
||
except Exception as e:
|
||
assert "not in allowed module prefixes" not in str(e)
|
||
|
||
def test_blocked_os_system(self):
|
||
"""os.system → 阻止(ConfigValidationError)"""
|
||
agent = self._make_agent_with_custom("os.system")
|
||
with pytest.raises(Exception, match="not in allowed module prefixes"):
|
||
agent._import_handler("os.system")
|
||
|
||
def test_blocked_subprocess_run(self):
|
||
"""subprocess.run → 阻止"""
|
||
agent = self._make_agent_with_custom("subprocess.run")
|
||
with pytest.raises(Exception, match="not in allowed module prefixes"):
|
||
agent._import_handler("subprocess.run")
|
||
|
||
def test_blocked_builtins_exec(self):
|
||
"""builtins.exec → 阻止"""
|
||
agent = self._make_agent_with_custom("builtins.exec")
|
||
with pytest.raises(Exception, match="not in allowed module prefixes"):
|
||
agent._import_handler("builtins.exec")
|
||
|
||
def test_blocked_empty_string(self):
|
||
"""空字符串 → 阻止(在 _import_handler 级别直接被前缀检查拒绝)"""
|
||
config = AgentConfig(
|
||
name="test_agent",
|
||
agent_type="test",
|
||
task_mode="custom",
|
||
custom_handler="agentkit.dummy", # valid config, but we test _import_handler directly
|
||
)
|
||
agent = ConfigDrivenAgent(config=config)
|
||
with pytest.raises(Exception, match="not in allowed module prefixes"):
|
||
agent._import_handler("")
|
||
|
||
def test_blocked_agentkitx_prefix(self):
|
||
"""agentkitx. → 阻止(不是 agentkit.)"""
|
||
agent = self._make_agent_with_custom("agentkitx.handlers.evil")
|
||
with pytest.raises(Exception, match="not in allowed module prefixes"):
|
||
agent._import_handler("agentkitx.handlers.evil")
|