fischer-agentkit/tests/unit/test_config_driven.py

525 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""U2 测试: ConfigDrivenAgent + YAML 配置驱动"""
import json
import tempfile
from pathlib import Path
import pytest
import yaml
from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent
from agentkit.core.protocol import TaskMessage, TaskStatus
from agentkit.core.standalone import StandaloneRunner
from agentkit.tools.function_tool import FunctionTool
from agentkit.tools.registry import ToolRegistry
# ── Fixtures ──────────────────────────────────────────────
def _make_task(**overrides) -> TaskMessage:
defaults = dict(
task_id="test-task-001",
agent_name="test_agent",
task_type="generate",
priority=1,
input_data={"query": "hello"},
callback_url=None,
created_at=None,
)
defaults.update(overrides)
return TaskMessage.from_dict(defaults)
def _sample_llm_config() -> dict:
return {
"name": "content_generator",
"agent_type": "content_generation",
"version": "1.0.0",
"description": "内容生成 Agent",
"task_mode": "llm_generate",
"supported_tasks": ["content_generation"],
"max_concurrency": 2,
"prompt": {
"identity": "你是一个专业的内容生成助手",
"instructions": "根据用户需求生成高质量内容",
"output_format": "以 JSON 格式输出 {title, content}",
},
"llm": {
"model": "gpt-4",
"temperature": 0.7,
},
}
def _sample_tool_call_config() -> dict:
return {
"name": "citation_detector",
"agent_type": "citation_detection",
"task_mode": "tool_call",
"description": "引用检测 Agent",
"tools": ["check_citation"],
}
def _sample_custom_config() -> dict:
return {
"name": "monitor",
"agent_type": "monitoring",
"task_mode": "custom",
"description": "监控 Agent",
"custom_handler": "my_handler",
}
# ── AgentConfig 测试 ──────────────────────────────────────
class TestAgentConfig:
def test_from_dict_llm_generate(self):
config = AgentConfig.from_dict(_sample_llm_config())
assert config.name == "content_generator"
assert config.task_mode == "llm_generate"
assert config.prompt["identity"] == "你是一个专业的内容生成助手"
assert config.llm["model"] == "gpt-4"
def test_from_dict_tool_call(self):
config = AgentConfig.from_dict(_sample_tool_call_config())
assert config.task_mode == "tool_call"
assert config.tools == ["check_citation"]
def test_from_dict_custom(self):
config = AgentConfig.from_dict(_sample_custom_config())
assert config.task_mode == "custom"
assert config.custom_handler == "my_handler"
def test_invalid_task_mode(self):
with pytest.raises(Exception, match="Invalid task_mode"):
AgentConfig(name="x", agent_type="x", task_mode="invalid_mode")
def test_llm_generate_requires_prompt(self):
with pytest.raises(Exception, match="llm_generate mode requires"):
AgentConfig(name="x", agent_type="x", task_mode="llm_generate", prompt=None)
def test_tool_call_requires_tools(self):
with pytest.raises(Exception, match="tool_call mode requires"):
AgentConfig(name="x", agent_type="x", task_mode="tool_call", tools=[])
def test_custom_requires_handler(self):
with pytest.raises(Exception, match="custom mode requires"):
AgentConfig(name="x", agent_type="x", task_mode="custom", custom_handler=None)
def test_from_yaml(self):
with tempfile.NamedTemporaryFile(
mode="w", suffix=".yaml", delete=False
) as f:
yaml.dump(_sample_llm_config(), f)
f.flush()
config = AgentConfig.from_yaml(f.name)
assert config.name == "content_generator"
assert config.task_mode == "llm_generate"
def test_to_dict_roundtrip(self):
original = _sample_llm_config()
config = AgentConfig.from_dict(original)
result = config.to_dict()
assert result["name"] == original["name"]
assert result["task_mode"] == original["task_mode"]
assert result["prompt"] == original["prompt"]
# ── ConfigDrivenAgent 测试 ────────────────────────────────
class TestConfigDrivenAgent:
async def test_llm_generate_no_client(self):
"""无 LLM 客户端时降级返回渲染后的 Prompt"""
config = AgentConfig.from_dict(_sample_llm_config())
agent = ConfigDrivenAgent(config=config)
task = _make_task()
result = await agent.handle_task(task)
assert result["mode"] == "llm_generate_no_client"
assert len(result["messages"]) == 2 # system + user
async def test_llm_generate_with_client(self):
"""有 LLM 客户端时调用 LLM 并解析结果"""
class MockLLMClient:
async def chat(self, messages, **kwargs):
return json.dumps({"title": "Test", "content": "Hello world"})
config = AgentConfig.from_dict(_sample_llm_config())
agent = ConfigDrivenAgent(config=config, llm_client=MockLLMClient())
task = _make_task()
result = await agent.handle_task(task)
assert result["title"] == "Test"
assert result["content"] == "Hello world"
async def test_llm_generate_with_markdown_json(self):
"""LLM 返回 markdown 代码块包裹的 JSON"""
class MockLLMClient:
async def chat(self, messages, **kwargs):
return '```json\n{"title": "Wrapped", "content": "In markdown"}\n```'
config = AgentConfig.from_dict(_sample_llm_config())
agent = ConfigDrivenAgent(config=config, llm_client=MockLLMClient())
task = _make_task()
result = await agent.handle_task(task)
assert result["title"] == "Wrapped"
async def test_llm_generate_fallback_text(self):
"""LLM 返回非 JSON 时降级为文本"""
class MockLLMClient:
async def chat(self, messages, **kwargs):
return "This is plain text response"
config = AgentConfig.from_dict(_sample_llm_config())
agent = ConfigDrivenAgent(config=config, llm_client=MockLLMClient())
task = _make_task()
result = await agent.handle_task(task)
assert result["text"] == "This is plain text response"
async def test_tool_call_mode(self):
"""tool_call 模式调用注册的 Tool"""
registry = ToolRegistry()
async def check_citation(url: str, **kwargs) -> dict:
return {"found": True, "url": url}
tool = FunctionTool(
name="check_citation",
description="Check citation",
func=check_citation,
)
registry.register(tool)
config = AgentConfig.from_dict(_sample_tool_call_config())
agent = ConfigDrivenAgent(config=config, tool_registry=registry)
task = _make_task(input_data={"url": "https://example.com"})
result = await agent.handle_task(task)
assert result["found"] is True
assert result["url"] == "https://example.com"
async def test_custom_mode(self):
"""custom 模式调用自定义 handler"""
config = AgentConfig.from_dict(_sample_custom_config())
async def my_handler(task):
return {"status": "monitored", "task_id": task.task_id}
agent = ConfigDrivenAgent(
config=config,
custom_handlers={"my_handler": my_handler},
)
task = _make_task()
result = await agent.handle_task(task)
assert result["status"] == "monitored"
assert result["task_id"] == "test-task-001"
async def test_execute_wraps_task_result(self):
"""execute() 自动包装 handle_task 结果为 TaskResult"""
config = AgentConfig.from_dict(_sample_llm_config())
agent = ConfigDrivenAgent(config=config)
task = _make_task()
result = await agent.execute(task)
assert result.status == TaskStatus.COMPLETED
assert result.output_data is not None
assert result.metrics["elapsed_seconds"] >= 0
def test_get_capabilities(self):
"""能力声明从配置正确构建"""
config = AgentConfig.from_dict(_sample_llm_config())
agent = ConfigDrivenAgent(config=config)
cap = agent.get_capabilities()
assert cap.agent_name == "content_generator"
assert cap.agent_type == "content_generation"
assert cap.max_concurrency == 2
assert "content_generation" in cap.supported_tasks
def test_prompt_template_rendering(self):
"""Prompt 模板正确渲染"""
config = AgentConfig.from_dict(_sample_llm_config())
agent = ConfigDrivenAgent(config=config)
assert agent.prompt_template is not None
messages = agent.prompt_template.render(variables={"query": "test"})
assert len(messages) == 2
assert "专业的内容生成助手" in messages[0]["content"]
async def test_callable_llm_client(self):
"""LLM 客户端为可调用对象"""
async def mock_llm(messages, **kwargs):
return '{"result": "from_callable"}'
config = AgentConfig.from_dict(_sample_llm_config())
agent = ConfigDrivenAgent(config=config, llm_client=mock_llm)
task = _make_task()
result = await agent.handle_task(task)
assert result["result"] == "from_callable"
# ── StandaloneRunner 测试 ─────────────────────────────────
class TestStandaloneRunner:
def test_discover_configs(self):
"""自动发现 YAML 配置"""
with tempfile.TemporaryDirectory() as tmpdir:
for name in ["agent_a.yaml", "agent_b.yml"]:
config = {
"name": name.replace(".", "_"),
"agent_type": "test",
"task_mode": "llm_generate",
"prompt": {"identity": "test", "instructions": "test"},
}
with open(Path(tmpdir) / name, "w") as f:
yaml.dump(config, f)
runner = StandaloneRunner(config_dir=tmpdir)
configs = runner.discover_configs()
assert len(configs) == 2
def test_build_agents(self):
"""从配置构建 Agent 实例"""
with tempfile.TemporaryDirectory() as tmpdir:
config = _sample_llm_config()
with open(Path(tmpdir) / "test.yaml", "w") as f:
yaml.dump(config, f)
runner = StandaloneRunner(config_dir=tmpdir)
agents = runner.build_agents()
assert "content_generator" in agents
assert agents["content_generator"].config.task_mode == "llm_generate"
def test_add_tool(self):
"""添加工具到注册中心"""
runner = StandaloneRunner()
async def my_tool(x: int) -> dict:
return {"doubled": x * 2}
tool = FunctionTool(name="my_tool", description="Test tool", func=my_tool)
runner.add_tool(tool)
assert runner._tool_registry.has_tool("my_tool")
async def test_execute_task_local(self):
"""本地模式执行任务"""
with tempfile.TemporaryDirectory() as tmpdir:
config = _sample_llm_config()
with open(Path(tmpdir) / "test.yaml", "w") as f:
yaml.dump(config, f)
runner = StandaloneRunner(config_dir=tmpdir)
runner.build_agents()
task = _make_task(agent_name="content_generator")
result = await runner.execute_task("content_generator", task)
assert result is not None
assert result["status"] == "completed"
def test_empty_config_dir(self):
"""空配置目录不报错"""
with tempfile.TemporaryDirectory() as tmpdir:
runner = StandaloneRunner(config_dir=tmpdir)
configs = runner.discover_configs()
assert len(configs) == 0
def test_nonexistent_config_dir(self):
"""不存在的配置目录不报错"""
runner = StandaloneRunner(config_dir="/nonexistent/path")
configs = runner.discover_configs()
assert len(configs) == 0
# ── Handler Prefix Whitelist 测试 ─────────────────────────
class TestConfigDrivenAgentPublicAccessors:
"""U8: Test public accessor methods on ConfigDrivenAgent"""
def test_get_tools_returns_bound_tools(self):
"""get_tools() returns list of tools bound to the agent"""
from agentkit.tools.function_tool import FunctionTool
async def check_citation(url: str, **kwargs) -> dict:
return {"found": True, "url": url}
tool = FunctionTool(name="check_citation", description="Check citation", func=check_citation)
registry = ToolRegistry()
registry.register(tool)
config = AgentConfig.from_dict(_sample_tool_call_config())
agent = ConfigDrivenAgent(config=config, tool_registry=registry)
tools = agent.get_tools()
assert len(tools) >= 1
assert any(t.name == "check_citation" for t in tools)
def test_get_tools_empty_when_no_tools(self):
"""get_tools() returns empty list when no tools bound"""
config = AgentConfig.from_dict(_sample_llm_config())
agent = ConfigDrivenAgent(config=config)
tools = agent.get_tools()
assert tools == []
def test_get_model_returns_configured_model(self):
"""get_model() returns the model from config.llm"""
config = AgentConfig.from_dict(_sample_llm_config())
agent = ConfigDrivenAgent(config=config)
assert agent.get_model() == "gpt-4"
def test_get_model_default_when_no_llm_config(self):
"""get_model() returns 'default' when no llm config"""
config = AgentConfig(
name="test",
agent_type="test",
task_mode="llm_generate",
prompt={"identity": "Test"},
)
agent = ConfigDrivenAgent(config=config)
assert agent.get_model() == "default"
def test_get_system_prompt_returns_prompt_sections(self):
"""get_system_prompt() returns combined prompt sections"""
config = AgentConfig.from_dict(_sample_llm_config())
agent = ConfigDrivenAgent(config=config)
prompt = agent.get_system_prompt()
assert prompt is not None
assert "专业的内容生成助手" in prompt
assert "根据用户需求生成高质量内容" in prompt
def test_get_system_prompt_none_when_no_prompt(self):
"""get_system_prompt() returns None when no prompt configured"""
config = AgentConfig(
name="test",
agent_type="test",
task_mode="tool_call",
tools=["some_tool"],
)
agent = ConfigDrivenAgent(config=config)
assert agent.get_system_prompt() is None
def test_get_react_config_default_values(self):
"""get_react_config() returns defaults when no SkillConfig"""
config = AgentConfig.from_dict(_sample_llm_config())
agent = ConfigDrivenAgent(config=config)
react_config = agent.get_react_config()
assert react_config["max_steps"] == 10
assert react_config["timeout_seconds"] is None
def test_get_react_config_with_skill_config(self):
"""get_react_config() returns values from SkillConfig"""
from agentkit.skills.base import SkillConfig
skill_config = SkillConfig(
name="test_skill",
agent_type="test",
task_mode="llm_generate",
prompt={"identity": "Test"},
intent={"keywords": ["test"], "description": "Test"},
max_steps=20,
)
agent = ConfigDrivenAgent(config=skill_config)
react_config = agent.get_react_config()
assert react_config["max_steps"] == 20
assert react_config["timeout_seconds"] is None
class TestHandlerPrefixWhitelist:
"""U4: 测试 _import_handler 的模块前缀白名单,防止任意代码执行"""
def _make_agent_with_custom(self, handler_path: str) -> ConfigDrivenAgent:
config = AgentConfig(
name="test_agent",
agent_type="test",
task_mode="custom",
custom_handler=handler_path,
)
return ConfigDrivenAgent(config=config)
def test_allowed_prefix_agentkit(self):
"""agentkit.xxx.handler → 允许通过前缀检查"""
agent = self._make_agent_with_custom("agentkit.handlers.test_handler")
# 前缀检查通过,但模块不存在会报 ImportError我们只验证不报 ConfigValidationError(前缀)
try:
agent._import_handler("agentkit.handlers.test_handler")
except Exception as e:
# 允许 ImportError/AttributeError模块不存在但不允许前缀拒绝
assert "not in allowed module prefixes" not in str(e)
def test_allowed_prefix_app_agent_framework(self):
"""app.agent_framework.handlers.xxx → 允许通过前缀检查"""
agent = self._make_agent_with_custom("app.agent_framework.handlers.xxx_handler")
try:
agent._import_handler("app.agent_framework.handlers.xxx_handler")
except Exception as e:
assert "not in allowed module prefixes" not in str(e)
def test_blocked_os_system(self):
"""os.system → 阻止ConfigValidationError"""
agent = self._make_agent_with_custom("os.system")
with pytest.raises(Exception, match="not in allowed module prefixes"):
agent._import_handler("os.system")
def test_blocked_subprocess_run(self):
"""subprocess.run → 阻止"""
agent = self._make_agent_with_custom("subprocess.run")
with pytest.raises(Exception, match="not in allowed module prefixes"):
agent._import_handler("subprocess.run")
def test_blocked_builtins_exec(self):
"""builtins.exec → 阻止"""
agent = self._make_agent_with_custom("builtins.exec")
with pytest.raises(Exception, match="not in allowed module prefixes"):
agent._import_handler("builtins.exec")
def test_blocked_empty_string(self):
"""空字符串 → 阻止(在 _import_handler 级别直接被前缀检查拒绝)"""
config = AgentConfig(
name="test_agent",
agent_type="test",
task_mode="custom",
custom_handler="agentkit.dummy", # valid config, but we test _import_handler directly
)
agent = ConfigDrivenAgent(config=config)
with pytest.raises(Exception, match="not in allowed module prefixes"):
agent._import_handler("")
def test_blocked_agentkitx_prefix(self):
"""agentkitx. → 阻止(不是 agentkit."""
agent = self._make_agent_with_custom("agentkitx.handlers.evil")
with pytest.raises(Exception, match="not in allowed module prefixes"):
agent._import_handler("agentkitx.handlers.evil")