300 lines
9.1 KiB
Python
300 lines
9.1 KiB
Python
"""Integration tests for tool composition patterns end-to-end"""
|
|
|
|
import pytest
|
|
from unittest.mock import AsyncMock
|
|
|
|
from agentkit.core.base import BaseAgent
|
|
from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent
|
|
from agentkit.core.protocol import AgentCapability, TaskMessage, TaskResult, TaskStatus
|
|
from agentkit.tools.agent_tool import AgentTool
|
|
from agentkit.tools.composition import DynamicSelector, ParallelFanOut, SequentialChain
|
|
from agentkit.tools.function_tool import FunctionTool
|
|
from datetime import datetime, timezone
|
|
|
|
|
|
# ── Helper Functions ───────────────────────────────────────
|
|
|
|
|
|
def add_prefix(text: str, prefix: str = "hello") -> dict:
|
|
"""Add a prefix to text."""
|
|
return {"text": f"{prefix} {text}"}
|
|
|
|
|
|
def make_uppercase(text: str) -> dict:
|
|
"""Convert text to uppercase."""
|
|
return {"text": text.upper()}
|
|
|
|
|
|
def multiply(x: int, y: int = 2, **kwargs) -> dict:
|
|
"""Multiply two numbers (ignores extra kwargs for chaining)."""
|
|
return {"product": x * y}
|
|
|
|
|
|
def double_product(product: int) -> dict:
|
|
"""Double the product value (for chaining after multiply)."""
|
|
return {"total": product * 2}
|
|
|
|
|
|
def search_data(query: str, **kwargs) -> dict:
|
|
"""Search for data (ignores extra kwargs)."""
|
|
return {"search_results": [f"result for {query}"]}
|
|
|
|
|
|
def calculate(expression: str, **kwargs) -> dict:
|
|
"""Calculate an expression (ignores extra kwargs)."""
|
|
return {"calculation_result": f"calc: {expression}"}
|
|
|
|
|
|
def translate(text: str, **kwargs) -> dict:
|
|
"""Translate text (ignores extra kwargs)."""
|
|
return {"translated": f"[{kwargs.get('target_lang', 'en')}] {text}"}
|
|
|
|
|
|
# ── Tests ──────────────────────────────────────────────────
|
|
|
|
|
|
@pytest.mark.integration
|
|
async def test_sequential_chain():
|
|
"""SequentialChain: two FunctionTools execute in sequence, second receives first's output."""
|
|
tool1 = FunctionTool(
|
|
name="add_prefix",
|
|
description="Add prefix to text",
|
|
func=add_prefix,
|
|
)
|
|
tool2 = FunctionTool(
|
|
name="make_uppercase",
|
|
description="Convert text to uppercase",
|
|
func=make_uppercase,
|
|
)
|
|
|
|
chain = SequentialChain(
|
|
name="prefix_then_uppercase",
|
|
description="Add prefix then uppercase",
|
|
tools=[tool1, tool2],
|
|
)
|
|
|
|
result = await chain.safe_execute(text="world")
|
|
assert result["text"] == "HELLO WORLD"
|
|
|
|
|
|
@pytest.mark.integration
|
|
async def test_sequential_chain_numeric():
|
|
"""SequentialChain with numeric tools: multiply then double_product (chained output)."""
|
|
tool_multiply = FunctionTool(
|
|
name="multiply",
|
|
description="Multiply numbers",
|
|
func=multiply,
|
|
)
|
|
tool_double = FunctionTool(
|
|
name="double_product",
|
|
description="Double the product value",
|
|
func=double_product,
|
|
)
|
|
|
|
chain = SequentialChain(
|
|
name="multiply_then_double",
|
|
description="Multiply then double the product",
|
|
tools=[tool_multiply, tool_double],
|
|
)
|
|
|
|
# multiply(x=3, y=2) -> {"product": 6}
|
|
# double_product(product=6) -> {"total": 12}
|
|
result = await chain.safe_execute(x=3, y=2)
|
|
assert result["total"] == 12
|
|
|
|
|
|
@pytest.mark.integration
|
|
async def test_parallel_fan_out():
|
|
"""ParallelFanOut: three FunctionTools execute in parallel, results merged."""
|
|
tool_search = FunctionTool(
|
|
name="search",
|
|
description="Search for data",
|
|
func=search_data,
|
|
tags=["search"],
|
|
)
|
|
tool_calc = FunctionTool(
|
|
name="calculate",
|
|
description="Calculate expression",
|
|
func=calculate,
|
|
tags=["calculate"],
|
|
)
|
|
tool_translate = FunctionTool(
|
|
name="translate",
|
|
description="Translate text",
|
|
func=translate,
|
|
tags=["translate"],
|
|
)
|
|
|
|
fan_out = ParallelFanOut(
|
|
name="multi_action",
|
|
description="Run multiple actions in parallel",
|
|
tools=[tool_search, tool_calc, tool_translate],
|
|
)
|
|
|
|
result = await fan_out.safe_execute(query="AI trends", expression="2+2", text="hello")
|
|
|
|
# All three tools should have contributed to merged result
|
|
assert "search_results" in result
|
|
assert "calculation_result" in result
|
|
assert "translated" in result
|
|
|
|
|
|
@pytest.mark.integration
|
|
async def test_parallel_fan_out_namespace_merge():
|
|
"""ParallelFanOut with namespace merge strategy."""
|
|
tool_search = FunctionTool(
|
|
name="search",
|
|
description="Search for data",
|
|
func=search_data,
|
|
)
|
|
tool_translate = FunctionTool(
|
|
name="translate",
|
|
description="Translate text",
|
|
func=translate,
|
|
)
|
|
|
|
fan_out = ParallelFanOut(
|
|
name="namespace_fanout",
|
|
description="Namespace merge fan-out",
|
|
tools=[tool_search, tool_translate],
|
|
merge_strategy="namespace",
|
|
)
|
|
|
|
result = await fan_out.safe_execute(query="test", text="hello")
|
|
|
|
# Namespace strategy: results keyed by tool name
|
|
assert "search" in result
|
|
assert "translate" in result
|
|
assert "search_results" in result["search"]
|
|
assert "translated" in result["translate"]
|
|
|
|
|
|
@pytest.mark.integration
|
|
async def test_dynamic_selector_keyword_mode():
|
|
"""DynamicSelector: keyword-based tool selection."""
|
|
tool_search = FunctionTool(
|
|
name="search_tool",
|
|
description="Search for information",
|
|
func=search_data,
|
|
tags=["search"],
|
|
)
|
|
tool_calc = FunctionTool(
|
|
name="calculate_tool",
|
|
description="Calculate mathematical expressions",
|
|
func=calculate,
|
|
tags=["calculate"],
|
|
)
|
|
tool_translate = FunctionTool(
|
|
name="translate_tool",
|
|
description="Translate text between languages",
|
|
func=translate,
|
|
tags=["translate"],
|
|
)
|
|
|
|
selector = DynamicSelector(
|
|
name="smart_tool",
|
|
description="Dynamically select a tool",
|
|
tools=[tool_search, tool_calc, tool_translate],
|
|
mode="keyword",
|
|
)
|
|
|
|
# Select search tool via intent
|
|
result = await selector.safe_execute(query="AI trends", _intent="search")
|
|
assert "search_results" in result
|
|
|
|
# Select calculate tool via intent
|
|
result = await selector.safe_execute(expression="2+2", _intent="calculate")
|
|
assert "calculation_result" in result
|
|
|
|
|
|
@pytest.mark.integration
|
|
async def test_dynamic_selector_llm_mode():
|
|
"""DynamicSelector: LLM-based tool selection with mock LLM."""
|
|
tool_search = FunctionTool(
|
|
name="search_tool",
|
|
description="Search for information",
|
|
func=search_data,
|
|
tags=["search"],
|
|
)
|
|
tool_calc = FunctionTool(
|
|
name="calculate_tool",
|
|
description="Calculate mathematical expressions",
|
|
func=calculate,
|
|
tags=["calculate"],
|
|
)
|
|
|
|
# Mock LLM that always selects tool index 0 (search_tool)
|
|
mock_llm = AsyncMock()
|
|
mock_llm.chat = AsyncMock(return_value="0")
|
|
|
|
selector = DynamicSelector(
|
|
name="llm_smart_tool",
|
|
description="LLM-based dynamic tool selector",
|
|
tools=[tool_search, tool_calc],
|
|
mode="llm",
|
|
llm_client=mock_llm,
|
|
)
|
|
|
|
result = await selector.safe_execute(query="test query")
|
|
assert "search_results" in result
|
|
|
|
|
|
@pytest.mark.integration
|
|
async def test_agent_tool_wrap_and_call():
|
|
"""AgentTool: wrap Agent as Tool and call it."""
|
|
|
|
class SimpleAgent(BaseAgent):
|
|
def __init__(self):
|
|
super().__init__(name="simple_agent", agent_type="simple")
|
|
|
|
def get_capabilities(self) -> AgentCapability:
|
|
return AgentCapability(
|
|
agent_name=self.name,
|
|
agent_type=self.agent_type,
|
|
version=self.version,
|
|
supported_tasks=["simple"],
|
|
max_concurrency=1,
|
|
description="Simple agent for testing",
|
|
)
|
|
|
|
async def handle_task(self, task: TaskMessage) -> dict:
|
|
return {"greeting": f"Hello, {task.input_data.get('name', 'world')}!"}
|
|
|
|
agent = SimpleAgent()
|
|
await agent.start()
|
|
|
|
# Create a mock dispatcher that routes to the agent directly
|
|
class MockDispatcher:
|
|
def __init__(self, target_agent: BaseAgent):
|
|
self._agent = target_agent
|
|
self._results: dict[str, TaskResult] = {}
|
|
|
|
async def dispatch(self, task: TaskMessage):
|
|
result = await self._agent.execute(task)
|
|
self._results[task.task_id] = result
|
|
|
|
async def get_task_status(self, task_id: str) -> dict:
|
|
result = self._results.get(task_id)
|
|
if result is None:
|
|
return {"status": "pending"}
|
|
return {
|
|
"status": result.status,
|
|
"output_data": result.output_data,
|
|
"error_message": result.error_message,
|
|
}
|
|
|
|
dispatcher = MockDispatcher(agent)
|
|
|
|
agent_tool = AgentTool(
|
|
name="simple_agent_tool",
|
|
description="Call the simple agent",
|
|
agent_name="simple_agent",
|
|
task_type="simple",
|
|
)
|
|
agent_tool.set_dispatcher(dispatcher)
|
|
|
|
result = await agent_tool.safe_execute(name="Alice")
|
|
assert result["greeting"] == "Hello, Alice!"
|
|
|
|
await agent.stop()
|