fischer-agentkit/tests/integration/test_tool_composition.py

300 lines
9.1 KiB
Python

"""Integration tests for tool composition patterns end-to-end"""
import pytest
from unittest.mock import AsyncMock
from agentkit.core.base import BaseAgent
from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent
from agentkit.core.protocol import AgentCapability, TaskMessage, TaskResult, TaskStatus
from agentkit.tools.agent_tool import AgentTool
from agentkit.tools.composition import DynamicSelector, ParallelFanOut, SequentialChain
from agentkit.tools.function_tool import FunctionTool
from datetime import datetime, timezone
# ── Helper Functions ───────────────────────────────────────
def add_prefix(text: str, prefix: str = "hello") -> dict:
"""Add a prefix to text."""
return {"text": f"{prefix} {text}"}
def make_uppercase(text: str) -> dict:
"""Convert text to uppercase."""
return {"text": text.upper()}
def multiply(x: int, y: int = 2, **kwargs) -> dict:
"""Multiply two numbers (ignores extra kwargs for chaining)."""
return {"product": x * y}
def double_product(product: int) -> dict:
"""Double the product value (for chaining after multiply)."""
return {"total": product * 2}
def search_data(query: str, **kwargs) -> dict:
"""Search for data (ignores extra kwargs)."""
return {"search_results": [f"result for {query}"]}
def calculate(expression: str, **kwargs) -> dict:
"""Calculate an expression (ignores extra kwargs)."""
return {"calculation_result": f"calc: {expression}"}
def translate(text: str, **kwargs) -> dict:
"""Translate text (ignores extra kwargs)."""
return {"translated": f"[{kwargs.get('target_lang', 'en')}] {text}"}
# ── Tests ──────────────────────────────────────────────────
@pytest.mark.integration
async def test_sequential_chain():
"""SequentialChain: two FunctionTools execute in sequence, second receives first's output."""
tool1 = FunctionTool(
name="add_prefix",
description="Add prefix to text",
func=add_prefix,
)
tool2 = FunctionTool(
name="make_uppercase",
description="Convert text to uppercase",
func=make_uppercase,
)
chain = SequentialChain(
name="prefix_then_uppercase",
description="Add prefix then uppercase",
tools=[tool1, tool2],
)
result = await chain.safe_execute(text="world")
assert result["text"] == "HELLO WORLD"
@pytest.mark.integration
async def test_sequential_chain_numeric():
"""SequentialChain with numeric tools: multiply then double_product (chained output)."""
tool_multiply = FunctionTool(
name="multiply",
description="Multiply numbers",
func=multiply,
)
tool_double = FunctionTool(
name="double_product",
description="Double the product value",
func=double_product,
)
chain = SequentialChain(
name="multiply_then_double",
description="Multiply then double the product",
tools=[tool_multiply, tool_double],
)
# multiply(x=3, y=2) -> {"product": 6}
# double_product(product=6) -> {"total": 12}
result = await chain.safe_execute(x=3, y=2)
assert result["total"] == 12
@pytest.mark.integration
async def test_parallel_fan_out():
"""ParallelFanOut: three FunctionTools execute in parallel, results merged."""
tool_search = FunctionTool(
name="search",
description="Search for data",
func=search_data,
tags=["search"],
)
tool_calc = FunctionTool(
name="calculate",
description="Calculate expression",
func=calculate,
tags=["calculate"],
)
tool_translate = FunctionTool(
name="translate",
description="Translate text",
func=translate,
tags=["translate"],
)
fan_out = ParallelFanOut(
name="multi_action",
description="Run multiple actions in parallel",
tools=[tool_search, tool_calc, tool_translate],
)
result = await fan_out.safe_execute(query="AI trends", expression="2+2", text="hello")
# All three tools should have contributed to merged result
assert "search_results" in result
assert "calculation_result" in result
assert "translated" in result
@pytest.mark.integration
async def test_parallel_fan_out_namespace_merge():
"""ParallelFanOut with namespace merge strategy."""
tool_search = FunctionTool(
name="search",
description="Search for data",
func=search_data,
)
tool_translate = FunctionTool(
name="translate",
description="Translate text",
func=translate,
)
fan_out = ParallelFanOut(
name="namespace_fanout",
description="Namespace merge fan-out",
tools=[tool_search, tool_translate],
merge_strategy="namespace",
)
result = await fan_out.safe_execute(query="test", text="hello")
# Namespace strategy: results keyed by tool name
assert "search" in result
assert "translate" in result
assert "search_results" in result["search"]
assert "translated" in result["translate"]
@pytest.mark.integration
async def test_dynamic_selector_keyword_mode():
"""DynamicSelector: keyword-based tool selection."""
tool_search = FunctionTool(
name="search_tool",
description="Search for information",
func=search_data,
tags=["search"],
)
tool_calc = FunctionTool(
name="calculate_tool",
description="Calculate mathematical expressions",
func=calculate,
tags=["calculate"],
)
tool_translate = FunctionTool(
name="translate_tool",
description="Translate text between languages",
func=translate,
tags=["translate"],
)
selector = DynamicSelector(
name="smart_tool",
description="Dynamically select a tool",
tools=[tool_search, tool_calc, tool_translate],
mode="keyword",
)
# Select search tool via intent
result = await selector.safe_execute(query="AI trends", _intent="search")
assert "search_results" in result
# Select calculate tool via intent
result = await selector.safe_execute(expression="2+2", _intent="calculate")
assert "calculation_result" in result
@pytest.mark.integration
async def test_dynamic_selector_llm_mode():
"""DynamicSelector: LLM-based tool selection with mock LLM."""
tool_search = FunctionTool(
name="search_tool",
description="Search for information",
func=search_data,
tags=["search"],
)
tool_calc = FunctionTool(
name="calculate_tool",
description="Calculate mathematical expressions",
func=calculate,
tags=["calculate"],
)
# Mock LLM that always selects tool index 0 (search_tool)
mock_llm = AsyncMock()
mock_llm.chat = AsyncMock(return_value="0")
selector = DynamicSelector(
name="llm_smart_tool",
description="LLM-based dynamic tool selector",
tools=[tool_search, tool_calc],
mode="llm",
llm_client=mock_llm,
)
result = await selector.safe_execute(query="test query")
assert "search_results" in result
@pytest.mark.integration
async def test_agent_tool_wrap_and_call():
"""AgentTool: wrap Agent as Tool and call it."""
class SimpleAgent(BaseAgent):
def __init__(self):
super().__init__(name="simple_agent", agent_type="simple")
def get_capabilities(self) -> AgentCapability:
return AgentCapability(
agent_name=self.name,
agent_type=self.agent_type,
version=self.version,
supported_tasks=["simple"],
max_concurrency=1,
description="Simple agent for testing",
)
async def handle_task(self, task: TaskMessage) -> dict:
return {"greeting": f"Hello, {task.input_data.get('name', 'world')}!"}
agent = SimpleAgent()
await agent.start()
# Create a mock dispatcher that routes to the agent directly
class MockDispatcher:
def __init__(self, target_agent: BaseAgent):
self._agent = target_agent
self._results: dict[str, TaskResult] = {}
async def dispatch(self, task: TaskMessage):
result = await self._agent.execute(task)
self._results[task.task_id] = result
async def get_task_status(self, task_id: str) -> dict:
result = self._results.get(task_id)
if result is None:
return {"status": "pending"}
return {
"status": result.status,
"output_data": result.output_data,
"error_message": result.error_message,
}
dispatcher = MockDispatcher(agent)
agent_tool = AgentTool(
name="simple_agent_tool",
description="Call the simple agent",
agent_name="simple_agent",
task_type="simple",
)
agent_tool.set_dispatcher(dispatcher)
result = await agent_tool.safe_execute(name="Alice")
assert result["greeting"] == "Hello, Alice!"
await agent.stop()