304 lines
10 KiB
Python
304 lines
10 KiB
Python
"""U3 测试: Tool 组合 (SequentialChain, ParallelFanOut, DynamicSelector) + MCPTool"""
|
||
|
||
import asyncio
|
||
import json
|
||
|
||
import pytest
|
||
|
||
from agentkit.tools.base import Tool
|
||
from agentkit.tools.composition import DynamicSelector, ParallelFanOut, SequentialChain
|
||
from agentkit.tools.function_tool import FunctionTool
|
||
from agentkit.tools.mcp_tool import MCPTool
|
||
|
||
|
||
# ── Helper: 创建简单 FunctionTool ────────────────────────
|
||
|
||
|
||
def _make_tool(name: str, func=None, **kw):
|
||
if func is None:
|
||
async def default_func(**kwargs):
|
||
return {"tool": name, "input": kwargs}
|
||
func = default_func
|
||
return FunctionTool(name=name, description=f"Test tool {name}", func=func, **kw)
|
||
|
||
|
||
# ── SequentialChain 测试 ─────────────────────────────────
|
||
|
||
|
||
class TestSequentialChain:
|
||
async def test_chain_two_tools(self):
|
||
"""顺序执行两个工具,前一个输出作为后一个输入"""
|
||
async def step1(text: str) -> dict:
|
||
return {"processed": text.upper(), "length": len(text)}
|
||
|
||
async def step2(processed: str, length: int) -> dict:
|
||
return {"result": f"{processed}({length})"}
|
||
|
||
chain = SequentialChain(
|
||
name="upper_and_format",
|
||
description="Uppercase then format",
|
||
tools=[
|
||
FunctionTool(name="step1", description="Step 1", func=step1),
|
||
FunctionTool(name="step2", description="Step 2", func=step2),
|
||
],
|
||
)
|
||
|
||
result = await chain.execute(text="hello")
|
||
assert result["result"] == "HELLO(5)"
|
||
|
||
async def test_chain_single_tool(self):
|
||
"""单工具链直接返回结果"""
|
||
tool = _make_tool("solo")
|
||
chain = SequentialChain(name="solo_chain", description="Solo", tools=[tool])
|
||
|
||
result = await chain.execute(x=1)
|
||
assert result["tool"] == "solo"
|
||
|
||
async def test_chain_preserves_dict(self):
|
||
"""链中每步都返回 dict"""
|
||
async def add_field(x: int) -> dict:
|
||
return {"x": x, "doubled": x * 2}
|
||
|
||
async def add_label(x: int, doubled: int) -> dict:
|
||
return {"x": x, "doubled": doubled, "label": f"{x}->{doubled}"}
|
||
|
||
chain = SequentialChain(
|
||
name="enrich",
|
||
description="Enrich data",
|
||
tools=[
|
||
FunctionTool(name="add_field", description="Add", func=add_field),
|
||
FunctionTool(name="add_label", description="Label", func=add_label),
|
||
],
|
||
)
|
||
|
||
result = await chain.execute(x=3)
|
||
assert result["label"] == "3->6"
|
||
|
||
|
||
# ── ParallelFanOut 测试 ──────────────────────────────────
|
||
|
||
|
||
class TestParallelFanOut:
|
||
async def test_parallel_execution(self):
|
||
"""并行执行三个工具"""
|
||
async def search_web(query: str) -> dict:
|
||
return {"web_results": [f"web: {query}"]}
|
||
|
||
async def search_db(query: str) -> dict:
|
||
return {"db_results": [f"db: {query}"]}
|
||
|
||
async def search_cache(query: str) -> dict:
|
||
return {"cache_results": [f"cache: {query}"]}
|
||
|
||
fan_out = ParallelFanOut(
|
||
name="multi_search",
|
||
description="Search multiple sources",
|
||
tools=[
|
||
FunctionTool(name="search_web", description="Web", func=search_web),
|
||
FunctionTool(name="search_db", description="DB", func=search_db),
|
||
FunctionTool(name="search_cache", description="Cache", func=search_cache),
|
||
],
|
||
)
|
||
|
||
result = await fan_out.execute(query="AI")
|
||
assert "web_results" in result
|
||
assert "db_results" in result
|
||
assert "cache_results" in result
|
||
|
||
async def test_parallel_with_error(self):
|
||
"""并行执行中有工具失败,不影响其他"""
|
||
async def good_tool(x: int) -> dict:
|
||
return {"good": x * 2}
|
||
|
||
async def bad_tool(x: int) -> dict:
|
||
raise ValueError("boom")
|
||
|
||
fan_out = ParallelFanOut(
|
||
name="mixed",
|
||
description="Mixed results",
|
||
tools=[
|
||
FunctionTool(name="good", description="Good", func=good_tool),
|
||
FunctionTool(name="bad", description="Bad", func=bad_tool),
|
||
],
|
||
)
|
||
|
||
result = await fan_out.execute(x=5)
|
||
assert result["good"] == 10
|
||
assert "_errors" in result
|
||
assert any(e["tool"] == "bad" for e in result["_errors"])
|
||
|
||
async def test_parallel_namespace_merge(self):
|
||
"""namespace 合并策略:每个工具结果独立命名空间"""
|
||
async def tool_a(x: int) -> dict:
|
||
return {"result": "a"}
|
||
|
||
async def tool_b(x: int) -> dict:
|
||
return {"result": "b"}
|
||
|
||
fan_out = ParallelFanOut(
|
||
name="namespaced",
|
||
description="Namespaced",
|
||
tools=[
|
||
FunctionTool(name="tool_a", description="A", func=tool_a),
|
||
FunctionTool(name="tool_b", description="B", func=tool_b),
|
||
],
|
||
merge_strategy="namespace",
|
||
)
|
||
|
||
result = await fan_out.execute(x=1)
|
||
assert "tool_a" in result
|
||
assert "tool_b" in result
|
||
assert result["tool_a"]["result"] == "a"
|
||
assert result["tool_b"]["result"] == "b"
|
||
|
||
|
||
# ── DynamicSelector 测试 ─────────────────────────────────
|
||
|
||
|
||
class TestDynamicSelector:
|
||
async def test_keyword_selection_by_intent(self):
|
||
"""根据 _intent 参数选择工具"""
|
||
search_tool = _make_tool("search_engine")
|
||
calc_tool = _make_tool("calculator")
|
||
|
||
selector = DynamicSelector(
|
||
name="smart",
|
||
description="Smart selector",
|
||
tools=[search_tool, calc_tool],
|
||
mode="keyword",
|
||
)
|
||
|
||
result = await selector.execute(query="AI", _intent="search")
|
||
assert result["tool"] == "search_engine"
|
||
|
||
async def test_keyword_selection_by_input(self):
|
||
"""根据输入内容推断工具"""
|
||
search_tool = _make_tool("search_data")
|
||
calc_tool = _make_tool("calculate_result")
|
||
|
||
selector = DynamicSelector(
|
||
name="smart",
|
||
description="Smart selector",
|
||
tools=[search_tool, calc_tool],
|
||
mode="keyword",
|
||
)
|
||
|
||
result = await selector.execute(query="search for trends")
|
||
assert result["tool"] == "search_data"
|
||
|
||
async def test_no_matching_tool(self):
|
||
"""无匹配工具时返回错误"""
|
||
tool_a = _make_tool("tool_a")
|
||
|
||
selector = DynamicSelector(
|
||
name="smart",
|
||
description="Smart selector",
|
||
tools=[tool_a],
|
||
mode="keyword",
|
||
)
|
||
|
||
result = await selector.execute(x=999)
|
||
# 应该仍然选择 tool_a(best_score=0 的默认)
|
||
assert "tool" in result or "error" in result
|
||
|
||
async def test_llm_selection(self):
|
||
"""LLM 模式选择工具"""
|
||
search_tool = _make_tool("search_engine")
|
||
calc_tool = _make_tool("calculator")
|
||
|
||
class MockLLM:
|
||
async def chat(self, messages, **kwargs):
|
||
return "0" # 选择第一个工具
|
||
|
||
selector = DynamicSelector(
|
||
name="smart",
|
||
description="Smart selector",
|
||
tools=[search_tool, calc_tool],
|
||
mode="llm",
|
||
llm_client=MockLLM(),
|
||
)
|
||
|
||
result = await selector.execute(query="find information")
|
||
assert result["tool"] == "search_engine"
|
||
|
||
async def test_llm_fallback_to_keyword(self):
|
||
"""LLM 失败时回退到关键字模式"""
|
||
search_tool = _make_tool("search_engine")
|
||
|
||
class BrokenLLM:
|
||
async def chat(self, messages, **kwargs):
|
||
raise ConnectionError("LLM unavailable")
|
||
|
||
selector = DynamicSelector(
|
||
name="smart",
|
||
description="Smart selector",
|
||
tools=[search_tool],
|
||
mode="llm",
|
||
llm_client=BrokenLLM(),
|
||
)
|
||
|
||
result = await selector.execute(query="search something")
|
||
assert result["tool"] == "search_engine"
|
||
|
||
|
||
# ── MCPTool 测试 ─────────────────────────────────────────
|
||
|
||
|
||
class TestMCPTool:
|
||
async def test_mcp_tool_execute(self):
|
||
"""MCPTool 通过 client 调用远程工具"""
|
||
|
||
class MockMCPClient:
|
||
async def call_tool(self, name, arguments):
|
||
return {
|
||
"content": [
|
||
{"type": "text", "text": json.dumps({"status": "ok", "data": arguments})}
|
||
]
|
||
}
|
||
|
||
tool = MCPTool(
|
||
name="remote_search",
|
||
description="Remote search via MCP",
|
||
client=MockMCPClient(),
|
||
)
|
||
|
||
result = await tool.execute(query="test")
|
||
assert result["status"] == "ok"
|
||
assert result["data"]["query"] == "test"
|
||
|
||
async def test_mcp_tool_text_response(self):
|
||
"""MCP 返回非 JSON 文本"""
|
||
|
||
class MockMCPClient:
|
||
async def call_tool(self, name, arguments):
|
||
return {
|
||
"content": [
|
||
{"type": "text", "text": "plain text result"}
|
||
]
|
||
}
|
||
|
||
tool = MCPTool(
|
||
name="remote_echo",
|
||
description="Remote echo",
|
||
client=MockMCPClient(),
|
||
)
|
||
|
||
result = await tool.execute(text="hello")
|
||
assert result["result"] == "plain text result"
|
||
|
||
async def test_mcp_tool_raw_response(self):
|
||
"""MCP 返回非标准格式"""
|
||
|
||
class MockMCPClient:
|
||
async def call_tool(self, name, arguments):
|
||
return {"raw_key": "raw_value"}
|
||
|
||
tool = MCPTool(
|
||
name="remote_raw",
|
||
description="Raw response",
|
||
client=MockMCPClient(),
|
||
)
|
||
|
||
result = await tool.execute(x=1)
|
||
assert result["raw_key"] == "raw_value"
|