fischer-agentkit/tests/unit/test_config_driven.py

357 lines
12 KiB
Python

"""U2 测试: ConfigDrivenAgent + YAML 配置驱动"""
import json
import tempfile
from pathlib import Path
import pytest
import yaml
from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent
from agentkit.core.protocol import TaskMessage, TaskStatus
from agentkit.core.standalone import StandaloneRunner
from agentkit.tools.function_tool import FunctionTool
from agentkit.tools.registry import ToolRegistry
# ── Fixtures ──────────────────────────────────────────────
def _make_task(**overrides) -> TaskMessage:
defaults = dict(
task_id="test-task-001",
agent_name="test_agent",
task_type="generate",
priority=1,
input_data={"query": "hello"},
callback_url=None,
created_at=None,
)
defaults.update(overrides)
return TaskMessage.from_dict(defaults)
def _sample_llm_config() -> dict:
return {
"name": "content_generator",
"agent_type": "content_generation",
"version": "1.0.0",
"description": "内容生成 Agent",
"task_mode": "llm_generate",
"supported_tasks": ["content_generation"],
"max_concurrency": 2,
"prompt": {
"identity": "你是一个专业的内容生成助手",
"instructions": "根据用户需求生成高质量内容",
"output_format": "以 JSON 格式输出 {title, content}",
},
"llm": {
"model": "gpt-4",
"temperature": 0.7,
},
}
def _sample_tool_call_config() -> dict:
return {
"name": "citation_detector",
"agent_type": "citation_detection",
"task_mode": "tool_call",
"description": "引用检测 Agent",
"tools": ["check_citation"],
}
def _sample_custom_config() -> dict:
return {
"name": "monitor",
"agent_type": "monitoring",
"task_mode": "custom",
"description": "监控 Agent",
"custom_handler": "my_handler",
}
# ── AgentConfig 测试 ──────────────────────────────────────
class TestAgentConfig:
def test_from_dict_llm_generate(self):
config = AgentConfig.from_dict(_sample_llm_config())
assert config.name == "content_generator"
assert config.task_mode == "llm_generate"
assert config.prompt["identity"] == "你是一个专业的内容生成助手"
assert config.llm["model"] == "gpt-4"
def test_from_dict_tool_call(self):
config = AgentConfig.from_dict(_sample_tool_call_config())
assert config.task_mode == "tool_call"
assert config.tools == ["check_citation"]
def test_from_dict_custom(self):
config = AgentConfig.from_dict(_sample_custom_config())
assert config.task_mode == "custom"
assert config.custom_handler == "my_handler"
def test_invalid_task_mode(self):
with pytest.raises(Exception, match="Invalid task_mode"):
AgentConfig(name="x", agent_type="x", task_mode="invalid_mode")
def test_llm_generate_requires_prompt(self):
with pytest.raises(Exception, match="llm_generate mode requires"):
AgentConfig(name="x", agent_type="x", task_mode="llm_generate", prompt=None)
def test_tool_call_requires_tools(self):
with pytest.raises(Exception, match="tool_call mode requires"):
AgentConfig(name="x", agent_type="x", task_mode="tool_call", tools=[])
def test_custom_requires_handler(self):
with pytest.raises(Exception, match="custom mode requires"):
AgentConfig(name="x", agent_type="x", task_mode="custom", custom_handler=None)
def test_from_yaml(self):
with tempfile.NamedTemporaryFile(
mode="w", suffix=".yaml", delete=False
) as f:
yaml.dump(_sample_llm_config(), f)
f.flush()
config = AgentConfig.from_yaml(f.name)
assert config.name == "content_generator"
assert config.task_mode == "llm_generate"
def test_to_dict_roundtrip(self):
original = _sample_llm_config()
config = AgentConfig.from_dict(original)
result = config.to_dict()
assert result["name"] == original["name"]
assert result["task_mode"] == original["task_mode"]
assert result["prompt"] == original["prompt"]
# ── ConfigDrivenAgent 测试 ────────────────────────────────
class TestConfigDrivenAgent:
async def test_llm_generate_no_client(self):
"""无 LLM 客户端时降级返回渲染后的 Prompt"""
config = AgentConfig.from_dict(_sample_llm_config())
agent = ConfigDrivenAgent(config=config)
task = _make_task()
result = await agent.handle_task(task)
assert result["mode"] == "llm_generate_no_client"
assert len(result["messages"]) == 2 # system + user
async def test_llm_generate_with_client(self):
"""有 LLM 客户端时调用 LLM 并解析结果"""
class MockLLMClient:
async def chat(self, messages, **kwargs):
return json.dumps({"title": "Test", "content": "Hello world"})
config = AgentConfig.from_dict(_sample_llm_config())
agent = ConfigDrivenAgent(config=config, llm_client=MockLLMClient())
task = _make_task()
result = await agent.handle_task(task)
assert result["title"] == "Test"
assert result["content"] == "Hello world"
async def test_llm_generate_with_markdown_json(self):
"""LLM 返回 markdown 代码块包裹的 JSON"""
class MockLLMClient:
async def chat(self, messages, **kwargs):
return '```json\n{"title": "Wrapped", "content": "In markdown"}\n```'
config = AgentConfig.from_dict(_sample_llm_config())
agent = ConfigDrivenAgent(config=config, llm_client=MockLLMClient())
task = _make_task()
result = await agent.handle_task(task)
assert result["title"] == "Wrapped"
async def test_llm_generate_fallback_text(self):
"""LLM 返回非 JSON 时降级为文本"""
class MockLLMClient:
async def chat(self, messages, **kwargs):
return "This is plain text response"
config = AgentConfig.from_dict(_sample_llm_config())
agent = ConfigDrivenAgent(config=config, llm_client=MockLLMClient())
task = _make_task()
result = await agent.handle_task(task)
assert result["text"] == "This is plain text response"
async def test_tool_call_mode(self):
"""tool_call 模式调用注册的 Tool"""
registry = ToolRegistry()
async def check_citation(url: str, **kwargs) -> dict:
return {"found": True, "url": url}
tool = FunctionTool(
name="check_citation",
description="Check citation",
func=check_citation,
)
registry.register(tool)
config = AgentConfig.from_dict(_sample_tool_call_config())
agent = ConfigDrivenAgent(config=config, tool_registry=registry)
task = _make_task(input_data={"url": "https://example.com"})
result = await agent.handle_task(task)
assert result["found"] is True
assert result["url"] == "https://example.com"
async def test_custom_mode(self):
"""custom 模式调用自定义 handler"""
config = AgentConfig.from_dict(_sample_custom_config())
async def my_handler(task):
return {"status": "monitored", "task_id": task.task_id}
agent = ConfigDrivenAgent(
config=config,
custom_handlers={"my_handler": my_handler},
)
task = _make_task()
result = await agent.handle_task(task)
assert result["status"] == "monitored"
assert result["task_id"] == "test-task-001"
async def test_execute_wraps_task_result(self):
"""execute() 自动包装 handle_task 结果为 TaskResult"""
config = AgentConfig.from_dict(_sample_llm_config())
agent = ConfigDrivenAgent(config=config)
task = _make_task()
result = await agent.execute(task)
assert result.status == TaskStatus.COMPLETED
assert result.output_data is not None
assert result.metrics["elapsed_seconds"] >= 0
def test_get_capabilities(self):
"""能力声明从配置正确构建"""
config = AgentConfig.from_dict(_sample_llm_config())
agent = ConfigDrivenAgent(config=config)
cap = agent.get_capabilities()
assert cap.agent_name == "content_generator"
assert cap.agent_type == "content_generation"
assert cap.max_concurrency == 2
assert "content_generation" in cap.supported_tasks
def test_prompt_template_rendering(self):
"""Prompt 模板正确渲染"""
config = AgentConfig.from_dict(_sample_llm_config())
agent = ConfigDrivenAgent(config=config)
assert agent.prompt_template is not None
messages = agent.prompt_template.render(variables={"query": "test"})
assert len(messages) == 2
assert "专业的内容生成助手" in messages[0]["content"]
async def test_callable_llm_client(self):
"""LLM 客户端为可调用对象"""
async def mock_llm(messages, **kwargs):
return '{"result": "from_callable"}'
config = AgentConfig.from_dict(_sample_llm_config())
agent = ConfigDrivenAgent(config=config, llm_client=mock_llm)
task = _make_task()
result = await agent.handle_task(task)
assert result["result"] == "from_callable"
# ── StandaloneRunner 测试 ─────────────────────────────────
class TestStandaloneRunner:
def test_discover_configs(self):
"""自动发现 YAML 配置"""
with tempfile.TemporaryDirectory() as tmpdir:
for name in ["agent_a.yaml", "agent_b.yml"]:
config = {
"name": name.replace(".", "_"),
"agent_type": "test",
"task_mode": "llm_generate",
"prompt": {"identity": "test", "instructions": "test"},
}
with open(Path(tmpdir) / name, "w") as f:
yaml.dump(config, f)
runner = StandaloneRunner(config_dir=tmpdir)
configs = runner.discover_configs()
assert len(configs) == 2
def test_build_agents(self):
"""从配置构建 Agent 实例"""
with tempfile.TemporaryDirectory() as tmpdir:
config = _sample_llm_config()
with open(Path(tmpdir) / "test.yaml", "w") as f:
yaml.dump(config, f)
runner = StandaloneRunner(config_dir=tmpdir)
agents = runner.build_agents()
assert "content_generator" in agents
assert agents["content_generator"].config.task_mode == "llm_generate"
def test_add_tool(self):
"""添加工具到注册中心"""
runner = StandaloneRunner()
async def my_tool(x: int) -> dict:
return {"doubled": x * 2}
tool = FunctionTool(name="my_tool", description="Test tool", func=my_tool)
runner.add_tool(tool)
assert runner._tool_registry.has_tool("my_tool")
async def test_execute_task_local(self):
"""本地模式执行任务"""
with tempfile.TemporaryDirectory() as tmpdir:
config = _sample_llm_config()
with open(Path(tmpdir) / "test.yaml", "w") as f:
yaml.dump(config, f)
runner = StandaloneRunner(config_dir=tmpdir)
runner.build_agents()
task = _make_task(agent_name="content_generator")
result = await runner.execute_task("content_generator", task)
assert result is not None
assert result["status"] == "completed"
def test_empty_config_dir(self):
"""空配置目录不报错"""
with tempfile.TemporaryDirectory() as tmpdir:
runner = StandaloneRunner(config_dir=tmpdir)
configs = runner.discover_configs()
assert len(configs) == 0
def test_nonexistent_config_dir(self):
"""不存在的配置目录不报错"""
runner = StandaloneRunner(config_dir="/nonexistent/path")
configs = runner.discover_configs()
assert len(configs) == 0