"""Integration tests for tool composition patterns end-to-end""" import pytest from unittest.mock import AsyncMock from agentkit.core.base import BaseAgent from agentkit.core.config_driven import AgentConfig, ConfigDrivenAgent from agentkit.core.protocol import AgentCapability, TaskMessage, TaskResult, TaskStatus from agentkit.tools.agent_tool import AgentTool from agentkit.tools.composition import DynamicSelector, ParallelFanOut, SequentialChain from agentkit.tools.function_tool import FunctionTool from datetime import datetime, timezone # ── Helper Functions ─────────────────────────────────────── def add_prefix(text: str, prefix: str = "hello") -> dict: """Add a prefix to text.""" return {"text": f"{prefix} {text}"} def make_uppercase(text: str) -> dict: """Convert text to uppercase.""" return {"text": text.upper()} def multiply(x: int, y: int = 2, **kwargs) -> dict: """Multiply two numbers (ignores extra kwargs for chaining).""" return {"product": x * y} def double_product(product: int) -> dict: """Double the product value (for chaining after multiply).""" return {"total": product * 2} def search_data(query: str, **kwargs) -> dict: """Search for data (ignores extra kwargs).""" return {"search_results": [f"result for {query}"]} def calculate(expression: str, **kwargs) -> dict: """Calculate an expression (ignores extra kwargs).""" return {"calculation_result": f"calc: {expression}"} def translate(text: str, **kwargs) -> dict: """Translate text (ignores extra kwargs).""" return {"translated": f"[{kwargs.get('target_lang', 'en')}] {text}"} # ── Tests ────────────────────────────────────────────────── @pytest.mark.integration async def test_sequential_chain(): """SequentialChain: two FunctionTools execute in sequence, second receives first's output.""" tool1 = FunctionTool( name="add_prefix", description="Add prefix to text", func=add_prefix, ) tool2 = FunctionTool( name="make_uppercase", description="Convert text to uppercase", func=make_uppercase, ) chain = SequentialChain( name="prefix_then_uppercase", description="Add prefix then uppercase", tools=[tool1, tool2], ) result = await chain.safe_execute(text="world") assert result["text"] == "HELLO WORLD" @pytest.mark.integration async def test_sequential_chain_numeric(): """SequentialChain with numeric tools: multiply then double_product (chained output).""" tool_multiply = FunctionTool( name="multiply", description="Multiply numbers", func=multiply, ) tool_double = FunctionTool( name="double_product", description="Double the product value", func=double_product, ) chain = SequentialChain( name="multiply_then_double", description="Multiply then double the product", tools=[tool_multiply, tool_double], ) # multiply(x=3, y=2) -> {"product": 6} # double_product(product=6) -> {"total": 12} result = await chain.safe_execute(x=3, y=2) assert result["total"] == 12 @pytest.mark.integration async def test_parallel_fan_out(): """ParallelFanOut: three FunctionTools execute in parallel, results merged.""" tool_search = FunctionTool( name="search", description="Search for data", func=search_data, tags=["search"], ) tool_calc = FunctionTool( name="calculate", description="Calculate expression", func=calculate, tags=["calculate"], ) tool_translate = FunctionTool( name="translate", description="Translate text", func=translate, tags=["translate"], ) fan_out = ParallelFanOut( name="multi_action", description="Run multiple actions in parallel", tools=[tool_search, tool_calc, tool_translate], ) result = await fan_out.safe_execute(query="AI trends", expression="2+2", text="hello") # All three tools should have contributed to merged result assert "search_results" in result assert "calculation_result" in result assert "translated" in result @pytest.mark.integration async def test_parallel_fan_out_namespace_merge(): """ParallelFanOut with namespace merge strategy.""" tool_search = FunctionTool( name="search", description="Search for data", func=search_data, ) tool_translate = FunctionTool( name="translate", description="Translate text", func=translate, ) fan_out = ParallelFanOut( name="namespace_fanout", description="Namespace merge fan-out", tools=[tool_search, tool_translate], merge_strategy="namespace", ) result = await fan_out.safe_execute(query="test", text="hello") # Namespace strategy: results keyed by tool name assert "search" in result assert "translate" in result assert "search_results" in result["search"] assert "translated" in result["translate"] @pytest.mark.integration async def test_dynamic_selector_keyword_mode(): """DynamicSelector: keyword-based tool selection.""" tool_search = FunctionTool( name="search_tool", description="Search for information", func=search_data, tags=["search"], ) tool_calc = FunctionTool( name="calculate_tool", description="Calculate mathematical expressions", func=calculate, tags=["calculate"], ) tool_translate = FunctionTool( name="translate_tool", description="Translate text between languages", func=translate, tags=["translate"], ) selector = DynamicSelector( name="smart_tool", description="Dynamically select a tool", tools=[tool_search, tool_calc, tool_translate], mode="keyword", ) # Select search tool via intent result = await selector.safe_execute(query="AI trends", _intent="search") assert "search_results" in result # Select calculate tool via intent result = await selector.safe_execute(expression="2+2", _intent="calculate") assert "calculation_result" in result @pytest.mark.integration async def test_dynamic_selector_llm_mode(): """DynamicSelector: LLM-based tool selection with mock LLM.""" tool_search = FunctionTool( name="search_tool", description="Search for information", func=search_data, tags=["search"], ) tool_calc = FunctionTool( name="calculate_tool", description="Calculate mathematical expressions", func=calculate, tags=["calculate"], ) # Mock LLM that always selects tool index 0 (search_tool) mock_llm = AsyncMock() mock_llm.chat = AsyncMock(return_value="0") selector = DynamicSelector( name="llm_smart_tool", description="LLM-based dynamic tool selector", tools=[tool_search, tool_calc], mode="llm", llm_client=mock_llm, ) result = await selector.safe_execute(query="test query") assert "search_results" in result @pytest.mark.integration async def test_agent_tool_wrap_and_call(): """AgentTool: wrap Agent as Tool and call it.""" class SimpleAgent(BaseAgent): def __init__(self): super().__init__(name="simple_agent", agent_type="simple") def get_capabilities(self) -> AgentCapability: return AgentCapability( agent_name=self.name, agent_type=self.agent_type, version=self.version, supported_tasks=["simple"], max_concurrency=1, description="Simple agent for testing", ) async def handle_task(self, task: TaskMessage) -> dict: return {"greeting": f"Hello, {task.input_data.get('name', 'world')}!"} agent = SimpleAgent() await agent.start() # Create a mock dispatcher that routes to the agent directly class MockDispatcher: def __init__(self, target_agent: BaseAgent): self._agent = target_agent self._results: dict[str, TaskResult] = {} async def dispatch(self, task: TaskMessage): result = await self._agent.execute(task) self._results[task.task_id] = result async def get_task_status(self, task_id: str) -> dict: result = self._results.get(task_id) if result is None: return {"status": "pending"} return { "status": result.status, "output_data": result.output_data, "error_message": result.error_message, } dispatcher = MockDispatcher(agent) agent_tool = AgentTool( name="simple_agent_tool", description="Call the simple agent", agent_name="simple_agent", task_type="simple", ) agent_tool.set_dispatcher(dispatcher) result = await agent_tool.safe_execute(name="Alice") assert result["greeting"] == "Hello, Alice!" await agent.stop()