397 lines
14 KiB
Python
397 lines
14 KiB
Python
"""Tests for ReAct Prompt, Skill/Agent tool sync, MCP bridge, and execution modes"""
|
|
|
|
import asyncio
|
|
import json
|
|
from datetime import datetime, timezone
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent
|
|
from agentkit.core.protocol import TaskMessage, TaskStatus
|
|
from agentkit.skills.base import Skill, SkillConfig
|
|
from agentkit.tools.function_tool import FunctionTool
|
|
from agentkit.tools.registry import ToolRegistry
|
|
|
|
|
|
def _make_skill_config(execution_mode="react", **kwargs) -> SkillConfig:
|
|
"""Helper to create a SkillConfig for testing."""
|
|
defaults = {
|
|
"name": "test_skill",
|
|
"agent_type": "test",
|
|
"task_mode": "llm_generate",
|
|
"supported_tasks": ["test_task"],
|
|
"prompt": {
|
|
"identity": "You are a test assistant.",
|
|
"context": "Context: ${topic}",
|
|
"instructions": "Please help with: ${query}",
|
|
"constraints": "Be concise.",
|
|
"output_format": "Return JSON.",
|
|
"examples": "Example: input -> output",
|
|
},
|
|
"execution_mode": execution_mode,
|
|
"max_steps": 3,
|
|
"intent": {
|
|
"keywords": ["test", "demo"],
|
|
"description": "A test skill",
|
|
},
|
|
}
|
|
defaults.update(kwargs)
|
|
return SkillConfig.from_dict(defaults)
|
|
|
|
|
|
def _make_task(**kwargs) -> TaskMessage:
|
|
"""Helper to create a TaskMessage for testing."""
|
|
defaults = {
|
|
"task_id": "test-task-1",
|
|
"agent_name": "test_skill",
|
|
"task_type": "test_task",
|
|
"priority": 0,
|
|
"input_data": {"topic": "AI", "query": "What is AI?"},
|
|
"callback_url": None,
|
|
"created_at": datetime.now(timezone.utc),
|
|
}
|
|
defaults.update(kwargs)
|
|
return TaskMessage(**defaults)
|
|
|
|
|
|
class TestReActPromptFullRendering:
|
|
"""Test that ReAct mode uses full PromptTemplate.render() output."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_react_uses_full_prompt_template(self):
|
|
"""ReAct mode should use PromptTemplate.render() to get all prompt sections,
|
|
not just identity.
|
|
|
|
In ReAct mode, _handle_react() passes system_prompt to ReActEngine.execute(),
|
|
which prepends it as a system message in the conversation passed to gateway.chat().
|
|
So we check the 'messages' kwarg for a system message containing all sections.
|
|
"""
|
|
config = _make_skill_config(execution_mode="react")
|
|
tool_registry = ToolRegistry()
|
|
|
|
# Mock LLMGateway to capture what messages are sent
|
|
mock_gateway = MagicMock()
|
|
mock_response = MagicMock()
|
|
mock_response.content = json.dumps({"answer": "test"})
|
|
mock_response.usage = MagicMock()
|
|
mock_response.usage.total_tokens = 10
|
|
mock_response.has_tool_calls = False
|
|
mock_gateway.chat = AsyncMock(return_value=mock_response)
|
|
|
|
agent = ConfigDrivenAgent(
|
|
config=config,
|
|
tool_registry=tool_registry,
|
|
llm_gateway=mock_gateway,
|
|
)
|
|
|
|
task = _make_task()
|
|
await agent.handle_task(task)
|
|
|
|
# Verify the gateway was called
|
|
mock_gateway.chat.assert_called_once()
|
|
call_kwargs = mock_gateway.chat.call_args
|
|
|
|
# ReActEngine.execute() puts system_prompt as the first message in conversation
|
|
messages = call_kwargs.kwargs.get("messages", [])
|
|
assert len(messages) > 0, "No messages sent to gateway"
|
|
|
|
# First message should be the system message with all prompt sections
|
|
system_msg = messages[0]
|
|
assert system_msg["role"] == "system", f"First message is not system: {system_msg['role']}"
|
|
system_content = system_msg["content"]
|
|
assert "test assistant" in system_content, f"Identity missing from system message: {system_content}"
|
|
assert "AI" in system_content, f"Context variable not resolved in system message: {system_content}"
|
|
assert "concise" in system_content, f"Constraints missing from system message: {system_content}"
|
|
|
|
# Check that user messages contain instructions + output_format + examples
|
|
user_content = " ".join(m.get("content", "") for m in messages if m["role"] != "system")
|
|
assert "What is AI?" in user_content, f"Instructions variable not resolved: {user_content}"
|
|
assert "JSON" in user_content, f"Output format missing: {user_content}"
|
|
assert "Example" in user_content, f"Examples missing: {user_content}"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_react_without_prompt_template(self):
|
|
"""ReAct mode without prompt template should use input_data as fallback."""
|
|
config = SkillConfig(
|
|
name="no_prompt_skill",
|
|
agent_type="test",
|
|
task_mode="tool_call",
|
|
supported_tasks=["test"],
|
|
execution_mode="react",
|
|
tools=["mock_tool"],
|
|
)
|
|
tool_registry = ToolRegistry()
|
|
|
|
async def mock_func(**kwargs):
|
|
return {"mock": True}
|
|
|
|
tool_registry.register(FunctionTool(name="mock_tool", description="Mock", func=mock_func))
|
|
|
|
mock_gateway = MagicMock()
|
|
mock_response = MagicMock()
|
|
mock_response.content = '{"result": "ok"}'
|
|
mock_response.usage = MagicMock()
|
|
mock_response.usage.total_tokens = 5
|
|
mock_response.has_tool_calls = False
|
|
mock_gateway.chat = AsyncMock(return_value=mock_response)
|
|
|
|
agent = ConfigDrivenAgent(
|
|
config=config,
|
|
tool_registry=tool_registry,
|
|
llm_gateway=mock_gateway,
|
|
)
|
|
|
|
task = _make_task(input_data={"message": "hello"})
|
|
result = await agent.handle_task(task)
|
|
assert isinstance(result, dict)
|
|
|
|
|
|
class TestSkillAgentToolSync:
|
|
"""Test that Skill-bound tools are merged into Agent._tools."""
|
|
|
|
def test_skill_tools_merged_into_agent(self):
|
|
"""When ConfigDrivenAgent receives a SkillConfig with tools,
|
|
the Skill's bound tools should be merged into Agent._tools."""
|
|
config = _make_skill_config(
|
|
execution_mode="react",
|
|
tools=["tool_a", "tool_b"],
|
|
)
|
|
tool_registry = ToolRegistry()
|
|
|
|
async def mock_func(**kwargs):
|
|
return {"mock": True}
|
|
|
|
tool_registry.register(FunctionTool(name="tool_a", description="Tool A", func=mock_func))
|
|
tool_registry.register(FunctionTool(name="tool_b", description="Tool B", func=mock_func))
|
|
|
|
agent = ConfigDrivenAgent(
|
|
config=config,
|
|
tool_registry=tool_registry,
|
|
)
|
|
|
|
# Agent should have both tools from the config
|
|
tool_names = [t.name for t in agent._tools]
|
|
assert "tool_a" in tool_names, f"tool_a not found in agent tools: {tool_names}"
|
|
assert "tool_b" in tool_names, f"tool_b not found in agent tools: {tool_names}"
|
|
|
|
def test_skill_instance_tools_merged(self):
|
|
"""When a Skill instance has tools bound via bind_tool(),
|
|
those tools should be merged into Agent._tools."""
|
|
config = _make_skill_config(execution_mode="react")
|
|
tool_registry = ToolRegistry()
|
|
|
|
agent = ConfigDrivenAgent(
|
|
config=config,
|
|
tool_registry=tool_registry,
|
|
)
|
|
|
|
# Manually bind a tool to the skill instance
|
|
async def extra_func(**kwargs):
|
|
return {"extra": True}
|
|
|
|
extra_tool = FunctionTool(name="extra_tool", description="Extra", func=extra_func)
|
|
agent._skill_instance.bind_tool(extra_tool)
|
|
|
|
# Simulate re-creating agent (in real flow, tools are merged during __init__)
|
|
# For this test, verify the merge logic works
|
|
initial_count = len(agent._tools)
|
|
for tool in agent._skill_instance.tools:
|
|
if not any(t.name == tool.name for t in agent._tools):
|
|
agent.use_tool(tool)
|
|
|
|
tool_names = [t.name for t in agent._tools]
|
|
assert "extra_tool" in tool_names
|
|
assert len(agent._tools) == initial_count + 1
|
|
|
|
|
|
class TestMCPBridge:
|
|
"""Test MCP → ReAct bridge."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_mcp_servers_parameter_accepted(self):
|
|
"""ConfigDrivenAgent should accept mcp_servers parameter."""
|
|
config = _make_skill_config(execution_mode="react")
|
|
tool_registry = ToolRegistry()
|
|
|
|
agent = ConfigDrivenAgent(
|
|
config=config,
|
|
tool_registry=tool_registry,
|
|
mcp_servers={"test_server": "http://localhost:8080"},
|
|
)
|
|
|
|
assert agent._mcp_servers == {"test_server": "http://localhost:8080"}
|
|
assert agent._mcp_tools_registered is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_mcp_lazy_registration_on_task(self):
|
|
"""MCP tools should be lazily registered on first task execution."""
|
|
config = _make_skill_config(execution_mode="react")
|
|
tool_registry = ToolRegistry()
|
|
|
|
mock_gateway = MagicMock()
|
|
mock_response = MagicMock()
|
|
mock_response.content = '{"result": "ok"}'
|
|
mock_response.usage = MagicMock()
|
|
mock_response.usage.total_tokens = 5
|
|
mock_response.has_tool_calls = False
|
|
mock_gateway.chat = AsyncMock(return_value=mock_response)
|
|
|
|
agent = ConfigDrivenAgent(
|
|
config=config,
|
|
tool_registry=tool_registry,
|
|
llm_gateway=mock_gateway,
|
|
mcp_servers={"test_server": "http://localhost:8080"},
|
|
)
|
|
|
|
# Mock MCPClient to avoid real HTTP calls
|
|
with patch("agentkit.mcp.client.MCPClient") as MockMCPClient:
|
|
mock_client_instance = MagicMock()
|
|
mock_client_instance.list_tools = AsyncMock(return_value=[
|
|
{"name": "remote_tool", "description": "A remote tool"}
|
|
])
|
|
mock_mcp_tool = MagicMock()
|
|
mock_mcp_tool.name = "remote_tool"
|
|
mock_client_instance.as_tool = MagicMock(return_value=mock_mcp_tool)
|
|
MockMCPClient.return_value = mock_client_instance
|
|
|
|
task = _make_task()
|
|
await agent.handle_task(task)
|
|
|
|
# MCP tools should now be registered
|
|
assert agent._mcp_tools_registered is True
|
|
mock_client_instance.list_tools.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_mcp_registration_failure_graceful(self):
|
|
"""MCP registration failure should not prevent task execution."""
|
|
config = _make_skill_config(execution_mode="react")
|
|
tool_registry = ToolRegistry()
|
|
|
|
mock_gateway = MagicMock()
|
|
mock_response = MagicMock()
|
|
mock_response.content = '{"result": "ok"}'
|
|
mock_response.usage = MagicMock()
|
|
mock_response.usage.total_tokens = 5
|
|
mock_response.has_tool_calls = False
|
|
mock_gateway.chat = AsyncMock(return_value=mock_response)
|
|
|
|
agent = ConfigDrivenAgent(
|
|
config=config,
|
|
tool_registry=tool_registry,
|
|
llm_gateway=mock_gateway,
|
|
mcp_servers={"bad_server": "http://nonexistent:9999"},
|
|
)
|
|
|
|
with patch("agentkit.mcp.client.MCPClient") as MockMCPClient:
|
|
MockMCPClient.return_value.list_tools = AsyncMock(
|
|
side_effect=Exception("Connection refused")
|
|
)
|
|
|
|
task = _make_task()
|
|
result = await agent.handle_task(task)
|
|
# Should still complete despite MCP failure
|
|
assert isinstance(result, dict)
|
|
|
|
|
|
class TestExecutionModes:
|
|
"""Test execution_mode=react/direct/custom."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_direct_mode_single_llm_call(self):
|
|
"""execution_mode=direct should make a single LLM call without ReAct loop."""
|
|
config = _make_skill_config(execution_mode="direct")
|
|
tool_registry = ToolRegistry()
|
|
|
|
mock_gateway = MagicMock()
|
|
mock_response = MagicMock()
|
|
mock_response.content = json.dumps({"answer": "direct result"})
|
|
mock_response.usage = MagicMock()
|
|
mock_response.usage.total_tokens = 15
|
|
mock_gateway.chat = AsyncMock(return_value=mock_response)
|
|
|
|
agent = ConfigDrivenAgent(
|
|
config=config,
|
|
tool_registry=tool_registry,
|
|
llm_gateway=mock_gateway,
|
|
)
|
|
|
|
task = _make_task()
|
|
result = await agent.handle_task(task)
|
|
|
|
# Should call gateway.chat directly (not ReAct engine)
|
|
mock_gateway.chat.assert_called_once()
|
|
assert result == {"answer": "direct result"}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_custom_mode_with_skill_config(self):
|
|
"""execution_mode=custom should use custom handler."""
|
|
config = _make_skill_config(
|
|
execution_mode="custom",
|
|
custom_handler="test.handlers.mock_handler",
|
|
)
|
|
tool_registry = ToolRegistry()
|
|
|
|
async def mock_handler(task):
|
|
return {"custom": True, "task_id": task.task_id}
|
|
|
|
agent = ConfigDrivenAgent(
|
|
config=config,
|
|
tool_registry=tool_registry,
|
|
custom_handlers={"test.handlers.mock_handler": mock_handler},
|
|
)
|
|
|
|
task = _make_task()
|
|
result = await agent.handle_task(task)
|
|
|
|
assert result["custom"] is True
|
|
assert result["task_id"] == "test-task-1"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_react_mode_uses_react_engine(self):
|
|
"""execution_mode=react should use ReAct engine."""
|
|
config = _make_skill_config(execution_mode="react")
|
|
tool_registry = ToolRegistry()
|
|
|
|
mock_gateway = MagicMock()
|
|
mock_response = MagicMock()
|
|
mock_response.content = json.dumps({"answer": "react result"})
|
|
mock_response.usage = MagicMock()
|
|
mock_response.usage.total_tokens = 20
|
|
mock_response.has_tool_calls = False
|
|
mock_gateway.chat = AsyncMock(return_value=mock_response)
|
|
|
|
agent = ConfigDrivenAgent(
|
|
config=config,
|
|
tool_registry=tool_registry,
|
|
llm_gateway=mock_gateway,
|
|
)
|
|
|
|
task = _make_task()
|
|
result = await agent.handle_task(task)
|
|
|
|
assert isinstance(result, dict)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fallback_to_task_mode_without_skill_config(self):
|
|
"""Without SkillConfig, should fall back to task_mode."""
|
|
config = AgentConfig(
|
|
name="legacy_agent",
|
|
agent_type="test",
|
|
task_mode="llm_generate",
|
|
supported_tasks=["test"],
|
|
prompt={"identity": "Legacy agent"},
|
|
)
|
|
tool_registry = ToolRegistry()
|
|
|
|
agent = ConfigDrivenAgent(
|
|
config=config,
|
|
tool_registry=tool_registry,
|
|
)
|
|
|
|
task = _make_task()
|
|
result = await agent.handle_task(task)
|
|
|
|
# Should return rendered prompt (no LLM client)
|
|
assert "messages" in result or isinstance(result, dict)
|