feat(tools): add MCPTool, SequentialChain, ParallelFanOut, DynamicSelector

- MCPTool: call remote MCP tools via MCPClient
- SequentialChain: chain tools with output-to-input piping
- ParallelFanOut: execute tools concurrently with merge strategies
- DynamicSelector: keyword/LLM-based tool selection
- 14 new tests, total 70 passing
This commit is contained in:
chiguyong 2026-06-04 22:42:22 +08:00
parent 5a90824c77
commit d73a3391ab
4 changed files with 620 additions and 0 deletions

View File

@ -3,11 +3,17 @@
from agentkit.tools.base import Tool
from agentkit.tools.function_tool import FunctionTool
from agentkit.tools.agent_tool import AgentTool
from agentkit.tools.mcp_tool import MCPTool
from agentkit.tools.registry import ToolRegistry
from agentkit.tools.composition import SequentialChain, ParallelFanOut, DynamicSelector
__all__ = [
"Tool",
"FunctionTool",
"AgentTool",
"MCPTool",
"ToolRegistry",
"SequentialChain",
"ParallelFanOut",
"DynamicSelector",
]

View File

@ -0,0 +1,260 @@
"""工具组合 - SequentialChain, ParallelFanOut, DynamicSelector
支持工具的高级组合模式
- SequentialChain: 顺序执行前一个输出作为后一个输入
- ParallelFanOut: 并行执行多个工具结果合并
- DynamicSelector: 根据 LLM 判断动态选择工具
"""
import asyncio
import json
import logging
from typing import Any
from agentkit.tools.base import Tool
logger = logging.getLogger(__name__)
class SequentialChain(Tool):
"""顺序链 - 依次执行多个工具,前一个输出作为后一个输入
示例::
chain = SequentialChain(
name="retrieve_and_summarize",
tools=[retrieve_tool, summarize_tool],
)
result = await chain.execute(query="AI trends")
"""
def __init__(
self,
name: str,
description: str,
tools: list[Tool],
version: str = "1.0.0",
tags: list[str] | None = None,
):
super().__init__(
name=name,
description=description,
version=version,
tags=tags or ["composition", "sequential"],
)
self._chain_tools = tools
async def execute(self, **kwargs) -> dict:
"""顺序执行所有工具"""
result = kwargs
for tool in self._chain_tools:
if isinstance(result, dict):
result = await tool.safe_execute(**result)
else:
raise TypeError(
f"SequentialChain: tool '{tool.name}' must return dict, got {type(result)}"
)
logger.debug(f"SequentialChain step '{tool.name}' completed")
return result if isinstance(result, dict) else {"result": result}
class ParallelFanOut(Tool):
"""并行扇出 - 同时执行多个工具,结果合并
示例::
fan_out = ParallelFanOut(
name="multi_source_search",
tools=[web_search, db_search, cache_search],
)
result = await fan_out.execute(query="AI trends")
"""
def __init__(
self,
name: str,
description: str,
tools: list[Tool],
merge_strategy: str = "merge",
version: str = "1.0.0",
tags: list[str] | None = None,
):
super().__init__(
name=name,
description=description,
version=version,
tags=tags or ["composition", "parallel"],
)
self._parallel_tools = tools
self._merge_strategy = merge_strategy
async def execute(self, **kwargs) -> dict:
"""并行执行所有工具并合并结果"""
tasks = [tool.safe_execute(**kwargs) for tool in self._parallel_tools]
results = await asyncio.gather(*tasks, return_exceptions=True)
merged = {}
errors = []
for tool, result in zip(self._parallel_tools, results):
if isinstance(result, Exception):
logger.warning(f"ParallelFanOut tool '{tool.name}' failed: {result}")
errors.append({"tool": tool.name, "error": str(result)})
elif isinstance(result, dict):
if self._merge_strategy == "namespace":
merged[tool.name] = result
else:
merged.update(result)
if errors:
merged["_errors"] = errors
return merged
class DynamicSelector(Tool):
"""动态选择器 - 根据输入动态选择合适的工具执行
支持两种选择模式
- "keyword": 根据输入中的关键字匹配工具名/标签
- "llm": 通过 LLM 判断选择最合适的工具
示例::
selector = DynamicSelector(
name="smart_tool",
tools=[search_tool, calculate_tool, translate_tool],
mode="keyword",
)
result = await selector.execute(query="search for AI", _intent="search")
"""
def __init__(
self,
name: str,
description: str,
tools: list[Tool],
mode: str = "keyword",
llm_client: Any = None,
version: str = "1.0.0",
tags: list[str] | None = None,
):
super().__init__(
name=name,
description=description,
version=version,
tags=tags or ["composition", "dynamic"],
)
self._selector_tools = tools
self._mode = mode
self._llm_client = llm_client
async def execute(self, **kwargs) -> dict:
"""动态选择并执行工具"""
# 移除内部参数
intent = kwargs.pop("_intent", None)
if self._mode == "keyword":
selected = self._select_by_keyword(kwargs, intent)
elif self._mode == "llm":
selected = await self._select_by_llm(kwargs)
else:
raise ValueError(f"Unknown selector mode: {self._mode}")
if selected is None:
return {
"error": "no_matching_tool",
"available": [t.name for t in self._selector_tools],
}
logger.info(f"DynamicSelector selected tool '{selected.name}'")
return await selected.safe_execute(**kwargs)
def _select_by_keyword(self, kwargs: dict, intent: str | None) -> Tool | None:
"""根据关键字匹配工具"""
# 优先使用显式 intent
if intent:
for tool in self._selector_tools:
if intent.lower() in tool.name.lower():
return tool
if intent.lower() in " ".join(tool.tags).lower():
return tool
# 从输入值中推断
input_text = " ".join(str(v) for v in kwargs.values()).lower()
best_match = None
best_score = 0
for tool in self._selector_tools:
score = 0
# 工具名匹配
for word in tool.name.lower().split("_"):
if word in input_text:
score += 2
# 标签匹配
for tag in tool.tags:
if tag.lower() in input_text:
score += 1
# 描述匹配
for word in tool.description.lower().split():
if word in input_text and len(word) > 3:
score += 1
if score > best_score:
best_score = score
best_match = tool
return best_match
async def _select_by_llm(self, kwargs: dict) -> Tool | None:
"""通过 LLM 选择工具"""
if self._llm_client is None:
logger.warning("DynamicSelector: no LLM client, falling back to keyword mode")
return self._select_by_keyword(kwargs, None)
tool_descriptions = []
for i, tool in enumerate(self._selector_tools):
tool_descriptions.append(
f"{i}. {tool.name}: {tool.description} (tags: {', '.join(tool.tags)})"
)
prompt = (
f"Given the following input:\n{json.dumps(kwargs, ensure_ascii=False)}\n\n"
f"Available tools:\n" + "\n".join(tool_descriptions) + "\n\n"
f"Which tool number should be used? Reply with just the number."
)
try:
if hasattr(self._llm_client, "chat"):
response = await self._llm_client.chat(
messages=[{"role": "user", "content": prompt}],
model="gpt-4",
temperature=0,
max_tokens=10,
)
elif callable(self._llm_client):
response = await self._llm_client(
messages=[{"role": "user", "content": prompt}],
model="gpt-4",
temperature=0,
max_tokens=10,
)
else:
return self._select_by_keyword(kwargs, None)
if isinstance(response, str):
idx = int(response.strip())
elif isinstance(response, dict):
idx = int(response.get("content", "0").strip())
else:
idx = 0
if 0 <= idx < len(self._selector_tools):
return self._selector_tools[idx]
except (ValueError, IndexError, Exception) as e:
logger.warning(f"DynamicSelector LLM selection failed: {e}")
return self._select_by_keyword(kwargs, None)

View File

@ -0,0 +1,51 @@
"""MCPTool - 通过 MCP Client 调用远程工具"""
import json
import logging
from typing import Any
from agentkit.tools.base import Tool
logger = logging.getLogger(__name__)
class MCPTool(Tool):
"""MCP 工具 - 通过 MCP Client 调用远程工具
将外部 MCP Server 上的工具包装为本地 Tool 对象
通过 MCPClient 发起 HTTP 调用
"""
def __init__(
self,
name: str,
description: str,
client: Any,
input_schema: dict[str, Any] | None = None,
output_schema: dict[str, Any] | None = None,
version: str = "1.0.0",
tags: list[str] | None = None,
):
super().__init__(
name=name,
description=description,
input_schema=input_schema,
output_schema=output_schema,
version=version,
tags=tags or ["mcp"],
)
self._client = client
async def execute(self, **kwargs) -> dict:
"""通过 MCP Client 调用远程工具"""
result = await self._client.call_tool(self.name, kwargs)
# 解析 MCP 响应格式
if "content" in result:
for item in result["content"]:
if item.get("type") == "text":
try:
return json.loads(item["text"])
except json.JSONDecodeError:
return {"result": item["text"]}
return result

View File

@ -0,0 +1,303 @@
"""U3 测试: Tool 组合 (SequentialChain, ParallelFanOut, DynamicSelector) + MCPTool"""
import asyncio
import json
import pytest
from agentkit.tools.base import Tool
from agentkit.tools.composition import DynamicSelector, ParallelFanOut, SequentialChain
from agentkit.tools.function_tool import FunctionTool
from agentkit.tools.mcp_tool import MCPTool
# ── Helper: 创建简单 FunctionTool ────────────────────────
def _make_tool(name: str, func=None, **kw):
if func is None:
async def default_func(**kwargs):
return {"tool": name, "input": kwargs}
func = default_func
return FunctionTool(name=name, description=f"Test tool {name}", func=func, **kw)
# ── SequentialChain 测试 ─────────────────────────────────
class TestSequentialChain:
async def test_chain_two_tools(self):
"""顺序执行两个工具,前一个输出作为后一个输入"""
async def step1(text: str) -> dict:
return {"processed": text.upper(), "length": len(text)}
async def step2(processed: str, length: int) -> dict:
return {"result": f"{processed}({length})"}
chain = SequentialChain(
name="upper_and_format",
description="Uppercase then format",
tools=[
FunctionTool(name="step1", description="Step 1", func=step1),
FunctionTool(name="step2", description="Step 2", func=step2),
],
)
result = await chain.execute(text="hello")
assert result["result"] == "HELLO(5)"
async def test_chain_single_tool(self):
"""单工具链直接返回结果"""
tool = _make_tool("solo")
chain = SequentialChain(name="solo_chain", description="Solo", tools=[tool])
result = await chain.execute(x=1)
assert result["tool"] == "solo"
async def test_chain_preserves_dict(self):
"""链中每步都返回 dict"""
async def add_field(x: int) -> dict:
return {"x": x, "doubled": x * 2}
async def add_label(x: int, doubled: int) -> dict:
return {"x": x, "doubled": doubled, "label": f"{x}->{doubled}"}
chain = SequentialChain(
name="enrich",
description="Enrich data",
tools=[
FunctionTool(name="add_field", description="Add", func=add_field),
FunctionTool(name="add_label", description="Label", func=add_label),
],
)
result = await chain.execute(x=3)
assert result["label"] == "3->6"
# ── ParallelFanOut 测试 ──────────────────────────────────
class TestParallelFanOut:
async def test_parallel_execution(self):
"""并行执行三个工具"""
async def search_web(query: str) -> dict:
return {"web_results": [f"web: {query}"]}
async def search_db(query: str) -> dict:
return {"db_results": [f"db: {query}"]}
async def search_cache(query: str) -> dict:
return {"cache_results": [f"cache: {query}"]}
fan_out = ParallelFanOut(
name="multi_search",
description="Search multiple sources",
tools=[
FunctionTool(name="search_web", description="Web", func=search_web),
FunctionTool(name="search_db", description="DB", func=search_db),
FunctionTool(name="search_cache", description="Cache", func=search_cache),
],
)
result = await fan_out.execute(query="AI")
assert "web_results" in result
assert "db_results" in result
assert "cache_results" in result
async def test_parallel_with_error(self):
"""并行执行中有工具失败,不影响其他"""
async def good_tool(x: int) -> dict:
return {"good": x * 2}
async def bad_tool(x: int) -> dict:
raise ValueError("boom")
fan_out = ParallelFanOut(
name="mixed",
description="Mixed results",
tools=[
FunctionTool(name="good", description="Good", func=good_tool),
FunctionTool(name="bad", description="Bad", func=bad_tool),
],
)
result = await fan_out.execute(x=5)
assert result["good"] == 10
assert "_errors" in result
assert any(e["tool"] == "bad" for e in result["_errors"])
async def test_parallel_namespace_merge(self):
"""namespace 合并策略:每个工具结果独立命名空间"""
async def tool_a(x: int) -> dict:
return {"result": "a"}
async def tool_b(x: int) -> dict:
return {"result": "b"}
fan_out = ParallelFanOut(
name="namespaced",
description="Namespaced",
tools=[
FunctionTool(name="tool_a", description="A", func=tool_a),
FunctionTool(name="tool_b", description="B", func=tool_b),
],
merge_strategy="namespace",
)
result = await fan_out.execute(x=1)
assert "tool_a" in result
assert "tool_b" in result
assert result["tool_a"]["result"] == "a"
assert result["tool_b"]["result"] == "b"
# ── DynamicSelector 测试 ─────────────────────────────────
class TestDynamicSelector:
async def test_keyword_selection_by_intent(self):
"""根据 _intent 参数选择工具"""
search_tool = _make_tool("search_engine")
calc_tool = _make_tool("calculator")
selector = DynamicSelector(
name="smart",
description="Smart selector",
tools=[search_tool, calc_tool],
mode="keyword",
)
result = await selector.execute(query="AI", _intent="search")
assert result["tool"] == "search_engine"
async def test_keyword_selection_by_input(self):
"""根据输入内容推断工具"""
search_tool = _make_tool("search_data")
calc_tool = _make_tool("calculate_result")
selector = DynamicSelector(
name="smart",
description="Smart selector",
tools=[search_tool, calc_tool],
mode="keyword",
)
result = await selector.execute(query="search for trends")
assert result["tool"] == "search_data"
async def test_no_matching_tool(self):
"""无匹配工具时返回错误"""
tool_a = _make_tool("tool_a")
selector = DynamicSelector(
name="smart",
description="Smart selector",
tools=[tool_a],
mode="keyword",
)
result = await selector.execute(x=999)
# 应该仍然选择 tool_abest_score=0 的默认)
assert "tool" in result or "error" in result
async def test_llm_selection(self):
"""LLM 模式选择工具"""
search_tool = _make_tool("search_engine")
calc_tool = _make_tool("calculator")
class MockLLM:
async def chat(self, messages, **kwargs):
return "0" # 选择第一个工具
selector = DynamicSelector(
name="smart",
description="Smart selector",
tools=[search_tool, calc_tool],
mode="llm",
llm_client=MockLLM(),
)
result = await selector.execute(query="find information")
assert result["tool"] == "search_engine"
async def test_llm_fallback_to_keyword(self):
"""LLM 失败时回退到关键字模式"""
search_tool = _make_tool("search_engine")
class BrokenLLM:
async def chat(self, messages, **kwargs):
raise ConnectionError("LLM unavailable")
selector = DynamicSelector(
name="smart",
description="Smart selector",
tools=[search_tool],
mode="llm",
llm_client=BrokenLLM(),
)
result = await selector.execute(query="search something")
assert result["tool"] == "search_engine"
# ── MCPTool 测试 ─────────────────────────────────────────
class TestMCPTool:
async def test_mcp_tool_execute(self):
"""MCPTool 通过 client 调用远程工具"""
class MockMCPClient:
async def call_tool(self, name, arguments):
return {
"content": [
{"type": "text", "text": json.dumps({"status": "ok", "data": arguments})}
]
}
tool = MCPTool(
name="remote_search",
description="Remote search via MCP",
client=MockMCPClient(),
)
result = await tool.execute(query="test")
assert result["status"] == "ok"
assert result["data"]["query"] == "test"
async def test_mcp_tool_text_response(self):
"""MCP 返回非 JSON 文本"""
class MockMCPClient:
async def call_tool(self, name, arguments):
return {
"content": [
{"type": "text", "text": "plain text result"}
]
}
tool = MCPTool(
name="remote_echo",
description="Remote echo",
client=MockMCPClient(),
)
result = await tool.execute(text="hello")
assert result["result"] == "plain text result"
async def test_mcp_tool_raw_response(self):
"""MCP 返回非标准格式"""
class MockMCPClient:
async def call_tool(self, name, arguments):
return {"raw_key": "raw_value"}
tool = MCPTool(
name="remote_raw",
description="Raw response",
client=MockMCPClient(),
)
result = await tool.execute(x=1)
assert result["raw_key"] == "raw_value"