From d73a3391ab554feabd3b588f1780310a5e560057 Mon Sep 17 00:00:00 2001 From: chiguyong Date: Thu, 4 Jun 2026 22:42:22 +0800 Subject: [PATCH] feat(tools): add MCPTool, SequentialChain, ParallelFanOut, DynamicSelector - MCPTool: call remote MCP tools via MCPClient - SequentialChain: chain tools with output-to-input piping - ParallelFanOut: execute tools concurrently with merge strategies - DynamicSelector: keyword/LLM-based tool selection - 14 new tests, total 70 passing --- src/agentkit/tools/__init__.py | 6 + src/agentkit/tools/composition.py | 260 ++++++++++++++++++++++++ src/agentkit/tools/mcp_tool.py | 51 +++++ tests/unit/test_tool_composition.py | 303 ++++++++++++++++++++++++++++ 4 files changed, 620 insertions(+) create mode 100644 src/agentkit/tools/composition.py create mode 100644 src/agentkit/tools/mcp_tool.py create mode 100644 tests/unit/test_tool_composition.py diff --git a/src/agentkit/tools/__init__.py b/src/agentkit/tools/__init__.py index 57ac4ac..f136aa6 100644 --- a/src/agentkit/tools/__init__.py +++ b/src/agentkit/tools/__init__.py @@ -3,11 +3,17 @@ from agentkit.tools.base import Tool from agentkit.tools.function_tool import FunctionTool from agentkit.tools.agent_tool import AgentTool +from agentkit.tools.mcp_tool import MCPTool from agentkit.tools.registry import ToolRegistry +from agentkit.tools.composition import SequentialChain, ParallelFanOut, DynamicSelector __all__ = [ "Tool", "FunctionTool", "AgentTool", + "MCPTool", "ToolRegistry", + "SequentialChain", + "ParallelFanOut", + "DynamicSelector", ] diff --git a/src/agentkit/tools/composition.py b/src/agentkit/tools/composition.py new file mode 100644 index 0000000..9cfe069 --- /dev/null +++ b/src/agentkit/tools/composition.py @@ -0,0 +1,260 @@ +"""工具组合 - SequentialChain, ParallelFanOut, DynamicSelector + +支持工具的高级组合模式: +- SequentialChain: 顺序执行,前一个输出作为后一个输入 +- ParallelFanOut: 并行执行多个工具,结果合并 +- DynamicSelector: 根据 LLM 判断动态选择工具 +""" + +import asyncio +import json +import logging +from typing import Any + +from agentkit.tools.base import Tool + +logger = logging.getLogger(__name__) + + +class SequentialChain(Tool): + """顺序链 - 依次执行多个工具,前一个输出作为后一个输入 + + 示例:: + + chain = SequentialChain( + name="retrieve_and_summarize", + tools=[retrieve_tool, summarize_tool], + ) + result = await chain.execute(query="AI trends") + """ + + def __init__( + self, + name: str, + description: str, + tools: list[Tool], + version: str = "1.0.0", + tags: list[str] | None = None, + ): + super().__init__( + name=name, + description=description, + version=version, + tags=tags or ["composition", "sequential"], + ) + self._chain_tools = tools + + async def execute(self, **kwargs) -> dict: + """顺序执行所有工具""" + result = kwargs + + for tool in self._chain_tools: + if isinstance(result, dict): + result = await tool.safe_execute(**result) + else: + raise TypeError( + f"SequentialChain: tool '{tool.name}' must return dict, got {type(result)}" + ) + logger.debug(f"SequentialChain step '{tool.name}' completed") + + return result if isinstance(result, dict) else {"result": result} + + +class ParallelFanOut(Tool): + """并行扇出 - 同时执行多个工具,结果合并 + + 示例:: + + fan_out = ParallelFanOut( + name="multi_source_search", + tools=[web_search, db_search, cache_search], + ) + result = await fan_out.execute(query="AI trends") + """ + + def __init__( + self, + name: str, + description: str, + tools: list[Tool], + merge_strategy: str = "merge", + version: str = "1.0.0", + tags: list[str] | None = None, + ): + super().__init__( + name=name, + description=description, + version=version, + tags=tags or ["composition", "parallel"], + ) + self._parallel_tools = tools + self._merge_strategy = merge_strategy + + async def execute(self, **kwargs) -> dict: + """并行执行所有工具并合并结果""" + tasks = [tool.safe_execute(**kwargs) for tool in self._parallel_tools] + results = await asyncio.gather(*tasks, return_exceptions=True) + + merged = {} + errors = [] + + for tool, result in zip(self._parallel_tools, results): + if isinstance(result, Exception): + logger.warning(f"ParallelFanOut tool '{tool.name}' failed: {result}") + errors.append({"tool": tool.name, "error": str(result)}) + elif isinstance(result, dict): + if self._merge_strategy == "namespace": + merged[tool.name] = result + else: + merged.update(result) + + if errors: + merged["_errors"] = errors + + return merged + + +class DynamicSelector(Tool): + """动态选择器 - 根据输入动态选择合适的工具执行 + + 支持两种选择模式: + - "keyword": 根据输入中的关键字匹配工具名/标签 + - "llm": 通过 LLM 判断选择最合适的工具 + + 示例:: + + selector = DynamicSelector( + name="smart_tool", + tools=[search_tool, calculate_tool, translate_tool], + mode="keyword", + ) + result = await selector.execute(query="search for AI", _intent="search") + """ + + def __init__( + self, + name: str, + description: str, + tools: list[Tool], + mode: str = "keyword", + llm_client: Any = None, + version: str = "1.0.0", + tags: list[str] | None = None, + ): + super().__init__( + name=name, + description=description, + version=version, + tags=tags or ["composition", "dynamic"], + ) + self._selector_tools = tools + self._mode = mode + self._llm_client = llm_client + + async def execute(self, **kwargs) -> dict: + """动态选择并执行工具""" + # 移除内部参数 + intent = kwargs.pop("_intent", None) + + if self._mode == "keyword": + selected = self._select_by_keyword(kwargs, intent) + elif self._mode == "llm": + selected = await self._select_by_llm(kwargs) + else: + raise ValueError(f"Unknown selector mode: {self._mode}") + + if selected is None: + return { + "error": "no_matching_tool", + "available": [t.name for t in self._selector_tools], + } + + logger.info(f"DynamicSelector selected tool '{selected.name}'") + return await selected.safe_execute(**kwargs) + + def _select_by_keyword(self, kwargs: dict, intent: str | None) -> Tool | None: + """根据关键字匹配工具""" + # 优先使用显式 intent + if intent: + for tool in self._selector_tools: + if intent.lower() in tool.name.lower(): + return tool + if intent.lower() in " ".join(tool.tags).lower(): + return tool + + # 从输入值中推断 + input_text = " ".join(str(v) for v in kwargs.values()).lower() + best_match = None + best_score = 0 + + for tool in self._selector_tools: + score = 0 + # 工具名匹配 + for word in tool.name.lower().split("_"): + if word in input_text: + score += 2 + # 标签匹配 + for tag in tool.tags: + if tag.lower() in input_text: + score += 1 + # 描述匹配 + for word in tool.description.lower().split(): + if word in input_text and len(word) > 3: + score += 1 + + if score > best_score: + best_score = score + best_match = tool + + return best_match + + async def _select_by_llm(self, kwargs: dict) -> Tool | None: + """通过 LLM 选择工具""" + if self._llm_client is None: + logger.warning("DynamicSelector: no LLM client, falling back to keyword mode") + return self._select_by_keyword(kwargs, None) + + tool_descriptions = [] + for i, tool in enumerate(self._selector_tools): + tool_descriptions.append( + f"{i}. {tool.name}: {tool.description} (tags: {', '.join(tool.tags)})" + ) + + prompt = ( + f"Given the following input:\n{json.dumps(kwargs, ensure_ascii=False)}\n\n" + f"Available tools:\n" + "\n".join(tool_descriptions) + "\n\n" + f"Which tool number should be used? Reply with just the number." + ) + + try: + if hasattr(self._llm_client, "chat"): + response = await self._llm_client.chat( + messages=[{"role": "user", "content": prompt}], + model="gpt-4", + temperature=0, + max_tokens=10, + ) + elif callable(self._llm_client): + response = await self._llm_client( + messages=[{"role": "user", "content": prompt}], + model="gpt-4", + temperature=0, + max_tokens=10, + ) + else: + return self._select_by_keyword(kwargs, None) + + if isinstance(response, str): + idx = int(response.strip()) + elif isinstance(response, dict): + idx = int(response.get("content", "0").strip()) + else: + idx = 0 + + if 0 <= idx < len(self._selector_tools): + return self._selector_tools[idx] + + except (ValueError, IndexError, Exception) as e: + logger.warning(f"DynamicSelector LLM selection failed: {e}") + + return self._select_by_keyword(kwargs, None) diff --git a/src/agentkit/tools/mcp_tool.py b/src/agentkit/tools/mcp_tool.py new file mode 100644 index 0000000..67622d9 --- /dev/null +++ b/src/agentkit/tools/mcp_tool.py @@ -0,0 +1,51 @@ +"""MCPTool - 通过 MCP Client 调用远程工具""" + +import json +import logging +from typing import Any + +from agentkit.tools.base import Tool + +logger = logging.getLogger(__name__) + + +class MCPTool(Tool): + """MCP 工具 - 通过 MCP Client 调用远程工具 + + 将外部 MCP Server 上的工具包装为本地 Tool 对象, + 通过 MCPClient 发起 HTTP 调用。 + """ + + def __init__( + self, + name: str, + description: str, + client: Any, + input_schema: dict[str, Any] | None = None, + output_schema: dict[str, Any] | None = None, + version: str = "1.0.0", + tags: list[str] | None = None, + ): + super().__init__( + name=name, + description=description, + input_schema=input_schema, + output_schema=output_schema, + version=version, + tags=tags or ["mcp"], + ) + self._client = client + + async def execute(self, **kwargs) -> dict: + """通过 MCP Client 调用远程工具""" + result = await self._client.call_tool(self.name, kwargs) + + # 解析 MCP 响应格式 + if "content" in result: + for item in result["content"]: + if item.get("type") == "text": + try: + return json.loads(item["text"]) + except json.JSONDecodeError: + return {"result": item["text"]} + return result diff --git a/tests/unit/test_tool_composition.py b/tests/unit/test_tool_composition.py new file mode 100644 index 0000000..fae8007 --- /dev/null +++ b/tests/unit/test_tool_composition.py @@ -0,0 +1,303 @@ +"""U3 测试: Tool 组合 (SequentialChain, ParallelFanOut, DynamicSelector) + MCPTool""" + +import asyncio +import json + +import pytest + +from agentkit.tools.base import Tool +from agentkit.tools.composition import DynamicSelector, ParallelFanOut, SequentialChain +from agentkit.tools.function_tool import FunctionTool +from agentkit.tools.mcp_tool import MCPTool + + +# ── Helper: 创建简单 FunctionTool ──────────────────────── + + +def _make_tool(name: str, func=None, **kw): + if func is None: + async def default_func(**kwargs): + return {"tool": name, "input": kwargs} + func = default_func + return FunctionTool(name=name, description=f"Test tool {name}", func=func, **kw) + + +# ── SequentialChain 测试 ───────────────────────────────── + + +class TestSequentialChain: + async def test_chain_two_tools(self): + """顺序执行两个工具,前一个输出作为后一个输入""" + async def step1(text: str) -> dict: + return {"processed": text.upper(), "length": len(text)} + + async def step2(processed: str, length: int) -> dict: + return {"result": f"{processed}({length})"} + + chain = SequentialChain( + name="upper_and_format", + description="Uppercase then format", + tools=[ + FunctionTool(name="step1", description="Step 1", func=step1), + FunctionTool(name="step2", description="Step 2", func=step2), + ], + ) + + result = await chain.execute(text="hello") + assert result["result"] == "HELLO(5)" + + async def test_chain_single_tool(self): + """单工具链直接返回结果""" + tool = _make_tool("solo") + chain = SequentialChain(name="solo_chain", description="Solo", tools=[tool]) + + result = await chain.execute(x=1) + assert result["tool"] == "solo" + + async def test_chain_preserves_dict(self): + """链中每步都返回 dict""" + async def add_field(x: int) -> dict: + return {"x": x, "doubled": x * 2} + + async def add_label(x: int, doubled: int) -> dict: + return {"x": x, "doubled": doubled, "label": f"{x}->{doubled}"} + + chain = SequentialChain( + name="enrich", + description="Enrich data", + tools=[ + FunctionTool(name="add_field", description="Add", func=add_field), + FunctionTool(name="add_label", description="Label", func=add_label), + ], + ) + + result = await chain.execute(x=3) + assert result["label"] == "3->6" + + +# ── ParallelFanOut 测试 ────────────────────────────────── + + +class TestParallelFanOut: + async def test_parallel_execution(self): + """并行执行三个工具""" + async def search_web(query: str) -> dict: + return {"web_results": [f"web: {query}"]} + + async def search_db(query: str) -> dict: + return {"db_results": [f"db: {query}"]} + + async def search_cache(query: str) -> dict: + return {"cache_results": [f"cache: {query}"]} + + fan_out = ParallelFanOut( + name="multi_search", + description="Search multiple sources", + tools=[ + FunctionTool(name="search_web", description="Web", func=search_web), + FunctionTool(name="search_db", description="DB", func=search_db), + FunctionTool(name="search_cache", description="Cache", func=search_cache), + ], + ) + + result = await fan_out.execute(query="AI") + assert "web_results" in result + assert "db_results" in result + assert "cache_results" in result + + async def test_parallel_with_error(self): + """并行执行中有工具失败,不影响其他""" + async def good_tool(x: int) -> dict: + return {"good": x * 2} + + async def bad_tool(x: int) -> dict: + raise ValueError("boom") + + fan_out = ParallelFanOut( + name="mixed", + description="Mixed results", + tools=[ + FunctionTool(name="good", description="Good", func=good_tool), + FunctionTool(name="bad", description="Bad", func=bad_tool), + ], + ) + + result = await fan_out.execute(x=5) + assert result["good"] == 10 + assert "_errors" in result + assert any(e["tool"] == "bad" for e in result["_errors"]) + + async def test_parallel_namespace_merge(self): + """namespace 合并策略:每个工具结果独立命名空间""" + async def tool_a(x: int) -> dict: + return {"result": "a"} + + async def tool_b(x: int) -> dict: + return {"result": "b"} + + fan_out = ParallelFanOut( + name="namespaced", + description="Namespaced", + tools=[ + FunctionTool(name="tool_a", description="A", func=tool_a), + FunctionTool(name="tool_b", description="B", func=tool_b), + ], + merge_strategy="namespace", + ) + + result = await fan_out.execute(x=1) + assert "tool_a" in result + assert "tool_b" in result + assert result["tool_a"]["result"] == "a" + assert result["tool_b"]["result"] == "b" + + +# ── DynamicSelector 测试 ───────────────────────────────── + + +class TestDynamicSelector: + async def test_keyword_selection_by_intent(self): + """根据 _intent 参数选择工具""" + search_tool = _make_tool("search_engine") + calc_tool = _make_tool("calculator") + + selector = DynamicSelector( + name="smart", + description="Smart selector", + tools=[search_tool, calc_tool], + mode="keyword", + ) + + result = await selector.execute(query="AI", _intent="search") + assert result["tool"] == "search_engine" + + async def test_keyword_selection_by_input(self): + """根据输入内容推断工具""" + search_tool = _make_tool("search_data") + calc_tool = _make_tool("calculate_result") + + selector = DynamicSelector( + name="smart", + description="Smart selector", + tools=[search_tool, calc_tool], + mode="keyword", + ) + + result = await selector.execute(query="search for trends") + assert result["tool"] == "search_data" + + async def test_no_matching_tool(self): + """无匹配工具时返回错误""" + tool_a = _make_tool("tool_a") + + selector = DynamicSelector( + name="smart", + description="Smart selector", + tools=[tool_a], + mode="keyword", + ) + + result = await selector.execute(x=999) + # 应该仍然选择 tool_a(best_score=0 的默认) + assert "tool" in result or "error" in result + + async def test_llm_selection(self): + """LLM 模式选择工具""" + search_tool = _make_tool("search_engine") + calc_tool = _make_tool("calculator") + + class MockLLM: + async def chat(self, messages, **kwargs): + return "0" # 选择第一个工具 + + selector = DynamicSelector( + name="smart", + description="Smart selector", + tools=[search_tool, calc_tool], + mode="llm", + llm_client=MockLLM(), + ) + + result = await selector.execute(query="find information") + assert result["tool"] == "search_engine" + + async def test_llm_fallback_to_keyword(self): + """LLM 失败时回退到关键字模式""" + search_tool = _make_tool("search_engine") + + class BrokenLLM: + async def chat(self, messages, **kwargs): + raise ConnectionError("LLM unavailable") + + selector = DynamicSelector( + name="smart", + description="Smart selector", + tools=[search_tool], + mode="llm", + llm_client=BrokenLLM(), + ) + + result = await selector.execute(query="search something") + assert result["tool"] == "search_engine" + + +# ── MCPTool 测试 ───────────────────────────────────────── + + +class TestMCPTool: + async def test_mcp_tool_execute(self): + """MCPTool 通过 client 调用远程工具""" + + class MockMCPClient: + async def call_tool(self, name, arguments): + return { + "content": [ + {"type": "text", "text": json.dumps({"status": "ok", "data": arguments})} + ] + } + + tool = MCPTool( + name="remote_search", + description="Remote search via MCP", + client=MockMCPClient(), + ) + + result = await tool.execute(query="test") + assert result["status"] == "ok" + assert result["data"]["query"] == "test" + + async def test_mcp_tool_text_response(self): + """MCP 返回非 JSON 文本""" + + class MockMCPClient: + async def call_tool(self, name, arguments): + return { + "content": [ + {"type": "text", "text": "plain text result"} + ] + } + + tool = MCPTool( + name="remote_echo", + description="Remote echo", + client=MockMCPClient(), + ) + + result = await tool.execute(text="hello") + assert result["result"] == "plain text result" + + async def test_mcp_tool_raw_response(self): + """MCP 返回非标准格式""" + + class MockMCPClient: + async def call_tool(self, name, arguments): + return {"raw_key": "raw_value"} + + tool = MCPTool( + name="remote_raw", + description="Raw response", + client=MockMCPClient(), + ) + + result = await tool.execute(x=1) + assert result["raw_key"] == "raw_value"