feat(tools): add MCPTool, SequentialChain, ParallelFanOut, DynamicSelector
- MCPTool: call remote MCP tools via MCPClient - SequentialChain: chain tools with output-to-input piping - ParallelFanOut: execute tools concurrently with merge strategies - DynamicSelector: keyword/LLM-based tool selection - 14 new tests, total 70 passing
This commit is contained in:
parent
5a90824c77
commit
d73a3391ab
|
|
@ -3,11 +3,17 @@
|
|||
from agentkit.tools.base import Tool
|
||||
from agentkit.tools.function_tool import FunctionTool
|
||||
from agentkit.tools.agent_tool import AgentTool
|
||||
from agentkit.tools.mcp_tool import MCPTool
|
||||
from agentkit.tools.registry import ToolRegistry
|
||||
from agentkit.tools.composition import SequentialChain, ParallelFanOut, DynamicSelector
|
||||
|
||||
__all__ = [
|
||||
"Tool",
|
||||
"FunctionTool",
|
||||
"AgentTool",
|
||||
"MCPTool",
|
||||
"ToolRegistry",
|
||||
"SequentialChain",
|
||||
"ParallelFanOut",
|
||||
"DynamicSelector",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,260 @@
|
|||
"""工具组合 - SequentialChain, ParallelFanOut, DynamicSelector
|
||||
|
||||
支持工具的高级组合模式:
|
||||
- SequentialChain: 顺序执行,前一个输出作为后一个输入
|
||||
- ParallelFanOut: 并行执行多个工具,结果合并
|
||||
- DynamicSelector: 根据 LLM 判断动态选择工具
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from agentkit.tools.base import Tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SequentialChain(Tool):
|
||||
"""顺序链 - 依次执行多个工具,前一个输出作为后一个输入
|
||||
|
||||
示例::
|
||||
|
||||
chain = SequentialChain(
|
||||
name="retrieve_and_summarize",
|
||||
tools=[retrieve_tool, summarize_tool],
|
||||
)
|
||||
result = await chain.execute(query="AI trends")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
tools: list[Tool],
|
||||
version: str = "1.0.0",
|
||||
tags: list[str] | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
name=name,
|
||||
description=description,
|
||||
version=version,
|
||||
tags=tags or ["composition", "sequential"],
|
||||
)
|
||||
self._chain_tools = tools
|
||||
|
||||
async def execute(self, **kwargs) -> dict:
|
||||
"""顺序执行所有工具"""
|
||||
result = kwargs
|
||||
|
||||
for tool in self._chain_tools:
|
||||
if isinstance(result, dict):
|
||||
result = await tool.safe_execute(**result)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"SequentialChain: tool '{tool.name}' must return dict, got {type(result)}"
|
||||
)
|
||||
logger.debug(f"SequentialChain step '{tool.name}' completed")
|
||||
|
||||
return result if isinstance(result, dict) else {"result": result}
|
||||
|
||||
|
||||
class ParallelFanOut(Tool):
|
||||
"""并行扇出 - 同时执行多个工具,结果合并
|
||||
|
||||
示例::
|
||||
|
||||
fan_out = ParallelFanOut(
|
||||
name="multi_source_search",
|
||||
tools=[web_search, db_search, cache_search],
|
||||
)
|
||||
result = await fan_out.execute(query="AI trends")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
tools: list[Tool],
|
||||
merge_strategy: str = "merge",
|
||||
version: str = "1.0.0",
|
||||
tags: list[str] | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
name=name,
|
||||
description=description,
|
||||
version=version,
|
||||
tags=tags or ["composition", "parallel"],
|
||||
)
|
||||
self._parallel_tools = tools
|
||||
self._merge_strategy = merge_strategy
|
||||
|
||||
async def execute(self, **kwargs) -> dict:
|
||||
"""并行执行所有工具并合并结果"""
|
||||
tasks = [tool.safe_execute(**kwargs) for tool in self._parallel_tools]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
merged = {}
|
||||
errors = []
|
||||
|
||||
for tool, result in zip(self._parallel_tools, results):
|
||||
if isinstance(result, Exception):
|
||||
logger.warning(f"ParallelFanOut tool '{tool.name}' failed: {result}")
|
||||
errors.append({"tool": tool.name, "error": str(result)})
|
||||
elif isinstance(result, dict):
|
||||
if self._merge_strategy == "namespace":
|
||||
merged[tool.name] = result
|
||||
else:
|
||||
merged.update(result)
|
||||
|
||||
if errors:
|
||||
merged["_errors"] = errors
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
class DynamicSelector(Tool):
|
||||
"""动态选择器 - 根据输入动态选择合适的工具执行
|
||||
|
||||
支持两种选择模式:
|
||||
- "keyword": 根据输入中的关键字匹配工具名/标签
|
||||
- "llm": 通过 LLM 判断选择最合适的工具
|
||||
|
||||
示例::
|
||||
|
||||
selector = DynamicSelector(
|
||||
name="smart_tool",
|
||||
tools=[search_tool, calculate_tool, translate_tool],
|
||||
mode="keyword",
|
||||
)
|
||||
result = await selector.execute(query="search for AI", _intent="search")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
tools: list[Tool],
|
||||
mode: str = "keyword",
|
||||
llm_client: Any = None,
|
||||
version: str = "1.0.0",
|
||||
tags: list[str] | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
name=name,
|
||||
description=description,
|
||||
version=version,
|
||||
tags=tags or ["composition", "dynamic"],
|
||||
)
|
||||
self._selector_tools = tools
|
||||
self._mode = mode
|
||||
self._llm_client = llm_client
|
||||
|
||||
async def execute(self, **kwargs) -> dict:
|
||||
"""动态选择并执行工具"""
|
||||
# 移除内部参数
|
||||
intent = kwargs.pop("_intent", None)
|
||||
|
||||
if self._mode == "keyword":
|
||||
selected = self._select_by_keyword(kwargs, intent)
|
||||
elif self._mode == "llm":
|
||||
selected = await self._select_by_llm(kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown selector mode: {self._mode}")
|
||||
|
||||
if selected is None:
|
||||
return {
|
||||
"error": "no_matching_tool",
|
||||
"available": [t.name for t in self._selector_tools],
|
||||
}
|
||||
|
||||
logger.info(f"DynamicSelector selected tool '{selected.name}'")
|
||||
return await selected.safe_execute(**kwargs)
|
||||
|
||||
def _select_by_keyword(self, kwargs: dict, intent: str | None) -> Tool | None:
|
||||
"""根据关键字匹配工具"""
|
||||
# 优先使用显式 intent
|
||||
if intent:
|
||||
for tool in self._selector_tools:
|
||||
if intent.lower() in tool.name.lower():
|
||||
return tool
|
||||
if intent.lower() in " ".join(tool.tags).lower():
|
||||
return tool
|
||||
|
||||
# 从输入值中推断
|
||||
input_text = " ".join(str(v) for v in kwargs.values()).lower()
|
||||
best_match = None
|
||||
best_score = 0
|
||||
|
||||
for tool in self._selector_tools:
|
||||
score = 0
|
||||
# 工具名匹配
|
||||
for word in tool.name.lower().split("_"):
|
||||
if word in input_text:
|
||||
score += 2
|
||||
# 标签匹配
|
||||
for tag in tool.tags:
|
||||
if tag.lower() in input_text:
|
||||
score += 1
|
||||
# 描述匹配
|
||||
for word in tool.description.lower().split():
|
||||
if word in input_text and len(word) > 3:
|
||||
score += 1
|
||||
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_match = tool
|
||||
|
||||
return best_match
|
||||
|
||||
async def _select_by_llm(self, kwargs: dict) -> Tool | None:
|
||||
"""通过 LLM 选择工具"""
|
||||
if self._llm_client is None:
|
||||
logger.warning("DynamicSelector: no LLM client, falling back to keyword mode")
|
||||
return self._select_by_keyword(kwargs, None)
|
||||
|
||||
tool_descriptions = []
|
||||
for i, tool in enumerate(self._selector_tools):
|
||||
tool_descriptions.append(
|
||||
f"{i}. {tool.name}: {tool.description} (tags: {', '.join(tool.tags)})"
|
||||
)
|
||||
|
||||
prompt = (
|
||||
f"Given the following input:\n{json.dumps(kwargs, ensure_ascii=False)}\n\n"
|
||||
f"Available tools:\n" + "\n".join(tool_descriptions) + "\n\n"
|
||||
f"Which tool number should be used? Reply with just the number."
|
||||
)
|
||||
|
||||
try:
|
||||
if hasattr(self._llm_client, "chat"):
|
||||
response = await self._llm_client.chat(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
model="gpt-4",
|
||||
temperature=0,
|
||||
max_tokens=10,
|
||||
)
|
||||
elif callable(self._llm_client):
|
||||
response = await self._llm_client(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
model="gpt-4",
|
||||
temperature=0,
|
||||
max_tokens=10,
|
||||
)
|
||||
else:
|
||||
return self._select_by_keyword(kwargs, None)
|
||||
|
||||
if isinstance(response, str):
|
||||
idx = int(response.strip())
|
||||
elif isinstance(response, dict):
|
||||
idx = int(response.get("content", "0").strip())
|
||||
else:
|
||||
idx = 0
|
||||
|
||||
if 0 <= idx < len(self._selector_tools):
|
||||
return self._selector_tools[idx]
|
||||
|
||||
except (ValueError, IndexError, Exception) as e:
|
||||
logger.warning(f"DynamicSelector LLM selection failed: {e}")
|
||||
|
||||
return self._select_by_keyword(kwargs, None)
|
||||
|
|
@ -0,0 +1,51 @@
|
|||
"""MCPTool - 通过 MCP Client 调用远程工具"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from agentkit.tools.base import Tool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPTool(Tool):
|
||||
"""MCP 工具 - 通过 MCP Client 调用远程工具
|
||||
|
||||
将外部 MCP Server 上的工具包装为本地 Tool 对象,
|
||||
通过 MCPClient 发起 HTTP 调用。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
client: Any,
|
||||
input_schema: dict[str, Any] | None = None,
|
||||
output_schema: dict[str, Any] | None = None,
|
||||
version: str = "1.0.0",
|
||||
tags: list[str] | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
name=name,
|
||||
description=description,
|
||||
input_schema=input_schema,
|
||||
output_schema=output_schema,
|
||||
version=version,
|
||||
tags=tags or ["mcp"],
|
||||
)
|
||||
self._client = client
|
||||
|
||||
async def execute(self, **kwargs) -> dict:
|
||||
"""通过 MCP Client 调用远程工具"""
|
||||
result = await self._client.call_tool(self.name, kwargs)
|
||||
|
||||
# 解析 MCP 响应格式
|
||||
if "content" in result:
|
||||
for item in result["content"]:
|
||||
if item.get("type") == "text":
|
||||
try:
|
||||
return json.loads(item["text"])
|
||||
except json.JSONDecodeError:
|
||||
return {"result": item["text"]}
|
||||
return result
|
||||
|
|
@ -0,0 +1,303 @@
|
|||
"""U3 测试: Tool 组合 (SequentialChain, ParallelFanOut, DynamicSelector) + MCPTool"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from agentkit.tools.base import Tool
|
||||
from agentkit.tools.composition import DynamicSelector, ParallelFanOut, SequentialChain
|
||||
from agentkit.tools.function_tool import FunctionTool
|
||||
from agentkit.tools.mcp_tool import MCPTool
|
||||
|
||||
|
||||
# ── Helper: 创建简单 FunctionTool ────────────────────────
|
||||
|
||||
|
||||
def _make_tool(name: str, func=None, **kw):
|
||||
if func is None:
|
||||
async def default_func(**kwargs):
|
||||
return {"tool": name, "input": kwargs}
|
||||
func = default_func
|
||||
return FunctionTool(name=name, description=f"Test tool {name}", func=func, **kw)
|
||||
|
||||
|
||||
# ── SequentialChain 测试 ─────────────────────────────────
|
||||
|
||||
|
||||
class TestSequentialChain:
|
||||
async def test_chain_two_tools(self):
|
||||
"""顺序执行两个工具,前一个输出作为后一个输入"""
|
||||
async def step1(text: str) -> dict:
|
||||
return {"processed": text.upper(), "length": len(text)}
|
||||
|
||||
async def step2(processed: str, length: int) -> dict:
|
||||
return {"result": f"{processed}({length})"}
|
||||
|
||||
chain = SequentialChain(
|
||||
name="upper_and_format",
|
||||
description="Uppercase then format",
|
||||
tools=[
|
||||
FunctionTool(name="step1", description="Step 1", func=step1),
|
||||
FunctionTool(name="step2", description="Step 2", func=step2),
|
||||
],
|
||||
)
|
||||
|
||||
result = await chain.execute(text="hello")
|
||||
assert result["result"] == "HELLO(5)"
|
||||
|
||||
async def test_chain_single_tool(self):
|
||||
"""单工具链直接返回结果"""
|
||||
tool = _make_tool("solo")
|
||||
chain = SequentialChain(name="solo_chain", description="Solo", tools=[tool])
|
||||
|
||||
result = await chain.execute(x=1)
|
||||
assert result["tool"] == "solo"
|
||||
|
||||
async def test_chain_preserves_dict(self):
|
||||
"""链中每步都返回 dict"""
|
||||
async def add_field(x: int) -> dict:
|
||||
return {"x": x, "doubled": x * 2}
|
||||
|
||||
async def add_label(x: int, doubled: int) -> dict:
|
||||
return {"x": x, "doubled": doubled, "label": f"{x}->{doubled}"}
|
||||
|
||||
chain = SequentialChain(
|
||||
name="enrich",
|
||||
description="Enrich data",
|
||||
tools=[
|
||||
FunctionTool(name="add_field", description="Add", func=add_field),
|
||||
FunctionTool(name="add_label", description="Label", func=add_label),
|
||||
],
|
||||
)
|
||||
|
||||
result = await chain.execute(x=3)
|
||||
assert result["label"] == "3->6"
|
||||
|
||||
|
||||
# ── ParallelFanOut 测试 ──────────────────────────────────
|
||||
|
||||
|
||||
class TestParallelFanOut:
|
||||
async def test_parallel_execution(self):
|
||||
"""并行执行三个工具"""
|
||||
async def search_web(query: str) -> dict:
|
||||
return {"web_results": [f"web: {query}"]}
|
||||
|
||||
async def search_db(query: str) -> dict:
|
||||
return {"db_results": [f"db: {query}"]}
|
||||
|
||||
async def search_cache(query: str) -> dict:
|
||||
return {"cache_results": [f"cache: {query}"]}
|
||||
|
||||
fan_out = ParallelFanOut(
|
||||
name="multi_search",
|
||||
description="Search multiple sources",
|
||||
tools=[
|
||||
FunctionTool(name="search_web", description="Web", func=search_web),
|
||||
FunctionTool(name="search_db", description="DB", func=search_db),
|
||||
FunctionTool(name="search_cache", description="Cache", func=search_cache),
|
||||
],
|
||||
)
|
||||
|
||||
result = await fan_out.execute(query="AI")
|
||||
assert "web_results" in result
|
||||
assert "db_results" in result
|
||||
assert "cache_results" in result
|
||||
|
||||
async def test_parallel_with_error(self):
|
||||
"""并行执行中有工具失败,不影响其他"""
|
||||
async def good_tool(x: int) -> dict:
|
||||
return {"good": x * 2}
|
||||
|
||||
async def bad_tool(x: int) -> dict:
|
||||
raise ValueError("boom")
|
||||
|
||||
fan_out = ParallelFanOut(
|
||||
name="mixed",
|
||||
description="Mixed results",
|
||||
tools=[
|
||||
FunctionTool(name="good", description="Good", func=good_tool),
|
||||
FunctionTool(name="bad", description="Bad", func=bad_tool),
|
||||
],
|
||||
)
|
||||
|
||||
result = await fan_out.execute(x=5)
|
||||
assert result["good"] == 10
|
||||
assert "_errors" in result
|
||||
assert any(e["tool"] == "bad" for e in result["_errors"])
|
||||
|
||||
async def test_parallel_namespace_merge(self):
|
||||
"""namespace 合并策略:每个工具结果独立命名空间"""
|
||||
async def tool_a(x: int) -> dict:
|
||||
return {"result": "a"}
|
||||
|
||||
async def tool_b(x: int) -> dict:
|
||||
return {"result": "b"}
|
||||
|
||||
fan_out = ParallelFanOut(
|
||||
name="namespaced",
|
||||
description="Namespaced",
|
||||
tools=[
|
||||
FunctionTool(name="tool_a", description="A", func=tool_a),
|
||||
FunctionTool(name="tool_b", description="B", func=tool_b),
|
||||
],
|
||||
merge_strategy="namespace",
|
||||
)
|
||||
|
||||
result = await fan_out.execute(x=1)
|
||||
assert "tool_a" in result
|
||||
assert "tool_b" in result
|
||||
assert result["tool_a"]["result"] == "a"
|
||||
assert result["tool_b"]["result"] == "b"
|
||||
|
||||
|
||||
# ── DynamicSelector 测试 ─────────────────────────────────
|
||||
|
||||
|
||||
class TestDynamicSelector:
|
||||
async def test_keyword_selection_by_intent(self):
|
||||
"""根据 _intent 参数选择工具"""
|
||||
search_tool = _make_tool("search_engine")
|
||||
calc_tool = _make_tool("calculator")
|
||||
|
||||
selector = DynamicSelector(
|
||||
name="smart",
|
||||
description="Smart selector",
|
||||
tools=[search_tool, calc_tool],
|
||||
mode="keyword",
|
||||
)
|
||||
|
||||
result = await selector.execute(query="AI", _intent="search")
|
||||
assert result["tool"] == "search_engine"
|
||||
|
||||
async def test_keyword_selection_by_input(self):
|
||||
"""根据输入内容推断工具"""
|
||||
search_tool = _make_tool("search_data")
|
||||
calc_tool = _make_tool("calculate_result")
|
||||
|
||||
selector = DynamicSelector(
|
||||
name="smart",
|
||||
description="Smart selector",
|
||||
tools=[search_tool, calc_tool],
|
||||
mode="keyword",
|
||||
)
|
||||
|
||||
result = await selector.execute(query="search for trends")
|
||||
assert result["tool"] == "search_data"
|
||||
|
||||
async def test_no_matching_tool(self):
|
||||
"""无匹配工具时返回错误"""
|
||||
tool_a = _make_tool("tool_a")
|
||||
|
||||
selector = DynamicSelector(
|
||||
name="smart",
|
||||
description="Smart selector",
|
||||
tools=[tool_a],
|
||||
mode="keyword",
|
||||
)
|
||||
|
||||
result = await selector.execute(x=999)
|
||||
# 应该仍然选择 tool_a(best_score=0 的默认)
|
||||
assert "tool" in result or "error" in result
|
||||
|
||||
async def test_llm_selection(self):
|
||||
"""LLM 模式选择工具"""
|
||||
search_tool = _make_tool("search_engine")
|
||||
calc_tool = _make_tool("calculator")
|
||||
|
||||
class MockLLM:
|
||||
async def chat(self, messages, **kwargs):
|
||||
return "0" # 选择第一个工具
|
||||
|
||||
selector = DynamicSelector(
|
||||
name="smart",
|
||||
description="Smart selector",
|
||||
tools=[search_tool, calc_tool],
|
||||
mode="llm",
|
||||
llm_client=MockLLM(),
|
||||
)
|
||||
|
||||
result = await selector.execute(query="find information")
|
||||
assert result["tool"] == "search_engine"
|
||||
|
||||
async def test_llm_fallback_to_keyword(self):
|
||||
"""LLM 失败时回退到关键字模式"""
|
||||
search_tool = _make_tool("search_engine")
|
||||
|
||||
class BrokenLLM:
|
||||
async def chat(self, messages, **kwargs):
|
||||
raise ConnectionError("LLM unavailable")
|
||||
|
||||
selector = DynamicSelector(
|
||||
name="smart",
|
||||
description="Smart selector",
|
||||
tools=[search_tool],
|
||||
mode="llm",
|
||||
llm_client=BrokenLLM(),
|
||||
)
|
||||
|
||||
result = await selector.execute(query="search something")
|
||||
assert result["tool"] == "search_engine"
|
||||
|
||||
|
||||
# ── MCPTool 测试 ─────────────────────────────────────────
|
||||
|
||||
|
||||
class TestMCPTool:
|
||||
async def test_mcp_tool_execute(self):
|
||||
"""MCPTool 通过 client 调用远程工具"""
|
||||
|
||||
class MockMCPClient:
|
||||
async def call_tool(self, name, arguments):
|
||||
return {
|
||||
"content": [
|
||||
{"type": "text", "text": json.dumps({"status": "ok", "data": arguments})}
|
||||
]
|
||||
}
|
||||
|
||||
tool = MCPTool(
|
||||
name="remote_search",
|
||||
description="Remote search via MCP",
|
||||
client=MockMCPClient(),
|
||||
)
|
||||
|
||||
result = await tool.execute(query="test")
|
||||
assert result["status"] == "ok"
|
||||
assert result["data"]["query"] == "test"
|
||||
|
||||
async def test_mcp_tool_text_response(self):
|
||||
"""MCP 返回非 JSON 文本"""
|
||||
|
||||
class MockMCPClient:
|
||||
async def call_tool(self, name, arguments):
|
||||
return {
|
||||
"content": [
|
||||
{"type": "text", "text": "plain text result"}
|
||||
]
|
||||
}
|
||||
|
||||
tool = MCPTool(
|
||||
name="remote_echo",
|
||||
description="Remote echo",
|
||||
client=MockMCPClient(),
|
||||
)
|
||||
|
||||
result = await tool.execute(text="hello")
|
||||
assert result["result"] == "plain text result"
|
||||
|
||||
async def test_mcp_tool_raw_response(self):
|
||||
"""MCP 返回非标准格式"""
|
||||
|
||||
class MockMCPClient:
|
||||
async def call_tool(self, name, arguments):
|
||||
return {"raw_key": "raw_value"}
|
||||
|
||||
tool = MCPTool(
|
||||
name="remote_raw",
|
||||
description="Raw response",
|
||||
client=MockMCPClient(),
|
||||
)
|
||||
|
||||
result = await tool.execute(x=1)
|
||||
assert result["raw_key"] == "raw_value"
|
||||
Loading…
Reference in New Issue